1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-14 11:58:02 +00:00

Add (lua) interpreter with lua 5.4.4

Small fixes all around.
This commit is contained in:
Roman Gershman 2022-01-31 21:40:40 +02:00
parent 8a3207f23e
commit 8fbe19d3c5
13 changed files with 368 additions and 14 deletions

View file

@ -1,6 +1,8 @@
add_library(dfly_core compact_object.cc dragonfly_core.cc tx_queue.cc)
cxx_link(dfly_core base absl::flat_hash_map redis_lib)
add_library(dfly_core compact_object.cc dragonfly_core.cc interpreter.cc
tx_queue.cc)
cxx_link(dfly_core base absl::flat_hash_map redis_lib TRDP::lua crypto)
cxx_test(dfly_core_test dfly_core LABELS DFLY)
cxx_test(compact_object_test dfly_core LABELS DFLY)
cxx_test(dash_test dfly_core LABELS DFLY)
cxx_test(dash_test dfly_core LABELS DFLY)
cxx_test(interpreter_test dfly_core LABELS DFLY)

170
core/interpreter.cc Normal file
View file

@ -0,0 +1,170 @@
// Copyright 2022, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
#include "core/interpreter.h"
#include <absl/strings/str_cat.h>
#include <openssl/sha.h>
#include <cstring>
extern "C" {
#include <lauxlib.h>
#include <lua.h>
#include <lualib.h>
}
#include "base/logging.h"
namespace dfly {
using namespace std;
namespace {
void RunSafe(lua_State* lua, string_view buf, const char* name) {
CHECK_EQ(0, luaL_loadbuffer(lua, buf.data(), buf.size(), name));
int err = lua_pcall(lua, 0, 0, 0);
if (err) {
const char* errstr = lua_tostring(lua, -1);
LOG(FATAL) << "Error running " << name << " " << errstr;
}
}
void Require(lua_State* lua, const char* name, lua_CFunction openf) {
luaL_requiref(lua, name, openf, 1);
lua_pop(lua, 1); /* remove lib */
}
void InitLua(lua_State* lua) {
Require(lua, "", luaopen_base);
Require(lua, LUA_TABLIBNAME, luaopen_table);
Require(lua, LUA_STRLIBNAME, luaopen_string);
Require(lua, LUA_MATHLIBNAME, luaopen_math);
Require(lua, LUA_DBLIBNAME, luaopen_debug);
/* Add a helper function we use for pcall error reporting.
* Note that when the error is in the C function we want to report the
* information about the caller, that's what makes sense from the point
* of view of the user debugging a script. */
{
const char errh_func[] =
"local dbg = debug\n"
"function __redis__err__handler(err)\n"
" local i = dbg.getinfo(2,'nSl')\n"
" if i and i.what == 'C' then\n"
" i = dbg.getinfo(3,'nSl')\n"
" end\n"
" if i then\n"
" return i.source .. ':' .. i.currentline .. ': ' .. err\n"
" else\n"
" return err\n"
" end\n"
"end\n";
RunSafe(lua, errh_func, "@err_handler_def");
}
{
const char code[] = R"(
local dbg=debug
local mt = {}
setmetatable(_G, mt)
mt.__newindex = function (t, n, v)
if dbg.getinfo(2) then
local w = dbg.getinfo(2, "S").what
if w ~= "main" and w ~= "C" then
error("Script attempted to create global variable '"..tostring(n).."'", 2)
end
end
rawset(t, n, v)
end
mt.__index = function (t, n)
if dbg.getinfo(2) and dbg.getinfo(2, "S").what ~= "C" then
error("Script attempted to access nonexistent global variable '"..tostring(n).."'", 2)
end
return rawget(t, n)
end
debug = nil
)";
RunSafe(lua, code, "@enable_strict_lua");
}
}
void ToHex(const uint8_t* src, char* dest) {
const char cset[] = "0123456789abcdef";
for (size_t j = 0; j < 20; j++) {
dest[j * 2] = cset[((src[j] & 0xF0) >> 4)];
dest[j * 2 + 1] = cset[(src[j] & 0xF)];
}
dest[40] = '\0';
}
} // namespace
Interpreter::Interpreter() {
lua_ = luaL_newstate();
InitLua(lua_);
}
Interpreter::~Interpreter() {
lua_close(lua_);
}
void Interpreter::Fingerprint(string_view body, char* fp) {
SHA_CTX ctx;
uint8_t buf[20];
SHA1_Init(&ctx);
SHA1_Update(&ctx, body.data(), body.size());
SHA1_Final(buf, &ctx);
fp[0] = 'f';
fp[1] = '_';
ToHex(buf, fp + 2);
}
bool Interpreter::AddFunction(string_view body, string* result) {
char funcname[43];
Fingerprint(body, funcname);
string script = absl::StrCat("function ", funcname, "() \n");
absl::StrAppend(&script, body, "\nend");
int res = luaL_loadbuffer(lua_, script.data(), script.size(), "@user_script");
if (res == 0) {
res = lua_pcall(lua_, 0, 0, 0); // run func definition code
}
if (res) {
result->assign(lua_tostring(lua_, -1));
lua_pop(lua_, 1); // Remove the error.
return false;
}
result->assign(funcname);
return true;
}
bool Interpreter::RunFunction(const char* f_id, std::string* error) {
lua_getglobal(lua_, "__redis__err__handler");
int type = lua_getglobal(lua_, f_id);
if (type != LUA_TFUNCTION) {
error->assign("function not found"); // TODO: noscripterr.
lua_pop(lua_, 2);
return false;
}
/* We have zero arguments and expect
* a single return value. */
int err = lua_pcall(lua_, 0, 1, -2);
if (err) {
*error = lua_tostring(lua_, -1);
}
return err == 0;
}
} // namespace dfly

43
core/interpreter.h Normal file
View file

@ -0,0 +1,43 @@
// Copyright 2022, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
#pragma once
#include <string_view>
typedef struct lua_State lua_State;
namespace dfly {
class Interpreter {
public:
Interpreter();
~Interpreter();
Interpreter(const Interpreter&) = delete;
void operator=(const Interpreter&) = delete;
// Note: We leak the state for now.
// Production code should not access this method.
lua_State* lua() {
return lua_;
}
// returns false if an error happenned, sets error string into result.
// otherwise, returns true and sets result to function id.
bool AddFunction(std::string_view body, std::string* result);
// Runs already added function f_id returned by a successful call to AddFunction().
// Returns: true if the call succeeded, otherwise fills error and returns false.
bool RunFunction(const char* f_id, std::string* err);
// fp must point to buffer with at least 43 chars.
// fp[42] will be set to '\0'.
static void Fingerprint(std::string_view body, char* fp);
private:
lua_State* lua_;
};
} // namespace dfly

67
core/interpreter_test.cc Normal file
View file

@ -0,0 +1,67 @@
// Copyright 2022, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
#include "core/interpreter.h"
extern "C" {
#include <lauxlib.h>
#include <lua.h>
}
#include <gmock/gmock.h>
#include "base/gtest.h"
#include "base/logging.h"
namespace dfly {
using namespace std;
class InterpreterTest : public ::testing::Test {
protected:
InterpreterTest() {
}
lua_State* lua() {
return intptr_.lua();
}
void RunInline(string_view buf, const char* name) {
CHECK_EQ(0, luaL_loadbuffer(lua(), buf.data(), buf.size(), name));
CHECK_EQ(0, lua_pcall(lua(), 0, 0, 0));
}
Interpreter intptr_;
};
TEST_F(InterpreterTest, Basic) {
RunInline(R"(
function foo(n)
return n,n+1
end)",
"code1");
int type = lua_getglobal(lua(), "foo");
ASSERT_EQ(LUA_TFUNCTION, type);
lua_pushnumber(lua(), 42);
lua_pcall(lua(), 1, 2, 0);
int val1 = lua_tointeger(lua(), -1);
int val2 = lua_tointeger(lua(), -2);
lua_pop(lua(), 2);
EXPECT_EQ(43, val1);
EXPECT_EQ(42, val2);
EXPECT_EQ(0, lua_gettop(lua()));
}
TEST_F(InterpreterTest, Add) {
string res1, res2;
EXPECT_TRUE(intptr_.AddFunction("return 0", &res1));
EXPECT_EQ(0, lua_gettop(lua()));
EXPECT_FALSE(intptr_.AddFunction("foobar", &res2));
EXPECT_THAT(res2, testing::HasSubstr("syntax error"));
EXPECT_EQ(0, lua_gettop(lua()));
}
} // namespace dfly

2
helio

@ -1 +1 @@
Subproject commit 3ee017cce280493c845a010a28caa4cf1d0f4e9b
Subproject commit a300a704e193d115333f41a81438ba74d3df8c51

View file

@ -24,4 +24,5 @@ cxx_test(rdb_test dfly_test_lib DATA testdata/empty.rdb testdata/small.rdb LABEL
add_custom_target(check_dfly WORKING_DIRECTORY .. COMMAND ctest -L DFLY)
add_dependencies(check_dfly dragonfly_test list_family_test
generic_family_test memcache_parser_test redis_parser_test string_family_test)
generic_family_test memcache_parser_test rdb_test
redis_parser_test string_family_test)

View file

@ -7,7 +7,6 @@
#include "base/logging.h"
#include "server/common_types.h"
#include "server/error.h"
#include "server/global_state.h"
#include "server/server_state.h"
namespace dfly {
@ -22,6 +21,23 @@ ServerState::ServerState() {
ServerState::~ServerState() {
}
void ServerState::Init() {
gstate_ = GlobalState::IDLE;
}
void ServerState::Shutdown() {
gstate_ = GlobalState::SHUTTING_DOWN;
interpreter_.reset();
}
Interpreter& ServerState::GetInterpreter() {
if (!interpreter_) {
interpreter_.emplace();
}
return interpreter_.value();
}
#define ADD(x) (x) += o.x
ConnectionStats& ConnectionStats::operator+=(const ConnectionStats& o) {

View file

@ -31,6 +31,10 @@ struct ConnectionState {
enum Mask : uint32_t {
ASYNC_DISPATCH = 1, // whether a command is handled via async dispatch.
CONN_CLOSING = 2, // could be because of unrecoverable error or planned action.
// Whether this connection belongs to replica, i.e. a dragonfly slave is connected to this
// host (master) via this connection to sync from it.
REPL_CONNECTION = 2,
};
uint32_t mask = 0; // A bitmask of Mask values.

View file

@ -280,6 +280,11 @@ finish:
stats->read_buf_capacity -= io_buf_.Capacity();
// Update num_replicas if this was a replica connection.
if (cc_->conn_state.mask & ConnectionState::REPL_CONNECTION) {
--stats->num_replicas;
}
if (cc_->ec()) {
ec = cc_->ec();
} else {

View file

@ -72,6 +72,8 @@ void Service::Init(util::AcceptServer* acceptor, const InitOpts& opts) {
shard_set_.Init(shard_num);
pp_.Await([&](uint32_t index, ProactorBase* pb) {
ServerState::tlocal()->Init();
if (index < shard_count()) {
shard_set_.InitThreadLocal(pb, !opts.disable_time_update);
}
@ -96,6 +98,8 @@ void Service::Shutdown() {
request_latency_usec.Shutdown();
ping_qps.Shutdown();
pp_.AwaitFiberOnAll([](ProactorBase* pb) { ServerState::tlocal()->Shutdown(); });
// to shutdown all the runtime components that depend on EngineShard.
server_family_.Shutdown();
StringFamily::Shutdown();
@ -129,7 +133,20 @@ void Service::DispatchCommand(CmdArgList args, ConnectionContext* cntx) {
return cntx->SendError(absl::StrCat("unknown command `", cmd_str, "`"));
}
if (etl.gstate() == GlobalState::LOADING || etl.gstate() == GlobalState::SHUTTING_DOWN) {
string err = absl::StrCat("Can not execute during ", GlobalState::Name(etl.gstate()));
cntx->SendError(err);
return;
}
bool is_write_cmd = cid->opt_mask() & CO::WRITE;
bool under_multi = cntx->conn_state.exec_state != ConnectionState::EXEC_INACTIVE && !is_trans_cmd;
if (!etl.is_master && is_write_cmd) {
cntx->SendError("-READONLY You can't write against a read only replica.");
return;
}
if ((cid->arity() > 0 && args.size() != size_t(cid->arity())) ||
(cid->arity() < 0 && args.size() < size_t(-cid->arity()))) {
return cntx->SendError(WrongNumArgsError(cmd_str));
@ -272,6 +289,12 @@ void Service::Multi(CmdArgList args, ConnectionContext* cntx) {
return cntx->SendOk();
}
void Service::Eval(CmdArgList args, ConnectionContext* cntx) {
Interpreter& script = ServerState::tlocal()->GetInterpreter();
script.lua();
return cntx->SendOk();
}
void Service::Exec(CmdArgList args, ConnectionContext* cntx) {
if (cntx->conn_state.exec_state == ConnectionState::EXEC_INACTIVE) {
return cntx->SendError("EXEC without MULTI");
@ -315,9 +338,12 @@ void Service::Exec(CmdArgList args, ConnectionContext* cntx) {
VarzValue::Map Service::GetVarzStats() {
VarzValue::Map res;
atomic_ulong num_keys{0};
shard_set_.RunBriefInParallel([&](EngineShard* es) { num_keys += es->db_slice().DbSize(0); });
res.emplace_back("keys", VarzValue::FromInt(num_keys.load()));
Metrics m = server_family_.GetMetrics();
res.emplace_back("keys", VarzValue::FromInt(m.db.key_count));
res.emplace_back("obj_mem_usage", VarzValue::FromInt(m.db.obj_memory_usage));
double load = double(m.db.key_count) / (1 + m.db.bucket_count);
res.emplace_back("table_load_factor", VarzValue::FromDouble(load));
return res;
}
@ -338,6 +364,7 @@ void Service::RegisterCommands() {
registry_ << CI{"QUIT", CO::READONLY | CO::FAST, 1, 0, 0, 0}.HFUNC(Quit)
<< CI{"MULTI", CO::NOSCRIPT | CO::FAST | CO::LOADING | CO::STALE, 1, 0, 0, 0}.HFUNC(
Multi)
<< CI{"EVAL", CO::NOSCRIPT, -3, 0, 0, 0}.HFUNC(Eval)
<< CI{"EXEC", kExecMask, 1, 0, 0, 0}.SetHandler(cb_exec);
StringFamily::Register(&registry_);

View file

@ -64,6 +64,8 @@ class Service {
private:
static void Quit(CmdArgList args, ConnectionContext* cntx);
static void Multi(CmdArgList args, ConnectionContext* cntx);
static void Eval(CmdArgList args, ConnectionContext* cntx);
void Exec(CmdArgList args, ConnectionContext* cntx);

View file

@ -63,6 +63,7 @@ error_code CreateDirs(fs::path dir_path) {
}
return ec;
}
} // namespace
ServerFamily::ServerFamily(Service* engine)
@ -80,6 +81,7 @@ void ServerFamily::Init(util::AcceptServer* acceptor) {
void ServerFamily::Shutdown() {
VLOG(1) << "ServerFamily::Shutdown";
pp_.GetNextProactor()->Await([this] {
unique_lock lk(replica_of_mu_);
if (replica_) {
@ -176,7 +178,7 @@ void ServerFamily::Save(CmdArgList args, ConnectionContext* cntx) {
return;
}
pp_.Await([](auto*) { ServerState::tlocal()->state = GlobalState::SAVING; });
pp_.Await([](auto*) { ServerState::tlocal()->set_gstate(GlobalState::SAVING); });
unique_ptr<::io::WriteFile> wf(*res);
auto start = absl::Now();
@ -200,7 +202,7 @@ void ServerFamily::Save(CmdArgList args, ConnectionContext* cntx) {
return;
}
pp_.Await([](auto*) { ServerState::tlocal()->state = GlobalState::IDLE; });
pp_.Await([](auto*) { ServerState::tlocal()->set_gstate(GlobalState::IDLE); });
CHECK_EQ(GlobalState::SAVING, global_state_.Clear());
absl::Duration dur = absl::Now() - start;
@ -243,7 +245,7 @@ Metrics ServerFamily::GetMetrics() const {
void ServerFamily::Info(CmdArgList args, ConnectionContext* cntx) {
const char kInfo1[] =
R"(# Server
redis_version:6.2.0
redis_version:1.9.9
redis_mode:standalone
arch_bits:64
multiplexing_api:iouring
@ -292,6 +294,9 @@ tcp_port:)";
absl::StrAppend(&info, "master_last_io_seconds_ago:", rinfo.master_last_io_sec, "\n");
absl::StrAppend(&info, "master_sync_in_progress:", rinfo.sync_in_progress, "\n");
}
absl::StrAppend(&info, "\n# Keyspace\n");
absl::StrAppend(&info, "db0:keys=xxx,expires=yyy,avg_ttl=zzz\n"); // TODO
cntx->SendBulkString(info);
}
@ -308,7 +313,6 @@ void ServerFamily::ReplicaOf(CmdArgList args, ConnectionContext* cntx) {
auto repl_ptr = replica_;
CHECK(repl_ptr);
pp_.AwaitFiberOnAll([&](util::ProactorBase* pb) { ServerState::tlocal()->is_master = true; });
replica_->Stop();
replica_.reset();
@ -378,6 +382,8 @@ void ServerFamily::SyncGeneric(std::string_view repl_master_id, uint64_t offs,
return;
}
cntx->conn_state.mask |= ConnectionState::REPL_CONNECTION;
ServerState::tl_connection_stats()->num_replicas += 1;
// TBD.
}

View file

@ -4,10 +4,12 @@
#pragma once
#include <optional>
#include <vector>
#include "server/common_types.h"
#include "server/global_state.h"
#include "core/interpreter.h"
namespace dfly {
@ -32,7 +34,9 @@ class ServerState { // public struct - to allow initialization.
ServerState();
~ServerState();
GlobalState::S state = GlobalState::IDLE;
void Init();
void Shutdown();
bool is_master = true;
ConnectionStats connection_stats;
@ -49,8 +53,15 @@ class ServerState { // public struct - to allow initialization.
return live_transactions_;
}
GlobalState::S gstate() const { return gstate_;}
void set_gstate(GlobalState::S s) { gstate_ = s; }
Interpreter& GetInterpreter();
private:
int64_t live_transactions_ = 0;
std::optional<Interpreter> interpreter_;
GlobalState::S gstate_ = GlobalState::IDLE;
static thread_local ServerState state_;
};