mirror of
https://github.com/dragonflydb/dragonfly.git
synced 2024-12-15 17:51:06 +00:00
Wire support for redis.call and redis.pcall commands
This commit is contained in:
parent
fc56a8e61a
commit
7a2f4baeec
6 changed files with 146 additions and 7 deletions
|
@ -37,6 +37,70 @@ void Require(lua_State* lua, const char* name, lua_CFunction openf) {
|
||||||
lua_pop(lua, 1); /* remove lib */
|
lua_pop(lua, 1); /* remove lib */
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Save the give pointer on Lua registry, used to save the Lua context and
|
||||||
|
* function context so we can retrieve them from lua_State.
|
||||||
|
*/
|
||||||
|
void SaveOnRegistry(lua_State* lua, const char* name, void* ptr) {
|
||||||
|
lua_pushstring(lua, name);
|
||||||
|
if (ptr) {
|
||||||
|
lua_pushlightuserdata(lua, ptr);
|
||||||
|
} else {
|
||||||
|
lua_pushnil(lua);
|
||||||
|
}
|
||||||
|
lua_settable(lua, LUA_REGISTRYINDEX);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Get a saved pointer from registry
|
||||||
|
*/
|
||||||
|
void* GetFromRegistry(lua_State* lua, const char* name) {
|
||||||
|
lua_pushstring(lua, name);
|
||||||
|
lua_gettable(lua, LUA_REGISTRYINDEX);
|
||||||
|
|
||||||
|
/* must be light user data */
|
||||||
|
DCHECK(lua_islightuserdata(lua, -1));
|
||||||
|
|
||||||
|
void* ptr = (void*)lua_topointer(lua, -1);
|
||||||
|
DCHECK(ptr);
|
||||||
|
|
||||||
|
/* pops the value */
|
||||||
|
lua_pop(lua, 1);
|
||||||
|
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* This function is used in order to push an error on the Lua stack in the
|
||||||
|
* format used by redis.pcall to return errors, which is a lua table
|
||||||
|
* with a single "err" field set to the error string. Note that this
|
||||||
|
* table is never a valid reply by proper commands, since the returned
|
||||||
|
* tables are otherwise always indexed by integers, never by strings. */
|
||||||
|
void PushError(lua_State* lua, const char* error) {
|
||||||
|
lua_Debug dbg;
|
||||||
|
|
||||||
|
lua_newtable(lua);
|
||||||
|
lua_pushstring(lua, "err");
|
||||||
|
|
||||||
|
/* Attempt to figure out where this function was called, if possible */
|
||||||
|
if (lua_getstack(lua, 1, &dbg) && lua_getinfo(lua, "nSl", &dbg)) {
|
||||||
|
string msg = absl::StrCat(dbg.source, ": ", dbg.currentline, ": ", error);
|
||||||
|
lua_pushstring(lua, msg.c_str());
|
||||||
|
} else {
|
||||||
|
lua_pushstring(lua, error);
|
||||||
|
}
|
||||||
|
lua_settable(lua, -3);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* In case the error set into the Lua stack by PushError() was generated
|
||||||
|
* by the non-error-trapping version of redis.pcall(), which is redis.call(),
|
||||||
|
* this function will raise the Lua error so that the execution of the
|
||||||
|
* script will be halted. */
|
||||||
|
int RaiseError(lua_State* lua) {
|
||||||
|
lua_pushstring(lua, "err");
|
||||||
|
lua_gettable(lua, -2);
|
||||||
|
return lua_error(lua);
|
||||||
|
}
|
||||||
|
|
||||||
void InitLua(lua_State* lua) {
|
void InitLua(lua_State* lua) {
|
||||||
Require(lua, "", luaopen_base);
|
Require(lua, "", luaopen_base);
|
||||||
Require(lua, LUA_TABLIBNAME, luaopen_table);
|
Require(lua, LUA_TABLIBNAME, luaopen_table);
|
||||||
|
@ -90,6 +154,11 @@ debug = nil
|
||||||
)";
|
)";
|
||||||
RunSafe(lua, code, "@enable_strict_lua");
|
RunSafe(lua, code, "@enable_strict_lua");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lua_pushnil(lua);
|
||||||
|
lua_setglobal(lua, "loadfile");
|
||||||
|
lua_pushnil(lua);
|
||||||
|
lua_setglobal(lua, "dofile");
|
||||||
}
|
}
|
||||||
|
|
||||||
void ToHex(const uint8_t* src, char* dest) {
|
void ToHex(const uint8_t* src, char* dest) {
|
||||||
|
@ -115,11 +184,14 @@ optional<int> FetchKey(lua_State* lua, const char* key) {
|
||||||
return type;
|
return type;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const char* kInstanceKey = "_INSTANCE";
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Interpreter::Interpreter() {
|
Interpreter::Interpreter() {
|
||||||
lua_ = luaL_newstate();
|
lua_ = luaL_newstate();
|
||||||
InitLua(lua_);
|
InitLua(lua_);
|
||||||
|
SaveOnRegistry(lua_, kInstanceKey, this);
|
||||||
}
|
}
|
||||||
|
|
||||||
Interpreter::~Interpreter() {
|
Interpreter::~Interpreter() {
|
||||||
|
@ -241,7 +313,7 @@ bool Interpreter::Serialize(ObjectExplorer* serializer, std::string* error) {
|
||||||
if (len > 0) { // array
|
if (len > 0) { // array
|
||||||
serializer->OnArrayStart(len);
|
serializer->OnArrayStart(len);
|
||||||
for (unsigned i = 0; i < len; ++i) {
|
for (unsigned i = 0; i < len; ++i) {
|
||||||
t = lua_rawgeti(lua_, -1, i + 1); // push table element
|
t = lua_rawgeti(lua_, -1, i + 1); // push table element
|
||||||
|
|
||||||
// TODO: we should make sure that we have enough stack space
|
// TODO: we should make sure that we have enough stack space
|
||||||
// to traverse each object. This can be done as a dry-run before doing real serialization.
|
// to traverse each object. This can be done as a dry-run before doing real serialization.
|
||||||
|
@ -279,4 +351,53 @@ bool Interpreter::Serialize(ObjectExplorer* serializer, std::string* error) {
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns number of results, which is always 1 in this case.
|
||||||
|
// Please note that lua resets the stack once the function returns so no need
|
||||||
|
// to unwind the stack manually in the function (though lua allows doing this).
|
||||||
|
int Interpreter::RedisGenericCommand(bool raise_error) {
|
||||||
|
/* By using Lua debug hooks it is possible to trigger a recursive call
|
||||||
|
* to luaRedisGenericCommand(), which normally should never happen.
|
||||||
|
* To make this function reentrant is futile and makes it slower, but
|
||||||
|
* we should at least detect such a misuse, and abort. */
|
||||||
|
if (cmd_depth_) {
|
||||||
|
const char* recursion_warning =
|
||||||
|
"luaRedisGenericCommand() recursive call detected. "
|
||||||
|
"Are you doing funny stuff with Lua debug hooks?";
|
||||||
|
PushError(lua_, recursion_warning);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd_depth_++;
|
||||||
|
int argc = lua_gettop(lua_);
|
||||||
|
|
||||||
|
/* Require at least one argument */
|
||||||
|
if (argc == 0) {
|
||||||
|
PushError(lua_, "Please specify at least one argument for redis.call()");
|
||||||
|
cmd_depth_--;
|
||||||
|
|
||||||
|
return raise_error ? RaiseError(lua_) : 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: to prepare arguments.
|
||||||
|
|
||||||
|
/* Pop all arguments from the stack, we do not need them anymore
|
||||||
|
* and this way we guaranty we will have room on the stack for the result. */
|
||||||
|
lua_pop(lua_, argc);
|
||||||
|
|
||||||
|
cmd_depth_--;
|
||||||
|
lua_pushinteger(lua_, 42);
|
||||||
|
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int Interpreter::RedisCallCommand(lua_State* lua) {
|
||||||
|
void* me = GetFromRegistry(lua, kInstanceKey);
|
||||||
|
return reinterpret_cast<Interpreter*>(me)->RedisGenericCommand(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
int Interpreter::RedisPCallCommand(lua_State* lua) {
|
||||||
|
void* me = GetFromRegistry(lua, kInstanceKey);
|
||||||
|
return reinterpret_cast<Interpreter*>(me)->RedisGenericCommand(false);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace dfly
|
} // namespace dfly
|
||||||
|
|
|
@ -57,7 +57,13 @@ class Interpreter {
|
||||||
private:
|
private:
|
||||||
bool AddInternal(const char* f_id, std::string_view body, std::string* result);
|
bool AddInternal(const char* f_id, std::string_view body, std::string* result);
|
||||||
|
|
||||||
|
int RedisGenericCommand(bool raise_error);
|
||||||
|
|
||||||
|
static int RedisCallCommand(lua_State *lua);
|
||||||
|
static int RedisPCallCommand(lua_State *lua);
|
||||||
|
|
||||||
lua_State* lua_;
|
lua_State* lua_;
|
||||||
|
unsigned cmd_depth_ = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace dfly
|
} // namespace dfly
|
||||||
|
|
|
@ -216,6 +216,9 @@ TEST_F(DflyEngineTest, FlushDb) {
|
||||||
TEST_F(DflyEngineTest, Eval) {
|
TEST_F(DflyEngineTest, Eval) {
|
||||||
auto resp = Run({"eval", "return 42", "0"});
|
auto resp = Run({"eval", "return 42", "0"});
|
||||||
EXPECT_THAT(resp[0], IntArg(42));
|
EXPECT_THAT(resp[0], IntArg(42));
|
||||||
|
|
||||||
|
resp = Run({"eval", "return redis.call('get', 'foo')", "0"});
|
||||||
|
EXPECT_THAT(resp[0], IntArg(42)); // TODO.
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: to test transactions with a single shard since then all transactions become local.
|
// TODO: to test transactions with a single shard since then all transactions become local.
|
||||||
|
|
|
@ -55,15 +55,19 @@ class EvalSerializer : public ObjectExplorer {
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnBool(bool b) final {
|
void OnBool(bool b) final {
|
||||||
LOG(FATAL) << "TBD";
|
if (b) {
|
||||||
|
rb_->SendLong(1);
|
||||||
|
} else {
|
||||||
|
rb_->SendNull();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnString(std::string_view str) final {
|
void OnString(std::string_view str) final {
|
||||||
LOG(FATAL) << "TBD";
|
rb_->SendBulkString(str);
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnDouble(double d) final {
|
void OnDouble(double d) final {
|
||||||
LOG(FATAL) << "TBD";
|
rb_->SendDouble(d);
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnInt(int64_t val) final {
|
void OnInt(int64_t val) final {
|
||||||
|
@ -79,15 +83,15 @@ class EvalSerializer : public ObjectExplorer {
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnNil() final {
|
void OnNil() final {
|
||||||
LOG(FATAL) << "TBD";
|
rb_->SendNull();
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnStatus(std::string_view str) {
|
void OnStatus(std::string_view str) {
|
||||||
LOG(FATAL) << "TBD";
|
rb_->SendSimpleRespString(str);
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnError(std::string_view str) {
|
void OnError(std::string_view str) {
|
||||||
LOG(FATAL) << "TBD";
|
rb_->SendError(str);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -183,6 +183,10 @@ void ReplyBuilder::SendLong(long num) {
|
||||||
as_resp()->SendDirect(str);
|
as_resp()->SendDirect(str);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ReplyBuilder::SendDouble(double val) {
|
||||||
|
SendBulkString(absl::StrCat(val));
|
||||||
|
}
|
||||||
|
|
||||||
void ReplyBuilder::SendMGetResponse(const StrOrNil* arr, uint32_t count) {
|
void ReplyBuilder::SendMGetResponse(const StrOrNil* arr, uint32_t count) {
|
||||||
string res = absl::StrCat("*", count, kCRLF);
|
string res = absl::StrCat("*", count, kCRLF);
|
||||||
for (size_t i = 0; i < count; ++i) {
|
for (size_t i = 0; i < count; ++i) {
|
||||||
|
|
|
@ -113,6 +113,7 @@ class ReplyBuilder {
|
||||||
}
|
}
|
||||||
|
|
||||||
void SendLong(long val);
|
void SendLong(long val);
|
||||||
|
void SendDouble(double val);
|
||||||
|
|
||||||
void SendBulkString(std::string_view str) {
|
void SendBulkString(std::string_view str) {
|
||||||
as_resp()->SendBulkString(str);
|
as_resp()->SendBulkString(str);
|
||||||
|
|
Loading…
Reference in a new issue