1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-14 11:58:02 +00:00

Add SCRIPT LOAD/EVALSHA

This commit is contained in:
Roman Gershman 2022-02-05 11:50:50 +02:00
parent ec70cc9e9f
commit c567a70244
14 changed files with 264 additions and 57 deletions

View file

@ -397,9 +397,10 @@ void Interpreter::FuncSha1(string_view body, char* fp) {
auto Interpreter::AddFunction(string_view body, string* result) -> AddResult {
char funcname[43];
FuncSha1(body, funcname + 2);
funcname[0] = 'f';
funcname[1] = '_';
FuncSha1(body, funcname + 2);
funcname[42] = '\0';
int type = lua_getglobal(lua_, funcname);
lua_pop(lua_, 1);
@ -409,17 +410,40 @@ auto Interpreter::AddFunction(string_view body, string* result) -> AddResult {
result->assign(funcname + 2);
return type == LUA_TNIL ? OK : ALREADY_EXISTS;
return type == LUA_TNIL ? ADD_OK : ALREADY_EXISTS;
}
bool Interpreter::RunFunction(const char* f_id, std::string* error) {
bool Interpreter::Exists(string_view sha) const {
if (sha.size() != 40)
return false;
char fname[43];
fname[0] = 'f';
fname[1] = '_';
fname[42] = '\0';
memcpy(fname + 2, sha.data(), 40);
int type = lua_getglobal(lua_, fname);
lua_pop(lua_, 1);
return type == LUA_TFUNCTION;
}
auto Interpreter::RunFunction(string_view sha, std::string* error) -> RunResult {
DCHECK_EQ(40u, sha.size());
lua_getglobal(lua_, "__redis__err__handler");
int type = lua_getglobal(lua_, f_id);
char fname[43];
fname[0] = 'f';
fname[1] = '_';
memcpy(fname + 2, sha.data(), 40);
fname[42] = '\0';
int type = lua_getglobal(lua_, fname);
if (type != LUA_TFUNCTION) {
error->assign("function not found"); // TODO: noscripterr.
lua_pop(lua_, 2);
return false;
return NOT_EXISTS;
}
/* We have zero arguments and expect
@ -429,13 +453,15 @@ bool Interpreter::RunFunction(const char* f_id, std::string* error) {
if (err) {
*error = lua_tostring(lua_, -1);
}
return err == 0;
return err == 0 ? RUN_OK : RUN_ERR;
}
void Interpreter::SetGlobalArray(const char* name, MutSliceSpan args) {
SetGlobalArrayInternal(lua_, name, args);
}
/*
bool Interpreter::Execute(string_view body, char f_id[41], string* error) {
lua_getglobal(lua_, "__redis__err__handler");
char fname[43];
@ -464,6 +490,7 @@ bool Interpreter::Execute(string_view body, char f_id[41], string* error) {
return err == 0;
}
*/
bool Interpreter::AddInternal(const char* f_id, string_view body, string* error) {
string script = absl::StrCat("function ", f_id, "() \n");

View file

@ -6,6 +6,7 @@
#include <absl/types/span.h>
#include <boost/fiber/mutex.hpp>
#include <functional>
#include <string_view>
@ -48,7 +49,7 @@ class Interpreter {
}
enum AddResult {
OK = 0,
ADD_OK = 0,
ALREADY_EXISTS = 1,
COMPILE_ERR = 2,
};
@ -58,13 +59,22 @@ class Interpreter {
// function id is sha1 of the function body.
AddResult AddFunction(std::string_view body, std::string* result);
// Runs already added function f_id returned by a successful call to AddFunction().
bool Exists(std::string_view sha) const;
enum RunResult {
RUN_OK = 0,
NOT_EXISTS = 1,
RUN_ERR = 2,
};
// Runs already added function sha returned by a successful call to AddFunction().
// Returns: true if the call succeeded, otherwise fills error and returns false.
bool RunFunction(const char* f_id, std::string* err);
// sha must be 40 char length.
RunResult RunFunction(std::string_view sha, std::string* err);
void SetGlobalArray(const char* name, MutSliceSpan args);
bool Execute(std::string_view body, char f_id[41], std::string* err);
// bool Execute(std::string_view body, char f_id[41], std::string* err);
bool Serialize(ObjectExplorer* serializer, std::string* err);
// fp must point to buffer with at least 41 chars.
@ -75,6 +85,15 @@ class Interpreter {
redis_func_ = std::forward<U>(u);
}
// We have interpreter per thread, not per connection.
// Since we might preempt into different fibers when operating on interpreter
// we must lock it until we finish using it per request.
// Only RunFunction with companions require locking since other functions peform atomically
// without preemptions.
std::lock_guard<::boost::fibers::mutex> Lock() {
return std::lock_guard<::boost::fibers::mutex>{mu_};
}
private:
// Returns true if function was successfully added,
// otherwise returns false and sets the error.
@ -88,6 +107,11 @@ class Interpreter {
lua_State* lua_;
unsigned cmd_depth_ = 0;
RedisFunc redis_func_;
// We have interpreter per thread, not per connection.
// Since we might preempt into different fibers when operating on interpreter
// we must lock it until we finish using it per request.
::boost::fibers::mutex mu_;
};
} // namespace dfly

View file

@ -88,7 +88,6 @@ class InterpreterTest : public ::testing::Test {
bool Execute(string_view script);
Interpreter intptr_;
TestSerializer ser_;
string error_;
@ -103,9 +102,18 @@ void InterpreterTest::SetGlobalArray(const char* name, vector<string> vec) {
}
bool InterpreterTest::Execute(string_view script) {
char buf[48];
string result;
Interpreter::AddResult add_res = intptr_.AddFunction(script, &result);
if (add_res == Interpreter::COMPILE_ERR) {
error_ = result;
return false;
}
return intptr_.Execute(script, buf, &error_) && Serialize(&error_);
Interpreter::RunResult run_res = intptr_.RunFunction(result, &error_);
if (run_res != Interpreter::RUN_OK) {
return false;
}
return Serialize(&error_);
}
TEST_F(InterpreterTest, Basic) {
@ -155,16 +163,17 @@ TEST_F(InterpreterTest, Basic) {
TEST_F(InterpreterTest, Add) {
string res1, res2;
EXPECT_EQ(Interpreter::OK, intptr_.AddFunction("return 0", &res1));
EXPECT_EQ(Interpreter::ADD_OK, intptr_.AddFunction("return 0", &res1));
EXPECT_EQ(0, lua_gettop(lua()));
EXPECT_EQ(Interpreter::COMPILE_ERR, intptr_.AddFunction("foobar", &res2));
EXPECT_THAT(res2, testing::HasSubstr("syntax error"));
EXPECT_EQ(0, lua_gettop(lua()));
EXPECT_TRUE(intptr_.Exists(res1));
}
// Test cases taken from scripting.tcl
TEST_F(InterpreterTest, Execute) {
EXPECT_TRUE(Execute("return 42"));
ASSERT_TRUE(Execute("return 42"));
EXPECT_EQ("i(42)", ser_.res);
EXPECT_TRUE(Execute("return 'hello'"));
@ -213,7 +222,7 @@ TEST_F(InterpreterTest, Call) {
};
intptr_.SetRedisFunc(cb);
EXPECT_TRUE(Execute("local var = redis.call('string'); return {type(var), var}"));
ASSERT_TRUE(Execute("local var = redis.call('string'); return {type(var), var}"));
EXPECT_EQ("[str(string) str(foo)]", ser_.res);
EXPECT_TRUE(Execute("local var = redis.call('double'); return {type(var), var}"));

View file

@ -66,6 +66,7 @@ const char kInvalidIntErr[] = "value is not an integer or out of range";
const char kUintErr[] = "value is out of range, must be positive";
const char kDbIndOutOfRangeErr[] = "DB index is out of range";
const char kInvalidDbIndErr[] = "invalid DB index";
const char kScriptNotFound[] = "-NOSCRIPT No matching script. Please use EVAL.";
const char* GlobalState::Name(S s) {
switch (s) {

View file

@ -83,6 +83,12 @@ inline void ToUpper(const MutableSlice* val) {
}
}
inline void ToLower(const MutableSlice* val) {
for (auto& c : *val) {
c = absl::ascii_tolower(c);
}
}
} // namespace dfly
namespace std {

View file

@ -4,6 +4,7 @@
#include <absl/strings/str_join.h>
#include <absl/strings/strip.h>
#include <absl/strings/ascii.h>
#include <gmock/gmock.h>
#include "base/gtest.h"
@ -214,13 +215,33 @@ TEST_F(DflyEngineTest, FlushDb) {
}
TEST_F(DflyEngineTest, Eval) {
auto resp = Run({"eval", "return 42", "0"});
EXPECT_THAT(resp[0], IntArg(42));
auto resp = Run({"eval", "return 43", "0"});
EXPECT_THAT(resp[0], IntArg(43));
// resp = Run({"eval", "return redis.call('get', 'foo')", "0"});
// EXPECT_THAT(resp[0], IntArg(42)); // TODO.
}
TEST_F(DflyEngineTest, EvalSha) {
auto resp = Run({"script", "load", "return 5"});
EXPECT_THAT(resp, ElementsAre(ArgType(RespExpr::STRING)));
string sha{ToSV(resp[0].GetBuf())};
resp = Run({"evalsha", sha, "0"});
EXPECT_THAT(resp[0], IntArg(5));
resp = Run({"script", "load", " return 5 "});
EXPECT_THAT(resp, ElementsAre(StrArg(sha)));
absl::AsciiStrToUpper(&sha);
resp = Run({"evalsha", sha, "0"});
EXPECT_THAT(resp[0], IntArg(5));
resp = Run({"evalsha", "foobar", "0"});
EXPECT_THAT(resp[0], ErrArg("No matching"));
}
// TODO: to test transactions with a single shard since then all transactions become local.
// To consider having a parameter in dragonfly engine controlling number of shards
// unconditionally from number of cpus. TO TEST BLPOP under multi for single/multi argument case.

View file

@ -18,6 +18,7 @@ extern const char kInvalidIntErr[];
extern const char kUintErr[];
extern const char kDbIndOutOfRangeErr[];
extern const char kInvalidDbIndErr[];
extern const char kScriptNotFound[];
#ifndef RETURN_ON_ERR

View file

@ -20,6 +20,7 @@ extern "C" {
#include "server/error.h"
#include "server/generic_family.h"
#include "server/list_family.h"
#include "server/script_mgr.h"
#include "server/server_state.h"
#include "server/string_family.h"
#include "server/transaction.h"
@ -343,7 +344,44 @@ void Service::CallFromScript(CmdArgList args, ObjectExplorer* reply, ConnectionC
}
void Service::Eval(CmdArgList args, ConnectionContext* cntx) {
string_view num_keys_str = ArgS(args, 2);
int32_t num_keys;
if (!absl::SimpleAtoi(num_keys_str, &num_keys) || num_keys < 0) {
return (*cntx)->SendError(kInvalidIntErr);
}
if (unsigned(num_keys) > args.size() - 3) {
return (*cntx)->SendError("Number of keys can't be greater than number of args");
}
string_view body = ArgS(args, 1);
body = absl::StripAsciiWhitespace(body);
if (body.empty()) {
return (*cntx)->SendNull();
}
ServerState* ss = ServerState::tlocal();
Interpreter& script = ss->GetInterpreter();
string result;
Interpreter::AddResult add_result = script.AddFunction(body, &result);
if (add_result == Interpreter::COMPILE_ERR) {
return (*cntx)->SendError(result);
}
if (add_result == Interpreter::ADD_OK) {
server_family_.script_mgr()->InsertFunction(result, body);
}
EvalArgs eval_args;
eval_args.sha = result;
eval_args.keys = args.subspan(3, num_keys);
eval_args.args = args.subspan(3 + num_keys);
EvalInternal(eval_args, &script, cntx);
}
void Service::EvalSha(CmdArgList args, ConnectionContext* cntx) {
string_view num_keys_str = ArgS(args, 2);
int32_t num_keys;
@ -355,31 +393,82 @@ void Service::Eval(CmdArgList args, ConnectionContext* cntx) {
return (*cntx)->SendError("Number of keys can't be greater than number of args");
}
ToLower(&args[1]);
string_view sha = ArgS(args, 1);
ServerState* ss = ServerState::tlocal();
lock_guard lk(ss->interpreter_mutex);
Interpreter& script = ss->GetInterpreter();
bool exists = script.Exists(sha);
CmdArgList eval_keys = args.subspan(3, num_keys);
CmdArgList eval_args = args.subspan(3 + num_keys);
if (!exists) {
const char* body = (sha.size() == 40) ? server_family_.script_mgr()->Find(sha) : nullptr;
if (!body) {
return (*cntx)->SendError(kScriptNotFound);
}
string res;
CHECK_EQ(Interpreter::ADD_OK, script.AddFunction(body, &res));
CHECK_EQ(res, sha);
}
EvalArgs ev_args;
ev_args.sha = sha;
ev_args.keys = args.subspan(3, num_keys);
ev_args.args = args.subspan(3 + num_keys);
EvalInternal(ev_args, &script, cntx);
}
void Service::EvalInternal(const EvalArgs& eval_args, Interpreter* interpreter,
ConnectionContext* cntx) {
DCHECK(!eval_args.sha.empty());
if (eval_args.sha.size() != 40) {
return (*cntx)->SendError(kScriptNotFound);
}
bool exists = interpreter->Exists(eval_args.sha);
if (!exists) {
const char* body = server_family_.script_mgr()->Find(eval_args.sha);
if (!body) {
return (*cntx)->SendError(kScriptNotFound);
}
string res;
CHECK_EQ(Interpreter::ADD_OK, interpreter->AddFunction(body, &res));
CHECK_EQ(res, eval_args.sha);
}
script.SetGlobalArray("KEYS", eval_keys);
script.SetGlobalArray("ARGV", eval_args);
string error;
char f_id[48];
bool success = script.Execute(body, f_id, &error);
auto lk = interpreter->Lock();
interpreter->SetGlobalArray("KEYS", eval_args.keys);
interpreter->SetGlobalArray("ARGV", eval_args.args);
interpreter->SetRedisFunc(
[cntx, this](CmdArgList args, ObjectExplorer* reply) { CallFromScript(args, reply, cntx); });
bool success = false;
if (eval_args.sha.empty()) {
} else {
Interpreter::RunResult result = interpreter->RunFunction(eval_args.sha, &error);
if (result == Interpreter::RUN_ERR) {
return (*cntx)->SendError(error);
}
CHECK(result == Interpreter::RUN_OK);
success = true;
}
if (success) {
EvalSerializer ser{static_cast<RedisReplyBuilder*>(cntx->reply_builder())};
string error;
script.SetRedisFunc([cntx, this](CmdArgList args, ObjectExplorer* reply) {
CallFromScript(args, reply, cntx);
});
if (!script.Serialize(&ser, &error)) {
if (!interpreter->Serialize(&ser, &error)) {
(*cntx)->SendError(error);
}
} else {
string resp = absl::StrCat("Error running script (call to ", f_id, "): ", error);
string resp = absl::StrCat("Error running script (call to ", eval_args.sha, "): ", error);
return (*cntx)->SendError(resp);
}
}
@ -441,8 +530,8 @@ VarzValue::Map Service::GetVarzStats() {
using ServiceFunc = void (Service::*)(CmdArgList, ConnectionContext* cntx);
#define HFUNC(x) SetHandler(&Service::x)
#define MFUNC(x) SetHandler([this](CmdArgList sp, ConnectionContext* cntx) { \
this->x(std::move(sp), cntx); })
#define MFUNC(x) \
SetHandler([this](CmdArgList sp, ConnectionContext* cntx) { this->x(std::move(sp), cntx); })
void Service::RegisterCommands() {
using CI = CommandId;
@ -450,9 +539,9 @@ void Service::RegisterCommands() {
constexpr auto kExecMask = CO::LOADING | CO::NOSCRIPT | CO::GLOBAL_TRANS;
registry_ << CI{"QUIT", CO::READONLY | CO::FAST, 1, 0, 0, 0}.HFUNC(Quit)
<< CI{"MULTI", CO::NOSCRIPT | CO::FAST | CO::LOADING, 1, 0, 0, 0}.HFUNC(
Multi)
<< CI{"MULTI", CO::NOSCRIPT | CO::FAST | CO::LOADING, 1, 0, 0, 0}.HFUNC(Multi)
<< CI{"EVAL", CO::NOSCRIPT, -3, 0, 0, 0}.MFUNC(Eval)
<< CI{"EVALSHA", CO::NOSCRIPT, -3, 0, 0, 0}.MFUNC(EvalSha)
<< CI{"EXEC", kExecMask, 1, 0, 0, 0}.MFUNC(Exec);
StringFamily::Register(&registry_);

View file

@ -17,6 +17,7 @@ class AcceptServer;
namespace dfly {
class Interpreter;
class ObjectExplorer; // for Interpreter
class Service {
@ -68,7 +69,16 @@ class Service {
static void Multi(CmdArgList args, ConnectionContext* cntx);
void Eval(CmdArgList args, ConnectionContext* cntx);
void EvalSha(CmdArgList args, ConnectionContext* cntx);
void Exec(CmdArgList args, ConnectionContext* cntx);
struct EvalArgs {
std::string_view sha; // only one of them is defined.
CmdArgList keys, args;
};
void EvalInternal(const EvalArgs& eval_args, Interpreter* interpreter, ConnectionContext* cntx);
void CallFromScript(CmdArgList args, ObjectExplorer* reply, ConnectionContext* cntx);
void RegisterCommands();

View file

@ -24,7 +24,7 @@ void ScriptMgr::Run(CmdArgList args, ConnectionContext* cntx) {
string_view kHelp[] = {
"SCRIPT <subcommand> [<arg> [value] [opt] ...]",
"Subcommands are:",
"EXISTS <sha1> [<s ha1> ...]",
"EXISTS <sha1> [<sha1> ...]",
" Return information about the existence of the scripts in the script cache.",
"LOAD <script>",
" Load a script into the scripts cache without executing it.",
@ -41,6 +41,7 @@ void ScriptMgr::Run(CmdArgList args, ConnectionContext* cntx) {
return (*cntx)->SendError("Refuse to load empty script");
Interpreter& interpreter = ServerState::tlocal()->GetInterpreter();
// no need to lock the interpreter since we do not mess the stack.
string error_or_id;
Interpreter::AddResult add_result = interpreter.AddFunction(body, &error_or_id);
if (add_result == Interpreter::ALREADY_EXISTS) {
@ -50,17 +51,7 @@ void ScriptMgr::Run(CmdArgList args, ConnectionContext* cntx) {
return (*cntx)->SendError(error_or_id);
}
ScriptKey sha1;
CHECK_EQ(sha1.size(), error_or_id.size());
memcpy(sha1.data(), error_or_id.data(), sha1.size());
lock_guard lk(mu_);
auto [it, inserted] = db_.emplace(sha1, nullptr);
if (inserted) {
it->second.reset(new char[body.size() + 1]);
memcpy(it->second.get(), body.data(), body.size());
it->second[body.size()] = '\0';
}
InsertFunction(error_or_id, body);
return (*cntx)->SendBulkString(error_or_id);
}
@ -68,4 +59,32 @@ void ScriptMgr::Run(CmdArgList args, ConnectionContext* cntx) {
"Unknown subcommand or wrong number of arguments for '", subcmd, "'. Try SCRIPT HELP."));
} // namespace dfly
bool ScriptMgr::InsertFunction(std::string_view id, std::string_view body) {
ScriptKey key;
CHECK_EQ(key.size(), id.size());
memcpy(key.data(), id.data(), key.size());
lock_guard lk(mu_);
auto [it, inserted] = db_.emplace(key, nullptr);
if (inserted) {
it->second.reset(new char[body.size() + 1]);
memcpy(it->second.get(), body.data(), body.size());
it->second[body.size()] = '\0';
}
return inserted;
}
const char* ScriptMgr::Find(std::string_view sha) const {
ScriptKey key;
CHECK_EQ(key.size(), sha.size());
memcpy(key.data(), sha.data(), key.size());
lock_guard lk(mu_);
auto it = db_.find(key);
if (it == db_.end())
return nullptr;
return it->second.get();
}
} // namespace dfly

View file

@ -22,11 +22,14 @@ class ScriptMgr {
void Run(CmdArgList args, ConnectionContext* cntx);
bool InsertFunction(std::string_view sha, std::string_view body);
const char* Find(std::string_view sha) const;
private:
EngineShardSet* ess_;
using ScriptKey = std::array<char, 40>;
absl::flat_hash_map<ScriptKey, std::unique_ptr<char[]>> db_;
::boost::fibers::mutex mu_;
mutable ::boost::fibers::mutex mu_;
};
} // namespace dfly

View file

@ -42,6 +42,10 @@ class ServerFamily {
return &global_state_;
}
ScriptMgr* script_mgr() {
return script_mgr_.get();
}
private:
uint32_t shard_count() const {
return ess_.size();
@ -59,7 +63,6 @@ class ServerFamily {
void Script(CmdArgList args, ConnectionContext* cntx);
void Sync(CmdArgList args, ConnectionContext* cntx);
void _Shutdown(CmdArgList args, ConnectionContext* cntx);
void SyncGeneric(std::string_view repl_master_id, uint64_t offs, ConnectionContext* cntx);

View file

@ -4,7 +4,6 @@
#pragma once
#include <boost/fiber/mutex.hpp>
#include <optional>
#include <vector>
@ -63,11 +62,6 @@ class ServerState { // public struct - to allow initialization.
Interpreter& GetInterpreter();
// We have interpreter per thread, not per connection.
// Since we might preempt into different fibers when operating on interpreter
// we must lock it until we finish using it per request.
::boost::fibers::mutex interpreter_mutex;
private:
int64_t live_transactions_ = 0;
std::optional<Interpreter> interpreter_;

View file

@ -116,7 +116,7 @@ class BaseFamilyTest : public ::testing::Test {
unsigned num_threads_ = 3;
struct TestConn {
io::StringSink sink;
::io::StringSink sink;
std::unique_ptr<Connection> dummy_conn;
ConnectionContext cmd_cntx;