From 18ca61d29b280f5f6527f6bdf8e0bc2ecd4e973d Mon Sep 17 00:00:00 2001 From: Shahar Mike Date: Tue, 16 Jul 2024 19:34:49 +0300 Subject: [PATCH] feat(namespaces): Initial support for multi-tenant (#3260) * feat(namespaces): Initial support for multi-tenant #3050 This PR introduces a way to create multiple, separate and isolated namespaces in Dragonfly. Each user can be associated with a single namespace, and will not be able to interact with other namespaces. This is still experimental, and lacks some important features, such as: * Replication and RDB saving completely ignores non-default namespaces * Defrag and statistics either use the default namespace or all namespaces without separation To associate a user with a namespace, use the `ACL` command with the `TENANT:` flag: ``` ACL SETUSER user TENANT:namespace1 ON >user_pass +@all ~* ``` For more examples and up to date info check `tests/dragonfly/acl_family_test.py` - specifically the `test_namespaces` function. --- docs/namespaces.md | 2 +- src/facade/acl_commands_def.h | 1 + src/server/CMakeLists.txt | 2 +- src/server/acl/acl_family.cc | 14 +++ src/server/acl/acl_family.h | 2 + src/server/acl/user.cc | 10 ++ src/server/acl/user.h | 8 ++ src/server/acl/user_registry.cc | 14 ++- src/server/blocking_controller.cc | 4 +- src/server/blocking_controller.h | 4 +- src/server/blocking_controller_test.cc | 13 ++- src/server/cluster/cluster_family.cc | 5 +- src/server/cluster/cluster_utility.cc | 6 +- src/server/cluster/outgoing_slot_migration.cc | 11 +- src/server/conn_context.h | 1 + src/server/container_utils.cc | 7 +- src/server/db_slice.cc | 2 +- src/server/debugcmd.cc | 36 +++--- src/server/detail/save_stages_controller.cc | 3 +- src/server/dflycmd.cc | 7 +- src/server/dragonfly_test.cc | 8 +- src/server/engine_shard_set.cc | 89 +++++++++------ src/server/engine_shard_set.h | 20 +--- src/server/generic_family.cc | 28 +++-- src/server/hset_family.cc | 2 +- src/server/journal/executor.cc | 1 + src/server/list_family.cc | 15 ++- src/server/list_family_test.cc | 11 +- src/server/main_service.cc | 52 +++++---- src/server/main_service.h | 2 +- src/server/memory_cmd.cc | 6 +- src/server/multi_command_squasher.cc | 4 +- src/server/multi_test.cc | 40 +++---- src/server/namespaces.cc | 103 ++++++++++++++++++ src/server/namespaces.h | 70 ++++++++++++ src/server/rdb_load.cc | 11 +- src/server/rdb_save.cc | 7 +- src/server/replica.cc | 1 + src/server/server_family.cc | 46 ++++---- src/server/server_family.h | 7 +- src/server/stream_family.cc | 15 +-- src/server/string_family.cc | 5 +- src/server/string_family_test.cc | 2 +- src/server/test_utils.cc | 31 ++++-- src/server/test_utils.h | 2 +- src/server/transaction.cc | 50 +++++---- src/server/transaction.h | 21 ++-- src/server/tx_base.cc | 8 +- src/server/tx_base.h | 2 + src/server/zset_family.cc | 5 +- tests/dragonfly/acl_family_test.py | 39 ++++++- 51 files changed, 600 insertions(+), 255 deletions(-) create mode 100644 src/server/namespaces.cc create mode 100644 src/server/namespaces.h diff --git a/docs/namespaces.md b/docs/namespaces.md index 75939d49e..019596711 100644 --- a/docs/namespaces.md +++ b/docs/namespaces.md @@ -1,6 +1,6 @@ # Namespaces in Dragonfly -Dragonfly will soon add an _experimental_ feature, allowing complete separation of data by different users. +Dragonfly added an _experimental_ feature, allowing complete separation of data by different users. We call this feature _namespaces_, and it allows using a single Dragonfly server with multiple tenants, each using their own data, without being able to mix them together. diff --git a/src/facade/acl_commands_def.h b/src/facade/acl_commands_def.h index 62e131fdd..b4aaadd35 100644 --- a/src/facade/acl_commands_def.h +++ b/src/facade/acl_commands_def.h @@ -28,6 +28,7 @@ struct UserCredentials { uint32_t acl_categories{0}; std::vector acl_commands; AclKeys keys; + std::string ns; }; } // namespace dfly::acl diff --git a/src/server/CMakeLists.txt b/src/server/CMakeLists.txt index b2f2a5f7b..3fc7ac064 100644 --- a/src/server/CMakeLists.txt +++ b/src/server/CMakeLists.txt @@ -28,7 +28,7 @@ endif() add_library(dfly_transaction db_slice.cc malloc_stats.cc blocking_controller.cc command_registry.cc cluster/cluster_utility.cc - journal/tx_executor.cc + journal/tx_executor.cc namespaces.cc common.cc journal/journal.cc journal/types.cc journal/journal_slice.cc server_state.cc table.cc top_keys.cc transaction.cc tx_base.cc serializer_commons.cc journal/serializer.cc journal/executor.cc journal/streamer.cc diff --git a/src/server/acl/acl_family.cc b/src/server/acl/acl_family.cc index fc7b00a18..8791c8758 100644 --- a/src/server/acl/acl_family.cc +++ b/src/server/acl/acl_family.cc @@ -942,6 +942,14 @@ std::pair AclFamily::MaybeParseAclCategory(std::string_view comman return {}; } +std::optional AclFamily::MaybeParseNamespace(std::string_view command) const { + constexpr std::string_view kPrefix = "NAMESPACE:"; + if (absl::StartsWith(command, kPrefix)) { + return std::string(command.substr(kPrefix.size())); + } + return std::nullopt; +} + std::pair AclFamily::MaybeParseAclCommand( std::string_view command) const { if (absl::StartsWith(command, "+")) { @@ -1019,6 +1027,12 @@ std::variant AclFamily::ParseAclSetUser( continue; } + auto ns = MaybeParseNamespace(command); + if (ns.has_value()) { + req.ns = *ns; + continue; + } + auto [cmd, sign] = MaybeParseAclCommand(command); if (!cmd) { return ErrorReply(absl::StrCat("Unrecognized parameter ", command)); diff --git a/src/server/acl/acl_family.h b/src/server/acl/acl_family.h index 143e29f02..0be818725 100644 --- a/src/server/acl/acl_family.h +++ b/src/server/acl/acl_family.h @@ -80,6 +80,8 @@ class AclFamily final { using OptCommand = std::optional>; std::pair MaybeParseAclCommand(std::string_view command) const; + std::optional MaybeParseNamespace(std::string_view command) const; + std::variant ParseAclSetUser( const facade::ArgRange& args, bool hashed = false, bool has_all_keys = false) const; diff --git a/src/server/acl/user.cc b/src/server/acl/user.cc index fa59c7a25..341e6b546 100644 --- a/src/server/acl/user.cc +++ b/src/server/acl/user.cc @@ -79,6 +79,8 @@ void User::Update(UpdateRequest&& req, const CategoryToIdxStore& cat_to_id, if (req.is_active) { SetIsActive(*req.is_active); } + + SetNamespace(req.ns); } void User::SetPasswordHash(std::string_view password, bool is_hashed) { @@ -94,6 +96,14 @@ void User::UnsetPassword(std::string_view password) { password_hashes_.erase(StringSHA256(password)); } +void User::SetNamespace(const std::string& ns) { + namespace_ = ns; +} + +const std::string& User::Namespace() const { + return namespace_; +} + bool User::HasPassword(std::string_view password) const { if (nopass_) { return true; diff --git a/src/server/acl/user.h b/src/server/acl/user.h index d8e440293..187856728 100644 --- a/src/server/acl/user.h +++ b/src/server/acl/user.h @@ -58,8 +58,11 @@ class User final { std::vector keys; bool reset_all_keys{false}; bool allow_all_keys{false}; + // TODO allow reset all // bool reset_all{false}; + + std::string ns; }; using CategoryChange = uint32_t; @@ -104,6 +107,8 @@ class User final { const AclKeys& Keys() const; + const std::string& Namespace() const; + using CategoryChanges = absl::flat_hash_map; using CommandChanges = absl::flat_hash_map; @@ -135,6 +140,7 @@ class User final { // For ACL key globs void SetKeyGlobs(std::vector keys); + void SetNamespace(const std::string& ns); // Set NOPASS and remove all passwords void SetNopass(); @@ -166,6 +172,8 @@ class User final { // if the user is on/off bool is_active_{false}; + + std::string namespace_; }; } // namespace dfly::acl diff --git a/src/server/acl/user_registry.cc b/src/server/acl/user_registry.cc index ae252ea3a..6c5d946af 100644 --- a/src/server/acl/user_registry.cc +++ b/src/server/acl/user_registry.cc @@ -35,7 +35,8 @@ UserCredentials UserRegistry::GetCredentials(std::string_view username) const { if (it == registry_.end()) { return {}; } - return {it->second.AclCategory(), it->second.AclCommands(), it->second.Keys()}; + return {it->second.AclCategory(), it->second.AclCommands(), it->second.Keys(), + it->second.Namespace()}; } bool UserRegistry::IsUserActive(std::string_view username) const { @@ -73,10 +74,13 @@ UserRegistry::UserWithWriteLock::UserWithWriteLock(std::unique_lock acl{User::Sign::PLUS, acl::ALL}; - auto key = User::UpdateKey{"~*", KeyOp::READ_WRITE, true, false}; - auto pass = std::vector{{"", false, true}}; - return {std::move(pass), true, false, {std::move(acl)}, {std::move(key)}}; + // Assign field by field to supress an annoying compiler warning + User::UpdateRequest req; + req.passwords = std::vector{{"", false, true}}; + req.is_active = true; + req.updates = {std::pair{User::Sign::PLUS, acl::ALL}}; + req.keys = {User::UpdateKey{"~*", KeyOp::READ_WRITE, true, false}}; + return req; } void UserRegistry::Init(const CategoryToIdxStore* cat_to_id_table, diff --git a/src/server/blocking_controller.cc b/src/server/blocking_controller.cc index 307e60b41..6f51e2bf7 100644 --- a/src/server/blocking_controller.cc +++ b/src/server/blocking_controller.cc @@ -10,6 +10,7 @@ #include "base/logging.h" #include "server/engine_shard_set.h" +#include "server/namespaces.h" #include "server/transaction.h" namespace dfly { @@ -102,7 +103,7 @@ bool BlockingController::DbWatchTable::UnwatchTx(string_view key, Transaction* t return res; } -BlockingController::BlockingController(EngineShard* owner) : owner_(owner) { +BlockingController::BlockingController(EngineShard* owner, Namespace* ns) : owner_(owner), ns_(ns) { } BlockingController::~BlockingController() { @@ -153,6 +154,7 @@ void BlockingController::NotifyPending() { CHECK(tx == nullptr) << tx->DebugId(); DbContext context; + context.ns = ns_; context.time_now_ms = GetCurrentTimeMs(); for (DbIndex index : awakened_indices_) { diff --git a/src/server/blocking_controller.h b/src/server/blocking_controller.h index 94375e7ab..c8e6e13f0 100644 --- a/src/server/blocking_controller.h +++ b/src/server/blocking_controller.h @@ -15,10 +15,11 @@ namespace dfly { class Transaction; +class Namespace; class BlockingController { public: - explicit BlockingController(EngineShard* owner); + explicit BlockingController(EngineShard* owner, Namespace* ns); ~BlockingController(); using Keys = std::variant; @@ -60,6 +61,7 @@ class BlockingController { // void NotifyConvergence(Transaction* tx); EngineShard* owner_; + Namespace* ns_; absl::flat_hash_map> watched_dbs_; diff --git a/src/server/blocking_controller_test.cc b/src/server/blocking_controller_test.cc index 2079f244a..f516c8e53 100644 --- a/src/server/blocking_controller_test.cc +++ b/src/server/blocking_controller_test.cc @@ -61,7 +61,7 @@ void BlockingControllerTest::SetUp() { arg_vec_.emplace_back(s); } - trans_->InitByArgs(0, {arg_vec_.data(), arg_vec_.size()}); + trans_->InitByArgs(&namespaces.GetDefaultNamespace(), 0, {arg_vec_.data(), arg_vec_.size()}); CHECK_EQ(0u, Shard("x", shard_set->size())); CHECK_EQ(2u, Shard("z", shard_set->size())); @@ -70,6 +70,8 @@ void BlockingControllerTest::SetUp() { } void BlockingControllerTest::TearDown() { + namespaces.Clear(); + shard_set->Shutdown(); delete shard_set; @@ -79,7 +81,7 @@ void BlockingControllerTest::TearDown() { TEST_F(BlockingControllerTest, Basic) { trans_->ScheduleSingleHop([&](Transaction* t, EngineShard* shard) { - BlockingController bc(shard); + BlockingController bc(shard, &namespaces.GetDefaultNamespace()); auto keys = t->GetShardArgs(shard->shard_id()); bc.AddWatched( keys, [](auto...) { return true; }, t); @@ -103,7 +105,12 @@ TEST_F(BlockingControllerTest, Timeout) { EXPECT_EQ(status, facade::OpStatus::TIMED_OUT); unsigned num_watched = shard_set->Await( - 0, [&] { return EngineShard::tlocal()->blocking_controller()->NumWatched(0); }); + + 0, [&] { + return namespaces.GetDefaultNamespace() + .GetBlockingController(EngineShard::tlocal()->shard_id()) + ->NumWatched(0); + }); EXPECT_EQ(0, num_watched); trans_.reset(); diff --git a/src/server/cluster/cluster_family.cc b/src/server/cluster/cluster_family.cc index a5312f51d..f441e5970 100644 --- a/src/server/cluster/cluster_family.cc +++ b/src/server/cluster/cluster_family.cc @@ -21,6 +21,7 @@ #include "server/error.h" #include "server/journal/journal.h" #include "server/main_service.h" +#include "server/namespaces.h" #include "server/server_family.h" #include "server/server_state.h" @@ -451,7 +452,7 @@ void DeleteSlots(const SlotRanges& slots_ranges) { if (shard == nullptr) return; - shard->db_slice().FlushSlots(slots_ranges); + namespaces.GetDefaultNamespace().GetDbSlice(shard->shard_id()).FlushSlots(slots_ranges); }; shard_set->pool()->AwaitFiberOnAll(std::move(cb)); } @@ -599,7 +600,7 @@ void ClusterFamily::DflyClusterGetSlotInfo(CmdArgList args, ConnectionContext* c lock_guard lk(mu); for (auto& [slot, data] : slots_stats) { - data += shard->db_slice().GetSlotStats(slot); + data += namespaces.GetDefaultNamespace().GetDbSlice(shard->shard_id()).GetSlotStats(slot); } }; diff --git a/src/server/cluster/cluster_utility.cc b/src/server/cluster/cluster_utility.cc index c624425b7..c56ad919f 100644 --- a/src/server/cluster/cluster_utility.cc +++ b/src/server/cluster/cluster_utility.cc @@ -2,6 +2,7 @@ #include "server/cluster/cluster_defs.h" #include "server/engine_shard_set.h" +#include "server/namespaces.h" using namespace std; @@ -49,7 +50,10 @@ uint64_t GetKeyCount(const SlotRanges& slots) { uint64_t shard_keys = 0; for (const SlotRange& range : slots) { for (SlotId slot = range.start; slot <= range.end; slot++) { - shard_keys += shard->db_slice().GetSlotStats(slot).key_count; + shard_keys += namespaces.GetDefaultNamespace() + .GetDbSlice(shard->shard_id()) + .GetSlotStats(slot) + .key_count; } } keys.fetch_add(shard_keys); diff --git a/src/server/cluster/outgoing_slot_migration.cc b/src/server/cluster/outgoing_slot_migration.cc index b2b7cc5a6..1a7235b00 100644 --- a/src/server/cluster/outgoing_slot_migration.cc +++ b/src/server/cluster/outgoing_slot_migration.cc @@ -96,7 +96,7 @@ OutgoingMigration::OutgoingMigration(MigrationInfo info, ClusterFamily* cf, Serv server_family_(sf), cf_(cf), tx_(new Transaction{sf->service().FindCmd("DFLYCLUSTER")}) { - tx_->InitByArgs(0, {}); + tx_->InitByArgs(&namespaces.GetDefaultNamespace(), 0, {}); } OutgoingMigration::~OutgoingMigration() { @@ -212,10 +212,10 @@ void OutgoingMigration::SyncFb() { } OnAllShards([this](auto& migration) { - auto* shard = EngineShard::tlocal(); + DbSlice& db_slice = namespaces.GetDefaultNamespace().GetCurrentDbSlice(); server_family_->journal()->StartInThread(); migration = std::make_unique( - &shard->db_slice(), server(), migration_info_.slot_ranges, server_family_->journal()); + &db_slice, server(), migration_info_.slot_ranges, server_family_->journal()); }); if (!ChangeState(MigrationState::C_SYNC)) { @@ -284,8 +284,9 @@ bool OutgoingMigration::FinalizeMigration(long attempt) { // TODO implement blocking on migrated slots only bool is_block_active = true; auto is_pause_in_progress = [&is_block_active] { return is_block_active; }; - auto pause_fb_opt = Pause(server_family_->GetNonPriviligedListeners(), nullptr, - ClientPause::WRITE, is_pause_in_progress); + auto pause_fb_opt = + Pause(server_family_->GetNonPriviligedListeners(), &namespaces.GetDefaultNamespace(), nullptr, + ClientPause::WRITE, is_pause_in_progress); if (!pause_fb_opt) { LOG(WARNING) << "Cluster migration finalization time out"; diff --git a/src/server/conn_context.h b/src/server/conn_context.h index 13341c262..a61c513fd 100644 --- a/src/server/conn_context.h +++ b/src/server/conn_context.h @@ -282,6 +282,7 @@ class ConnectionContext : public facade::ConnectionContext { DebugInfo last_command_debug; // TODO: to introduce proper accessors. + Namespace* ns = nullptr; Transaction* transaction = nullptr; const CommandId* cid = nullptr; diff --git a/src/server/container_utils.cc b/src/server/container_utils.cc index 314cb0191..8848dae82 100644 --- a/src/server/container_utils.cc +++ b/src/server/container_utils.cc @@ -339,9 +339,10 @@ OpResult RunCbOnFirstNonEmptyBlocking(Transaction* trans, int req_obj_ty } auto wcb = [](Transaction* t, EngineShard* shard) { return t->GetShardArgs(shard->shard_id()); }; - const auto key_checker = [req_obj_type](EngineShard* owner, const DbContext& context, - Transaction*, std::string_view key) -> bool { - return context.GetDbSlice(owner->shard_id()).FindReadOnly(context, key, req_obj_type).ok(); + auto* ns = &trans->GetNamespace(); + const auto key_checker = [req_obj_type, ns](EngineShard* owner, const DbContext& context, + Transaction*, std::string_view key) -> bool { + return ns->GetDbSlice(owner->shard_id()).FindReadOnly(context, key, req_obj_type).ok(); }; auto status = trans->WaitOnWatch(limit_tp, std::move(wcb), key_checker, block_flag, pause_flag); diff --git a/src/server/db_slice.cc b/src/server/db_slice.cc index 8858a7f2f..e860fa689 100644 --- a/src/server/db_slice.cc +++ b/src/server/db_slice.cc @@ -1078,7 +1078,7 @@ void DbSlice::ExpireAllIfNeeded() { LOG(ERROR) << "Expire entry " << exp_it->first.ToString() << " not found in prime table"; return; } - ExpireIfNeeded(Context{db_index, GetCurrentTimeMs()}, prime_it); + ExpireIfNeeded(Context{nullptr, db_index, GetCurrentTimeMs()}, prime_it); }; ExpireTable::Cursor cursor; diff --git a/src/server/debugcmd.cc b/src/server/debugcmd.cc index 623a4ee8a..abacd231a 100644 --- a/src/server/debugcmd.cc +++ b/src/server/debugcmd.cc @@ -159,7 +159,7 @@ void DoPopulateBatch(string_view type, string_view prefix, size_t val_size, bool stub_tx->MultiSwitchCmd(cid); local_cntx.cid = cid; crb.SetReplyMode(ReplyMode::NONE); - stub_tx->InitByArgs(local_cntx.conn_state.db_index, args_span); + stub_tx->InitByArgs(cntx->ns, local_cntx.conn_state.db_index, args_span); sf->service().InvokeCmd(cid, args_span, &local_cntx); } @@ -261,8 +261,8 @@ void MergeObjHistMap(ObjHistMap&& src, ObjHistMap* dest) { } } -void DoBuildObjHist(EngineShard* shard, ObjHistMap* obj_hist_map) { - auto& db_slice = shard->db_slice(); +void DoBuildObjHist(EngineShard* shard, ConnectionContext* cntx, ObjHistMap* obj_hist_map) { + auto& db_slice = cntx->ns->GetDbSlice(shard->shard_id()); unsigned steps = 0; for (unsigned i = 0; i < db_slice.db_array_size(); ++i) { @@ -288,8 +288,9 @@ void DoBuildObjHist(EngineShard* shard, ObjHistMap* obj_hist_map) { } } -ObjInfo InspectOp(string_view key, DbIndex db_index) { - auto& db_slice = EngineShard::tlocal()->db_slice(); +ObjInfo InspectOp(ConnectionContext* cntx, string_view key) { + auto& db_slice = cntx->ns->GetCurrentDbSlice(); + auto db_index = cntx->db_index(); auto [pt, exp_t] = db_slice.GetTables(db_index); PrimeIterator it = pt->Find(key); @@ -323,8 +324,9 @@ ObjInfo InspectOp(string_view key, DbIndex db_index) { return oinfo; } -OpResult EstimateCompression(string_view key, DbIndex db_index) { - auto& db_slice = EngineShard::tlocal()->db_slice(); +OpResult EstimateCompression(ConnectionContext* cntx, string_view key) { + auto& db_slice = cntx->ns->GetCurrentDbSlice(); + auto db_index = cntx->db_index(); auto [pt, exp_t] = db_slice.GetTables(db_index); PrimeIterator it = pt->Find(key); @@ -544,7 +546,7 @@ void DebugCmd::Load(string_view filename) { const CommandId* cid = sf_.service().FindCmd("FLUSHALL"); intrusive_ptr flush_trans(new Transaction{cid}); - flush_trans->InitByArgs(0, {}); + flush_trans->InitByArgs(cntx_->ns, 0, {}); VLOG(1) << "Performing flush"; error_code ec = sf_.Drakarys(flush_trans.get(), DbSlice::kDbAll); if (ec) { @@ -750,7 +752,7 @@ void DebugCmd::PopulateRangeFiber(uint64_t from, uint64_t num_of_keys, // after running the callback // Note that running debug populate while running flushall/db can cause dcheck fail because the // finish cb is executed just when we finish populating the database. - shard->db_slice().OnCbFinish(); + cntx_->ns->GetDbSlice(shard->shard_id()).OnCbFinish(); }); } @@ -801,7 +803,7 @@ void DebugCmd::Inspect(string_view key, CmdArgList args) { string resp; if (check_compression) { - auto cb = [&] { return EstimateCompression(key, cntx_->db_index()); }; + auto cb = [&] { return EstimateCompression(cntx_, key); }; auto res = ess.Await(sid, std::move(cb)); if (!res) { cntx_->SendError(res.status()); @@ -812,7 +814,7 @@ void DebugCmd::Inspect(string_view key, CmdArgList args) { StrAppend(&resp, " ratio: ", static_cast(res->compressed_size) / (res->raw_size)); } } else { - auto cb = [&] { return InspectOp(key, cntx_->db_index()); }; + auto cb = [&] { return InspectOp(cntx_, key); }; ObjInfo res = ess.Await(sid, std::move(cb)); @@ -846,7 +848,7 @@ void DebugCmd::Watched() { vector awaked_trans; auto cb = [&](EngineShard* shard) { - auto* bc = shard->blocking_controller(); + auto* bc = cntx_->ns->GetBlockingController(shard->shard_id()); if (bc) { auto keys = bc->GetWatchedKeys(cntx_->db_index()); @@ -894,8 +896,9 @@ void DebugCmd::TxAnalysis() { void DebugCmd::ObjHist() { vector obj_hist_map_arr(shard_set->size()); - shard_set->RunBlockingInParallel( - [&](EngineShard* shard) { DoBuildObjHist(shard, &obj_hist_map_arr[shard->shard_id()]); }); + shard_set->RunBlockingInParallel([&](EngineShard* shard) { + DoBuildObjHist(shard, cntx_, &obj_hist_map_arr[shard->shard_id()]); + }); for (size_t i = shard_set->size() - 1; i > 0; --i) { MergeObjHistMap(std::move(obj_hist_map_arr[i]), &obj_hist_map_arr[0]); @@ -937,8 +940,9 @@ void DebugCmd::Shards() { vector infos(shard_set->size()); shard_set->RunBriefInParallel([&](EngineShard* shard) { - auto slice_stats = shard->db_slice().GetStats(); - auto& stats = infos[shard->shard_id()]; + auto sid = shard->shard_id(); + auto slice_stats = cntx_->ns->GetDbSlice(sid).GetStats(); + auto& stats = infos[sid]; stats.used_memory = shard->UsedMemory(); for (const auto& db_stats : slice_stats.db_stats) { diff --git a/src/server/detail/save_stages_controller.cc b/src/server/detail/save_stages_controller.cc index fc823074b..0a56ca02e 100644 --- a/src/server/detail/save_stages_controller.cc +++ b/src/server/detail/save_stages_controller.cc @@ -13,6 +13,7 @@ #include "base/logging.h" #include "server/detail/snapshot_storage.h" #include "server/main_service.h" +#include "server/namespaces.h" #include "server/script_mgr.h" #include "server/transaction.h" #include "strings/human_readable.h" @@ -400,7 +401,7 @@ void SaveStagesController::CloseCb(unsigned index) { } if (auto* es = EngineShard::tlocal(); use_dfs_format_ && es) - es->db_slice().ResetUpdateEvents(); + namespaces.GetDefaultNamespace().GetDbSlice(es->shard_id()).ResetUpdateEvents(); } void SaveStagesController::RunStage(void (SaveStagesController::*cb)(unsigned)) { diff --git a/src/server/dflycmd.cc b/src/server/dflycmd.cc index d611d030d..cc6446718 100644 --- a/src/server/dflycmd.cc +++ b/src/server/dflycmd.cc @@ -75,7 +75,7 @@ OpStatus WaitReplicaFlowToCatchup(absl::Time end_time, shared_ptrdb_slice().SetExpireAllowed(false); + namespaces.GetDefaultNamespace().GetDbSlice(shard->shard_id()).SetExpireAllowed(false); shard->journal()->RecordEntry(0, journal::Op::PING, 0, 0, nullopt, {}, true); FlowInfo* flow = &replica->flows[shard->shard_id()]; @@ -396,8 +396,9 @@ void DflyCmd::TakeOver(CmdArgList args, ConnectionContext* cntx) { VLOG(1) << "AwaitCurrentDispatches done"; absl::Cleanup([] { - shard_set->RunBriefInParallel( - [](EngineShard* shard) { shard->db_slice().SetExpireAllowed(true); }); + shard_set->RunBriefInParallel([](EngineShard* shard) { + namespaces.GetDefaultNamespace().GetDbSlice(shard->shard_id()).SetExpireAllowed(true); + }); VLOG(2) << "Enable expiration"; }); diff --git a/src/server/dragonfly_test.cc b/src/server/dragonfly_test.cc index f036568c8..550777f5e 100644 --- a/src/server/dragonfly_test.cc +++ b/src/server/dragonfly_test.cc @@ -366,7 +366,11 @@ TEST_F(DflyEngineTest, MemcacheFlags) { ASSERT_EQ(Run("resp", {"flushdb"}), "OK"); pp_->AwaitFiberOnAll([](auto*) { if (auto* shard = EngineShard::tlocal(); shard) { - EXPECT_EQ(shard->db_slice().GetDBTable(0)->mcflag.size(), 0u); + EXPECT_EQ(namespaces.GetDefaultNamespace() + .GetDbSlice(shard->shard_id()) + .GetDBTable(0) + ->mcflag.size(), + 0u); } }); } @@ -584,7 +588,7 @@ TEST_F(DflyEngineTest, Bug496) { if (shard == nullptr) return; - auto& db = shard->db_slice(); + auto& db = namespaces.GetDefaultNamespace().GetDbSlice(shard->shard_id()); int cb_hits = 0; uint32_t cb_id = diff --git a/src/server/engine_shard_set.cc b/src/server/engine_shard_set.cc index 3618b4751..38c62086a 100644 --- a/src/server/engine_shard_set.cc +++ b/src/server/engine_shard_set.cc @@ -20,6 +20,7 @@ extern "C" { #include "io/proc_reader.h" #include "server/blocking_controller.h" #include "server/cluster/cluster_defs.h" +#include "server/namespaces.h" #include "server/search/doc_index.h" #include "server/server_state.h" #include "server/tiered_storage.h" @@ -294,7 +295,8 @@ bool EngineShard::DoDefrag() { constexpr size_t kMaxTraverses = 40; const float threshold = GetFlag(FLAGS_mem_defrag_page_utilization_threshold); - auto& slice = db_slice(); + // TODO: enable tiered storage on non-default db slice + DbSlice& slice = namespaces.GetDefaultNamespace().GetDbSlice(shard_->shard_id()); // If we moved to an invalid db, skip as long as it's not the last one while (!slice.IsDbValid(defrag_state_.dbid) && defrag_state_.dbid + 1 < slice.db_array_size()) @@ -324,7 +326,7 @@ bool EngineShard::DoDefrag() { } }); traverses_count++; - } while (traverses_count < kMaxTraverses && cur); + } while (traverses_count < kMaxTraverses && cur && namespaces.IsInitialized()); defrag_state_.UpdateScanState(cur.value()); @@ -355,11 +357,14 @@ bool EngineShard::DoDefrag() { // priority. // otherwise lower the task priority so that it would not use the CPU when not required uint32_t EngineShard::DefragTask() { + if (!namespaces.IsInitialized()) { + return util::ProactorBase::kOnIdleMaxLevel; + } + constexpr uint32_t kRunAtLowPriority = 0u; - const auto shard_id = db_slice().shard_id(); if (defrag_state_.CheckRequired()) { - VLOG(2) << shard_id << ": need to run defrag memory cursor state: " << defrag_state_.cursor; + VLOG(2) << shard_id_ << ": need to run defrag memory cursor state: " << defrag_state_.cursor; if (DoDefrag()) { // we didn't finish the scan return util::ProactorBase::kOnIdleMaxLevel; @@ -372,13 +377,11 @@ EngineShard::EngineShard(util::ProactorBase* pb, mi_heap_t* heap) : queue_(1, kQueueLen), txq_([](const Transaction* t) { return t->txid(); }), mi_resource_(heap), - db_slice_(pb->GetPoolIndex(), GetFlag(FLAGS_cache_mode), this) { + shard_id_(pb->GetPoolIndex()) { tmp_str1 = sdsempty(); - db_slice_.UpdateExpireBase(absl::GetCurrentTimeNanos() / 1000000, 0); - // start the defragmented task here - defrag_task_ = pb->AddOnIdleTask([this]() { return this->DefragTask(); }); - queue_.Start(absl::StrCat("shard_queue_", db_slice_.shard_id())); + defrag_task_ = pb->AddOnIdleTask([this]() { return DefragTask(); }); + queue_.Start(absl::StrCat("shard_queue_", shard_id())); } EngineShard::~EngineShard() { @@ -437,8 +440,10 @@ void EngineShard::InitTieredStorage(ProactorBase* pb, size_t max_file_size) { LOG_IF(FATAL, pb->GetKind() != ProactorBase::IOURING) << "Only ioring based backing storage is supported. Exiting..."; + // TODO: enable tiered storage on non-default namespace + DbSlice& db_slice = namespaces.GetDefaultNamespace().GetDbSlice(shard_id()); auto* shard = EngineShard::tlocal(); - shard->tiered_storage_ = make_unique(&db_slice_, max_file_size); + shard->tiered_storage_ = make_unique(&db_slice, max_file_size); error_code ec = shard->tiered_storage_->Open(backing_prefix); CHECK(!ec) << ec.message(); } @@ -515,24 +520,31 @@ void EngineShard::PollExecution(const char* context, Transaction* trans) { trans = nullptr; if ((is_self && disarmed) || continuation_trans_->DisarmInShard(sid)) { + auto bc = continuation_trans_->GetNamespace().GetBlockingController(shard_id_); if (bool keep = run(continuation_trans_, false); !keep) { // if this holds, we can remove this check altogether. DCHECK(continuation_trans_ == nullptr); continuation_trans_ = nullptr; } + if (bc && bc->HasAwakedTransaction()) { + // Break if there are any awakened transactions, as we must give way to them + // before continuing to handle regular transactions from the queue. + return; + } } } // Progress on the transaction queue if no transaction is running currently. Transaction* head = nullptr; while (continuation_trans_ == nullptr && !txq_.Empty()) { + head = get(txq_.Front()); + // Break if there are any awakened transactions, as we must give way to them // before continuing to handle regular transactions from the queue. - if (blocking_controller_ && blocking_controller_->HasAwakedTransaction()) + if (head->GetNamespace().GetBlockingController(shard_id_) && + head->GetNamespace().GetBlockingController(shard_id_)->HasAwakedTransaction()) break; - head = get(txq_.Front()); - VLOG(2) << "Considering head " << head->DebugId() << " isarmed: " << head->DEBUG_IsArmedInShard(sid); @@ -610,22 +622,28 @@ void EngineShard::Heartbeat() { DbContext db_cntx; db_cntx.time_now_ms = GetCurrentTimeMs(); - for (unsigned i = 0; i < db_slice_.db_array_size(); ++i) { - if (!db_slice_.IsDbValid(i)) + // TODO: iterate over all namespaces + if (!namespaces.IsInitialized()) { + return; + } + + DbSlice& db_slice = namespaces.GetDefaultNamespace().GetDbSlice(shard_id()); + for (unsigned i = 0; i < db_slice.db_array_size(); ++i) { + if (!db_slice.IsDbValid(i)) continue; db_cntx.db_index = i; - auto [pt, expt] = db_slice_.GetTables(i); + auto [pt, expt] = db_slice.GetTables(i); if (expt->size() > pt->size() / 4) { - DbSlice::DeleteExpiredStats stats = db_slice_.DeleteExpiredStep(db_cntx, ttl_delete_target); + DbSlice::DeleteExpiredStats stats = db_slice.DeleteExpiredStep(db_cntx, ttl_delete_target); counter_[TTL_TRAVERSE].IncBy(stats.traversed); counter_[TTL_DELETE].IncBy(stats.deleted); } // if our budget is below the limit - if (db_slice_.memory_budget() < eviction_redline) { - db_slice_.FreeMemWithEvictionStep(i, eviction_redline - db_slice_.memory_budget()); + if (db_slice.memory_budget() < eviction_redline) { + db_slice.FreeMemWithEvictionStep(i, eviction_redline - db_slice.memory_budget()); } if (UsedMemory() > tiering_offload_threshold) { @@ -686,18 +704,23 @@ void EngineShard::RunPeriodic(std::chrono::milliseconds period_ms) { } void EngineShard::CacheStats() { + if (!namespaces.IsInitialized()) { + return; + } + // mi_heap_visit_blocks(tlh, false /* visit all blocks*/, visit_cb, &sum); mi_stats_merge(); // Used memory for this shard. size_t used_mem = UsedMemory(); - cached_stats[db_slice_.shard_id()].used_memory.store(used_mem, memory_order_relaxed); + DbSlice& db_slice = namespaces.GetDefaultNamespace().GetDbSlice(shard_id()); + cached_stats[db_slice.shard_id()].used_memory.store(used_mem, memory_order_relaxed); ssize_t free_mem = max_memory_limit - used_mem_current.load(memory_order_relaxed); size_t entries = 0; size_t table_memory = 0; - for (size_t i = 0; i < db_slice_.db_array_size(); ++i) { - DbTable* table = db_slice_.GetDBTable(i); + for (size_t i = 0; i < db_slice.db_array_size(); ++i) { + DbTable* table = db_slice.GetDBTable(i); if (table) { entries += table->prime.size(); table_memory += (table->prime.mem_usage() + table->expire.mem_usage()); @@ -706,7 +729,7 @@ void EngineShard::CacheStats() { size_t obj_memory = table_memory <= used_mem ? used_mem - table_memory : 0; size_t bytes_per_obj = entries > 0 ? obj_memory / entries : 0; - db_slice_.SetCachedParams(free_mem / shard_set->size(), bytes_per_obj); + db_slice.SetCachedParams(free_mem / shard_set->size(), bytes_per_obj); } size_t EngineShard::UsedMemory() const { @@ -714,14 +737,6 @@ size_t EngineShard::UsedMemory() const { search_indices()->GetUsedMemory(); } -BlockingController* EngineShard::EnsureBlockingController() { - if (!blocking_controller_) { - blocking_controller_.reset(new BlockingController(this)); - } - - return blocking_controller_.get(); -} - void EngineShard::TEST_EnableHeartbeat() { fiber_periodic_ = fb2::Fiber("shard_periodic_TEST", [this, period_ms = 1] { RunPeriodic(std::chrono::milliseconds(period_ms)); @@ -750,6 +765,8 @@ auto EngineShard::AnalyzeTxQueue() const -> TxQueueInfo { info.tx_total = queue->size(); unsigned max_db_id = 0; + auto& db_slice = namespaces.GetDefaultNamespace().GetCurrentDbSlice(); + do { auto value = queue->At(cur); Transaction* trx = std::get(value); @@ -766,7 +783,7 @@ auto EngineShard::AnalyzeTxQueue() const -> TxQueueInfo { if (trx->IsGlobal() || (trx->IsMulti() && trx->GetMultiMode() == Transaction::GLOBAL)) { info.tx_global++; } else { - const DbTable* table = db_slice().GetDBTable(trx->GetDbIndex()); + const DbTable* table = db_slice.GetDBTable(trx->GetDbIndex()); bool can_run = !HasContendedLocks(sid, trx, table); if (can_run) { info.tx_runnable++; @@ -778,7 +795,7 @@ auto EngineShard::AnalyzeTxQueue() const -> TxQueueInfo { // Analyze locks for (unsigned i = 0; i <= max_db_id; ++i) { - const DbTable* table = db_slice().GetDBTable(i); + const DbTable* table = db_slice.GetDBTable(i); if (table == nullptr) continue; @@ -869,6 +886,8 @@ void EngineShardSet::Init(uint32_t sz, bool update_db_time) { } }); + namespaces.Init(); + pp_->AwaitFiberOnAll([&](uint32_t index, ProactorBase* pb) { if (index < shard_queue_.size()) { EngineShard::tlocal()->InitTieredStorage(pb, max_shard_file_size); @@ -895,7 +914,9 @@ void EngineShardSet::TEST_EnableHeartBeat() { } void EngineShardSet::TEST_EnableCacheMode() { - RunBriefInParallel([](EngineShard* shard) { shard->db_slice().TEST_EnableCacheMode(); }); + RunBlockingInParallel([](EngineShard* shard) { + namespaces.GetDefaultNamespace().GetCurrentDbSlice().TEST_EnableCacheMode(); + }); } ShardId Shard(string_view v, ShardId shard_num) { diff --git a/src/server/engine_shard_set.h b/src/server/engine_shard_set.h index 388c9d61d..b641a3af9 100644 --- a/src/server/engine_shard_set.h +++ b/src/server/engine_shard_set.h @@ -60,15 +60,7 @@ class EngineShard { } ShardId shard_id() const { - return db_slice_.shard_id(); - } - - DbSlice& db_slice() { - return db_slice_; - } - - const DbSlice& db_slice() const { - return db_slice_; + return shard_id_; } PMR_NS::memory_resource* memory_resource() { @@ -124,12 +116,6 @@ class EngineShard { return shard_search_indices_.get(); } - BlockingController* EnsureBlockingController(); - - BlockingController* blocking_controller() { - return blocking_controller_.get(); - } - // for everyone to use for string transformations during atomic cpu sequences. sds tmp_str1; @@ -242,7 +228,7 @@ class EngineShard { TxQueue txq_; MiMemoryResource mi_resource_; - DbSlice db_slice_; + ShardId shard_id_; Stats stats_; @@ -261,8 +247,8 @@ class EngineShard { DefragTaskState defrag_state_; std::unique_ptr tiered_storage_; + // TODO: Move indices to Namespace std::unique_ptr shard_search_indices_; - std::unique_ptr blocking_controller_; using Counter = util::SlidingCounter<7>; diff --git a/src/server/generic_family.cc b/src/server/generic_family.cc index 1ce3d8115..44956a57f 100644 --- a/src/server/generic_family.cc +++ b/src/server/generic_family.cc @@ -451,7 +451,7 @@ OpStatus Renamer::DeserializeDest(Transaction* t, EngineShard* shard) { auto& dest_it = restored_dest_it->it; dest_it->first.SetSticky(serialized_value_.sticky); - auto bc = shard->blocking_controller(); + auto bc = op_args.db_cntx.ns->GetBlockingController(op_args.shard->shard_id()); if (bc) { bc->AwakeWatched(t->GetDbIndex(), dest_key_); } @@ -607,7 +607,7 @@ uint64_t ScanGeneric(uint64_t cursor, const ScanOpts& scan_opts, StringVec* keys } cursor >>= 10; - DbContext db_cntx{cntx->conn_state.db_index, GetCurrentTimeMs()}; + DbContext db_cntx{cntx->ns, cntx->conn_state.db_index, GetCurrentTimeMs()}; do { auto cb = [&] { @@ -1355,8 +1355,10 @@ void GenericFamily::Select(CmdArgList args, ConnectionContext* cntx) { return cntx->SendError(kDbIndOutOfRangeErr); } cntx->conn_state.db_index = index; - auto cb = [index](EngineShard* shard) { - shard->db_slice().ActivateDb(index); + auto cb = [cntx, index](EngineShard* shard) { + CHECK(cntx->ns != nullptr); + auto& db_slice = cntx->ns->GetDbSlice(shard->shard_id()); + db_slice.ActivateDb(index); return OpStatus::OK; }; shard_set->RunBriefInParallel(std::move(cb)); @@ -1385,7 +1387,8 @@ void GenericFamily::Type(CmdArgList args, ConnectionContext* cntx) { std::string_view key = ArgS(args, 0); auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult { - auto it = shard->db_slice().FindReadOnly(t->GetDbContext(), key).it; + auto& db_slice = cntx->ns->GetDbSlice(shard->shard_id()); + auto it = db_slice.FindReadOnly(t->GetDbContext(), key).it; if (!it.is_done()) { return it->second.ObjType(); } else { @@ -1549,8 +1552,9 @@ OpResult GenericFamily::OpRen(const OpArgs& op_args, string_view from_key, to_res.it->first.SetSticky(sticky); } - if (!is_prior_list && to_res.it->second.ObjType() == OBJ_LIST && es->blocking_controller()) { - es->blocking_controller()->AwakeWatched(op_args.db_cntx.db_index, to_key); + auto bc = op_args.db_cntx.ns->GetBlockingController(es->shard_id()); + if (!is_prior_list && to_res.it->second.ObjType() == OBJ_LIST && bc) { + bc->AwakeWatched(op_args.db_cntx.db_index, to_key); } return OpStatus::OK; } @@ -1590,8 +1594,9 @@ OpStatus GenericFamily::OpMove(const OpArgs& op_args, string_view key, DbIndex t auto& add_res = *op_result; add_res.it->first.SetSticky(sticky); - if (add_res.it->second.ObjType() == OBJ_LIST && op_args.shard->blocking_controller()) { - op_args.shard->blocking_controller()->AwakeWatched(target_db, key); + auto bc = op_args.db_cntx.ns->GetBlockingController(op_args.shard->shard_id()); + if (add_res.it->second.ObjType() == OBJ_LIST && bc) { + bc->AwakeWatched(target_db, key); } return OpStatus::OK; @@ -1602,14 +1607,15 @@ void GenericFamily::RandomKey(CmdArgList args, ConnectionContext* cntx) { absl::BitGen bitgen; atomic_size_t candidates_counter{0}; - DbContext db_cntx{cntx->conn_state.db_index, GetCurrentTimeMs()}; + DbContext db_cntx{cntx->ns, cntx->conn_state.db_index, GetCurrentTimeMs()}; ScanOpts scan_opts; scan_opts.limit = 3; // number of entries per shard std::vector candidates_collection(shard_set->size()); shard_set->RunBriefInParallel( [&](EngineShard* shard) { - auto [prime_table, expire_table] = shard->db_slice().GetTables(db_cntx.db_index); + auto [prime_table, expire_table] = + cntx->ns->GetDbSlice(shard->shard_id()).GetTables(db_cntx.db_index); if (prime_table->size() == 0) { return; } diff --git a/src/server/hset_family.cc b/src/server/hset_family.cc index fb1ad8558..b54e70029 100644 --- a/src/server/hset_family.cc +++ b/src/server/hset_family.cc @@ -1055,7 +1055,7 @@ void HSetFamily::HRandField(CmdArgList args, ConnectionContext* cntx) { } auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult { - auto& db_slice = shard->db_slice(); + auto& db_slice = cntx->ns->GetDbSlice(shard->shard_id()); DbContext db_context = t->GetDbContext(); auto it_res = db_slice.FindReadOnly(db_context, key, OBJ_HASH); diff --git a/src/server/journal/executor.cc b/src/server/journal/executor.cc index 5f5fff3dd..753968718 100644 --- a/src/server/journal/executor.cc +++ b/src/server/journal/executor.cc @@ -44,6 +44,7 @@ JournalExecutor::JournalExecutor(Service* service) conn_context_.is_replicating = true; conn_context_.journal_emulated = true; conn_context_.skip_acl_validation = true; + conn_context_.ns = &namespaces.GetDefaultNamespace(); } JournalExecutor::~JournalExecutor() { diff --git a/src/server/list_family.cc b/src/server/list_family.cc index ee518fc45..522721b53 100644 --- a/src/server/list_family.cc +++ b/src/server/list_family.cc @@ -341,11 +341,12 @@ OpResult OpPush(const OpArgs& op_args, std::string_view key, ListDir d } if (res.is_new) { - if (es->blocking_controller()) { + auto blocking_controller = op_args.db_cntx.ns->GetBlockingController(es->shard_id()); + if (blocking_controller) { string tmp; string_view key = res.it->first.GetSlice(&tmp); - es->blocking_controller()->AwakeWatched(op_args.db_cntx.db_index, key); + blocking_controller->AwakeWatched(op_args.db_cntx.db_index, key); absl::StrAppend(debugMessages.Next(), "OpPush AwakeWatched: ", key, " by ", op_args.tx->DebugId()); } @@ -444,11 +445,12 @@ OpResult MoveTwoShards(Transaction* trans, string_view src, string_view OpPush(op_args, key, dest_dir, false, ArgSlice{val}, true); // blocking_controller does not have to be set with non-blocking transactions. - if (shard->blocking_controller()) { + auto blocking_controller = t->GetNamespace().GetBlockingController(shard->shard_id()); + if (blocking_controller) { // hack, again. since we hacked which queue we are waiting on (see RunPair) // we must clean-up src key here manually. See RunPair why we do this. // in short- we suspended on "src" on both shards. - shard->blocking_controller()->FinalizeWatched(ArgSlice({src}), t); + blocking_controller->FinalizeWatched(ArgSlice({src}), t); } } else { DVLOG(1) << "Popping value from list: " << key; @@ -852,10 +854,11 @@ OpResult BPopPusher::RunSingle(ConnectionContext* cntx, time_point tp) { std::array arr = {pop_key_, push_key_, DirToSv(popdir_), DirToSv(pushdir_)}; RecordJournal(op_args, "LMOVE", arr, 1); } - if (shard->blocking_controller()) { + auto blocking_controller = cntx->ns->GetBlockingController(shard->shard_id()); + if (blocking_controller) { string tmp; - shard->blocking_controller()->AwakeWatched(op_args.db_cntx.db_index, push_key_); + blocking_controller->AwakeWatched(op_args.db_cntx.db_index, push_key_); absl::StrAppend(debugMessages.Next(), "OpPush AwakeWatched: ", push_key_, " by ", op_args.tx->DebugId()); } diff --git a/src/server/list_family_test.cc b/src/server/list_family_test.cc index 817b54256..3c42664c2 100644 --- a/src/server/list_family_test.cc +++ b/src/server/list_family_test.cc @@ -32,8 +32,10 @@ class ListFamilyTest : public BaseFamilyTest { static unsigned NumWatched() { atomic_uint32_t sum{0}; + + auto ns = &namespaces.GetDefaultNamespace(); shard_set->RunBriefInParallel([&](EngineShard* es) { - auto* bc = es->blocking_controller(); + auto* bc = ns->GetBlockingController(es->shard_id()); if (bc) sum.fetch_add(bc->NumWatched(0), memory_order_relaxed); }); @@ -43,8 +45,9 @@ class ListFamilyTest : public BaseFamilyTest { static bool HasAwakened() { atomic_uint32_t sum{0}; + auto ns = &namespaces.GetDefaultNamespace(); shard_set->RunBriefInParallel([&](EngineShard* es) { - auto* bc = es->blocking_controller(); + auto* bc = ns->GetBlockingController(es->shard_id()); if (bc) sum.fetch_add(bc->HasAwakedTransaction(), memory_order_relaxed); }); @@ -168,7 +171,7 @@ TEST_F(ListFamilyTest, BLPopTimeout) { RespExpr resp = Run({"blpop", kKey1, kKey2, kKey3, "0.01"}); EXPECT_THAT(resp, ArgType(RespExpr::NIL_ARRAY)); EXPECT_EQ(3, GetDebugInfo().shards_count); - ASSERT_FALSE(service_->IsLocked(0, kKey1)); + ASSERT_FALSE(IsLocked(0, kKey1)); // Under Multi resp = Run({"multi"}); @@ -178,7 +181,7 @@ TEST_F(ListFamilyTest, BLPopTimeout) { resp = Run({"exec"}); EXPECT_THAT(resp, ArgType(RespExpr::NIL_ARRAY)); - ASSERT_FALSE(service_->IsLocked(0, kKey1)); + ASSERT_FALSE(IsLocked(0, kKey1)); ASSERT_EQ(0, NumWatched()); } diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 702ae5de2..04796d93a 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -47,6 +47,7 @@ extern "C" { #include "server/json_family.h" #include "server/list_family.h" #include "server/multi_command_squasher.h" +#include "server/namespaces.h" #include "server/script_mgr.h" #include "server/search/search_family.h" #include "server/server_state.h" @@ -135,9 +136,11 @@ constexpr size_t kMaxThreadSize = 1024; // Unwatch all keys for a connection and unregister from DbSlices. // Used by UNWATCH, DICARD and EXEC. -void UnwatchAllKeys(ConnectionState::ExecInfo* exec_info) { +void UnwatchAllKeys(Namespace* ns, ConnectionState::ExecInfo* exec_info) { if (!exec_info->watched_keys.empty()) { - auto cb = [&](EngineShard* shard) { shard->db_slice().UnregisterConnectionWatches(exec_info); }; + auto cb = [&](EngineShard* shard) { + ns->GetDbSlice(shard->shard_id()).UnregisterConnectionWatches(exec_info); + }; shard_set->RunBriefInParallel(std::move(cb)); } exec_info->ClearWatched(); @@ -149,7 +152,7 @@ void MultiCleanup(ConnectionContext* cntx) { ServerState::tlocal()->ReturnInterpreter(borrowed); exec_info.preborrowed_interpreter = nullptr; } - UnwatchAllKeys(&exec_info); + UnwatchAllKeys(cntx->ns, &exec_info); exec_info.Clear(); } @@ -513,7 +516,8 @@ void Topkeys(const http::QueryArgs& args, HttpContext* send) { vector rows(shard_set->size()); shard_set->RunBriefInParallel([&](EngineShard* shard) { - for (const auto& db : shard->db_slice().databases()) { + for (const auto& db : + namespaces.GetDefaultNamespace().GetDbSlice(shard->shard_id()).databases()) { if (db->top_keys.IsEnabled()) { is_enabled = true; for (const auto& [key, count] : db->top_keys.GetTopKeys()) { @@ -829,6 +833,7 @@ Service::Service(ProactorPool* pp) Service::~Service() { delete shard_set; shard_set = nullptr; + namespaces.Clear(); } void Service::Init(util::AcceptServer* acceptor, std::vector listeners, @@ -908,7 +913,10 @@ void Service::Shutdown() { ChannelStore::Destroy(); + namespaces.Clear(); + shard_set->Shutdown(); + pp_.Await([](ProactorBase* pb) { ServerState::tlocal()->Destroy(); }); // wait for all the pending callbacks to stop. @@ -1212,8 +1220,8 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) DCHECK(dfly_cntx->transaction); if (cid->IsTransactional()) { dfly_cntx->transaction->MultiSwitchCmd(cid); - OpStatus status = - dfly_cntx->transaction->InitByArgs(dfly_cntx->conn_state.db_index, args_no_cmd); + OpStatus status = dfly_cntx->transaction->InitByArgs( + dfly_cntx->ns, dfly_cntx->conn_state.db_index, args_no_cmd); if (status != OpStatus::OK) return cntx->SendError(status); @@ -1225,7 +1233,9 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) dist_trans.reset(new Transaction{cid}); if (!dist_trans->IsMulti()) { // Multi command initialize themself based on their mode. - if (auto st = dist_trans->InitByArgs(dfly_cntx->conn_state.db_index, args_no_cmd); + CHECK(dfly_cntx->ns != nullptr); + if (auto st = + dist_trans->InitByArgs(dfly_cntx->ns, dfly_cntx->conn_state.db_index, args_no_cmd); st != OpStatus::OK) return cntx->SendError(st); } @@ -1581,6 +1591,7 @@ facade::ConnectionContext* Service::CreateContext(util::FiberSocketBase* peer, facade::Connection* owner) { auto cred = user_registry_.GetCredentials("default"); ConnectionContext* res = new ConnectionContext{peer, owner, std::move(cred)}; + res->ns = &namespaces.GetOrInsert(""); if (peer->IsUDS()) { res->req_auth = false; @@ -1606,10 +1617,10 @@ const CommandId* Service::FindCmd(std::string_view cmd) const { return registry_.Find(registry_.RenamedOrOriginal(cmd)); } -bool Service::IsLocked(DbIndex db_index, std::string_view key) const { +bool Service::IsLocked(Namespace* ns, DbIndex db_index, std::string_view key) const { ShardId sid = Shard(key, shard_count()); - bool is_open = pp_.at(sid)->AwaitBrief([db_index, key] { - return EngineShard::tlocal()->db_slice().CheckLock(IntentLock::EXCLUSIVE, db_index, key); + bool is_open = pp_.at(sid)->AwaitBrief([db_index, key, ns, sid] { + return ns->GetDbSlice(sid).CheckLock(IntentLock::EXCLUSIVE, db_index, key); }); return !is_open; } @@ -1682,7 +1693,7 @@ void Service::Watch(CmdArgList args, ConnectionContext* cntx) { } void Service::Unwatch(CmdArgList args, ConnectionContext* cntx) { - UnwatchAllKeys(&cntx->conn_state.exec_info); + UnwatchAllKeys(cntx->ns, &cntx->conn_state.exec_info); return cntx->SendOk(); } @@ -1841,6 +1852,7 @@ Transaction::MultiMode DetermineMultiMode(ScriptMgr::ScriptParams params) { optional StartMultiEval(DbIndex dbid, CmdArgList keys, ScriptMgr::ScriptParams params, ConnectionContext* cntx) { Transaction* trans = cntx->transaction; + Namespace* ns = cntx->ns; Transaction::MultiMode script_mode = DetermineMultiMode(params); Transaction::MultiMode multi_mode = trans->GetMultiMode(); // Check if eval is already part of a running multi transaction @@ -1860,10 +1872,10 @@ optional StartMultiEval(DbIndex dbid, CmdArgList keys, ScriptMgr::ScriptPa switch (script_mode) { case Transaction::GLOBAL: - trans->StartMultiGlobal(dbid); + trans->StartMultiGlobal(ns, dbid); return true; case Transaction::LOCK_AHEAD: - trans->StartMultiLockedAhead(dbid, keys); + trans->StartMultiLockedAhead(ns, dbid, keys); return true; case Transaction::NON_ATOMIC: trans->StartMultiNonAtomic(); @@ -1988,7 +2000,7 @@ void Service::EvalInternal(CmdArgList args, const EvalArgs& eval_args, Interpret }); ++ServerState::tlocal()->stats.eval_shardlocal_coordination_cnt; - tx->PrepareMultiForScheduleSingleHop(*sid, tx->GetDbIndex(), args); + tx->PrepareMultiForScheduleSingleHop(cntx->ns, *sid, tx->GetDbIndex(), args); tx->ScheduleSingleHop([&](Transaction*, EngineShard*) { boost::intrusive_ptr stub_tx = new Transaction{tx, *sid, slot_checker.GetUniqueSlotId()}; @@ -2077,7 +2089,7 @@ bool CheckWatchedKeyExpiry(ConnectionContext* cntx, const CommandRegistry& regis }; cntx->transaction->MultiSwitchCmd(registry.Find(EXISTS)); - cntx->transaction->InitByArgs(cntx->conn_state.db_index, CmdArgList{str_list}); + cntx->transaction->InitByArgs(cntx->ns, cntx->conn_state.db_index, CmdArgList{str_list}); OpStatus status = cntx->transaction->ScheduleSingleHop(std::move(cb)); CHECK_EQ(OpStatus::OK, status); @@ -2139,11 +2151,11 @@ void StartMultiExec(ConnectionContext* cntx, ConnectionState::ExecInfo* exec_inf auto dbid = cntx->db_index(); switch (multi_mode) { case Transaction::GLOBAL: - trans->StartMultiGlobal(dbid); + trans->StartMultiGlobal(cntx->ns, dbid); break; case Transaction::LOCK_AHEAD: { auto vec = CollectAllKeys(exec_info); - trans->StartMultiLockedAhead(dbid, absl::MakeSpan(vec)); + trans->StartMultiLockedAhead(cntx->ns, dbid, absl::MakeSpan(vec)); } break; case Transaction::NON_ATOMIC: trans->StartMultiNonAtomic(); @@ -2234,7 +2246,7 @@ void Service::Exec(CmdArgList args, ConnectionContext* cntx) { CmdArgList args = absl::MakeSpan(arg_vec); if (scmd.Cid()->IsTransactional()) { - OpStatus st = cntx->transaction->InitByArgs(cntx->conn_state.db_index, args); + OpStatus st = cntx->transaction->InitByArgs(cntx->ns, cntx->conn_state.db_index, args); if (st != OpStatus::OK) { cntx->SendError(st); break; @@ -2444,7 +2456,7 @@ void Service::Command(CmdArgList args, ConnectionContext* cntx) { VarzValue::Map Service::GetVarzStats() { VarzValue::Map res; - Metrics m = server_family_.GetMetrics(); + Metrics m = server_family_.GetMetrics(&namespaces.GetDefaultNamespace()); DbStats db_stats; for (const auto& s : m.db_stats) { db_stats += s; @@ -2543,7 +2555,7 @@ void Service::OnClose(facade::ConnectionContext* cntx) { DCHECK(!conn_state.subscribe_info); } - UnwatchAllKeys(&conn_state.exec_info); + UnwatchAllKeys(server_cntx->ns, &conn_state.exec_info); DeactivateMonitoring(server_cntx); diff --git a/src/server/main_service.h b/src/server/main_service.h index f8cf381a6..21ebbbda3 100644 --- a/src/server/main_service.h +++ b/src/server/main_service.h @@ -102,7 +102,7 @@ class Service : public facade::ServiceInterface { } // Used by tests. - bool IsLocked(DbIndex db_index, std::string_view key) const; + bool IsLocked(Namespace* ns, DbIndex db_index, std::string_view key) const; bool IsShardSetLocked() const; util::ProactorPool& proactor_pool() { diff --git a/src/server/memory_cmd.cc b/src/server/memory_cmd.cc index 91c9b25f5..f066fd01e 100644 --- a/src/server/memory_cmd.cc +++ b/src/server/memory_cmd.cc @@ -245,7 +245,7 @@ void PushMemoryUsageStats(const base::IoBuf::MemoryUsage& mem, string_view prefi void MemoryCmd::Stats() { vector> stats; stats.reserve(25); - auto server_metrics = owner_->GetMetrics(); + auto server_metrics = owner_->GetMetrics(cntx_->ns); // RSS stats.push_back({"rss_bytes", rss_mem_current.load(memory_order_relaxed)}); @@ -369,8 +369,8 @@ void MemoryCmd::ArenaStats(CmdArgList args) { void MemoryCmd::Usage(std::string_view key) { ShardId sid = Shard(key, shard_set->size()); - ssize_t memory_usage = shard_set->pool()->at(sid)->AwaitBrief([key, this]() -> ssize_t { - auto& db_slice = EngineShard::tlocal()->db_slice(); + ssize_t memory_usage = shard_set->pool()->at(sid)->AwaitBrief([key, this, sid]() -> ssize_t { + auto& db_slice = cntx_->ns->GetDbSlice(sid); auto [pt, exp_t] = db_slice.GetTables(cntx_->db_index()); PrimeIterator it = pt->Find(key); if (IsValid(it)) { diff --git a/src/server/multi_command_squasher.cc b/src/server/multi_command_squasher.cc index b9eeb6b54..c25680f3b 100644 --- a/src/server/multi_command_squasher.cc +++ b/src/server/multi_command_squasher.cc @@ -145,7 +145,7 @@ bool MultiCommandSquasher::ExecuteStandalone(StoredCmd* cmd) { cntx_->cid = cmd->Cid(); if (cmd->Cid()->IsTransactional()) - tx->InitByArgs(cntx_->conn_state.db_index, args); + tx->InitByArgs(cntx_->ns, cntx_->conn_state.db_index, args); service_->InvokeCmd(cmd->Cid(), args, cntx_); return true; @@ -181,7 +181,7 @@ OpStatus MultiCommandSquasher::SquashedHopCb(Transaction* parent_tx, EngineShard local_cntx.cid = cmd->Cid(); crb.SetReplyMode(cmd->ReplyMode()); - local_tx->InitByArgs(local_cntx.conn_state.db_index, args); + local_tx->InitByArgs(cntx_->ns, local_cntx.conn_state.db_index, args); service_->InvokeCmd(cmd->Cid(), args, &local_cntx); sinfo.replies.emplace_back(crb.Take()); diff --git a/src/server/multi_test.cc b/src/server/multi_test.cc index 02a7c687b..ebf1b6f36 100644 --- a/src/server/multi_test.cc +++ b/src/server/multi_test.cc @@ -109,8 +109,8 @@ TEST_F(MultiTest, Multi) { resp = Run({"get", kKey4}); ASSERT_THAT(resp, ArgType(RespExpr::NIL)); - ASSERT_FALSE(service_->IsLocked(0, kKey1)); - ASSERT_FALSE(service_->IsLocked(0, kKey4)); + ASSERT_FALSE(IsLocked(0, kKey1)); + ASSERT_FALSE(IsLocked(0, kKey4)); ASSERT_FALSE(service_->IsShardSetLocked()); } @@ -129,8 +129,8 @@ TEST_F(MultiTest, MultiGlobalCommands) { ASSERT_THAT(Run({"select", "2"}), "OK"); ASSERT_THAT(Run({"get", "key"}), "val"); - ASSERT_FALSE(service_->IsLocked(0, "key")); - ASSERT_FALSE(service_->IsLocked(2, "key")); + ASSERT_FALSE(IsLocked(0, "key")); + ASSERT_FALSE(IsLocked(2, "key")); } TEST_F(MultiTest, HitMissStats) { @@ -181,8 +181,8 @@ TEST_F(MultiTest, MultiSeq) { ASSERT_EQ(resp, "QUEUED"); resp = Run({"exec"}); - ASSERT_FALSE(service_->IsLocked(0, kKey1)); - ASSERT_FALSE(service_->IsLocked(0, kKey4)); + ASSERT_FALSE(IsLocked(0, kKey1)); + ASSERT_FALSE(IsLocked(0, kKey4)); ASSERT_FALSE(service_->IsShardSetLocked()); ASSERT_THAT(resp, ArrLen(3)); @@ -237,8 +237,8 @@ TEST_F(MultiTest, MultiConsistent) { mset_fb.Join(); fb.Join(); - ASSERT_FALSE(service_->IsLocked(0, kKey1)); - ASSERT_FALSE(service_->IsLocked(0, kKey4)); + ASSERT_FALSE(IsLocked(0, kKey1)); + ASSERT_FALSE(IsLocked(0, kKey4)); ASSERT_FALSE(service_->IsShardSetLocked()); } @@ -312,9 +312,9 @@ TEST_F(MultiTest, MultiRename) { resp = Run({"exec"}); EXPECT_EQ(resp, "OK"); - EXPECT_FALSE(service_->IsLocked(0, kKey1)); - EXPECT_FALSE(service_->IsLocked(0, kKey2)); - EXPECT_FALSE(service_->IsLocked(0, kKey4)); + EXPECT_FALSE(IsLocked(0, kKey1)); + EXPECT_FALSE(IsLocked(0, kKey2)); + EXPECT_FALSE(IsLocked(0, kKey4)); EXPECT_FALSE(service_->IsShardSetLocked()); } @@ -366,8 +366,8 @@ TEST_F(MultiTest, FlushDb) { fb0.Join(); - ASSERT_FALSE(service_->IsLocked(0, kKey1)); - ASSERT_FALSE(service_->IsLocked(0, kKey4)); + ASSERT_FALSE(IsLocked(0, kKey1)); + ASSERT_FALSE(IsLocked(0, kKey4)); ASSERT_FALSE(service_->IsShardSetLocked()); } @@ -400,17 +400,17 @@ TEST_F(MultiTest, Eval) { resp = Run({"eval", "return redis.call('get', 'foo')", "1", "bar"}); EXPECT_THAT(resp, ErrArg("undeclared")); - ASSERT_FALSE(service_->IsLocked(0, "foo")); + ASSERT_FALSE(IsLocked(0, "foo")); Run({"script", "flush"}); // Reset global flag from autocorrect resp = Run({"eval", "return redis.call('get', 'foo')", "1", "foo"}); EXPECT_THAT(resp, "42"); - ASSERT_FALSE(service_->IsLocked(0, "foo")); + ASSERT_FALSE(IsLocked(0, "foo")); resp = Run({"eval", "return redis.call('get', KEYS[1])", "1", "foo"}); EXPECT_THAT(resp, "42"); - ASSERT_FALSE(service_->IsLocked(0, "foo")); + ASSERT_FALSE(IsLocked(0, "foo")); ASSERT_FALSE(service_->IsShardSetLocked()); resp = Run({"eval", "return 77", "2", "foo", "zoo"}); @@ -451,7 +451,7 @@ TEST_F(MultiTest, Eval) { "1", "foo"}), "42"); - auto condition = [&]() { return service_->IsLocked(0, "foo"); }; + auto condition = [&]() { return IsLocked(0, "foo"); }; auto fb = ExpectConditionWithSuspension(condition); EXPECT_EQ(Run({"eval", R"(redis.call('set', 'foo', '42') @@ -974,7 +974,7 @@ TEST_F(MultiTest, TestLockedKeys) { GTEST_SKIP() << "Skipped TestLockedKeys test because multi_exec_mode is not lock ahead"; return; } - auto condition = [&]() { return service_->IsLocked(0, "key1") && service_->IsLocked(0, "key2"); }; + auto condition = [&]() { return IsLocked(0, "key1") && IsLocked(0, "key2"); }; auto fb = ExpectConditionWithSuspension(condition); EXPECT_EQ(Run({"multi"}), "OK"); @@ -983,8 +983,8 @@ TEST_F(MultiTest, TestLockedKeys) { EXPECT_EQ(Run({"mset", "key1", "val3", "key1", "val4"}), "QUEUED"); EXPECT_THAT(Run({"exec"}), RespArray(ElementsAre("OK", "OK", "OK"))); fb.Join(); - EXPECT_FALSE(service_->IsLocked(0, "key1")); - EXPECT_FALSE(service_->IsLocked(0, "key2")); + EXPECT_FALSE(IsLocked(0, "key1")); + EXPECT_FALSE(IsLocked(0, "key2")); } TEST_F(MultiTest, EvalExpiration) { diff --git a/src/server/namespaces.cc b/src/server/namespaces.cc new file mode 100644 index 000000000..1c16f69d3 --- /dev/null +++ b/src/server/namespaces.cc @@ -0,0 +1,103 @@ +#include "server/namespaces.h" + +#include "base/flags.h" +#include "base/logging.h" +#include "server/engine_shard_set.h" + +ABSL_DECLARE_FLAG(bool, cache_mode); + +namespace dfly { + +using namespace std; + +Namespace::Namespace() { + shard_db_slices_.resize(shard_set->size()); + shard_blocking_controller_.resize(shard_set->size()); + shard_set->RunBriefInParallel([&](EngineShard* es) { + CHECK(es != nullptr); + ShardId sid = es->shard_id(); + shard_db_slices_[sid] = make_unique(sid, absl::GetFlag(FLAGS_cache_mode), es); + shard_db_slices_[sid]->UpdateExpireBase(absl::GetCurrentTimeNanos() / 1000000, 0); + }); +} + +DbSlice& Namespace::GetCurrentDbSlice() { + EngineShard* es = EngineShard::tlocal(); + CHECK(es != nullptr); + return GetDbSlice(es->shard_id()); +} + +DbSlice& Namespace::GetDbSlice(ShardId sid) { + CHECK_LT(sid, shard_db_slices_.size()); + return *shard_db_slices_[sid]; +} + +BlockingController* Namespace::GetOrAddBlockingController(EngineShard* shard) { + if (!shard_blocking_controller_[shard->shard_id()]) { + shard_blocking_controller_[shard->shard_id()] = make_unique(shard, this); + } + + return shard_blocking_controller_[shard->shard_id()].get(); +} + +BlockingController* Namespace::GetBlockingController(ShardId sid) { + return shard_blocking_controller_[sid].get(); +} + +Namespaces namespaces; + +Namespaces::~Namespaces() { + Clear(); +} + +void Namespaces::Init() { + DCHECK(default_namespace_ == nullptr); + default_namespace_ = &GetOrInsert(""); +} + +bool Namespaces::IsInitialized() const { + return default_namespace_ != nullptr; +} + +void Namespaces::Clear() { + std::unique_lock guard(mu_); + + namespaces.default_namespace_ = nullptr; + + if (namespaces_.empty()) { + return; + } + + shard_set->RunBriefInParallel([&](EngineShard* es) { + CHECK(es != nullptr); + for (auto& ns : namespaces_) { + ns.second.shard_db_slices_[es->shard_id()].reset(); + } + }); + + namespaces.namespaces_.clear(); +} + +Namespace& Namespaces::GetDefaultNamespace() const { + CHECK(default_namespace_ != nullptr); + return *default_namespace_; +} + +Namespace& Namespaces::GetOrInsert(std::string_view ns) { + { + // Try to look up under a shared lock + std::shared_lock guard(mu_); + auto it = namespaces_.find(ns); + if (it != namespaces_.end()) { + return it->second; + } + } + + { + // Key was not found, so we create create it under unique lock + std::unique_lock guard(mu_); + return namespaces_[ns]; + } +} + +} // namespace dfly diff --git a/src/server/namespaces.h b/src/server/namespaces.h new file mode 100644 index 000000000..d4d6390e3 --- /dev/null +++ b/src/server/namespaces.h @@ -0,0 +1,70 @@ +// Copyright 2024, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#pragma once + +#include + +#include +#include +#include + +#include "server/blocking_controller.h" +#include "server/db_slice.h" +#include "server/tx_base.h" +#include "util/fibers/synchronization.h" + +namespace dfly { + +// A Namespace is a way to separate and isolate different databases in a single instance. +// It can be used to allow multiple tenants to use the same server without hacks of using a common +// prefix, or SELECT-ing a different database. +// Each Namespace contains per-shard DbSlice, as well as a BlockingController. +class Namespace { + public: + Namespace(); + + DbSlice& GetCurrentDbSlice(); + + DbSlice& GetDbSlice(ShardId sid); + BlockingController* GetOrAddBlockingController(EngineShard* shard); + BlockingController* GetBlockingController(ShardId sid); + + private: + std::vector> shard_db_slices_; + std::vector> shard_blocking_controller_; + + friend class Namespaces; +}; + +// Namespaces is a registry and container for Namespace instances. +// Each Namespace has a unique string name, which identifies it in the store. +// Any attempt to access a non-existing Namespace will first create it, add it to the internal map +// and will then return it. +// It is currently impossible to remove a Namespace after it has been created. +// The default Namespace can be accessed via either GetDefaultNamespace() (which guarantees not to +// yield), or via the GetOrInsert() with an empty string. +// The initialization order of this class with the engine shards is slightly subtle, as they have +// mutual dependencies. +class Namespaces { + public: + Namespaces() = default; + ~Namespaces(); + + void Init(); + bool IsInitialized() const; + void Clear(); // Thread unsafe, use in tear-down or tests + + Namespace& GetDefaultNamespace() const; // No locks + Namespace& GetOrInsert(std::string_view ns); + + private: + util::fb2::SharedMutex mu_{}; + absl::node_hash_map namespaces_ ABSL_GUARDED_BY(mu_); + Namespace* default_namespace_ = nullptr; +}; + +extern Namespaces namespaces; + +} // namespace dfly diff --git a/src/server/rdb_load.cc b/src/server/rdb_load.cc index c03062f84..61c44fdb9 100644 --- a/src/server/rdb_load.cc +++ b/src/server/rdb_load.cc @@ -2066,7 +2066,8 @@ error_code RdbLoader::Load(io::Source* src) { FlushShardAsync(i); // Active database if not existed before. - shard_set->Add(i, [dbid] { EngineShard::tlocal()->db_slice().ActivateDb(dbid); }); + shard_set->Add( + i, [dbid] { namespaces.GetDefaultNamespace().GetCurrentDbSlice().ActivateDb(dbid); }); } cur_db_index_ = dbid; @@ -2451,8 +2452,8 @@ std::error_code RdbLoaderBase::FromOpaque(const OpaqueObj& opaque, CompactObj* p void RdbLoader::LoadItemsBuffer(DbIndex db_ind, const ItemsBuf& ib) { EngineShard* es = EngineShard::tlocal(); - DbSlice& db_slice = es->db_slice(); - DbContext db_cntx{db_ind, GetCurrentTimeMs()}; + DbContext db_cntx{&namespaces.GetDefaultNamespace(), db_ind, GetCurrentTimeMs()}; + DbSlice& db_slice = db_cntx.GetDbSlice(es->shard_id()); for (const auto* item : ib) { PrimeValue pv; @@ -2564,6 +2565,7 @@ void RdbLoader::LoadSearchIndexDefFromAux(string&& def) { cntx.is_replicating = true; cntx.journal_emulated = true; cntx.skip_acl_validation = true; + cntx.ns = &namespaces.GetDefaultNamespace(); // Avoid deleting local crb absl::Cleanup cntx_clean = [&cntx] { cntx.Inject(nullptr); }; @@ -2613,7 +2615,8 @@ void RdbLoader::PerformPostLoad(Service* service) { // Rebuild all search indices as only their definitions are extracted from the snapshot shard_set->AwaitRunningOnShardQueue([](EngineShard* es) { - es->search_indices()->RebuildAllIndices(OpArgs{es, nullptr, DbContext{0, GetCurrentTimeMs()}}); + es->search_indices()->RebuildAllIndices( + OpArgs{es, nullptr, DbContext{&namespaces.GetDefaultNamespace(), 0, GetCurrentTimeMs()}}); }); } diff --git a/src/server/rdb_save.cc b/src/server/rdb_save.cc index 9166bdc97..1c20ba2c2 100644 --- a/src/server/rdb_save.cc +++ b/src/server/rdb_save.cc @@ -35,6 +35,7 @@ extern "C" { #include "server/engine_shard_set.h" #include "server/error.h" #include "server/main_service.h" +#include "server/namespaces.h" #include "server/rdb_extensions.h" #include "server/search/doc_index.h" #include "server/serializer_commons.h" @@ -1251,15 +1252,17 @@ error_code RdbSaver::Impl::ConsumeChannel(const Cancellation* cll) { void RdbSaver::Impl::StartSnapshotting(bool stream_journal, const Cancellation* cll, EngineShard* shard) { auto& s = GetSnapshot(shard); - s = std::make_unique(&shard->db_slice(), &channel_, compression_mode_); + auto& db_slice = namespaces.GetDefaultNamespace().GetDbSlice(shard->shard_id()); + s = std::make_unique(&db_slice, &channel_, compression_mode_); s->Start(stream_journal, cll); } void RdbSaver::Impl::StartIncrementalSnapshotting(Context* cntx, EngineShard* shard, LSN start_lsn) { + auto& db_slice = namespaces.GetDefaultNamespace().GetDbSlice(shard->shard_id()); auto& s = GetSnapshot(shard); - s = std::make_unique(&shard->db_slice(), &channel_, compression_mode_); + s = std::make_unique(&db_slice, &channel_, compression_mode_); s->StartIncremental(cntx, start_lsn); } diff --git a/src/server/replica.cc b/src/server/replica.cc index d7e02fa1d..63c261c23 100644 --- a/src/server/replica.cc +++ b/src/server/replica.cc @@ -570,6 +570,7 @@ error_code Replica::ConsumeRedisStream() { conn_context.is_replicating = true; conn_context.journal_emulated = true; conn_context.skip_acl_validation = true; + conn_context.ns = &namespaces.GetDefaultNamespace(); ResetParser(true); // Master waits for this command in order to start sending replication stream. diff --git a/src/server/server_family.cc b/src/server/server_family.cc index a5f5e760b..7b5a40cb8 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -447,7 +447,7 @@ void ClientPauseCmd(CmdArgList args, vector listeners, Connec }; if (auto pause_fb_opt = - Pause(listeners, cntx->conn(), pause_state, std::move(is_pause_in_progress)); + Pause(listeners, cntx->ns, cntx->conn(), pause_state, std::move(is_pause_in_progress)); pause_fb_opt) { pause_fb_opt->Detach(); cntx->SendOk(); @@ -673,8 +673,8 @@ optional ReplicaOfArgs::FromCmdArgs(CmdArgList args, ConnectionCo } // namespace -std::optional Pause(std::vector listeners, facade::Connection* conn, - ClientPause pause_state, +std::optional Pause(std::vector listeners, Namespace* ns, + facade::Connection* conn, ClientPause pause_state, std::function is_pause_in_progress) { // Track connections and set pause state to be able to wait untill all running transactions read // the new pause state. Exlude already paused commands from the busy count. Exlude tracking @@ -683,7 +683,7 @@ std::optional Pause(std::vector listeners, facade // command that did not pause on the new state yet we will pause after waking up. DispatchTracker tracker{std::move(listeners), conn, true /* ignore paused commands */, true /*ignore blocking*/}; - shard_set->pool()->AwaitBrief([&tracker, pause_state](unsigned, util::ProactorBase*) { + shard_set->pool()->AwaitBrief([&tracker, pause_state, ns](unsigned, util::ProactorBase*) { // Commands don't suspend before checking the pause state, so // it's impossible to deadlock on waiting for a command that will be paused. tracker.TrackOnThread(); @@ -703,9 +703,9 @@ std::optional Pause(std::vector listeners, facade // We should not expire/evict keys while clients are puased. shard_set->RunBriefInParallel( - [](EngineShard* shard) { shard->db_slice().SetExpireAllowed(false); }); + [ns](EngineShard* shard) { ns->GetDbSlice(shard->shard_id()).SetExpireAllowed(false); }); - return fb2::Fiber("client_pause", [is_pause_in_progress, pause_state]() mutable { + return fb2::Fiber("client_pause", [is_pause_in_progress, pause_state, ns]() mutable { // On server shutdown we sleep 10ms to make sure all running task finish, therefore 10ms steps // ensure this fiber will not left hanging . constexpr auto step = 10ms; @@ -719,7 +719,7 @@ std::optional Pause(std::vector listeners, facade ServerState::tlocal()->SetPauseState(pause_state, false); }); shard_set->RunBriefInParallel( - [](EngineShard* shard) { shard->db_slice().SetExpireAllowed(true); }); + [ns](EngineShard* shard) { ns->GetDbSlice(shard->shard_id()).SetExpireAllowed(true); }); } }); } @@ -1345,7 +1345,8 @@ void ServerFamily::ConfigureMetrics(util::HttpListenerBase* http_base) { auto cb = [this](const util::http::QueryArgs& args, util::HttpContext* send) { StringResponse resp = util::http::MakeStringResponse(boost::beast::http::status::ok); - PrintPrometheusMetrics(this->GetMetrics(), this->dfly_cmd_.get(), &resp); + PrintPrometheusMetrics(this->GetMetrics(&namespaces.GetDefaultNamespace()), + this->dfly_cmd_.get(), &resp); return send->Invoke(std::move(resp)); }; @@ -1424,7 +1425,7 @@ void ServerFamily::StatsMC(std::string_view section, facade::ConnectionContext* double utime = dbl_time(ru.ru_utime); double systime = dbl_time(ru.ru_stime); - Metrics m = GetMetrics(); + Metrics m = GetMetrics(&namespaces.GetDefaultNamespace()); ADD_LINE(pid, getpid()); ADD_LINE(uptime, m.uptime); @@ -1454,7 +1455,7 @@ GenericError ServerFamily::DoSave(bool ignore_state) { const CommandId* cid = service().FindCmd("SAVE"); CHECK_NOTNULL(cid); boost::intrusive_ptr trans(new Transaction{cid}); - trans->InitByArgs(0, {}); + trans->InitByArgs(&namespaces.GetDefaultNamespace(), 0, {}); return DoSave(absl::GetFlag(FLAGS_df_snapshot_format), {}, trans.get(), ignore_state); } @@ -1551,7 +1552,7 @@ void ServerFamily::DbSize(CmdArgList args, ConnectionContext* cntx) { shard_set->RunBriefInParallel( [&](EngineShard* shard) { - auto db_size = shard->db_slice().DbSize(cntx->conn_state.db_index); + auto db_size = cntx->ns->GetDbSlice(shard->shard_id()).DbSize(cntx->conn_state.db_index); num_keys.fetch_add(db_size, memory_order_relaxed); }, [](ShardId) { return true; }); @@ -1647,6 +1648,7 @@ void ServerFamily::Auth(CmdArgList args, ConnectionContext* cntx) { auto cred = registry->GetCredentials(username); cntx->acl_commands = cred.acl_commands; cntx->keys = std::move(cred.keys); + cntx->ns = &namespaces.GetOrInsert(cred.ns); cntx->authenticated = true; return cntx->SendOk(); } @@ -1773,7 +1775,7 @@ void ServerFamily::Config(CmdArgList args, ConnectionContext* cntx) { } if (sub_cmd == "RESETSTAT") { - ResetStat(); + ResetStat(cntx->ns); return cntx->SendOk(); } else { return cntx->SendError(UnknownSubCmd(sub_cmd, "CONFIG"), kSyntaxErrType); @@ -1883,17 +1885,16 @@ static void MergeDbSliceStats(const DbSlice::Stats& src, Metrics* dest) { dest->small_string_bytes += src.small_string_bytes; } -void ServerFamily::ResetStat() { +void ServerFamily::ResetStat(Namespace* ns) { shard_set->pool()->AwaitBrief( - [registry = service_.mutable_registry(), this](unsigned index, auto*) { + [registry = service_.mutable_registry(), this, ns](unsigned index, auto*) { registry->ResetCallStats(index); SinkReplyBuilder::ResetThreadLocalStats(); auto& stats = tl_facade_stats->conn_stats; stats.command_cnt = 0; stats.pipelined_cmd_cnt = 0; - EngineShard* shard = EngineShard::tlocal(); - shard->db_slice().ResetEvents(); + ns->GetCurrentDbSlice().ResetEvents(); tl_facade_stats->conn_stats.conn_received_cnt = 0; tl_facade_stats->conn_stats.pipelined_cmd_cnt = 0; tl_facade_stats->conn_stats.command_cnt = 0; @@ -1910,7 +1911,7 @@ void ServerFamily::ResetStat() { }); } -Metrics ServerFamily::GetMetrics() const { +Metrics ServerFamily::GetMetrics(Namespace* ns) const { Metrics result; util::fb2::Mutex mu; @@ -1942,7 +1943,7 @@ Metrics ServerFamily::GetMetrics() const { if (shard) { result.heap_used_bytes += shard->UsedMemory(); - MergeDbSliceStats(shard->db_slice().GetStats(), &result); + MergeDbSliceStats(ns->GetDbSlice(shard->shard_id()).GetStats(), &result); result.shard_stats += shard->stats(); if (shard->tiered_storage()) { @@ -2017,7 +2018,7 @@ void ServerFamily::Info(CmdArgList args, ConnectionContext* cntx) { absl::StrAppend(&info, a1, ":", a2, "\r\n"); }; - Metrics m = GetMetrics(); + Metrics m = GetMetrics(cntx->ns); DbStats total; for (const auto& db_stats : m.db_stats) total += db_stats; @@ -2589,8 +2590,9 @@ void ServerFamily::ReplicaOf(CmdArgList args, ConnectionContext* cntx) { void ServerFamily::Replicate(string_view host, string_view port) { io::NullSink sink; - ConnectionContext ctxt{&sink, nullptr, {}}; - ctxt.skip_acl_validation = true; + ConnectionContext cntx{&sink, nullptr, {}}; + cntx.ns = &namespaces.GetDefaultNamespace(); + cntx.skip_acl_validation = true; StringVec replicaof_params{string(host), string(port)}; @@ -2599,7 +2601,7 @@ void ServerFamily::Replicate(string_view host, string_view port) { args_vec.emplace_back(MutableSlice{s.data(), s.size()}); } CmdArgList args_list = absl::MakeSpan(args_vec); - ReplicaOfInternal(args_list, &ctxt, ActionOnConnectionFail::kContinueReplication); + ReplicaOfInternal(args_list, &cntx, ActionOnConnectionFail::kContinueReplication); } // REPLTAKEOVER [SAVE] diff --git a/src/server/server_family.h b/src/server/server_family.h index 021812d9c..fd7fb1420 100644 --- a/src/server/server_family.h +++ b/src/server/server_family.h @@ -15,6 +15,7 @@ #include "server/detail/save_stages_controller.h" #include "server/dflycmd.h" #include "server/engine_shard_set.h" +#include "server/namespaces.h" #include "server/replica.h" #include "server/server_state.h" #include "util/fibers/fiberqueue_threadpool.h" @@ -158,9 +159,9 @@ class ServerFamily { return service_; } - void ResetStat(); + void ResetStat(Namespace* ns); - Metrics GetMetrics() const; + Metrics GetMetrics(Namespace* ns) const; ScriptMgr* script_mgr() { return script_mgr_.get(); @@ -337,7 +338,7 @@ class ServerFamily { }; // Reusable CLIENT PAUSE implementation that blocks while polling is_pause_in_progress -std::optional Pause(std::vector listeners, +std::optional Pause(std::vector listeners, Namespace* ns, facade::Connection* conn, ClientPause pause_state, std::function is_pause_in_progress); diff --git a/src/server/stream_family.cc b/src/server/stream_family.cc index afcb05d89..3c94d9431 100644 --- a/src/server/stream_family.cc +++ b/src/server/stream_family.cc @@ -648,9 +648,9 @@ OpResult OpAdd(const OpArgs& op_args, const AddTrimOpts& opts, CmdArgL StreamTrim(opts, stream_inst); - EngineShard* es = op_args.shard; - if (es->blocking_controller()) { - es->blocking_controller()->AwakeWatched(op_args.db_cntx.db_index, opts.key); + auto blocking_controller = op_args.db_cntx.ns->GetBlockingController(op_args.shard->shard_id()); + if (blocking_controller) { + blocking_controller->AwakeWatched(op_args.db_cntx.db_index, opts.key); } return result_id; @@ -2364,7 +2364,7 @@ void StreamFamily::XInfo(CmdArgList args, ConnectionContext* cntx) { // We do not use transactional xemantics for xinfo since it's informational command. auto cb = [&]() { EngineShard* shard = EngineShard::tlocal(); - DbContext db_context{cntx->db_index(), GetCurrentTimeMs()}; + DbContext db_context{cntx->ns, cntx->db_index(), GetCurrentTimeMs()}; return OpListGroups(db_context, key, shard); }; @@ -2433,7 +2433,8 @@ void StreamFamily::XInfo(CmdArgList args, ConnectionContext* cntx) { auto cb = [&]() { EngineShard* shard = EngineShard::tlocal(); - return OpStreams(DbContext{cntx->db_index(), GetCurrentTimeMs()}, key, shard, full, count); + return OpStreams(DbContext{cntx->ns, cntx->db_index(), GetCurrentTimeMs()}, key, shard, + full, count); }; auto* rb = static_cast(cntx->reply_builder()); @@ -2561,8 +2562,8 @@ void StreamFamily::XInfo(CmdArgList args, ConnectionContext* cntx) { string_view stream_name = ArgS(args, 1); string_view group_name = ArgS(args, 2); auto cb = [&]() { - return OpConsumers(DbContext{cntx->db_index(), GetCurrentTimeMs()}, EngineShard::tlocal(), - stream_name, group_name); + return OpConsumers(DbContext{cntx->ns, cntx->db_index(), GetCurrentTimeMs()}, + EngineShard::tlocal(), stream_name, group_name); }; OpResult> result = shard_set->Await(sid, std::move(cb)); diff --git a/src/server/string_family.cc b/src/server/string_family.cc index 497dbc6b7..6bac876d6 100644 --- a/src/server/string_family.cc +++ b/src/server/string_family.cc @@ -1226,9 +1226,10 @@ void StringFamily::MSetNx(CmdArgList args, ConnectionContext* cntx) { atomic_bool exists{false}; auto cb = [&](Transaction* t, EngineShard* es) { - auto args = t->GetShardArgs(es->shard_id()); + auto sid = es->shard_id(); + auto args = t->GetShardArgs(sid); for (auto arg_it = args.begin(); arg_it != args.end(); ++arg_it) { - auto it = es->db_slice().FindReadOnly(t->GetDbContext(), *arg_it).it; + auto it = cntx->ns->GetDbSlice(sid).FindReadOnly(t->GetDbContext(), *arg_it).it; ++arg_it; if (IsValid(it)) { exists.store(true, memory_order_relaxed); diff --git a/src/server/string_family_test.cc b/src/server/string_family_test.cc index 3f84add0c..95532cc53 100644 --- a/src/server/string_family_test.cc +++ b/src/server/string_family_test.cc @@ -804,7 +804,7 @@ TEST_F(StringFamilyTest, SetWithHashtagsNoCluster) { auto fb = ExpectUsedKeys({"{key}1"}); EXPECT_EQ(Run({"set", "{key}1", "val1"}), "OK"); fb.Join(); - EXPECT_FALSE(service_->IsLocked(0, "{key}1")); + EXPECT_FALSE(IsLocked(0, "{key}1")); fb = ExpectUsedKeys({"{key}2"}); EXPECT_EQ(Run({"set", "{key}2", "val2"}), "OK"); diff --git a/src/server/test_utils.cc b/src/server/test_utils.cc index 3150284ba..73ea7b527 100644 --- a/src/server/test_utils.cc +++ b/src/server/test_utils.cc @@ -81,7 +81,7 @@ void TransactionSuspension::Start() { transaction_ = new dfly::Transaction{&cid}; - auto st = transaction_->InitByArgs(0, {}); + auto st = transaction_->InitByArgs(&namespaces.GetDefaultNamespace(), 0, {}); CHECK_EQ(st, OpStatus::OK); transaction_->Execute([](Transaction* t, EngineShard* shard) { return OpStatus::OK; }, false); @@ -107,7 +107,9 @@ class BaseFamilyTest::TestConnWrapper { const facade::Connection::InvalidationMessage& GetInvalidationMessage(size_t index) const; ConnectionContext* cmd_cntx() { - return static_cast(dummy_conn_->cntx()); + auto cntx = static_cast(dummy_conn_->cntx()); + cntx->ns = &namespaces.GetDefaultNamespace(); + return cntx; } StringVec SplitLines() const { @@ -210,7 +212,10 @@ void BaseFamilyTest::ResetService() { used_mem_current = 0; TEST_current_time_ms = absl::GetCurrentTimeNanos() / 1000000; - auto cb = [&](EngineShard* s) { s->db_slice().UpdateExpireBase(TEST_current_time_ms - 1000, 0); }; + auto default_ns = &namespaces.GetDefaultNamespace(); + auto cb = [&](EngineShard* s) { + default_ns->GetDbSlice(s->shard_id()).UpdateExpireBase(TEST_current_time_ms - 1000, 0); + }; shard_set->RunBriefInParallel(cb); const TestInfo* const test_info = UnitTest::GetInstance()->current_test_info(); @@ -244,7 +249,10 @@ void BaseFamilyTest::ResetService() { } LOG(ERROR) << "TxLocks for shard " << es->shard_id(); - for (const auto& k_v : es->db_slice().GetDBTable(0)->trans_locks) { + for (const auto& k_v : namespaces.GetDefaultNamespace() + .GetDbSlice(es->shard_id()) + .GetDBTable(0) + ->trans_locks) { LOG(ERROR) << "Key " << k_v.first << " " << k_v.second; } } @@ -264,6 +272,7 @@ void BaseFamilyTest::ShutdownService() { service_->Shutdown(); service_.reset(); + delete shard_set; shard_set = nullptr; @@ -295,8 +304,9 @@ void BaseFamilyTest::CleanupSnapshots() { unsigned BaseFamilyTest::NumLocked() { atomic_uint count = 0; + auto default_ns = &namespaces.GetDefaultNamespace(); shard_set->RunBriefInParallel([&](EngineShard* shard) { - for (const auto& db : shard->db_slice().databases()) { + for (const auto& db : default_ns->GetDbSlice(shard->shard_id()).databases()) { if (db == nullptr) { continue; } @@ -375,6 +385,7 @@ RespExpr BaseFamilyTest::Run(std::string_view id, ArgSlice slice) { CmdArgVec args = conn_wrapper->Args(slice); auto* context = conn_wrapper->cmd_cntx(); + context->ns = &namespaces.GetDefaultNamespace(); DCHECK(context->transaction == nullptr) << id; @@ -551,12 +562,7 @@ BaseFamilyTest::TestConnWrapper::GetInvalidationMessage(size_t index) const { } bool BaseFamilyTest::IsLocked(DbIndex db_index, std::string_view key) const { - ShardId sid = Shard(key, shard_set->size()); - - bool is_open = pp_->at(sid)->AwaitBrief([db_index, key] { - return EngineShard::tlocal()->db_slice().CheckLock(IntentLock::EXCLUSIVE, db_index, key); - }); - return !is_open; + return service_->IsLocked(&namespaces.GetDefaultNamespace(), db_index, key); } string BaseFamilyTest::GetId() const { @@ -643,7 +649,8 @@ vector BaseFamilyTest::GetLastFps() { } lock_guard lk(mu); - for (auto fp : shard->db_slice().TEST_GetLastLockedFps()) { + for (auto fp : + namespaces.GetDefaultNamespace().GetDbSlice(shard->shard_id()).TEST_GetLastLockedFps()) { result.push_back(fp); } }; diff --git a/src/server/test_utils.h b/src/server/test_utils.h index 1c7149269..3656def8a 100644 --- a/src/server/test_utils.h +++ b/src/server/test_utils.h @@ -117,7 +117,7 @@ class BaseFamilyTest : public ::testing::Test { static std::vector StrArray(const RespExpr& expr); Metrics GetMetrics() const { - return service_->server_family().GetMetrics(); + return service_->server_family().GetMetrics(&namespaces.GetDefaultNamespace()); } void ClearMetrics(); diff --git a/src/server/transaction.cc b/src/server/transaction.cc index 27e222d75..2523ea425 100644 --- a/src/server/transaction.cc +++ b/src/server/transaction.cc @@ -173,8 +173,9 @@ Transaction::~Transaction() { << " destroyed"; } -void Transaction::InitBase(DbIndex dbid, CmdArgList args) { +void Transaction::InitBase(Namespace* ns, DbIndex dbid, CmdArgList args) { global_ = false; + namespace_ = ns; db_index_ = dbid; full_args_ = args; local_result_ = OpStatus::OK; @@ -359,8 +360,8 @@ void Transaction::InitByKeys(const KeyIndex& key_index) { } } -OpStatus Transaction::InitByArgs(DbIndex index, CmdArgList args) { - InitBase(index, args); +OpStatus Transaction::InitByArgs(Namespace* ns, DbIndex index, CmdArgList args) { + InitBase(ns, index, args); if ((cid_->opt_mask() & CO::GLOBAL_TRANS) > 0) { InitGlobal(); @@ -393,7 +394,7 @@ void Transaction::PrepareSquashedMultiHop(const CommandId* cid, MultiSwitchCmd(cid); - InitBase(db_index_, {}); + InitBase(namespace_, db_index_, {}); // Because squashing already determines active shards by partitioning commands, // we don't have to work with keys manually and can just mark active shards. @@ -414,19 +415,20 @@ void Transaction::PrepareSquashedMultiHop(const CommandId* cid, MultiBecomeSquasher(); } -void Transaction::StartMultiGlobal(DbIndex dbid) { +void Transaction::StartMultiGlobal(Namespace* ns, DbIndex dbid) { CHECK(multi_); CHECK(shard_data_.empty()); // Make sure default InitByArgs didn't run. multi_->mode = GLOBAL; - InitBase(dbid, {}); + InitBase(ns, dbid, {}); InitGlobal(); multi_->lock_mode = IntentLock::EXCLUSIVE; ScheduleInternal(); } -void Transaction::StartMultiLockedAhead(DbIndex dbid, CmdArgList keys, bool skip_scheduling) { +void Transaction::StartMultiLockedAhead(Namespace* ns, DbIndex dbid, CmdArgList keys, + bool skip_scheduling) { DVLOG(1) << "StartMultiLockedAhead on " << keys.size() << " keys"; DCHECK(multi_); @@ -437,7 +439,7 @@ void Transaction::StartMultiLockedAhead(DbIndex dbid, CmdArgList keys, bool skip PrepareMultiFps(keys); - InitBase(dbid, keys); + InitBase(ns, dbid, keys); InitByKeys(KeyIndex::Range(0, keys.size())); if (!skip_scheduling) @@ -504,6 +506,7 @@ void Transaction::MultiUpdateWithParent(const Transaction* parent) { txid_ = parent->txid_; time_now_ms_ = parent->time_now_ms_; unique_slot_checker_ = parent->unique_slot_checker_; + namespace_ = parent->namespace_; } void Transaction::MultiBecomeSquasher() { @@ -528,9 +531,10 @@ string Transaction::DebugId(std::optional sid) const { return res; } -void Transaction::PrepareMultiForScheduleSingleHop(ShardId sid, DbIndex db, CmdArgList args) { +void Transaction::PrepareMultiForScheduleSingleHop(Namespace* ns, ShardId sid, DbIndex db, + CmdArgList args) { multi_.reset(); - InitBase(db, args); + InitBase(ns, db, args); EnableShard(sid); OpResult key_index = DetermineKeys(cid_, args); CHECK(key_index); @@ -608,7 +612,8 @@ bool Transaction::RunInShard(EngineShard* shard, bool txq_ooo) { // 1: to go over potential wakened keys, verify them and activate watch queues. // 2: if this transaction was notified and finished running - to remove it from the head // of the queue and notify the next one. - if (auto* bcontroller = shard->blocking_controller(); bcontroller) { + + if (auto* bcontroller = namespace_->GetBlockingController(shard->shard_id()); bcontroller) { if (awaked_prerun || was_suspended) { bcontroller->FinalizeWatched(GetShardArgs(idx), this); } @@ -864,6 +869,7 @@ void Transaction::DispatchHop() { use_count_.fetch_add(run_cnt, memory_order_relaxed); // for each pointer from poll_cb auto poll_cb = [this] { + CHECK(namespace_ != nullptr); EngineShard::tlocal()->PollExecution("exec_cb", this); DVLOG(3) << "ptr_release " << DebugId(); intrusive_ptr_release(this); // against use_count_.fetch_add above. @@ -1149,7 +1155,7 @@ OpStatus Transaction::WaitOnWatch(const time_point& tp, WaitKeysProvider wkeys_p // Register keys on active shards blocking controllers and mark shard state as suspended. auto cb = [&](Transaction* t, EngineShard* shard) { auto keys = wkeys_provider(t, shard); - return t->WatchInShard(keys, shard, krc); + return t->WatchInShard(&t->GetNamespace(), keys, shard, krc); }; Execute(std::move(cb), true); @@ -1187,7 +1193,7 @@ OpStatus Transaction::WaitOnWatch(const time_point& tp, WaitKeysProvider wkeys_p return result; } -OpStatus Transaction::WatchInShard(BlockingController::Keys keys, EngineShard* shard, +OpStatus Transaction::WatchInShard(Namespace* ns, BlockingController::Keys keys, EngineShard* shard, KeyReadyChecker krc) { auto& sd = shard_data_[SidToId(shard->shard_id())]; @@ -1195,7 +1201,7 @@ OpStatus Transaction::WatchInShard(BlockingController::Keys keys, EngineShard* s sd.local_mask |= SUSPENDED_Q; sd.local_mask &= ~OUT_OF_ORDER; - shard->EnsureBlockingController()->AddWatched(keys, std::move(krc), this); + ns->GetOrAddBlockingController(shard)->AddWatched(keys, std::move(krc), this); DVLOG(2) << "WatchInShard " << DebugId(); return OpStatus::OK; @@ -1209,8 +1215,10 @@ void Transaction::ExpireShardCb(BlockingController::Keys keys, EngineShard* shar auto& sd = shard_data_[SidToId(shard->shard_id())]; sd.local_mask &= ~KEYLOCK_ACQUIRED; - shard->blocking_controller()->FinalizeWatched(keys, this); - DCHECK(!shard->blocking_controller()->awakened_transactions().contains(this)); + namespace_->GetBlockingController(shard->shard_id())->FinalizeWatched(keys, this); + DCHECK(!namespace_->GetBlockingController(shard->shard_id()) + ->awakened_transactions() + .contains(this)); // Resume processing of transaction queue shard->PollExecution("unwatchcb", nullptr); @@ -1218,9 +1226,8 @@ void Transaction::ExpireShardCb(BlockingController::Keys keys, EngineShard* shar } DbSlice& Transaction::GetDbSlice(ShardId shard_id) const { - auto* shard = EngineShard::tlocal(); - DCHECK_EQ(shard->shard_id(), shard_id); - return shard->db_slice(); + CHECK(namespace_ != nullptr); + return namespace_->GetDbSlice(shard_id); } OpStatus Transaction::RunSquashedMultiCb(RunnableType cb) { @@ -1270,8 +1277,9 @@ void Transaction::UnlockMultiShardCb(absl::Span fps, EngineShard* shard->RemoveContTx(this); // Wake only if no tx queue head is currently running - if (shard->blocking_controller() && shard->GetContTx() == nullptr) - shard->blocking_controller()->NotifyPending(); + auto bc = namespace_->GetBlockingController(shard->shard_id()); + if (bc && shard->GetContTx() == nullptr) + bc->NotifyPending(); shard->PollExecution("unlockmulti", nullptr); } diff --git a/src/server/transaction.h b/src/server/transaction.h index 4c5b8e9d7..da0836dc5 100644 --- a/src/server/transaction.h +++ b/src/server/transaction.h @@ -21,6 +21,7 @@ #include "server/cluster/cluster_utility.h" #include "server/common.h" #include "server/journal/types.h" +#include "server/namespaces.h" #include "server/table.h" #include "server/tx_base.h" #include "util/fibers/synchronization.h" @@ -185,7 +186,7 @@ class Transaction { std::optional slot_id); // Initialize from command (args) on specific db. - OpStatus InitByArgs(DbIndex index, CmdArgList args); + OpStatus InitByArgs(Namespace* ns, DbIndex index, CmdArgList args); // Get command arguments for specific shard. Called from shard thread. ShardArgs GetShardArgs(ShardId sid) const; @@ -230,10 +231,11 @@ class Transaction { void PrepareSquashedMultiHop(const CommandId* cid, absl::FunctionRef enabled); // Start multi in GLOBAL mode. - void StartMultiGlobal(DbIndex dbid); + void StartMultiGlobal(Namespace* ns, DbIndex dbid); // Start multi in LOCK_AHEAD mode with given keys. - void StartMultiLockedAhead(DbIndex dbid, CmdArgList keys, bool skip_scheduling = false); + void StartMultiLockedAhead(Namespace* ns, DbIndex dbid, CmdArgList keys, + bool skip_scheduling = false); // Start multi in NON_ATOMIC mode. void StartMultiNonAtomic(); @@ -311,7 +313,11 @@ class Transaction { bool IsGlobal() const; DbContext GetDbContext() const { - return DbContext{db_index_, time_now_ms_}; + return DbContext{namespace_, db_index_, time_now_ms_}; + } + + Namespace& GetNamespace() const { + return *namespace_; } DbSlice& GetDbSlice(ShardId sid) const; @@ -330,7 +336,7 @@ class Transaction { // Prepares for running ScheduleSingleHop() for a single-shard multi tx. // It is safe to call ScheduleSingleHop() after calling this method, but the callback passed // to it must not block. - void PrepareMultiForScheduleSingleHop(ShardId sid, DbIndex db, CmdArgList args); + void PrepareMultiForScheduleSingleHop(Namespace* ns, ShardId sid, DbIndex db, CmdArgList args); // Write a journal entry to a shard journal with the given payload. void LogJournalOnShard(EngineShard* shard, journal::Entry::Payload&& payload, uint32_t shard_cnt, @@ -479,7 +485,7 @@ class Transaction { }; // Init basic fields and reset re-usable. - void InitBase(DbIndex dbid, CmdArgList args); + void InitBase(Namespace* ns, DbIndex dbid, CmdArgList args); // Init as a global transaction. void InitGlobal(); @@ -518,7 +524,7 @@ class Transaction { void RunCallback(EngineShard* shard); // Adds itself to watched queue in the shard. Must run in that shard thread. - OpStatus WatchInShard(std::variant keys, EngineShard* shard, + OpStatus WatchInShard(Namespace* ns, std::variant keys, EngineShard* shard, KeyReadyChecker krc); // Expire blocking transaction, unlock keys and unregister it from the blocking controller @@ -612,6 +618,7 @@ class Transaction { TxId txid_{0}; bool global_{false}; + Namespace* namespace_{nullptr}; DbIndex db_index_{0}; uint64_t time_now_ms_{0}; diff --git a/src/server/tx_base.cc b/src/server/tx_base.cc index ba075aa33..08081df7b 100644 --- a/src/server/tx_base.cc +++ b/src/server/tx_base.cc @@ -9,6 +9,7 @@ #include "server/cluster/cluster_defs.h" #include "server/engine_shard_set.h" #include "server/journal/journal.h" +#include "server/namespaces.h" #include "server/transaction.h" namespace dfly { @@ -17,14 +18,11 @@ using namespace std; using Payload = journal::Entry::Payload; DbSlice& DbContext::GetDbSlice(ShardId shard_id) const { - // TODO: Update this when adding namespaces - DCHECK_EQ(shard_id, EngineShard::tlocal()->shard_id()); - return EngineShard::tlocal()->db_slice(); + return ns->GetDbSlice(shard_id); } DbSlice& OpArgs::GetDbSlice() const { - // TODO: Update this when adding namespaces - return shard->db_slice(); + return db_cntx.GetDbSlice(shard->shard_id()); } size_t ShardArgs::Size() const { diff --git a/src/server/tx_base.h b/src/server/tx_base.h index 03aab1a5f..c83be839b 100644 --- a/src/server/tx_base.h +++ b/src/server/tx_base.h @@ -14,6 +14,7 @@ namespace dfly { class EngineShard; class Transaction; +class Namespace; class DbSlice; using DbIndex = uint16_t; @@ -58,6 +59,7 @@ struct KeyIndex { }; struct DbContext { + Namespace* ns = nullptr; DbIndex db_index = 0; uint64_t time_now_ms = 0; diff --git a/src/server/zset_family.cc b/src/server/zset_family.cc index 0e504c8c1..e60f130c9 100644 --- a/src/server/zset_family.cc +++ b/src/server/zset_family.cc @@ -213,10 +213,11 @@ OpResult FindZEntry(const ZParams& zparams, const OpArgs& return OpStatus::WRONG_TYPE; } - if (add_res.is_new && op_args.shard->blocking_controller()) { + auto* blocking_controller = op_args.db_cntx.ns->GetBlockingController(op_args.shard->shard_id()); + if (add_res.is_new && blocking_controller) { string tmp; string_view key = it->first.GetSlice(&tmp); - op_args.shard->blocking_controller()->AwakeWatched(op_args.db_cntx.db_index, key); + blocking_controller->AwakeWatched(op_args.db_cntx.db_index, key); } return DbSlice::ItAndUpdater{add_res.it, add_res.exp_it, std::move(add_res.post_updater)}; diff --git a/tests/dragonfly/acl_family_test.py b/tests/dragonfly/acl_family_test.py index 9a2e76694..1138cba28 100644 --- a/tests/dragonfly/acl_family_test.py +++ b/tests/dragonfly/acl_family_test.py @@ -2,7 +2,7 @@ import pytest import redis from redis import asyncio as aioredis from .instance import DflyInstanceFactory -from .utility import disconnect_clients +from .utility import * import tempfile import asyncio import os @@ -567,6 +567,43 @@ async def test_acl_keys(async_client): await async_client.execute_command("ZUNIONSTORE destkey 2 barz1 barz2") +@pytest.mark.asyncio +async def test_namespaces(df_factory): + df = df_factory.create() + df.start() + + admin = aioredis.Redis(port=df.port) + assert await admin.execute_command("SET foo admin") == b"OK" + assert await admin.execute_command("GET foo") == b"admin" + + # Create ns space named 'ns1' + await admin.execute_command("ACL SETUSER adi NAMESPACE:ns1 ON >adi_pass +@all ~*") + + adi = aioredis.Redis(port=df.port) + assert await adi.execute_command("AUTH adi adi_pass") == b"OK" + assert await adi.execute_command("SET foo bar") == b"OK" + assert await adi.execute_command("GET foo") == b"bar" + assert await admin.execute_command("GET foo") == b"admin" + + # Adi and Shahar are on the same team + await admin.execute_command("ACL SETUSER shahar NAMESPACE:ns1 ON >shahar_pass +@all ~*") + + shahar = aioredis.Redis(port=df.port) + assert await shahar.execute_command("AUTH shahar shahar_pass") == b"OK" + assert await shahar.execute_command("GET foo") == b"bar" + assert await shahar.execute_command("SET foo bar2") == b"OK" + assert await adi.execute_command("GET foo") == b"bar2" + + # Roman is a CTO, he has his own private space + await admin.execute_command("ACL SETUSER roman NAMESPACE:ns2 ON >roman_pass +@all ~*") + + roman = aioredis.Redis(port=df.port) + assert await roman.execute_command("AUTH roman roman_pass") == b"OK" + assert await roman.execute_command("GET foo") == None + + await close_clients(admin, adi, shahar, roman) + + @pytest.mark.asyncio async def default_user_bug(df_factory): df.start()