1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-14 11:58:02 +00:00

feat(namespaces): Initial support for multi-tenant (#3260)

* feat(namespaces): Initial support for multi-tenant #3050

This PR introduces a way to create multiple, separate and isolated
namespaces in Dragonfly. Each user can be associated with a single
namespace, and will not be able to interact with other namespaces.

This is still experimental, and lacks some important features, such as:
* Replication and RDB saving completely ignores non-default namespaces
* Defrag and statistics either use the default namespace or all
  namespaces without separation

To associate a user with a namespace, use the `ACL` command with the
`TENANT:<namespace>` flag:

```
ACL SETUSER user TENANT:namespace1 ON >user_pass +@all ~*
```

For more examples and up to date info check
`tests/dragonfly/acl_family_test.py` - specifically the
`test_namespaces` function.
This commit is contained in:
Shahar Mike 2024-07-16 19:34:49 +03:00 committed by GitHub
parent 3891efac2c
commit 18ca61d29b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
51 changed files with 600 additions and 255 deletions

View file

@ -1,6 +1,6 @@
# Namespaces in Dragonfly # Namespaces in Dragonfly
Dragonfly will soon add an _experimental_ feature, allowing complete separation of data by different users. Dragonfly added an _experimental_ feature, allowing complete separation of data by different users.
We call this feature _namespaces_, and it allows using a single Dragonfly server with multiple We call this feature _namespaces_, and it allows using a single Dragonfly server with multiple
tenants, each using their own data, without being able to mix them together. tenants, each using their own data, without being able to mix them together.

View file

@ -28,6 +28,7 @@ struct UserCredentials {
uint32_t acl_categories{0}; uint32_t acl_categories{0};
std::vector<uint64_t> acl_commands; std::vector<uint64_t> acl_commands;
AclKeys keys; AclKeys keys;
std::string ns;
}; };
} // namespace dfly::acl } // namespace dfly::acl

View file

@ -28,7 +28,7 @@ endif()
add_library(dfly_transaction db_slice.cc malloc_stats.cc blocking_controller.cc add_library(dfly_transaction db_slice.cc malloc_stats.cc blocking_controller.cc
command_registry.cc cluster/cluster_utility.cc command_registry.cc cluster/cluster_utility.cc
journal/tx_executor.cc journal/tx_executor.cc namespaces.cc
common.cc journal/journal.cc journal/types.cc journal/journal_slice.cc common.cc journal/journal.cc journal/types.cc journal/journal_slice.cc
server_state.cc table.cc top_keys.cc transaction.cc tx_base.cc server_state.cc table.cc top_keys.cc transaction.cc tx_base.cc
serializer_commons.cc journal/serializer.cc journal/executor.cc journal/streamer.cc serializer_commons.cc journal/serializer.cc journal/executor.cc journal/streamer.cc

View file

@ -942,6 +942,14 @@ std::pair<OptCat, bool> AclFamily::MaybeParseAclCategory(std::string_view comman
return {}; return {};
} }
std::optional<std::string> AclFamily::MaybeParseNamespace(std::string_view command) const {
constexpr std::string_view kPrefix = "NAMESPACE:";
if (absl::StartsWith(command, kPrefix)) {
return std::string(command.substr(kPrefix.size()));
}
return std::nullopt;
}
std::pair<AclFamily::OptCommand, bool> AclFamily::MaybeParseAclCommand( std::pair<AclFamily::OptCommand, bool> AclFamily::MaybeParseAclCommand(
std::string_view command) const { std::string_view command) const {
if (absl::StartsWith(command, "+")) { if (absl::StartsWith(command, "+")) {
@ -1019,6 +1027,12 @@ std::variant<User::UpdateRequest, ErrorReply> AclFamily::ParseAclSetUser(
continue; continue;
} }
auto ns = MaybeParseNamespace(command);
if (ns.has_value()) {
req.ns = *ns;
continue;
}
auto [cmd, sign] = MaybeParseAclCommand(command); auto [cmd, sign] = MaybeParseAclCommand(command);
if (!cmd) { if (!cmd) {
return ErrorReply(absl::StrCat("Unrecognized parameter ", command)); return ErrorReply(absl::StrCat("Unrecognized parameter ", command));

View file

@ -80,6 +80,8 @@ class AclFamily final {
using OptCommand = std::optional<std::pair<size_t, uint64_t>>; using OptCommand = std::optional<std::pair<size_t, uint64_t>>;
std::pair<OptCommand, bool> MaybeParseAclCommand(std::string_view command) const; std::pair<OptCommand, bool> MaybeParseAclCommand(std::string_view command) const;
std::optional<std::string> MaybeParseNamespace(std::string_view command) const;
std::variant<User::UpdateRequest, facade::ErrorReply> ParseAclSetUser( std::variant<User::UpdateRequest, facade::ErrorReply> ParseAclSetUser(
const facade::ArgRange& args, bool hashed = false, bool has_all_keys = false) const; const facade::ArgRange& args, bool hashed = false, bool has_all_keys = false) const;

View file

@ -79,6 +79,8 @@ void User::Update(UpdateRequest&& req, const CategoryToIdxStore& cat_to_id,
if (req.is_active) { if (req.is_active) {
SetIsActive(*req.is_active); SetIsActive(*req.is_active);
} }
SetNamespace(req.ns);
} }
void User::SetPasswordHash(std::string_view password, bool is_hashed) { void User::SetPasswordHash(std::string_view password, bool is_hashed) {
@ -94,6 +96,14 @@ void User::UnsetPassword(std::string_view password) {
password_hashes_.erase(StringSHA256(password)); password_hashes_.erase(StringSHA256(password));
} }
void User::SetNamespace(const std::string& ns) {
namespace_ = ns;
}
const std::string& User::Namespace() const {
return namespace_;
}
bool User::HasPassword(std::string_view password) const { bool User::HasPassword(std::string_view password) const {
if (nopass_) { if (nopass_) {
return true; return true;

View file

@ -58,8 +58,11 @@ class User final {
std::vector<UpdateKey> keys; std::vector<UpdateKey> keys;
bool reset_all_keys{false}; bool reset_all_keys{false};
bool allow_all_keys{false}; bool allow_all_keys{false};
// TODO allow reset all // TODO allow reset all
// bool reset_all{false}; // bool reset_all{false};
std::string ns;
}; };
using CategoryChange = uint32_t; using CategoryChange = uint32_t;
@ -104,6 +107,8 @@ class User final {
const AclKeys& Keys() const; const AclKeys& Keys() const;
const std::string& Namespace() const;
using CategoryChanges = absl::flat_hash_map<CategoryChange, ChangeMetadata>; using CategoryChanges = absl::flat_hash_map<CategoryChange, ChangeMetadata>;
using CommandChanges = absl::flat_hash_map<CommandChange, ChangeMetadata>; using CommandChanges = absl::flat_hash_map<CommandChange, ChangeMetadata>;
@ -135,6 +140,7 @@ class User final {
// For ACL key globs // For ACL key globs
void SetKeyGlobs(std::vector<UpdateKey> keys); void SetKeyGlobs(std::vector<UpdateKey> keys);
void SetNamespace(const std::string& ns);
// Set NOPASS and remove all passwords // Set NOPASS and remove all passwords
void SetNopass(); void SetNopass();
@ -166,6 +172,8 @@ class User final {
// if the user is on/off // if the user is on/off
bool is_active_{false}; bool is_active_{false};
std::string namespace_;
}; };
} // namespace dfly::acl } // namespace dfly::acl

View file

@ -35,7 +35,8 @@ UserCredentials UserRegistry::GetCredentials(std::string_view username) const {
if (it == registry_.end()) { if (it == registry_.end()) {
return {}; return {};
} }
return {it->second.AclCategory(), it->second.AclCommands(), it->second.Keys()}; return {it->second.AclCategory(), it->second.AclCommands(), it->second.Keys(),
it->second.Namespace()};
} }
bool UserRegistry::IsUserActive(std::string_view username) const { bool UserRegistry::IsUserActive(std::string_view username) const {
@ -73,10 +74,13 @@ UserRegistry::UserWithWriteLock::UserWithWriteLock(std::unique_lock<fb2::SharedM
} }
User::UpdateRequest UserRegistry::DefaultUserUpdateRequest() const { User::UpdateRequest UserRegistry::DefaultUserUpdateRequest() const {
std::pair<User::Sign, uint32_t> acl{User::Sign::PLUS, acl::ALL}; // Assign field by field to supress an annoying compiler warning
auto key = User::UpdateKey{"~*", KeyOp::READ_WRITE, true, false}; User::UpdateRequest req;
auto pass = std::vector<User::UpdatePass>{{"", false, true}}; req.passwords = std::vector<User::UpdatePass>{{"", false, true}};
return {std::move(pass), true, false, {std::move(acl)}, {std::move(key)}}; req.is_active = true;
req.updates = {std::pair<User::Sign, uint32_t>{User::Sign::PLUS, acl::ALL}};
req.keys = {User::UpdateKey{"~*", KeyOp::READ_WRITE, true, false}};
return req;
} }
void UserRegistry::Init(const CategoryToIdxStore* cat_to_id_table, void UserRegistry::Init(const CategoryToIdxStore* cat_to_id_table,

View file

@ -10,6 +10,7 @@
#include "base/logging.h" #include "base/logging.h"
#include "server/engine_shard_set.h" #include "server/engine_shard_set.h"
#include "server/namespaces.h"
#include "server/transaction.h" #include "server/transaction.h"
namespace dfly { namespace dfly {
@ -102,7 +103,7 @@ bool BlockingController::DbWatchTable::UnwatchTx(string_view key, Transaction* t
return res; return res;
} }
BlockingController::BlockingController(EngineShard* owner) : owner_(owner) { BlockingController::BlockingController(EngineShard* owner, Namespace* ns) : owner_(owner), ns_(ns) {
} }
BlockingController::~BlockingController() { BlockingController::~BlockingController() {
@ -153,6 +154,7 @@ void BlockingController::NotifyPending() {
CHECK(tx == nullptr) << tx->DebugId(); CHECK(tx == nullptr) << tx->DebugId();
DbContext context; DbContext context;
context.ns = ns_;
context.time_now_ms = GetCurrentTimeMs(); context.time_now_ms = GetCurrentTimeMs();
for (DbIndex index : awakened_indices_) { for (DbIndex index : awakened_indices_) {

View file

@ -15,10 +15,11 @@
namespace dfly { namespace dfly {
class Transaction; class Transaction;
class Namespace;
class BlockingController { class BlockingController {
public: public:
explicit BlockingController(EngineShard* owner); explicit BlockingController(EngineShard* owner, Namespace* ns);
~BlockingController(); ~BlockingController();
using Keys = std::variant<ShardArgs, ArgSlice>; using Keys = std::variant<ShardArgs, ArgSlice>;
@ -60,6 +61,7 @@ class BlockingController {
// void NotifyConvergence(Transaction* tx); // void NotifyConvergence(Transaction* tx);
EngineShard* owner_; EngineShard* owner_;
Namespace* ns_;
absl::flat_hash_map<DbIndex, std::unique_ptr<DbWatchTable>> watched_dbs_; absl::flat_hash_map<DbIndex, std::unique_ptr<DbWatchTable>> watched_dbs_;

View file

@ -61,7 +61,7 @@ void BlockingControllerTest::SetUp() {
arg_vec_.emplace_back(s); arg_vec_.emplace_back(s);
} }
trans_->InitByArgs(0, {arg_vec_.data(), arg_vec_.size()}); trans_->InitByArgs(&namespaces.GetDefaultNamespace(), 0, {arg_vec_.data(), arg_vec_.size()});
CHECK_EQ(0u, Shard("x", shard_set->size())); CHECK_EQ(0u, Shard("x", shard_set->size()));
CHECK_EQ(2u, Shard("z", shard_set->size())); CHECK_EQ(2u, Shard("z", shard_set->size()));
@ -70,6 +70,8 @@ void BlockingControllerTest::SetUp() {
} }
void BlockingControllerTest::TearDown() { void BlockingControllerTest::TearDown() {
namespaces.Clear();
shard_set->Shutdown(); shard_set->Shutdown();
delete shard_set; delete shard_set;
@ -79,7 +81,7 @@ void BlockingControllerTest::TearDown() {
TEST_F(BlockingControllerTest, Basic) { TEST_F(BlockingControllerTest, Basic) {
trans_->ScheduleSingleHop([&](Transaction* t, EngineShard* shard) { trans_->ScheduleSingleHop([&](Transaction* t, EngineShard* shard) {
BlockingController bc(shard); BlockingController bc(shard, &namespaces.GetDefaultNamespace());
auto keys = t->GetShardArgs(shard->shard_id()); auto keys = t->GetShardArgs(shard->shard_id());
bc.AddWatched( bc.AddWatched(
keys, [](auto...) { return true; }, t); keys, [](auto...) { return true; }, t);
@ -103,7 +105,12 @@ TEST_F(BlockingControllerTest, Timeout) {
EXPECT_EQ(status, facade::OpStatus::TIMED_OUT); EXPECT_EQ(status, facade::OpStatus::TIMED_OUT);
unsigned num_watched = shard_set->Await( unsigned num_watched = shard_set->Await(
0, [&] { return EngineShard::tlocal()->blocking_controller()->NumWatched(0); });
0, [&] {
return namespaces.GetDefaultNamespace()
.GetBlockingController(EngineShard::tlocal()->shard_id())
->NumWatched(0);
});
EXPECT_EQ(0, num_watched); EXPECT_EQ(0, num_watched);
trans_.reset(); trans_.reset();

View file

@ -21,6 +21,7 @@
#include "server/error.h" #include "server/error.h"
#include "server/journal/journal.h" #include "server/journal/journal.h"
#include "server/main_service.h" #include "server/main_service.h"
#include "server/namespaces.h"
#include "server/server_family.h" #include "server/server_family.h"
#include "server/server_state.h" #include "server/server_state.h"
@ -451,7 +452,7 @@ void DeleteSlots(const SlotRanges& slots_ranges) {
if (shard == nullptr) if (shard == nullptr)
return; return;
shard->db_slice().FlushSlots(slots_ranges); namespaces.GetDefaultNamespace().GetDbSlice(shard->shard_id()).FlushSlots(slots_ranges);
}; };
shard_set->pool()->AwaitFiberOnAll(std::move(cb)); shard_set->pool()->AwaitFiberOnAll(std::move(cb));
} }
@ -599,7 +600,7 @@ void ClusterFamily::DflyClusterGetSlotInfo(CmdArgList args, ConnectionContext* c
lock_guard lk(mu); lock_guard lk(mu);
for (auto& [slot, data] : slots_stats) { for (auto& [slot, data] : slots_stats) {
data += shard->db_slice().GetSlotStats(slot); data += namespaces.GetDefaultNamespace().GetDbSlice(shard->shard_id()).GetSlotStats(slot);
} }
}; };

View file

@ -2,6 +2,7 @@
#include "server/cluster/cluster_defs.h" #include "server/cluster/cluster_defs.h"
#include "server/engine_shard_set.h" #include "server/engine_shard_set.h"
#include "server/namespaces.h"
using namespace std; using namespace std;
@ -49,7 +50,10 @@ uint64_t GetKeyCount(const SlotRanges& slots) {
uint64_t shard_keys = 0; uint64_t shard_keys = 0;
for (const SlotRange& range : slots) { for (const SlotRange& range : slots) {
for (SlotId slot = range.start; slot <= range.end; slot++) { for (SlotId slot = range.start; slot <= range.end; slot++) {
shard_keys += shard->db_slice().GetSlotStats(slot).key_count; shard_keys += namespaces.GetDefaultNamespace()
.GetDbSlice(shard->shard_id())
.GetSlotStats(slot)
.key_count;
} }
} }
keys.fetch_add(shard_keys); keys.fetch_add(shard_keys);

View file

@ -96,7 +96,7 @@ OutgoingMigration::OutgoingMigration(MigrationInfo info, ClusterFamily* cf, Serv
server_family_(sf), server_family_(sf),
cf_(cf), cf_(cf),
tx_(new Transaction{sf->service().FindCmd("DFLYCLUSTER")}) { tx_(new Transaction{sf->service().FindCmd("DFLYCLUSTER")}) {
tx_->InitByArgs(0, {}); tx_->InitByArgs(&namespaces.GetDefaultNamespace(), 0, {});
} }
OutgoingMigration::~OutgoingMigration() { OutgoingMigration::~OutgoingMigration() {
@ -212,10 +212,10 @@ void OutgoingMigration::SyncFb() {
} }
OnAllShards([this](auto& migration) { OnAllShards([this](auto& migration) {
auto* shard = EngineShard::tlocal(); DbSlice& db_slice = namespaces.GetDefaultNamespace().GetCurrentDbSlice();
server_family_->journal()->StartInThread(); server_family_->journal()->StartInThread();
migration = std::make_unique<SliceSlotMigration>( migration = std::make_unique<SliceSlotMigration>(
&shard->db_slice(), server(), migration_info_.slot_ranges, server_family_->journal()); &db_slice, server(), migration_info_.slot_ranges, server_family_->journal());
}); });
if (!ChangeState(MigrationState::C_SYNC)) { if (!ChangeState(MigrationState::C_SYNC)) {
@ -284,8 +284,9 @@ bool OutgoingMigration::FinalizeMigration(long attempt) {
// TODO implement blocking on migrated slots only // TODO implement blocking on migrated slots only
bool is_block_active = true; bool is_block_active = true;
auto is_pause_in_progress = [&is_block_active] { return is_block_active; }; auto is_pause_in_progress = [&is_block_active] { return is_block_active; };
auto pause_fb_opt = Pause(server_family_->GetNonPriviligedListeners(), nullptr, auto pause_fb_opt =
ClientPause::WRITE, is_pause_in_progress); Pause(server_family_->GetNonPriviligedListeners(), &namespaces.GetDefaultNamespace(), nullptr,
ClientPause::WRITE, is_pause_in_progress);
if (!pause_fb_opt) { if (!pause_fb_opt) {
LOG(WARNING) << "Cluster migration finalization time out"; LOG(WARNING) << "Cluster migration finalization time out";

View file

@ -282,6 +282,7 @@ class ConnectionContext : public facade::ConnectionContext {
DebugInfo last_command_debug; DebugInfo last_command_debug;
// TODO: to introduce proper accessors. // TODO: to introduce proper accessors.
Namespace* ns = nullptr;
Transaction* transaction = nullptr; Transaction* transaction = nullptr;
const CommandId* cid = nullptr; const CommandId* cid = nullptr;

View file

@ -339,9 +339,10 @@ OpResult<string> RunCbOnFirstNonEmptyBlocking(Transaction* trans, int req_obj_ty
} }
auto wcb = [](Transaction* t, EngineShard* shard) { return t->GetShardArgs(shard->shard_id()); }; auto wcb = [](Transaction* t, EngineShard* shard) { return t->GetShardArgs(shard->shard_id()); };
const auto key_checker = [req_obj_type](EngineShard* owner, const DbContext& context, auto* ns = &trans->GetNamespace();
Transaction*, std::string_view key) -> bool { const auto key_checker = [req_obj_type, ns](EngineShard* owner, const DbContext& context,
return context.GetDbSlice(owner->shard_id()).FindReadOnly(context, key, req_obj_type).ok(); Transaction*, std::string_view key) -> bool {
return ns->GetDbSlice(owner->shard_id()).FindReadOnly(context, key, req_obj_type).ok();
}; };
auto status = trans->WaitOnWatch(limit_tp, std::move(wcb), key_checker, block_flag, pause_flag); auto status = trans->WaitOnWatch(limit_tp, std::move(wcb), key_checker, block_flag, pause_flag);

View file

@ -1078,7 +1078,7 @@ void DbSlice::ExpireAllIfNeeded() {
LOG(ERROR) << "Expire entry " << exp_it->first.ToString() << " not found in prime table"; LOG(ERROR) << "Expire entry " << exp_it->first.ToString() << " not found in prime table";
return; return;
} }
ExpireIfNeeded(Context{db_index, GetCurrentTimeMs()}, prime_it); ExpireIfNeeded(Context{nullptr, db_index, GetCurrentTimeMs()}, prime_it);
}; };
ExpireTable::Cursor cursor; ExpireTable::Cursor cursor;

View file

@ -159,7 +159,7 @@ void DoPopulateBatch(string_view type, string_view prefix, size_t val_size, bool
stub_tx->MultiSwitchCmd(cid); stub_tx->MultiSwitchCmd(cid);
local_cntx.cid = cid; local_cntx.cid = cid;
crb.SetReplyMode(ReplyMode::NONE); crb.SetReplyMode(ReplyMode::NONE);
stub_tx->InitByArgs(local_cntx.conn_state.db_index, args_span); stub_tx->InitByArgs(cntx->ns, local_cntx.conn_state.db_index, args_span);
sf->service().InvokeCmd(cid, args_span, &local_cntx); sf->service().InvokeCmd(cid, args_span, &local_cntx);
} }
@ -261,8 +261,8 @@ void MergeObjHistMap(ObjHistMap&& src, ObjHistMap* dest) {
} }
} }
void DoBuildObjHist(EngineShard* shard, ObjHistMap* obj_hist_map) { void DoBuildObjHist(EngineShard* shard, ConnectionContext* cntx, ObjHistMap* obj_hist_map) {
auto& db_slice = shard->db_slice(); auto& db_slice = cntx->ns->GetDbSlice(shard->shard_id());
unsigned steps = 0; unsigned steps = 0;
for (unsigned i = 0; i < db_slice.db_array_size(); ++i) { for (unsigned i = 0; i < db_slice.db_array_size(); ++i) {
@ -288,8 +288,9 @@ void DoBuildObjHist(EngineShard* shard, ObjHistMap* obj_hist_map) {
} }
} }
ObjInfo InspectOp(string_view key, DbIndex db_index) { ObjInfo InspectOp(ConnectionContext* cntx, string_view key) {
auto& db_slice = EngineShard::tlocal()->db_slice(); auto& db_slice = cntx->ns->GetCurrentDbSlice();
auto db_index = cntx->db_index();
auto [pt, exp_t] = db_slice.GetTables(db_index); auto [pt, exp_t] = db_slice.GetTables(db_index);
PrimeIterator it = pt->Find(key); PrimeIterator it = pt->Find(key);
@ -323,8 +324,9 @@ ObjInfo InspectOp(string_view key, DbIndex db_index) {
return oinfo; return oinfo;
} }
OpResult<ValueCompressInfo> EstimateCompression(string_view key, DbIndex db_index) { OpResult<ValueCompressInfo> EstimateCompression(ConnectionContext* cntx, string_view key) {
auto& db_slice = EngineShard::tlocal()->db_slice(); auto& db_slice = cntx->ns->GetCurrentDbSlice();
auto db_index = cntx->db_index();
auto [pt, exp_t] = db_slice.GetTables(db_index); auto [pt, exp_t] = db_slice.GetTables(db_index);
PrimeIterator it = pt->Find(key); PrimeIterator it = pt->Find(key);
@ -544,7 +546,7 @@ void DebugCmd::Load(string_view filename) {
const CommandId* cid = sf_.service().FindCmd("FLUSHALL"); const CommandId* cid = sf_.service().FindCmd("FLUSHALL");
intrusive_ptr<Transaction> flush_trans(new Transaction{cid}); intrusive_ptr<Transaction> flush_trans(new Transaction{cid});
flush_trans->InitByArgs(0, {}); flush_trans->InitByArgs(cntx_->ns, 0, {});
VLOG(1) << "Performing flush"; VLOG(1) << "Performing flush";
error_code ec = sf_.Drakarys(flush_trans.get(), DbSlice::kDbAll); error_code ec = sf_.Drakarys(flush_trans.get(), DbSlice::kDbAll);
if (ec) { if (ec) {
@ -750,7 +752,7 @@ void DebugCmd::PopulateRangeFiber(uint64_t from, uint64_t num_of_keys,
// after running the callback // after running the callback
// Note that running debug populate while running flushall/db can cause dcheck fail because the // Note that running debug populate while running flushall/db can cause dcheck fail because the
// finish cb is executed just when we finish populating the database. // finish cb is executed just when we finish populating the database.
shard->db_slice().OnCbFinish(); cntx_->ns->GetDbSlice(shard->shard_id()).OnCbFinish();
}); });
} }
@ -801,7 +803,7 @@ void DebugCmd::Inspect(string_view key, CmdArgList args) {
string resp; string resp;
if (check_compression) { if (check_compression) {
auto cb = [&] { return EstimateCompression(key, cntx_->db_index()); }; auto cb = [&] { return EstimateCompression(cntx_, key); };
auto res = ess.Await(sid, std::move(cb)); auto res = ess.Await(sid, std::move(cb));
if (!res) { if (!res) {
cntx_->SendError(res.status()); cntx_->SendError(res.status());
@ -812,7 +814,7 @@ void DebugCmd::Inspect(string_view key, CmdArgList args) {
StrAppend(&resp, " ratio: ", static_cast<double>(res->compressed_size) / (res->raw_size)); StrAppend(&resp, " ratio: ", static_cast<double>(res->compressed_size) / (res->raw_size));
} }
} else { } else {
auto cb = [&] { return InspectOp(key, cntx_->db_index()); }; auto cb = [&] { return InspectOp(cntx_, key); };
ObjInfo res = ess.Await(sid, std::move(cb)); ObjInfo res = ess.Await(sid, std::move(cb));
@ -846,7 +848,7 @@ void DebugCmd::Watched() {
vector<string> awaked_trans; vector<string> awaked_trans;
auto cb = [&](EngineShard* shard) { auto cb = [&](EngineShard* shard) {
auto* bc = shard->blocking_controller(); auto* bc = cntx_->ns->GetBlockingController(shard->shard_id());
if (bc) { if (bc) {
auto keys = bc->GetWatchedKeys(cntx_->db_index()); auto keys = bc->GetWatchedKeys(cntx_->db_index());
@ -894,8 +896,9 @@ void DebugCmd::TxAnalysis() {
void DebugCmd::ObjHist() { void DebugCmd::ObjHist() {
vector<ObjHistMap> obj_hist_map_arr(shard_set->size()); vector<ObjHistMap> obj_hist_map_arr(shard_set->size());
shard_set->RunBlockingInParallel( shard_set->RunBlockingInParallel([&](EngineShard* shard) {
[&](EngineShard* shard) { DoBuildObjHist(shard, &obj_hist_map_arr[shard->shard_id()]); }); DoBuildObjHist(shard, cntx_, &obj_hist_map_arr[shard->shard_id()]);
});
for (size_t i = shard_set->size() - 1; i > 0; --i) { for (size_t i = shard_set->size() - 1; i > 0; --i) {
MergeObjHistMap(std::move(obj_hist_map_arr[i]), &obj_hist_map_arr[0]); MergeObjHistMap(std::move(obj_hist_map_arr[i]), &obj_hist_map_arr[0]);
@ -937,8 +940,9 @@ void DebugCmd::Shards() {
vector<ShardInfo> infos(shard_set->size()); vector<ShardInfo> infos(shard_set->size());
shard_set->RunBriefInParallel([&](EngineShard* shard) { shard_set->RunBriefInParallel([&](EngineShard* shard) {
auto slice_stats = shard->db_slice().GetStats(); auto sid = shard->shard_id();
auto& stats = infos[shard->shard_id()]; auto slice_stats = cntx_->ns->GetDbSlice(sid).GetStats();
auto& stats = infos[sid];
stats.used_memory = shard->UsedMemory(); stats.used_memory = shard->UsedMemory();
for (const auto& db_stats : slice_stats.db_stats) { for (const auto& db_stats : slice_stats.db_stats) {

View file

@ -13,6 +13,7 @@
#include "base/logging.h" #include "base/logging.h"
#include "server/detail/snapshot_storage.h" #include "server/detail/snapshot_storage.h"
#include "server/main_service.h" #include "server/main_service.h"
#include "server/namespaces.h"
#include "server/script_mgr.h" #include "server/script_mgr.h"
#include "server/transaction.h" #include "server/transaction.h"
#include "strings/human_readable.h" #include "strings/human_readable.h"
@ -400,7 +401,7 @@ void SaveStagesController::CloseCb(unsigned index) {
} }
if (auto* es = EngineShard::tlocal(); use_dfs_format_ && es) if (auto* es = EngineShard::tlocal(); use_dfs_format_ && es)
es->db_slice().ResetUpdateEvents(); namespaces.GetDefaultNamespace().GetDbSlice(es->shard_id()).ResetUpdateEvents();
} }
void SaveStagesController::RunStage(void (SaveStagesController::*cb)(unsigned)) { void SaveStagesController::RunStage(void (SaveStagesController::*cb)(unsigned)) {

View file

@ -75,7 +75,7 @@ OpStatus WaitReplicaFlowToCatchup(absl::Time end_time, shared_ptr<DflyCmd::Repli
EngineShard* shard) { EngineShard* shard) {
// We don't want any writes to the journal after we send the `PING`, // We don't want any writes to the journal after we send the `PING`,
// and expirations could ruin that. // and expirations could ruin that.
shard->db_slice().SetExpireAllowed(false); namespaces.GetDefaultNamespace().GetDbSlice(shard->shard_id()).SetExpireAllowed(false);
shard->journal()->RecordEntry(0, journal::Op::PING, 0, 0, nullopt, {}, true); shard->journal()->RecordEntry(0, journal::Op::PING, 0, 0, nullopt, {}, true);
FlowInfo* flow = &replica->flows[shard->shard_id()]; FlowInfo* flow = &replica->flows[shard->shard_id()];
@ -396,8 +396,9 @@ void DflyCmd::TakeOver(CmdArgList args, ConnectionContext* cntx) {
VLOG(1) << "AwaitCurrentDispatches done"; VLOG(1) << "AwaitCurrentDispatches done";
absl::Cleanup([] { absl::Cleanup([] {
shard_set->RunBriefInParallel( shard_set->RunBriefInParallel([](EngineShard* shard) {
[](EngineShard* shard) { shard->db_slice().SetExpireAllowed(true); }); namespaces.GetDefaultNamespace().GetDbSlice(shard->shard_id()).SetExpireAllowed(true);
});
VLOG(2) << "Enable expiration"; VLOG(2) << "Enable expiration";
}); });

View file

@ -366,7 +366,11 @@ TEST_F(DflyEngineTest, MemcacheFlags) {
ASSERT_EQ(Run("resp", {"flushdb"}), "OK"); ASSERT_EQ(Run("resp", {"flushdb"}), "OK");
pp_->AwaitFiberOnAll([](auto*) { pp_->AwaitFiberOnAll([](auto*) {
if (auto* shard = EngineShard::tlocal(); shard) { if (auto* shard = EngineShard::tlocal(); shard) {
EXPECT_EQ(shard->db_slice().GetDBTable(0)->mcflag.size(), 0u); EXPECT_EQ(namespaces.GetDefaultNamespace()
.GetDbSlice(shard->shard_id())
.GetDBTable(0)
->mcflag.size(),
0u);
} }
}); });
} }
@ -584,7 +588,7 @@ TEST_F(DflyEngineTest, Bug496) {
if (shard == nullptr) if (shard == nullptr)
return; return;
auto& db = shard->db_slice(); auto& db = namespaces.GetDefaultNamespace().GetDbSlice(shard->shard_id());
int cb_hits = 0; int cb_hits = 0;
uint32_t cb_id = uint32_t cb_id =

View file

@ -20,6 +20,7 @@ extern "C" {
#include "io/proc_reader.h" #include "io/proc_reader.h"
#include "server/blocking_controller.h" #include "server/blocking_controller.h"
#include "server/cluster/cluster_defs.h" #include "server/cluster/cluster_defs.h"
#include "server/namespaces.h"
#include "server/search/doc_index.h" #include "server/search/doc_index.h"
#include "server/server_state.h" #include "server/server_state.h"
#include "server/tiered_storage.h" #include "server/tiered_storage.h"
@ -294,7 +295,8 @@ bool EngineShard::DoDefrag() {
constexpr size_t kMaxTraverses = 40; constexpr size_t kMaxTraverses = 40;
const float threshold = GetFlag(FLAGS_mem_defrag_page_utilization_threshold); const float threshold = GetFlag(FLAGS_mem_defrag_page_utilization_threshold);
auto& slice = db_slice(); // TODO: enable tiered storage on non-default db slice
DbSlice& slice = namespaces.GetDefaultNamespace().GetDbSlice(shard_->shard_id());
// If we moved to an invalid db, skip as long as it's not the last one // If we moved to an invalid db, skip as long as it's not the last one
while (!slice.IsDbValid(defrag_state_.dbid) && defrag_state_.dbid + 1 < slice.db_array_size()) while (!slice.IsDbValid(defrag_state_.dbid) && defrag_state_.dbid + 1 < slice.db_array_size())
@ -324,7 +326,7 @@ bool EngineShard::DoDefrag() {
} }
}); });
traverses_count++; traverses_count++;
} while (traverses_count < kMaxTraverses && cur); } while (traverses_count < kMaxTraverses && cur && namespaces.IsInitialized());
defrag_state_.UpdateScanState(cur.value()); defrag_state_.UpdateScanState(cur.value());
@ -355,11 +357,14 @@ bool EngineShard::DoDefrag() {
// priority. // priority.
// otherwise lower the task priority so that it would not use the CPU when not required // otherwise lower the task priority so that it would not use the CPU when not required
uint32_t EngineShard::DefragTask() { uint32_t EngineShard::DefragTask() {
if (!namespaces.IsInitialized()) {
return util::ProactorBase::kOnIdleMaxLevel;
}
constexpr uint32_t kRunAtLowPriority = 0u; constexpr uint32_t kRunAtLowPriority = 0u;
const auto shard_id = db_slice().shard_id();
if (defrag_state_.CheckRequired()) { if (defrag_state_.CheckRequired()) {
VLOG(2) << shard_id << ": need to run defrag memory cursor state: " << defrag_state_.cursor; VLOG(2) << shard_id_ << ": need to run defrag memory cursor state: " << defrag_state_.cursor;
if (DoDefrag()) { if (DoDefrag()) {
// we didn't finish the scan // we didn't finish the scan
return util::ProactorBase::kOnIdleMaxLevel; return util::ProactorBase::kOnIdleMaxLevel;
@ -372,13 +377,11 @@ EngineShard::EngineShard(util::ProactorBase* pb, mi_heap_t* heap)
: queue_(1, kQueueLen), : queue_(1, kQueueLen),
txq_([](const Transaction* t) { return t->txid(); }), txq_([](const Transaction* t) { return t->txid(); }),
mi_resource_(heap), mi_resource_(heap),
db_slice_(pb->GetPoolIndex(), GetFlag(FLAGS_cache_mode), this) { shard_id_(pb->GetPoolIndex()) {
tmp_str1 = sdsempty(); tmp_str1 = sdsempty();
db_slice_.UpdateExpireBase(absl::GetCurrentTimeNanos() / 1000000, 0); defrag_task_ = pb->AddOnIdleTask([this]() { return DefragTask(); });
// start the defragmented task here queue_.Start(absl::StrCat("shard_queue_", shard_id()));
defrag_task_ = pb->AddOnIdleTask([this]() { return this->DefragTask(); });
queue_.Start(absl::StrCat("shard_queue_", db_slice_.shard_id()));
} }
EngineShard::~EngineShard() { EngineShard::~EngineShard() {
@ -437,8 +440,10 @@ void EngineShard::InitTieredStorage(ProactorBase* pb, size_t max_file_size) {
LOG_IF(FATAL, pb->GetKind() != ProactorBase::IOURING) LOG_IF(FATAL, pb->GetKind() != ProactorBase::IOURING)
<< "Only ioring based backing storage is supported. Exiting..."; << "Only ioring based backing storage is supported. Exiting...";
// TODO: enable tiered storage on non-default namespace
DbSlice& db_slice = namespaces.GetDefaultNamespace().GetDbSlice(shard_id());
auto* shard = EngineShard::tlocal(); auto* shard = EngineShard::tlocal();
shard->tiered_storage_ = make_unique<TieredStorage>(&db_slice_, max_file_size); shard->tiered_storage_ = make_unique<TieredStorage>(&db_slice, max_file_size);
error_code ec = shard->tiered_storage_->Open(backing_prefix); error_code ec = shard->tiered_storage_->Open(backing_prefix);
CHECK(!ec) << ec.message(); CHECK(!ec) << ec.message();
} }
@ -515,24 +520,31 @@ void EngineShard::PollExecution(const char* context, Transaction* trans) {
trans = nullptr; trans = nullptr;
if ((is_self && disarmed) || continuation_trans_->DisarmInShard(sid)) { if ((is_self && disarmed) || continuation_trans_->DisarmInShard(sid)) {
auto bc = continuation_trans_->GetNamespace().GetBlockingController(shard_id_);
if (bool keep = run(continuation_trans_, false); !keep) { if (bool keep = run(continuation_trans_, false); !keep) {
// if this holds, we can remove this check altogether. // if this holds, we can remove this check altogether.
DCHECK(continuation_trans_ == nullptr); DCHECK(continuation_trans_ == nullptr);
continuation_trans_ = nullptr; continuation_trans_ = nullptr;
} }
if (bc && bc->HasAwakedTransaction()) {
// Break if there are any awakened transactions, as we must give way to them
// before continuing to handle regular transactions from the queue.
return;
}
} }
} }
// Progress on the transaction queue if no transaction is running currently. // Progress on the transaction queue if no transaction is running currently.
Transaction* head = nullptr; Transaction* head = nullptr;
while (continuation_trans_ == nullptr && !txq_.Empty()) { while (continuation_trans_ == nullptr && !txq_.Empty()) {
head = get<Transaction*>(txq_.Front());
// Break if there are any awakened transactions, as we must give way to them // Break if there are any awakened transactions, as we must give way to them
// before continuing to handle regular transactions from the queue. // before continuing to handle regular transactions from the queue.
if (blocking_controller_ && blocking_controller_->HasAwakedTransaction()) if (head->GetNamespace().GetBlockingController(shard_id_) &&
head->GetNamespace().GetBlockingController(shard_id_)->HasAwakedTransaction())
break; break;
head = get<Transaction*>(txq_.Front());
VLOG(2) << "Considering head " << head->DebugId() VLOG(2) << "Considering head " << head->DebugId()
<< " isarmed: " << head->DEBUG_IsArmedInShard(sid); << " isarmed: " << head->DEBUG_IsArmedInShard(sid);
@ -610,22 +622,28 @@ void EngineShard::Heartbeat() {
DbContext db_cntx; DbContext db_cntx;
db_cntx.time_now_ms = GetCurrentTimeMs(); db_cntx.time_now_ms = GetCurrentTimeMs();
for (unsigned i = 0; i < db_slice_.db_array_size(); ++i) { // TODO: iterate over all namespaces
if (!db_slice_.IsDbValid(i)) if (!namespaces.IsInitialized()) {
return;
}
DbSlice& db_slice = namespaces.GetDefaultNamespace().GetDbSlice(shard_id());
for (unsigned i = 0; i < db_slice.db_array_size(); ++i) {
if (!db_slice.IsDbValid(i))
continue; continue;
db_cntx.db_index = i; db_cntx.db_index = i;
auto [pt, expt] = db_slice_.GetTables(i); auto [pt, expt] = db_slice.GetTables(i);
if (expt->size() > pt->size() / 4) { if (expt->size() > pt->size() / 4) {
DbSlice::DeleteExpiredStats stats = db_slice_.DeleteExpiredStep(db_cntx, ttl_delete_target); DbSlice::DeleteExpiredStats stats = db_slice.DeleteExpiredStep(db_cntx, ttl_delete_target);
counter_[TTL_TRAVERSE].IncBy(stats.traversed); counter_[TTL_TRAVERSE].IncBy(stats.traversed);
counter_[TTL_DELETE].IncBy(stats.deleted); counter_[TTL_DELETE].IncBy(stats.deleted);
} }
// if our budget is below the limit // if our budget is below the limit
if (db_slice_.memory_budget() < eviction_redline) { if (db_slice.memory_budget() < eviction_redline) {
db_slice_.FreeMemWithEvictionStep(i, eviction_redline - db_slice_.memory_budget()); db_slice.FreeMemWithEvictionStep(i, eviction_redline - db_slice.memory_budget());
} }
if (UsedMemory() > tiering_offload_threshold) { if (UsedMemory() > tiering_offload_threshold) {
@ -686,18 +704,23 @@ void EngineShard::RunPeriodic(std::chrono::milliseconds period_ms) {
} }
void EngineShard::CacheStats() { void EngineShard::CacheStats() {
if (!namespaces.IsInitialized()) {
return;
}
// mi_heap_visit_blocks(tlh, false /* visit all blocks*/, visit_cb, &sum); // mi_heap_visit_blocks(tlh, false /* visit all blocks*/, visit_cb, &sum);
mi_stats_merge(); mi_stats_merge();
// Used memory for this shard. // Used memory for this shard.
size_t used_mem = UsedMemory(); size_t used_mem = UsedMemory();
cached_stats[db_slice_.shard_id()].used_memory.store(used_mem, memory_order_relaxed); DbSlice& db_slice = namespaces.GetDefaultNamespace().GetDbSlice(shard_id());
cached_stats[db_slice.shard_id()].used_memory.store(used_mem, memory_order_relaxed);
ssize_t free_mem = max_memory_limit - used_mem_current.load(memory_order_relaxed); ssize_t free_mem = max_memory_limit - used_mem_current.load(memory_order_relaxed);
size_t entries = 0; size_t entries = 0;
size_t table_memory = 0; size_t table_memory = 0;
for (size_t i = 0; i < db_slice_.db_array_size(); ++i) { for (size_t i = 0; i < db_slice.db_array_size(); ++i) {
DbTable* table = db_slice_.GetDBTable(i); DbTable* table = db_slice.GetDBTable(i);
if (table) { if (table) {
entries += table->prime.size(); entries += table->prime.size();
table_memory += (table->prime.mem_usage() + table->expire.mem_usage()); table_memory += (table->prime.mem_usage() + table->expire.mem_usage());
@ -706,7 +729,7 @@ void EngineShard::CacheStats() {
size_t obj_memory = table_memory <= used_mem ? used_mem - table_memory : 0; size_t obj_memory = table_memory <= used_mem ? used_mem - table_memory : 0;
size_t bytes_per_obj = entries > 0 ? obj_memory / entries : 0; size_t bytes_per_obj = entries > 0 ? obj_memory / entries : 0;
db_slice_.SetCachedParams(free_mem / shard_set->size(), bytes_per_obj); db_slice.SetCachedParams(free_mem / shard_set->size(), bytes_per_obj);
} }
size_t EngineShard::UsedMemory() const { size_t EngineShard::UsedMemory() const {
@ -714,14 +737,6 @@ size_t EngineShard::UsedMemory() const {
search_indices()->GetUsedMemory(); search_indices()->GetUsedMemory();
} }
BlockingController* EngineShard::EnsureBlockingController() {
if (!blocking_controller_) {
blocking_controller_.reset(new BlockingController(this));
}
return blocking_controller_.get();
}
void EngineShard::TEST_EnableHeartbeat() { void EngineShard::TEST_EnableHeartbeat() {
fiber_periodic_ = fb2::Fiber("shard_periodic_TEST", [this, period_ms = 1] { fiber_periodic_ = fb2::Fiber("shard_periodic_TEST", [this, period_ms = 1] {
RunPeriodic(std::chrono::milliseconds(period_ms)); RunPeriodic(std::chrono::milliseconds(period_ms));
@ -750,6 +765,8 @@ auto EngineShard::AnalyzeTxQueue() const -> TxQueueInfo {
info.tx_total = queue->size(); info.tx_total = queue->size();
unsigned max_db_id = 0; unsigned max_db_id = 0;
auto& db_slice = namespaces.GetDefaultNamespace().GetCurrentDbSlice();
do { do {
auto value = queue->At(cur); auto value = queue->At(cur);
Transaction* trx = std::get<Transaction*>(value); Transaction* trx = std::get<Transaction*>(value);
@ -766,7 +783,7 @@ auto EngineShard::AnalyzeTxQueue() const -> TxQueueInfo {
if (trx->IsGlobal() || (trx->IsMulti() && trx->GetMultiMode() == Transaction::GLOBAL)) { if (trx->IsGlobal() || (trx->IsMulti() && trx->GetMultiMode() == Transaction::GLOBAL)) {
info.tx_global++; info.tx_global++;
} else { } else {
const DbTable* table = db_slice().GetDBTable(trx->GetDbIndex()); const DbTable* table = db_slice.GetDBTable(trx->GetDbIndex());
bool can_run = !HasContendedLocks(sid, trx, table); bool can_run = !HasContendedLocks(sid, trx, table);
if (can_run) { if (can_run) {
info.tx_runnable++; info.tx_runnable++;
@ -778,7 +795,7 @@ auto EngineShard::AnalyzeTxQueue() const -> TxQueueInfo {
// Analyze locks // Analyze locks
for (unsigned i = 0; i <= max_db_id; ++i) { for (unsigned i = 0; i <= max_db_id; ++i) {
const DbTable* table = db_slice().GetDBTable(i); const DbTable* table = db_slice.GetDBTable(i);
if (table == nullptr) if (table == nullptr)
continue; continue;
@ -869,6 +886,8 @@ void EngineShardSet::Init(uint32_t sz, bool update_db_time) {
} }
}); });
namespaces.Init();
pp_->AwaitFiberOnAll([&](uint32_t index, ProactorBase* pb) { pp_->AwaitFiberOnAll([&](uint32_t index, ProactorBase* pb) {
if (index < shard_queue_.size()) { if (index < shard_queue_.size()) {
EngineShard::tlocal()->InitTieredStorage(pb, max_shard_file_size); EngineShard::tlocal()->InitTieredStorage(pb, max_shard_file_size);
@ -895,7 +914,9 @@ void EngineShardSet::TEST_EnableHeartBeat() {
} }
void EngineShardSet::TEST_EnableCacheMode() { void EngineShardSet::TEST_EnableCacheMode() {
RunBriefInParallel([](EngineShard* shard) { shard->db_slice().TEST_EnableCacheMode(); }); RunBlockingInParallel([](EngineShard* shard) {
namespaces.GetDefaultNamespace().GetCurrentDbSlice().TEST_EnableCacheMode();
});
} }
ShardId Shard(string_view v, ShardId shard_num) { ShardId Shard(string_view v, ShardId shard_num) {

View file

@ -60,15 +60,7 @@ class EngineShard {
} }
ShardId shard_id() const { ShardId shard_id() const {
return db_slice_.shard_id(); return shard_id_;
}
DbSlice& db_slice() {
return db_slice_;
}
const DbSlice& db_slice() const {
return db_slice_;
} }
PMR_NS::memory_resource* memory_resource() { PMR_NS::memory_resource* memory_resource() {
@ -124,12 +116,6 @@ class EngineShard {
return shard_search_indices_.get(); return shard_search_indices_.get();
} }
BlockingController* EnsureBlockingController();
BlockingController* blocking_controller() {
return blocking_controller_.get();
}
// for everyone to use for string transformations during atomic cpu sequences. // for everyone to use for string transformations during atomic cpu sequences.
sds tmp_str1; sds tmp_str1;
@ -242,7 +228,7 @@ class EngineShard {
TxQueue txq_; TxQueue txq_;
MiMemoryResource mi_resource_; MiMemoryResource mi_resource_;
DbSlice db_slice_; ShardId shard_id_;
Stats stats_; Stats stats_;
@ -261,8 +247,8 @@ class EngineShard {
DefragTaskState defrag_state_; DefragTaskState defrag_state_;
std::unique_ptr<TieredStorage> tiered_storage_; std::unique_ptr<TieredStorage> tiered_storage_;
// TODO: Move indices to Namespace
std::unique_ptr<ShardDocIndices> shard_search_indices_; std::unique_ptr<ShardDocIndices> shard_search_indices_;
std::unique_ptr<BlockingController> blocking_controller_;
using Counter = util::SlidingCounter<7>; using Counter = util::SlidingCounter<7>;

View file

@ -451,7 +451,7 @@ OpStatus Renamer::DeserializeDest(Transaction* t, EngineShard* shard) {
auto& dest_it = restored_dest_it->it; auto& dest_it = restored_dest_it->it;
dest_it->first.SetSticky(serialized_value_.sticky); dest_it->first.SetSticky(serialized_value_.sticky);
auto bc = shard->blocking_controller(); auto bc = op_args.db_cntx.ns->GetBlockingController(op_args.shard->shard_id());
if (bc) { if (bc) {
bc->AwakeWatched(t->GetDbIndex(), dest_key_); bc->AwakeWatched(t->GetDbIndex(), dest_key_);
} }
@ -607,7 +607,7 @@ uint64_t ScanGeneric(uint64_t cursor, const ScanOpts& scan_opts, StringVec* keys
} }
cursor >>= 10; cursor >>= 10;
DbContext db_cntx{cntx->conn_state.db_index, GetCurrentTimeMs()}; DbContext db_cntx{cntx->ns, cntx->conn_state.db_index, GetCurrentTimeMs()};
do { do {
auto cb = [&] { auto cb = [&] {
@ -1355,8 +1355,10 @@ void GenericFamily::Select(CmdArgList args, ConnectionContext* cntx) {
return cntx->SendError(kDbIndOutOfRangeErr); return cntx->SendError(kDbIndOutOfRangeErr);
} }
cntx->conn_state.db_index = index; cntx->conn_state.db_index = index;
auto cb = [index](EngineShard* shard) { auto cb = [cntx, index](EngineShard* shard) {
shard->db_slice().ActivateDb(index); CHECK(cntx->ns != nullptr);
auto& db_slice = cntx->ns->GetDbSlice(shard->shard_id());
db_slice.ActivateDb(index);
return OpStatus::OK; return OpStatus::OK;
}; };
shard_set->RunBriefInParallel(std::move(cb)); shard_set->RunBriefInParallel(std::move(cb));
@ -1385,7 +1387,8 @@ void GenericFamily::Type(CmdArgList args, ConnectionContext* cntx) {
std::string_view key = ArgS(args, 0); std::string_view key = ArgS(args, 0);
auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult<int> { auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult<int> {
auto it = shard->db_slice().FindReadOnly(t->GetDbContext(), key).it; auto& db_slice = cntx->ns->GetDbSlice(shard->shard_id());
auto it = db_slice.FindReadOnly(t->GetDbContext(), key).it;
if (!it.is_done()) { if (!it.is_done()) {
return it->second.ObjType(); return it->second.ObjType();
} else { } else {
@ -1549,8 +1552,9 @@ OpResult<void> GenericFamily::OpRen(const OpArgs& op_args, string_view from_key,
to_res.it->first.SetSticky(sticky); to_res.it->first.SetSticky(sticky);
} }
if (!is_prior_list && to_res.it->second.ObjType() == OBJ_LIST && es->blocking_controller()) { auto bc = op_args.db_cntx.ns->GetBlockingController(es->shard_id());
es->blocking_controller()->AwakeWatched(op_args.db_cntx.db_index, to_key); if (!is_prior_list && to_res.it->second.ObjType() == OBJ_LIST && bc) {
bc->AwakeWatched(op_args.db_cntx.db_index, to_key);
} }
return OpStatus::OK; return OpStatus::OK;
} }
@ -1590,8 +1594,9 @@ OpStatus GenericFamily::OpMove(const OpArgs& op_args, string_view key, DbIndex t
auto& add_res = *op_result; auto& add_res = *op_result;
add_res.it->first.SetSticky(sticky); add_res.it->first.SetSticky(sticky);
if (add_res.it->second.ObjType() == OBJ_LIST && op_args.shard->blocking_controller()) { auto bc = op_args.db_cntx.ns->GetBlockingController(op_args.shard->shard_id());
op_args.shard->blocking_controller()->AwakeWatched(target_db, key); if (add_res.it->second.ObjType() == OBJ_LIST && bc) {
bc->AwakeWatched(target_db, key);
} }
return OpStatus::OK; return OpStatus::OK;
@ -1602,14 +1607,15 @@ void GenericFamily::RandomKey(CmdArgList args, ConnectionContext* cntx) {
absl::BitGen bitgen; absl::BitGen bitgen;
atomic_size_t candidates_counter{0}; atomic_size_t candidates_counter{0};
DbContext db_cntx{cntx->conn_state.db_index, GetCurrentTimeMs()}; DbContext db_cntx{cntx->ns, cntx->conn_state.db_index, GetCurrentTimeMs()};
ScanOpts scan_opts; ScanOpts scan_opts;
scan_opts.limit = 3; // number of entries per shard scan_opts.limit = 3; // number of entries per shard
std::vector<StringVec> candidates_collection(shard_set->size()); std::vector<StringVec> candidates_collection(shard_set->size());
shard_set->RunBriefInParallel( shard_set->RunBriefInParallel(
[&](EngineShard* shard) { [&](EngineShard* shard) {
auto [prime_table, expire_table] = shard->db_slice().GetTables(db_cntx.db_index); auto [prime_table, expire_table] =
cntx->ns->GetDbSlice(shard->shard_id()).GetTables(db_cntx.db_index);
if (prime_table->size() == 0) { if (prime_table->size() == 0) {
return; return;
} }

View file

@ -1055,7 +1055,7 @@ void HSetFamily::HRandField(CmdArgList args, ConnectionContext* cntx) {
} }
auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult<StringVec> { auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult<StringVec> {
auto& db_slice = shard->db_slice(); auto& db_slice = cntx->ns->GetDbSlice(shard->shard_id());
DbContext db_context = t->GetDbContext(); DbContext db_context = t->GetDbContext();
auto it_res = db_slice.FindReadOnly(db_context, key, OBJ_HASH); auto it_res = db_slice.FindReadOnly(db_context, key, OBJ_HASH);

View file

@ -44,6 +44,7 @@ JournalExecutor::JournalExecutor(Service* service)
conn_context_.is_replicating = true; conn_context_.is_replicating = true;
conn_context_.journal_emulated = true; conn_context_.journal_emulated = true;
conn_context_.skip_acl_validation = true; conn_context_.skip_acl_validation = true;
conn_context_.ns = &namespaces.GetDefaultNamespace();
} }
JournalExecutor::~JournalExecutor() { JournalExecutor::~JournalExecutor() {

View file

@ -341,11 +341,12 @@ OpResult<uint32_t> OpPush(const OpArgs& op_args, std::string_view key, ListDir d
} }
if (res.is_new) { if (res.is_new) {
if (es->blocking_controller()) { auto blocking_controller = op_args.db_cntx.ns->GetBlockingController(es->shard_id());
if (blocking_controller) {
string tmp; string tmp;
string_view key = res.it->first.GetSlice(&tmp); string_view key = res.it->first.GetSlice(&tmp);
es->blocking_controller()->AwakeWatched(op_args.db_cntx.db_index, key); blocking_controller->AwakeWatched(op_args.db_cntx.db_index, key);
absl::StrAppend(debugMessages.Next(), "OpPush AwakeWatched: ", key, " by ", absl::StrAppend(debugMessages.Next(), "OpPush AwakeWatched: ", key, " by ",
op_args.tx->DebugId()); op_args.tx->DebugId());
} }
@ -444,11 +445,12 @@ OpResult<string> MoveTwoShards(Transaction* trans, string_view src, string_view
OpPush(op_args, key, dest_dir, false, ArgSlice{val}, true); OpPush(op_args, key, dest_dir, false, ArgSlice{val}, true);
// blocking_controller does not have to be set with non-blocking transactions. // blocking_controller does not have to be set with non-blocking transactions.
if (shard->blocking_controller()) { auto blocking_controller = t->GetNamespace().GetBlockingController(shard->shard_id());
if (blocking_controller) {
// hack, again. since we hacked which queue we are waiting on (see RunPair) // hack, again. since we hacked which queue we are waiting on (see RunPair)
// we must clean-up src key here manually. See RunPair why we do this. // we must clean-up src key here manually. See RunPair why we do this.
// in short- we suspended on "src" on both shards. // in short- we suspended on "src" on both shards.
shard->blocking_controller()->FinalizeWatched(ArgSlice({src}), t); blocking_controller->FinalizeWatched(ArgSlice({src}), t);
} }
} else { } else {
DVLOG(1) << "Popping value from list: " << key; DVLOG(1) << "Popping value from list: " << key;
@ -852,10 +854,11 @@ OpResult<string> BPopPusher::RunSingle(ConnectionContext* cntx, time_point tp) {
std::array<string_view, 4> arr = {pop_key_, push_key_, DirToSv(popdir_), DirToSv(pushdir_)}; std::array<string_view, 4> arr = {pop_key_, push_key_, DirToSv(popdir_), DirToSv(pushdir_)};
RecordJournal(op_args, "LMOVE", arr, 1); RecordJournal(op_args, "LMOVE", arr, 1);
} }
if (shard->blocking_controller()) { auto blocking_controller = cntx->ns->GetBlockingController(shard->shard_id());
if (blocking_controller) {
string tmp; string tmp;
shard->blocking_controller()->AwakeWatched(op_args.db_cntx.db_index, push_key_); blocking_controller->AwakeWatched(op_args.db_cntx.db_index, push_key_);
absl::StrAppend(debugMessages.Next(), "OpPush AwakeWatched: ", push_key_, " by ", absl::StrAppend(debugMessages.Next(), "OpPush AwakeWatched: ", push_key_, " by ",
op_args.tx->DebugId()); op_args.tx->DebugId());
} }

View file

@ -32,8 +32,10 @@ class ListFamilyTest : public BaseFamilyTest {
static unsigned NumWatched() { static unsigned NumWatched() {
atomic_uint32_t sum{0}; atomic_uint32_t sum{0};
auto ns = &namespaces.GetDefaultNamespace();
shard_set->RunBriefInParallel([&](EngineShard* es) { shard_set->RunBriefInParallel([&](EngineShard* es) {
auto* bc = es->blocking_controller(); auto* bc = ns->GetBlockingController(es->shard_id());
if (bc) if (bc)
sum.fetch_add(bc->NumWatched(0), memory_order_relaxed); sum.fetch_add(bc->NumWatched(0), memory_order_relaxed);
}); });
@ -43,8 +45,9 @@ class ListFamilyTest : public BaseFamilyTest {
static bool HasAwakened() { static bool HasAwakened() {
atomic_uint32_t sum{0}; atomic_uint32_t sum{0};
auto ns = &namespaces.GetDefaultNamespace();
shard_set->RunBriefInParallel([&](EngineShard* es) { shard_set->RunBriefInParallel([&](EngineShard* es) {
auto* bc = es->blocking_controller(); auto* bc = ns->GetBlockingController(es->shard_id());
if (bc) if (bc)
sum.fetch_add(bc->HasAwakedTransaction(), memory_order_relaxed); sum.fetch_add(bc->HasAwakedTransaction(), memory_order_relaxed);
}); });
@ -168,7 +171,7 @@ TEST_F(ListFamilyTest, BLPopTimeout) {
RespExpr resp = Run({"blpop", kKey1, kKey2, kKey3, "0.01"}); RespExpr resp = Run({"blpop", kKey1, kKey2, kKey3, "0.01"});
EXPECT_THAT(resp, ArgType(RespExpr::NIL_ARRAY)); EXPECT_THAT(resp, ArgType(RespExpr::NIL_ARRAY));
EXPECT_EQ(3, GetDebugInfo().shards_count); EXPECT_EQ(3, GetDebugInfo().shards_count);
ASSERT_FALSE(service_->IsLocked(0, kKey1)); ASSERT_FALSE(IsLocked(0, kKey1));
// Under Multi // Under Multi
resp = Run({"multi"}); resp = Run({"multi"});
@ -178,7 +181,7 @@ TEST_F(ListFamilyTest, BLPopTimeout) {
resp = Run({"exec"}); resp = Run({"exec"});
EXPECT_THAT(resp, ArgType(RespExpr::NIL_ARRAY)); EXPECT_THAT(resp, ArgType(RespExpr::NIL_ARRAY));
ASSERT_FALSE(service_->IsLocked(0, kKey1)); ASSERT_FALSE(IsLocked(0, kKey1));
ASSERT_EQ(0, NumWatched()); ASSERT_EQ(0, NumWatched());
} }

View file

@ -47,6 +47,7 @@ extern "C" {
#include "server/json_family.h" #include "server/json_family.h"
#include "server/list_family.h" #include "server/list_family.h"
#include "server/multi_command_squasher.h" #include "server/multi_command_squasher.h"
#include "server/namespaces.h"
#include "server/script_mgr.h" #include "server/script_mgr.h"
#include "server/search/search_family.h" #include "server/search/search_family.h"
#include "server/server_state.h" #include "server/server_state.h"
@ -135,9 +136,11 @@ constexpr size_t kMaxThreadSize = 1024;
// Unwatch all keys for a connection and unregister from DbSlices. // Unwatch all keys for a connection and unregister from DbSlices.
// Used by UNWATCH, DICARD and EXEC. // Used by UNWATCH, DICARD and EXEC.
void UnwatchAllKeys(ConnectionState::ExecInfo* exec_info) { void UnwatchAllKeys(Namespace* ns, ConnectionState::ExecInfo* exec_info) {
if (!exec_info->watched_keys.empty()) { if (!exec_info->watched_keys.empty()) {
auto cb = [&](EngineShard* shard) { shard->db_slice().UnregisterConnectionWatches(exec_info); }; auto cb = [&](EngineShard* shard) {
ns->GetDbSlice(shard->shard_id()).UnregisterConnectionWatches(exec_info);
};
shard_set->RunBriefInParallel(std::move(cb)); shard_set->RunBriefInParallel(std::move(cb));
} }
exec_info->ClearWatched(); exec_info->ClearWatched();
@ -149,7 +152,7 @@ void MultiCleanup(ConnectionContext* cntx) {
ServerState::tlocal()->ReturnInterpreter(borrowed); ServerState::tlocal()->ReturnInterpreter(borrowed);
exec_info.preborrowed_interpreter = nullptr; exec_info.preborrowed_interpreter = nullptr;
} }
UnwatchAllKeys(&exec_info); UnwatchAllKeys(cntx->ns, &exec_info);
exec_info.Clear(); exec_info.Clear();
} }
@ -513,7 +516,8 @@ void Topkeys(const http::QueryArgs& args, HttpContext* send) {
vector<string> rows(shard_set->size()); vector<string> rows(shard_set->size());
shard_set->RunBriefInParallel([&](EngineShard* shard) { shard_set->RunBriefInParallel([&](EngineShard* shard) {
for (const auto& db : shard->db_slice().databases()) { for (const auto& db :
namespaces.GetDefaultNamespace().GetDbSlice(shard->shard_id()).databases()) {
if (db->top_keys.IsEnabled()) { if (db->top_keys.IsEnabled()) {
is_enabled = true; is_enabled = true;
for (const auto& [key, count] : db->top_keys.GetTopKeys()) { for (const auto& [key, count] : db->top_keys.GetTopKeys()) {
@ -829,6 +833,7 @@ Service::Service(ProactorPool* pp)
Service::~Service() { Service::~Service() {
delete shard_set; delete shard_set;
shard_set = nullptr; shard_set = nullptr;
namespaces.Clear();
} }
void Service::Init(util::AcceptServer* acceptor, std::vector<facade::Listener*> listeners, void Service::Init(util::AcceptServer* acceptor, std::vector<facade::Listener*> listeners,
@ -908,7 +913,10 @@ void Service::Shutdown() {
ChannelStore::Destroy(); ChannelStore::Destroy();
namespaces.Clear();
shard_set->Shutdown(); shard_set->Shutdown();
pp_.Await([](ProactorBase* pb) { ServerState::tlocal()->Destroy(); }); pp_.Await([](ProactorBase* pb) { ServerState::tlocal()->Destroy(); });
// wait for all the pending callbacks to stop. // wait for all the pending callbacks to stop.
@ -1212,8 +1220,8 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx)
DCHECK(dfly_cntx->transaction); DCHECK(dfly_cntx->transaction);
if (cid->IsTransactional()) { if (cid->IsTransactional()) {
dfly_cntx->transaction->MultiSwitchCmd(cid); dfly_cntx->transaction->MultiSwitchCmd(cid);
OpStatus status = OpStatus status = dfly_cntx->transaction->InitByArgs(
dfly_cntx->transaction->InitByArgs(dfly_cntx->conn_state.db_index, args_no_cmd); dfly_cntx->ns, dfly_cntx->conn_state.db_index, args_no_cmd);
if (status != OpStatus::OK) if (status != OpStatus::OK)
return cntx->SendError(status); return cntx->SendError(status);
@ -1225,7 +1233,9 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx)
dist_trans.reset(new Transaction{cid}); dist_trans.reset(new Transaction{cid});
if (!dist_trans->IsMulti()) { // Multi command initialize themself based on their mode. if (!dist_trans->IsMulti()) { // Multi command initialize themself based on their mode.
if (auto st = dist_trans->InitByArgs(dfly_cntx->conn_state.db_index, args_no_cmd); CHECK(dfly_cntx->ns != nullptr);
if (auto st =
dist_trans->InitByArgs(dfly_cntx->ns, dfly_cntx->conn_state.db_index, args_no_cmd);
st != OpStatus::OK) st != OpStatus::OK)
return cntx->SendError(st); return cntx->SendError(st);
} }
@ -1581,6 +1591,7 @@ facade::ConnectionContext* Service::CreateContext(util::FiberSocketBase* peer,
facade::Connection* owner) { facade::Connection* owner) {
auto cred = user_registry_.GetCredentials("default"); auto cred = user_registry_.GetCredentials("default");
ConnectionContext* res = new ConnectionContext{peer, owner, std::move(cred)}; ConnectionContext* res = new ConnectionContext{peer, owner, std::move(cred)};
res->ns = &namespaces.GetOrInsert("");
if (peer->IsUDS()) { if (peer->IsUDS()) {
res->req_auth = false; res->req_auth = false;
@ -1606,10 +1617,10 @@ const CommandId* Service::FindCmd(std::string_view cmd) const {
return registry_.Find(registry_.RenamedOrOriginal(cmd)); return registry_.Find(registry_.RenamedOrOriginal(cmd));
} }
bool Service::IsLocked(DbIndex db_index, std::string_view key) const { bool Service::IsLocked(Namespace* ns, DbIndex db_index, std::string_view key) const {
ShardId sid = Shard(key, shard_count()); ShardId sid = Shard(key, shard_count());
bool is_open = pp_.at(sid)->AwaitBrief([db_index, key] { bool is_open = pp_.at(sid)->AwaitBrief([db_index, key, ns, sid] {
return EngineShard::tlocal()->db_slice().CheckLock(IntentLock::EXCLUSIVE, db_index, key); return ns->GetDbSlice(sid).CheckLock(IntentLock::EXCLUSIVE, db_index, key);
}); });
return !is_open; return !is_open;
} }
@ -1682,7 +1693,7 @@ void Service::Watch(CmdArgList args, ConnectionContext* cntx) {
} }
void Service::Unwatch(CmdArgList args, ConnectionContext* cntx) { void Service::Unwatch(CmdArgList args, ConnectionContext* cntx) {
UnwatchAllKeys(&cntx->conn_state.exec_info); UnwatchAllKeys(cntx->ns, &cntx->conn_state.exec_info);
return cntx->SendOk(); return cntx->SendOk();
} }
@ -1841,6 +1852,7 @@ Transaction::MultiMode DetermineMultiMode(ScriptMgr::ScriptParams params) {
optional<bool> StartMultiEval(DbIndex dbid, CmdArgList keys, ScriptMgr::ScriptParams params, optional<bool> StartMultiEval(DbIndex dbid, CmdArgList keys, ScriptMgr::ScriptParams params,
ConnectionContext* cntx) { ConnectionContext* cntx) {
Transaction* trans = cntx->transaction; Transaction* trans = cntx->transaction;
Namespace* ns = cntx->ns;
Transaction::MultiMode script_mode = DetermineMultiMode(params); Transaction::MultiMode script_mode = DetermineMultiMode(params);
Transaction::MultiMode multi_mode = trans->GetMultiMode(); Transaction::MultiMode multi_mode = trans->GetMultiMode();
// Check if eval is already part of a running multi transaction // Check if eval is already part of a running multi transaction
@ -1860,10 +1872,10 @@ optional<bool> StartMultiEval(DbIndex dbid, CmdArgList keys, ScriptMgr::ScriptPa
switch (script_mode) { switch (script_mode) {
case Transaction::GLOBAL: case Transaction::GLOBAL:
trans->StartMultiGlobal(dbid); trans->StartMultiGlobal(ns, dbid);
return true; return true;
case Transaction::LOCK_AHEAD: case Transaction::LOCK_AHEAD:
trans->StartMultiLockedAhead(dbid, keys); trans->StartMultiLockedAhead(ns, dbid, keys);
return true; return true;
case Transaction::NON_ATOMIC: case Transaction::NON_ATOMIC:
trans->StartMultiNonAtomic(); trans->StartMultiNonAtomic();
@ -1988,7 +2000,7 @@ void Service::EvalInternal(CmdArgList args, const EvalArgs& eval_args, Interpret
}); });
++ServerState::tlocal()->stats.eval_shardlocal_coordination_cnt; ++ServerState::tlocal()->stats.eval_shardlocal_coordination_cnt;
tx->PrepareMultiForScheduleSingleHop(*sid, tx->GetDbIndex(), args); tx->PrepareMultiForScheduleSingleHop(cntx->ns, *sid, tx->GetDbIndex(), args);
tx->ScheduleSingleHop([&](Transaction*, EngineShard*) { tx->ScheduleSingleHop([&](Transaction*, EngineShard*) {
boost::intrusive_ptr<Transaction> stub_tx = boost::intrusive_ptr<Transaction> stub_tx =
new Transaction{tx, *sid, slot_checker.GetUniqueSlotId()}; new Transaction{tx, *sid, slot_checker.GetUniqueSlotId()};
@ -2077,7 +2089,7 @@ bool CheckWatchedKeyExpiry(ConnectionContext* cntx, const CommandRegistry& regis
}; };
cntx->transaction->MultiSwitchCmd(registry.Find(EXISTS)); cntx->transaction->MultiSwitchCmd(registry.Find(EXISTS));
cntx->transaction->InitByArgs(cntx->conn_state.db_index, CmdArgList{str_list}); cntx->transaction->InitByArgs(cntx->ns, cntx->conn_state.db_index, CmdArgList{str_list});
OpStatus status = cntx->transaction->ScheduleSingleHop(std::move(cb)); OpStatus status = cntx->transaction->ScheduleSingleHop(std::move(cb));
CHECK_EQ(OpStatus::OK, status); CHECK_EQ(OpStatus::OK, status);
@ -2139,11 +2151,11 @@ void StartMultiExec(ConnectionContext* cntx, ConnectionState::ExecInfo* exec_inf
auto dbid = cntx->db_index(); auto dbid = cntx->db_index();
switch (multi_mode) { switch (multi_mode) {
case Transaction::GLOBAL: case Transaction::GLOBAL:
trans->StartMultiGlobal(dbid); trans->StartMultiGlobal(cntx->ns, dbid);
break; break;
case Transaction::LOCK_AHEAD: { case Transaction::LOCK_AHEAD: {
auto vec = CollectAllKeys(exec_info); auto vec = CollectAllKeys(exec_info);
trans->StartMultiLockedAhead(dbid, absl::MakeSpan(vec)); trans->StartMultiLockedAhead(cntx->ns, dbid, absl::MakeSpan(vec));
} break; } break;
case Transaction::NON_ATOMIC: case Transaction::NON_ATOMIC:
trans->StartMultiNonAtomic(); trans->StartMultiNonAtomic();
@ -2234,7 +2246,7 @@ void Service::Exec(CmdArgList args, ConnectionContext* cntx) {
CmdArgList args = absl::MakeSpan(arg_vec); CmdArgList args = absl::MakeSpan(arg_vec);
if (scmd.Cid()->IsTransactional()) { if (scmd.Cid()->IsTransactional()) {
OpStatus st = cntx->transaction->InitByArgs(cntx->conn_state.db_index, args); OpStatus st = cntx->transaction->InitByArgs(cntx->ns, cntx->conn_state.db_index, args);
if (st != OpStatus::OK) { if (st != OpStatus::OK) {
cntx->SendError(st); cntx->SendError(st);
break; break;
@ -2444,7 +2456,7 @@ void Service::Command(CmdArgList args, ConnectionContext* cntx) {
VarzValue::Map Service::GetVarzStats() { VarzValue::Map Service::GetVarzStats() {
VarzValue::Map res; VarzValue::Map res;
Metrics m = server_family_.GetMetrics(); Metrics m = server_family_.GetMetrics(&namespaces.GetDefaultNamespace());
DbStats db_stats; DbStats db_stats;
for (const auto& s : m.db_stats) { for (const auto& s : m.db_stats) {
db_stats += s; db_stats += s;
@ -2543,7 +2555,7 @@ void Service::OnClose(facade::ConnectionContext* cntx) {
DCHECK(!conn_state.subscribe_info); DCHECK(!conn_state.subscribe_info);
} }
UnwatchAllKeys(&conn_state.exec_info); UnwatchAllKeys(server_cntx->ns, &conn_state.exec_info);
DeactivateMonitoring(server_cntx); DeactivateMonitoring(server_cntx);

View file

@ -102,7 +102,7 @@ class Service : public facade::ServiceInterface {
} }
// Used by tests. // Used by tests.
bool IsLocked(DbIndex db_index, std::string_view key) const; bool IsLocked(Namespace* ns, DbIndex db_index, std::string_view key) const;
bool IsShardSetLocked() const; bool IsShardSetLocked() const;
util::ProactorPool& proactor_pool() { util::ProactorPool& proactor_pool() {

View file

@ -245,7 +245,7 @@ void PushMemoryUsageStats(const base::IoBuf::MemoryUsage& mem, string_view prefi
void MemoryCmd::Stats() { void MemoryCmd::Stats() {
vector<pair<string, size_t>> stats; vector<pair<string, size_t>> stats;
stats.reserve(25); stats.reserve(25);
auto server_metrics = owner_->GetMetrics(); auto server_metrics = owner_->GetMetrics(cntx_->ns);
// RSS // RSS
stats.push_back({"rss_bytes", rss_mem_current.load(memory_order_relaxed)}); stats.push_back({"rss_bytes", rss_mem_current.load(memory_order_relaxed)});
@ -369,8 +369,8 @@ void MemoryCmd::ArenaStats(CmdArgList args) {
void MemoryCmd::Usage(std::string_view key) { void MemoryCmd::Usage(std::string_view key) {
ShardId sid = Shard(key, shard_set->size()); ShardId sid = Shard(key, shard_set->size());
ssize_t memory_usage = shard_set->pool()->at(sid)->AwaitBrief([key, this]() -> ssize_t { ssize_t memory_usage = shard_set->pool()->at(sid)->AwaitBrief([key, this, sid]() -> ssize_t {
auto& db_slice = EngineShard::tlocal()->db_slice(); auto& db_slice = cntx_->ns->GetDbSlice(sid);
auto [pt, exp_t] = db_slice.GetTables(cntx_->db_index()); auto [pt, exp_t] = db_slice.GetTables(cntx_->db_index());
PrimeIterator it = pt->Find(key); PrimeIterator it = pt->Find(key);
if (IsValid(it)) { if (IsValid(it)) {

View file

@ -145,7 +145,7 @@ bool MultiCommandSquasher::ExecuteStandalone(StoredCmd* cmd) {
cntx_->cid = cmd->Cid(); cntx_->cid = cmd->Cid();
if (cmd->Cid()->IsTransactional()) if (cmd->Cid()->IsTransactional())
tx->InitByArgs(cntx_->conn_state.db_index, args); tx->InitByArgs(cntx_->ns, cntx_->conn_state.db_index, args);
service_->InvokeCmd(cmd->Cid(), args, cntx_); service_->InvokeCmd(cmd->Cid(), args, cntx_);
return true; return true;
@ -181,7 +181,7 @@ OpStatus MultiCommandSquasher::SquashedHopCb(Transaction* parent_tx, EngineShard
local_cntx.cid = cmd->Cid(); local_cntx.cid = cmd->Cid();
crb.SetReplyMode(cmd->ReplyMode()); crb.SetReplyMode(cmd->ReplyMode());
local_tx->InitByArgs(local_cntx.conn_state.db_index, args); local_tx->InitByArgs(cntx_->ns, local_cntx.conn_state.db_index, args);
service_->InvokeCmd(cmd->Cid(), args, &local_cntx); service_->InvokeCmd(cmd->Cid(), args, &local_cntx);
sinfo.replies.emplace_back(crb.Take()); sinfo.replies.emplace_back(crb.Take());

View file

@ -109,8 +109,8 @@ TEST_F(MultiTest, Multi) {
resp = Run({"get", kKey4}); resp = Run({"get", kKey4});
ASSERT_THAT(resp, ArgType(RespExpr::NIL)); ASSERT_THAT(resp, ArgType(RespExpr::NIL));
ASSERT_FALSE(service_->IsLocked(0, kKey1)); ASSERT_FALSE(IsLocked(0, kKey1));
ASSERT_FALSE(service_->IsLocked(0, kKey4)); ASSERT_FALSE(IsLocked(0, kKey4));
ASSERT_FALSE(service_->IsShardSetLocked()); ASSERT_FALSE(service_->IsShardSetLocked());
} }
@ -129,8 +129,8 @@ TEST_F(MultiTest, MultiGlobalCommands) {
ASSERT_THAT(Run({"select", "2"}), "OK"); ASSERT_THAT(Run({"select", "2"}), "OK");
ASSERT_THAT(Run({"get", "key"}), "val"); ASSERT_THAT(Run({"get", "key"}), "val");
ASSERT_FALSE(service_->IsLocked(0, "key")); ASSERT_FALSE(IsLocked(0, "key"));
ASSERT_FALSE(service_->IsLocked(2, "key")); ASSERT_FALSE(IsLocked(2, "key"));
} }
TEST_F(MultiTest, HitMissStats) { TEST_F(MultiTest, HitMissStats) {
@ -181,8 +181,8 @@ TEST_F(MultiTest, MultiSeq) {
ASSERT_EQ(resp, "QUEUED"); ASSERT_EQ(resp, "QUEUED");
resp = Run({"exec"}); resp = Run({"exec"});
ASSERT_FALSE(service_->IsLocked(0, kKey1)); ASSERT_FALSE(IsLocked(0, kKey1));
ASSERT_FALSE(service_->IsLocked(0, kKey4)); ASSERT_FALSE(IsLocked(0, kKey4));
ASSERT_FALSE(service_->IsShardSetLocked()); ASSERT_FALSE(service_->IsShardSetLocked());
ASSERT_THAT(resp, ArrLen(3)); ASSERT_THAT(resp, ArrLen(3));
@ -237,8 +237,8 @@ TEST_F(MultiTest, MultiConsistent) {
mset_fb.Join(); mset_fb.Join();
fb.Join(); fb.Join();
ASSERT_FALSE(service_->IsLocked(0, kKey1)); ASSERT_FALSE(IsLocked(0, kKey1));
ASSERT_FALSE(service_->IsLocked(0, kKey4)); ASSERT_FALSE(IsLocked(0, kKey4));
ASSERT_FALSE(service_->IsShardSetLocked()); ASSERT_FALSE(service_->IsShardSetLocked());
} }
@ -312,9 +312,9 @@ TEST_F(MultiTest, MultiRename) {
resp = Run({"exec"}); resp = Run({"exec"});
EXPECT_EQ(resp, "OK"); EXPECT_EQ(resp, "OK");
EXPECT_FALSE(service_->IsLocked(0, kKey1)); EXPECT_FALSE(IsLocked(0, kKey1));
EXPECT_FALSE(service_->IsLocked(0, kKey2)); EXPECT_FALSE(IsLocked(0, kKey2));
EXPECT_FALSE(service_->IsLocked(0, kKey4)); EXPECT_FALSE(IsLocked(0, kKey4));
EXPECT_FALSE(service_->IsShardSetLocked()); EXPECT_FALSE(service_->IsShardSetLocked());
} }
@ -366,8 +366,8 @@ TEST_F(MultiTest, FlushDb) {
fb0.Join(); fb0.Join();
ASSERT_FALSE(service_->IsLocked(0, kKey1)); ASSERT_FALSE(IsLocked(0, kKey1));
ASSERT_FALSE(service_->IsLocked(0, kKey4)); ASSERT_FALSE(IsLocked(0, kKey4));
ASSERT_FALSE(service_->IsShardSetLocked()); ASSERT_FALSE(service_->IsShardSetLocked());
} }
@ -400,17 +400,17 @@ TEST_F(MultiTest, Eval) {
resp = Run({"eval", "return redis.call('get', 'foo')", "1", "bar"}); resp = Run({"eval", "return redis.call('get', 'foo')", "1", "bar"});
EXPECT_THAT(resp, ErrArg("undeclared")); EXPECT_THAT(resp, ErrArg("undeclared"));
ASSERT_FALSE(service_->IsLocked(0, "foo")); ASSERT_FALSE(IsLocked(0, "foo"));
Run({"script", "flush"}); // Reset global flag from autocorrect Run({"script", "flush"}); // Reset global flag from autocorrect
resp = Run({"eval", "return redis.call('get', 'foo')", "1", "foo"}); resp = Run({"eval", "return redis.call('get', 'foo')", "1", "foo"});
EXPECT_THAT(resp, "42"); EXPECT_THAT(resp, "42");
ASSERT_FALSE(service_->IsLocked(0, "foo")); ASSERT_FALSE(IsLocked(0, "foo"));
resp = Run({"eval", "return redis.call('get', KEYS[1])", "1", "foo"}); resp = Run({"eval", "return redis.call('get', KEYS[1])", "1", "foo"});
EXPECT_THAT(resp, "42"); EXPECT_THAT(resp, "42");
ASSERT_FALSE(service_->IsLocked(0, "foo")); ASSERT_FALSE(IsLocked(0, "foo"));
ASSERT_FALSE(service_->IsShardSetLocked()); ASSERT_FALSE(service_->IsShardSetLocked());
resp = Run({"eval", "return 77", "2", "foo", "zoo"}); resp = Run({"eval", "return 77", "2", "foo", "zoo"});
@ -451,7 +451,7 @@ TEST_F(MultiTest, Eval) {
"1", "foo"}), "1", "foo"}),
"42"); "42");
auto condition = [&]() { return service_->IsLocked(0, "foo"); }; auto condition = [&]() { return IsLocked(0, "foo"); };
auto fb = ExpectConditionWithSuspension(condition); auto fb = ExpectConditionWithSuspension(condition);
EXPECT_EQ(Run({"eval", EXPECT_EQ(Run({"eval",
R"(redis.call('set', 'foo', '42') R"(redis.call('set', 'foo', '42')
@ -974,7 +974,7 @@ TEST_F(MultiTest, TestLockedKeys) {
GTEST_SKIP() << "Skipped TestLockedKeys test because multi_exec_mode is not lock ahead"; GTEST_SKIP() << "Skipped TestLockedKeys test because multi_exec_mode is not lock ahead";
return; return;
} }
auto condition = [&]() { return service_->IsLocked(0, "key1") && service_->IsLocked(0, "key2"); }; auto condition = [&]() { return IsLocked(0, "key1") && IsLocked(0, "key2"); };
auto fb = ExpectConditionWithSuspension(condition); auto fb = ExpectConditionWithSuspension(condition);
EXPECT_EQ(Run({"multi"}), "OK"); EXPECT_EQ(Run({"multi"}), "OK");
@ -983,8 +983,8 @@ TEST_F(MultiTest, TestLockedKeys) {
EXPECT_EQ(Run({"mset", "key1", "val3", "key1", "val4"}), "QUEUED"); EXPECT_EQ(Run({"mset", "key1", "val3", "key1", "val4"}), "QUEUED");
EXPECT_THAT(Run({"exec"}), RespArray(ElementsAre("OK", "OK", "OK"))); EXPECT_THAT(Run({"exec"}), RespArray(ElementsAre("OK", "OK", "OK")));
fb.Join(); fb.Join();
EXPECT_FALSE(service_->IsLocked(0, "key1")); EXPECT_FALSE(IsLocked(0, "key1"));
EXPECT_FALSE(service_->IsLocked(0, "key2")); EXPECT_FALSE(IsLocked(0, "key2"));
} }
TEST_F(MultiTest, EvalExpiration) { TEST_F(MultiTest, EvalExpiration) {

103
src/server/namespaces.cc Normal file
View file

@ -0,0 +1,103 @@
#include "server/namespaces.h"
#include "base/flags.h"
#include "base/logging.h"
#include "server/engine_shard_set.h"
ABSL_DECLARE_FLAG(bool, cache_mode);
namespace dfly {
using namespace std;
Namespace::Namespace() {
shard_db_slices_.resize(shard_set->size());
shard_blocking_controller_.resize(shard_set->size());
shard_set->RunBriefInParallel([&](EngineShard* es) {
CHECK(es != nullptr);
ShardId sid = es->shard_id();
shard_db_slices_[sid] = make_unique<DbSlice>(sid, absl::GetFlag(FLAGS_cache_mode), es);
shard_db_slices_[sid]->UpdateExpireBase(absl::GetCurrentTimeNanos() / 1000000, 0);
});
}
DbSlice& Namespace::GetCurrentDbSlice() {
EngineShard* es = EngineShard::tlocal();
CHECK(es != nullptr);
return GetDbSlice(es->shard_id());
}
DbSlice& Namespace::GetDbSlice(ShardId sid) {
CHECK_LT(sid, shard_db_slices_.size());
return *shard_db_slices_[sid];
}
BlockingController* Namespace::GetOrAddBlockingController(EngineShard* shard) {
if (!shard_blocking_controller_[shard->shard_id()]) {
shard_blocking_controller_[shard->shard_id()] = make_unique<BlockingController>(shard, this);
}
return shard_blocking_controller_[shard->shard_id()].get();
}
BlockingController* Namespace::GetBlockingController(ShardId sid) {
return shard_blocking_controller_[sid].get();
}
Namespaces namespaces;
Namespaces::~Namespaces() {
Clear();
}
void Namespaces::Init() {
DCHECK(default_namespace_ == nullptr);
default_namespace_ = &GetOrInsert("");
}
bool Namespaces::IsInitialized() const {
return default_namespace_ != nullptr;
}
void Namespaces::Clear() {
std::unique_lock guard(mu_);
namespaces.default_namespace_ = nullptr;
if (namespaces_.empty()) {
return;
}
shard_set->RunBriefInParallel([&](EngineShard* es) {
CHECK(es != nullptr);
for (auto& ns : namespaces_) {
ns.second.shard_db_slices_[es->shard_id()].reset();
}
});
namespaces.namespaces_.clear();
}
Namespace& Namespaces::GetDefaultNamespace() const {
CHECK(default_namespace_ != nullptr);
return *default_namespace_;
}
Namespace& Namespaces::GetOrInsert(std::string_view ns) {
{
// Try to look up under a shared lock
std::shared_lock guard(mu_);
auto it = namespaces_.find(ns);
if (it != namespaces_.end()) {
return it->second;
}
}
{
// Key was not found, so we create create it under unique lock
std::unique_lock guard(mu_);
return namespaces_[ns];
}
}
} // namespace dfly

70
src/server/namespaces.h Normal file
View file

@ -0,0 +1,70 @@
// Copyright 2024, DragonflyDB authors. All rights reserved.
// See LICENSE for licensing terms.
//
#pragma once
#include <absl/container/node_hash_map.h>
#include <memory>
#include <string>
#include <vector>
#include "server/blocking_controller.h"
#include "server/db_slice.h"
#include "server/tx_base.h"
#include "util/fibers/synchronization.h"
namespace dfly {
// A Namespace is a way to separate and isolate different databases in a single instance.
// It can be used to allow multiple tenants to use the same server without hacks of using a common
// prefix, or SELECT-ing a different database.
// Each Namespace contains per-shard DbSlice, as well as a BlockingController.
class Namespace {
public:
Namespace();
DbSlice& GetCurrentDbSlice();
DbSlice& GetDbSlice(ShardId sid);
BlockingController* GetOrAddBlockingController(EngineShard* shard);
BlockingController* GetBlockingController(ShardId sid);
private:
std::vector<std::unique_ptr<DbSlice>> shard_db_slices_;
std::vector<std::unique_ptr<BlockingController>> shard_blocking_controller_;
friend class Namespaces;
};
// Namespaces is a registry and container for Namespace instances.
// Each Namespace has a unique string name, which identifies it in the store.
// Any attempt to access a non-existing Namespace will first create it, add it to the internal map
// and will then return it.
// It is currently impossible to remove a Namespace after it has been created.
// The default Namespace can be accessed via either GetDefaultNamespace() (which guarantees not to
// yield), or via the GetOrInsert() with an empty string.
// The initialization order of this class with the engine shards is slightly subtle, as they have
// mutual dependencies.
class Namespaces {
public:
Namespaces() = default;
~Namespaces();
void Init();
bool IsInitialized() const;
void Clear(); // Thread unsafe, use in tear-down or tests
Namespace& GetDefaultNamespace() const; // No locks
Namespace& GetOrInsert(std::string_view ns);
private:
util::fb2::SharedMutex mu_{};
absl::node_hash_map<std::string, Namespace> namespaces_ ABSL_GUARDED_BY(mu_);
Namespace* default_namespace_ = nullptr;
};
extern Namespaces namespaces;
} // namespace dfly

View file

@ -2066,7 +2066,8 @@ error_code RdbLoader::Load(io::Source* src) {
FlushShardAsync(i); FlushShardAsync(i);
// Active database if not existed before. // Active database if not existed before.
shard_set->Add(i, [dbid] { EngineShard::tlocal()->db_slice().ActivateDb(dbid); }); shard_set->Add(
i, [dbid] { namespaces.GetDefaultNamespace().GetCurrentDbSlice().ActivateDb(dbid); });
} }
cur_db_index_ = dbid; cur_db_index_ = dbid;
@ -2451,8 +2452,8 @@ std::error_code RdbLoaderBase::FromOpaque(const OpaqueObj& opaque, CompactObj* p
void RdbLoader::LoadItemsBuffer(DbIndex db_ind, const ItemsBuf& ib) { void RdbLoader::LoadItemsBuffer(DbIndex db_ind, const ItemsBuf& ib) {
EngineShard* es = EngineShard::tlocal(); EngineShard* es = EngineShard::tlocal();
DbSlice& db_slice = es->db_slice(); DbContext db_cntx{&namespaces.GetDefaultNamespace(), db_ind, GetCurrentTimeMs()};
DbContext db_cntx{db_ind, GetCurrentTimeMs()}; DbSlice& db_slice = db_cntx.GetDbSlice(es->shard_id());
for (const auto* item : ib) { for (const auto* item : ib) {
PrimeValue pv; PrimeValue pv;
@ -2564,6 +2565,7 @@ void RdbLoader::LoadSearchIndexDefFromAux(string&& def) {
cntx.is_replicating = true; cntx.is_replicating = true;
cntx.journal_emulated = true; cntx.journal_emulated = true;
cntx.skip_acl_validation = true; cntx.skip_acl_validation = true;
cntx.ns = &namespaces.GetDefaultNamespace();
// Avoid deleting local crb // Avoid deleting local crb
absl::Cleanup cntx_clean = [&cntx] { cntx.Inject(nullptr); }; absl::Cleanup cntx_clean = [&cntx] { cntx.Inject(nullptr); };
@ -2613,7 +2615,8 @@ void RdbLoader::PerformPostLoad(Service* service) {
// Rebuild all search indices as only their definitions are extracted from the snapshot // Rebuild all search indices as only their definitions are extracted from the snapshot
shard_set->AwaitRunningOnShardQueue([](EngineShard* es) { shard_set->AwaitRunningOnShardQueue([](EngineShard* es) {
es->search_indices()->RebuildAllIndices(OpArgs{es, nullptr, DbContext{0, GetCurrentTimeMs()}}); es->search_indices()->RebuildAllIndices(
OpArgs{es, nullptr, DbContext{&namespaces.GetDefaultNamespace(), 0, GetCurrentTimeMs()}});
}); });
} }

View file

@ -35,6 +35,7 @@ extern "C" {
#include "server/engine_shard_set.h" #include "server/engine_shard_set.h"
#include "server/error.h" #include "server/error.h"
#include "server/main_service.h" #include "server/main_service.h"
#include "server/namespaces.h"
#include "server/rdb_extensions.h" #include "server/rdb_extensions.h"
#include "server/search/doc_index.h" #include "server/search/doc_index.h"
#include "server/serializer_commons.h" #include "server/serializer_commons.h"
@ -1251,15 +1252,17 @@ error_code RdbSaver::Impl::ConsumeChannel(const Cancellation* cll) {
void RdbSaver::Impl::StartSnapshotting(bool stream_journal, const Cancellation* cll, void RdbSaver::Impl::StartSnapshotting(bool stream_journal, const Cancellation* cll,
EngineShard* shard) { EngineShard* shard) {
auto& s = GetSnapshot(shard); auto& s = GetSnapshot(shard);
s = std::make_unique<SliceSnapshot>(&shard->db_slice(), &channel_, compression_mode_); auto& db_slice = namespaces.GetDefaultNamespace().GetDbSlice(shard->shard_id());
s = std::make_unique<SliceSnapshot>(&db_slice, &channel_, compression_mode_);
s->Start(stream_journal, cll); s->Start(stream_journal, cll);
} }
void RdbSaver::Impl::StartIncrementalSnapshotting(Context* cntx, EngineShard* shard, void RdbSaver::Impl::StartIncrementalSnapshotting(Context* cntx, EngineShard* shard,
LSN start_lsn) { LSN start_lsn) {
auto& db_slice = namespaces.GetDefaultNamespace().GetDbSlice(shard->shard_id());
auto& s = GetSnapshot(shard); auto& s = GetSnapshot(shard);
s = std::make_unique<SliceSnapshot>(&shard->db_slice(), &channel_, compression_mode_); s = std::make_unique<SliceSnapshot>(&db_slice, &channel_, compression_mode_);
s->StartIncremental(cntx, start_lsn); s->StartIncremental(cntx, start_lsn);
} }

View file

@ -570,6 +570,7 @@ error_code Replica::ConsumeRedisStream() {
conn_context.is_replicating = true; conn_context.is_replicating = true;
conn_context.journal_emulated = true; conn_context.journal_emulated = true;
conn_context.skip_acl_validation = true; conn_context.skip_acl_validation = true;
conn_context.ns = &namespaces.GetDefaultNamespace();
ResetParser(true); ResetParser(true);
// Master waits for this command in order to start sending replication stream. // Master waits for this command in order to start sending replication stream.

View file

@ -447,7 +447,7 @@ void ClientPauseCmd(CmdArgList args, vector<facade::Listener*> listeners, Connec
}; };
if (auto pause_fb_opt = if (auto pause_fb_opt =
Pause(listeners, cntx->conn(), pause_state, std::move(is_pause_in_progress)); Pause(listeners, cntx->ns, cntx->conn(), pause_state, std::move(is_pause_in_progress));
pause_fb_opt) { pause_fb_opt) {
pause_fb_opt->Detach(); pause_fb_opt->Detach();
cntx->SendOk(); cntx->SendOk();
@ -673,8 +673,8 @@ optional<ReplicaOfArgs> ReplicaOfArgs::FromCmdArgs(CmdArgList args, ConnectionCo
} // namespace } // namespace
std::optional<fb2::Fiber> Pause(std::vector<facade::Listener*> listeners, facade::Connection* conn, std::optional<fb2::Fiber> Pause(std::vector<facade::Listener*> listeners, Namespace* ns,
ClientPause pause_state, facade::Connection* conn, ClientPause pause_state,
std::function<bool()> is_pause_in_progress) { std::function<bool()> is_pause_in_progress) {
// Track connections and set pause state to be able to wait untill all running transactions read // Track connections and set pause state to be able to wait untill all running transactions read
// the new pause state. Exlude already paused commands from the busy count. Exlude tracking // the new pause state. Exlude already paused commands from the busy count. Exlude tracking
@ -683,7 +683,7 @@ std::optional<fb2::Fiber> Pause(std::vector<facade::Listener*> listeners, facade
// command that did not pause on the new state yet we will pause after waking up. // command that did not pause on the new state yet we will pause after waking up.
DispatchTracker tracker{std::move(listeners), conn, true /* ignore paused commands */, DispatchTracker tracker{std::move(listeners), conn, true /* ignore paused commands */,
true /*ignore blocking*/}; true /*ignore blocking*/};
shard_set->pool()->AwaitBrief([&tracker, pause_state](unsigned, util::ProactorBase*) { shard_set->pool()->AwaitBrief([&tracker, pause_state, ns](unsigned, util::ProactorBase*) {
// Commands don't suspend before checking the pause state, so // Commands don't suspend before checking the pause state, so
// it's impossible to deadlock on waiting for a command that will be paused. // it's impossible to deadlock on waiting for a command that will be paused.
tracker.TrackOnThread(); tracker.TrackOnThread();
@ -703,9 +703,9 @@ std::optional<fb2::Fiber> Pause(std::vector<facade::Listener*> listeners, facade
// We should not expire/evict keys while clients are puased. // We should not expire/evict keys while clients are puased.
shard_set->RunBriefInParallel( shard_set->RunBriefInParallel(
[](EngineShard* shard) { shard->db_slice().SetExpireAllowed(false); }); [ns](EngineShard* shard) { ns->GetDbSlice(shard->shard_id()).SetExpireAllowed(false); });
return fb2::Fiber("client_pause", [is_pause_in_progress, pause_state]() mutable { return fb2::Fiber("client_pause", [is_pause_in_progress, pause_state, ns]() mutable {
// On server shutdown we sleep 10ms to make sure all running task finish, therefore 10ms steps // On server shutdown we sleep 10ms to make sure all running task finish, therefore 10ms steps
// ensure this fiber will not left hanging . // ensure this fiber will not left hanging .
constexpr auto step = 10ms; constexpr auto step = 10ms;
@ -719,7 +719,7 @@ std::optional<fb2::Fiber> Pause(std::vector<facade::Listener*> listeners, facade
ServerState::tlocal()->SetPauseState(pause_state, false); ServerState::tlocal()->SetPauseState(pause_state, false);
}); });
shard_set->RunBriefInParallel( shard_set->RunBriefInParallel(
[](EngineShard* shard) { shard->db_slice().SetExpireAllowed(true); }); [ns](EngineShard* shard) { ns->GetDbSlice(shard->shard_id()).SetExpireAllowed(true); });
} }
}); });
} }
@ -1345,7 +1345,8 @@ void ServerFamily::ConfigureMetrics(util::HttpListenerBase* http_base) {
auto cb = [this](const util::http::QueryArgs& args, util::HttpContext* send) { auto cb = [this](const util::http::QueryArgs& args, util::HttpContext* send) {
StringResponse resp = util::http::MakeStringResponse(boost::beast::http::status::ok); StringResponse resp = util::http::MakeStringResponse(boost::beast::http::status::ok);
PrintPrometheusMetrics(this->GetMetrics(), this->dfly_cmd_.get(), &resp); PrintPrometheusMetrics(this->GetMetrics(&namespaces.GetDefaultNamespace()),
this->dfly_cmd_.get(), &resp);
return send->Invoke(std::move(resp)); return send->Invoke(std::move(resp));
}; };
@ -1424,7 +1425,7 @@ void ServerFamily::StatsMC(std::string_view section, facade::ConnectionContext*
double utime = dbl_time(ru.ru_utime); double utime = dbl_time(ru.ru_utime);
double systime = dbl_time(ru.ru_stime); double systime = dbl_time(ru.ru_stime);
Metrics m = GetMetrics(); Metrics m = GetMetrics(&namespaces.GetDefaultNamespace());
ADD_LINE(pid, getpid()); ADD_LINE(pid, getpid());
ADD_LINE(uptime, m.uptime); ADD_LINE(uptime, m.uptime);
@ -1454,7 +1455,7 @@ GenericError ServerFamily::DoSave(bool ignore_state) {
const CommandId* cid = service().FindCmd("SAVE"); const CommandId* cid = service().FindCmd("SAVE");
CHECK_NOTNULL(cid); CHECK_NOTNULL(cid);
boost::intrusive_ptr<Transaction> trans(new Transaction{cid}); boost::intrusive_ptr<Transaction> trans(new Transaction{cid});
trans->InitByArgs(0, {}); trans->InitByArgs(&namespaces.GetDefaultNamespace(), 0, {});
return DoSave(absl::GetFlag(FLAGS_df_snapshot_format), {}, trans.get(), ignore_state); return DoSave(absl::GetFlag(FLAGS_df_snapshot_format), {}, trans.get(), ignore_state);
} }
@ -1551,7 +1552,7 @@ void ServerFamily::DbSize(CmdArgList args, ConnectionContext* cntx) {
shard_set->RunBriefInParallel( shard_set->RunBriefInParallel(
[&](EngineShard* shard) { [&](EngineShard* shard) {
auto db_size = shard->db_slice().DbSize(cntx->conn_state.db_index); auto db_size = cntx->ns->GetDbSlice(shard->shard_id()).DbSize(cntx->conn_state.db_index);
num_keys.fetch_add(db_size, memory_order_relaxed); num_keys.fetch_add(db_size, memory_order_relaxed);
}, },
[](ShardId) { return true; }); [](ShardId) { return true; });
@ -1647,6 +1648,7 @@ void ServerFamily::Auth(CmdArgList args, ConnectionContext* cntx) {
auto cred = registry->GetCredentials(username); auto cred = registry->GetCredentials(username);
cntx->acl_commands = cred.acl_commands; cntx->acl_commands = cred.acl_commands;
cntx->keys = std::move(cred.keys); cntx->keys = std::move(cred.keys);
cntx->ns = &namespaces.GetOrInsert(cred.ns);
cntx->authenticated = true; cntx->authenticated = true;
return cntx->SendOk(); return cntx->SendOk();
} }
@ -1773,7 +1775,7 @@ void ServerFamily::Config(CmdArgList args, ConnectionContext* cntx) {
} }
if (sub_cmd == "RESETSTAT") { if (sub_cmd == "RESETSTAT") {
ResetStat(); ResetStat(cntx->ns);
return cntx->SendOk(); return cntx->SendOk();
} else { } else {
return cntx->SendError(UnknownSubCmd(sub_cmd, "CONFIG"), kSyntaxErrType); return cntx->SendError(UnknownSubCmd(sub_cmd, "CONFIG"), kSyntaxErrType);
@ -1883,17 +1885,16 @@ static void MergeDbSliceStats(const DbSlice::Stats& src, Metrics* dest) {
dest->small_string_bytes += src.small_string_bytes; dest->small_string_bytes += src.small_string_bytes;
} }
void ServerFamily::ResetStat() { void ServerFamily::ResetStat(Namespace* ns) {
shard_set->pool()->AwaitBrief( shard_set->pool()->AwaitBrief(
[registry = service_.mutable_registry(), this](unsigned index, auto*) { [registry = service_.mutable_registry(), this, ns](unsigned index, auto*) {
registry->ResetCallStats(index); registry->ResetCallStats(index);
SinkReplyBuilder::ResetThreadLocalStats(); SinkReplyBuilder::ResetThreadLocalStats();
auto& stats = tl_facade_stats->conn_stats; auto& stats = tl_facade_stats->conn_stats;
stats.command_cnt = 0; stats.command_cnt = 0;
stats.pipelined_cmd_cnt = 0; stats.pipelined_cmd_cnt = 0;
EngineShard* shard = EngineShard::tlocal(); ns->GetCurrentDbSlice().ResetEvents();
shard->db_slice().ResetEvents();
tl_facade_stats->conn_stats.conn_received_cnt = 0; tl_facade_stats->conn_stats.conn_received_cnt = 0;
tl_facade_stats->conn_stats.pipelined_cmd_cnt = 0; tl_facade_stats->conn_stats.pipelined_cmd_cnt = 0;
tl_facade_stats->conn_stats.command_cnt = 0; tl_facade_stats->conn_stats.command_cnt = 0;
@ -1910,7 +1911,7 @@ void ServerFamily::ResetStat() {
}); });
} }
Metrics ServerFamily::GetMetrics() const { Metrics ServerFamily::GetMetrics(Namespace* ns) const {
Metrics result; Metrics result;
util::fb2::Mutex mu; util::fb2::Mutex mu;
@ -1942,7 +1943,7 @@ Metrics ServerFamily::GetMetrics() const {
if (shard) { if (shard) {
result.heap_used_bytes += shard->UsedMemory(); result.heap_used_bytes += shard->UsedMemory();
MergeDbSliceStats(shard->db_slice().GetStats(), &result); MergeDbSliceStats(ns->GetDbSlice(shard->shard_id()).GetStats(), &result);
result.shard_stats += shard->stats(); result.shard_stats += shard->stats();
if (shard->tiered_storage()) { if (shard->tiered_storage()) {
@ -2017,7 +2018,7 @@ void ServerFamily::Info(CmdArgList args, ConnectionContext* cntx) {
absl::StrAppend(&info, a1, ":", a2, "\r\n"); absl::StrAppend(&info, a1, ":", a2, "\r\n");
}; };
Metrics m = GetMetrics(); Metrics m = GetMetrics(cntx->ns);
DbStats total; DbStats total;
for (const auto& db_stats : m.db_stats) for (const auto& db_stats : m.db_stats)
total += db_stats; total += db_stats;
@ -2589,8 +2590,9 @@ void ServerFamily::ReplicaOf(CmdArgList args, ConnectionContext* cntx) {
void ServerFamily::Replicate(string_view host, string_view port) { void ServerFamily::Replicate(string_view host, string_view port) {
io::NullSink sink; io::NullSink sink;
ConnectionContext ctxt{&sink, nullptr, {}}; ConnectionContext cntx{&sink, nullptr, {}};
ctxt.skip_acl_validation = true; cntx.ns = &namespaces.GetDefaultNamespace();
cntx.skip_acl_validation = true;
StringVec replicaof_params{string(host), string(port)}; StringVec replicaof_params{string(host), string(port)};
@ -2599,7 +2601,7 @@ void ServerFamily::Replicate(string_view host, string_view port) {
args_vec.emplace_back(MutableSlice{s.data(), s.size()}); args_vec.emplace_back(MutableSlice{s.data(), s.size()});
} }
CmdArgList args_list = absl::MakeSpan(args_vec); CmdArgList args_list = absl::MakeSpan(args_vec);
ReplicaOfInternal(args_list, &ctxt, ActionOnConnectionFail::kContinueReplication); ReplicaOfInternal(args_list, &cntx, ActionOnConnectionFail::kContinueReplication);
} }
// REPLTAKEOVER <seconds> [SAVE] // REPLTAKEOVER <seconds> [SAVE]

View file

@ -15,6 +15,7 @@
#include "server/detail/save_stages_controller.h" #include "server/detail/save_stages_controller.h"
#include "server/dflycmd.h" #include "server/dflycmd.h"
#include "server/engine_shard_set.h" #include "server/engine_shard_set.h"
#include "server/namespaces.h"
#include "server/replica.h" #include "server/replica.h"
#include "server/server_state.h" #include "server/server_state.h"
#include "util/fibers/fiberqueue_threadpool.h" #include "util/fibers/fiberqueue_threadpool.h"
@ -158,9 +159,9 @@ class ServerFamily {
return service_; return service_;
} }
void ResetStat(); void ResetStat(Namespace* ns);
Metrics GetMetrics() const; Metrics GetMetrics(Namespace* ns) const;
ScriptMgr* script_mgr() { ScriptMgr* script_mgr() {
return script_mgr_.get(); return script_mgr_.get();
@ -337,7 +338,7 @@ class ServerFamily {
}; };
// Reusable CLIENT PAUSE implementation that blocks while polling is_pause_in_progress // Reusable CLIENT PAUSE implementation that blocks while polling is_pause_in_progress
std::optional<util::fb2::Fiber> Pause(std::vector<facade::Listener*> listeners, std::optional<util::fb2::Fiber> Pause(std::vector<facade::Listener*> listeners, Namespace* ns,
facade::Connection* conn, ClientPause pause_state, facade::Connection* conn, ClientPause pause_state,
std::function<bool()> is_pause_in_progress); std::function<bool()> is_pause_in_progress);

View file

@ -648,9 +648,9 @@ OpResult<streamID> OpAdd(const OpArgs& op_args, const AddTrimOpts& opts, CmdArgL
StreamTrim(opts, stream_inst); StreamTrim(opts, stream_inst);
EngineShard* es = op_args.shard; auto blocking_controller = op_args.db_cntx.ns->GetBlockingController(op_args.shard->shard_id());
if (es->blocking_controller()) { if (blocking_controller) {
es->blocking_controller()->AwakeWatched(op_args.db_cntx.db_index, opts.key); blocking_controller->AwakeWatched(op_args.db_cntx.db_index, opts.key);
} }
return result_id; return result_id;
@ -2364,7 +2364,7 @@ void StreamFamily::XInfo(CmdArgList args, ConnectionContext* cntx) {
// We do not use transactional xemantics for xinfo since it's informational command. // We do not use transactional xemantics for xinfo since it's informational command.
auto cb = [&]() { auto cb = [&]() {
EngineShard* shard = EngineShard::tlocal(); EngineShard* shard = EngineShard::tlocal();
DbContext db_context{cntx->db_index(), GetCurrentTimeMs()}; DbContext db_context{cntx->ns, cntx->db_index(), GetCurrentTimeMs()};
return OpListGroups(db_context, key, shard); return OpListGroups(db_context, key, shard);
}; };
@ -2433,7 +2433,8 @@ void StreamFamily::XInfo(CmdArgList args, ConnectionContext* cntx) {
auto cb = [&]() { auto cb = [&]() {
EngineShard* shard = EngineShard::tlocal(); EngineShard* shard = EngineShard::tlocal();
return OpStreams(DbContext{cntx->db_index(), GetCurrentTimeMs()}, key, shard, full, count); return OpStreams(DbContext{cntx->ns, cntx->db_index(), GetCurrentTimeMs()}, key, shard,
full, count);
}; };
auto* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder()); auto* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder());
@ -2561,8 +2562,8 @@ void StreamFamily::XInfo(CmdArgList args, ConnectionContext* cntx) {
string_view stream_name = ArgS(args, 1); string_view stream_name = ArgS(args, 1);
string_view group_name = ArgS(args, 2); string_view group_name = ArgS(args, 2);
auto cb = [&]() { auto cb = [&]() {
return OpConsumers(DbContext{cntx->db_index(), GetCurrentTimeMs()}, EngineShard::tlocal(), return OpConsumers(DbContext{cntx->ns, cntx->db_index(), GetCurrentTimeMs()},
stream_name, group_name); EngineShard::tlocal(), stream_name, group_name);
}; };
OpResult<vector<ConsumerInfo>> result = shard_set->Await(sid, std::move(cb)); OpResult<vector<ConsumerInfo>> result = shard_set->Await(sid, std::move(cb));

View file

@ -1226,9 +1226,10 @@ void StringFamily::MSetNx(CmdArgList args, ConnectionContext* cntx) {
atomic_bool exists{false}; atomic_bool exists{false};
auto cb = [&](Transaction* t, EngineShard* es) { auto cb = [&](Transaction* t, EngineShard* es) {
auto args = t->GetShardArgs(es->shard_id()); auto sid = es->shard_id();
auto args = t->GetShardArgs(sid);
for (auto arg_it = args.begin(); arg_it != args.end(); ++arg_it) { for (auto arg_it = args.begin(); arg_it != args.end(); ++arg_it) {
auto it = es->db_slice().FindReadOnly(t->GetDbContext(), *arg_it).it; auto it = cntx->ns->GetDbSlice(sid).FindReadOnly(t->GetDbContext(), *arg_it).it;
++arg_it; ++arg_it;
if (IsValid(it)) { if (IsValid(it)) {
exists.store(true, memory_order_relaxed); exists.store(true, memory_order_relaxed);

View file

@ -804,7 +804,7 @@ TEST_F(StringFamilyTest, SetWithHashtagsNoCluster) {
auto fb = ExpectUsedKeys({"{key}1"}); auto fb = ExpectUsedKeys({"{key}1"});
EXPECT_EQ(Run({"set", "{key}1", "val1"}), "OK"); EXPECT_EQ(Run({"set", "{key}1", "val1"}), "OK");
fb.Join(); fb.Join();
EXPECT_FALSE(service_->IsLocked(0, "{key}1")); EXPECT_FALSE(IsLocked(0, "{key}1"));
fb = ExpectUsedKeys({"{key}2"}); fb = ExpectUsedKeys({"{key}2"});
EXPECT_EQ(Run({"set", "{key}2", "val2"}), "OK"); EXPECT_EQ(Run({"set", "{key}2", "val2"}), "OK");

View file

@ -81,7 +81,7 @@ void TransactionSuspension::Start() {
transaction_ = new dfly::Transaction{&cid}; transaction_ = new dfly::Transaction{&cid};
auto st = transaction_->InitByArgs(0, {}); auto st = transaction_->InitByArgs(&namespaces.GetDefaultNamespace(), 0, {});
CHECK_EQ(st, OpStatus::OK); CHECK_EQ(st, OpStatus::OK);
transaction_->Execute([](Transaction* t, EngineShard* shard) { return OpStatus::OK; }, false); transaction_->Execute([](Transaction* t, EngineShard* shard) { return OpStatus::OK; }, false);
@ -107,7 +107,9 @@ class BaseFamilyTest::TestConnWrapper {
const facade::Connection::InvalidationMessage& GetInvalidationMessage(size_t index) const; const facade::Connection::InvalidationMessage& GetInvalidationMessage(size_t index) const;
ConnectionContext* cmd_cntx() { ConnectionContext* cmd_cntx() {
return static_cast<ConnectionContext*>(dummy_conn_->cntx()); auto cntx = static_cast<ConnectionContext*>(dummy_conn_->cntx());
cntx->ns = &namespaces.GetDefaultNamespace();
return cntx;
} }
StringVec SplitLines() const { StringVec SplitLines() const {
@ -210,7 +212,10 @@ void BaseFamilyTest::ResetService() {
used_mem_current = 0; used_mem_current = 0;
TEST_current_time_ms = absl::GetCurrentTimeNanos() / 1000000; TEST_current_time_ms = absl::GetCurrentTimeNanos() / 1000000;
auto cb = [&](EngineShard* s) { s->db_slice().UpdateExpireBase(TEST_current_time_ms - 1000, 0); }; auto default_ns = &namespaces.GetDefaultNamespace();
auto cb = [&](EngineShard* s) {
default_ns->GetDbSlice(s->shard_id()).UpdateExpireBase(TEST_current_time_ms - 1000, 0);
};
shard_set->RunBriefInParallel(cb); shard_set->RunBriefInParallel(cb);
const TestInfo* const test_info = UnitTest::GetInstance()->current_test_info(); const TestInfo* const test_info = UnitTest::GetInstance()->current_test_info();
@ -244,7 +249,10 @@ void BaseFamilyTest::ResetService() {
} }
LOG(ERROR) << "TxLocks for shard " << es->shard_id(); LOG(ERROR) << "TxLocks for shard " << es->shard_id();
for (const auto& k_v : es->db_slice().GetDBTable(0)->trans_locks) { for (const auto& k_v : namespaces.GetDefaultNamespace()
.GetDbSlice(es->shard_id())
.GetDBTable(0)
->trans_locks) {
LOG(ERROR) << "Key " << k_v.first << " " << k_v.second; LOG(ERROR) << "Key " << k_v.first << " " << k_v.second;
} }
} }
@ -264,6 +272,7 @@ void BaseFamilyTest::ShutdownService() {
service_->Shutdown(); service_->Shutdown();
service_.reset(); service_.reset();
delete shard_set; delete shard_set;
shard_set = nullptr; shard_set = nullptr;
@ -295,8 +304,9 @@ void BaseFamilyTest::CleanupSnapshots() {
unsigned BaseFamilyTest::NumLocked() { unsigned BaseFamilyTest::NumLocked() {
atomic_uint count = 0; atomic_uint count = 0;
auto default_ns = &namespaces.GetDefaultNamespace();
shard_set->RunBriefInParallel([&](EngineShard* shard) { shard_set->RunBriefInParallel([&](EngineShard* shard) {
for (const auto& db : shard->db_slice().databases()) { for (const auto& db : default_ns->GetDbSlice(shard->shard_id()).databases()) {
if (db == nullptr) { if (db == nullptr) {
continue; continue;
} }
@ -375,6 +385,7 @@ RespExpr BaseFamilyTest::Run(std::string_view id, ArgSlice slice) {
CmdArgVec args = conn_wrapper->Args(slice); CmdArgVec args = conn_wrapper->Args(slice);
auto* context = conn_wrapper->cmd_cntx(); auto* context = conn_wrapper->cmd_cntx();
context->ns = &namespaces.GetDefaultNamespace();
DCHECK(context->transaction == nullptr) << id; DCHECK(context->transaction == nullptr) << id;
@ -551,12 +562,7 @@ BaseFamilyTest::TestConnWrapper::GetInvalidationMessage(size_t index) const {
} }
bool BaseFamilyTest::IsLocked(DbIndex db_index, std::string_view key) const { bool BaseFamilyTest::IsLocked(DbIndex db_index, std::string_view key) const {
ShardId sid = Shard(key, shard_set->size()); return service_->IsLocked(&namespaces.GetDefaultNamespace(), db_index, key);
bool is_open = pp_->at(sid)->AwaitBrief([db_index, key] {
return EngineShard::tlocal()->db_slice().CheckLock(IntentLock::EXCLUSIVE, db_index, key);
});
return !is_open;
} }
string BaseFamilyTest::GetId() const { string BaseFamilyTest::GetId() const {
@ -643,7 +649,8 @@ vector<LockFp> BaseFamilyTest::GetLastFps() {
} }
lock_guard lk(mu); lock_guard lk(mu);
for (auto fp : shard->db_slice().TEST_GetLastLockedFps()) { for (auto fp :
namespaces.GetDefaultNamespace().GetDbSlice(shard->shard_id()).TEST_GetLastLockedFps()) {
result.push_back(fp); result.push_back(fp);
} }
}; };

View file

@ -117,7 +117,7 @@ class BaseFamilyTest : public ::testing::Test {
static std::vector<std::string> StrArray(const RespExpr& expr); static std::vector<std::string> StrArray(const RespExpr& expr);
Metrics GetMetrics() const { Metrics GetMetrics() const {
return service_->server_family().GetMetrics(); return service_->server_family().GetMetrics(&namespaces.GetDefaultNamespace());
} }
void ClearMetrics(); void ClearMetrics();

View file

@ -173,8 +173,9 @@ Transaction::~Transaction() {
<< " destroyed"; << " destroyed";
} }
void Transaction::InitBase(DbIndex dbid, CmdArgList args) { void Transaction::InitBase(Namespace* ns, DbIndex dbid, CmdArgList args) {
global_ = false; global_ = false;
namespace_ = ns;
db_index_ = dbid; db_index_ = dbid;
full_args_ = args; full_args_ = args;
local_result_ = OpStatus::OK; local_result_ = OpStatus::OK;
@ -359,8 +360,8 @@ void Transaction::InitByKeys(const KeyIndex& key_index) {
} }
} }
OpStatus Transaction::InitByArgs(DbIndex index, CmdArgList args) { OpStatus Transaction::InitByArgs(Namespace* ns, DbIndex index, CmdArgList args) {
InitBase(index, args); InitBase(ns, index, args);
if ((cid_->opt_mask() & CO::GLOBAL_TRANS) > 0) { if ((cid_->opt_mask() & CO::GLOBAL_TRANS) > 0) {
InitGlobal(); InitGlobal();
@ -393,7 +394,7 @@ void Transaction::PrepareSquashedMultiHop(const CommandId* cid,
MultiSwitchCmd(cid); MultiSwitchCmd(cid);
InitBase(db_index_, {}); InitBase(namespace_, db_index_, {});
// Because squashing already determines active shards by partitioning commands, // Because squashing already determines active shards by partitioning commands,
// we don't have to work with keys manually and can just mark active shards. // we don't have to work with keys manually and can just mark active shards.
@ -414,19 +415,20 @@ void Transaction::PrepareSquashedMultiHop(const CommandId* cid,
MultiBecomeSquasher(); MultiBecomeSquasher();
} }
void Transaction::StartMultiGlobal(DbIndex dbid) { void Transaction::StartMultiGlobal(Namespace* ns, DbIndex dbid) {
CHECK(multi_); CHECK(multi_);
CHECK(shard_data_.empty()); // Make sure default InitByArgs didn't run. CHECK(shard_data_.empty()); // Make sure default InitByArgs didn't run.
multi_->mode = GLOBAL; multi_->mode = GLOBAL;
InitBase(dbid, {}); InitBase(ns, dbid, {});
InitGlobal(); InitGlobal();
multi_->lock_mode = IntentLock::EXCLUSIVE; multi_->lock_mode = IntentLock::EXCLUSIVE;
ScheduleInternal(); ScheduleInternal();
} }
void Transaction::StartMultiLockedAhead(DbIndex dbid, CmdArgList keys, bool skip_scheduling) { void Transaction::StartMultiLockedAhead(Namespace* ns, DbIndex dbid, CmdArgList keys,
bool skip_scheduling) {
DVLOG(1) << "StartMultiLockedAhead on " << keys.size() << " keys"; DVLOG(1) << "StartMultiLockedAhead on " << keys.size() << " keys";
DCHECK(multi_); DCHECK(multi_);
@ -437,7 +439,7 @@ void Transaction::StartMultiLockedAhead(DbIndex dbid, CmdArgList keys, bool skip
PrepareMultiFps(keys); PrepareMultiFps(keys);
InitBase(dbid, keys); InitBase(ns, dbid, keys);
InitByKeys(KeyIndex::Range(0, keys.size())); InitByKeys(KeyIndex::Range(0, keys.size()));
if (!skip_scheduling) if (!skip_scheduling)
@ -504,6 +506,7 @@ void Transaction::MultiUpdateWithParent(const Transaction* parent) {
txid_ = parent->txid_; txid_ = parent->txid_;
time_now_ms_ = parent->time_now_ms_; time_now_ms_ = parent->time_now_ms_;
unique_slot_checker_ = parent->unique_slot_checker_; unique_slot_checker_ = parent->unique_slot_checker_;
namespace_ = parent->namespace_;
} }
void Transaction::MultiBecomeSquasher() { void Transaction::MultiBecomeSquasher() {
@ -528,9 +531,10 @@ string Transaction::DebugId(std::optional<ShardId> sid) const {
return res; return res;
} }
void Transaction::PrepareMultiForScheduleSingleHop(ShardId sid, DbIndex db, CmdArgList args) { void Transaction::PrepareMultiForScheduleSingleHop(Namespace* ns, ShardId sid, DbIndex db,
CmdArgList args) {
multi_.reset(); multi_.reset();
InitBase(db, args); InitBase(ns, db, args);
EnableShard(sid); EnableShard(sid);
OpResult<KeyIndex> key_index = DetermineKeys(cid_, args); OpResult<KeyIndex> key_index = DetermineKeys(cid_, args);
CHECK(key_index); CHECK(key_index);
@ -608,7 +612,8 @@ bool Transaction::RunInShard(EngineShard* shard, bool txq_ooo) {
// 1: to go over potential wakened keys, verify them and activate watch queues. // 1: to go over potential wakened keys, verify them and activate watch queues.
// 2: if this transaction was notified and finished running - to remove it from the head // 2: if this transaction was notified and finished running - to remove it from the head
// of the queue and notify the next one. // of the queue and notify the next one.
if (auto* bcontroller = shard->blocking_controller(); bcontroller) {
if (auto* bcontroller = namespace_->GetBlockingController(shard->shard_id()); bcontroller) {
if (awaked_prerun || was_suspended) { if (awaked_prerun || was_suspended) {
bcontroller->FinalizeWatched(GetShardArgs(idx), this); bcontroller->FinalizeWatched(GetShardArgs(idx), this);
} }
@ -864,6 +869,7 @@ void Transaction::DispatchHop() {
use_count_.fetch_add(run_cnt, memory_order_relaxed); // for each pointer from poll_cb use_count_.fetch_add(run_cnt, memory_order_relaxed); // for each pointer from poll_cb
auto poll_cb = [this] { auto poll_cb = [this] {
CHECK(namespace_ != nullptr);
EngineShard::tlocal()->PollExecution("exec_cb", this); EngineShard::tlocal()->PollExecution("exec_cb", this);
DVLOG(3) << "ptr_release " << DebugId(); DVLOG(3) << "ptr_release " << DebugId();
intrusive_ptr_release(this); // against use_count_.fetch_add above. intrusive_ptr_release(this); // against use_count_.fetch_add above.
@ -1149,7 +1155,7 @@ OpStatus Transaction::WaitOnWatch(const time_point& tp, WaitKeysProvider wkeys_p
// Register keys on active shards blocking controllers and mark shard state as suspended. // Register keys on active shards blocking controllers and mark shard state as suspended.
auto cb = [&](Transaction* t, EngineShard* shard) { auto cb = [&](Transaction* t, EngineShard* shard) {
auto keys = wkeys_provider(t, shard); auto keys = wkeys_provider(t, shard);
return t->WatchInShard(keys, shard, krc); return t->WatchInShard(&t->GetNamespace(), keys, shard, krc);
}; };
Execute(std::move(cb), true); Execute(std::move(cb), true);
@ -1187,7 +1193,7 @@ OpStatus Transaction::WaitOnWatch(const time_point& tp, WaitKeysProvider wkeys_p
return result; return result;
} }
OpStatus Transaction::WatchInShard(BlockingController::Keys keys, EngineShard* shard, OpStatus Transaction::WatchInShard(Namespace* ns, BlockingController::Keys keys, EngineShard* shard,
KeyReadyChecker krc) { KeyReadyChecker krc) {
auto& sd = shard_data_[SidToId(shard->shard_id())]; auto& sd = shard_data_[SidToId(shard->shard_id())];
@ -1195,7 +1201,7 @@ OpStatus Transaction::WatchInShard(BlockingController::Keys keys, EngineShard* s
sd.local_mask |= SUSPENDED_Q; sd.local_mask |= SUSPENDED_Q;
sd.local_mask &= ~OUT_OF_ORDER; sd.local_mask &= ~OUT_OF_ORDER;
shard->EnsureBlockingController()->AddWatched(keys, std::move(krc), this); ns->GetOrAddBlockingController(shard)->AddWatched(keys, std::move(krc), this);
DVLOG(2) << "WatchInShard " << DebugId(); DVLOG(2) << "WatchInShard " << DebugId();
return OpStatus::OK; return OpStatus::OK;
@ -1209,8 +1215,10 @@ void Transaction::ExpireShardCb(BlockingController::Keys keys, EngineShard* shar
auto& sd = shard_data_[SidToId(shard->shard_id())]; auto& sd = shard_data_[SidToId(shard->shard_id())];
sd.local_mask &= ~KEYLOCK_ACQUIRED; sd.local_mask &= ~KEYLOCK_ACQUIRED;
shard->blocking_controller()->FinalizeWatched(keys, this); namespace_->GetBlockingController(shard->shard_id())->FinalizeWatched(keys, this);
DCHECK(!shard->blocking_controller()->awakened_transactions().contains(this)); DCHECK(!namespace_->GetBlockingController(shard->shard_id())
->awakened_transactions()
.contains(this));
// Resume processing of transaction queue // Resume processing of transaction queue
shard->PollExecution("unwatchcb", nullptr); shard->PollExecution("unwatchcb", nullptr);
@ -1218,9 +1226,8 @@ void Transaction::ExpireShardCb(BlockingController::Keys keys, EngineShard* shar
} }
DbSlice& Transaction::GetDbSlice(ShardId shard_id) const { DbSlice& Transaction::GetDbSlice(ShardId shard_id) const {
auto* shard = EngineShard::tlocal(); CHECK(namespace_ != nullptr);
DCHECK_EQ(shard->shard_id(), shard_id); return namespace_->GetDbSlice(shard_id);
return shard->db_slice();
} }
OpStatus Transaction::RunSquashedMultiCb(RunnableType cb) { OpStatus Transaction::RunSquashedMultiCb(RunnableType cb) {
@ -1270,8 +1277,9 @@ void Transaction::UnlockMultiShardCb(absl::Span<const LockFp> fps, EngineShard*
shard->RemoveContTx(this); shard->RemoveContTx(this);
// Wake only if no tx queue head is currently running // Wake only if no tx queue head is currently running
if (shard->blocking_controller() && shard->GetContTx() == nullptr) auto bc = namespace_->GetBlockingController(shard->shard_id());
shard->blocking_controller()->NotifyPending(); if (bc && shard->GetContTx() == nullptr)
bc->NotifyPending();
shard->PollExecution("unlockmulti", nullptr); shard->PollExecution("unlockmulti", nullptr);
} }

View file

@ -21,6 +21,7 @@
#include "server/cluster/cluster_utility.h" #include "server/cluster/cluster_utility.h"
#include "server/common.h" #include "server/common.h"
#include "server/journal/types.h" #include "server/journal/types.h"
#include "server/namespaces.h"
#include "server/table.h" #include "server/table.h"
#include "server/tx_base.h" #include "server/tx_base.h"
#include "util/fibers/synchronization.h" #include "util/fibers/synchronization.h"
@ -185,7 +186,7 @@ class Transaction {
std::optional<cluster::SlotId> slot_id); std::optional<cluster::SlotId> slot_id);
// Initialize from command (args) on specific db. // Initialize from command (args) on specific db.
OpStatus InitByArgs(DbIndex index, CmdArgList args); OpStatus InitByArgs(Namespace* ns, DbIndex index, CmdArgList args);
// Get command arguments for specific shard. Called from shard thread. // Get command arguments for specific shard. Called from shard thread.
ShardArgs GetShardArgs(ShardId sid) const; ShardArgs GetShardArgs(ShardId sid) const;
@ -230,10 +231,11 @@ class Transaction {
void PrepareSquashedMultiHop(const CommandId* cid, absl::FunctionRef<bool(ShardId)> enabled); void PrepareSquashedMultiHop(const CommandId* cid, absl::FunctionRef<bool(ShardId)> enabled);
// Start multi in GLOBAL mode. // Start multi in GLOBAL mode.
void StartMultiGlobal(DbIndex dbid); void StartMultiGlobal(Namespace* ns, DbIndex dbid);
// Start multi in LOCK_AHEAD mode with given keys. // Start multi in LOCK_AHEAD mode with given keys.
void StartMultiLockedAhead(DbIndex dbid, CmdArgList keys, bool skip_scheduling = false); void StartMultiLockedAhead(Namespace* ns, DbIndex dbid, CmdArgList keys,
bool skip_scheduling = false);
// Start multi in NON_ATOMIC mode. // Start multi in NON_ATOMIC mode.
void StartMultiNonAtomic(); void StartMultiNonAtomic();
@ -311,7 +313,11 @@ class Transaction {
bool IsGlobal() const; bool IsGlobal() const;
DbContext GetDbContext() const { DbContext GetDbContext() const {
return DbContext{db_index_, time_now_ms_}; return DbContext{namespace_, db_index_, time_now_ms_};
}
Namespace& GetNamespace() const {
return *namespace_;
} }
DbSlice& GetDbSlice(ShardId sid) const; DbSlice& GetDbSlice(ShardId sid) const;
@ -330,7 +336,7 @@ class Transaction {
// Prepares for running ScheduleSingleHop() for a single-shard multi tx. // Prepares for running ScheduleSingleHop() for a single-shard multi tx.
// It is safe to call ScheduleSingleHop() after calling this method, but the callback passed // It is safe to call ScheduleSingleHop() after calling this method, but the callback passed
// to it must not block. // to it must not block.
void PrepareMultiForScheduleSingleHop(ShardId sid, DbIndex db, CmdArgList args); void PrepareMultiForScheduleSingleHop(Namespace* ns, ShardId sid, DbIndex db, CmdArgList args);
// Write a journal entry to a shard journal with the given payload. // Write a journal entry to a shard journal with the given payload.
void LogJournalOnShard(EngineShard* shard, journal::Entry::Payload&& payload, uint32_t shard_cnt, void LogJournalOnShard(EngineShard* shard, journal::Entry::Payload&& payload, uint32_t shard_cnt,
@ -479,7 +485,7 @@ class Transaction {
}; };
// Init basic fields and reset re-usable. // Init basic fields and reset re-usable.
void InitBase(DbIndex dbid, CmdArgList args); void InitBase(Namespace* ns, DbIndex dbid, CmdArgList args);
// Init as a global transaction. // Init as a global transaction.
void InitGlobal(); void InitGlobal();
@ -518,7 +524,7 @@ class Transaction {
void RunCallback(EngineShard* shard); void RunCallback(EngineShard* shard);
// Adds itself to watched queue in the shard. Must run in that shard thread. // Adds itself to watched queue in the shard. Must run in that shard thread.
OpStatus WatchInShard(std::variant<ShardArgs, ArgSlice> keys, EngineShard* shard, OpStatus WatchInShard(Namespace* ns, std::variant<ShardArgs, ArgSlice> keys, EngineShard* shard,
KeyReadyChecker krc); KeyReadyChecker krc);
// Expire blocking transaction, unlock keys and unregister it from the blocking controller // Expire blocking transaction, unlock keys and unregister it from the blocking controller
@ -612,6 +618,7 @@ class Transaction {
TxId txid_{0}; TxId txid_{0};
bool global_{false}; bool global_{false};
Namespace* namespace_{nullptr};
DbIndex db_index_{0}; DbIndex db_index_{0};
uint64_t time_now_ms_{0}; uint64_t time_now_ms_{0};

View file

@ -9,6 +9,7 @@
#include "server/cluster/cluster_defs.h" #include "server/cluster/cluster_defs.h"
#include "server/engine_shard_set.h" #include "server/engine_shard_set.h"
#include "server/journal/journal.h" #include "server/journal/journal.h"
#include "server/namespaces.h"
#include "server/transaction.h" #include "server/transaction.h"
namespace dfly { namespace dfly {
@ -17,14 +18,11 @@ using namespace std;
using Payload = journal::Entry::Payload; using Payload = journal::Entry::Payload;
DbSlice& DbContext::GetDbSlice(ShardId shard_id) const { DbSlice& DbContext::GetDbSlice(ShardId shard_id) const {
// TODO: Update this when adding namespaces return ns->GetDbSlice(shard_id);
DCHECK_EQ(shard_id, EngineShard::tlocal()->shard_id());
return EngineShard::tlocal()->db_slice();
} }
DbSlice& OpArgs::GetDbSlice() const { DbSlice& OpArgs::GetDbSlice() const {
// TODO: Update this when adding namespaces return db_cntx.GetDbSlice(shard->shard_id());
return shard->db_slice();
} }
size_t ShardArgs::Size() const { size_t ShardArgs::Size() const {

View file

@ -14,6 +14,7 @@ namespace dfly {
class EngineShard; class EngineShard;
class Transaction; class Transaction;
class Namespace;
class DbSlice; class DbSlice;
using DbIndex = uint16_t; using DbIndex = uint16_t;
@ -58,6 +59,7 @@ struct KeyIndex {
}; };
struct DbContext { struct DbContext {
Namespace* ns = nullptr;
DbIndex db_index = 0; DbIndex db_index = 0;
uint64_t time_now_ms = 0; uint64_t time_now_ms = 0;

View file

@ -213,10 +213,11 @@ OpResult<DbSlice::ItAndUpdater> FindZEntry(const ZParams& zparams, const OpArgs&
return OpStatus::WRONG_TYPE; return OpStatus::WRONG_TYPE;
} }
if (add_res.is_new && op_args.shard->blocking_controller()) { auto* blocking_controller = op_args.db_cntx.ns->GetBlockingController(op_args.shard->shard_id());
if (add_res.is_new && blocking_controller) {
string tmp; string tmp;
string_view key = it->first.GetSlice(&tmp); string_view key = it->first.GetSlice(&tmp);
op_args.shard->blocking_controller()->AwakeWatched(op_args.db_cntx.db_index, key); blocking_controller->AwakeWatched(op_args.db_cntx.db_index, key);
} }
return DbSlice::ItAndUpdater{add_res.it, add_res.exp_it, std::move(add_res.post_updater)}; return DbSlice::ItAndUpdater{add_res.it, add_res.exp_it, std::move(add_res.post_updater)};

View file

@ -2,7 +2,7 @@ import pytest
import redis import redis
from redis import asyncio as aioredis from redis import asyncio as aioredis
from .instance import DflyInstanceFactory from .instance import DflyInstanceFactory
from .utility import disconnect_clients from .utility import *
import tempfile import tempfile
import asyncio import asyncio
import os import os
@ -567,6 +567,43 @@ async def test_acl_keys(async_client):
await async_client.execute_command("ZUNIONSTORE destkey 2 barz1 barz2") await async_client.execute_command("ZUNIONSTORE destkey 2 barz1 barz2")
@pytest.mark.asyncio
async def test_namespaces(df_factory):
df = df_factory.create()
df.start()
admin = aioredis.Redis(port=df.port)
assert await admin.execute_command("SET foo admin") == b"OK"
assert await admin.execute_command("GET foo") == b"admin"
# Create ns space named 'ns1'
await admin.execute_command("ACL SETUSER adi NAMESPACE:ns1 ON >adi_pass +@all ~*")
adi = aioredis.Redis(port=df.port)
assert await adi.execute_command("AUTH adi adi_pass") == b"OK"
assert await adi.execute_command("SET foo bar") == b"OK"
assert await adi.execute_command("GET foo") == b"bar"
assert await admin.execute_command("GET foo") == b"admin"
# Adi and Shahar are on the same team
await admin.execute_command("ACL SETUSER shahar NAMESPACE:ns1 ON >shahar_pass +@all ~*")
shahar = aioredis.Redis(port=df.port)
assert await shahar.execute_command("AUTH shahar shahar_pass") == b"OK"
assert await shahar.execute_command("GET foo") == b"bar"
assert await shahar.execute_command("SET foo bar2") == b"OK"
assert await adi.execute_command("GET foo") == b"bar2"
# Roman is a CTO, he has his own private space
await admin.execute_command("ACL SETUSER roman NAMESPACE:ns2 ON >roman_pass +@all ~*")
roman = aioredis.Redis(port=df.port)
assert await roman.execute_command("AUTH roman roman_pass") == b"OK"
assert await roman.execute_command("GET foo") == None
await close_clients(admin, adi, shahar, roman)
@pytest.mark.asyncio @pytest.mark.asyncio
async def default_user_bug(df_factory): async def default_user_bug(df_factory):
df.start() df.start()