mirror of
https://github.com/dragonflydb/dragonfly.git
synced 2024-12-14 11:58:02 +00:00
Detect possible async calls in scripts (#1122)
Automatically detect possible async calls for lua scripts based on regex
This commit is contained in:
parent
3fd4e277d3
commit
89072228e5
8 changed files with 205 additions and 24 deletions
|
@ -11,6 +11,8 @@
|
|||
|
||||
#include <cstring>
|
||||
#include <optional>
|
||||
#include <regex>
|
||||
#include <set>
|
||||
|
||||
#include "core/interpreter_polyfill.h"
|
||||
|
||||
|
@ -479,6 +481,72 @@ void Interpreter::SetGlobalArray(const char* name, MutSliceSpan args) {
|
|||
SetGlobalArrayInternal(lua_, name, args);
|
||||
}
|
||||
|
||||
optional<string> Interpreter::DetectPossibleAsyncCalls(string_view body_sv) {
|
||||
// We want to detect `redis.call` expressions with unused return values, i.e. they are a
|
||||
// standalone statement, not part of a expression, condition, function call or assignment.
|
||||
//
|
||||
// We search for all `redis.(p)call` statements, that are preceeded on the same line by
|
||||
// - `do` or `then` -> first statement in a new block, certainly unused value
|
||||
// - no tokens -> we need to check the previous line, if its part of a multi-line expression.
|
||||
//
|
||||
// If we need to check the previous line, we search for the last word (before comments, if it has
|
||||
// one).
|
||||
static const regex kRegex{"(?:(\\S+)(\\s*--.*?)*\\s*\n|(then)|(do)|(^))\\s*redis\\.(p*call)"};
|
||||
|
||||
// Taken from https://www.lua.org/manual/5.4/manual.html - 3.1 - Lexical conventions
|
||||
|
||||
// If a line ends with it, then most likely the next line belongs to it as well
|
||||
static const set<string_view> kContOperators = {
|
||||
"+", "-", "*", "/", "%", "^", "#", "&", "~", "|", "<<", ">>", "//", "==",
|
||||
"~=", "<=", ">=", "<", ">", "=", "(", "{", "[", "::", ":", ",", ".", ".."};
|
||||
|
||||
// If a line ends with it, then most likely the next line belongs to it as well
|
||||
static const set<string_view> kContTokens = {"and", "else", "elseif", "for", "goto",
|
||||
"if", "in", "local", "not", "or",
|
||||
"repeat", "return", "until", "while"};
|
||||
|
||||
auto last_n = [](const string& s, size_t n) {
|
||||
return s.size() < n ? s : s.substr(s.size() - n, n);
|
||||
};
|
||||
|
||||
smatch sm;
|
||||
string body{body_sv};
|
||||
vector<size_t> targets;
|
||||
|
||||
// We don't handle comment blocks yet.
|
||||
if (body.find("--[[") != string::npos)
|
||||
return {};
|
||||
|
||||
sregex_iterator it{body.begin(), body.end(), kRegex};
|
||||
sregex_iterator end{};
|
||||
|
||||
for (; it != end; it++) {
|
||||
auto last_word = it->str(1);
|
||||
|
||||
if (kContOperators.count(last_n(last_word, 2)) > 0 ||
|
||||
kContOperators.count(last_n(last_word, 1)) > 0)
|
||||
continue;
|
||||
|
||||
if (kContTokens.count(last_word) > 0)
|
||||
continue;
|
||||
|
||||
targets.push_back(it->position(it->size() - 1));
|
||||
}
|
||||
|
||||
if (targets.empty())
|
||||
return nullopt;
|
||||
|
||||
// Insert 'a' before 'call' and 'pcall'. Reverse order to preserve positions
|
||||
reverse(targets.begin(), targets.end());
|
||||
body.reserve(body.size() + targets.size());
|
||||
for (auto pos : targets)
|
||||
body.insert(pos, "a");
|
||||
|
||||
VLOG(1) << "Detected " << targets.size() << " aync calls in script";
|
||||
|
||||
return body;
|
||||
}
|
||||
|
||||
bool Interpreter::IsResultSafe() const {
|
||||
int top = lua_gettop(lua_);
|
||||
if (top >= 128)
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <string_view>
|
||||
|
||||
#include "core/core_types.h"
|
||||
|
@ -106,6 +107,8 @@ class Interpreter {
|
|||
// fp[40] will be set to '\0'.
|
||||
static void FuncSha1(std::string_view body, char* fp);
|
||||
|
||||
static std::optional<std::string> DetectPossibleAsyncCalls(std::string_view body);
|
||||
|
||||
template <typename U> void SetRedisFunc(U&& u) {
|
||||
redis_func_ = std::forward<U>(u);
|
||||
}
|
||||
|
|
|
@ -371,4 +371,85 @@ TEST_F(InterpreterTest, Compatibility) {
|
|||
EXPECT_FALSE(Execute("local t = {}; local a = 1; table.setn(t, 100); return a+123;"));
|
||||
}
|
||||
|
||||
TEST_F(InterpreterTest, AsyncReplacement) {
|
||||
const string_view kCases[] = {
|
||||
R"(
|
||||
redis.[A]call('INCR', 'A')
|
||||
redis.[A]call('INCR', 'A')
|
||||
)",
|
||||
R"(
|
||||
function test()
|
||||
redis.[A]call('INCR', 'A')
|
||||
end
|
||||
)",
|
||||
R"(
|
||||
local b = redis.call('GET', 'A') + redis.call('GET', 'B')
|
||||
)",
|
||||
R"(
|
||||
if redis.call('EXISTS', 'A') then redis.[A]call('SET', 'B', 1) end
|
||||
)",
|
||||
R"(
|
||||
while redis.call('EXISTS', 'A') do redis.[A]call('SET', 'B', 1) end
|
||||
)",
|
||||
R"(
|
||||
while
|
||||
redis.call('EXISTS', 'A') do
|
||||
print("OK")
|
||||
end
|
||||
)",
|
||||
R"(
|
||||
print(redis.call('GET', 'A'))
|
||||
)",
|
||||
R"(
|
||||
local table = {
|
||||
redis.call('GET', 'A')
|
||||
}
|
||||
)",
|
||||
R"(
|
||||
while true do
|
||||
redis.[A]call('INCR', 'A')
|
||||
end
|
||||
)",
|
||||
R"(
|
||||
if 1 + -- now this is a tricky comment
|
||||
redis.call('GET', 'A')
|
||||
> 0
|
||||
then end
|
||||
)",
|
||||
R"(
|
||||
print('Output'
|
||||
..
|
||||
redis.call('GET', 'A')
|
||||
)
|
||||
)",
|
||||
R"(
|
||||
while
|
||||
0 < -- we have a comment here unfortunately
|
||||
redis.call('GET', 'A')
|
||||
then end
|
||||
)",
|
||||
R"(
|
||||
while
|
||||
-- we have
|
||||
-- a tricky
|
||||
-- multiline comment
|
||||
redis.call('EXISTS')
|
||||
do end
|
||||
)",
|
||||
R"(
|
||||
--[[ WE SKIP COMMENT BLOCKS FOR NOW ]]
|
||||
redis.call('ECHO', 'TEST')
|
||||
)"};
|
||||
|
||||
for (auto test : kCases) {
|
||||
auto expected = absl::StrReplaceAll(test, {{"[A]", "a"}});
|
||||
auto input = absl::StrReplaceAll(test, {{"[A]", ""}});
|
||||
|
||||
auto result = Interpreter::DetectPossibleAsyncCalls(input);
|
||||
string_view output = result ? *result : input;
|
||||
|
||||
EXPECT_EQ(expected, output);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace dfly
|
||||
|
|
|
@ -429,8 +429,8 @@ void DflyCmd::FullSyncFb(FlowInfo* flow, Context* cntx) {
|
|||
if (saver->Mode() == SaveMode::SUMMARY) {
|
||||
auto scripts = sf_->script_mgr()->GetAll();
|
||||
StringVec script_bodies;
|
||||
for (auto& script : scripts) {
|
||||
script_bodies.push_back(move(script.second));
|
||||
for (auto& [sha, data] : scripts) {
|
||||
script_bodies.push_back(move(data.orig_body));
|
||||
}
|
||||
ec = saver->SaveHeader(script_bodies);
|
||||
} else {
|
||||
|
|
|
@ -524,7 +524,8 @@ TEST_F(MultiTest, EvalOOO) {
|
|||
return;
|
||||
}
|
||||
|
||||
const char* kScript = "redis.call('MGET', unpack(KEYS)); return 'OK'";
|
||||
// Assign to prevent asyc optimization.
|
||||
const char* kScript = "local r = redis.call('MGET', unpack(KEYS)); return 'OK'";
|
||||
|
||||
// Check single call.
|
||||
{
|
||||
|
|
|
@ -29,6 +29,10 @@ ABSL_FLAG(std::string, default_lua_config, "",
|
|||
"separated by space, for example 'allow-undeclared-keys disable-atomicity' runs scripts "
|
||||
"non-atomically and allows accessing undeclared keys");
|
||||
|
||||
ABSL_FLAG(
|
||||
bool, lua_auto_async, false,
|
||||
"If enabled, call/pcall with discarded values are automatically replaced with acall/apcall.");
|
||||
|
||||
namespace dfly {
|
||||
|
||||
using namespace std;
|
||||
|
@ -150,12 +154,14 @@ void ScriptMgr::ConfigCmd(CmdArgList args, ConnectionContext* cntx) {
|
|||
}
|
||||
|
||||
void ScriptMgr::ListCmd(ConnectionContext* cntx) const {
|
||||
vector<pair<string, string>> scripts = GetAll();
|
||||
vector<pair<string, ScriptData>> scripts = GetAll();
|
||||
(*cntx)->StartArray(scripts.size());
|
||||
for (const auto& k_v : scripts) {
|
||||
(*cntx)->StartArray(2);
|
||||
(*cntx)->SendBulkString(k_v.first);
|
||||
(*cntx)->SendBulkString(k_v.second);
|
||||
for (const auto& [sha, data] : scripts) {
|
||||
(*cntx)->StartArray(data.orig_body.empty() ? 2 : 3);
|
||||
(*cntx)->SendBulkString(sha);
|
||||
(*cntx)->SendBulkString(data.body);
|
||||
if (!data.orig_body.empty())
|
||||
(*cntx)->SendBulkString(data.orig_body);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -197,15 +203,34 @@ io::Result<optional<ScriptMgr::ScriptParams>, GenericError> DeduceParams(string_
|
|||
return params;
|
||||
}
|
||||
|
||||
unique_ptr<char[]> CharBufFromSV(string_view sv) {
|
||||
auto ptr = make_unique<char[]>(sv.size() + 1);
|
||||
memcpy(ptr.get(), sv.data(), sv.size());
|
||||
ptr[sv.size()] = '\0';
|
||||
return ptr;
|
||||
}
|
||||
|
||||
io::Result<string, GenericError> ScriptMgr::Insert(string_view body, Interpreter* interpreter) {
|
||||
// Calculate hash before removing shebang (#!lua).
|
||||
char sha_buf[64];
|
||||
Interpreter::FuncSha1(body, sha_buf);
|
||||
string_view sha{sha_buf, std::strlen(sha_buf)};
|
||||
|
||||
auto params = DeduceParams(&body);
|
||||
if (!params)
|
||||
return params.get_unexpected();
|
||||
string_view orig_body = body;
|
||||
|
||||
auto params_opt = DeduceParams(&body);
|
||||
if (!params_opt)
|
||||
return params_opt.get_unexpected();
|
||||
auto params = params_opt->value_or(default_params_);
|
||||
|
||||
// If the script is atomic, check for possible squashing optimizations.
|
||||
// For non atomic modes, squashing increases the time locks are held, which
|
||||
// can decrease throughput with frequently accessed keys.
|
||||
optional<string> async_body;
|
||||
if (params.atomic && absl::GetFlag(FLAGS_lua_auto_async)) {
|
||||
if (async_body = Interpreter::DetectPossibleAsyncCalls(body); async_body)
|
||||
body = *async_body;
|
||||
}
|
||||
|
||||
string result;
|
||||
Interpreter::AddResult add_result = interpreter->AddFunction(sha, body, &result);
|
||||
|
@ -213,12 +238,12 @@ io::Result<string, GenericError> ScriptMgr::Insert(string_view body, Interpreter
|
|||
return nonstd::make_unexpected(GenericError{move(result)});
|
||||
|
||||
lock_guard lk{mu_};
|
||||
auto [it, _] = db_.emplace(sha, InternalScriptData{params->value_or(default_params_), nullptr});
|
||||
auto [it, _] = db_.emplace(sha, InternalScriptData{params, nullptr});
|
||||
|
||||
if (auto& body_ptr = it->second.body; !body_ptr) {
|
||||
body_ptr.reset(new char[body.size() + 1]);
|
||||
memcpy(body_ptr.get(), body.data(), body.size());
|
||||
body_ptr[body.size()] = '\0';
|
||||
if (!it->second.body) {
|
||||
it->second.body = CharBufFromSV(body);
|
||||
if (body != orig_body)
|
||||
it->second.orig_body = CharBufFromSV(orig_body);
|
||||
}
|
||||
|
||||
UpdateScriptCaches(sha, it->second);
|
||||
|
@ -232,18 +257,19 @@ optional<ScriptMgr::ScriptData> ScriptMgr::Find(std::string_view sha) const {
|
|||
|
||||
lock_guard lk{mu_};
|
||||
if (auto it = db_.find(sha); it != db_.end() && it->second.body)
|
||||
return ScriptData{it->second, it->second.body.get()};
|
||||
return ScriptData{it->second, it->second.body.get(), {}};
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
vector<pair<string, string>> ScriptMgr::GetAll() const {
|
||||
vector<pair<string, string>> res;
|
||||
vector<pair<string, ScriptMgr::ScriptData>> ScriptMgr::GetAll() const {
|
||||
vector<pair<string, ScriptData>> res;
|
||||
|
||||
lock_guard lk{mu_};
|
||||
res.reserve(db_.size());
|
||||
for (const auto& [sha, data] : db_) {
|
||||
res.emplace_back(string{sha.data(), sha.size()}, data.body.get());
|
||||
res.emplace_back(string{sha.data(), sha.size()},
|
||||
ScriptData{data, data.body.get(), data.orig_body.get()});
|
||||
}
|
||||
|
||||
return res;
|
||||
|
|
|
@ -31,7 +31,8 @@ class ScriptMgr {
|
|||
};
|
||||
|
||||
struct ScriptData : public ScriptParams {
|
||||
const char* body = nullptr;
|
||||
std::string body;
|
||||
std::string orig_body;
|
||||
};
|
||||
|
||||
struct ScriptKey : public std::array<char, 40> {
|
||||
|
@ -51,7 +52,7 @@ class ScriptMgr {
|
|||
std::optional<ScriptData> Find(std::string_view sha) const;
|
||||
|
||||
// Returns a list of all scripts in the database with their sha and body.
|
||||
std::vector<std::pair<std::string, std::string>> GetAll() const;
|
||||
std::vector<std::pair<std::string, ScriptData>> GetAll() const;
|
||||
|
||||
private:
|
||||
void ExistsCmd(CmdArgList args, ConnectionContext* cntx) const;
|
||||
|
@ -65,6 +66,7 @@ class ScriptMgr {
|
|||
private:
|
||||
struct InternalScriptData : public ScriptParams {
|
||||
std::unique_ptr<char[]> body{};
|
||||
std::unique_ptr<char[]> orig_body{};
|
||||
};
|
||||
|
||||
ScriptParams default_params_;
|
||||
|
|
|
@ -991,8 +991,8 @@ GenericError ServerFamily::DoSave(bool new_version, Transaction* trans) {
|
|||
auto get_scripts = [this] {
|
||||
auto scripts = script_mgr_->GetAll();
|
||||
StringVec script_bodies;
|
||||
for (const auto& script : scripts) {
|
||||
script_bodies.push_back(move(script.second));
|
||||
for (auto& [sha, data] : scripts) {
|
||||
script_bodies.push_back(move(data.body));
|
||||
}
|
||||
return script_bodies;
|
||||
};
|
||||
|
|
Loading…
Reference in a new issue