mirror of
https://github.com/dragonflydb/dragonfly.git
synced 2024-12-14 11:58:02 +00:00
Wire support for redis.call and redis.pcall commands
This commit is contained in:
parent
fc56a8e61a
commit
7a2f4baeec
6 changed files with 146 additions and 7 deletions
|
@ -37,6 +37,70 @@ void Require(lua_State* lua, const char* name, lua_CFunction openf) {
|
|||
lua_pop(lua, 1); /* remove lib */
|
||||
}
|
||||
|
||||
/*
|
||||
* Save the give pointer on Lua registry, used to save the Lua context and
|
||||
* function context so we can retrieve them from lua_State.
|
||||
*/
|
||||
void SaveOnRegistry(lua_State* lua, const char* name, void* ptr) {
|
||||
lua_pushstring(lua, name);
|
||||
if (ptr) {
|
||||
lua_pushlightuserdata(lua, ptr);
|
||||
} else {
|
||||
lua_pushnil(lua);
|
||||
}
|
||||
lua_settable(lua, LUA_REGISTRYINDEX);
|
||||
}
|
||||
|
||||
/*
|
||||
* Get a saved pointer from registry
|
||||
*/
|
||||
void* GetFromRegistry(lua_State* lua, const char* name) {
|
||||
lua_pushstring(lua, name);
|
||||
lua_gettable(lua, LUA_REGISTRYINDEX);
|
||||
|
||||
/* must be light user data */
|
||||
DCHECK(lua_islightuserdata(lua, -1));
|
||||
|
||||
void* ptr = (void*)lua_topointer(lua, -1);
|
||||
DCHECK(ptr);
|
||||
|
||||
/* pops the value */
|
||||
lua_pop(lua, 1);
|
||||
|
||||
return ptr;
|
||||
}
|
||||
|
||||
/* This function is used in order to push an error on the Lua stack in the
|
||||
* format used by redis.pcall to return errors, which is a lua table
|
||||
* with a single "err" field set to the error string. Note that this
|
||||
* table is never a valid reply by proper commands, since the returned
|
||||
* tables are otherwise always indexed by integers, never by strings. */
|
||||
void PushError(lua_State* lua, const char* error) {
|
||||
lua_Debug dbg;
|
||||
|
||||
lua_newtable(lua);
|
||||
lua_pushstring(lua, "err");
|
||||
|
||||
/* Attempt to figure out where this function was called, if possible */
|
||||
if (lua_getstack(lua, 1, &dbg) && lua_getinfo(lua, "nSl", &dbg)) {
|
||||
string msg = absl::StrCat(dbg.source, ": ", dbg.currentline, ": ", error);
|
||||
lua_pushstring(lua, msg.c_str());
|
||||
} else {
|
||||
lua_pushstring(lua, error);
|
||||
}
|
||||
lua_settable(lua, -3);
|
||||
}
|
||||
|
||||
/* In case the error set into the Lua stack by PushError() was generated
|
||||
* by the non-error-trapping version of redis.pcall(), which is redis.call(),
|
||||
* this function will raise the Lua error so that the execution of the
|
||||
* script will be halted. */
|
||||
int RaiseError(lua_State* lua) {
|
||||
lua_pushstring(lua, "err");
|
||||
lua_gettable(lua, -2);
|
||||
return lua_error(lua);
|
||||
}
|
||||
|
||||
void InitLua(lua_State* lua) {
|
||||
Require(lua, "", luaopen_base);
|
||||
Require(lua, LUA_TABLIBNAME, luaopen_table);
|
||||
|
@ -90,6 +154,11 @@ debug = nil
|
|||
)";
|
||||
RunSafe(lua, code, "@enable_strict_lua");
|
||||
}
|
||||
|
||||
lua_pushnil(lua);
|
||||
lua_setglobal(lua, "loadfile");
|
||||
lua_pushnil(lua);
|
||||
lua_setglobal(lua, "dofile");
|
||||
}
|
||||
|
||||
void ToHex(const uint8_t* src, char* dest) {
|
||||
|
@ -115,11 +184,14 @@ optional<int> FetchKey(lua_State* lua, const char* key) {
|
|||
return type;
|
||||
}
|
||||
|
||||
const char* kInstanceKey = "_INSTANCE";
|
||||
|
||||
} // namespace
|
||||
|
||||
Interpreter::Interpreter() {
|
||||
lua_ = luaL_newstate();
|
||||
InitLua(lua_);
|
||||
SaveOnRegistry(lua_, kInstanceKey, this);
|
||||
}
|
||||
|
||||
Interpreter::~Interpreter() {
|
||||
|
@ -241,7 +313,7 @@ bool Interpreter::Serialize(ObjectExplorer* serializer, std::string* error) {
|
|||
if (len > 0) { // array
|
||||
serializer->OnArrayStart(len);
|
||||
for (unsigned i = 0; i < len; ++i) {
|
||||
t = lua_rawgeti(lua_, -1, i + 1); // push table element
|
||||
t = lua_rawgeti(lua_, -1, i + 1); // push table element
|
||||
|
||||
// TODO: we should make sure that we have enough stack space
|
||||
// to traverse each object. This can be done as a dry-run before doing real serialization.
|
||||
|
@ -279,4 +351,53 @@ bool Interpreter::Serialize(ObjectExplorer* serializer, std::string* error) {
|
|||
return res;
|
||||
}
|
||||
|
||||
// Returns number of results, which is always 1 in this case.
|
||||
// Please note that lua resets the stack once the function returns so no need
|
||||
// to unwind the stack manually in the function (though lua allows doing this).
|
||||
int Interpreter::RedisGenericCommand(bool raise_error) {
|
||||
/* By using Lua debug hooks it is possible to trigger a recursive call
|
||||
* to luaRedisGenericCommand(), which normally should never happen.
|
||||
* To make this function reentrant is futile and makes it slower, but
|
||||
* we should at least detect such a misuse, and abort. */
|
||||
if (cmd_depth_) {
|
||||
const char* recursion_warning =
|
||||
"luaRedisGenericCommand() recursive call detected. "
|
||||
"Are you doing funny stuff with Lua debug hooks?";
|
||||
PushError(lua_, recursion_warning);
|
||||
return 1;
|
||||
}
|
||||
|
||||
cmd_depth_++;
|
||||
int argc = lua_gettop(lua_);
|
||||
|
||||
/* Require at least one argument */
|
||||
if (argc == 0) {
|
||||
PushError(lua_, "Please specify at least one argument for redis.call()");
|
||||
cmd_depth_--;
|
||||
|
||||
return raise_error ? RaiseError(lua_) : 1;
|
||||
}
|
||||
|
||||
// TODO: to prepare arguments.
|
||||
|
||||
/* Pop all arguments from the stack, we do not need them anymore
|
||||
* and this way we guaranty we will have room on the stack for the result. */
|
||||
lua_pop(lua_, argc);
|
||||
|
||||
cmd_depth_--;
|
||||
lua_pushinteger(lua_, 42);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int Interpreter::RedisCallCommand(lua_State* lua) {
|
||||
void* me = GetFromRegistry(lua, kInstanceKey);
|
||||
return reinterpret_cast<Interpreter*>(me)->RedisGenericCommand(true);
|
||||
}
|
||||
|
||||
int Interpreter::RedisPCallCommand(lua_State* lua) {
|
||||
void* me = GetFromRegistry(lua, kInstanceKey);
|
||||
return reinterpret_cast<Interpreter*>(me)->RedisGenericCommand(false);
|
||||
}
|
||||
|
||||
} // namespace dfly
|
||||
|
|
|
@ -57,7 +57,13 @@ class Interpreter {
|
|||
private:
|
||||
bool AddInternal(const char* f_id, std::string_view body, std::string* result);
|
||||
|
||||
int RedisGenericCommand(bool raise_error);
|
||||
|
||||
static int RedisCallCommand(lua_State *lua);
|
||||
static int RedisPCallCommand(lua_State *lua);
|
||||
|
||||
lua_State* lua_;
|
||||
unsigned cmd_depth_ = 0;
|
||||
};
|
||||
|
||||
} // namespace dfly
|
||||
|
|
|
@ -216,6 +216,9 @@ TEST_F(DflyEngineTest, FlushDb) {
|
|||
TEST_F(DflyEngineTest, Eval) {
|
||||
auto resp = Run({"eval", "return 42", "0"});
|
||||
EXPECT_THAT(resp[0], IntArg(42));
|
||||
|
||||
resp = Run({"eval", "return redis.call('get', 'foo')", "0"});
|
||||
EXPECT_THAT(resp[0], IntArg(42)); // TODO.
|
||||
}
|
||||
|
||||
// TODO: to test transactions with a single shard since then all transactions become local.
|
||||
|
|
|
@ -55,15 +55,19 @@ class EvalSerializer : public ObjectExplorer {
|
|||
}
|
||||
|
||||
void OnBool(bool b) final {
|
||||
LOG(FATAL) << "TBD";
|
||||
if (b) {
|
||||
rb_->SendLong(1);
|
||||
} else {
|
||||
rb_->SendNull();
|
||||
}
|
||||
}
|
||||
|
||||
void OnString(std::string_view str) final {
|
||||
LOG(FATAL) << "TBD";
|
||||
rb_->SendBulkString(str);
|
||||
}
|
||||
|
||||
void OnDouble(double d) final {
|
||||
LOG(FATAL) << "TBD";
|
||||
rb_->SendDouble(d);
|
||||
}
|
||||
|
||||
void OnInt(int64_t val) final {
|
||||
|
@ -79,15 +83,15 @@ class EvalSerializer : public ObjectExplorer {
|
|||
}
|
||||
|
||||
void OnNil() final {
|
||||
LOG(FATAL) << "TBD";
|
||||
rb_->SendNull();
|
||||
}
|
||||
|
||||
void OnStatus(std::string_view str) {
|
||||
LOG(FATAL) << "TBD";
|
||||
rb_->SendSimpleRespString(str);
|
||||
}
|
||||
|
||||
void OnError(std::string_view str) {
|
||||
LOG(FATAL) << "TBD";
|
||||
rb_->SendError(str);
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -183,6 +183,10 @@ void ReplyBuilder::SendLong(long num) {
|
|||
as_resp()->SendDirect(str);
|
||||
}
|
||||
|
||||
void ReplyBuilder::SendDouble(double val) {
|
||||
SendBulkString(absl::StrCat(val));
|
||||
}
|
||||
|
||||
void ReplyBuilder::SendMGetResponse(const StrOrNil* arr, uint32_t count) {
|
||||
string res = absl::StrCat("*", count, kCRLF);
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
|
|
|
@ -113,6 +113,7 @@ class ReplyBuilder {
|
|||
}
|
||||
|
||||
void SendLong(long val);
|
||||
void SendDouble(double val);
|
||||
|
||||
void SendBulkString(std::string_view str) {
|
||||
as_resp()->SendBulkString(str);
|
||||
|
|
Loading…
Reference in a new issue