1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-14 11:58:02 +00:00

chore: harden lua rules

This commit is contained in:
Roman Gershman 2024-05-16 16:08:54 +03:00
parent 73e44a941e
commit c70c159abb
No known key found for this signature in database
GPG key ID: 6568CCAB9736B618
3 changed files with 144 additions and 66 deletions

View file

@ -70,6 +70,56 @@ void PushError(lua_State* lua, string_view error, bool trace = true) {
lua_settable(lua, -3);
}
static int ProtectedTableError(lua_State* lua) {
int argc = lua_gettop(lua);
if (argc != 2) {
LOG(DFATAL) << "Unexpected number of arguments " << argc;
return luaL_error(lua, "Wrong number of arguments to ProtectedTableError");
}
if (!lua_isstring(lua, -1) && !lua_isnumber(lua, -1)) {
return luaL_error(lua, "Second argument to ProtectedTableError must be a string or number");
}
const char* variable_name = lua_tostring(lua, -1);
VLOG(1) << "Undeclared variable " << variable_name;
return luaL_error(lua, "Script attempted to access nonexistent global variable '%s'",
variable_name);
}
static int NewIndexAllowList(lua_State* lua) {
int argc = lua_gettop(lua);
lua_Debug dbg;
if (argc != 3 || !lua_istable(lua, -3) || !lua_getstack(lua, 0, &dbg)) {
LOG(DFATAL) << "Bad arguments " << argc;
return luaL_error(lua, "Wrong number of arguments to NewIndexAllowList");
}
if (!lua_isstring(lua, -2) && !lua_isnumber(lua, -2)) {
return luaL_error(lua, "Second argument to NewIndexAllowList must be a string or number");
}
if (!lua_getinfo(lua, "S", &dbg)) {
return luaL_error(lua, "Could not find debug info");
}
/*
local w = dbg.getinfo(2, "S").what
if w ~= "main" and w ~= "C" then
error("Script attempted to create global variable '"..tostring(n).."'", 2)
end
*/
const char* variable_name = lua_tostring(lua, -2);
VLOG(1) << "New global variable " << variable_name << " " << dbg.what;
if (dbg.what[0] != 'C') {
return luaL_error(lua, "Can attempted to create global variable %s", variable_name);
}
return luaL_error(lua, "boo");
lua_rawset(lua, -3);
return 0;
}
// Custom object explorer that collects all values into string array
struct StringCollectorTranslator : public ObjectExplorer {
void OnString(std::string_view str) final {
@ -288,51 +338,29 @@ void InitLua(lua_State* lua) {
{
const char errh_func[] =
"local dbg = debug\n"
"function __redis__err__handler(err)\n"
"debug = nil\n"
"function __df__err__handler(err)\n"
" local i = dbg.getinfo(2,'nSl')\n"
" if i and i.what == 'C' then\n"
" i = dbg.getinfo(3,'nSl')\n"
" end\n"
" if type(err) ~= 'table' then\n"
" err = {err='ERR ' .. tostring(err)}"
" end"
" if i then\n"
" return i.source .. ':' .. i.currentline .. ': ' .. err\n"
" else\n"
" return err\n"
" end\n"
" err['source'] = i.source\n"
" err['line'] = i.currentline\n"
" end"
" return err\n"
"end\n";
RunSafe(lua, errh_func, "@err_handler_def");
}
{
const char code[] = R"(
local dbg=debug
local mt = {}
setmetatable(_G, mt)
mt.__newindex = function (t, n, v)
if dbg.getinfo(2) then
local w = dbg.getinfo(2, "S").what
if w ~= "main" and w ~= "C" then
error("Script attempted to create global variable '"..tostring(n).."'", 2)
end
end
rawset(t, n, v)
end
mt.__index = function (t, n)
if dbg.getinfo(2) and dbg.getinfo(2, "S").what ~= "C" then
error("Script attempted to access nonexistent global variable '"..tostring(n).."'", 2)
end
return rawget(t, n)
end
debug = nil
)";
RunSafe(lua, code, "@enable_strict_lua");
for (const char* func : {"loadfile", "dofile", "print"}) {
lua_pushnil(lua);
lua_setglobal(lua, func);
}
lua_pushnil(lua);
lua_setglobal(lua, "loadfile");
lua_pushnil(lua);
lua_setglobal(lua, "dofile");
// Register deprecated or removed functions to maintain compatibility with 5.1
register_polyfills(lua);
}
@ -566,6 +594,17 @@ Interpreter::Interpreter() {
/* Finally set the table as 'redis' global var. */
lua_setglobal(lua_, "redis");
// Add
lua_rawgeti(lua_, LUA_REGISTRYINDEX, LUA_RIDX_GLOBALS);
lua_newtable(lua_); /* push metatable */
lua_pushcfunction(lua_, ProtectedTableError); /* push get error handler */
lua_setfield(lua_, -2, "__index");
lua_pushcfunction(lua_, NewIndexAllowList); /* push get error handler */
lua_setfield(lua_, -2, "__newindex");
lua_setmetatable(lua_, -2);
lua_pop(lua_, 1); // LUA_RIDX_GLOBALS
CHECK(lua_checkstack(lua_, 64));
}
@ -583,20 +622,12 @@ void Interpreter::FuncSha1(string_view body, char* fp) {
}
auto Interpreter::AddFunction(string_view sha, string_view body, string* result) -> AddResult {
char funcname[43];
funcname[0] = 'f';
funcname[1] = '_';
DCHECK(sha.size() == 40);
memcpy(funcname + 2, sha.data(), sha.size());
funcname[42] = '\0';
bool exists = LuaNameExists(sha);
int type = lua_getglobal(lua_, funcname);
lua_pop(lua_, 1);
if (type == LUA_TNIL && !AddInternal(funcname, body, result))
if (!exists && !AddInternal(sha, body, result))
return COMPILE_ERR;
return type == LUA_TNIL ? ADD_OK : ALREADY_EXISTS;
return exists ? ALREADY_EXISTS : ADD_OK;
}
bool Interpreter::Exists(string_view sha) const {
@ -605,16 +636,7 @@ bool Interpreter::Exists(string_view sha) const {
if (sha.size() != 40)
return false;
char fname[43];
fname[0] = 'f';
fname[1] = '_';
fname[42] = '\0';
memcpy(fname + 2, sha.data(), 40);
int type = lua_getglobal(lua_, fname);
lua_pop(lua_, 1);
return type == LUA_TFUNCTION;
return LuaNameExists(sha);
}
auto Interpreter::RunFunction(string_view sha, std::string* error) -> RunResult {
@ -622,7 +644,7 @@ auto Interpreter::RunFunction(string_view sha, std::string* error) -> RunResult
DCHECK_EQ(40u, sha.size());
lua_getglobal(lua_, "__redis__err__handler");
lua_getglobal(lua_, "__df__err__handler");
char fname[43];
fname[0] = 'f';
fname[1] = '_';
@ -643,7 +665,14 @@ auto Interpreter::RunFunction(string_view sha, std::string* error) -> RunResult
int err = lua_pcall(lua_, 0, 1, -2);
if (err) {
*error = lua_tostring(lua_, -1);
const char* msg = "execution failure";
if (lua_istable(lua_, -1)) {
msg = "TBD";
} else if (lua_isstring(lua_, -1)) {
msg = lua_tostring(lua_, -1);
}
*error = absl::StrCat("Error running script ", sha, " ", msg);
}
return err == 0 ? RUN_OK : RUN_ERR;
@ -714,7 +743,7 @@ optional<string> Interpreter::DetectPossibleAsyncCalls(string_view body_sv) {
for (auto pos : targets)
body.insert(pos, "a");
VLOG(1) << "Detected " << targets.size() << " aync calls in script";
VLOG(1) << "Detected " << targets.size() << " async calls in script";
return body;
}
@ -737,8 +766,8 @@ bool Interpreter::IsResultSafe() const {
return res;
}
bool Interpreter::AddInternal(const char* f_id, string_view body, string* error) {
string script = absl::StrCat("function ", f_id, "() \n");
bool Interpreter::AddInternal(string_view sha, string_view body, string* error) {
string script = absl::StrCat("function f_", sha, "() \n");
absl::StrAppend(&script, body, "\nend");
int res = luaL_loadbuffer(lua_, script.data(), script.size(), "@user_script");
@ -1041,6 +1070,22 @@ int Interpreter::RedisAPCallCommand(lua_State* lua) {
return reinterpret_cast<Interpreter*>(*ptr)->RedisGenericCommand(false, true);
}
bool Interpreter::LuaNameExists(std::string_view sha) const {
char funcname[43];
funcname[0] = 'f';
funcname[1] = '_';
DCHECK_EQ(40u, sha.size());
memcpy(funcname + 2, sha.data(), sha.size());
funcname[42] = '\0';
lua_pushglobaltable(lua_);
lua_pushlstring(lua_, funcname, 42);
int type = lua_rawget(lua_, -2);
lua_pop(lua_, 2); // global table and the value.
return type != LUA_TNIL;
}
InterpreterManager::Stats& InterpreterManager::Stats::operator+=(const Stats& other) {
this->used_bytes += other.used_bytes;
this->interpreter_cnt += other.interpreter_cnt;

View file

@ -126,8 +126,9 @@ class Interpreter {
private:
// Returns true if function was successfully added,
// otherwise returns false and sets the error.
bool AddInternal(const char* f_id, std::string_view body, std::string* error);
bool AddInternal(std::string_view sha, std::string_view body, std::string* error);
bool IsTableSafe() const;
bool LuaNameExists(std::string_view sha) const;
static int RedisCallCommand(lua_State* lua);
static int RedisPCallCommand(lua_State* lua);

View file

@ -180,10 +180,18 @@ TEST_F(InterpreterTest, UnknownFunc) {
return myunknownfunc(1, n)
end)");
CHECK_EQ(0, luaL_loadbuffer(lua(), code.data(), code.size(), "code1"));
CHECK_EQ(0, lua_pcall(lua(), 0, 0, 0));
int type = lua_getglobal(lua(), "myunknownfunc");
ASSERT_EQ(LUA_TNIL, type);
RunInline(code, "code1");
lua_pushglobaltable(lua()); // Push the global table
lua_pushstring(lua(), "myunknownfunc"); // Push the function name onto the stack
int type = lua_rawget(lua(), -2); // Perform a raw access to get the value
ASSERT_EQ(LUA_TNIL, type); // does not exist
lua_pop(lua(), 1);
ASSERT_EQ(LUA_TTABLE, lua_type(lua(), -1));
lua_pushstring(lua(), "foo");
type = lua_rawget(lua(), -2);
ASSERT_EQ(LUA_TFUNCTION, type);
lua_pop(lua(), 2);
}
TEST_F(InterpreterTest, Stack) {
@ -267,6 +275,21 @@ TEST_F(InterpreterTest, Execute) {
EXPECT_TRUE(Execute("return {map={a=1,b=2}}"));
EXPECT_THAT(ser_.res, testing::AnyOf("{str(a) i(1) str(b) i(2)}", "{str(b) i(2) str(a) i(1)}"));
EXPECT_FALSE(Execute("var = 45"));
EXPECT_EQ(error_, "");
// Unknown variables.
EXPECT_FALSE(Execute("return {foo, bar}"));
// f_61fb78a9b8dad5c413e203473e5714dc00c982fc is a function registered due to
// previous executions.
//
// TODO: we inject here a new implementation, so when we run the script corresponding
// to the same sha again, we run something else. It's a security breach and must to be protected.
EXPECT_TRUE(Execute("f_61fb78a9b8dad5c413e203473e5714dc00c982fc = function() return 42 end"));
EXPECT_TRUE(Execute("return {foo, bar}"));
EXPECT_EQ(ser_.res, "i(42)");
}
TEST_F(InterpreterTest, Call) {
@ -485,11 +508,20 @@ TEST_F(InterpreterTest, Log) {
EXPECT_EQ("nil", ser_.res);
}
TEST_F(InterpreterTest, Robust) {
TEST_F(InterpreterTest, NonExistent) {
EXPECT_FALSE(Execute(R"(eval "local a = {}
setmetatable(a,{__index=function() foo() end})
return a")"));
EXPECT_EQ("", ser_.res);
}
TEST_F(InterpreterTest, ModifyG) {
EXPECT_TRUE(Execute(R"(
setmetatable(_G, {})
return _G
)"));
EXPECT_EQ("[]", ser_.res);
// EXPECT_TRUE(Execute("local g = getmetatable(_G); g.__index = {}"));
}
} // namespace dfly