1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-14 11:58:02 +00:00

feat(acl): add pub/sub (#3574)

* add support for pub/sub
* add tests
---------

Signed-off-by: kostas <kostas@dragonflydb.io>
This commit is contained in:
Kostas Kyrimis 2024-08-30 15:41:28 +03:00 committed by GitHub
parent a22eff15dc
commit 0705bbb536
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 445 additions and 85 deletions

View file

@ -21,13 +21,27 @@ using GlobType = std::pair<std::string, KeyOp>;
struct AclKeys {
std::vector<GlobType> key_globs;
// The user is allowed to "touch" any key. No glob matching required.
// Alias for ~*
bool all_keys = false;
};
// The second bool denotes if the pattern contains an asterisk and it's
// used to pattern match PSUBSCRIBE that requires exact literals
using GlobTypePubSub = std::pair<std::string, bool>;
struct AclPubSub {
std::vector<GlobTypePubSub> globs;
// The user can execute any variant of pub/sub/psub. No glob matching required.
// Alias for &* just like all_keys for AclKeys above.
bool all_channels = false;
};
struct UserCredentials {
uint32_t acl_categories{0};
std::vector<uint64_t> acl_commands;
AclKeys keys;
AclPubSub pub_sub;
std::string ns;
};

View file

@ -86,6 +86,16 @@ class CommandId {
static uint32_t OptCount(uint32_t mask);
// PUBLISH/SUBSCRIBE/UNSUBSCRIBE variant
bool IsPubSub() const {
return is_pub_sub_;
}
// PSUBSCRIBE/PUNSUBSCRIBE variant
bool IsPSub() const {
return is_p_sub_;
}
protected:
std::string name_;
@ -102,6 +112,9 @@ class CommandId {
// Whether the command can only be used by admin connections.
bool restricted_ = false;
bool is_pub_sub_ = false;
bool is_p_sub_ = false;
};
} // namespace facade

View file

@ -107,6 +107,8 @@ class ConnectionContext {
std::vector<uint64_t> acl_commands;
// keys
dfly::acl::AclKeys keys{{}, true};
// pub/sub
dfly::acl::AclPubSub pub_sub{{}, true};
private:
Connection* owner_;

View file

@ -451,6 +451,7 @@ void Connection::DispatchOperations::operator()(const AclUpdateMessage& msg) {
if (msg.username == self->cntx()->authed_username) {
self->cntx()->acl_commands = msg.commands;
self->cntx()->keys = msg.keys;
self->cntx()->pub_sub = msg.pub_sub;
}
}
}

View file

@ -113,6 +113,7 @@ class Connection : public util::Connection {
std::string username;
std::vector<uint64_t> commands;
dfly::acl::AclKeys keys;
dfly::acl::AclPubSub pub_sub;
};
// Migration request message, the dispatch fiber stops to give way for thread migration.

View file

@ -131,6 +131,11 @@ CommandId::CommandId(const char* name, uint32_t mask, int8_t arity, int8_t first
first_key_(first_key),
last_key_(last_key),
acl_categories_(acl_categories) {
if (name_ == "PUBLISH" || name_ == "SUBSCRIBE" || name_ == "UNSUBSCRIBE") {
is_pub_sub_ = true;
} else if (name_ == "PSUBSCRIBE" || name_ == "PUNSUBSCRIBE") {
is_p_sub_ = true;
}
}
uint32_t CommandId::OptCount(uint32_t mask) {

View file

@ -55,6 +55,8 @@ MaterializedContents MaterializeFileContents(std::vector<std::string>* usernames
std::string_view file_contents);
std::string AclKeysToString(const AclKeys& keys);
std::string AclPubSubToString(const AclPubSub& pub_sub);
} // namespace
AclFamily::AclFamily(UserRegistry* registry, util::ProactorPool* pool)
@ -76,6 +78,9 @@ void AclFamily::List(CmdArgList args, ConnectionContext* cntx) {
const std::string password = PasswordsToString(user.Passwords(), user.HasNopass(), false);
const std::string acl_keys = AclKeysToString(user.Keys());
const std::string acl_pub_sub = AclPubSubToString(user.PubSub());
const std::string maybe_space_com = acl_keys.empty() ? "" : " ";
const std::string acl_cat_and_commands =
@ -84,7 +89,7 @@ void AclFamily::List(CmdArgList args, ConnectionContext* cntx) {
using namespace std::string_view_literals;
absl::StrAppend(&buffer, username, " ", user.IsActive() ? "on "sv : "off "sv, password,
acl_keys, maybe_space_com, acl_cat_and_commands);
acl_keys, maybe_space_com, acl_pub_sub, " ", acl_cat_and_commands);
cntx->SendSimpleString(buffer);
}
@ -92,14 +97,15 @@ void AclFamily::List(CmdArgList args, ConnectionContext* cntx) {
void AclFamily::StreamUpdatesToAllProactorConnections(const std::string& user,
const Commands& update_commands,
const AclKeys& update_keys) {
const AclKeys& update_keys,
const AclPubSub& update_pub_sub) {
auto update_cb = [&]([[maybe_unused]] size_t id, util::Connection* conn) {
DCHECK(conn);
auto connection = static_cast<facade::Connection*>(conn);
if (connection->protocol() == facade::Protocol::REDIS && !connection->IsHttp() &&
connection->cntx()) {
connection->SendAclUpdateAsync(
facade::Connection::AclUpdateMessage{user, update_commands, update_keys});
facade::Connection::AclUpdateMessage{user, update_commands, update_keys, update_pub_sub});
}
};
@ -128,11 +134,20 @@ void AclFamily::SetUser(CmdArgList args, ConnectionContext* cntx) {
user.Update(std::move(default_req), CategoryToIdx(), reverse_cat_table_,
CategoryToCommandsIndex());
}
const bool reset_channels = req.reset_channels;
user.Update(std::move(req), CategoryToIdx(), reverse_cat_table_, CategoryToCommandsIndex());
if (exists) {
StreamUpdatesToAllProactorConnections(std::string(username), user.AclCommands(), user.Keys());
}
// Send ok first because the connection might get evicted
cntx->SendOk();
if (exists) {
if (!reset_channels) {
StreamUpdatesToAllProactorConnections(std::string(username), user.AclCommands(),
user.Keys(), user.PubSub());
}
// We evict connections that had their channels reseted
else {
EvictOpenConnectionsOnAllProactors({username});
}
}
};
std::visit(Overloaded{error_case, update_case}, std::move(req));
@ -208,15 +223,18 @@ std::string AclFamily::RegistryToString() const {
const std::string password = PasswordsToString(user.Passwords(), user.HasNopass(), true);
const std::string acl_keys = AclKeysToString(user.Keys());
const std::string maybe_space = acl_keys.empty() ? "" : " ";
const std::string acl_pub_sub = AclPubSubToString(user.PubSub());
const std::string acl_cat_and_commands =
AclCatAndCommandToString(user.CatChanges(), user.CmdChanges());
using namespace std::string_view_literals;
absl::StrAppend(&result, command, username, " ", user.IsActive() ? "ON "sv : "OFF "sv, password,
acl_keys, maybe_space, acl_cat_and_commands, "\n");
acl_keys, maybe_space, acl_pub_sub, " ", acl_cat_and_commands, "\n");
}
return result;
@ -391,6 +409,8 @@ void AclFamily::Log(CmdArgList args, ConnectionContext* cntx) {
reason = "COMMAND";
} else if (entry.reason == Reason::KEY) {
reason = "KEY";
} else if (entry.reason == Reason::PUB_SUB) {
reason = "PUB_SUB";
} else {
reason = "AUTH";
}
@ -511,7 +531,7 @@ void AclFamily::GetUser(CmdArgList args, ConnectionContext* cntx) {
}
auto* rb = static_cast<facade::RedisReplyBuilder*>(cntx->reply_builder());
rb->StartArray(8);
rb->StartArray(10);
rb->SendSimpleString("flags");
const size_t total_elements = (pass != "nopass") ? 1 : 2;
@ -541,6 +561,10 @@ void AclFamily::GetUser(CmdArgList args, ConnectionContext* cntx) {
} else {
rb->SendEmptyArray();
}
rb->SendSimpleString("channels");
std::string pub_sub = AclPubSubToString(user.PubSub());
rb->SendSimpleString(pub_sub);
}
void AclFamily::GenPass(CmdArgList args, ConnectionContext* cntx) {
@ -815,6 +839,31 @@ std::optional<ParseKeyResult> MaybeParseAclKey(std::string_view command) {
return ParseKeyResult{std::string(key), op};
}
struct ParsePubSubResult {
std::string glob;
bool has_asterisk{false};
bool all_channels{false};
bool reset_channels{false};
};
std::optional<ParsePubSubResult> MaybeParseAclPubSub(std::string_view command) {
if (absl::EqualsIgnoreCase(command, "ALLCHANNELS") || command == "&*") {
return ParsePubSubResult{"", false, true, false};
}
if (absl::EqualsIgnoreCase(command, "RESETCHANNELS")) {
return ParsePubSubResult{"", false, false, true};
}
if (absl::StartsWith(command, "&") && command.size() >= 2) {
const auto glob = command.substr(1);
const bool has_asterisk = glob.find('*') != std::string_view::npos;
return ParsePubSubResult{std::string(glob), has_asterisk};
}
return {};
}
std::string PrettyPrintSha(std::string_view pass, bool all) {
if (all) {
return absl::BytesToHexString(pass);
@ -886,6 +935,24 @@ std::string AclKeysToString(const AclKeys& keys) {
return result;
}
std::string AclPubSubToString(const AclPubSub& pub_sub) {
if (pub_sub.all_channels) {
return "&*";
}
std::string result = "resetchannels ";
for (const auto& [glob, has_asterisk] : pub_sub.globs) {
absl::StrAppend(&result, "&", glob, " ");
}
if (result.back() == ' ') {
result.pop_back();
}
return result;
}
} // namespace
std::string AclFamily::AclCatAndCommandToString(const User::CategoryChanges& cat,
@ -976,7 +1043,7 @@ std::pair<AclFamily::OptCommand, bool> AclFamily::MaybeParseAclCommand(
using facade::ErrorReply;
std::variant<User::UpdateRequest, ErrorReply> AclFamily::ParseAclSetUser(
const facade::ArgRange& args, bool hashed, bool has_all_keys) const {
const facade::ArgRange& args, bool hashed, bool has_all_keys, bool has_all_channels) const {
User::UpdateRequest req;
for (std::string_view arg : args) {
@ -993,10 +1060,11 @@ std::variant<User::UpdateRequest, ErrorReply> AclFamily::ParseAclSetUser(
auto& [glob, op, all_keys, reset_keys] = *res;
if ((has_all_keys && !all_keys && !reset_keys) ||
(req.allow_all_keys && !all_keys && !reset_keys)) {
return ErrorReply(
"Error in ACL SETUSER modifier '~tmp': Adding a pattern after the * pattern (or the "
return ErrorReply(absl::StrCat(
"Error in ACL SETUSER modifier \'", facade::ToSV(arg),
"\': Adding a pattern after the * pattern (or the "
"'allkeys' flag) is not valid and does not have any effect. Try 'resetkeys' to start "
"with an empty list of patterns");
"with an empty list of patterns"));
}
req.allow_all_keys = all_keys;
@ -1008,6 +1076,26 @@ std::variant<User::UpdateRequest, ErrorReply> AclFamily::ParseAclSetUser(
continue;
}
if (auto res = MaybeParseAclPubSub(facade::ToSV(arg)); res) {
auto& [glob, has_asterisk, all_channels, reset_channels] = *res;
if ((has_all_channels && !all_channels && !reset_channels) ||
(req.all_channels && !all_channels && !reset_channels)) {
return ErrorReply(
absl::StrCat("ERR Error in ACL SETUSER modifier \'", facade::ToSV(arg),
"\': Adding a pattern after the * pattern (or the 'allchannels' flag) is "
"not valid and does not have any effect. Try 'resetchannels' to start "
"with an empty list of channels"));
}
req.all_channels = all_channels;
req.reset_channels = reset_channels;
if (reset_channels) {
has_all_channels = false;
}
req.pub_sub.push_back({std::move(glob), has_asterisk, all_channels, reset_channels});
continue;
}
std::string command = absl::AsciiStrToUpper(arg);
if (auto status = MaybeParseStatus(command); status) {

View file

@ -52,7 +52,8 @@ class AclFamily final {
using Commands = std::vector<uint64_t>;
void StreamUpdatesToAllProactorConnections(const std::string& user,
const Commands& update_commands,
const AclKeys& update_keys);
const AclKeys& update_keys,
const AclPubSub& update_pub_sub);
// Helper function that closes all open connection from the deleted user
void EvictOpenConnectionsOnAllProactors(const absl::flat_hash_set<std::string_view>& user);
@ -83,7 +84,8 @@ class AclFamily final {
std::optional<std::string> MaybeParseNamespace(std::string_view command) const;
std::variant<User::UpdateRequest, facade::ErrorReply> ParseAclSetUser(
const facade::ArgRange& args, bool hashed = false, bool has_all_keys = false) const;
const facade::ArgRange& args, bool hashed = false, bool has_all_keys = false,
bool has_all_channels = false) const;
void BuildIndexers(RevCommandsIndexStore families);

View file

@ -48,15 +48,16 @@ TEST_F(AclFamilyTest, AclSetUser) {
EXPECT_THAT(resp, "OK");
resp = Run({"ACL", "LIST"});
auto vec = resp.GetVec();
EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass ~* +@all", "user vlad off -@all"));
EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass ~* &* +@all",
"user vlad off resetchannels -@all"));
resp = Run({"ACL", "SETUSER", "vlad", "+ACL"});
EXPECT_THAT(resp, "OK");
resp = Run({"ACL", "LIST"});
vec = resp.GetVec();
EXPECT_THAT(vec,
UnorderedElementsAre("user default on nopass ~* +@all", "user vlad off -@all +acl"));
EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass ~* &* +@all",
"user vlad off resetchannels -@all +acl"));
resp = Run({"ACL", "SETUSER", "vlad", "on", ">pass", ">temp"});
EXPECT_THAT(resp, "OK");
@ -65,9 +66,11 @@ TEST_F(AclFamilyTest, AclSetUser) {
vec = resp.GetVec();
EXPECT_THAT(vec.size(), 2);
auto contains_vlad = [](const auto& vec) {
const std::string default_user = "user default on nopass ~* +@all";
const std::string a_permutation = "user vlad on #a6864eb339b0e1f #d74ff0ee8da3b98 -@all +acl";
const std::string b_permutation = "user vlad on #d74ff0ee8da3b98 #a6864eb339b0e1f -@all +acl";
const std::string default_user = "user default on nopass ~* &* +@all";
const std::string a_permutation =
"user vlad on #a6864eb339b0e1f #d74ff0ee8da3b98 resetchannels -@all +acl";
const std::string b_permutation =
"user vlad on #d74ff0ee8da3b98 #a6864eb339b0e1f resetchannels -@all +acl";
std::string_view other;
if (vec[0] == default_user) {
other = vec[1].GetView();
@ -107,8 +110,8 @@ TEST_F(AclFamilyTest, AclSetUser) {
resp = Run({"ACL", "LIST"});
vec = resp.GetVec();
EXPECT_THAT(vec,
UnorderedElementsAre("user default on nopass ~* +@all", "user vlad on -@all +acl"));
EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass ~* &* +@all",
"user vlad on resetchannels -@all +acl"));
// +@NONE should not exist anymore. It's not in the spec.
resp = Run({"ACL", "SETUSER", "rand", "+@NONE"});
@ -139,7 +142,7 @@ TEST_F(AclFamilyTest, AclDelUser) {
EXPECT_THAT(resp, IntArg(0));
resp = Run({"ACL", "LIST"});
EXPECT_THAT(resp.GetString(), "user default on nopass ~* +@all");
EXPECT_THAT(resp.GetString(), "user default on nopass ~* &* +@all");
Run({"ACL", "SETUSER", "michael", "ON"});
Run({"ACL", "SETUSER", "kobe", "ON"});
@ -160,9 +163,10 @@ TEST_F(AclFamilyTest, AclList) {
resp = Run({"ACL", "LIST"});
auto vec = resp.GetVec();
EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass ~* +@all",
"user kostas off #d74ff0ee8da3b98 -@all +@admin",
"user adi off #d74ff0ee8da3b98 -@all +@fast"));
EXPECT_THAT(vec,
UnorderedElementsAre("user default on nopass ~* &* +@all",
"user kostas off #d74ff0ee8da3b98 resetchannels -@all +@admin",
"user adi off #d74ff0ee8da3b98 resetchannels -@all +@fast"));
}
TEST_F(AclFamilyTest, AclAuth) {
@ -210,17 +214,19 @@ TEST_F(AclFamilyTest, TestAllCategories) {
EXPECT_THAT(resp, "OK");
resp = Run({"ACL", "LIST"});
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("user default on nopass ~* +@all",
absl::StrCat("user kostas off -@all ", "+@",
absl::AsciiStrToLower(cat))));
EXPECT_THAT(resp.GetVec(),
UnorderedElementsAre("user default on nopass ~* &* +@all",
absl::StrCat("user kostas off resetchannels -@all ", "+@",
absl::AsciiStrToLower(cat))));
resp = Run({"ACL", "SETUSER", "kostas", absl::StrCat("-@", cat)});
EXPECT_THAT(resp, "OK");
resp = Run({"ACL", "LIST"});
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("user default on nopass ~* +@all",
absl::StrCat("user kostas off -@all ", "-@",
absl::AsciiStrToLower(cat))));
EXPECT_THAT(resp.GetVec(),
UnorderedElementsAre("user default on nopass ~* &* +@all",
absl::StrCat("user kostas off resetchannels -@all ", "-@",
absl::AsciiStrToLower(cat))));
resp = Run({"ACL", "DELUSER", "kostas"});
EXPECT_THAT(resp, IntArg(1));
@ -259,16 +265,16 @@ TEST_F(AclFamilyTest, TestAllCommands) {
resp = Run({"ACL", "LIST"});
EXPECT_THAT(resp.GetVec(),
UnorderedElementsAre("user default on nopass ~* +@all",
absl::StrCat("user kostas off -@all ", "+",
UnorderedElementsAre("user default on nopass ~* &* +@all",
absl::StrCat("user kostas off resetchannels -@all ", "+",
absl::AsciiStrToLower(command_name))));
resp = Run({"ACL", "SETUSER", "kostas", absl::StrCat("-", command_name)});
resp = Run({"ACL", "LIST"});
EXPECT_THAT(resp.GetVec(),
UnorderedElementsAre("user default on nopass ~* +@all",
absl::StrCat("user kostas off ", "-@all ", "-",
UnorderedElementsAre("user default on nopass ~* &* +@all",
absl::StrCat("user kostas off resetchannels -@all ", "-",
absl::AsciiStrToLower(command_name))));
resp = Run({"ACL", "DELUSER", "kostas"});
@ -321,6 +327,8 @@ TEST_F(AclFamilyTest, TestGetUser) {
EXPECT_THAT(vec[5], "+@all");
EXPECT_THAT(vec[6], "keys");
EXPECT_THAT(vec[7], "~*");
EXPECT_THAT(vec[8], "channels");
EXPECT_THAT(vec[9], "&*");
resp = Run({"ACL", "SETUSER", "kostas", "+@STRING", "+HSET"});
resp = Run({"ACL", "GETUSER", "kostas"});
@ -331,6 +339,10 @@ TEST_F(AclFamilyTest, TestGetUser) {
EXPECT_TRUE(kvec[3].GetVec().empty());
EXPECT_THAT(kvec[4], "commands");
EXPECT_THAT(kvec[5], "-@all +@string +hset");
EXPECT_THAT(kvec[6], "keys");
EXPECT_THAT(kvec[7], RespArray(ElementsAre()));
EXPECT_THAT(kvec[8], "channels");
EXPECT_THAT(kvec[9], "resetchannels");
}
TEST_F(AclFamilyTest, TestDryRun) {
@ -431,7 +443,7 @@ TEST_F(AclFamilyTest, TestKeys) {
EXPECT_THAT(vec[7], "~foo ~bar*");
resp = Run({"ACL", "SETUSER", "temp", "~*", "~foo"});
EXPECT_THAT(resp, ErrArg("ERR Error in ACL SETUSER modifier '~tmp': Adding a pattern after the * "
EXPECT_THAT(resp, ErrArg("ERR Error in ACL SETUSER modifier '~foo': Adding a pattern after the * "
"pattern (or the 'allkeys' flag) is not valid and does not have any "
"effect. Try 'resetkeys' to start with an empty list of patterns"));
@ -439,7 +451,7 @@ TEST_F(AclFamilyTest, TestKeys) {
EXPECT_THAT(resp, "OK");
resp = Run({"ACL", "SETUSER", "temp", "~foo"});
EXPECT_THAT(resp, ErrArg("ERR Error in ACL SETUSER modifier '~tmp': Adding a pattern after the * "
EXPECT_THAT(resp, ErrArg("ERR Error in ACL SETUSER modifier '~foo': Adding a pattern after the * "
"pattern (or the 'allkeys' flag) is not valid and does not have any "
"effect. Try 'resetkeys' to start with an empty list of patterns"));
@ -474,4 +486,37 @@ TEST_F(AclFamilyTest, TestKeys) {
EXPECT_THAT(resp, ErrArg("ERR Unrecognized parameter %RFOO"));
}
TEST_F(AclFamilyTest, TestPubSub) {
TestInitAclFam();
auto resp = Run({"ACL", "SETUSER", "temp", "&foo", "&b*r"});
EXPECT_THAT(resp, "OK");
resp = Run({"ACL", "GETUSER", "temp"});
auto vec = resp.GetVec();
EXPECT_THAT(vec[8], "channels");
EXPECT_THAT(vec[9], "resetchannels &foo &b*r");
resp = Run({"ACL", "SETUSER", "temp", "allchannels", "&bar"});
EXPECT_THAT(resp, ErrArg("ERR Error in ACL SETUSER modifier '&bar': Adding a pattern after the * "
"pattern (or the 'allchannels' flag) is "
"not valid and does not have any effect. Try 'resetchannels' to start "
"with an empty list of channels"));
resp = Run({"ACL", "SETUSER", "temp", "allchannels"});
EXPECT_THAT(resp, "OK");
resp = Run({"ACL", "GETUSER", "temp"});
vec = resp.GetVec();
EXPECT_THAT(vec[8], "channels");
EXPECT_THAT(vec[9], "&*");
resp = Run({"ACL", "SETUSER", "temp", "resetchannels", "&foo"});
EXPECT_THAT(resp, "OK");
resp = Run({"ACL", "GETUSER", "temp"});
vec = resp.GetVec();
EXPECT_THAT(vec[8], "channels");
EXPECT_THAT(vec[9], "resetchannels &foo");
}
} // namespace dfly

View file

@ -18,7 +18,7 @@ class AclLog {
public:
explicit AclLog();
enum class Reason { COMMAND, AUTH, KEY };
enum class Reason { COMMAND, AUTH, KEY, PUB_SUB };
struct LogEntry {
std::string username;

View file

@ -76,6 +76,10 @@ void User::Update(UpdateRequest&& req, const CategoryToIdxStore& cat_to_id,
SetKeyGlobs(std::move(req.keys));
}
if (!req.pub_sub.empty()) {
SetPubSub(std::move(req.pub_sub));
}
if (req.is_active) {
SetIsActive(*req.is_active);
}
@ -214,6 +218,10 @@ const AclKeys& User::Keys() const {
return keys_;
}
const AclPubSub& User::PubSub() const {
return pub_sub_;
}
const User::CategoryChanges& User::CatChanges() const {
return cat_changes_;
}
@ -236,6 +244,20 @@ void User::SetKeyGlobs(std::vector<UpdateKey> keys) {
}
}
void User::SetPubSub(std::vector<UpdatePubSub> pub_sub) {
for (auto& pattern : pub_sub) {
if (pattern.all_channels) {
pub_sub_.globs.clear();
pub_sub_.all_channels = true;
} else if (pattern.reset_channels) {
pub_sub_.globs.clear();
pub_sub_.all_channels = false;
} else {
pub_sub_.globs.push_back({std::move(pattern.pattern), pattern.has_asterisk});
}
}
}
void User::SetNopass() {
nopass_ = true;
password_hashes_.clear();

View file

@ -40,6 +40,13 @@ class User final {
bool is_hashed{false};
};
struct UpdatePubSub {
std::string pattern;
bool has_asterisk{false};
bool all_channels{false};
bool reset_channels{false};
};
struct UpdateRequest {
std::vector<UpdatePass> passwords;
@ -59,6 +66,11 @@ class User final {
bool reset_all_keys{false};
bool allow_all_keys{false};
// pub/sub
std::vector<UpdatePubSub> pub_sub;
bool reset_channels{false};
bool all_channels{false};
// TODO allow reset all
// bool reset_all{false};
@ -107,6 +119,8 @@ class User final {
const AclKeys& Keys() const;
const AclPubSub& PubSub() const;
const std::string& Namespace() const;
using CategoryChanges = absl::flat_hash_map<CategoryChange, ChangeMetadata>;
@ -140,6 +154,10 @@ class User final {
// For ACL key globs
void SetKeyGlobs(std::vector<UpdateKey> keys);
// For ACL pub/sub
void SetPubSub(std::vector<UpdatePubSub> pub_sub);
void SetNamespace(const std::string& ns);
// Set NOPASS and remove all passwords
@ -170,6 +188,9 @@ class User final {
// Glob patterns for the keys that a user is allowed to read/write
AclKeys keys_;
// Glob patterns for pub/sub channels
AclPubSub pub_sub_;
// if the user is on/off
bool is_active_{false};

View file

@ -35,8 +35,8 @@ UserCredentials UserRegistry::GetCredentials(std::string_view username) const {
if (it == registry_.end()) {
return {};
}
return {it->second.AclCategory(), it->second.AclCommands(), it->second.Keys(),
it->second.Namespace()};
auto& user = it->second;
return {user.AclCategory(), user.AclCommands(), user.Keys(), user.PubSub(), user.Namespace()};
}
bool UserRegistry::IsUserActive(std::string_view username) const {
@ -80,6 +80,7 @@ User::UpdateRequest UserRegistry::DefaultUserUpdateRequest() const {
req.is_active = true;
req.updates = {std::pair<User::Sign, uint32_t>{User::Sign::PLUS, acl::ALL}};
req.keys = {User::UpdateKey{"~*", KeyOp::READ_WRITE, true, false}};
req.pub_sub = {User::UpdatePubSub{"", false, true, false}};
return req;
}

View file

@ -23,8 +23,17 @@ namespace dfly::acl {
return true;
}
const auto [is_authed, reason] =
IsUserAllowedToInvokeCommandGeneric(cntx.acl_commands, cntx.keys, tail_args, id);
std::pair<bool, AclLog::Reason> auth_res;
if (id.IsPubSub()) {
auth_res = IsPubSubCommandAuthorized(false, cntx.acl_commands, cntx.pub_sub, tail_args, id);
} else if (id.IsPSub()) {
auth_res = IsPubSubCommandAuthorized(true, cntx.acl_commands, cntx.pub_sub, tail_args, id);
} else {
auth_res = IsUserAllowedToInvokeCommandGeneric(cntx.acl_commands, cntx.keys, tail_args, id);
}
const auto [is_authed, reason] = auth_res;
if (!is_authed) {
auto& log = ServerState::tlocal()->acl_log;
@ -40,16 +49,18 @@ namespace dfly::acl {
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#endif
[[nodiscard]] std::pair<bool, AclLog::Reason> IsUserAllowedToInvokeCommandGeneric(
const std::vector<uint64_t>& acl_commands, const AclKeys& keys, CmdArgList tail_args,
const CommandId& id) {
static bool ValidateCommand(const std::vector<uint64_t>& acl_commands, const CommandId& id) {
const size_t index = id.GetFamily();
const uint64_t command_mask = id.GetBitIndex();
DCHECK_LT(index, acl_commands.size());
const bool command = (acl_commands[index] & command_mask) != 0;
return (acl_commands[index] & command_mask) != 0;
}
if (!command) {
[[nodiscard]] std::pair<bool, AclLog::Reason> IsUserAllowedToInvokeCommandGeneric(
const std::vector<uint64_t>& acl_commands, const AclKeys& keys, CmdArgList tail_args,
const CommandId& id) {
if (!ValidateCommand(acl_commands, id)) {
return {false, AclLog::Reason::COMMAND};
}
@ -86,6 +97,39 @@ namespace dfly::acl {
return {keys_allowed, AclLog::Reason::KEY};
}
[[nodiscard]] std::pair<bool, AclLog::Reason> IsPubSubCommandAuthorized(
bool literal_match, const std::vector<uint64_t>& acl_commands, const AclPubSub& pub_sub,
CmdArgList tail_args, const CommandId& id) {
if (!ValidateCommand(acl_commands, id)) {
return {false, AclLog::Reason::COMMAND};
}
auto match = [](std::string_view pattern, std::string_view target) {
return stringmatchlen(pattern.data(), pattern.size(), target.data(), target.size(), 0);
};
auto iterate_globs = [&](std::string_view target) {
for (auto& [glob, has_asterisk] : pub_sub.globs) {
if (literal_match && (glob == target)) {
return true;
}
if (!literal_match && match(glob, target)) {
return true;
}
}
return false;
};
bool allowed = true;
if (!pub_sub.all_channels) {
for (auto channel : tail_args) {
allowed &= iterate_globs(facade::ToSV(channel));
}
}
return {allowed, AclLog::Reason::PUB_SUB};
}
#pragma GCC diagnostic pop
} // namespace dfly::acl

View file

@ -13,6 +13,7 @@
namespace dfly::acl {
struct AclKeys;
struct AclPubSub;
std::pair<bool, AclLog::Reason> IsUserAllowedToInvokeCommandGeneric(
const std::vector<uint64_t>& acl_commands, const AclKeys& keys, facade::CmdArgList tail_args,
@ -20,4 +21,11 @@ std::pair<bool, AclLog::Reason> IsUserAllowedToInvokeCommandGeneric(
bool IsUserAllowedToInvokeCommand(const ConnectionContext& cntx, const CommandId& id,
facade::CmdArgList tail_args);
std::pair<bool, AclLog::Reason> IsPubSubCommandAuthorized(bool literal_match,
const std::vector<uint64_t>& acl_commands,
const AclPubSub& pub_sub,
facade::CmdArgList tail_args,
const CommandId& id);
} // namespace dfly::acl

View file

@ -89,6 +89,7 @@ ConnectionContext::ConnectionContext(::io::Sink* stream, facade::Connection* own
}
keys = std::move(cred.keys);
pub_sub = std::move(cred.pub_sub);
if (cred.acl_commands.empty()) {
acl_commands = std::vector<uint64_t>(acl::NumberOfFamilies(), acl::NONE_COMMANDS);
} else {
@ -102,6 +103,7 @@ ConnectionContext::ConnectionContext(const ConnectionContext* owner, Transaction
if (owner) {
acl_commands = owner->acl_commands;
keys = owner->keys;
pub_sub = owner->pub_sub;
skip_acl_validation = owner->skip_acl_validation;
ns = owner->ns;
} else {

View file

@ -1701,6 +1701,7 @@ bool ServerFamily::DoAuth(ConnectionContext* cntx, std::string_view username,
auto cred = registry->GetCredentials(username);
cntx->acl_commands = cred.acl_commands;
cntx->keys = std::move(cred.keys);
cntx->pub_sub = std::move(cred.pub_sub);
cntx->ns = &namespaces.GetOrInsert(cred.ns);
cntx->authenticated = true;
}

View file

@ -7,60 +7,69 @@ import tempfile
import asyncio
import os
from . import dfly_args
import async_timeout
@pytest.mark.asyncio
async def test_acl_setuser(async_client):
await async_client.execute_command("ACL SETUSER kostas")
result = await async_client.execute_command("ACL list")
result = await async_client.execute_command("ACL LIST")
assert 2 == len(result)
assert "user kostas off -@all" in result
assert "user kostas off resetchannels -@all" in result
await async_client.execute_command("ACL SETUSER kostas ON")
result = await async_client.execute_command("ACL list")
assert "user kostas on -@all" in result
result = await async_client.execute_command("ACL LIST")
assert "user kostas on resetchannels -@all" in result
await async_client.execute_command("ACL SETUSER kostas +@list +@string +@admin")
result = await async_client.execute_command("ACL list")
result = await async_client.execute_command("ACL LIST")
# TODO consider printing to lowercase
assert "user kostas on -@all +@list +@string +@admin" in result
assert "user kostas on resetchannels -@all +@list +@string +@admin" in result
await async_client.execute_command("ACL SETUSER kostas -@list -@admin")
result = await async_client.execute_command("ACL list")
assert "user kostas on -@all +@string -@list -@admin" in result
result = await async_client.execute_command("ACL LIST")
assert "user kostas on resetchannels -@all +@string -@list -@admin" in result
# mix and match
await async_client.execute_command("ACL SETUSER kostas +@list -@string")
result = await async_client.execute_command("ACL list")
assert "user kostas on -@all -@admin +@list -@string" in result
result = await async_client.execute_command("ACL LIST")
assert "user kostas on resetchannels -@all -@admin +@list -@string" in result
# mix and match interleaved
await async_client.execute_command("ACL SETUSER kostas +@set -@set +@set")
result = await async_client.execute_command("ACL list")
assert "user kostas on -@all -@admin +@list -@string +@set" in result
result = await async_client.execute_command("ACL LIST")
assert "user kostas on resetchannels -@all -@admin +@list -@string +@set" in result
await async_client.execute_command("ACL SETUSER kostas +@all")
result = await async_client.execute_command("ACL list")
assert "user kostas on -@admin +@list -@string +@set +@all" in result
result = await async_client.execute_command("ACL LIST")
assert "user kostas on resetchannels -@admin +@list -@string +@set +@all" in result
# commands
await async_client.execute_command("ACL SETUSER kostas +set +get +hset")
result = await async_client.execute_command("ACL list")
assert "user kostas on -@admin +@list -@string +@set +@all +set +get +hset" in result
result = await async_client.execute_command("ACL LIST")
assert (
"user kostas on resetchannels -@admin +@list -@string +@set +@all +set +get +hset" in result
)
await async_client.execute_command("ACL SETUSER kostas -set -get +hset")
result = await async_client.execute_command("ACL list")
assert "user kostas on -@admin +@list -@string +@set +@all -set -get +hset" in result
result = await async_client.execute_command("ACL LIST")
assert (
"user kostas on resetchannels -@admin +@list -@string +@set +@all -set -get +hset" in result
)
# interleaved
await async_client.execute_command("ACL SETUSER kostas -hset +get -get -@all")
result = await async_client.execute_command("ACL list")
assert "user kostas on -@admin +@list -@string +@set -set -hset -get -@all" in result
result = await async_client.execute_command("ACL LIST")
assert (
"user kostas on resetchannels -@admin +@list -@string +@set -set -hset -get -@all" in result
)
# interleaved with categories
await async_client.execute_command("ACL SETUSER kostas +@string +get -get +set")
result = await async_client.execute_command("ACL list")
assert "user kostas on -@admin +@list +@set -hset -@all +@string -get +set" in result
result = await async_client.execute_command("ACL LIST")
assert (
"user kostas on resetchannels -@admin +@list +@set -hset -@all +@string -get +set" in result
)
@pytest.mark.asyncio
@ -324,7 +333,7 @@ async def test_bad_acl_file(df_factory, tmp_dir):
async def test_good_acl_file(df_factory, tmp_dir):
# The hash below is password temp
acl = create_temp_file(
"USER MrFoo ON #a6864eb339b0e1f6e00d75293a8840abf069a2c0fe82e6e53af6ac099793c1d5 >mypass",
"USER MrFoo ON #a6864eb339b0e1f6e00d75293a8840abf069a2c0fe82e6e53af6ac099793c1d5 >mypass &bar &r*nd",
tmp_dir,
)
df = df_factory.create(aclfile=acl)
@ -333,13 +342,14 @@ async def test_good_acl_file(df_factory, tmp_dir):
client = df.client()
await client.execute_command("ACL LOAD")
result = await client.execute_command("ACL list")
result = await client.execute_command("ACL LIST")
assert 2 == len(result)
assert (
"user MrFoo on #ea71c25a7a60224 #a6864eb339b0e1f -@all" in result
or "user MrFoo on #a6864eb339b0e1f #ea71c25a7a60224 -@all" in result
"user MrFoo on #ea71c25a7a60224 #a6864eb339b0e1f resetchannels &bar &r*nd -@all" in result
or "user MrFoo on #a6864eb339b0e1f #ea71c25a7a60224 resetchannels &bar &r*nd -@all"
in result
)
assert "user default on nopass ~* +@all" in result
assert "user default on nopass ~* &* +@all" in result
await client.execute_command("ACL SETUSER MrFoo +@all")
# Check multiple passwords work
assert "OK" == await client.execute_command("AUTH mypass")
@ -351,12 +361,12 @@ async def test_good_acl_file(df_factory, tmp_dir):
await client.execute_command("ACL SETUSER shahar >mypass +@set")
await client.execute_command("ACL SETUSER vlad ~foo ~bar* +@string")
result = await client.execute_command("ACL list")
result = await client.execute_command("ACL LIST")
assert 4 == len(result)
assert "user roy on #ea71c25a7a60224 -@all +@string +hset" in result
assert "user shahar off #ea71c25a7a60224 -@all +@set" in result
assert "user vlad off ~foo ~bar* -@all +@string" in result
assert "user default on nopass ~* +@all" in result
assert "user roy on #ea71c25a7a60224 resetchannels -@all +@string +hset" in result
assert "user shahar off #ea71c25a7a60224 resetchannels -@all +@set" in result
assert "user vlad off ~foo ~bar* resetchannels -@all +@string" in result
assert "user default on nopass ~* &* +@all" in result
result = await client.execute_command("ACL DELUSER shahar")
assert result == 1
@ -365,11 +375,11 @@ async def test_good_acl_file(df_factory, tmp_dir):
result = await client.execute_command("ACL LOAD")
result = await client.execute_command("ACL list")
result = await client.execute_command("ACL LIST")
assert 3 == len(result)
assert "user roy on #ea71c25a7a60224 -@all +@string +hset" in result
assert "user vlad off ~foo ~bar* -@all +@string" in result
assert "user default on nopass ~* +@all" in result
assert "user roy on #ea71c25a7a60224 resetchannels -@all +@string +hset" in result
assert "user vlad off ~foo ~bar* resetchannels -@all +@string" in result
assert "user default on nopass ~* &* +@all" in result
await client.close()
@ -483,7 +493,7 @@ async def test_set_acl_file(async_client: aioredis.Redis, tmp_dir):
await async_client.execute_command("ACL LOAD")
result = await async_client.execute_command("ACL list")
result = await async_client.execute_command("ACL LIST")
assert 3 == len(result)
result = await async_client.execute_command("AUTH roy mypass")
@ -635,3 +645,83 @@ async def test_auth_resp3_bug(df_factory):
}
await client.close()
@pytest.mark.asyncio
async def test_acl_pub_sub_auth(df_factory):
df = df_factory.create()
df.start()
client = df.client()
await client.execute_command("ACL SETUSER kostas on >tmp +subscribe +psubscribe &f*o &bar")
assert await client.execute_command("AUTH kostas tmp") == "OK"
res = await client.execute_command("SUBSCRIBE bar")
assert res == ["subscribe", "bar", 1]
res = await client.execute_command("SUBSCRIBE foo")
assert res == ["subscribe", "foo", 2]
with pytest.raises(redis.exceptions.NoPermissionError):
res = await client.execute_command("SUBSCRIBE my_channel")
# PSUBSCRIBE only matches pure literals, no asterisks
with pytest.raises(redis.exceptions.NoPermissionError):
res = await client.execute_command("PSUBSCRIBE foo")
# my_channel is not in our list so the command should fail
with pytest.raises(redis.exceptions.NoPermissionError):
res = await client.execute_command("PSUBSCRIBE bar my_channel")
res = await client.execute_command("PSUBSCRIBE bar")
assert res == ["psubscribe", "bar", 3]
@pytest.mark.asyncio
async def test_acl_revoke_pub_sub_while_subscribed(df_factory):
df = df_factory.create()
df.start()
publisher = df.client()
async def publish_worker(client):
for i in range(0, 10):
await client.publish("channel", "message")
async def subscribe_worker(channel: aioredis.client.PubSub):
total_msgs = 1
async with async_timeout.timeout(10):
while total_msgs != 10:
res = await channel.get_message(ignore_subscribe_messages=True, timeout=5)
if total_msgs is not None:
total_msgs = total_msgs + 1
await publisher.execute_command("ACL SETUSER kostas >tmp ON +@slow +SUBSCRIBE allchannels")
subscriber = aioredis.Redis(
username="kostas", password="tmp", port=df.port, decode_responses=True
)
subscriber_obj = subscriber.pubsub()
await subscriber_obj.subscribe("channel")
subscribe_task = asyncio.create_task(subscribe_worker(subscriber_obj))
await publish_worker(publisher)
await subscribe_task
subscribe_task = asyncio.create_task(subscribe_worker(subscriber_obj))
# Already subscribed, we should still be able to receive messages on channel
# We should not be able to unsubscribe
await publisher.execute_command("ACL SETUSER kostas -SUBSCRIBE -UNSUBSCRIBE")
await publish_worker(publisher)
await subscribe_task
# unsubscribe is not marked async and it's such a mess that it throws the error
# once we try to resubscribe. Instead I use the raw execute command to check that
# permission changes work
with pytest.raises(redis.exceptions.NoPermissionError):
await subscriber.execute_command("UNSUBSCRIBE channel")
await publisher.execute_command("ACL SETUSER kostas +SUBSCRIBE +UNSUBSCRIBE")
subscribe_task = asyncio.create_task(subscribe_worker(subscriber_obj))
await publisher.execute_command("ACL SETUSER kostas resetchannels")
await publish_worker(publisher)
with pytest.raises(redis.exceptions.ConnectionError):
await subscribe_task