1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-15 17:51:06 +00:00

Wire support for redis.call and redis.pcall commands

This commit is contained in:
Roman Gershman 2022-02-01 19:02:43 +02:00
parent fc56a8e61a
commit 7a2f4baeec
6 changed files with 146 additions and 7 deletions

View file

@ -37,6 +37,70 @@ void Require(lua_State* lua, const char* name, lua_CFunction openf) {
lua_pop(lua, 1); /* remove lib */ lua_pop(lua, 1); /* remove lib */
} }
/*
* Save the give pointer on Lua registry, used to save the Lua context and
* function context so we can retrieve them from lua_State.
*/
void SaveOnRegistry(lua_State* lua, const char* name, void* ptr) {
lua_pushstring(lua, name);
if (ptr) {
lua_pushlightuserdata(lua, ptr);
} else {
lua_pushnil(lua);
}
lua_settable(lua, LUA_REGISTRYINDEX);
}
/*
* Get a saved pointer from registry
*/
void* GetFromRegistry(lua_State* lua, const char* name) {
lua_pushstring(lua, name);
lua_gettable(lua, LUA_REGISTRYINDEX);
/* must be light user data */
DCHECK(lua_islightuserdata(lua, -1));
void* ptr = (void*)lua_topointer(lua, -1);
DCHECK(ptr);
/* pops the value */
lua_pop(lua, 1);
return ptr;
}
/* This function is used in order to push an error on the Lua stack in the
* format used by redis.pcall to return errors, which is a lua table
* with a single "err" field set to the error string. Note that this
* table is never a valid reply by proper commands, since the returned
* tables are otherwise always indexed by integers, never by strings. */
void PushError(lua_State* lua, const char* error) {
lua_Debug dbg;
lua_newtable(lua);
lua_pushstring(lua, "err");
/* Attempt to figure out where this function was called, if possible */
if (lua_getstack(lua, 1, &dbg) && lua_getinfo(lua, "nSl", &dbg)) {
string msg = absl::StrCat(dbg.source, ": ", dbg.currentline, ": ", error);
lua_pushstring(lua, msg.c_str());
} else {
lua_pushstring(lua, error);
}
lua_settable(lua, -3);
}
/* In case the error set into the Lua stack by PushError() was generated
* by the non-error-trapping version of redis.pcall(), which is redis.call(),
* this function will raise the Lua error so that the execution of the
* script will be halted. */
int RaiseError(lua_State* lua) {
lua_pushstring(lua, "err");
lua_gettable(lua, -2);
return lua_error(lua);
}
void InitLua(lua_State* lua) { void InitLua(lua_State* lua) {
Require(lua, "", luaopen_base); Require(lua, "", luaopen_base);
Require(lua, LUA_TABLIBNAME, luaopen_table); Require(lua, LUA_TABLIBNAME, luaopen_table);
@ -90,6 +154,11 @@ debug = nil
)"; )";
RunSafe(lua, code, "@enable_strict_lua"); RunSafe(lua, code, "@enable_strict_lua");
} }
lua_pushnil(lua);
lua_setglobal(lua, "loadfile");
lua_pushnil(lua);
lua_setglobal(lua, "dofile");
} }
void ToHex(const uint8_t* src, char* dest) { void ToHex(const uint8_t* src, char* dest) {
@ -115,11 +184,14 @@ optional<int> FetchKey(lua_State* lua, const char* key) {
return type; return type;
} }
const char* kInstanceKey = "_INSTANCE";
} // namespace } // namespace
Interpreter::Interpreter() { Interpreter::Interpreter() {
lua_ = luaL_newstate(); lua_ = luaL_newstate();
InitLua(lua_); InitLua(lua_);
SaveOnRegistry(lua_, kInstanceKey, this);
} }
Interpreter::~Interpreter() { Interpreter::~Interpreter() {
@ -241,7 +313,7 @@ bool Interpreter::Serialize(ObjectExplorer* serializer, std::string* error) {
if (len > 0) { // array if (len > 0) { // array
serializer->OnArrayStart(len); serializer->OnArrayStart(len);
for (unsigned i = 0; i < len; ++i) { for (unsigned i = 0; i < len; ++i) {
t = lua_rawgeti(lua_, -1, i + 1); // push table element t = lua_rawgeti(lua_, -1, i + 1); // push table element
// TODO: we should make sure that we have enough stack space // TODO: we should make sure that we have enough stack space
// to traverse each object. This can be done as a dry-run before doing real serialization. // to traverse each object. This can be done as a dry-run before doing real serialization.
@ -279,4 +351,53 @@ bool Interpreter::Serialize(ObjectExplorer* serializer, std::string* error) {
return res; return res;
} }
// Returns number of results, which is always 1 in this case.
// Please note that lua resets the stack once the function returns so no need
// to unwind the stack manually in the function (though lua allows doing this).
int Interpreter::RedisGenericCommand(bool raise_error) {
/* By using Lua debug hooks it is possible to trigger a recursive call
* to luaRedisGenericCommand(), which normally should never happen.
* To make this function reentrant is futile and makes it slower, but
* we should at least detect such a misuse, and abort. */
if (cmd_depth_) {
const char* recursion_warning =
"luaRedisGenericCommand() recursive call detected. "
"Are you doing funny stuff with Lua debug hooks?";
PushError(lua_, recursion_warning);
return 1;
}
cmd_depth_++;
int argc = lua_gettop(lua_);
/* Require at least one argument */
if (argc == 0) {
PushError(lua_, "Please specify at least one argument for redis.call()");
cmd_depth_--;
return raise_error ? RaiseError(lua_) : 1;
}
// TODO: to prepare arguments.
/* Pop all arguments from the stack, we do not need them anymore
* and this way we guaranty we will have room on the stack for the result. */
lua_pop(lua_, argc);
cmd_depth_--;
lua_pushinteger(lua_, 42);
return 1;
}
int Interpreter::RedisCallCommand(lua_State* lua) {
void* me = GetFromRegistry(lua, kInstanceKey);
return reinterpret_cast<Interpreter*>(me)->RedisGenericCommand(true);
}
int Interpreter::RedisPCallCommand(lua_State* lua) {
void* me = GetFromRegistry(lua, kInstanceKey);
return reinterpret_cast<Interpreter*>(me)->RedisGenericCommand(false);
}
} // namespace dfly } // namespace dfly

View file

@ -57,7 +57,13 @@ class Interpreter {
private: private:
bool AddInternal(const char* f_id, std::string_view body, std::string* result); bool AddInternal(const char* f_id, std::string_view body, std::string* result);
int RedisGenericCommand(bool raise_error);
static int RedisCallCommand(lua_State *lua);
static int RedisPCallCommand(lua_State *lua);
lua_State* lua_; lua_State* lua_;
unsigned cmd_depth_ = 0;
}; };
} // namespace dfly } // namespace dfly

View file

@ -216,6 +216,9 @@ TEST_F(DflyEngineTest, FlushDb) {
TEST_F(DflyEngineTest, Eval) { TEST_F(DflyEngineTest, Eval) {
auto resp = Run({"eval", "return 42", "0"}); auto resp = Run({"eval", "return 42", "0"});
EXPECT_THAT(resp[0], IntArg(42)); EXPECT_THAT(resp[0], IntArg(42));
resp = Run({"eval", "return redis.call('get', 'foo')", "0"});
EXPECT_THAT(resp[0], IntArg(42)); // TODO.
} }
// TODO: to test transactions with a single shard since then all transactions become local. // TODO: to test transactions with a single shard since then all transactions become local.

View file

@ -55,15 +55,19 @@ class EvalSerializer : public ObjectExplorer {
} }
void OnBool(bool b) final { void OnBool(bool b) final {
LOG(FATAL) << "TBD"; if (b) {
rb_->SendLong(1);
} else {
rb_->SendNull();
}
} }
void OnString(std::string_view str) final { void OnString(std::string_view str) final {
LOG(FATAL) << "TBD"; rb_->SendBulkString(str);
} }
void OnDouble(double d) final { void OnDouble(double d) final {
LOG(FATAL) << "TBD"; rb_->SendDouble(d);
} }
void OnInt(int64_t val) final { void OnInt(int64_t val) final {
@ -79,15 +83,15 @@ class EvalSerializer : public ObjectExplorer {
} }
void OnNil() final { void OnNil() final {
LOG(FATAL) << "TBD"; rb_->SendNull();
} }
void OnStatus(std::string_view str) { void OnStatus(std::string_view str) {
LOG(FATAL) << "TBD"; rb_->SendSimpleRespString(str);
} }
void OnError(std::string_view str) { void OnError(std::string_view str) {
LOG(FATAL) << "TBD"; rb_->SendError(str);
} }
private: private:

View file

@ -183,6 +183,10 @@ void ReplyBuilder::SendLong(long num) {
as_resp()->SendDirect(str); as_resp()->SendDirect(str);
} }
void ReplyBuilder::SendDouble(double val) {
SendBulkString(absl::StrCat(val));
}
void ReplyBuilder::SendMGetResponse(const StrOrNil* arr, uint32_t count) { void ReplyBuilder::SendMGetResponse(const StrOrNil* arr, uint32_t count) {
string res = absl::StrCat("*", count, kCRLF); string res = absl::StrCat("*", count, kCRLF);
for (size_t i = 0; i < count; ++i) { for (size_t i = 0; i < count; ++i) {

View file

@ -113,6 +113,7 @@ class ReplyBuilder {
} }
void SendLong(long val); void SendLong(long val);
void SendDouble(double val);
void SendBulkString(std::string_view str) { void SendBulkString(std::string_view str) {
as_resp()->SendBulkString(str); as_resp()->SendBulkString(str);