From 74a1454409bad4e3eb3d4915c3df1bd33ffe0403 Mon Sep 17 00:00:00 2001 From: Vladislav Date: Sat, 22 Apr 2023 09:02:22 +0300 Subject: [PATCH] Collect errors from async lua calls (#1092) Error collection from async calls --- src/core/interpreter.cc | 12 ++++- src/core/interpreter.h | 9 +++- src/facade/reply_builder.h | 7 +++ src/facade/reply_capture.cc | 47 ++++++++++++++++++ src/facade/reply_capture.h | 19 +++++++- src/server/conn_context.cc | 25 ++++++++-- src/server/conn_context.h | 20 ++++++-- src/server/main_service.cc | 72 +++++++++++++++++----------- src/server/main_service.h | 5 +- src/server/multi_command_squasher.cc | 44 +++++++++++------ src/server/multi_command_squasher.h | 23 +++++---- tests/dragonfly/eval_test.py | 26 +++++----- 12 files changed, 229 insertions(+), 80 deletions(-) diff --git a/src/core/interpreter.cc b/src/core/interpreter.cc index 79882f304..35b041b7b 100644 --- a/src/core/interpreter.cc +++ b/src/core/interpreter.cc @@ -377,6 +377,11 @@ Interpreter::Interpreter() { lua_pushcfunction(lua_, RedisACallCommand); lua_settable(lua_, -3); + /* redis.apcall */ + lua_pushstring(lua_, "apcall"); + lua_pushcfunction(lua_, RedisAPCallCommand); + lua_settable(lua_, -3); + lua_pushstring(lua_, "sha1hex"); lua_pushcfunction(lua_, RedisSha1Command); lua_settable(lua_, -3); @@ -720,7 +725,7 @@ int Interpreter::RedisGenericCommand(bool raise_error, bool async) { * and this way we guaranty we will have room on the stack for the result. */ lua_pop(lua_, argc); RedisTranslator translator(lua_); - redis_func_(CallArgs{MutSliceSpan{args}, &buffer, &translator, async}); + redis_func_(CallArgs{MutSliceSpan{args}, &buffer, &translator, async, raise_error, &raise_error}); cmd_depth_--; // Raise error for regular 'call' command if needed. @@ -746,6 +751,11 @@ int Interpreter::RedisPCallCommand(lua_State* lua) { } int Interpreter::RedisACallCommand(lua_State* lua) { + void** ptr = static_cast(lua_getextraspace(lua)); + return reinterpret_cast(*ptr)->RedisGenericCommand(true, true); +} + +int Interpreter::RedisAPCallCommand(lua_State* lua) { void** ptr = static_cast(lua_getextraspace(lua)); return reinterpret_cast(*ptr)->RedisGenericCommand(false, true); } diff --git a/src/core/interpreter.h b/src/core/interpreter.h index 653c210a4..1f25d260e 100644 --- a/src/core/interpreter.h +++ b/src/core/interpreter.h @@ -43,7 +43,12 @@ class Interpreter { ObjectExplorer* translator; - bool async; // async by redis.acall + bool async; // async by acall + bool error_abort; // abort on errors (not pcall) + + // The function can request an abort due to an error, even if error_abort is false. + // It happens when async cmds are flushed and result in an uncatched error. + bool* requested_abort; }; using RedisFunc = std::function; @@ -112,10 +117,12 @@ class Interpreter { bool IsTableSafe() const; int RedisGenericCommand(bool raise_error, bool async); + int RedisACallErrorsCommand(); static int RedisCallCommand(lua_State* lua); static int RedisPCallCommand(lua_State* lua); static int RedisACallCommand(lua_State* lua); + static int RedisAPCallCommand(lua_State* lua); lua_State* lua_; unsigned cmd_depth_ = 0; diff --git a/src/facade/reply_builder.h b/src/facade/reply_builder.h index 9d71e15cb..3bb2eafae 100644 --- a/src/facade/reply_builder.h +++ b/src/facade/reply_builder.h @@ -13,6 +13,13 @@ namespace facade { +// Reply mode allows filtering replies. +enum class ReplyMode { + NONE, // No replies are recorded + ONLY_ERR, // Only errors are recorded + FULL // All replies are recorded +}; + class SinkReplyBuilder { public: struct ResponseValue { diff --git a/src/facade/reply_capture.cc b/src/facade/reply_capture.cc index 4e975d83b..6e7a1a26b 100644 --- a/src/facade/reply_capture.cc +++ b/src/facade/reply_capture.cc @@ -6,31 +6,42 @@ #include "base/logging.h" #include "reply_capture.h" +#define SKIP_LESS(needed) \ + if (reply_mode_ < needed) { \ + current_ = monostate{}; \ + return; \ + } namespace facade { using namespace std; void CapturingReplyBuilder::SendError(std::string_view str, std::string_view type) { + SKIP_LESS(ReplyMode::ONLY_ERR); Capture(Error{str, type}); } void CapturingReplyBuilder::SendMGetResponse(absl::Span arr) { + SKIP_LESS(ReplyMode::FULL); Capture(vector{arr.begin(), arr.end()}); } void CapturingReplyBuilder::SendError(OpStatus status) { + SKIP_LESS(ReplyMode::ONLY_ERR); Capture(status); } void CapturingReplyBuilder::SendNullArray() { + SKIP_LESS(ReplyMode::FULL); Capture(unique_ptr{nullptr}); } void CapturingReplyBuilder::SendEmptyArray() { + SKIP_LESS(ReplyMode::FULL); Capture(make_unique(0, ARRAY)); } void CapturingReplyBuilder::SendSimpleStrArr(StrSpan arr) { + SKIP_LESS(ReplyMode::FULL); DCHECK_EQ(current_.index(), 0u); WrappedStrSpan warr{arr}; @@ -42,6 +53,7 @@ void CapturingReplyBuilder::SendSimpleStrArr(StrSpan arr) { } void CapturingReplyBuilder::SendStringArr(StrSpan arr, CollectionType type) { + SKIP_LESS(ReplyMode::FULL); DCHECK_EQ(current_.index(), 0u); // TODO: 1. Allocate all strings at once 2. Allow movable types @@ -54,31 +66,38 @@ void CapturingReplyBuilder::SendStringArr(StrSpan arr, CollectionType type) { } void CapturingReplyBuilder::SendNull() { + SKIP_LESS(ReplyMode::FULL); Capture(nullptr_t{}); } void CapturingReplyBuilder::SendLong(long val) { + SKIP_LESS(ReplyMode::FULL); Capture(val); } void CapturingReplyBuilder::SendDouble(double val) { + SKIP_LESS(ReplyMode::FULL); Capture(val); } void CapturingReplyBuilder::SendSimpleString(std::string_view str) { + SKIP_LESS(ReplyMode::FULL); Capture(SimpleString{string{str}}); } void CapturingReplyBuilder::SendBulkString(std::string_view str) { + SKIP_LESS(ReplyMode::FULL); Capture(BulkString{string{str}}); } void CapturingReplyBuilder::SendScoredArray(const std::vector>& arr, bool with_scores) { + SKIP_LESS(ReplyMode::FULL); Capture(ScoredArray{arr, with_scores}); } void CapturingReplyBuilder::StartCollection(unsigned len, CollectionType type) { + SKIP_LESS(ReplyMode::FULL); stack_.emplace(make_unique(len, type), type == MAP ? len * 2 : len); // If we added an empty collection, it must be collapsed immediately. @@ -92,6 +111,17 @@ CapturingReplyBuilder::Payload CapturingReplyBuilder::Take() { return pl; } +void CapturingReplyBuilder::SendDirect(Payload&& val) { + bool is_err = holds_alternative(val) || holds_alternative(val); + ReplyMode min_mode = is_err ? ReplyMode::ONLY_ERR : ReplyMode::FULL; + if (reply_mode_ >= min_mode) { + DCHECK_EQ(current_.index(), 0u); + current_ = move(val); + } else { + current_ = monostate{}; + } +} + void CapturingReplyBuilder::Capture(Payload val) { if (!stack_.empty()) { stack_.top().first->arr.push_back(std::move(val)); @@ -183,8 +213,25 @@ struct CaptureVisitor { }; void CapturingReplyBuilder::Apply(Payload&& pl, RedisReplyBuilder* rb) { + if (auto* crb = dynamic_cast(rb); crb != nullptr) { + crb->SendDirect(move(pl)); + return; + } + CaptureVisitor cv{rb}; visit(cv, pl); } +void CapturingReplyBuilder::SetReplyMode(ReplyMode mode) { + reply_mode_ = mode; + current_ = monostate{}; +} + +optional CapturingReplyBuilder::GetError(const Payload& pl) { + if (auto* err = get_if(&pl); err != nullptr) { + return ErrorRef{err->first, err->second}; + } + return nullopt; +} + } // namespace facade diff --git a/src/facade/reply_capture.h b/src/facade/reply_capture.h index dddc27b18..97f8d5a8d 100644 --- a/src/facade/reply_capture.h +++ b/src/facade/reply_capture.h @@ -45,7 +45,7 @@ class CapturingReplyBuilder : public RedisReplyBuilder { void StartCollection(unsigned len, CollectionType type) override; private: - using Error = std::pair; // SendError + using Error = std::pair; // SendError (msg, type) using Null = std::nullptr_t; // SendNull or SendNullArray struct SimpleString : public std::string {}; // SendSimpleString struct BulkString : public std::string {}; // SendBulkString @@ -64,19 +64,28 @@ class CapturingReplyBuilder : public RedisReplyBuilder { }; public: - CapturingReplyBuilder() : RedisReplyBuilder{nullptr}, stack_{}, current_{} { + CapturingReplyBuilder(ReplyMode mode = ReplyMode::FULL) + : RedisReplyBuilder{nullptr}, reply_mode_{mode}, stack_{}, current_{} { } using Payload = std::variant, std::vector, ScoredArray>; + // Non owned Error based on SendError arguments (msg, type) + using ErrorRef = std::pair; + + void SetReplyMode(ReplyMode mode); + // Take payload and clear state. Payload Take(); // Send payload to builder. static void Apply(Payload&& pl, RedisReplyBuilder* builder); + // If an error is stored inside payload, get a reference to it. + static std::optional GetError(const Payload& pl); + private: struct CollectionPayload { CollectionPayload(unsigned len, CollectionType type); @@ -87,12 +96,18 @@ class CapturingReplyBuilder : public RedisReplyBuilder { }; private: + // Send payload directly, bypassing external interface. For efficient passing between two + // captures. + void SendDirect(Payload&& val); + // Capture value and store eiter in current topmost collection or as a standalone value. void Capture(Payload val); // While topmost collection in stack is full, finalize it and add it as a regular value. void CollapseFilledCollections(); + ReplyMode reply_mode_; + // List of nested active collections that are being built. std::stack, int>> stack_; diff --git a/src/server/conn_context.cc b/src/server/conn_context.cc index 6135ec317..2ee5f4086 100644 --- a/src/server/conn_context.cc +++ b/src/server/conn_context.cc @@ -16,8 +16,8 @@ namespace dfly { using namespace std; using namespace facade; -StoredCmd::StoredCmd(const CommandId* cid, CmdArgList args) - : cid_{cid}, buffer_{}, sizes_(args.size()) { +StoredCmd::StoredCmd(const CommandId* cid, CmdArgList args, facade::ReplyMode mode) + : cid_{cid}, buffer_{}, sizes_(args.size()), reply_mode_{mode} { size_t total_size = 0; for (auto args : args) total_size += args.size(); @@ -31,8 +31,8 @@ StoredCmd::StoredCmd(const CommandId* cid, CmdArgList args) } } -StoredCmd::StoredCmd(string&& buffer, const CommandId* cid, CmdArgList args) - : cid_{cid}, buffer_{move(buffer)}, sizes_(args.size()) { +StoredCmd::StoredCmd(string&& buffer, const CommandId* cid, CmdArgList args, facade::ReplyMode mode) + : cid_{cid}, buffer_{move(buffer)}, sizes_(args.size()), reply_mode_{mode} { for (unsigned i = 0; i < args.size(); i++) { // Assume tightly packed list. DCHECK(i + 1 == args.size() || args[i].data() + args[i].size() == args[i + 1].data()); @@ -53,6 +53,23 @@ size_t StoredCmd::NumArgs() const { return sizes_.size(); } +facade::ReplyMode StoredCmd::ReplyMode() const { + return reply_mode_; +} + +template size_t IsStoredInlined(const C& c) { + const char* start = reinterpret_cast(&c); + const char* end = start + sizeof(C); + const char* data = reinterpret_cast(c.data()); + return data >= start && data <= end; +} + +size_t StoredCmd::UsedHeapMemory() const { + size_t buffer_size = IsStoredInlined(buffer_) ? 0 : buffer_.size(); + size_t sz_size = IsStoredInlined(sizes_) ? 0 : sizes_.size() * sizeof(uint32_t); + return buffer_size + sz_size; +} + const CommandId* StoredCmd::Cid() const { return cid_; } diff --git a/src/server/conn_context.h b/src/server/conn_context.h index 1a351e9b0..27f068b12 100644 --- a/src/server/conn_context.h +++ b/src/server/conn_context.h @@ -22,23 +22,30 @@ class ChannelStore; // Used for storing MULTI/EXEC commands. class StoredCmd { public: - StoredCmd(const CommandId* cid, CmdArgList args); + StoredCmd(const CommandId* cid, CmdArgList args, + facade::ReplyMode mode = facade::ReplyMode::FULL); // Create on top of already filled tightly-packed buffer. - StoredCmd(std::string&& buffer, const CommandId* cid, CmdArgList args); + StoredCmd(std::string&& buffer, const CommandId* cid, CmdArgList args, + facade::ReplyMode mode = facade::ReplyMode::FULL); size_t NumArgs() const; + size_t UsedHeapMemory() const; + // Fill the arg list with stored arguments, it should be at least of size NumArgs(). // Between filling and invocation, cmd should NOT be moved. void Fill(CmdArgList args); const CommandId* Cid() const; + facade::ReplyMode ReplyMode() const; + private: const CommandId* cid_; // underlying command std::string buffer_; // underlying buffer - absl::FixedArray sizes_; // sizes of arg parts + absl::FixedArray sizes_; // sizes of arg part + facade::ReplyMode reply_mode_; // reply mode }; struct ConnectionState { @@ -75,7 +82,10 @@ struct ConnectionState { struct ScriptInfo { bool is_write = true; absl::flat_hash_set keys; // declared keys - std::vector async_cmds; // aggregated by acall + + size_t async_cmds_heap_mem = 0; // bytes used by async_cmds + size_t async_cmds_heap_limit = 0; // max bytes allowed for async_cmds + std::vector async_cmds; // aggregated by acall }; // PUB-SUB messaging related data. @@ -117,7 +127,7 @@ struct ConnectionState { ExecInfo exec_info; ReplicationInfo replicaiton_info; - std::optional script_info; + std::unique_ptr script_info; std::unique_ptr subscribe_info; }; diff --git a/src/server/main_service.cc b/src/server/main_service.cc index bec534ba6..e8a1b4879 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -41,17 +41,20 @@ extern "C" { #include "util/varz.h" using namespace std; +using dfly::operator""_KB; ABSL_FLAG(uint32_t, port, 6379, "Redis port"); ABSL_FLAG(uint32_t, memcache_port, 0, "Memcached port"); +ABSL_FLAG(uint32_t, num_shards, 0, "Number of database shards, 0 - to choose automatically"); + ABSL_FLAG(uint32_t, multi_exec_mode, 1, "Set multi exec atomicity mode: 1 for global, 2 for locking ahead, 3 for non atomic"); ABSL_FLAG(bool, multi_exec_squash, true, "Whether multi exec will squash single shard commands to optimize performance"); -ABSL_FLAG(uint32_t, num_shards, 0, "Number of database shards, 0 - to choose automatically"); +ABSL_FLAG(uint32_t, multi_eval_squash_buffer, 4_KB, "Max buffer for squashed commands per script"); namespace dfly { @@ -154,7 +157,7 @@ std::string MakeMonitorMessage(const ConnectionState& conn_state, const facade::Connection* connection, CmdArgList args) { std::string message = absl::StrCat(CreateMonitorTimestamp(), " [", conn_state.db_index); - if (conn_state.script_info.has_value()) { + if (conn_state.script_info) { absl::StrAppend(&message, " lua] "); } else { auto endpoint = connection == nullptr ? "REPLICATION:0" : connection->RemoteEndpointStr(); @@ -598,7 +601,7 @@ bool Service::VerifyCommand(const CommandId* cid, CmdArgList args, ConnectionContext* dfly_cntx = static_cast(cntx); bool is_trans_cmd = (cmd_str == "EXEC" || cmd_str == "MULTI" || cmd_str == "DISCARD"); - bool under_script = dfly_cntx->conn_state.script_info.has_value(); + bool under_script = bool(dfly_cntx->conn_state.script_info); absl::Cleanup multi_error([dfly_cntx] { SetMultiExecErrorFlag(dfly_cntx); }); @@ -702,7 +705,7 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) ToUpper(&args[0]); ConnectionContext* dfly_cntx = static_cast(cntx); - bool under_script = dfly_cntx->conn_state.script_info.has_value(); + bool under_script = bool(dfly_cntx->conn_state.script_info); if (VLOG_IS_ON(2) && cntx->owner()) { // owner may not exists in case of this being called from replica context @@ -1015,32 +1018,34 @@ void Service::Unwatch(CmdArgList args, ConnectionContext* cntx) { return (*cntx)->SendOk(); } -template void WithoutReplies(ConnectionContext* cntx, F&& f) { - io::NullSink null_sink; - facade::RedisReplyBuilder rrb{&null_sink}; - auto* old_rrb = cntx->Inject(&rrb); - +template void WithReplies(CapturingReplyBuilder* crb, ConnectionContext* cntx, F&& f) { + SinkReplyBuilder* old_rrb = nullptr; + old_rrb = cntx->Inject(crb); f(); - cntx->Inject(old_rrb); } -void Service::FlushEvalAsyncCmds(ConnectionContext* cntx, bool force) { - const int kMaxAsyncCmds = 100; - +optional Service::FlushEvalAsyncCmds(ConnectionContext* cntx, + bool force) { auto& info = cntx->conn_state.script_info; - if ((!force && info->async_cmds.size() <= kMaxAsyncCmds) || info->async_cmds.empty()) - return; + size_t used_mem = info->async_cmds_heap_mem + info->async_cmds.size() * sizeof(StoredCmd); + if ((info->async_cmds.empty() || !force) && used_mem < info->async_cmds_heap_limit) + return nullopt; auto* eval_cid = registry_.Find("EVAL"); DCHECK(eval_cid); cntx->transaction->MultiSwitchCmd(eval_cid); - WithoutReplies(cntx, - [&] { MultiCommandSquasher::Execute(absl::MakeSpan(info->async_cmds), cntx); }); + CapturingReplyBuilder crb{ReplyMode::ONLY_ERR}; + WithReplies(&crb, cntx, + [&] { MultiCommandSquasher::Execute(absl::MakeSpan(info->async_cmds), cntx, true); }); + info->async_cmds_heap_mem = 0; info->async_cmds.clear(); + + auto reply = move(crb.Take()); + return CapturingReplyBuilder::GetError(reply) ? make_optional(move(reply)) : nullopt; } void Service::CallFromScript(ConnectionContext* cntx, Interpreter::CallArgs& ca) { @@ -1051,20 +1056,25 @@ void Service::CallFromScript(ConnectionContext* cntx, Interpreter::CallArgs& ca) auto& info = cntx->conn_state.script_info; auto* cid = registry_.Find(facade::ToSV(ca.args[0])); - bool valid = true; - WithoutReplies(cntx, [&] { valid = VerifyCommand(cid, ca.args, cntx); }); - - if (!valid) // TODO: collect errors with capturing reply builder. + if (!VerifyCommand(cid, ca.args, cntx)) return; - info->async_cmds.emplace_back(move(*ca.buffer), cid, ca.args.subspan(1)); - FlushEvalAsyncCmds(cntx, false); + auto replies = ca.error_abort ? ReplyMode::ONLY_ERR : ReplyMode::NONE; + info->async_cmds.emplace_back(move(*ca.buffer), cid, ca.args.subspan(1), replies); + info->async_cmds_heap_mem += info->async_cmds.back().UsedHeapMemory(); + } + + InterpreterReplier replier(ca.translator); + + if (auto err = FlushEvalAsyncCmds(cntx, !ca.async); err) { + CapturingReplyBuilder::Apply(move(*err), &replier); // forward error to lua + *ca.requested_abort = true; return; } - FlushEvalAsyncCmds(cntx, true); + if (ca.async) + return; - InterpreterReplier replier(ca.translator); facade::SinkReplyBuilder* orig = cntx->Inject(&replier); DispatchCommand(ca.args, cntx); @@ -1208,10 +1218,12 @@ void Service::EvalInternal(const EvalArgs& eval_args, Interpreter* interpreter, // TODO: to determine whether the script is RO by scanning all "redis.p?call" calls // and checking whether all invocations consist of RO commands. // we can do it once during script insertion into script mgr. - cntx->conn_state.script_info.emplace(ConnectionState::ScriptInfo{}); + auto& sinfo = cntx->conn_state.script_info; + sinfo.reset(new ConnectionState::ScriptInfo{}); for (size_t i = 0; i < eval_args.keys.size(); ++i) { - cntx->conn_state.script_info->keys.insert(ArgS(eval_args.keys, i)); + sinfo->keys.insert(ArgS(eval_args.keys, i)); } + sinfo->async_cmds_heap_limit = absl::GetFlag(FLAGS_multi_eval_squash_buffer); DCHECK(cntx->transaction); bool scheduled = StartMultiEval(cntx->db_index(), eval_args.keys, *params, cntx->transaction); @@ -1223,7 +1235,11 @@ void Service::EvalInternal(const EvalArgs& eval_args, Interpreter* interpreter, Interpreter::RunResult result = interpreter->RunFunction(eval_args.sha, &error); absl::Cleanup clean = [interpreter]() { interpreter->ResetStack(); }; - FlushEvalAsyncCmds(cntx, true); + if (auto err = FlushEvalAsyncCmds(cntx, true); err) { + auto err_ref = CapturingReplyBuilder::GetError(*err); + result = Interpreter::RUN_ERR; + error = absl::StrCat(err_ref->first); + } cntx->conn_state.script_info.reset(); // reset script_info diff --git a/src/server/main_service.h b/src/server/main_service.h index c731bab72..c73b4eed5 100644 --- a/src/server/main_service.h +++ b/src/server/main_service.h @@ -120,7 +120,10 @@ class Service : public facade::ServiceInterface { void EvalInternal(const EvalArgs& eval_args, Interpreter* interpreter, ConnectionContext* cntx); - void FlushEvalAsyncCmds(ConnectionContext* cntx, bool force = false); + // Return optional payload - first received error that occured when executing commands. + std::optional FlushEvalAsyncCmds(ConnectionContext* cntx, + bool force = false); + void CallFromScript(ConnectionContext* cntx, Interpreter::CallArgs& args); void RegisterCommands(); diff --git a/src/server/multi_command_squasher.cc b/src/server/multi_command_squasher.cc index 49ce61c10..4b27340eb 100644 --- a/src/server/multi_command_squasher.cc +++ b/src/server/multi_command_squasher.cc @@ -24,8 +24,9 @@ template void IterateKeys(CmdArgList args, KeyIndex keys, F&& f) { } // namespace -MultiCommandSquasher::MultiCommandSquasher(absl::Span cmds, ConnectionContext* cntx) - : cmds_{cmds}, cntx_{cntx}, base_cid_{cntx->transaction->GetCId()} { +MultiCommandSquasher::MultiCommandSquasher(absl::Span cmds, ConnectionContext* cntx, + bool error_abort) + : cmds_{cmds}, cntx_{cntx}, base_cid_{cntx->transaction->GetCId()}, error_abort_{error_abort} { auto mode = cntx->transaction->GetMultiMode(); track_keys_ = mode == Transaction::NON_ATOMIC; } @@ -38,9 +39,6 @@ MultiCommandSquasher::ShardExecInfo& MultiCommandSquasher::PrepareShardInfo(Shar if (!sinfo.local_tx) sinfo.local_tx = new Transaction{cntx_->transaction}; - if (!sinfo.reply_chan) - sinfo.reply_chan = make_unique(kChanBufferSize, 1); - return sinfo; } @@ -85,7 +83,7 @@ MultiCommandSquasher::SquashResult MultiCommandSquasher::TrySquash(StoredCmd* cm // Because the squashed hop is currently blocking, we cannot add more than the max channel size, // otherwise a deadlock occurs. - bool need_flush = sinfo.cmds.size() >= kChanBufferSize - 1; + bool need_flush = sinfo.cmds.size() >= kMaxSquashing - 1; return need_flush ? SquashResult::SQUASHED_FULL : SquashResult::SQUASHED; } @@ -118,6 +116,7 @@ OpStatus MultiCommandSquasher::SquashedHopCb(Transaction* parent_tx, EngineShard for (auto* cmd : sinfo.cmds) { local_tx->MultiSwitchCmd(cmd->Cid()); local_cntx.cid = cmd->Cid(); + crb.SetReplyMode(cmd->ReplyMode()); arg_vec.resize(cmd->NumArgs()); auto args = absl::MakeSpan(arg_vec); @@ -126,18 +125,20 @@ OpStatus MultiCommandSquasher::SquashedHopCb(Transaction* parent_tx, EngineShard local_tx->InitByArgs(parent_tx->GetDbIndex(), args); cmd->Cid()->Invoke(args, &local_cntx); - sinfo.reply_chan->Push(crb.Take()); + sinfo.replies.emplace_back(crb.Take()); } // ConnectionContext deletes the reply builder upon destruction, so // remove our local pointer from it. local_cntx.Inject(nullptr); + + reverse(sinfo.replies.begin(), sinfo.replies.end()); return OpStatus::OK; } -void MultiCommandSquasher::ExecuteSquashed() { +bool MultiCommandSquasher::ExecuteSquashed() { if (order_.empty()) - return; + return false; Transaction* tx = cntx_->transaction; @@ -149,14 +150,26 @@ void MultiCommandSquasher::ExecuteSquashed() { tx->PrepareSquashedMultiHop(base_cid_, cb); } + for (auto& sd : sharded_) + sd.replies.reserve(sd.cmds.size()); + cntx_->cid = base_cid_; tx->ScheduleSingleHop([this](auto* tx, auto* es) { return SquashedHopCb(tx, es); }); - facade::CapturingReplyBuilder::Payload payload; + bool aborted = false; + RedisReplyBuilder* rb = static_cast(cntx_->reply_builder()); for (auto idx : order_) { - CHECK(sharded_[idx].reply_chan->Pop(payload)); - CapturingReplyBuilder::Apply(move(payload), rb); + auto& replies = sharded_[idx].replies; + CHECK(!replies.empty()); + + aborted |= error_abort_ && CapturingReplyBuilder::GetError(replies.back()); + + CapturingReplyBuilder::Apply(move(replies.back()), rb); + replies.pop_back(); + + if (aborted) + break; } for (auto& sinfo : sharded_) @@ -164,6 +177,7 @@ void MultiCommandSquasher::ExecuteSquashed() { order_.clear(); collected_keys_.clear(); + return aborted; } void MultiCommandSquasher::Run() { @@ -173,8 +187,10 @@ void MultiCommandSquasher::Run() { if (res == SquashResult::ERROR) break; - if (res == SquashResult::NOT_SQUASHED || res == SquashResult::SQUASHED_FULL) - ExecuteSquashed(); + if (res == SquashResult::NOT_SQUASHED || res == SquashResult::SQUASHED_FULL) { + if (ExecuteSquashed()) + break; + } if (res == SquashResult::NOT_SQUASHED) ExecuteStandalone(&cmd); diff --git a/src/server/multi_command_squasher.h b/src/server/multi_command_squasher.h index 54a2805aa..08257f060 100644 --- a/src/server/multi_command_squasher.h +++ b/src/server/multi_command_squasher.h @@ -16,32 +16,29 @@ namespace dfly { // thus greatly decreasing the dispatch overhead for them. class MultiCommandSquasher { public: - static void Execute(absl::Span cmds, ConnectionContext* cntx) { - MultiCommandSquasher{cmds, cntx}.Run(); + static void Execute(absl::Span cmds, ConnectionContext* cntx, + bool error_abort = false) { + MultiCommandSquasher{cmds, cntx, error_abort}.Run(); } private: - using ReplyChan = - dfly::SimpleChannel>; - // Per-shard exection info. struct ShardExecInfo { - ShardExecInfo() : had_writes{false}, cmds{}, reply_chan{nullptr}, local_tx{nullptr} { + ShardExecInfo() : had_writes{false}, cmds{}, replies{}, local_tx{nullptr} { } bool had_writes; std::vector cmds; // accumulated commands - std::unique_ptr reply_chan; + std::vector replies; boost::intrusive_ptr local_tx; // stub-mode tx for use inside shard }; enum class SquashResult { SQUASHED, SQUASHED_FULL, NOT_SQUASHED, ERROR }; - static constexpr int kChanBufferSize = 32; + static constexpr int kMaxSquashing = 32; private: - MultiCommandSquasher(absl::Span cmds, ConnectionContext* cntx); + MultiCommandSquasher(absl::Span cmds, ConnectionContext* cntx, bool error_abort); // Lazy initialize shard info. ShardExecInfo& PrepareShardInfo(ShardId sid); @@ -55,8 +52,8 @@ class MultiCommandSquasher { // Callback that runs on shards during squashed hop. facade::OpStatus SquashedHopCb(Transaction* parent_tx, EngineShard* es); - // Execute all currently squashed commands. - void ExecuteSquashed(); + // Execute all currently squashed commands. Return true if aborting on error. + bool ExecuteSquashed(); // Run all commands until completion. void Run(); @@ -66,6 +63,8 @@ class MultiCommandSquasher { ConnectionContext* cntx_; // Underlying context const CommandId* base_cid_; // either EVAL or EXEC, used for squashed hops + bool error_abort_ = false; // Abort upon receiving error + std::vector sharded_; std::vector order_; // reply order for squashed cmds diff --git a/tests/dragonfly/eval_test.py b/tests/dragonfly/eval_test.py index 941e67ed2..2ca9a22df 100644 --- a/tests/dragonfly/eval_test.py +++ b/tests/dragonfly/eval_test.py @@ -177,7 +177,6 @@ async def test_golang_asynq_script(async_pool, num_queues=10, num_tasks=100): jobs = [asyncio.create_task(enqueue_worker( f"q-{queue}")) for queue in range(num_queues)] - collected = 0 async def dequeue_worker(): @@ -202,20 +201,23 @@ async def test_golang_asynq_script(async_pool, num_queues=10, num_tasks=100): for job in jobs: await job -ERROR_CALL_SCRIPT = """ -redis.call('ECHO', 'I', 'want', 'an', 'error') + +ERROR_CALL_SCRIPT_TEMPLATE = """ +redis.{}('LTRIM', 'l', 'a', 'b') """ -ERROR_PCALL_SCRIPT = """ -redis.pcall('ECHO', 'I', 'want', 'an', 'error') -""" +@dfly_args({"proactor_threads": 1}) @pytest.mark.asyncio async def test_eval_error_propagation(async_client): - assert await async_client.eval(ERROR_PCALL_SCRIPT, 0) is None + CMDS = ['call', 'pcall', 'acall', 'apcall'] - try: - await async_client.eval(ERROR_CALL_SCRIPT, 0) - assert False, "Eval must have thrown an error" - except aioredis.RedisError as e: - pass + for cmd in CMDS: + does_abort = 'p' not in cmd + try: + await async_client.eval(ERROR_CALL_SCRIPT_TEMPLATE.format(cmd), 1, 'l') + if does_abort: + assert False, "Eval must have thrown an error: " + cmd + except aioredis.RedisError as e: + if not does_abort: + assert False, "Error should have been ignored: " + cmd