From 0705bbb536168ac8efdde5c724c204256b6ed853 Mon Sep 17 00:00:00 2001 From: Kostas Kyrimis Date: Fri, 30 Aug 2024 15:41:28 +0300 Subject: [PATCH] feat(acl): add pub/sub (#3574) * add support for pub/sub * add tests --------- Signed-off-by: kostas --- src/facade/acl_commands_def.h | 14 +++ src/facade/command_id.h | 13 +++ src/facade/conn_context.h | 2 + src/facade/dragonfly_connection.cc | 1 + src/facade/dragonfly_connection.h | 1 + src/facade/facade.cc | 5 + src/server/acl/acl_family.cc | 112 +++++++++++++++++--- src/server/acl/acl_family.h | 6 +- src/server/acl/acl_family_test.cc | 93 +++++++++++----- src/server/acl/acl_log.h | 2 +- src/server/acl/user.cc | 22 ++++ src/server/acl/user.h | 21 ++++ src/server/acl/user_registry.cc | 5 +- src/server/acl/validator.cc | 58 ++++++++-- src/server/acl/validator.h | 8 ++ src/server/conn_context.cc | 2 + src/server/server_family.cc | 1 + tests/dragonfly/acl_family_test.py | 164 ++++++++++++++++++++++------- 18 files changed, 445 insertions(+), 85 deletions(-) diff --git a/src/facade/acl_commands_def.h b/src/facade/acl_commands_def.h index b4aaadd35..13bbc6366 100644 --- a/src/facade/acl_commands_def.h +++ b/src/facade/acl_commands_def.h @@ -21,13 +21,27 @@ using GlobType = std::pair; struct AclKeys { std::vector key_globs; + // The user is allowed to "touch" any key. No glob matching required. + // Alias for ~* bool all_keys = false; }; +// The second bool denotes if the pattern contains an asterisk and it's +// used to pattern match PSUBSCRIBE that requires exact literals +using GlobTypePubSub = std::pair; + +struct AclPubSub { + std::vector globs; + // The user can execute any variant of pub/sub/psub. No glob matching required. + // Alias for &* just like all_keys for AclKeys above. + bool all_channels = false; +}; + struct UserCredentials { uint32_t acl_categories{0}; std::vector acl_commands; AclKeys keys; + AclPubSub pub_sub; std::string ns; }; diff --git a/src/facade/command_id.h b/src/facade/command_id.h index edbc180aa..085d88385 100644 --- a/src/facade/command_id.h +++ b/src/facade/command_id.h @@ -86,6 +86,16 @@ class CommandId { static uint32_t OptCount(uint32_t mask); + // PUBLISH/SUBSCRIBE/UNSUBSCRIBE variant + bool IsPubSub() const { + return is_pub_sub_; + } + + // PSUBSCRIBE/PUNSUBSCRIBE variant + bool IsPSub() const { + return is_p_sub_; + } + protected: std::string name_; @@ -102,6 +112,9 @@ class CommandId { // Whether the command can only be used by admin connections. bool restricted_ = false; + + bool is_pub_sub_ = false; + bool is_p_sub_ = false; }; } // namespace facade diff --git a/src/facade/conn_context.h b/src/facade/conn_context.h index 8a71797df..3dd3b6ac9 100644 --- a/src/facade/conn_context.h +++ b/src/facade/conn_context.h @@ -107,6 +107,8 @@ class ConnectionContext { std::vector acl_commands; // keys dfly::acl::AclKeys keys{{}, true}; + // pub/sub + dfly::acl::AclPubSub pub_sub{{}, true}; private: Connection* owner_; diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index 816380038..9add861b0 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -451,6 +451,7 @@ void Connection::DispatchOperations::operator()(const AclUpdateMessage& msg) { if (msg.username == self->cntx()->authed_username) { self->cntx()->acl_commands = msg.commands; self->cntx()->keys = msg.keys; + self->cntx()->pub_sub = msg.pub_sub; } } } diff --git a/src/facade/dragonfly_connection.h b/src/facade/dragonfly_connection.h index e01f1486f..3d3067f48 100644 --- a/src/facade/dragonfly_connection.h +++ b/src/facade/dragonfly_connection.h @@ -113,6 +113,7 @@ class Connection : public util::Connection { std::string username; std::vector commands; dfly::acl::AclKeys keys; + dfly::acl::AclPubSub pub_sub; }; // Migration request message, the dispatch fiber stops to give way for thread migration. diff --git a/src/facade/facade.cc b/src/facade/facade.cc index 4567135a4..fd06def1d 100644 --- a/src/facade/facade.cc +++ b/src/facade/facade.cc @@ -131,6 +131,11 @@ CommandId::CommandId(const char* name, uint32_t mask, int8_t arity, int8_t first first_key_(first_key), last_key_(last_key), acl_categories_(acl_categories) { + if (name_ == "PUBLISH" || name_ == "SUBSCRIBE" || name_ == "UNSUBSCRIBE") { + is_pub_sub_ = true; + } else if (name_ == "PSUBSCRIBE" || name_ == "PUNSUBSCRIBE") { + is_p_sub_ = true; + } } uint32_t CommandId::OptCount(uint32_t mask) { diff --git a/src/server/acl/acl_family.cc b/src/server/acl/acl_family.cc index 8791c8758..d510bbded 100644 --- a/src/server/acl/acl_family.cc +++ b/src/server/acl/acl_family.cc @@ -55,6 +55,8 @@ MaterializedContents MaterializeFileContents(std::vector* usernames std::string_view file_contents); std::string AclKeysToString(const AclKeys& keys); + +std::string AclPubSubToString(const AclPubSub& pub_sub); } // namespace AclFamily::AclFamily(UserRegistry* registry, util::ProactorPool* pool) @@ -76,6 +78,9 @@ void AclFamily::List(CmdArgList args, ConnectionContext* cntx) { const std::string password = PasswordsToString(user.Passwords(), user.HasNopass(), false); const std::string acl_keys = AclKeysToString(user.Keys()); + + const std::string acl_pub_sub = AclPubSubToString(user.PubSub()); + const std::string maybe_space_com = acl_keys.empty() ? "" : " "; const std::string acl_cat_and_commands = @@ -84,7 +89,7 @@ void AclFamily::List(CmdArgList args, ConnectionContext* cntx) { using namespace std::string_view_literals; absl::StrAppend(&buffer, username, " ", user.IsActive() ? "on "sv : "off "sv, password, - acl_keys, maybe_space_com, acl_cat_and_commands); + acl_keys, maybe_space_com, acl_pub_sub, " ", acl_cat_and_commands); cntx->SendSimpleString(buffer); } @@ -92,14 +97,15 @@ void AclFamily::List(CmdArgList args, ConnectionContext* cntx) { void AclFamily::StreamUpdatesToAllProactorConnections(const std::string& user, const Commands& update_commands, - const AclKeys& update_keys) { + const AclKeys& update_keys, + const AclPubSub& update_pub_sub) { auto update_cb = [&]([[maybe_unused]] size_t id, util::Connection* conn) { DCHECK(conn); auto connection = static_cast(conn); if (connection->protocol() == facade::Protocol::REDIS && !connection->IsHttp() && connection->cntx()) { connection->SendAclUpdateAsync( - facade::Connection::AclUpdateMessage{user, update_commands, update_keys}); + facade::Connection::AclUpdateMessage{user, update_commands, update_keys, update_pub_sub}); } }; @@ -128,11 +134,20 @@ void AclFamily::SetUser(CmdArgList args, ConnectionContext* cntx) { user.Update(std::move(default_req), CategoryToIdx(), reverse_cat_table_, CategoryToCommandsIndex()); } + const bool reset_channels = req.reset_channels; user.Update(std::move(req), CategoryToIdx(), reverse_cat_table_, CategoryToCommandsIndex()); - if (exists) { - StreamUpdatesToAllProactorConnections(std::string(username), user.AclCommands(), user.Keys()); - } + // Send ok first because the connection might get evicted cntx->SendOk(); + if (exists) { + if (!reset_channels) { + StreamUpdatesToAllProactorConnections(std::string(username), user.AclCommands(), + user.Keys(), user.PubSub()); + } + // We evict connections that had their channels reseted + else { + EvictOpenConnectionsOnAllProactors({username}); + } + } }; std::visit(Overloaded{error_case, update_case}, std::move(req)); @@ -208,15 +223,18 @@ std::string AclFamily::RegistryToString() const { const std::string password = PasswordsToString(user.Passwords(), user.HasNopass(), true); const std::string acl_keys = AclKeysToString(user.Keys()); + const std::string maybe_space = acl_keys.empty() ? "" : " "; + const std::string acl_pub_sub = AclPubSubToString(user.PubSub()); + const std::string acl_cat_and_commands = AclCatAndCommandToString(user.CatChanges(), user.CmdChanges()); using namespace std::string_view_literals; absl::StrAppend(&result, command, username, " ", user.IsActive() ? "ON "sv : "OFF "sv, password, - acl_keys, maybe_space, acl_cat_and_commands, "\n"); + acl_keys, maybe_space, acl_pub_sub, " ", acl_cat_and_commands, "\n"); } return result; @@ -391,6 +409,8 @@ void AclFamily::Log(CmdArgList args, ConnectionContext* cntx) { reason = "COMMAND"; } else if (entry.reason == Reason::KEY) { reason = "KEY"; + } else if (entry.reason == Reason::PUB_SUB) { + reason = "PUB_SUB"; } else { reason = "AUTH"; } @@ -511,7 +531,7 @@ void AclFamily::GetUser(CmdArgList args, ConnectionContext* cntx) { } auto* rb = static_cast(cntx->reply_builder()); - rb->StartArray(8); + rb->StartArray(10); rb->SendSimpleString("flags"); const size_t total_elements = (pass != "nopass") ? 1 : 2; @@ -541,6 +561,10 @@ void AclFamily::GetUser(CmdArgList args, ConnectionContext* cntx) { } else { rb->SendEmptyArray(); } + + rb->SendSimpleString("channels"); + std::string pub_sub = AclPubSubToString(user.PubSub()); + rb->SendSimpleString(pub_sub); } void AclFamily::GenPass(CmdArgList args, ConnectionContext* cntx) { @@ -815,6 +839,31 @@ std::optional MaybeParseAclKey(std::string_view command) { return ParseKeyResult{std::string(key), op}; } +struct ParsePubSubResult { + std::string glob; + bool has_asterisk{false}; + bool all_channels{false}; + bool reset_channels{false}; +}; + +std::optional MaybeParseAclPubSub(std::string_view command) { + if (absl::EqualsIgnoreCase(command, "ALLCHANNELS") || command == "&*") { + return ParsePubSubResult{"", false, true, false}; + } + + if (absl::EqualsIgnoreCase(command, "RESETCHANNELS")) { + return ParsePubSubResult{"", false, false, true}; + } + + if (absl::StartsWith(command, "&") && command.size() >= 2) { + const auto glob = command.substr(1); + const bool has_asterisk = glob.find('*') != std::string_view::npos; + return ParsePubSubResult{std::string(glob), has_asterisk}; + } + + return {}; +} + std::string PrettyPrintSha(std::string_view pass, bool all) { if (all) { return absl::BytesToHexString(pass); @@ -886,6 +935,24 @@ std::string AclKeysToString(const AclKeys& keys) { return result; } +std::string AclPubSubToString(const AclPubSub& pub_sub) { + if (pub_sub.all_channels) { + return "&*"; + } + + std::string result = "resetchannels "; + + for (const auto& [glob, has_asterisk] : pub_sub.globs) { + absl::StrAppend(&result, "&", glob, " "); + } + + if (result.back() == ' ') { + result.pop_back(); + } + + return result; +} + } // namespace std::string AclFamily::AclCatAndCommandToString(const User::CategoryChanges& cat, @@ -976,7 +1043,7 @@ std::pair AclFamily::MaybeParseAclCommand( using facade::ErrorReply; std::variant AclFamily::ParseAclSetUser( - const facade::ArgRange& args, bool hashed, bool has_all_keys) const { + const facade::ArgRange& args, bool hashed, bool has_all_keys, bool has_all_channels) const { User::UpdateRequest req; for (std::string_view arg : args) { @@ -993,10 +1060,11 @@ std::variant AclFamily::ParseAclSetUser( auto& [glob, op, all_keys, reset_keys] = *res; if ((has_all_keys && !all_keys && !reset_keys) || (req.allow_all_keys && !all_keys && !reset_keys)) { - return ErrorReply( - "Error in ACL SETUSER modifier '~tmp': Adding a pattern after the * pattern (or the " + return ErrorReply(absl::StrCat( + "Error in ACL SETUSER modifier \'", facade::ToSV(arg), + "\': Adding a pattern after the * pattern (or the " "'allkeys' flag) is not valid and does not have any effect. Try 'resetkeys' to start " - "with an empty list of patterns"); + "with an empty list of patterns")); } req.allow_all_keys = all_keys; @@ -1008,6 +1076,26 @@ std::variant AclFamily::ParseAclSetUser( continue; } + if (auto res = MaybeParseAclPubSub(facade::ToSV(arg)); res) { + auto& [glob, has_asterisk, all_channels, reset_channels] = *res; + if ((has_all_channels && !all_channels && !reset_channels) || + (req.all_channels && !all_channels && !reset_channels)) { + return ErrorReply( + absl::StrCat("ERR Error in ACL SETUSER modifier \'", facade::ToSV(arg), + "\': Adding a pattern after the * pattern (or the 'allchannels' flag) is " + "not valid and does not have any effect. Try 'resetchannels' to start " + "with an empty list of channels")); + } + + req.all_channels = all_channels; + req.reset_channels = reset_channels; + if (reset_channels) { + has_all_channels = false; + } + req.pub_sub.push_back({std::move(glob), has_asterisk, all_channels, reset_channels}); + continue; + } + std::string command = absl::AsciiStrToUpper(arg); if (auto status = MaybeParseStatus(command); status) { diff --git a/src/server/acl/acl_family.h b/src/server/acl/acl_family.h index 8ffce0567..22d05df87 100644 --- a/src/server/acl/acl_family.h +++ b/src/server/acl/acl_family.h @@ -52,7 +52,8 @@ class AclFamily final { using Commands = std::vector; void StreamUpdatesToAllProactorConnections(const std::string& user, const Commands& update_commands, - const AclKeys& update_keys); + const AclKeys& update_keys, + const AclPubSub& update_pub_sub); // Helper function that closes all open connection from the deleted user void EvictOpenConnectionsOnAllProactors(const absl::flat_hash_set& user); @@ -83,7 +84,8 @@ class AclFamily final { std::optional MaybeParseNamespace(std::string_view command) const; std::variant ParseAclSetUser( - const facade::ArgRange& args, bool hashed = false, bool has_all_keys = false) const; + const facade::ArgRange& args, bool hashed = false, bool has_all_keys = false, + bool has_all_channels = false) const; void BuildIndexers(RevCommandsIndexStore families); diff --git a/src/server/acl/acl_family_test.cc b/src/server/acl/acl_family_test.cc index 7c41ed404..c1f2f83e8 100644 --- a/src/server/acl/acl_family_test.cc +++ b/src/server/acl/acl_family_test.cc @@ -48,15 +48,16 @@ TEST_F(AclFamilyTest, AclSetUser) { EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "LIST"}); auto vec = resp.GetVec(); - EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass ~* +@all", "user vlad off -@all")); + EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass ~* &* +@all", + "user vlad off resetchannels -@all")); resp = Run({"ACL", "SETUSER", "vlad", "+ACL"}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "LIST"}); vec = resp.GetVec(); - EXPECT_THAT(vec, - UnorderedElementsAre("user default on nopass ~* +@all", "user vlad off -@all +acl")); + EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass ~* &* +@all", + "user vlad off resetchannels -@all +acl")); resp = Run({"ACL", "SETUSER", "vlad", "on", ">pass", ">temp"}); EXPECT_THAT(resp, "OK"); @@ -65,9 +66,11 @@ TEST_F(AclFamilyTest, AclSetUser) { vec = resp.GetVec(); EXPECT_THAT(vec.size(), 2); auto contains_vlad = [](const auto& vec) { - const std::string default_user = "user default on nopass ~* +@all"; - const std::string a_permutation = "user vlad on #a6864eb339b0e1f #d74ff0ee8da3b98 -@all +acl"; - const std::string b_permutation = "user vlad on #d74ff0ee8da3b98 #a6864eb339b0e1f -@all +acl"; + const std::string default_user = "user default on nopass ~* &* +@all"; + const std::string a_permutation = + "user vlad on #a6864eb339b0e1f #d74ff0ee8da3b98 resetchannels -@all +acl"; + const std::string b_permutation = + "user vlad on #d74ff0ee8da3b98 #a6864eb339b0e1f resetchannels -@all +acl"; std::string_view other; if (vec[0] == default_user) { other = vec[1].GetView(); @@ -107,8 +110,8 @@ TEST_F(AclFamilyTest, AclSetUser) { resp = Run({"ACL", "LIST"}); vec = resp.GetVec(); - EXPECT_THAT(vec, - UnorderedElementsAre("user default on nopass ~* +@all", "user vlad on -@all +acl")); + EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass ~* &* +@all", + "user vlad on resetchannels -@all +acl")); // +@NONE should not exist anymore. It's not in the spec. resp = Run({"ACL", "SETUSER", "rand", "+@NONE"}); @@ -139,7 +142,7 @@ TEST_F(AclFamilyTest, AclDelUser) { EXPECT_THAT(resp, IntArg(0)); resp = Run({"ACL", "LIST"}); - EXPECT_THAT(resp.GetString(), "user default on nopass ~* +@all"); + EXPECT_THAT(resp.GetString(), "user default on nopass ~* &* +@all"); Run({"ACL", "SETUSER", "michael", "ON"}); Run({"ACL", "SETUSER", "kobe", "ON"}); @@ -160,9 +163,10 @@ TEST_F(AclFamilyTest, AclList) { resp = Run({"ACL", "LIST"}); auto vec = resp.GetVec(); - EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass ~* +@all", - "user kostas off #d74ff0ee8da3b98 -@all +@admin", - "user adi off #d74ff0ee8da3b98 -@all +@fast")); + EXPECT_THAT(vec, + UnorderedElementsAre("user default on nopass ~* &* +@all", + "user kostas off #d74ff0ee8da3b98 resetchannels -@all +@admin", + "user adi off #d74ff0ee8da3b98 resetchannels -@all +@fast")); } TEST_F(AclFamilyTest, AclAuth) { @@ -210,17 +214,19 @@ TEST_F(AclFamilyTest, TestAllCategories) { EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "LIST"}); - EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("user default on nopass ~* +@all", - absl::StrCat("user kostas off -@all ", "+@", - absl::AsciiStrToLower(cat)))); + EXPECT_THAT(resp.GetVec(), + UnorderedElementsAre("user default on nopass ~* &* +@all", + absl::StrCat("user kostas off resetchannels -@all ", "+@", + absl::AsciiStrToLower(cat)))); resp = Run({"ACL", "SETUSER", "kostas", absl::StrCat("-@", cat)}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "LIST"}); - EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("user default on nopass ~* +@all", - absl::StrCat("user kostas off -@all ", "-@", - absl::AsciiStrToLower(cat)))); + EXPECT_THAT(resp.GetVec(), + UnorderedElementsAre("user default on nopass ~* &* +@all", + absl::StrCat("user kostas off resetchannels -@all ", "-@", + absl::AsciiStrToLower(cat)))); resp = Run({"ACL", "DELUSER", "kostas"}); EXPECT_THAT(resp, IntArg(1)); @@ -259,16 +265,16 @@ TEST_F(AclFamilyTest, TestAllCommands) { resp = Run({"ACL", "LIST"}); EXPECT_THAT(resp.GetVec(), - UnorderedElementsAre("user default on nopass ~* +@all", - absl::StrCat("user kostas off -@all ", "+", + UnorderedElementsAre("user default on nopass ~* &* +@all", + absl::StrCat("user kostas off resetchannels -@all ", "+", absl::AsciiStrToLower(command_name)))); resp = Run({"ACL", "SETUSER", "kostas", absl::StrCat("-", command_name)}); resp = Run({"ACL", "LIST"}); EXPECT_THAT(resp.GetVec(), - UnorderedElementsAre("user default on nopass ~* +@all", - absl::StrCat("user kostas off ", "-@all ", "-", + UnorderedElementsAre("user default on nopass ~* &* +@all", + absl::StrCat("user kostas off resetchannels -@all ", "-", absl::AsciiStrToLower(command_name)))); resp = Run({"ACL", "DELUSER", "kostas"}); @@ -321,6 +327,8 @@ TEST_F(AclFamilyTest, TestGetUser) { EXPECT_THAT(vec[5], "+@all"); EXPECT_THAT(vec[6], "keys"); EXPECT_THAT(vec[7], "~*"); + EXPECT_THAT(vec[8], "channels"); + EXPECT_THAT(vec[9], "&*"); resp = Run({"ACL", "SETUSER", "kostas", "+@STRING", "+HSET"}); resp = Run({"ACL", "GETUSER", "kostas"}); @@ -331,6 +339,10 @@ TEST_F(AclFamilyTest, TestGetUser) { EXPECT_TRUE(kvec[3].GetVec().empty()); EXPECT_THAT(kvec[4], "commands"); EXPECT_THAT(kvec[5], "-@all +@string +hset"); + EXPECT_THAT(kvec[6], "keys"); + EXPECT_THAT(kvec[7], RespArray(ElementsAre())); + EXPECT_THAT(kvec[8], "channels"); + EXPECT_THAT(kvec[9], "resetchannels"); } TEST_F(AclFamilyTest, TestDryRun) { @@ -431,7 +443,7 @@ TEST_F(AclFamilyTest, TestKeys) { EXPECT_THAT(vec[7], "~foo ~bar*"); resp = Run({"ACL", "SETUSER", "temp", "~*", "~foo"}); - EXPECT_THAT(resp, ErrArg("ERR Error in ACL SETUSER modifier '~tmp': Adding a pattern after the * " + EXPECT_THAT(resp, ErrArg("ERR Error in ACL SETUSER modifier '~foo': Adding a pattern after the * " "pattern (or the 'allkeys' flag) is not valid and does not have any " "effect. Try 'resetkeys' to start with an empty list of patterns")); @@ -439,7 +451,7 @@ TEST_F(AclFamilyTest, TestKeys) { EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "SETUSER", "temp", "~foo"}); - EXPECT_THAT(resp, ErrArg("ERR Error in ACL SETUSER modifier '~tmp': Adding a pattern after the * " + EXPECT_THAT(resp, ErrArg("ERR Error in ACL SETUSER modifier '~foo': Adding a pattern after the * " "pattern (or the 'allkeys' flag) is not valid and does not have any " "effect. Try 'resetkeys' to start with an empty list of patterns")); @@ -474,4 +486,37 @@ TEST_F(AclFamilyTest, TestKeys) { EXPECT_THAT(resp, ErrArg("ERR Unrecognized parameter %RFOO")); } +TEST_F(AclFamilyTest, TestPubSub) { + TestInitAclFam(); + + auto resp = Run({"ACL", "SETUSER", "temp", "&foo", "&b*r"}); + EXPECT_THAT(resp, "OK"); + + resp = Run({"ACL", "GETUSER", "temp"}); + auto vec = resp.GetVec(); + EXPECT_THAT(vec[8], "channels"); + EXPECT_THAT(vec[9], "resetchannels &foo &b*r"); + + resp = Run({"ACL", "SETUSER", "temp", "allchannels", "&bar"}); + EXPECT_THAT(resp, ErrArg("ERR Error in ACL SETUSER modifier '&bar': Adding a pattern after the * " + "pattern (or the 'allchannels' flag) is " + "not valid and does not have any effect. Try 'resetchannels' to start " + "with an empty list of channels")); + + resp = Run({"ACL", "SETUSER", "temp", "allchannels"}); + EXPECT_THAT(resp, "OK"); + + resp = Run({"ACL", "GETUSER", "temp"}); + vec = resp.GetVec(); + EXPECT_THAT(vec[8], "channels"); + EXPECT_THAT(vec[9], "&*"); + + resp = Run({"ACL", "SETUSER", "temp", "resetchannels", "&foo"}); + EXPECT_THAT(resp, "OK"); + + resp = Run({"ACL", "GETUSER", "temp"}); + vec = resp.GetVec(); + EXPECT_THAT(vec[8], "channels"); + EXPECT_THAT(vec[9], "resetchannels &foo"); +} } // namespace dfly diff --git a/src/server/acl/acl_log.h b/src/server/acl/acl_log.h index bb206ac9f..158ccc461 100644 --- a/src/server/acl/acl_log.h +++ b/src/server/acl/acl_log.h @@ -18,7 +18,7 @@ class AclLog { public: explicit AclLog(); - enum class Reason { COMMAND, AUTH, KEY }; + enum class Reason { COMMAND, AUTH, KEY, PUB_SUB }; struct LogEntry { std::string username; diff --git a/src/server/acl/user.cc b/src/server/acl/user.cc index 341e6b546..bc0e5e34b 100644 --- a/src/server/acl/user.cc +++ b/src/server/acl/user.cc @@ -76,6 +76,10 @@ void User::Update(UpdateRequest&& req, const CategoryToIdxStore& cat_to_id, SetKeyGlobs(std::move(req.keys)); } + if (!req.pub_sub.empty()) { + SetPubSub(std::move(req.pub_sub)); + } + if (req.is_active) { SetIsActive(*req.is_active); } @@ -214,6 +218,10 @@ const AclKeys& User::Keys() const { return keys_; } +const AclPubSub& User::PubSub() const { + return pub_sub_; +} + const User::CategoryChanges& User::CatChanges() const { return cat_changes_; } @@ -236,6 +244,20 @@ void User::SetKeyGlobs(std::vector keys) { } } +void User::SetPubSub(std::vector pub_sub) { + for (auto& pattern : pub_sub) { + if (pattern.all_channels) { + pub_sub_.globs.clear(); + pub_sub_.all_channels = true; + } else if (pattern.reset_channels) { + pub_sub_.globs.clear(); + pub_sub_.all_channels = false; + } else { + pub_sub_.globs.push_back({std::move(pattern.pattern), pattern.has_asterisk}); + } + } +} + void User::SetNopass() { nopass_ = true; password_hashes_.clear(); diff --git a/src/server/acl/user.h b/src/server/acl/user.h index 187856728..134c8aeb7 100644 --- a/src/server/acl/user.h +++ b/src/server/acl/user.h @@ -40,6 +40,13 @@ class User final { bool is_hashed{false}; }; + struct UpdatePubSub { + std::string pattern; + bool has_asterisk{false}; + bool all_channels{false}; + bool reset_channels{false}; + }; + struct UpdateRequest { std::vector passwords; @@ -59,6 +66,11 @@ class User final { bool reset_all_keys{false}; bool allow_all_keys{false}; + // pub/sub + std::vector pub_sub; + bool reset_channels{false}; + bool all_channels{false}; + // TODO allow reset all // bool reset_all{false}; @@ -107,6 +119,8 @@ class User final { const AclKeys& Keys() const; + const AclPubSub& PubSub() const; + const std::string& Namespace() const; using CategoryChanges = absl::flat_hash_map; @@ -140,6 +154,10 @@ class User final { // For ACL key globs void SetKeyGlobs(std::vector keys); + + // For ACL pub/sub + void SetPubSub(std::vector pub_sub); + void SetNamespace(const std::string& ns); // Set NOPASS and remove all passwords @@ -170,6 +188,9 @@ class User final { // Glob patterns for the keys that a user is allowed to read/write AclKeys keys_; + // Glob patterns for pub/sub channels + AclPubSub pub_sub_; + // if the user is on/off bool is_active_{false}; diff --git a/src/server/acl/user_registry.cc b/src/server/acl/user_registry.cc index 6c5d946af..309d944f1 100644 --- a/src/server/acl/user_registry.cc +++ b/src/server/acl/user_registry.cc @@ -35,8 +35,8 @@ UserCredentials UserRegistry::GetCredentials(std::string_view username) const { if (it == registry_.end()) { return {}; } - return {it->second.AclCategory(), it->second.AclCommands(), it->second.Keys(), - it->second.Namespace()}; + auto& user = it->second; + return {user.AclCategory(), user.AclCommands(), user.Keys(), user.PubSub(), user.Namespace()}; } bool UserRegistry::IsUserActive(std::string_view username) const { @@ -80,6 +80,7 @@ User::UpdateRequest UserRegistry::DefaultUserUpdateRequest() const { req.is_active = true; req.updates = {std::pair{User::Sign::PLUS, acl::ALL}}; req.keys = {User::UpdateKey{"~*", KeyOp::READ_WRITE, true, false}}; + req.pub_sub = {User::UpdatePubSub{"", false, true, false}}; return req; } diff --git a/src/server/acl/validator.cc b/src/server/acl/validator.cc index 9dcb1dba6..5513ece55 100644 --- a/src/server/acl/validator.cc +++ b/src/server/acl/validator.cc @@ -23,8 +23,17 @@ namespace dfly::acl { return true; } - const auto [is_authed, reason] = - IsUserAllowedToInvokeCommandGeneric(cntx.acl_commands, cntx.keys, tail_args, id); + std::pair auth_res; + + if (id.IsPubSub()) { + auth_res = IsPubSubCommandAuthorized(false, cntx.acl_commands, cntx.pub_sub, tail_args, id); + } else if (id.IsPSub()) { + auth_res = IsPubSubCommandAuthorized(true, cntx.acl_commands, cntx.pub_sub, tail_args, id); + } else { + auth_res = IsUserAllowedToInvokeCommandGeneric(cntx.acl_commands, cntx.keys, tail_args, id); + } + + const auto [is_authed, reason] = auth_res; if (!is_authed) { auto& log = ServerState::tlocal()->acl_log; @@ -40,16 +49,18 @@ namespace dfly::acl { #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" #endif -[[nodiscard]] std::pair IsUserAllowedToInvokeCommandGeneric( - const std::vector& acl_commands, const AclKeys& keys, CmdArgList tail_args, - const CommandId& id) { +static bool ValidateCommand(const std::vector& acl_commands, const CommandId& id) { const size_t index = id.GetFamily(); const uint64_t command_mask = id.GetBitIndex(); DCHECK_LT(index, acl_commands.size()); - const bool command = (acl_commands[index] & command_mask) != 0; + return (acl_commands[index] & command_mask) != 0; +} - if (!command) { +[[nodiscard]] std::pair IsUserAllowedToInvokeCommandGeneric( + const std::vector& acl_commands, const AclKeys& keys, CmdArgList tail_args, + const CommandId& id) { + if (!ValidateCommand(acl_commands, id)) { return {false, AclLog::Reason::COMMAND}; } @@ -86,6 +97,39 @@ namespace dfly::acl { return {keys_allowed, AclLog::Reason::KEY}; } +[[nodiscard]] std::pair IsPubSubCommandAuthorized( + bool literal_match, const std::vector& acl_commands, const AclPubSub& pub_sub, + CmdArgList tail_args, const CommandId& id) { + if (!ValidateCommand(acl_commands, id)) { + return {false, AclLog::Reason::COMMAND}; + } + + auto match = [](std::string_view pattern, std::string_view target) { + return stringmatchlen(pattern.data(), pattern.size(), target.data(), target.size(), 0); + }; + + auto iterate_globs = [&](std::string_view target) { + for (auto& [glob, has_asterisk] : pub_sub.globs) { + if (literal_match && (glob == target)) { + return true; + } + if (!literal_match && match(glob, target)) { + return true; + } + } + return false; + }; + + bool allowed = true; + if (!pub_sub.all_channels) { + for (auto channel : tail_args) { + allowed &= iterate_globs(facade::ToSV(channel)); + } + } + + return {allowed, AclLog::Reason::PUB_SUB}; +} + #pragma GCC diagnostic pop } // namespace dfly::acl diff --git a/src/server/acl/validator.h b/src/server/acl/validator.h index 4dd461b1a..70d849f97 100644 --- a/src/server/acl/validator.h +++ b/src/server/acl/validator.h @@ -13,6 +13,7 @@ namespace dfly::acl { struct AclKeys; +struct AclPubSub; std::pair IsUserAllowedToInvokeCommandGeneric( const std::vector& acl_commands, const AclKeys& keys, facade::CmdArgList tail_args, @@ -20,4 +21,11 @@ std::pair IsUserAllowedToInvokeCommandGeneric( bool IsUserAllowedToInvokeCommand(const ConnectionContext& cntx, const CommandId& id, facade::CmdArgList tail_args); + +std::pair IsPubSubCommandAuthorized(bool literal_match, + const std::vector& acl_commands, + const AclPubSub& pub_sub, + facade::CmdArgList tail_args, + const CommandId& id); + } // namespace dfly::acl diff --git a/src/server/conn_context.cc b/src/server/conn_context.cc index 79bb4bd9f..5f8bfe4ec 100644 --- a/src/server/conn_context.cc +++ b/src/server/conn_context.cc @@ -89,6 +89,7 @@ ConnectionContext::ConnectionContext(::io::Sink* stream, facade::Connection* own } keys = std::move(cred.keys); + pub_sub = std::move(cred.pub_sub); if (cred.acl_commands.empty()) { acl_commands = std::vector(acl::NumberOfFamilies(), acl::NONE_COMMANDS); } else { @@ -102,6 +103,7 @@ ConnectionContext::ConnectionContext(const ConnectionContext* owner, Transaction if (owner) { acl_commands = owner->acl_commands; keys = owner->keys; + pub_sub = owner->pub_sub; skip_acl_validation = owner->skip_acl_validation; ns = owner->ns; } else { diff --git a/src/server/server_family.cc b/src/server/server_family.cc index 2dc5c8e76..59e7f4c66 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -1701,6 +1701,7 @@ bool ServerFamily::DoAuth(ConnectionContext* cntx, std::string_view username, auto cred = registry->GetCredentials(username); cntx->acl_commands = cred.acl_commands; cntx->keys = std::move(cred.keys); + cntx->pub_sub = std::move(cred.pub_sub); cntx->ns = &namespaces.GetOrInsert(cred.ns); cntx->authenticated = true; } diff --git a/tests/dragonfly/acl_family_test.py b/tests/dragonfly/acl_family_test.py index f7ae8ecea..81ee14ea6 100644 --- a/tests/dragonfly/acl_family_test.py +++ b/tests/dragonfly/acl_family_test.py @@ -7,60 +7,69 @@ import tempfile import asyncio import os from . import dfly_args +import async_timeout @pytest.mark.asyncio async def test_acl_setuser(async_client): await async_client.execute_command("ACL SETUSER kostas") - result = await async_client.execute_command("ACL list") + result = await async_client.execute_command("ACL LIST") assert 2 == len(result) - assert "user kostas off -@all" in result + assert "user kostas off resetchannels -@all" in result await async_client.execute_command("ACL SETUSER kostas ON") - result = await async_client.execute_command("ACL list") - assert "user kostas on -@all" in result + result = await async_client.execute_command("ACL LIST") + assert "user kostas on resetchannels -@all" in result await async_client.execute_command("ACL SETUSER kostas +@list +@string +@admin") - result = await async_client.execute_command("ACL list") + result = await async_client.execute_command("ACL LIST") # TODO consider printing to lowercase - assert "user kostas on -@all +@list +@string +@admin" in result + assert "user kostas on resetchannels -@all +@list +@string +@admin" in result await async_client.execute_command("ACL SETUSER kostas -@list -@admin") - result = await async_client.execute_command("ACL list") - assert "user kostas on -@all +@string -@list -@admin" in result + result = await async_client.execute_command("ACL LIST") + assert "user kostas on resetchannels -@all +@string -@list -@admin" in result # mix and match await async_client.execute_command("ACL SETUSER kostas +@list -@string") - result = await async_client.execute_command("ACL list") - assert "user kostas on -@all -@admin +@list -@string" in result + result = await async_client.execute_command("ACL LIST") + assert "user kostas on resetchannels -@all -@admin +@list -@string" in result # mix and match interleaved await async_client.execute_command("ACL SETUSER kostas +@set -@set +@set") - result = await async_client.execute_command("ACL list") - assert "user kostas on -@all -@admin +@list -@string +@set" in result + result = await async_client.execute_command("ACL LIST") + assert "user kostas on resetchannels -@all -@admin +@list -@string +@set" in result await async_client.execute_command("ACL SETUSER kostas +@all") - result = await async_client.execute_command("ACL list") - assert "user kostas on -@admin +@list -@string +@set +@all" in result + result = await async_client.execute_command("ACL LIST") + assert "user kostas on resetchannels -@admin +@list -@string +@set +@all" in result # commands await async_client.execute_command("ACL SETUSER kostas +set +get +hset") - result = await async_client.execute_command("ACL list") - assert "user kostas on -@admin +@list -@string +@set +@all +set +get +hset" in result + result = await async_client.execute_command("ACL LIST") + assert ( + "user kostas on resetchannels -@admin +@list -@string +@set +@all +set +get +hset" in result + ) await async_client.execute_command("ACL SETUSER kostas -set -get +hset") - result = await async_client.execute_command("ACL list") - assert "user kostas on -@admin +@list -@string +@set +@all -set -get +hset" in result + result = await async_client.execute_command("ACL LIST") + assert ( + "user kostas on resetchannels -@admin +@list -@string +@set +@all -set -get +hset" in result + ) # interleaved await async_client.execute_command("ACL SETUSER kostas -hset +get -get -@all") - result = await async_client.execute_command("ACL list") - assert "user kostas on -@admin +@list -@string +@set -set -hset -get -@all" in result + result = await async_client.execute_command("ACL LIST") + assert ( + "user kostas on resetchannels -@admin +@list -@string +@set -set -hset -get -@all" in result + ) # interleaved with categories await async_client.execute_command("ACL SETUSER kostas +@string +get -get +set") - result = await async_client.execute_command("ACL list") - assert "user kostas on -@admin +@list +@set -hset -@all +@string -get +set" in result + result = await async_client.execute_command("ACL LIST") + assert ( + "user kostas on resetchannels -@admin +@list +@set -hset -@all +@string -get +set" in result + ) @pytest.mark.asyncio @@ -324,7 +333,7 @@ async def test_bad_acl_file(df_factory, tmp_dir): async def test_good_acl_file(df_factory, tmp_dir): # The hash below is password temp acl = create_temp_file( - "USER MrFoo ON #a6864eb339b0e1f6e00d75293a8840abf069a2c0fe82e6e53af6ac099793c1d5 >mypass", + "USER MrFoo ON #a6864eb339b0e1f6e00d75293a8840abf069a2c0fe82e6e53af6ac099793c1d5 >mypass &bar &r*nd", tmp_dir, ) df = df_factory.create(aclfile=acl) @@ -333,13 +342,14 @@ async def test_good_acl_file(df_factory, tmp_dir): client = df.client() await client.execute_command("ACL LOAD") - result = await client.execute_command("ACL list") + result = await client.execute_command("ACL LIST") assert 2 == len(result) assert ( - "user MrFoo on #ea71c25a7a60224 #a6864eb339b0e1f -@all" in result - or "user MrFoo on #a6864eb339b0e1f #ea71c25a7a60224 -@all" in result + "user MrFoo on #ea71c25a7a60224 #a6864eb339b0e1f resetchannels &bar &r*nd -@all" in result + or "user MrFoo on #a6864eb339b0e1f #ea71c25a7a60224 resetchannels &bar &r*nd -@all" + in result ) - assert "user default on nopass ~* +@all" in result + assert "user default on nopass ~* &* +@all" in result await client.execute_command("ACL SETUSER MrFoo +@all") # Check multiple passwords work assert "OK" == await client.execute_command("AUTH mypass") @@ -351,12 +361,12 @@ async def test_good_acl_file(df_factory, tmp_dir): await client.execute_command("ACL SETUSER shahar >mypass +@set") await client.execute_command("ACL SETUSER vlad ~foo ~bar* +@string") - result = await client.execute_command("ACL list") + result = await client.execute_command("ACL LIST") assert 4 == len(result) - assert "user roy on #ea71c25a7a60224 -@all +@string +hset" in result - assert "user shahar off #ea71c25a7a60224 -@all +@set" in result - assert "user vlad off ~foo ~bar* -@all +@string" in result - assert "user default on nopass ~* +@all" in result + assert "user roy on #ea71c25a7a60224 resetchannels -@all +@string +hset" in result + assert "user shahar off #ea71c25a7a60224 resetchannels -@all +@set" in result + assert "user vlad off ~foo ~bar* resetchannels -@all +@string" in result + assert "user default on nopass ~* &* +@all" in result result = await client.execute_command("ACL DELUSER shahar") assert result == 1 @@ -365,11 +375,11 @@ async def test_good_acl_file(df_factory, tmp_dir): result = await client.execute_command("ACL LOAD") - result = await client.execute_command("ACL list") + result = await client.execute_command("ACL LIST") assert 3 == len(result) - assert "user roy on #ea71c25a7a60224 -@all +@string +hset" in result - assert "user vlad off ~foo ~bar* -@all +@string" in result - assert "user default on nopass ~* +@all" in result + assert "user roy on #ea71c25a7a60224 resetchannels -@all +@string +hset" in result + assert "user vlad off ~foo ~bar* resetchannels -@all +@string" in result + assert "user default on nopass ~* &* +@all" in result await client.close() @@ -483,7 +493,7 @@ async def test_set_acl_file(async_client: aioredis.Redis, tmp_dir): await async_client.execute_command("ACL LOAD") - result = await async_client.execute_command("ACL list") + result = await async_client.execute_command("ACL LIST") assert 3 == len(result) result = await async_client.execute_command("AUTH roy mypass") @@ -635,3 +645,83 @@ async def test_auth_resp3_bug(df_factory): } await client.close() + + +@pytest.mark.asyncio +async def test_acl_pub_sub_auth(df_factory): + df = df_factory.create() + df.start() + client = df.client() + await client.execute_command("ACL SETUSER kostas on >tmp +subscribe +psubscribe &f*o &bar") + assert await client.execute_command("AUTH kostas tmp") == "OK" + + res = await client.execute_command("SUBSCRIBE bar") + assert res == ["subscribe", "bar", 1] + + res = await client.execute_command("SUBSCRIBE foo") + assert res == ["subscribe", "foo", 2] + + with pytest.raises(redis.exceptions.NoPermissionError): + res = await client.execute_command("SUBSCRIBE my_channel") + + # PSUBSCRIBE only matches pure literals, no asterisks + with pytest.raises(redis.exceptions.NoPermissionError): + res = await client.execute_command("PSUBSCRIBE foo") + + # my_channel is not in our list so the command should fail + with pytest.raises(redis.exceptions.NoPermissionError): + res = await client.execute_command("PSUBSCRIBE bar my_channel") + + res = await client.execute_command("PSUBSCRIBE bar") + assert res == ["psubscribe", "bar", 3] + + +@pytest.mark.asyncio +async def test_acl_revoke_pub_sub_while_subscribed(df_factory): + df = df_factory.create() + df.start() + publisher = df.client() + + async def publish_worker(client): + for i in range(0, 10): + await client.publish("channel", "message") + + async def subscribe_worker(channel: aioredis.client.PubSub): + total_msgs = 1 + async with async_timeout.timeout(10): + while total_msgs != 10: + res = await channel.get_message(ignore_subscribe_messages=True, timeout=5) + if total_msgs is not None: + total_msgs = total_msgs + 1 + + await publisher.execute_command("ACL SETUSER kostas >tmp ON +@slow +SUBSCRIBE allchannels") + + subscriber = aioredis.Redis( + username="kostas", password="tmp", port=df.port, decode_responses=True + ) + subscriber_obj = subscriber.pubsub() + await subscriber_obj.subscribe("channel") + + subscribe_task = asyncio.create_task(subscribe_worker(subscriber_obj)) + await publish_worker(publisher) + await subscribe_task + + subscribe_task = asyncio.create_task(subscribe_worker(subscriber_obj)) + # Already subscribed, we should still be able to receive messages on channel + # We should not be able to unsubscribe + await publisher.execute_command("ACL SETUSER kostas -SUBSCRIBE -UNSUBSCRIBE") + await publish_worker(publisher) + await subscribe_task + # unsubscribe is not marked async and it's such a mess that it throws the error + # once we try to resubscribe. Instead I use the raw execute command to check that + # permission changes work + with pytest.raises(redis.exceptions.NoPermissionError): + await subscriber.execute_command("UNSUBSCRIBE channel") + + await publisher.execute_command("ACL SETUSER kostas +SUBSCRIBE +UNSUBSCRIBE") + + subscribe_task = asyncio.create_task(subscribe_worker(subscriber_obj)) + await publisher.execute_command("ACL SETUSER kostas resetchannels") + await publish_worker(publisher) + with pytest.raises(redis.exceptions.ConnectionError): + await subscribe_task