diff --git a/helio b/helio index c8ccbbdf9..963304405 160000 --- a/helio +++ b/helio @@ -1 +1 @@ -Subproject commit c8ccbbdf9113e5d3f1dc16c6cb96396ac7e3694d +Subproject commit 96330440550013c69da14ae173049bf80e1e9257 diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index 4ae87b462..f95250167 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -11,6 +11,7 @@ #include "base/flags.h" #include "base/logging.h" #include "facade/conn_context.h" +#include "facade/dragonfly_listener.h" #include "facade/memcache_parser.h" #include "facade/redis_parser.h" #include "facade/service_interface.h" @@ -403,8 +404,7 @@ uint32_t Connection::GetClientId() const { } bool Connection::IsAdmin() const { - uint16_t admin_port = absl::GetFlag(FLAGS_admin_port); - return socket_->LocalEndpoint().port() == admin_port; + return static_cast(owner())->IsAdminInterface(); } io::Result Connection::CheckForHttpProto(FiberSocketBase* peer) { diff --git a/src/facade/dragonfly_listener.cc b/src/facade/dragonfly_listener.cc index 528461e5f..034de828c 100644 --- a/src/facade/dragonfly_listener.cc +++ b/src/facade/dragonfly_listener.cc @@ -201,6 +201,14 @@ bool Listener::AwaitDispatches(absl::Duration timeout, return false; } +bool Listener::IsAdminInterface() const { + return is_admin_; +} + +void Listener::SetAdminInterface(bool is_admin) { + is_admin_ = is_admin; +} + void Listener::PreShutdown() { // Iterate on all connections and allow them to finish their commands for // a short period. @@ -266,6 +274,10 @@ void Listener::OnConnectionClose(util::Connection* conn) { } } +void Listener::OnMaxConnectionsReached(util::FiberSocketBase* sock) { + sock->Write(io::Buffer("-ERR max number of clients reached\r\n")); +} + // We can limit number of threads handling dragonfly connections. ProactorBase* Listener::PickConnectionProactor(util::FiberSocketBase* sock) { util::ProactorPool* pp = pool(); diff --git a/src/facade/dragonfly_listener.h b/src/facade/dragonfly_listener.h index d81832241..b5a404143 100644 --- a/src/facade/dragonfly_listener.h +++ b/src/facade/dragonfly_listener.h @@ -35,12 +35,16 @@ class Listener : public util::ListenerInterface { bool AwaitDispatches(absl::Duration timeout, const std::function& filter); + bool IsAdminInterface() const; + void SetAdminInterface(bool is_admin = true); + private: util::Connection* NewConnection(ProactorBase* proactor) final; ProactorBase* PickConnectionProactor(util::FiberSocketBase* sock) final; void OnConnectionStart(util::Connection* conn) final; void OnConnectionClose(util::Connection* conn) final; + void OnMaxConnectionsReached(util::FiberSocketBase* sock) final; void PreAcceptLoop(ProactorBase* pb) final; void PreShutdown() final; @@ -58,6 +62,8 @@ class Listener : public util::ListenerInterface { std::atomic_uint32_t next_id_{0}; + bool is_admin_ = false; + uint32_t conn_cnt_{0}; uint32_t min_cnt_thread_id_{0}; int32_t min_cnt_{0}; diff --git a/src/server/config_registry.cc b/src/server/config_registry.cc index d0622bd58..a7d65bd3b 100644 --- a/src/server/config_registry.cc +++ b/src/server/config_registry.cc @@ -43,6 +43,17 @@ bool ConfigRegistry::Set(std::string_view config_name, std::string_view value) { return cb(*flag); } +std::optional ConfigRegistry::Get(std::string_view config_name) { + unique_lock lk(mu_); + if (!registry_.contains(config_name)) + return std::nullopt; + lk.unlock(); + + absl::CommandLineFlag* flag = absl::FindCommandLineFlag(config_name); + CHECK(flag); + return flag->CurrentValue(); +} + void ConfigRegistry::Reset() { unique_lock lk(mu_); registry_.clear(); diff --git a/src/server/config_registry.h b/src/server/config_registry.h index b1ac336a0..ab72eee0e 100644 --- a/src/server/config_registry.h +++ b/src/server/config_registry.h @@ -20,6 +20,8 @@ class ConfigRegistry { // Returns true if the value was updated. bool Set(std::string_view config_name, std::string_view value); + std::optional Get(std::string_view config_name); + void Reset(); private: diff --git a/src/server/dfly_main.cc b/src/server/dfly_main.cc index 37a1b0765..2d7d7f805 100644 --- a/src/server/dfly_main.cc +++ b/src/server/dfly_main.cc @@ -392,6 +392,7 @@ bool RunEngine(ProactorPool* pool, AcceptServer* acceptor) { const std::string printable_addr = absl::StrCat("admin socket ", interface_addr ? interface_addr : "any", ":", admin_port); Listener* admin_listener = new Listener{Protocol::REDIS, &service}; + admin_listener->SetAdminInterface(); error_code ec = acceptor->AddListener(interface_addr, admin_port, admin_listener); if (ec) { diff --git a/src/server/server_family.cc b/src/server/server_family.cc index 30d802a7a..584659e30 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -74,6 +74,8 @@ ABSL_FLAG(string, dbfilename, "dump-{timestamp}", "the filename to save/load the ABSL_FLAG(string, requirepass, "", "password for AUTH authentication. " "If empty can also be set with DFLY_PASSWORD environment variable."); +ABSL_FLAG(uint32_t, maxclients, 64000, "Maximum number of concurrent clients allowed."); + ABSL_FLAG(string, save_schedule, "", "glob spec for the UTC time to save a snapshot which matches HH:MM 24h time"); ABSL_FLAG(string, snapshot_cron, "", @@ -898,12 +900,28 @@ ServerFamily::ServerFamily(Service* service) : service_(*service) { ServerFamily::~ServerFamily() { } +void SetMaxClients(std::vector& listeners, uint32_t maxclients) { + for (auto* listener : listeners) { + if (!listener->IsAdminInterface()) { + listener->SetMaxClients(maxclients); + } + } +} + void ServerFamily::Init(util::AcceptServer* acceptor, std::vector listeners) { CHECK(acceptor_ == nullptr); acceptor_ = acceptor; listeners_ = std::move(listeners); dfly_cmd_ = make_unique(this); + SetMaxClients(listeners_, absl::GetFlag(FLAGS_maxclients)); + config_registry.Register("maxclients", [this](const absl::CommandLineFlag& flag) { + auto res = flag.TryGet(); + if (res.has_value()) + SetMaxClients(listeners_, res.value()); + return res.has_value(); + }); + pb_task_ = shard_set->pool()->GetNextProactor(); if (pb_task_->GetKind() == ProactorBase::EPOLL) { fq_threadpool_.reset(new FiberQueueThreadPool(absl::GetFlag(FLAGS_epoll_file_threads))); @@ -1621,9 +1639,10 @@ void ServerFamily::Config(CmdArgList args, ConnectionContext* cntx) { if (param == "databases") { res.emplace_back(param); res.push_back(absl::StrCat(absl::GetFlag(FLAGS_dbnum))); - } else if (param == "maxmemory") { + } else if (auto value_from_registry = config_registry.Get(param); + value_from_registry.has_value()) { res.emplace_back(param); - res.push_back(absl::StrCat(max_memory_limit)); + res.push_back(*value_from_registry); } return (*cntx)->SendStringArr(res, RedisReplyBuilder::MAP); diff --git a/tests/dragonfly/config_test.py b/tests/dragonfly/config_test.py new file mode 100644 index 000000000..903e7e535 --- /dev/null +++ b/tests/dragonfly/config_test.py @@ -0,0 +1,27 @@ +import pytest +import redis +from redis.asyncio import Redis as RedisClient +from .utility import * +from . import DflyStartException + + +async def test_maxclients(df_factory): + # Needs some authentication + server = df_factory.create(port=1111, maxclients=1, admin_port=1112) + server.start() + + async with server.client() as client1: + assert [b"maxclients", b"1"] == await client1.execute_command("CONFIG GET maxclients") + + with pytest.raises(redis.exceptions.ConnectionError): + async with server.client() as client2: + await client2.get("test") + + # Check that admin connections are not limited. + async with RedisClient(port=server.admin_port) as admin_client: + await admin_client.get("test") + + await client1.execute_command("CONFIG SET maxclients 3") + assert [b"maxclients", b"3"] == await client1.execute_command("CONFIG GET maxclients") + async with server.client() as client2: + await client2.get("test")