diff --git a/src/server/command_registry.cc b/src/server/command_registry.cc index 193a60f98..adf590fb0 100644 --- a/src/server/command_registry.cc +++ b/src/server/command_registry.cc @@ -85,7 +85,8 @@ bool CommandId::IsTransactional() const { if (first_key_ > 0 || (opt_mask_ & CO::GLOBAL_TRANS) || (opt_mask_ & CO::NO_KEY_TRANSACTIONAL)) return true; - if (name_ == "EVAL" || name_ == "EVALSHA" || name_ == "EXEC") + if (name_ == "EVAL" || name_ == "EVALSHA" || name_ == "EVAL_RO" || name_ == "EVALSHA_RO" || + name_ == "EXEC") return true; return false; diff --git a/src/server/command_registry.h b/src/server/command_registry.h index 8fdb8b71f..ea5eab996 100644 --- a/src/server/command_registry.h +++ b/src/server/command_registry.h @@ -64,7 +64,8 @@ constexpr inline bool IsTransKind(std::string_view name) { return (name == "EXEC") || (name == "MULTI") || (name == "DISCARD"); } -static_assert(IsEvalKind("EVAL") && IsEvalKind("EVALSHA")); +static_assert(IsEvalKind("EVAL") && IsEvalKind("EVAL_RO") && IsEvalKind("EVALSHA") && + IsEvalKind("EVALSHA_RO")); static_assert(!IsEvalKind("")); }; // namespace CO diff --git a/src/server/conn_context.h b/src/server/conn_context.h index 79b7588fb..db5f76f2d 100644 --- a/src/server/conn_context.h +++ b/src/server/conn_context.h @@ -104,6 +104,7 @@ struct ConnectionState { size_t UsedMemory() const; absl::flat_hash_set lock_tags; // declared tags + bool read_only = false; size_t async_cmds_heap_mem = 0; // bytes used by async_cmds size_t async_cmds_heap_limit = 0; // max bytes allowed for async_cmds diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 66614638b..9fb27b0bb 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1121,6 +1121,10 @@ std::optional Service::VerifyCommandState(const CommandId* cid, CmdA return err.value(); } } + + if (dfly_cntx.conn_state.script_info->read_only && is_write_cmd) { + return ErrorReply{"Write commands are not allowed from read-only scripts"}; + } } return VerifyConnectionAclStatus(cid, &dfly_cntx, "has no ACL permissions", tail_args); @@ -1764,7 +1768,7 @@ void Service::CallFromScript(ConnectionContext* cntx, Interpreter::CallArgs& ca) } void Service::Eval(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder, - ConnectionContext* cntx) { + ConnectionContext* cntx, bool read_only) { string_view body = ArgS(args, 0); auto* rb = static_cast(builder); @@ -1779,19 +1783,29 @@ void Service::Eval(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder, string sha{std::move(res.value())}; - CallSHA(args, sha, interpreter, builder, cntx); + CallSHA(args, sha, interpreter, builder, cntx, read_only); +} + +void Service::EvalRo(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder, + ConnectionContext* cntx) { + Eval(args, tx, builder, cntx, true); } void Service::EvalSha(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder, - ConnectionContext* cntx) { + ConnectionContext* cntx, bool read_only) { string sha = absl::AsciiStrToLower(ArgS(args, 0)); BorrowedInterpreter interpreter{cntx->transaction, &cntx->conn_state}; - CallSHA(args, sha, interpreter, builder, cntx); + CallSHA(args, sha, interpreter, builder, cntx, read_only); +} + +void Service::EvalShaRo(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder, + ConnectionContext* cntx) { + EvalSha(args, tx, builder, cntx, true); } void Service::CallSHA(CmdArgList args, string_view sha, Interpreter* interpreter, - SinkReplyBuilder* builder, ConnectionContext* cntx) { + SinkReplyBuilder* builder, ConnectionContext* cntx, bool read_only) { uint32_t num_keys; CHECK(absl::SimpleAtoi(ArgS(args, 1), &num_keys)); // we already validated this @@ -1801,7 +1815,7 @@ void Service::CallSHA(CmdArgList args, string_view sha, Interpreter* interpreter ev_args.args = args.subspan(2 + num_keys); uint64_t start = absl::GetCurrentTimeNanos(); - EvalInternal(args, ev_args, interpreter, builder, cntx); + EvalInternal(args, ev_args, interpreter, builder, cntx, read_only); uint64_t end = absl::GetCurrentTimeNanos(); ServerState::tlocal()->RecordCallLatency(sha, (end - start) / 1000); @@ -1887,7 +1901,7 @@ static bool CanRunSingleShardMulti(optional sid, const ScriptMgr::Scrip } void Service::EvalInternal(CmdArgList args, const EvalArgs& eval_args, Interpreter* interpreter, - SinkReplyBuilder* builder, ConnectionContext* cntx) { + SinkReplyBuilder* builder, ConnectionContext* cntx, bool read_only) { DCHECK(!eval_args.sha.empty()); // Sanitizing the input to avoid code injection. @@ -1909,6 +1923,7 @@ void Service::EvalInternal(CmdArgList args, const EvalArgs& eval_args, Interpret auto& sinfo = cntx->conn_state.script_info; sinfo = make_unique(); sinfo->lock_tags.reserve(eval_args.keys.size()); + sinfo->read_only = read_only; optional sid; @@ -2594,7 +2609,9 @@ constexpr uint32_t kWatch = FAST | TRANSACTION; constexpr uint32_t kUnwatch = FAST | TRANSACTION; constexpr uint32_t kDiscard = FAST | TRANSACTION; constexpr uint32_t kEval = SLOW | SCRIPTING; +constexpr uint32_t kEvalRo = SLOW | SCRIPTING; constexpr uint32_t kEvalSha = SLOW | SCRIPTING; +constexpr uint32_t kEvalShaRo = SLOW | SCRIPTING; constexpr uint32_t kExec = SLOW | TRANSACTION; constexpr uint32_t kPublish = PUBSUB | FAST; constexpr uint32_t kSubscribe = PUBSUB | SLOW; @@ -2619,9 +2636,16 @@ void Service::Register(CommandRegistry* registry) { << CI{"EVAL", CO::NOSCRIPT | CO::VARIADIC_KEYS, -3, 3, 3, acl::kEval} .MFUNC(Eval) .SetValidator(&EvalValidator) + << CI{"EVAL_RO", CO::NOSCRIPT | CO::READONLY | CO::VARIADIC_KEYS, -3, 3, 3, acl::kEvalRo} + .MFUNC(EvalRo) + .SetValidator(&EvalValidator) << CI{"EVALSHA", CO::NOSCRIPT | CO::VARIADIC_KEYS, -3, 3, 3, acl::kEvalSha} .MFUNC(EvalSha) .SetValidator(&EvalValidator) + << CI{"EVALSHA_RO", CO::NOSCRIPT | CO::READONLY | CO::VARIADIC_KEYS, -3, 3, 3, + acl::kEvalShaRo} + .MFUNC(EvalShaRo) + .SetValidator(&EvalValidator) << CI{"EXEC", CO::LOADING | CO::NOSCRIPT, 1, 0, 0, acl::kExec}.MFUNC(Exec) << CI{"PUBLISH", CO::LOADING | CO::FAST, 3, 0, 0, acl::kPublish}.MFUNC(Publish) << CI{"SUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -2, 0, 0, acl::kSubscribe}.MFUNC(Subscribe) diff --git a/src/server/main_service.h b/src/server/main_service.h index 471184ffd..fe1dc1a29 100644 --- a/src/server/main_service.h +++ b/src/server/main_service.h @@ -133,9 +133,13 @@ class Service : public facade::ServiceInterface { void Discard(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder, ConnectionContext* cntx); - void Eval(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder, ConnectionContext* cntx); - void EvalSha(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder, - ConnectionContext* cntx); + void Eval(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder, ConnectionContext* cntx, + bool read_only = false); + void EvalRo(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder, ConnectionContext* cntx); + void EvalSha(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder, ConnectionContext* cntx, + bool read_only = false); + void EvalShaRo(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder, + ConnectionContext* cntx); void Exec(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder, ConnectionContext* cntx); void Publish(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder, ConnectionContext* cntx); @@ -169,9 +173,9 @@ class Service : public facade::ServiceInterface { const ConnectionContext& dfly_cntx); void EvalInternal(CmdArgList args, const EvalArgs& eval_args, Interpreter* interpreter, - SinkReplyBuilder* builder, ConnectionContext* cntx); + SinkReplyBuilder* builder, ConnectionContext* cntx, bool read_only); void CallSHA(CmdArgList args, std::string_view sha, Interpreter* interpreter, - SinkReplyBuilder* builder, ConnectionContext* cntx); + SinkReplyBuilder* builder, ConnectionContext* cntx, bool read_only); // Return optional payload - first received error that occured when executing commands. std::optional FlushEvalAsyncCmds(ConnectionContext* cntx, diff --git a/src/server/multi_test.cc b/src/server/multi_test.cc index 161f291de..727045900 100644 --- a/src/server/multi_test.cc +++ b/src/server/multi_test.cc @@ -1185,4 +1185,38 @@ TEST_F(MultiTest, MultiTypes) { RespArray(ElementsAre("none", "none", "none", "none", "none", "none"))); } +TEST_F(MultiTest, EvalRo) { + RespExpr resp; + + resp = Run({"set", "foo", "bar"}); + EXPECT_THAT(resp, "OK"); + + resp = Run({"eval_ro", "return redis.call('get', KEYS[1])", "1", "foo"}); + EXPECT_THAT(resp, "bar"); + + resp = Run({"eval_ro", "return redis.call('set', KEYS[1], 'car')", "1", "foo"}); + EXPECT_THAT(resp, ErrArg("Write commands are not allowed from read-only scripts")); +} + +TEST_F(MultiTest, EvalShaRo) { + RespExpr resp; + + const char* read_script = "return redis.call('get', KEYS[1]);"; + const char* write_script = "return redis.call('set', KEYS[1], 'car');"; + + auto sha_resp = Run({"script", "load", read_script}); + auto read_sha = facade::ToSV(sha_resp.GetBuf()); + sha_resp = Run({"script", "load", write_script}); + auto write_sha = facade::ToSV(sha_resp.GetBuf()); + + resp = Run({"set", "foo", "bar"}); + EXPECT_THAT(resp, "OK"); + + resp = Run({"evalsha_ro", read_sha, "1", "foo"}); + EXPECT_THAT(resp, "bar"); + + resp = Run({"evalsha_ro", write_sha, "1", "foo"}); + EXPECT_THAT(resp, ErrArg("Write commands are not allowed from read-only scripts")); +} + } // namespace dfly diff --git a/src/server/test_utils.cc b/src/server/test_utils.cc index c76ac9747..42608fce5 100644 --- a/src/server/test_utils.cc +++ b/src/server/test_utils.cc @@ -408,7 +408,8 @@ RespExpr BaseFamilyTest::Run(std::string_view id, ArgSlice slice) { DCHECK(context->transaction == nullptr); auto cmd = absl::AsciiStrToUpper(slice.front()); - if (cmd == "EVAL" || cmd == "EVALSHA" || cmd == "EXEC") { + if (cmd == "EVAL" || cmd == "EVALSHA" || cmd == "EVAL_RO" || cmd == "EVALSHA_RO" || + cmd == "EXEC") { shard_set->AwaitRunningOnShardQueue([](auto*) {}); // Wait for async UnlockMulti. } diff --git a/src/server/transaction.cc b/src/server/transaction.cc index 3064cd512..67f910438 100644 --- a/src/server/transaction.cc +++ b/src/server/transaction.cc @@ -176,7 +176,8 @@ void Transaction::Shutdown() { Transaction::Transaction(const CommandId* cid) : cid_{cid} { InitTxTime(); string_view cmd_name(cid_->name()); - if (cmd_name == "EXEC" || cmd_name == "EVAL" || cmd_name == "EVALSHA") { + if (cmd_name == "EXEC" || cmd_name == "EVAL" || cmd_name == "EVAL_RO" || cmd_name == "EVALSHA" || + cmd_name == "EVALSHA_RO") { multi_.reset(new MultiData); multi_->mode = NOT_DETERMINED; multi_->role = DEFAULT;