1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-14 11:58:02 +00:00

Collect errors from async lua calls (#1092)

Error collection from async calls
This commit is contained in:
Vladislav 2023-04-22 09:02:22 +03:00 committed by GitHub
parent 71147c20a9
commit 74a1454409
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 229 additions and 80 deletions

View file

@ -377,6 +377,11 @@ Interpreter::Interpreter() {
lua_pushcfunction(lua_, RedisACallCommand);
lua_settable(lua_, -3);
/* redis.apcall */
lua_pushstring(lua_, "apcall");
lua_pushcfunction(lua_, RedisAPCallCommand);
lua_settable(lua_, -3);
lua_pushstring(lua_, "sha1hex");
lua_pushcfunction(lua_, RedisSha1Command);
lua_settable(lua_, -3);
@ -720,7 +725,7 @@ int Interpreter::RedisGenericCommand(bool raise_error, bool async) {
* and this way we guaranty we will have room on the stack for the result. */
lua_pop(lua_, argc);
RedisTranslator translator(lua_);
redis_func_(CallArgs{MutSliceSpan{args}, &buffer, &translator, async});
redis_func_(CallArgs{MutSliceSpan{args}, &buffer, &translator, async, raise_error, &raise_error});
cmd_depth_--;
// Raise error for regular 'call' command if needed.
@ -746,6 +751,11 @@ int Interpreter::RedisPCallCommand(lua_State* lua) {
}
int Interpreter::RedisACallCommand(lua_State* lua) {
void** ptr = static_cast<void**>(lua_getextraspace(lua));
return reinterpret_cast<Interpreter*>(*ptr)->RedisGenericCommand(true, true);
}
int Interpreter::RedisAPCallCommand(lua_State* lua) {
void** ptr = static_cast<void**>(lua_getextraspace(lua));
return reinterpret_cast<Interpreter*>(*ptr)->RedisGenericCommand(false, true);
}

View file

@ -43,7 +43,12 @@ class Interpreter {
ObjectExplorer* translator;
bool async; // async by redis.acall
bool async; // async by acall
bool error_abort; // abort on errors (not pcall)
// The function can request an abort due to an error, even if error_abort is false.
// It happens when async cmds are flushed and result in an uncatched error.
bool* requested_abort;
};
using RedisFunc = std::function<void(CallArgs)>;
@ -112,10 +117,12 @@ class Interpreter {
bool IsTableSafe() const;
int RedisGenericCommand(bool raise_error, bool async);
int RedisACallErrorsCommand();
static int RedisCallCommand(lua_State* lua);
static int RedisPCallCommand(lua_State* lua);
static int RedisACallCommand(lua_State* lua);
static int RedisAPCallCommand(lua_State* lua);
lua_State* lua_;
unsigned cmd_depth_ = 0;

View file

@ -13,6 +13,13 @@
namespace facade {
// Reply mode allows filtering replies.
enum class ReplyMode {
NONE, // No replies are recorded
ONLY_ERR, // Only errors are recorded
FULL // All replies are recorded
};
class SinkReplyBuilder {
public:
struct ResponseValue {

View file

@ -6,31 +6,42 @@
#include "base/logging.h"
#include "reply_capture.h"
#define SKIP_LESS(needed) \
if (reply_mode_ < needed) { \
current_ = monostate{}; \
return; \
}
namespace facade {
using namespace std;
void CapturingReplyBuilder::SendError(std::string_view str, std::string_view type) {
SKIP_LESS(ReplyMode::ONLY_ERR);
Capture(Error{str, type});
}
void CapturingReplyBuilder::SendMGetResponse(absl::Span<const OptResp> arr) {
SKIP_LESS(ReplyMode::FULL);
Capture(vector<OptResp>{arr.begin(), arr.end()});
}
void CapturingReplyBuilder::SendError(OpStatus status) {
SKIP_LESS(ReplyMode::ONLY_ERR);
Capture(status);
}
void CapturingReplyBuilder::SendNullArray() {
SKIP_LESS(ReplyMode::FULL);
Capture(unique_ptr<CollectionPayload>{nullptr});
}
void CapturingReplyBuilder::SendEmptyArray() {
SKIP_LESS(ReplyMode::FULL);
Capture(make_unique<CollectionPayload>(0, ARRAY));
}
void CapturingReplyBuilder::SendSimpleStrArr(StrSpan arr) {
SKIP_LESS(ReplyMode::FULL);
DCHECK_EQ(current_.index(), 0u);
WrappedStrSpan warr{arr};
@ -42,6 +53,7 @@ void CapturingReplyBuilder::SendSimpleStrArr(StrSpan arr) {
}
void CapturingReplyBuilder::SendStringArr(StrSpan arr, CollectionType type) {
SKIP_LESS(ReplyMode::FULL);
DCHECK_EQ(current_.index(), 0u);
// TODO: 1. Allocate all strings at once 2. Allow movable types
@ -54,31 +66,38 @@ void CapturingReplyBuilder::SendStringArr(StrSpan arr, CollectionType type) {
}
void CapturingReplyBuilder::SendNull() {
SKIP_LESS(ReplyMode::FULL);
Capture(nullptr_t{});
}
void CapturingReplyBuilder::SendLong(long val) {
SKIP_LESS(ReplyMode::FULL);
Capture(val);
}
void CapturingReplyBuilder::SendDouble(double val) {
SKIP_LESS(ReplyMode::FULL);
Capture(val);
}
void CapturingReplyBuilder::SendSimpleString(std::string_view str) {
SKIP_LESS(ReplyMode::FULL);
Capture(SimpleString{string{str}});
}
void CapturingReplyBuilder::SendBulkString(std::string_view str) {
SKIP_LESS(ReplyMode::FULL);
Capture(BulkString{string{str}});
}
void CapturingReplyBuilder::SendScoredArray(const std::vector<std::pair<std::string, double>>& arr,
bool with_scores) {
SKIP_LESS(ReplyMode::FULL);
Capture(ScoredArray{arr, with_scores});
}
void CapturingReplyBuilder::StartCollection(unsigned len, CollectionType type) {
SKIP_LESS(ReplyMode::FULL);
stack_.emplace(make_unique<CollectionPayload>(len, type), type == MAP ? len * 2 : len);
// If we added an empty collection, it must be collapsed immediately.
@ -92,6 +111,17 @@ CapturingReplyBuilder::Payload CapturingReplyBuilder::Take() {
return pl;
}
void CapturingReplyBuilder::SendDirect(Payload&& val) {
bool is_err = holds_alternative<Error>(val) || holds_alternative<OpStatus>(val);
ReplyMode min_mode = is_err ? ReplyMode::ONLY_ERR : ReplyMode::FULL;
if (reply_mode_ >= min_mode) {
DCHECK_EQ(current_.index(), 0u);
current_ = move(val);
} else {
current_ = monostate{};
}
}
void CapturingReplyBuilder::Capture(Payload val) {
if (!stack_.empty()) {
stack_.top().first->arr.push_back(std::move(val));
@ -183,8 +213,25 @@ struct CaptureVisitor {
};
void CapturingReplyBuilder::Apply(Payload&& pl, RedisReplyBuilder* rb) {
if (auto* crb = dynamic_cast<CapturingReplyBuilder*>(rb); crb != nullptr) {
crb->SendDirect(move(pl));
return;
}
CaptureVisitor cv{rb};
visit(cv, pl);
}
void CapturingReplyBuilder::SetReplyMode(ReplyMode mode) {
reply_mode_ = mode;
current_ = monostate{};
}
optional<CapturingReplyBuilder::ErrorRef> CapturingReplyBuilder::GetError(const Payload& pl) {
if (auto* err = get_if<Error>(&pl); err != nullptr) {
return ErrorRef{err->first, err->second};
}
return nullopt;
}
} // namespace facade

View file

@ -45,7 +45,7 @@ class CapturingReplyBuilder : public RedisReplyBuilder {
void StartCollection(unsigned len, CollectionType type) override;
private:
using Error = std::pair<std::string, std::string>; // SendError
using Error = std::pair<std::string, std::string>; // SendError (msg, type)
using Null = std::nullptr_t; // SendNull or SendNullArray
struct SimpleString : public std::string {}; // SendSimpleString
struct BulkString : public std::string {}; // SendBulkString
@ -64,19 +64,28 @@ class CapturingReplyBuilder : public RedisReplyBuilder {
};
public:
CapturingReplyBuilder() : RedisReplyBuilder{nullptr}, stack_{}, current_{} {
CapturingReplyBuilder(ReplyMode mode = ReplyMode::FULL)
: RedisReplyBuilder{nullptr}, reply_mode_{mode}, stack_{}, current_{} {
}
using Payload = std::variant<std::monostate, Null, Error, OpStatus, long, double, SimpleString,
BulkString, StrArrPayload, std::unique_ptr<CollectionPayload>,
std::vector<OptResp>, ScoredArray>;
// Non owned Error based on SendError arguments (msg, type)
using ErrorRef = std::pair<std::string_view, std::string_view>;
void SetReplyMode(ReplyMode mode);
// Take payload and clear state.
Payload Take();
// Send payload to builder.
static void Apply(Payload&& pl, RedisReplyBuilder* builder);
// If an error is stored inside payload, get a reference to it.
static std::optional<ErrorRef> GetError(const Payload& pl);
private:
struct CollectionPayload {
CollectionPayload(unsigned len, CollectionType type);
@ -87,12 +96,18 @@ class CapturingReplyBuilder : public RedisReplyBuilder {
};
private:
// Send payload directly, bypassing external interface. For efficient passing between two
// captures.
void SendDirect(Payload&& val);
// Capture value and store eiter in current topmost collection or as a standalone value.
void Capture(Payload val);
// While topmost collection in stack is full, finalize it and add it as a regular value.
void CollapseFilledCollections();
ReplyMode reply_mode_;
// List of nested active collections that are being built.
std::stack<std::pair<std::unique_ptr<CollectionPayload>, int>> stack_;

View file

@ -16,8 +16,8 @@ namespace dfly {
using namespace std;
using namespace facade;
StoredCmd::StoredCmd(const CommandId* cid, CmdArgList args)
: cid_{cid}, buffer_{}, sizes_(args.size()) {
StoredCmd::StoredCmd(const CommandId* cid, CmdArgList args, facade::ReplyMode mode)
: cid_{cid}, buffer_{}, sizes_(args.size()), reply_mode_{mode} {
size_t total_size = 0;
for (auto args : args)
total_size += args.size();
@ -31,8 +31,8 @@ StoredCmd::StoredCmd(const CommandId* cid, CmdArgList args)
}
}
StoredCmd::StoredCmd(string&& buffer, const CommandId* cid, CmdArgList args)
: cid_{cid}, buffer_{move(buffer)}, sizes_(args.size()) {
StoredCmd::StoredCmd(string&& buffer, const CommandId* cid, CmdArgList args, facade::ReplyMode mode)
: cid_{cid}, buffer_{move(buffer)}, sizes_(args.size()), reply_mode_{mode} {
for (unsigned i = 0; i < args.size(); i++) {
// Assume tightly packed list.
DCHECK(i + 1 == args.size() || args[i].data() + args[i].size() == args[i + 1].data());
@ -53,6 +53,23 @@ size_t StoredCmd::NumArgs() const {
return sizes_.size();
}
facade::ReplyMode StoredCmd::ReplyMode() const {
return reply_mode_;
}
template <typename C> size_t IsStoredInlined(const C& c) {
const char* start = reinterpret_cast<const char*>(&c);
const char* end = start + sizeof(C);
const char* data = reinterpret_cast<const char*>(c.data());
return data >= start && data <= end;
}
size_t StoredCmd::UsedHeapMemory() const {
size_t buffer_size = IsStoredInlined(buffer_) ? 0 : buffer_.size();
size_t sz_size = IsStoredInlined(sizes_) ? 0 : sizes_.size() * sizeof(uint32_t);
return buffer_size + sz_size;
}
const CommandId* StoredCmd::Cid() const {
return cid_;
}

View file

@ -22,23 +22,30 @@ class ChannelStore;
// Used for storing MULTI/EXEC commands.
class StoredCmd {
public:
StoredCmd(const CommandId* cid, CmdArgList args);
StoredCmd(const CommandId* cid, CmdArgList args,
facade::ReplyMode mode = facade::ReplyMode::FULL);
// Create on top of already filled tightly-packed buffer.
StoredCmd(std::string&& buffer, const CommandId* cid, CmdArgList args);
StoredCmd(std::string&& buffer, const CommandId* cid, CmdArgList args,
facade::ReplyMode mode = facade::ReplyMode::FULL);
size_t NumArgs() const;
size_t UsedHeapMemory() const;
// Fill the arg list with stored arguments, it should be at least of size NumArgs().
// Between filling and invocation, cmd should NOT be moved.
void Fill(CmdArgList args);
const CommandId* Cid() const;
facade::ReplyMode ReplyMode() const;
private:
const CommandId* cid_; // underlying command
std::string buffer_; // underlying buffer
absl::FixedArray<uint32_t, 4> sizes_; // sizes of arg parts
absl::FixedArray<uint32_t, 4> sizes_; // sizes of arg part
facade::ReplyMode reply_mode_; // reply mode
};
struct ConnectionState {
@ -75,7 +82,10 @@ struct ConnectionState {
struct ScriptInfo {
bool is_write = true;
absl::flat_hash_set<std::string_view> keys; // declared keys
std::vector<StoredCmd> async_cmds; // aggregated by acall
size_t async_cmds_heap_mem = 0; // bytes used by async_cmds
size_t async_cmds_heap_limit = 0; // max bytes allowed for async_cmds
std::vector<StoredCmd> async_cmds; // aggregated by acall
};
// PUB-SUB messaging related data.
@ -117,7 +127,7 @@ struct ConnectionState {
ExecInfo exec_info;
ReplicationInfo replicaiton_info;
std::optional<ScriptInfo> script_info;
std::unique_ptr<ScriptInfo> script_info;
std::unique_ptr<SubscribeInfo> subscribe_info;
};

View file

@ -41,17 +41,20 @@ extern "C" {
#include "util/varz.h"
using namespace std;
using dfly::operator""_KB;
ABSL_FLAG(uint32_t, port, 6379, "Redis port");
ABSL_FLAG(uint32_t, memcache_port, 0, "Memcached port");
ABSL_FLAG(uint32_t, num_shards, 0, "Number of database shards, 0 - to choose automatically");
ABSL_FLAG(uint32_t, multi_exec_mode, 1,
"Set multi exec atomicity mode: 1 for global, 2 for locking ahead, 3 for non atomic");
ABSL_FLAG(bool, multi_exec_squash, true,
"Whether multi exec will squash single shard commands to optimize performance");
ABSL_FLAG(uint32_t, num_shards, 0, "Number of database shards, 0 - to choose automatically");
ABSL_FLAG(uint32_t, multi_eval_squash_buffer, 4_KB, "Max buffer for squashed commands per script");
namespace dfly {
@ -154,7 +157,7 @@ std::string MakeMonitorMessage(const ConnectionState& conn_state,
const facade::Connection* connection, CmdArgList args) {
std::string message = absl::StrCat(CreateMonitorTimestamp(), " [", conn_state.db_index);
if (conn_state.script_info.has_value()) {
if (conn_state.script_info) {
absl::StrAppend(&message, " lua] ");
} else {
auto endpoint = connection == nullptr ? "REPLICATION:0" : connection->RemoteEndpointStr();
@ -598,7 +601,7 @@ bool Service::VerifyCommand(const CommandId* cid, CmdArgList args,
ConnectionContext* dfly_cntx = static_cast<ConnectionContext*>(cntx);
bool is_trans_cmd = (cmd_str == "EXEC" || cmd_str == "MULTI" || cmd_str == "DISCARD");
bool under_script = dfly_cntx->conn_state.script_info.has_value();
bool under_script = bool(dfly_cntx->conn_state.script_info);
absl::Cleanup multi_error([dfly_cntx] { SetMultiExecErrorFlag(dfly_cntx); });
@ -702,7 +705,7 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx)
ToUpper(&args[0]);
ConnectionContext* dfly_cntx = static_cast<ConnectionContext*>(cntx);
bool under_script = dfly_cntx->conn_state.script_info.has_value();
bool under_script = bool(dfly_cntx->conn_state.script_info);
if (VLOG_IS_ON(2) &&
cntx->owner()) { // owner may not exists in case of this being called from replica context
@ -1015,32 +1018,34 @@ void Service::Unwatch(CmdArgList args, ConnectionContext* cntx) {
return (*cntx)->SendOk();
}
template <typename F> void WithoutReplies(ConnectionContext* cntx, F&& f) {
io::NullSink null_sink;
facade::RedisReplyBuilder rrb{&null_sink};
auto* old_rrb = cntx->Inject(&rrb);
template <typename F> void WithReplies(CapturingReplyBuilder* crb, ConnectionContext* cntx, F&& f) {
SinkReplyBuilder* old_rrb = nullptr;
old_rrb = cntx->Inject(crb);
f();
cntx->Inject(old_rrb);
}
void Service::FlushEvalAsyncCmds(ConnectionContext* cntx, bool force) {
const int kMaxAsyncCmds = 100;
optional<CapturingReplyBuilder::Payload> Service::FlushEvalAsyncCmds(ConnectionContext* cntx,
bool force) {
auto& info = cntx->conn_state.script_info;
if ((!force && info->async_cmds.size() <= kMaxAsyncCmds) || info->async_cmds.empty())
return;
size_t used_mem = info->async_cmds_heap_mem + info->async_cmds.size() * sizeof(StoredCmd);
if ((info->async_cmds.empty() || !force) && used_mem < info->async_cmds_heap_limit)
return nullopt;
auto* eval_cid = registry_.Find("EVAL");
DCHECK(eval_cid);
cntx->transaction->MultiSwitchCmd(eval_cid);
WithoutReplies(cntx,
[&] { MultiCommandSquasher::Execute(absl::MakeSpan(info->async_cmds), cntx); });
CapturingReplyBuilder crb{ReplyMode::ONLY_ERR};
WithReplies(&crb, cntx,
[&] { MultiCommandSquasher::Execute(absl::MakeSpan(info->async_cmds), cntx, true); });
info->async_cmds_heap_mem = 0;
info->async_cmds.clear();
auto reply = move(crb.Take());
return CapturingReplyBuilder::GetError(reply) ? make_optional(move(reply)) : nullopt;
}
void Service::CallFromScript(ConnectionContext* cntx, Interpreter::CallArgs& ca) {
@ -1051,20 +1056,25 @@ void Service::CallFromScript(ConnectionContext* cntx, Interpreter::CallArgs& ca)
auto& info = cntx->conn_state.script_info;
auto* cid = registry_.Find(facade::ToSV(ca.args[0]));
bool valid = true;
WithoutReplies(cntx, [&] { valid = VerifyCommand(cid, ca.args, cntx); });
if (!valid) // TODO: collect errors with capturing reply builder.
if (!VerifyCommand(cid, ca.args, cntx))
return;
info->async_cmds.emplace_back(move(*ca.buffer), cid, ca.args.subspan(1));
FlushEvalAsyncCmds(cntx, false);
auto replies = ca.error_abort ? ReplyMode::ONLY_ERR : ReplyMode::NONE;
info->async_cmds.emplace_back(move(*ca.buffer), cid, ca.args.subspan(1), replies);
info->async_cmds_heap_mem += info->async_cmds.back().UsedHeapMemory();
}
InterpreterReplier replier(ca.translator);
if (auto err = FlushEvalAsyncCmds(cntx, !ca.async); err) {
CapturingReplyBuilder::Apply(move(*err), &replier); // forward error to lua
*ca.requested_abort = true;
return;
}
FlushEvalAsyncCmds(cntx, true);
if (ca.async)
return;
InterpreterReplier replier(ca.translator);
facade::SinkReplyBuilder* orig = cntx->Inject(&replier);
DispatchCommand(ca.args, cntx);
@ -1208,10 +1218,12 @@ void Service::EvalInternal(const EvalArgs& eval_args, Interpreter* interpreter,
// TODO: to determine whether the script is RO by scanning all "redis.p?call" calls
// and checking whether all invocations consist of RO commands.
// we can do it once during script insertion into script mgr.
cntx->conn_state.script_info.emplace(ConnectionState::ScriptInfo{});
auto& sinfo = cntx->conn_state.script_info;
sinfo.reset(new ConnectionState::ScriptInfo{});
for (size_t i = 0; i < eval_args.keys.size(); ++i) {
cntx->conn_state.script_info->keys.insert(ArgS(eval_args.keys, i));
sinfo->keys.insert(ArgS(eval_args.keys, i));
}
sinfo->async_cmds_heap_limit = absl::GetFlag(FLAGS_multi_eval_squash_buffer);
DCHECK(cntx->transaction);
bool scheduled = StartMultiEval(cntx->db_index(), eval_args.keys, *params, cntx->transaction);
@ -1223,7 +1235,11 @@ void Service::EvalInternal(const EvalArgs& eval_args, Interpreter* interpreter,
Interpreter::RunResult result = interpreter->RunFunction(eval_args.sha, &error);
absl::Cleanup clean = [interpreter]() { interpreter->ResetStack(); };
FlushEvalAsyncCmds(cntx, true);
if (auto err = FlushEvalAsyncCmds(cntx, true); err) {
auto err_ref = CapturingReplyBuilder::GetError(*err);
result = Interpreter::RUN_ERR;
error = absl::StrCat(err_ref->first);
}
cntx->conn_state.script_info.reset(); // reset script_info

View file

@ -120,7 +120,10 @@ class Service : public facade::ServiceInterface {
void EvalInternal(const EvalArgs& eval_args, Interpreter* interpreter, ConnectionContext* cntx);
void FlushEvalAsyncCmds(ConnectionContext* cntx, bool force = false);
// Return optional payload - first received error that occured when executing commands.
std::optional<facade::CapturingReplyBuilder::Payload> FlushEvalAsyncCmds(ConnectionContext* cntx,
bool force = false);
void CallFromScript(ConnectionContext* cntx, Interpreter::CallArgs& args);
void RegisterCommands();

View file

@ -24,8 +24,9 @@ template <typename F> void IterateKeys(CmdArgList args, KeyIndex keys, F&& f) {
} // namespace
MultiCommandSquasher::MultiCommandSquasher(absl::Span<StoredCmd> cmds, ConnectionContext* cntx)
: cmds_{cmds}, cntx_{cntx}, base_cid_{cntx->transaction->GetCId()} {
MultiCommandSquasher::MultiCommandSquasher(absl::Span<StoredCmd> cmds, ConnectionContext* cntx,
bool error_abort)
: cmds_{cmds}, cntx_{cntx}, base_cid_{cntx->transaction->GetCId()}, error_abort_{error_abort} {
auto mode = cntx->transaction->GetMultiMode();
track_keys_ = mode == Transaction::NON_ATOMIC;
}
@ -38,9 +39,6 @@ MultiCommandSquasher::ShardExecInfo& MultiCommandSquasher::PrepareShardInfo(Shar
if (!sinfo.local_tx)
sinfo.local_tx = new Transaction{cntx_->transaction};
if (!sinfo.reply_chan)
sinfo.reply_chan = make_unique<ReplyChan>(kChanBufferSize, 1);
return sinfo;
}
@ -85,7 +83,7 @@ MultiCommandSquasher::SquashResult MultiCommandSquasher::TrySquash(StoredCmd* cm
// Because the squashed hop is currently blocking, we cannot add more than the max channel size,
// otherwise a deadlock occurs.
bool need_flush = sinfo.cmds.size() >= kChanBufferSize - 1;
bool need_flush = sinfo.cmds.size() >= kMaxSquashing - 1;
return need_flush ? SquashResult::SQUASHED_FULL : SquashResult::SQUASHED;
}
@ -118,6 +116,7 @@ OpStatus MultiCommandSquasher::SquashedHopCb(Transaction* parent_tx, EngineShard
for (auto* cmd : sinfo.cmds) {
local_tx->MultiSwitchCmd(cmd->Cid());
local_cntx.cid = cmd->Cid();
crb.SetReplyMode(cmd->ReplyMode());
arg_vec.resize(cmd->NumArgs());
auto args = absl::MakeSpan(arg_vec);
@ -126,18 +125,20 @@ OpStatus MultiCommandSquasher::SquashedHopCb(Transaction* parent_tx, EngineShard
local_tx->InitByArgs(parent_tx->GetDbIndex(), args);
cmd->Cid()->Invoke(args, &local_cntx);
sinfo.reply_chan->Push(crb.Take());
sinfo.replies.emplace_back(crb.Take());
}
// ConnectionContext deletes the reply builder upon destruction, so
// remove our local pointer from it.
local_cntx.Inject(nullptr);
reverse(sinfo.replies.begin(), sinfo.replies.end());
return OpStatus::OK;
}
void MultiCommandSquasher::ExecuteSquashed() {
bool MultiCommandSquasher::ExecuteSquashed() {
if (order_.empty())
return;
return false;
Transaction* tx = cntx_->transaction;
@ -149,14 +150,26 @@ void MultiCommandSquasher::ExecuteSquashed() {
tx->PrepareSquashedMultiHop(base_cid_, cb);
}
for (auto& sd : sharded_)
sd.replies.reserve(sd.cmds.size());
cntx_->cid = base_cid_;
tx->ScheduleSingleHop([this](auto* tx, auto* es) { return SquashedHopCb(tx, es); });
facade::CapturingReplyBuilder::Payload payload;
bool aborted = false;
RedisReplyBuilder* rb = static_cast<RedisReplyBuilder*>(cntx_->reply_builder());
for (auto idx : order_) {
CHECK(sharded_[idx].reply_chan->Pop(payload));
CapturingReplyBuilder::Apply(move(payload), rb);
auto& replies = sharded_[idx].replies;
CHECK(!replies.empty());
aborted |= error_abort_ && CapturingReplyBuilder::GetError(replies.back());
CapturingReplyBuilder::Apply(move(replies.back()), rb);
replies.pop_back();
if (aborted)
break;
}
for (auto& sinfo : sharded_)
@ -164,6 +177,7 @@ void MultiCommandSquasher::ExecuteSquashed() {
order_.clear();
collected_keys_.clear();
return aborted;
}
void MultiCommandSquasher::Run() {
@ -173,8 +187,10 @@ void MultiCommandSquasher::Run() {
if (res == SquashResult::ERROR)
break;
if (res == SquashResult::NOT_SQUASHED || res == SquashResult::SQUASHED_FULL)
ExecuteSquashed();
if (res == SquashResult::NOT_SQUASHED || res == SquashResult::SQUASHED_FULL) {
if (ExecuteSquashed())
break;
}
if (res == SquashResult::NOT_SQUASHED)
ExecuteStandalone(&cmd);

View file

@ -16,32 +16,29 @@ namespace dfly {
// thus greatly decreasing the dispatch overhead for them.
class MultiCommandSquasher {
public:
static void Execute(absl::Span<StoredCmd> cmds, ConnectionContext* cntx) {
MultiCommandSquasher{cmds, cntx}.Run();
static void Execute(absl::Span<StoredCmd> cmds, ConnectionContext* cntx,
bool error_abort = false) {
MultiCommandSquasher{cmds, cntx, error_abort}.Run();
}
private:
using ReplyChan =
dfly::SimpleChannel<facade::CapturingReplyBuilder::Payload,
base::mpmc_bounded_queue<facade::CapturingReplyBuilder::Payload>>;
// Per-shard exection info.
struct ShardExecInfo {
ShardExecInfo() : had_writes{false}, cmds{}, reply_chan{nullptr}, local_tx{nullptr} {
ShardExecInfo() : had_writes{false}, cmds{}, replies{}, local_tx{nullptr} {
}
bool had_writes;
std::vector<StoredCmd*> cmds; // accumulated commands
std::unique_ptr<ReplyChan> reply_chan;
std::vector<facade::CapturingReplyBuilder::Payload> replies;
boost::intrusive_ptr<Transaction> local_tx; // stub-mode tx for use inside shard
};
enum class SquashResult { SQUASHED, SQUASHED_FULL, NOT_SQUASHED, ERROR };
static constexpr int kChanBufferSize = 32;
static constexpr int kMaxSquashing = 32;
private:
MultiCommandSquasher(absl::Span<StoredCmd> cmds, ConnectionContext* cntx);
MultiCommandSquasher(absl::Span<StoredCmd> cmds, ConnectionContext* cntx, bool error_abort);
// Lazy initialize shard info.
ShardExecInfo& PrepareShardInfo(ShardId sid);
@ -55,8 +52,8 @@ class MultiCommandSquasher {
// Callback that runs on shards during squashed hop.
facade::OpStatus SquashedHopCb(Transaction* parent_tx, EngineShard* es);
// Execute all currently squashed commands.
void ExecuteSquashed();
// Execute all currently squashed commands. Return true if aborting on error.
bool ExecuteSquashed();
// Run all commands until completion.
void Run();
@ -66,6 +63,8 @@ class MultiCommandSquasher {
ConnectionContext* cntx_; // Underlying context
const CommandId* base_cid_; // either EVAL or EXEC, used for squashed hops
bool error_abort_ = false; // Abort upon receiving error
std::vector<ShardExecInfo> sharded_;
std::vector<ShardId> order_; // reply order for squashed cmds

View file

@ -177,7 +177,6 @@ async def test_golang_asynq_script(async_pool, num_queues=10, num_tasks=100):
jobs = [asyncio.create_task(enqueue_worker(
f"q-{queue}")) for queue in range(num_queues)]
collected = 0
async def dequeue_worker():
@ -202,20 +201,23 @@ async def test_golang_asynq_script(async_pool, num_queues=10, num_tasks=100):
for job in jobs:
await job
ERROR_CALL_SCRIPT = """
redis.call('ECHO', 'I', 'want', 'an', 'error')
ERROR_CALL_SCRIPT_TEMPLATE = """
redis.{}('LTRIM', 'l', 'a', 'b')
"""
ERROR_PCALL_SCRIPT = """
redis.pcall('ECHO', 'I', 'want', 'an', 'error')
"""
@dfly_args({"proactor_threads": 1})
@pytest.mark.asyncio
async def test_eval_error_propagation(async_client):
assert await async_client.eval(ERROR_PCALL_SCRIPT, 0) is None
CMDS = ['call', 'pcall', 'acall', 'apcall']
try:
await async_client.eval(ERROR_CALL_SCRIPT, 0)
assert False, "Eval must have thrown an error"
except aioredis.RedisError as e:
pass
for cmd in CMDS:
does_abort = 'p' not in cmd
try:
await async_client.eval(ERROR_CALL_SCRIPT_TEMPLATE.format(cmd), 1, 'l')
if does_abort:
assert False, "Eval must have thrown an error: " + cmd
except aioredis.RedisError as e:
if not does_abort:
assert False, "Error should have been ignored: " + cmd