1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-14 11:58:02 +00:00

Detect possible async calls in scripts (#1122)

Automatically detect possible async calls for lua scripts based on regex
This commit is contained in:
Vladislav 2023-05-01 15:03:51 +03:00 committed by GitHub
parent 3fd4e277d3
commit 89072228e5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 205 additions and 24 deletions

View file

@ -11,6 +11,8 @@
#include <cstring>
#include <optional>
#include <regex>
#include <set>
#include "core/interpreter_polyfill.h"
@ -479,6 +481,72 @@ void Interpreter::SetGlobalArray(const char* name, MutSliceSpan args) {
SetGlobalArrayInternal(lua_, name, args);
}
optional<string> Interpreter::DetectPossibleAsyncCalls(string_view body_sv) {
// We want to detect `redis.call` expressions with unused return values, i.e. they are a
// standalone statement, not part of a expression, condition, function call or assignment.
//
// We search for all `redis.(p)call` statements, that are preceeded on the same line by
// - `do` or `then` -> first statement in a new block, certainly unused value
// - no tokens -> we need to check the previous line, if its part of a multi-line expression.
//
// If we need to check the previous line, we search for the last word (before comments, if it has
// one).
static const regex kRegex{"(?:(\\S+)(\\s*--.*?)*\\s*\n|(then)|(do)|(^))\\s*redis\\.(p*call)"};
// Taken from https://www.lua.org/manual/5.4/manual.html - 3.1 - Lexical conventions
// If a line ends with it, then most likely the next line belongs to it as well
static const set<string_view> kContOperators = {
"+", "-", "*", "/", "%", "^", "#", "&", "~", "|", "<<", ">>", "//", "==",
"~=", "<=", ">=", "<", ">", "=", "(", "{", "[", "::", ":", ",", ".", ".."};
// If a line ends with it, then most likely the next line belongs to it as well
static const set<string_view> kContTokens = {"and", "else", "elseif", "for", "goto",
"if", "in", "local", "not", "or",
"repeat", "return", "until", "while"};
auto last_n = [](const string& s, size_t n) {
return s.size() < n ? s : s.substr(s.size() - n, n);
};
smatch sm;
string body{body_sv};
vector<size_t> targets;
// We don't handle comment blocks yet.
if (body.find("--[[") != string::npos)
return {};
sregex_iterator it{body.begin(), body.end(), kRegex};
sregex_iterator end{};
for (; it != end; it++) {
auto last_word = it->str(1);
if (kContOperators.count(last_n(last_word, 2)) > 0 ||
kContOperators.count(last_n(last_word, 1)) > 0)
continue;
if (kContTokens.count(last_word) > 0)
continue;
targets.push_back(it->position(it->size() - 1));
}
if (targets.empty())
return nullopt;
// Insert 'a' before 'call' and 'pcall'. Reverse order to preserve positions
reverse(targets.begin(), targets.end());
body.reserve(body.size() + targets.size());
for (auto pos : targets)
body.insert(pos, "a");
VLOG(1) << "Detected " << targets.size() << " aync calls in script";
return body;
}
bool Interpreter::IsResultSafe() const {
int top = lua_gettop(lua_);
if (top >= 128)

View file

@ -5,6 +5,7 @@
#pragma once
#include <functional>
#include <optional>
#include <string_view>
#include "core/core_types.h"
@ -106,6 +107,8 @@ class Interpreter {
// fp[40] will be set to '\0'.
static void FuncSha1(std::string_view body, char* fp);
static std::optional<std::string> DetectPossibleAsyncCalls(std::string_view body);
template <typename U> void SetRedisFunc(U&& u) {
redis_func_ = std::forward<U>(u);
}

View file

@ -371,4 +371,85 @@ TEST_F(InterpreterTest, Compatibility) {
EXPECT_FALSE(Execute("local t = {}; local a = 1; table.setn(t, 100); return a+123;"));
}
TEST_F(InterpreterTest, AsyncReplacement) {
const string_view kCases[] = {
R"(
redis.[A]call('INCR', 'A')
redis.[A]call('INCR', 'A')
)",
R"(
function test()
redis.[A]call('INCR', 'A')
end
)",
R"(
local b = redis.call('GET', 'A') + redis.call('GET', 'B')
)",
R"(
if redis.call('EXISTS', 'A') then redis.[A]call('SET', 'B', 1) end
)",
R"(
while redis.call('EXISTS', 'A') do redis.[A]call('SET', 'B', 1) end
)",
R"(
while
redis.call('EXISTS', 'A') do
print("OK")
end
)",
R"(
print(redis.call('GET', 'A'))
)",
R"(
local table = {
redis.call('GET', 'A')
}
)",
R"(
while true do
redis.[A]call('INCR', 'A')
end
)",
R"(
if 1 + -- now this is a tricky comment
redis.call('GET', 'A')
> 0
then end
)",
R"(
print('Output'
..
redis.call('GET', 'A')
)
)",
R"(
while
0 < -- we have a comment here unfortunately
redis.call('GET', 'A')
then end
)",
R"(
while
-- we have
-- a tricky
-- multiline comment
redis.call('EXISTS')
do end
)",
R"(
--[[ WE SKIP COMMENT BLOCKS FOR NOW ]]
redis.call('ECHO', 'TEST')
)"};
for (auto test : kCases) {
auto expected = absl::StrReplaceAll(test, {{"[A]", "a"}});
auto input = absl::StrReplaceAll(test, {{"[A]", ""}});
auto result = Interpreter::DetectPossibleAsyncCalls(input);
string_view output = result ? *result : input;
EXPECT_EQ(expected, output);
}
}
} // namespace dfly

View file

@ -429,8 +429,8 @@ void DflyCmd::FullSyncFb(FlowInfo* flow, Context* cntx) {
if (saver->Mode() == SaveMode::SUMMARY) {
auto scripts = sf_->script_mgr()->GetAll();
StringVec script_bodies;
for (auto& script : scripts) {
script_bodies.push_back(move(script.second));
for (auto& [sha, data] : scripts) {
script_bodies.push_back(move(data.orig_body));
}
ec = saver->SaveHeader(script_bodies);
} else {

View file

@ -524,7 +524,8 @@ TEST_F(MultiTest, EvalOOO) {
return;
}
const char* kScript = "redis.call('MGET', unpack(KEYS)); return 'OK'";
// Assign to prevent asyc optimization.
const char* kScript = "local r = redis.call('MGET', unpack(KEYS)); return 'OK'";
// Check single call.
{

View file

@ -29,6 +29,10 @@ ABSL_FLAG(std::string, default_lua_config, "",
"separated by space, for example 'allow-undeclared-keys disable-atomicity' runs scripts "
"non-atomically and allows accessing undeclared keys");
ABSL_FLAG(
bool, lua_auto_async, false,
"If enabled, call/pcall with discarded values are automatically replaced with acall/apcall.");
namespace dfly {
using namespace std;
@ -150,12 +154,14 @@ void ScriptMgr::ConfigCmd(CmdArgList args, ConnectionContext* cntx) {
}
void ScriptMgr::ListCmd(ConnectionContext* cntx) const {
vector<pair<string, string>> scripts = GetAll();
vector<pair<string, ScriptData>> scripts = GetAll();
(*cntx)->StartArray(scripts.size());
for (const auto& k_v : scripts) {
(*cntx)->StartArray(2);
(*cntx)->SendBulkString(k_v.first);
(*cntx)->SendBulkString(k_v.second);
for (const auto& [sha, data] : scripts) {
(*cntx)->StartArray(data.orig_body.empty() ? 2 : 3);
(*cntx)->SendBulkString(sha);
(*cntx)->SendBulkString(data.body);
if (!data.orig_body.empty())
(*cntx)->SendBulkString(data.orig_body);
}
}
@ -197,15 +203,34 @@ io::Result<optional<ScriptMgr::ScriptParams>, GenericError> DeduceParams(string_
return params;
}
unique_ptr<char[]> CharBufFromSV(string_view sv) {
auto ptr = make_unique<char[]>(sv.size() + 1);
memcpy(ptr.get(), sv.data(), sv.size());
ptr[sv.size()] = '\0';
return ptr;
}
io::Result<string, GenericError> ScriptMgr::Insert(string_view body, Interpreter* interpreter) {
// Calculate hash before removing shebang (#!lua).
char sha_buf[64];
Interpreter::FuncSha1(body, sha_buf);
string_view sha{sha_buf, std::strlen(sha_buf)};
auto params = DeduceParams(&body);
if (!params)
return params.get_unexpected();
string_view orig_body = body;
auto params_opt = DeduceParams(&body);
if (!params_opt)
return params_opt.get_unexpected();
auto params = params_opt->value_or(default_params_);
// If the script is atomic, check for possible squashing optimizations.
// For non atomic modes, squashing increases the time locks are held, which
// can decrease throughput with frequently accessed keys.
optional<string> async_body;
if (params.atomic && absl::GetFlag(FLAGS_lua_auto_async)) {
if (async_body = Interpreter::DetectPossibleAsyncCalls(body); async_body)
body = *async_body;
}
string result;
Interpreter::AddResult add_result = interpreter->AddFunction(sha, body, &result);
@ -213,12 +238,12 @@ io::Result<string, GenericError> ScriptMgr::Insert(string_view body, Interpreter
return nonstd::make_unexpected(GenericError{move(result)});
lock_guard lk{mu_};
auto [it, _] = db_.emplace(sha, InternalScriptData{params->value_or(default_params_), nullptr});
auto [it, _] = db_.emplace(sha, InternalScriptData{params, nullptr});
if (auto& body_ptr = it->second.body; !body_ptr) {
body_ptr.reset(new char[body.size() + 1]);
memcpy(body_ptr.get(), body.data(), body.size());
body_ptr[body.size()] = '\0';
if (!it->second.body) {
it->second.body = CharBufFromSV(body);
if (body != orig_body)
it->second.orig_body = CharBufFromSV(orig_body);
}
UpdateScriptCaches(sha, it->second);
@ -232,18 +257,19 @@ optional<ScriptMgr::ScriptData> ScriptMgr::Find(std::string_view sha) const {
lock_guard lk{mu_};
if (auto it = db_.find(sha); it != db_.end() && it->second.body)
return ScriptData{it->second, it->second.body.get()};
return ScriptData{it->second, it->second.body.get(), {}};
return std::nullopt;
}
vector<pair<string, string>> ScriptMgr::GetAll() const {
vector<pair<string, string>> res;
vector<pair<string, ScriptMgr::ScriptData>> ScriptMgr::GetAll() const {
vector<pair<string, ScriptData>> res;
lock_guard lk{mu_};
res.reserve(db_.size());
for (const auto& [sha, data] : db_) {
res.emplace_back(string{sha.data(), sha.size()}, data.body.get());
res.emplace_back(string{sha.data(), sha.size()},
ScriptData{data, data.body.get(), data.orig_body.get()});
}
return res;

View file

@ -31,7 +31,8 @@ class ScriptMgr {
};
struct ScriptData : public ScriptParams {
const char* body = nullptr;
std::string body;
std::string orig_body;
};
struct ScriptKey : public std::array<char, 40> {
@ -51,7 +52,7 @@ class ScriptMgr {
std::optional<ScriptData> Find(std::string_view sha) const;
// Returns a list of all scripts in the database with their sha and body.
std::vector<std::pair<std::string, std::string>> GetAll() const;
std::vector<std::pair<std::string, ScriptData>> GetAll() const;
private:
void ExistsCmd(CmdArgList args, ConnectionContext* cntx) const;
@ -65,6 +66,7 @@ class ScriptMgr {
private:
struct InternalScriptData : public ScriptParams {
std::unique_ptr<char[]> body{};
std::unique_ptr<char[]> orig_body{};
};
ScriptParams default_params_;

View file

@ -991,8 +991,8 @@ GenericError ServerFamily::DoSave(bool new_version, Transaction* trans) {
auto get_scripts = [this] {
auto scripts = script_mgr_->GetAll();
StringVec script_bodies;
for (const auto& script : scripts) {
script_bodies.push_back(move(script.second));
for (auto& [sha, data] : scripts) {
script_bodies.push_back(move(data.body));
}
return script_bodies;
};