1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-14 11:58:02 +00:00

feat(namespaces): Initial support for multi-tenant (#3260)

* feat(namespaces): Initial support for multi-tenant #3050

This PR introduces a way to create multiple, separate and isolated
namespaces in Dragonfly. Each user can be associated with a single
namespace, and will not be able to interact with other namespaces.

This is still experimental, and lacks some important features, such as:
* Replication and RDB saving completely ignores non-default namespaces
* Defrag and statistics either use the default namespace or all
  namespaces without separation

To associate a user with a namespace, use the `ACL` command with the
`TENANT:<namespace>` flag:

```
ACL SETUSER user TENANT:namespace1 ON >user_pass +@all ~*
```

For more examples and up to date info check
`tests/dragonfly/acl_family_test.py` - specifically the
`test_namespaces` function.
This commit is contained in:
Shahar Mike 2024-07-16 19:34:49 +03:00 committed by GitHub
parent 3891efac2c
commit 18ca61d29b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
51 changed files with 600 additions and 255 deletions

View file

@ -1,6 +1,6 @@
# Namespaces in Dragonfly
Dragonfly will soon add an _experimental_ feature, allowing complete separation of data by different users.
Dragonfly added an _experimental_ feature, allowing complete separation of data by different users.
We call this feature _namespaces_, and it allows using a single Dragonfly server with multiple
tenants, each using their own data, without being able to mix them together.

View file

@ -28,6 +28,7 @@ struct UserCredentials {
uint32_t acl_categories{0};
std::vector<uint64_t> acl_commands;
AclKeys keys;
std::string ns;
};
} // namespace dfly::acl

View file

@ -28,7 +28,7 @@ endif()
add_library(dfly_transaction db_slice.cc malloc_stats.cc blocking_controller.cc
command_registry.cc cluster/cluster_utility.cc
journal/tx_executor.cc
journal/tx_executor.cc namespaces.cc
common.cc journal/journal.cc journal/types.cc journal/journal_slice.cc
server_state.cc table.cc top_keys.cc transaction.cc tx_base.cc
serializer_commons.cc journal/serializer.cc journal/executor.cc journal/streamer.cc

View file

@ -942,6 +942,14 @@ std::pair<OptCat, bool> AclFamily::MaybeParseAclCategory(std::string_view comman
return {};
}
std::optional<std::string> AclFamily::MaybeParseNamespace(std::string_view command) const {
constexpr std::string_view kPrefix = "NAMESPACE:";
if (absl::StartsWith(command, kPrefix)) {
return std::string(command.substr(kPrefix.size()));
}
return std::nullopt;
}
std::pair<AclFamily::OptCommand, bool> AclFamily::MaybeParseAclCommand(
std::string_view command) const {
if (absl::StartsWith(command, "+")) {
@ -1019,6 +1027,12 @@ std::variant<User::UpdateRequest, ErrorReply> AclFamily::ParseAclSetUser(
continue;
}
auto ns = MaybeParseNamespace(command);
if (ns.has_value()) {
req.ns = *ns;
continue;
}
auto [cmd, sign] = MaybeParseAclCommand(command);
if (!cmd) {
return ErrorReply(absl::StrCat("Unrecognized parameter ", command));

View file

@ -80,6 +80,8 @@ class AclFamily final {
using OptCommand = std::optional<std::pair<size_t, uint64_t>>;
std::pair<OptCommand, bool> MaybeParseAclCommand(std::string_view command) const;
std::optional<std::string> MaybeParseNamespace(std::string_view command) const;
std::variant<User::UpdateRequest, facade::ErrorReply> ParseAclSetUser(
const facade::ArgRange& args, bool hashed = false, bool has_all_keys = false) const;

View file

@ -79,6 +79,8 @@ void User::Update(UpdateRequest&& req, const CategoryToIdxStore& cat_to_id,
if (req.is_active) {
SetIsActive(*req.is_active);
}
SetNamespace(req.ns);
}
void User::SetPasswordHash(std::string_view password, bool is_hashed) {
@ -94,6 +96,14 @@ void User::UnsetPassword(std::string_view password) {
password_hashes_.erase(StringSHA256(password));
}
void User::SetNamespace(const std::string& ns) {
namespace_ = ns;
}
const std::string& User::Namespace() const {
return namespace_;
}
bool User::HasPassword(std::string_view password) const {
if (nopass_) {
return true;

View file

@ -58,8 +58,11 @@ class User final {
std::vector<UpdateKey> keys;
bool reset_all_keys{false};
bool allow_all_keys{false};
// TODO allow reset all
// bool reset_all{false};
std::string ns;
};
using CategoryChange = uint32_t;
@ -104,6 +107,8 @@ class User final {
const AclKeys& Keys() const;
const std::string& Namespace() const;
using CategoryChanges = absl::flat_hash_map<CategoryChange, ChangeMetadata>;
using CommandChanges = absl::flat_hash_map<CommandChange, ChangeMetadata>;
@ -135,6 +140,7 @@ class User final {
// For ACL key globs
void SetKeyGlobs(std::vector<UpdateKey> keys);
void SetNamespace(const std::string& ns);
// Set NOPASS and remove all passwords
void SetNopass();
@ -166,6 +172,8 @@ class User final {
// if the user is on/off
bool is_active_{false};
std::string namespace_;
};
} // namespace dfly::acl

View file

@ -35,7 +35,8 @@ UserCredentials UserRegistry::GetCredentials(std::string_view username) const {
if (it == registry_.end()) {
return {};
}
return {it->second.AclCategory(), it->second.AclCommands(), it->second.Keys()};
return {it->second.AclCategory(), it->second.AclCommands(), it->second.Keys(),
it->second.Namespace()};
}
bool UserRegistry::IsUserActive(std::string_view username) const {
@ -73,10 +74,13 @@ UserRegistry::UserWithWriteLock::UserWithWriteLock(std::unique_lock<fb2::SharedM
}
User::UpdateRequest UserRegistry::DefaultUserUpdateRequest() const {
std::pair<User::Sign, uint32_t> acl{User::Sign::PLUS, acl::ALL};
auto key = User::UpdateKey{"~*", KeyOp::READ_WRITE, true, false};
auto pass = std::vector<User::UpdatePass>{{"", false, true}};
return {std::move(pass), true, false, {std::move(acl)}, {std::move(key)}};
// Assign field by field to supress an annoying compiler warning
User::UpdateRequest req;
req.passwords = std::vector<User::UpdatePass>{{"", false, true}};
req.is_active = true;
req.updates = {std::pair<User::Sign, uint32_t>{User::Sign::PLUS, acl::ALL}};
req.keys = {User::UpdateKey{"~*", KeyOp::READ_WRITE, true, false}};
return req;
}
void UserRegistry::Init(const CategoryToIdxStore* cat_to_id_table,

View file

@ -10,6 +10,7 @@
#include "base/logging.h"
#include "server/engine_shard_set.h"
#include "server/namespaces.h"
#include "server/transaction.h"
namespace dfly {
@ -102,7 +103,7 @@ bool BlockingController::DbWatchTable::UnwatchTx(string_view key, Transaction* t
return res;
}
BlockingController::BlockingController(EngineShard* owner) : owner_(owner) {
BlockingController::BlockingController(EngineShard* owner, Namespace* ns) : owner_(owner), ns_(ns) {
}
BlockingController::~BlockingController() {
@ -153,6 +154,7 @@ void BlockingController::NotifyPending() {
CHECK(tx == nullptr) << tx->DebugId();
DbContext context;
context.ns = ns_;
context.time_now_ms = GetCurrentTimeMs();
for (DbIndex index : awakened_indices_) {

View file

@ -15,10 +15,11 @@
namespace dfly {
class Transaction;
class Namespace;
class BlockingController {
public:
explicit BlockingController(EngineShard* owner);
explicit BlockingController(EngineShard* owner, Namespace* ns);
~BlockingController();
using Keys = std::variant<ShardArgs, ArgSlice>;
@ -60,6 +61,7 @@ class BlockingController {
// void NotifyConvergence(Transaction* tx);
EngineShard* owner_;
Namespace* ns_;
absl::flat_hash_map<DbIndex, std::unique_ptr<DbWatchTable>> watched_dbs_;

View file

@ -61,7 +61,7 @@ void BlockingControllerTest::SetUp() {
arg_vec_.emplace_back(s);
}
trans_->InitByArgs(0, {arg_vec_.data(), arg_vec_.size()});
trans_->InitByArgs(&namespaces.GetDefaultNamespace(), 0, {arg_vec_.data(), arg_vec_.size()});
CHECK_EQ(0u, Shard("x", shard_set->size()));
CHECK_EQ(2u, Shard("z", shard_set->size()));
@ -70,6 +70,8 @@ void BlockingControllerTest::SetUp() {
}
void BlockingControllerTest::TearDown() {
namespaces.Clear();
shard_set->Shutdown();
delete shard_set;
@ -79,7 +81,7 @@ void BlockingControllerTest::TearDown() {
TEST_F(BlockingControllerTest, Basic) {
trans_->ScheduleSingleHop([&](Transaction* t, EngineShard* shard) {
BlockingController bc(shard);
BlockingController bc(shard, &namespaces.GetDefaultNamespace());
auto keys = t->GetShardArgs(shard->shard_id());
bc.AddWatched(
keys, [](auto...) { return true; }, t);
@ -103,7 +105,12 @@ TEST_F(BlockingControllerTest, Timeout) {
EXPECT_EQ(status, facade::OpStatus::TIMED_OUT);
unsigned num_watched = shard_set->Await(
0, [&] { return EngineShard::tlocal()->blocking_controller()->NumWatched(0); });
0, [&] {
return namespaces.GetDefaultNamespace()
.GetBlockingController(EngineShard::tlocal()->shard_id())
->NumWatched(0);
});
EXPECT_EQ(0, num_watched);
trans_.reset();

View file

@ -21,6 +21,7 @@
#include "server/error.h"
#include "server/journal/journal.h"
#include "server/main_service.h"
#include "server/namespaces.h"
#include "server/server_family.h"
#include "server/server_state.h"
@ -451,7 +452,7 @@ void DeleteSlots(const SlotRanges& slots_ranges) {
if (shard == nullptr)
return;
shard->db_slice().FlushSlots(slots_ranges);
namespaces.GetDefaultNamespace().GetDbSlice(shard->shard_id()).FlushSlots(slots_ranges);
};
shard_set->pool()->AwaitFiberOnAll(std::move(cb));
}
@ -599,7 +600,7 @@ void ClusterFamily::DflyClusterGetSlotInfo(CmdArgList args, ConnectionContext* c
lock_guard lk(mu);
for (auto& [slot, data] : slots_stats) {
data += shard->db_slice().GetSlotStats(slot);
data += namespaces.GetDefaultNamespace().GetDbSlice(shard->shard_id()).GetSlotStats(slot);
}
};

View file

@ -2,6 +2,7 @@
#include "server/cluster/cluster_defs.h"
#include "server/engine_shard_set.h"
#include "server/namespaces.h"
using namespace std;
@ -49,7 +50,10 @@ uint64_t GetKeyCount(const SlotRanges& slots) {
uint64_t shard_keys = 0;
for (const SlotRange& range : slots) {
for (SlotId slot = range.start; slot <= range.end; slot++) {
shard_keys += shard->db_slice().GetSlotStats(slot).key_count;
shard_keys += namespaces.GetDefaultNamespace()
.GetDbSlice(shard->shard_id())
.GetSlotStats(slot)
.key_count;
}
}
keys.fetch_add(shard_keys);

View file

@ -96,7 +96,7 @@ OutgoingMigration::OutgoingMigration(MigrationInfo info, ClusterFamily* cf, Serv
server_family_(sf),
cf_(cf),
tx_(new Transaction{sf->service().FindCmd("DFLYCLUSTER")}) {
tx_->InitByArgs(0, {});
tx_->InitByArgs(&namespaces.GetDefaultNamespace(), 0, {});
}
OutgoingMigration::~OutgoingMigration() {
@ -212,10 +212,10 @@ void OutgoingMigration::SyncFb() {
}
OnAllShards([this](auto& migration) {
auto* shard = EngineShard::tlocal();
DbSlice& db_slice = namespaces.GetDefaultNamespace().GetCurrentDbSlice();
server_family_->journal()->StartInThread();
migration = std::make_unique<SliceSlotMigration>(
&shard->db_slice(), server(), migration_info_.slot_ranges, server_family_->journal());
&db_slice, server(), migration_info_.slot_ranges, server_family_->journal());
});
if (!ChangeState(MigrationState::C_SYNC)) {
@ -284,8 +284,9 @@ bool OutgoingMigration::FinalizeMigration(long attempt) {
// TODO implement blocking on migrated slots only
bool is_block_active = true;
auto is_pause_in_progress = [&is_block_active] { return is_block_active; };
auto pause_fb_opt = Pause(server_family_->GetNonPriviligedListeners(), nullptr,
ClientPause::WRITE, is_pause_in_progress);
auto pause_fb_opt =
Pause(server_family_->GetNonPriviligedListeners(), &namespaces.GetDefaultNamespace(), nullptr,
ClientPause::WRITE, is_pause_in_progress);
if (!pause_fb_opt) {
LOG(WARNING) << "Cluster migration finalization time out";

View file

@ -282,6 +282,7 @@ class ConnectionContext : public facade::ConnectionContext {
DebugInfo last_command_debug;
// TODO: to introduce proper accessors.
Namespace* ns = nullptr;
Transaction* transaction = nullptr;
const CommandId* cid = nullptr;

View file

@ -339,9 +339,10 @@ OpResult<string> RunCbOnFirstNonEmptyBlocking(Transaction* trans, int req_obj_ty
}
auto wcb = [](Transaction* t, EngineShard* shard) { return t->GetShardArgs(shard->shard_id()); };
const auto key_checker = [req_obj_type](EngineShard* owner, const DbContext& context,
Transaction*, std::string_view key) -> bool {
return context.GetDbSlice(owner->shard_id()).FindReadOnly(context, key, req_obj_type).ok();
auto* ns = &trans->GetNamespace();
const auto key_checker = [req_obj_type, ns](EngineShard* owner, const DbContext& context,
Transaction*, std::string_view key) -> bool {
return ns->GetDbSlice(owner->shard_id()).FindReadOnly(context, key, req_obj_type).ok();
};
auto status = trans->WaitOnWatch(limit_tp, std::move(wcb), key_checker, block_flag, pause_flag);

View file

@ -1078,7 +1078,7 @@ void DbSlice::ExpireAllIfNeeded() {
LOG(ERROR) << "Expire entry " << exp_it->first.ToString() << " not found in prime table";
return;
}
ExpireIfNeeded(Context{db_index, GetCurrentTimeMs()}, prime_it);
ExpireIfNeeded(Context{nullptr, db_index, GetCurrentTimeMs()}, prime_it);
};
ExpireTable::Cursor cursor;

View file

@ -159,7 +159,7 @@ void DoPopulateBatch(string_view type, string_view prefix, size_t val_size, bool
stub_tx->MultiSwitchCmd(cid);
local_cntx.cid = cid;
crb.SetReplyMode(ReplyMode::NONE);
stub_tx->InitByArgs(local_cntx.conn_state.db_index, args_span);
stub_tx->InitByArgs(cntx->ns, local_cntx.conn_state.db_index, args_span);
sf->service().InvokeCmd(cid, args_span, &local_cntx);
}
@ -261,8 +261,8 @@ void MergeObjHistMap(ObjHistMap&& src, ObjHistMap* dest) {
}
}
void DoBuildObjHist(EngineShard* shard, ObjHistMap* obj_hist_map) {
auto& db_slice = shard->db_slice();
void DoBuildObjHist(EngineShard* shard, ConnectionContext* cntx, ObjHistMap* obj_hist_map) {
auto& db_slice = cntx->ns->GetDbSlice(shard->shard_id());
unsigned steps = 0;
for (unsigned i = 0; i < db_slice.db_array_size(); ++i) {
@ -288,8 +288,9 @@ void DoBuildObjHist(EngineShard* shard, ObjHistMap* obj_hist_map) {
}
}
ObjInfo InspectOp(string_view key, DbIndex db_index) {
auto& db_slice = EngineShard::tlocal()->db_slice();
ObjInfo InspectOp(ConnectionContext* cntx, string_view key) {
auto& db_slice = cntx->ns->GetCurrentDbSlice();
auto db_index = cntx->db_index();
auto [pt, exp_t] = db_slice.GetTables(db_index);
PrimeIterator it = pt->Find(key);
@ -323,8 +324,9 @@ ObjInfo InspectOp(string_view key, DbIndex db_index) {
return oinfo;
}
OpResult<ValueCompressInfo> EstimateCompression(string_view key, DbIndex db_index) {
auto& db_slice = EngineShard::tlocal()->db_slice();
OpResult<ValueCompressInfo> EstimateCompression(ConnectionContext* cntx, string_view key) {
auto& db_slice = cntx->ns->GetCurrentDbSlice();
auto db_index = cntx->db_index();
auto [pt, exp_t] = db_slice.GetTables(db_index);
PrimeIterator it = pt->Find(key);
@ -544,7 +546,7 @@ void DebugCmd::Load(string_view filename) {
const CommandId* cid = sf_.service().FindCmd("FLUSHALL");
intrusive_ptr<Transaction> flush_trans(new Transaction{cid});
flush_trans->InitByArgs(0, {});
flush_trans->InitByArgs(cntx_->ns, 0, {});
VLOG(1) << "Performing flush";
error_code ec = sf_.Drakarys(flush_trans.get(), DbSlice::kDbAll);
if (ec) {
@ -750,7 +752,7 @@ void DebugCmd::PopulateRangeFiber(uint64_t from, uint64_t num_of_keys,
// after running the callback
// Note that running debug populate while running flushall/db can cause dcheck fail because the
// finish cb is executed just when we finish populating the database.
shard->db_slice().OnCbFinish();
cntx_->ns->GetDbSlice(shard->shard_id()).OnCbFinish();
});
}
@ -801,7 +803,7 @@ void DebugCmd::Inspect(string_view key, CmdArgList args) {
string resp;
if (check_compression) {
auto cb = [&] { return EstimateCompression(key, cntx_->db_index()); };
auto cb = [&] { return EstimateCompression(cntx_, key); };
auto res = ess.Await(sid, std::move(cb));
if (!res) {
cntx_->SendError(res.status());
@ -812,7 +814,7 @@ void DebugCmd::Inspect(string_view key, CmdArgList args) {
StrAppend(&resp, " ratio: ", static_cast<double>(res->compressed_size) / (res->raw_size));
}
} else {
auto cb = [&] { return InspectOp(key, cntx_->db_index()); };
auto cb = [&] { return InspectOp(cntx_, key); };
ObjInfo res = ess.Await(sid, std::move(cb));
@ -846,7 +848,7 @@ void DebugCmd::Watched() {
vector<string> awaked_trans;
auto cb = [&](EngineShard* shard) {
auto* bc = shard->blocking_controller();
auto* bc = cntx_->ns->GetBlockingController(shard->shard_id());
if (bc) {
auto keys = bc->GetWatchedKeys(cntx_->db_index());
@ -894,8 +896,9 @@ void DebugCmd::TxAnalysis() {
void DebugCmd::ObjHist() {
vector<ObjHistMap> obj_hist_map_arr(shard_set->size());
shard_set->RunBlockingInParallel(
[&](EngineShard* shard) { DoBuildObjHist(shard, &obj_hist_map_arr[shard->shard_id()]); });
shard_set->RunBlockingInParallel([&](EngineShard* shard) {
DoBuildObjHist(shard, cntx_, &obj_hist_map_arr[shard->shard_id()]);
});
for (size_t i = shard_set->size() - 1; i > 0; --i) {
MergeObjHistMap(std::move(obj_hist_map_arr[i]), &obj_hist_map_arr[0]);
@ -937,8 +940,9 @@ void DebugCmd::Shards() {
vector<ShardInfo> infos(shard_set->size());
shard_set->RunBriefInParallel([&](EngineShard* shard) {
auto slice_stats = shard->db_slice().GetStats();
auto& stats = infos[shard->shard_id()];
auto sid = shard->shard_id();
auto slice_stats = cntx_->ns->GetDbSlice(sid).GetStats();
auto& stats = infos[sid];
stats.used_memory = shard->UsedMemory();
for (const auto& db_stats : slice_stats.db_stats) {

View file

@ -13,6 +13,7 @@
#include "base/logging.h"
#include "server/detail/snapshot_storage.h"
#include "server/main_service.h"
#include "server/namespaces.h"
#include "server/script_mgr.h"
#include "server/transaction.h"
#include "strings/human_readable.h"
@ -400,7 +401,7 @@ void SaveStagesController::CloseCb(unsigned index) {
}
if (auto* es = EngineShard::tlocal(); use_dfs_format_ && es)
es->db_slice().ResetUpdateEvents();
namespaces.GetDefaultNamespace().GetDbSlice(es->shard_id()).ResetUpdateEvents();
}
void SaveStagesController::RunStage(void (SaveStagesController::*cb)(unsigned)) {

View file

@ -75,7 +75,7 @@ OpStatus WaitReplicaFlowToCatchup(absl::Time end_time, shared_ptr<DflyCmd::Repli
EngineShard* shard) {
// We don't want any writes to the journal after we send the `PING`,
// and expirations could ruin that.
shard->db_slice().SetExpireAllowed(false);
namespaces.GetDefaultNamespace().GetDbSlice(shard->shard_id()).SetExpireAllowed(false);
shard->journal()->RecordEntry(0, journal::Op::PING, 0, 0, nullopt, {}, true);
FlowInfo* flow = &replica->flows[shard->shard_id()];
@ -396,8 +396,9 @@ void DflyCmd::TakeOver(CmdArgList args, ConnectionContext* cntx) {
VLOG(1) << "AwaitCurrentDispatches done";
absl::Cleanup([] {
shard_set->RunBriefInParallel(
[](EngineShard* shard) { shard->db_slice().SetExpireAllowed(true); });
shard_set->RunBriefInParallel([](EngineShard* shard) {
namespaces.GetDefaultNamespace().GetDbSlice(shard->shard_id()).SetExpireAllowed(true);
});
VLOG(2) << "Enable expiration";
});

View file

@ -366,7 +366,11 @@ TEST_F(DflyEngineTest, MemcacheFlags) {
ASSERT_EQ(Run("resp", {"flushdb"}), "OK");
pp_->AwaitFiberOnAll([](auto*) {
if (auto* shard = EngineShard::tlocal(); shard) {
EXPECT_EQ(shard->db_slice().GetDBTable(0)->mcflag.size(), 0u);
EXPECT_EQ(namespaces.GetDefaultNamespace()
.GetDbSlice(shard->shard_id())
.GetDBTable(0)
->mcflag.size(),
0u);
}
});
}
@ -584,7 +588,7 @@ TEST_F(DflyEngineTest, Bug496) {
if (shard == nullptr)
return;
auto& db = shard->db_slice();
auto& db = namespaces.GetDefaultNamespace().GetDbSlice(shard->shard_id());
int cb_hits = 0;
uint32_t cb_id =

View file

@ -20,6 +20,7 @@ extern "C" {
#include "io/proc_reader.h"
#include "server/blocking_controller.h"
#include "server/cluster/cluster_defs.h"
#include "server/namespaces.h"
#include "server/search/doc_index.h"
#include "server/server_state.h"
#include "server/tiered_storage.h"
@ -294,7 +295,8 @@ bool EngineShard::DoDefrag() {
constexpr size_t kMaxTraverses = 40;
const float threshold = GetFlag(FLAGS_mem_defrag_page_utilization_threshold);
auto& slice = db_slice();
// TODO: enable tiered storage on non-default db slice
DbSlice& slice = namespaces.GetDefaultNamespace().GetDbSlice(shard_->shard_id());
// If we moved to an invalid db, skip as long as it's not the last one
while (!slice.IsDbValid(defrag_state_.dbid) && defrag_state_.dbid + 1 < slice.db_array_size())
@ -324,7 +326,7 @@ bool EngineShard::DoDefrag() {
}
});
traverses_count++;
} while (traverses_count < kMaxTraverses && cur);
} while (traverses_count < kMaxTraverses && cur && namespaces.IsInitialized());
defrag_state_.UpdateScanState(cur.value());
@ -355,11 +357,14 @@ bool EngineShard::DoDefrag() {
// priority.
// otherwise lower the task priority so that it would not use the CPU when not required
uint32_t EngineShard::DefragTask() {
if (!namespaces.IsInitialized()) {
return util::ProactorBase::kOnIdleMaxLevel;
}
constexpr uint32_t kRunAtLowPriority = 0u;
const auto shard_id = db_slice().shard_id();
if (defrag_state_.CheckRequired()) {
VLOG(2) << shard_id << ": need to run defrag memory cursor state: " << defrag_state_.cursor;
VLOG(2) << shard_id_ << ": need to run defrag memory cursor state: " << defrag_state_.cursor;
if (DoDefrag()) {
// we didn't finish the scan
return util::ProactorBase::kOnIdleMaxLevel;
@ -372,13 +377,11 @@ EngineShard::EngineShard(util::ProactorBase* pb, mi_heap_t* heap)
: queue_(1, kQueueLen),
txq_([](const Transaction* t) { return t->txid(); }),
mi_resource_(heap),
db_slice_(pb->GetPoolIndex(), GetFlag(FLAGS_cache_mode), this) {
shard_id_(pb->GetPoolIndex()) {
tmp_str1 = sdsempty();
db_slice_.UpdateExpireBase(absl::GetCurrentTimeNanos() / 1000000, 0);
// start the defragmented task here
defrag_task_ = pb->AddOnIdleTask([this]() { return this->DefragTask(); });
queue_.Start(absl::StrCat("shard_queue_", db_slice_.shard_id()));
defrag_task_ = pb->AddOnIdleTask([this]() { return DefragTask(); });
queue_.Start(absl::StrCat("shard_queue_", shard_id()));
}
EngineShard::~EngineShard() {
@ -437,8 +440,10 @@ void EngineShard::InitTieredStorage(ProactorBase* pb, size_t max_file_size) {
LOG_IF(FATAL, pb->GetKind() != ProactorBase::IOURING)
<< "Only ioring based backing storage is supported. Exiting...";
// TODO: enable tiered storage on non-default namespace
DbSlice& db_slice = namespaces.GetDefaultNamespace().GetDbSlice(shard_id());
auto* shard = EngineShard::tlocal();
shard->tiered_storage_ = make_unique<TieredStorage>(&db_slice_, max_file_size);
shard->tiered_storage_ = make_unique<TieredStorage>(&db_slice, max_file_size);
error_code ec = shard->tiered_storage_->Open(backing_prefix);
CHECK(!ec) << ec.message();
}
@ -515,24 +520,31 @@ void EngineShard::PollExecution(const char* context, Transaction* trans) {
trans = nullptr;
if ((is_self && disarmed) || continuation_trans_->DisarmInShard(sid)) {
auto bc = continuation_trans_->GetNamespace().GetBlockingController(shard_id_);
if (bool keep = run(continuation_trans_, false); !keep) {
// if this holds, we can remove this check altogether.
DCHECK(continuation_trans_ == nullptr);
continuation_trans_ = nullptr;
}
if (bc && bc->HasAwakedTransaction()) {
// Break if there are any awakened transactions, as we must give way to them
// before continuing to handle regular transactions from the queue.
return;
}
}
}
// Progress on the transaction queue if no transaction is running currently.
Transaction* head = nullptr;
while (continuation_trans_ == nullptr && !txq_.Empty()) {
head = get<Transaction*>(txq_.Front());
// Break if there are any awakened transactions, as we must give way to them
// before continuing to handle regular transactions from the queue.
if (blocking_controller_ && blocking_controller_->HasAwakedTransaction())
if (head->GetNamespace().GetBlockingController(shard_id_) &&
head->GetNamespace().GetBlockingController(shard_id_)->HasAwakedTransaction())
break;
head = get<Transaction*>(txq_.Front());
VLOG(2) << "Considering head " << head->DebugId()
<< " isarmed: " << head->DEBUG_IsArmedInShard(sid);
@ -610,22 +622,28 @@ void EngineShard::Heartbeat() {
DbContext db_cntx;
db_cntx.time_now_ms = GetCurrentTimeMs();
for (unsigned i = 0; i < db_slice_.db_array_size(); ++i) {
if (!db_slice_.IsDbValid(i))
// TODO: iterate over all namespaces
if (!namespaces.IsInitialized()) {
return;
}
DbSlice& db_slice = namespaces.GetDefaultNamespace().GetDbSlice(shard_id());
for (unsigned i = 0; i < db_slice.db_array_size(); ++i) {
if (!db_slice.IsDbValid(i))
continue;
db_cntx.db_index = i;
auto [pt, expt] = db_slice_.GetTables(i);
auto [pt, expt] = db_slice.GetTables(i);
if (expt->size() > pt->size() / 4) {
DbSlice::DeleteExpiredStats stats = db_slice_.DeleteExpiredStep(db_cntx, ttl_delete_target);
DbSlice::DeleteExpiredStats stats = db_slice.DeleteExpiredStep(db_cntx, ttl_delete_target);
counter_[TTL_TRAVERSE].IncBy(stats.traversed);
counter_[TTL_DELETE].IncBy(stats.deleted);
}
// if our budget is below the limit
if (db_slice_.memory_budget() < eviction_redline) {
db_slice_.FreeMemWithEvictionStep(i, eviction_redline - db_slice_.memory_budget());
if (db_slice.memory_budget() < eviction_redline) {
db_slice.FreeMemWithEvictionStep(i, eviction_redline - db_slice.memory_budget());
}
if (UsedMemory() > tiering_offload_threshold) {
@ -686,18 +704,23 @@ void EngineShard::RunPeriodic(std::chrono::milliseconds period_ms) {
}
void EngineShard::CacheStats() {
if (!namespaces.IsInitialized()) {
return;
}
// mi_heap_visit_blocks(tlh, false /* visit all blocks*/, visit_cb, &sum);
mi_stats_merge();
// Used memory for this shard.
size_t used_mem = UsedMemory();
cached_stats[db_slice_.shard_id()].used_memory.store(used_mem, memory_order_relaxed);
DbSlice& db_slice = namespaces.GetDefaultNamespace().GetDbSlice(shard_id());
cached_stats[db_slice.shard_id()].used_memory.store(used_mem, memory_order_relaxed);
ssize_t free_mem = max_memory_limit - used_mem_current.load(memory_order_relaxed);
size_t entries = 0;
size_t table_memory = 0;
for (size_t i = 0; i < db_slice_.db_array_size(); ++i) {
DbTable* table = db_slice_.GetDBTable(i);
for (size_t i = 0; i < db_slice.db_array_size(); ++i) {
DbTable* table = db_slice.GetDBTable(i);
if (table) {
entries += table->prime.size();
table_memory += (table->prime.mem_usage() + table->expire.mem_usage());
@ -706,7 +729,7 @@ void EngineShard::CacheStats() {
size_t obj_memory = table_memory <= used_mem ? used_mem - table_memory : 0;
size_t bytes_per_obj = entries > 0 ? obj_memory / entries : 0;
db_slice_.SetCachedParams(free_mem / shard_set->size(), bytes_per_obj);
db_slice.SetCachedParams(free_mem / shard_set->size(), bytes_per_obj);
}
size_t EngineShard::UsedMemory() const {
@ -714,14 +737,6 @@ size_t EngineShard::UsedMemory() const {
search_indices()->GetUsedMemory();
}
BlockingController* EngineShard::EnsureBlockingController() {
if (!blocking_controller_) {
blocking_controller_.reset(new BlockingController(this));
}
return blocking_controller_.get();
}
void EngineShard::TEST_EnableHeartbeat() {
fiber_periodic_ = fb2::Fiber("shard_periodic_TEST", [this, period_ms = 1] {
RunPeriodic(std::chrono::milliseconds(period_ms));
@ -750,6 +765,8 @@ auto EngineShard::AnalyzeTxQueue() const -> TxQueueInfo {
info.tx_total = queue->size();
unsigned max_db_id = 0;
auto& db_slice = namespaces.GetDefaultNamespace().GetCurrentDbSlice();
do {
auto value = queue->At(cur);
Transaction* trx = std::get<Transaction*>(value);
@ -766,7 +783,7 @@ auto EngineShard::AnalyzeTxQueue() const -> TxQueueInfo {
if (trx->IsGlobal() || (trx->IsMulti() && trx->GetMultiMode() == Transaction::GLOBAL)) {
info.tx_global++;
} else {
const DbTable* table = db_slice().GetDBTable(trx->GetDbIndex());
const DbTable* table = db_slice.GetDBTable(trx->GetDbIndex());
bool can_run = !HasContendedLocks(sid, trx, table);
if (can_run) {
info.tx_runnable++;
@ -778,7 +795,7 @@ auto EngineShard::AnalyzeTxQueue() const -> TxQueueInfo {
// Analyze locks
for (unsigned i = 0; i <= max_db_id; ++i) {
const DbTable* table = db_slice().GetDBTable(i);
const DbTable* table = db_slice.GetDBTable(i);
if (table == nullptr)
continue;
@ -869,6 +886,8 @@ void EngineShardSet::Init(uint32_t sz, bool update_db_time) {
}
});
namespaces.Init();
pp_->AwaitFiberOnAll([&](uint32_t index, ProactorBase* pb) {
if (index < shard_queue_.size()) {
EngineShard::tlocal()->InitTieredStorage(pb, max_shard_file_size);
@ -895,7 +914,9 @@ void EngineShardSet::TEST_EnableHeartBeat() {
}
void EngineShardSet::TEST_EnableCacheMode() {
RunBriefInParallel([](EngineShard* shard) { shard->db_slice().TEST_EnableCacheMode(); });
RunBlockingInParallel([](EngineShard* shard) {
namespaces.GetDefaultNamespace().GetCurrentDbSlice().TEST_EnableCacheMode();
});
}
ShardId Shard(string_view v, ShardId shard_num) {

View file

@ -60,15 +60,7 @@ class EngineShard {
}
ShardId shard_id() const {
return db_slice_.shard_id();
}
DbSlice& db_slice() {
return db_slice_;
}
const DbSlice& db_slice() const {
return db_slice_;
return shard_id_;
}
PMR_NS::memory_resource* memory_resource() {
@ -124,12 +116,6 @@ class EngineShard {
return shard_search_indices_.get();
}
BlockingController* EnsureBlockingController();
BlockingController* blocking_controller() {
return blocking_controller_.get();
}
// for everyone to use for string transformations during atomic cpu sequences.
sds tmp_str1;
@ -242,7 +228,7 @@ class EngineShard {
TxQueue txq_;
MiMemoryResource mi_resource_;
DbSlice db_slice_;
ShardId shard_id_;
Stats stats_;
@ -261,8 +247,8 @@ class EngineShard {
DefragTaskState defrag_state_;
std::unique_ptr<TieredStorage> tiered_storage_;
// TODO: Move indices to Namespace
std::unique_ptr<ShardDocIndices> shard_search_indices_;
std::unique_ptr<BlockingController> blocking_controller_;
using Counter = util::SlidingCounter<7>;

View file

@ -451,7 +451,7 @@ OpStatus Renamer::DeserializeDest(Transaction* t, EngineShard* shard) {
auto& dest_it = restored_dest_it->it;
dest_it->first.SetSticky(serialized_value_.sticky);
auto bc = shard->blocking_controller();
auto bc = op_args.db_cntx.ns->GetBlockingController(op_args.shard->shard_id());
if (bc) {
bc->AwakeWatched(t->GetDbIndex(), dest_key_);
}
@ -607,7 +607,7 @@ uint64_t ScanGeneric(uint64_t cursor, const ScanOpts& scan_opts, StringVec* keys
}
cursor >>= 10;
DbContext db_cntx{cntx->conn_state.db_index, GetCurrentTimeMs()};
DbContext db_cntx{cntx->ns, cntx->conn_state.db_index, GetCurrentTimeMs()};
do {
auto cb = [&] {
@ -1355,8 +1355,10 @@ void GenericFamily::Select(CmdArgList args, ConnectionContext* cntx) {
return cntx->SendError(kDbIndOutOfRangeErr);
}
cntx->conn_state.db_index = index;
auto cb = [index](EngineShard* shard) {
shard->db_slice().ActivateDb(index);
auto cb = [cntx, index](EngineShard* shard) {
CHECK(cntx->ns != nullptr);
auto& db_slice = cntx->ns->GetDbSlice(shard->shard_id());
db_slice.ActivateDb(index);
return OpStatus::OK;
};
shard_set->RunBriefInParallel(std::move(cb));
@ -1385,7 +1387,8 @@ void GenericFamily::Type(CmdArgList args, ConnectionContext* cntx) {
std::string_view key = ArgS(args, 0);
auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult<int> {
auto it = shard->db_slice().FindReadOnly(t->GetDbContext(), key).it;
auto& db_slice = cntx->ns->GetDbSlice(shard->shard_id());
auto it = db_slice.FindReadOnly(t->GetDbContext(), key).it;
if (!it.is_done()) {
return it->second.ObjType();
} else {
@ -1549,8 +1552,9 @@ OpResult<void> GenericFamily::OpRen(const OpArgs& op_args, string_view from_key,
to_res.it->first.SetSticky(sticky);
}
if (!is_prior_list && to_res.it->second.ObjType() == OBJ_LIST && es->blocking_controller()) {
es->blocking_controller()->AwakeWatched(op_args.db_cntx.db_index, to_key);
auto bc = op_args.db_cntx.ns->GetBlockingController(es->shard_id());
if (!is_prior_list && to_res.it->second.ObjType() == OBJ_LIST && bc) {
bc->AwakeWatched(op_args.db_cntx.db_index, to_key);
}
return OpStatus::OK;
}
@ -1590,8 +1594,9 @@ OpStatus GenericFamily::OpMove(const OpArgs& op_args, string_view key, DbIndex t
auto& add_res = *op_result;
add_res.it->first.SetSticky(sticky);
if (add_res.it->second.ObjType() == OBJ_LIST && op_args.shard->blocking_controller()) {
op_args.shard->blocking_controller()->AwakeWatched(target_db, key);
auto bc = op_args.db_cntx.ns->GetBlockingController(op_args.shard->shard_id());
if (add_res.it->second.ObjType() == OBJ_LIST && bc) {
bc->AwakeWatched(target_db, key);
}
return OpStatus::OK;
@ -1602,14 +1607,15 @@ void GenericFamily::RandomKey(CmdArgList args, ConnectionContext* cntx) {
absl::BitGen bitgen;
atomic_size_t candidates_counter{0};
DbContext db_cntx{cntx->conn_state.db_index, GetCurrentTimeMs()};
DbContext db_cntx{cntx->ns, cntx->conn_state.db_index, GetCurrentTimeMs()};
ScanOpts scan_opts;
scan_opts.limit = 3; // number of entries per shard
std::vector<StringVec> candidates_collection(shard_set->size());
shard_set->RunBriefInParallel(
[&](EngineShard* shard) {
auto [prime_table, expire_table] = shard->db_slice().GetTables(db_cntx.db_index);
auto [prime_table, expire_table] =
cntx->ns->GetDbSlice(shard->shard_id()).GetTables(db_cntx.db_index);
if (prime_table->size() == 0) {
return;
}

View file

@ -1055,7 +1055,7 @@ void HSetFamily::HRandField(CmdArgList args, ConnectionContext* cntx) {
}
auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult<StringVec> {
auto& db_slice = shard->db_slice();
auto& db_slice = cntx->ns->GetDbSlice(shard->shard_id());
DbContext db_context = t->GetDbContext();
auto it_res = db_slice.FindReadOnly(db_context, key, OBJ_HASH);

View file

@ -44,6 +44,7 @@ JournalExecutor::JournalExecutor(Service* service)
conn_context_.is_replicating = true;
conn_context_.journal_emulated = true;
conn_context_.skip_acl_validation = true;
conn_context_.ns = &namespaces.GetDefaultNamespace();
}
JournalExecutor::~JournalExecutor() {

View file

@ -341,11 +341,12 @@ OpResult<uint32_t> OpPush(const OpArgs& op_args, std::string_view key, ListDir d
}
if (res.is_new) {
if (es->blocking_controller()) {
auto blocking_controller = op_args.db_cntx.ns->GetBlockingController(es->shard_id());
if (blocking_controller) {
string tmp;
string_view key = res.it->first.GetSlice(&tmp);
es->blocking_controller()->AwakeWatched(op_args.db_cntx.db_index, key);
blocking_controller->AwakeWatched(op_args.db_cntx.db_index, key);
absl::StrAppend(debugMessages.Next(), "OpPush AwakeWatched: ", key, " by ",
op_args.tx->DebugId());
}
@ -444,11 +445,12 @@ OpResult<string> MoveTwoShards(Transaction* trans, string_view src, string_view
OpPush(op_args, key, dest_dir, false, ArgSlice{val}, true);
// blocking_controller does not have to be set with non-blocking transactions.
if (shard->blocking_controller()) {
auto blocking_controller = t->GetNamespace().GetBlockingController(shard->shard_id());
if (blocking_controller) {
// hack, again. since we hacked which queue we are waiting on (see RunPair)
// we must clean-up src key here manually. See RunPair why we do this.
// in short- we suspended on "src" on both shards.
shard->blocking_controller()->FinalizeWatched(ArgSlice({src}), t);
blocking_controller->FinalizeWatched(ArgSlice({src}), t);
}
} else {
DVLOG(1) << "Popping value from list: " << key;
@ -852,10 +854,11 @@ OpResult<string> BPopPusher::RunSingle(ConnectionContext* cntx, time_point tp) {
std::array<string_view, 4> arr = {pop_key_, push_key_, DirToSv(popdir_), DirToSv(pushdir_)};
RecordJournal(op_args, "LMOVE", arr, 1);
}
if (shard->blocking_controller()) {
auto blocking_controller = cntx->ns->GetBlockingController(shard->shard_id());
if (blocking_controller) {
string tmp;
shard->blocking_controller()->AwakeWatched(op_args.db_cntx.db_index, push_key_);
blocking_controller->AwakeWatched(op_args.db_cntx.db_index, push_key_);
absl::StrAppend(debugMessages.Next(), "OpPush AwakeWatched: ", push_key_, " by ",
op_args.tx->DebugId());
}

View file

@ -32,8 +32,10 @@ class ListFamilyTest : public BaseFamilyTest {
static unsigned NumWatched() {
atomic_uint32_t sum{0};
auto ns = &namespaces.GetDefaultNamespace();
shard_set->RunBriefInParallel([&](EngineShard* es) {
auto* bc = es->blocking_controller();
auto* bc = ns->GetBlockingController(es->shard_id());
if (bc)
sum.fetch_add(bc->NumWatched(0), memory_order_relaxed);
});
@ -43,8 +45,9 @@ class ListFamilyTest : public BaseFamilyTest {
static bool HasAwakened() {
atomic_uint32_t sum{0};
auto ns = &namespaces.GetDefaultNamespace();
shard_set->RunBriefInParallel([&](EngineShard* es) {
auto* bc = es->blocking_controller();
auto* bc = ns->GetBlockingController(es->shard_id());
if (bc)
sum.fetch_add(bc->HasAwakedTransaction(), memory_order_relaxed);
});
@ -168,7 +171,7 @@ TEST_F(ListFamilyTest, BLPopTimeout) {
RespExpr resp = Run({"blpop", kKey1, kKey2, kKey3, "0.01"});
EXPECT_THAT(resp, ArgType(RespExpr::NIL_ARRAY));
EXPECT_EQ(3, GetDebugInfo().shards_count);
ASSERT_FALSE(service_->IsLocked(0, kKey1));
ASSERT_FALSE(IsLocked(0, kKey1));
// Under Multi
resp = Run({"multi"});
@ -178,7 +181,7 @@ TEST_F(ListFamilyTest, BLPopTimeout) {
resp = Run({"exec"});
EXPECT_THAT(resp, ArgType(RespExpr::NIL_ARRAY));
ASSERT_FALSE(service_->IsLocked(0, kKey1));
ASSERT_FALSE(IsLocked(0, kKey1));
ASSERT_EQ(0, NumWatched());
}

View file

@ -47,6 +47,7 @@ extern "C" {
#include "server/json_family.h"
#include "server/list_family.h"
#include "server/multi_command_squasher.h"
#include "server/namespaces.h"
#include "server/script_mgr.h"
#include "server/search/search_family.h"
#include "server/server_state.h"
@ -135,9 +136,11 @@ constexpr size_t kMaxThreadSize = 1024;
// Unwatch all keys for a connection and unregister from DbSlices.
// Used by UNWATCH, DICARD and EXEC.
void UnwatchAllKeys(ConnectionState::ExecInfo* exec_info) {
void UnwatchAllKeys(Namespace* ns, ConnectionState::ExecInfo* exec_info) {
if (!exec_info->watched_keys.empty()) {
auto cb = [&](EngineShard* shard) { shard->db_slice().UnregisterConnectionWatches(exec_info); };
auto cb = [&](EngineShard* shard) {
ns->GetDbSlice(shard->shard_id()).UnregisterConnectionWatches(exec_info);
};
shard_set->RunBriefInParallel(std::move(cb));
}
exec_info->ClearWatched();
@ -149,7 +152,7 @@ void MultiCleanup(ConnectionContext* cntx) {
ServerState::tlocal()->ReturnInterpreter(borrowed);
exec_info.preborrowed_interpreter = nullptr;
}
UnwatchAllKeys(&exec_info);
UnwatchAllKeys(cntx->ns, &exec_info);
exec_info.Clear();
}
@ -513,7 +516,8 @@ void Topkeys(const http::QueryArgs& args, HttpContext* send) {
vector<string> rows(shard_set->size());
shard_set->RunBriefInParallel([&](EngineShard* shard) {
for (const auto& db : shard->db_slice().databases()) {
for (const auto& db :
namespaces.GetDefaultNamespace().GetDbSlice(shard->shard_id()).databases()) {
if (db->top_keys.IsEnabled()) {
is_enabled = true;
for (const auto& [key, count] : db->top_keys.GetTopKeys()) {
@ -829,6 +833,7 @@ Service::Service(ProactorPool* pp)
Service::~Service() {
delete shard_set;
shard_set = nullptr;
namespaces.Clear();
}
void Service::Init(util::AcceptServer* acceptor, std::vector<facade::Listener*> listeners,
@ -908,7 +913,10 @@ void Service::Shutdown() {
ChannelStore::Destroy();
namespaces.Clear();
shard_set->Shutdown();
pp_.Await([](ProactorBase* pb) { ServerState::tlocal()->Destroy(); });
// wait for all the pending callbacks to stop.
@ -1212,8 +1220,8 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx)
DCHECK(dfly_cntx->transaction);
if (cid->IsTransactional()) {
dfly_cntx->transaction->MultiSwitchCmd(cid);
OpStatus status =
dfly_cntx->transaction->InitByArgs(dfly_cntx->conn_state.db_index, args_no_cmd);
OpStatus status = dfly_cntx->transaction->InitByArgs(
dfly_cntx->ns, dfly_cntx->conn_state.db_index, args_no_cmd);
if (status != OpStatus::OK)
return cntx->SendError(status);
@ -1225,7 +1233,9 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx)
dist_trans.reset(new Transaction{cid});
if (!dist_trans->IsMulti()) { // Multi command initialize themself based on their mode.
if (auto st = dist_trans->InitByArgs(dfly_cntx->conn_state.db_index, args_no_cmd);
CHECK(dfly_cntx->ns != nullptr);
if (auto st =
dist_trans->InitByArgs(dfly_cntx->ns, dfly_cntx->conn_state.db_index, args_no_cmd);
st != OpStatus::OK)
return cntx->SendError(st);
}
@ -1581,6 +1591,7 @@ facade::ConnectionContext* Service::CreateContext(util::FiberSocketBase* peer,
facade::Connection* owner) {
auto cred = user_registry_.GetCredentials("default");
ConnectionContext* res = new ConnectionContext{peer, owner, std::move(cred)};
res->ns = &namespaces.GetOrInsert("");
if (peer->IsUDS()) {
res->req_auth = false;
@ -1606,10 +1617,10 @@ const CommandId* Service::FindCmd(std::string_view cmd) const {
return registry_.Find(registry_.RenamedOrOriginal(cmd));
}
bool Service::IsLocked(DbIndex db_index, std::string_view key) const {
bool Service::IsLocked(Namespace* ns, DbIndex db_index, std::string_view key) const {
ShardId sid = Shard(key, shard_count());
bool is_open = pp_.at(sid)->AwaitBrief([db_index, key] {
return EngineShard::tlocal()->db_slice().CheckLock(IntentLock::EXCLUSIVE, db_index, key);
bool is_open = pp_.at(sid)->AwaitBrief([db_index, key, ns, sid] {
return ns->GetDbSlice(sid).CheckLock(IntentLock::EXCLUSIVE, db_index, key);
});
return !is_open;
}
@ -1682,7 +1693,7 @@ void Service::Watch(CmdArgList args, ConnectionContext* cntx) {
}
void Service::Unwatch(CmdArgList args, ConnectionContext* cntx) {
UnwatchAllKeys(&cntx->conn_state.exec_info);
UnwatchAllKeys(cntx->ns, &cntx->conn_state.exec_info);
return cntx->SendOk();
}
@ -1841,6 +1852,7 @@ Transaction::MultiMode DetermineMultiMode(ScriptMgr::ScriptParams params) {
optional<bool> StartMultiEval(DbIndex dbid, CmdArgList keys, ScriptMgr::ScriptParams params,
ConnectionContext* cntx) {
Transaction* trans = cntx->transaction;
Namespace* ns = cntx->ns;
Transaction::MultiMode script_mode = DetermineMultiMode(params);
Transaction::MultiMode multi_mode = trans->GetMultiMode();
// Check if eval is already part of a running multi transaction
@ -1860,10 +1872,10 @@ optional<bool> StartMultiEval(DbIndex dbid, CmdArgList keys, ScriptMgr::ScriptPa
switch (script_mode) {
case Transaction::GLOBAL:
trans->StartMultiGlobal(dbid);
trans->StartMultiGlobal(ns, dbid);
return true;
case Transaction::LOCK_AHEAD:
trans->StartMultiLockedAhead(dbid, keys);
trans->StartMultiLockedAhead(ns, dbid, keys);
return true;
case Transaction::NON_ATOMIC:
trans->StartMultiNonAtomic();
@ -1988,7 +2000,7 @@ void Service::EvalInternal(CmdArgList args, const EvalArgs& eval_args, Interpret
});
++ServerState::tlocal()->stats.eval_shardlocal_coordination_cnt;
tx->PrepareMultiForScheduleSingleHop(*sid, tx->GetDbIndex(), args);
tx->PrepareMultiForScheduleSingleHop(cntx->ns, *sid, tx->GetDbIndex(), args);
tx->ScheduleSingleHop([&](Transaction*, EngineShard*) {
boost::intrusive_ptr<Transaction> stub_tx =
new Transaction{tx, *sid, slot_checker.GetUniqueSlotId()};
@ -2077,7 +2089,7 @@ bool CheckWatchedKeyExpiry(ConnectionContext* cntx, const CommandRegistry& regis
};
cntx->transaction->MultiSwitchCmd(registry.Find(EXISTS));
cntx->transaction->InitByArgs(cntx->conn_state.db_index, CmdArgList{str_list});
cntx->transaction->InitByArgs(cntx->ns, cntx->conn_state.db_index, CmdArgList{str_list});
OpStatus status = cntx->transaction->ScheduleSingleHop(std::move(cb));
CHECK_EQ(OpStatus::OK, status);
@ -2139,11 +2151,11 @@ void StartMultiExec(ConnectionContext* cntx, ConnectionState::ExecInfo* exec_inf
auto dbid = cntx->db_index();
switch (multi_mode) {
case Transaction::GLOBAL:
trans->StartMultiGlobal(dbid);
trans->StartMultiGlobal(cntx->ns, dbid);
break;
case Transaction::LOCK_AHEAD: {
auto vec = CollectAllKeys(exec_info);
trans->StartMultiLockedAhead(dbid, absl::MakeSpan(vec));
trans->StartMultiLockedAhead(cntx->ns, dbid, absl::MakeSpan(vec));
} break;
case Transaction::NON_ATOMIC:
trans->StartMultiNonAtomic();
@ -2234,7 +2246,7 @@ void Service::Exec(CmdArgList args, ConnectionContext* cntx) {
CmdArgList args = absl::MakeSpan(arg_vec);
if (scmd.Cid()->IsTransactional()) {
OpStatus st = cntx->transaction->InitByArgs(cntx->conn_state.db_index, args);
OpStatus st = cntx->transaction->InitByArgs(cntx->ns, cntx->conn_state.db_index, args);
if (st != OpStatus::OK) {
cntx->SendError(st);
break;
@ -2444,7 +2456,7 @@ void Service::Command(CmdArgList args, ConnectionContext* cntx) {
VarzValue::Map Service::GetVarzStats() {
VarzValue::Map res;
Metrics m = server_family_.GetMetrics();
Metrics m = server_family_.GetMetrics(&namespaces.GetDefaultNamespace());
DbStats db_stats;
for (const auto& s : m.db_stats) {
db_stats += s;
@ -2543,7 +2555,7 @@ void Service::OnClose(facade::ConnectionContext* cntx) {
DCHECK(!conn_state.subscribe_info);
}
UnwatchAllKeys(&conn_state.exec_info);
UnwatchAllKeys(server_cntx->ns, &conn_state.exec_info);
DeactivateMonitoring(server_cntx);

View file

@ -102,7 +102,7 @@ class Service : public facade::ServiceInterface {
}
// Used by tests.
bool IsLocked(DbIndex db_index, std::string_view key) const;
bool IsLocked(Namespace* ns, DbIndex db_index, std::string_view key) const;
bool IsShardSetLocked() const;
util::ProactorPool& proactor_pool() {

View file

@ -245,7 +245,7 @@ void PushMemoryUsageStats(const base::IoBuf::MemoryUsage& mem, string_view prefi
void MemoryCmd::Stats() {
vector<pair<string, size_t>> stats;
stats.reserve(25);
auto server_metrics = owner_->GetMetrics();
auto server_metrics = owner_->GetMetrics(cntx_->ns);
// RSS
stats.push_back({"rss_bytes", rss_mem_current.load(memory_order_relaxed)});
@ -369,8 +369,8 @@ void MemoryCmd::ArenaStats(CmdArgList args) {
void MemoryCmd::Usage(std::string_view key) {
ShardId sid = Shard(key, shard_set->size());
ssize_t memory_usage = shard_set->pool()->at(sid)->AwaitBrief([key, this]() -> ssize_t {
auto& db_slice = EngineShard::tlocal()->db_slice();
ssize_t memory_usage = shard_set->pool()->at(sid)->AwaitBrief([key, this, sid]() -> ssize_t {
auto& db_slice = cntx_->ns->GetDbSlice(sid);
auto [pt, exp_t] = db_slice.GetTables(cntx_->db_index());
PrimeIterator it = pt->Find(key);
if (IsValid(it)) {

View file

@ -145,7 +145,7 @@ bool MultiCommandSquasher::ExecuteStandalone(StoredCmd* cmd) {
cntx_->cid = cmd->Cid();
if (cmd->Cid()->IsTransactional())
tx->InitByArgs(cntx_->conn_state.db_index, args);
tx->InitByArgs(cntx_->ns, cntx_->conn_state.db_index, args);
service_->InvokeCmd(cmd->Cid(), args, cntx_);
return true;
@ -181,7 +181,7 @@ OpStatus MultiCommandSquasher::SquashedHopCb(Transaction* parent_tx, EngineShard
local_cntx.cid = cmd->Cid();
crb.SetReplyMode(cmd->ReplyMode());
local_tx->InitByArgs(local_cntx.conn_state.db_index, args);
local_tx->InitByArgs(cntx_->ns, local_cntx.conn_state.db_index, args);
service_->InvokeCmd(cmd->Cid(), args, &local_cntx);
sinfo.replies.emplace_back(crb.Take());

View file

@ -109,8 +109,8 @@ TEST_F(MultiTest, Multi) {
resp = Run({"get", kKey4});
ASSERT_THAT(resp, ArgType(RespExpr::NIL));
ASSERT_FALSE(service_->IsLocked(0, kKey1));
ASSERT_FALSE(service_->IsLocked(0, kKey4));
ASSERT_FALSE(IsLocked(0, kKey1));
ASSERT_FALSE(IsLocked(0, kKey4));
ASSERT_FALSE(service_->IsShardSetLocked());
}
@ -129,8 +129,8 @@ TEST_F(MultiTest, MultiGlobalCommands) {
ASSERT_THAT(Run({"select", "2"}), "OK");
ASSERT_THAT(Run({"get", "key"}), "val");
ASSERT_FALSE(service_->IsLocked(0, "key"));
ASSERT_FALSE(service_->IsLocked(2, "key"));
ASSERT_FALSE(IsLocked(0, "key"));
ASSERT_FALSE(IsLocked(2, "key"));
}
TEST_F(MultiTest, HitMissStats) {
@ -181,8 +181,8 @@ TEST_F(MultiTest, MultiSeq) {
ASSERT_EQ(resp, "QUEUED");
resp = Run({"exec"});
ASSERT_FALSE(service_->IsLocked(0, kKey1));
ASSERT_FALSE(service_->IsLocked(0, kKey4));
ASSERT_FALSE(IsLocked(0, kKey1));
ASSERT_FALSE(IsLocked(0, kKey4));
ASSERT_FALSE(service_->IsShardSetLocked());
ASSERT_THAT(resp, ArrLen(3));
@ -237,8 +237,8 @@ TEST_F(MultiTest, MultiConsistent) {
mset_fb.Join();
fb.Join();
ASSERT_FALSE(service_->IsLocked(0, kKey1));
ASSERT_FALSE(service_->IsLocked(0, kKey4));
ASSERT_FALSE(IsLocked(0, kKey1));
ASSERT_FALSE(IsLocked(0, kKey4));
ASSERT_FALSE(service_->IsShardSetLocked());
}
@ -312,9 +312,9 @@ TEST_F(MultiTest, MultiRename) {
resp = Run({"exec"});
EXPECT_EQ(resp, "OK");
EXPECT_FALSE(service_->IsLocked(0, kKey1));
EXPECT_FALSE(service_->IsLocked(0, kKey2));
EXPECT_FALSE(service_->IsLocked(0, kKey4));
EXPECT_FALSE(IsLocked(0, kKey1));
EXPECT_FALSE(IsLocked(0, kKey2));
EXPECT_FALSE(IsLocked(0, kKey4));
EXPECT_FALSE(service_->IsShardSetLocked());
}
@ -366,8 +366,8 @@ TEST_F(MultiTest, FlushDb) {
fb0.Join();
ASSERT_FALSE(service_->IsLocked(0, kKey1));
ASSERT_FALSE(service_->IsLocked(0, kKey4));
ASSERT_FALSE(IsLocked(0, kKey1));
ASSERT_FALSE(IsLocked(0, kKey4));
ASSERT_FALSE(service_->IsShardSetLocked());
}
@ -400,17 +400,17 @@ TEST_F(MultiTest, Eval) {
resp = Run({"eval", "return redis.call('get', 'foo')", "1", "bar"});
EXPECT_THAT(resp, ErrArg("undeclared"));
ASSERT_FALSE(service_->IsLocked(0, "foo"));
ASSERT_FALSE(IsLocked(0, "foo"));
Run({"script", "flush"}); // Reset global flag from autocorrect
resp = Run({"eval", "return redis.call('get', 'foo')", "1", "foo"});
EXPECT_THAT(resp, "42");
ASSERT_FALSE(service_->IsLocked(0, "foo"));
ASSERT_FALSE(IsLocked(0, "foo"));
resp = Run({"eval", "return redis.call('get', KEYS[1])", "1", "foo"});
EXPECT_THAT(resp, "42");
ASSERT_FALSE(service_->IsLocked(0, "foo"));
ASSERT_FALSE(IsLocked(0, "foo"));
ASSERT_FALSE(service_->IsShardSetLocked());
resp = Run({"eval", "return 77", "2", "foo", "zoo"});
@ -451,7 +451,7 @@ TEST_F(MultiTest, Eval) {
"1", "foo"}),
"42");
auto condition = [&]() { return service_->IsLocked(0, "foo"); };
auto condition = [&]() { return IsLocked(0, "foo"); };
auto fb = ExpectConditionWithSuspension(condition);
EXPECT_EQ(Run({"eval",
R"(redis.call('set', 'foo', '42')
@ -974,7 +974,7 @@ TEST_F(MultiTest, TestLockedKeys) {
GTEST_SKIP() << "Skipped TestLockedKeys test because multi_exec_mode is not lock ahead";
return;
}
auto condition = [&]() { return service_->IsLocked(0, "key1") && service_->IsLocked(0, "key2"); };
auto condition = [&]() { return IsLocked(0, "key1") && IsLocked(0, "key2"); };
auto fb = ExpectConditionWithSuspension(condition);
EXPECT_EQ(Run({"multi"}), "OK");
@ -983,8 +983,8 @@ TEST_F(MultiTest, TestLockedKeys) {
EXPECT_EQ(Run({"mset", "key1", "val3", "key1", "val4"}), "QUEUED");
EXPECT_THAT(Run({"exec"}), RespArray(ElementsAre("OK", "OK", "OK")));
fb.Join();
EXPECT_FALSE(service_->IsLocked(0, "key1"));
EXPECT_FALSE(service_->IsLocked(0, "key2"));
EXPECT_FALSE(IsLocked(0, "key1"));
EXPECT_FALSE(IsLocked(0, "key2"));
}
TEST_F(MultiTest, EvalExpiration) {

103
src/server/namespaces.cc Normal file
View file

@ -0,0 +1,103 @@
#include "server/namespaces.h"
#include "base/flags.h"
#include "base/logging.h"
#include "server/engine_shard_set.h"
ABSL_DECLARE_FLAG(bool, cache_mode);
namespace dfly {
using namespace std;
Namespace::Namespace() {
shard_db_slices_.resize(shard_set->size());
shard_blocking_controller_.resize(shard_set->size());
shard_set->RunBriefInParallel([&](EngineShard* es) {
CHECK(es != nullptr);
ShardId sid = es->shard_id();
shard_db_slices_[sid] = make_unique<DbSlice>(sid, absl::GetFlag(FLAGS_cache_mode), es);
shard_db_slices_[sid]->UpdateExpireBase(absl::GetCurrentTimeNanos() / 1000000, 0);
});
}
DbSlice& Namespace::GetCurrentDbSlice() {
EngineShard* es = EngineShard::tlocal();
CHECK(es != nullptr);
return GetDbSlice(es->shard_id());
}
DbSlice& Namespace::GetDbSlice(ShardId sid) {
CHECK_LT(sid, shard_db_slices_.size());
return *shard_db_slices_[sid];
}
BlockingController* Namespace::GetOrAddBlockingController(EngineShard* shard) {
if (!shard_blocking_controller_[shard->shard_id()]) {
shard_blocking_controller_[shard->shard_id()] = make_unique<BlockingController>(shard, this);
}
return shard_blocking_controller_[shard->shard_id()].get();
}
BlockingController* Namespace::GetBlockingController(ShardId sid) {
return shard_blocking_controller_[sid].get();
}
Namespaces namespaces;
Namespaces::~Namespaces() {
Clear();
}
void Namespaces::Init() {
DCHECK(default_namespace_ == nullptr);
default_namespace_ = &GetOrInsert("");
}
bool Namespaces::IsInitialized() const {
return default_namespace_ != nullptr;
}
void Namespaces::Clear() {
std::unique_lock guard(mu_);
namespaces.default_namespace_ = nullptr;
if (namespaces_.empty()) {
return;
}
shard_set->RunBriefInParallel([&](EngineShard* es) {
CHECK(es != nullptr);
for (auto& ns : namespaces_) {
ns.second.shard_db_slices_[es->shard_id()].reset();
}
});
namespaces.namespaces_.clear();
}
Namespace& Namespaces::GetDefaultNamespace() const {
CHECK(default_namespace_ != nullptr);
return *default_namespace_;
}
Namespace& Namespaces::GetOrInsert(std::string_view ns) {
{
// Try to look up under a shared lock
std::shared_lock guard(mu_);
auto it = namespaces_.find(ns);
if (it != namespaces_.end()) {
return it->second;
}
}
{
// Key was not found, so we create create it under unique lock
std::unique_lock guard(mu_);
return namespaces_[ns];
}
}
} // namespace dfly

70
src/server/namespaces.h Normal file
View file

@ -0,0 +1,70 @@
// Copyright 2024, DragonflyDB authors. All rights reserved.
// See LICENSE for licensing terms.
//
#pragma once
#include <absl/container/node_hash_map.h>
#include <memory>
#include <string>
#include <vector>
#include "server/blocking_controller.h"
#include "server/db_slice.h"
#include "server/tx_base.h"
#include "util/fibers/synchronization.h"
namespace dfly {
// A Namespace is a way to separate and isolate different databases in a single instance.
// It can be used to allow multiple tenants to use the same server without hacks of using a common
// prefix, or SELECT-ing a different database.
// Each Namespace contains per-shard DbSlice, as well as a BlockingController.
class Namespace {
public:
Namespace();
DbSlice& GetCurrentDbSlice();
DbSlice& GetDbSlice(ShardId sid);
BlockingController* GetOrAddBlockingController(EngineShard* shard);
BlockingController* GetBlockingController(ShardId sid);
private:
std::vector<std::unique_ptr<DbSlice>> shard_db_slices_;
std::vector<std::unique_ptr<BlockingController>> shard_blocking_controller_;
friend class Namespaces;
};
// Namespaces is a registry and container for Namespace instances.
// Each Namespace has a unique string name, which identifies it in the store.
// Any attempt to access a non-existing Namespace will first create it, add it to the internal map
// and will then return it.
// It is currently impossible to remove a Namespace after it has been created.
// The default Namespace can be accessed via either GetDefaultNamespace() (which guarantees not to
// yield), or via the GetOrInsert() with an empty string.
// The initialization order of this class with the engine shards is slightly subtle, as they have
// mutual dependencies.
class Namespaces {
public:
Namespaces() = default;
~Namespaces();
void Init();
bool IsInitialized() const;
void Clear(); // Thread unsafe, use in tear-down or tests
Namespace& GetDefaultNamespace() const; // No locks
Namespace& GetOrInsert(std::string_view ns);
private:
util::fb2::SharedMutex mu_{};
absl::node_hash_map<std::string, Namespace> namespaces_ ABSL_GUARDED_BY(mu_);
Namespace* default_namespace_ = nullptr;
};
extern Namespaces namespaces;
} // namespace dfly

View file

@ -2066,7 +2066,8 @@ error_code RdbLoader::Load(io::Source* src) {
FlushShardAsync(i);
// Active database if not existed before.
shard_set->Add(i, [dbid] { EngineShard::tlocal()->db_slice().ActivateDb(dbid); });
shard_set->Add(
i, [dbid] { namespaces.GetDefaultNamespace().GetCurrentDbSlice().ActivateDb(dbid); });
}
cur_db_index_ = dbid;
@ -2451,8 +2452,8 @@ std::error_code RdbLoaderBase::FromOpaque(const OpaqueObj& opaque, CompactObj* p
void RdbLoader::LoadItemsBuffer(DbIndex db_ind, const ItemsBuf& ib) {
EngineShard* es = EngineShard::tlocal();
DbSlice& db_slice = es->db_slice();
DbContext db_cntx{db_ind, GetCurrentTimeMs()};
DbContext db_cntx{&namespaces.GetDefaultNamespace(), db_ind, GetCurrentTimeMs()};
DbSlice& db_slice = db_cntx.GetDbSlice(es->shard_id());
for (const auto* item : ib) {
PrimeValue pv;
@ -2564,6 +2565,7 @@ void RdbLoader::LoadSearchIndexDefFromAux(string&& def) {
cntx.is_replicating = true;
cntx.journal_emulated = true;
cntx.skip_acl_validation = true;
cntx.ns = &namespaces.GetDefaultNamespace();
// Avoid deleting local crb
absl::Cleanup cntx_clean = [&cntx] { cntx.Inject(nullptr); };
@ -2613,7 +2615,8 @@ void RdbLoader::PerformPostLoad(Service* service) {
// Rebuild all search indices as only their definitions are extracted from the snapshot
shard_set->AwaitRunningOnShardQueue([](EngineShard* es) {
es->search_indices()->RebuildAllIndices(OpArgs{es, nullptr, DbContext{0, GetCurrentTimeMs()}});
es->search_indices()->RebuildAllIndices(
OpArgs{es, nullptr, DbContext{&namespaces.GetDefaultNamespace(), 0, GetCurrentTimeMs()}});
});
}

View file

@ -35,6 +35,7 @@ extern "C" {
#include "server/engine_shard_set.h"
#include "server/error.h"
#include "server/main_service.h"
#include "server/namespaces.h"
#include "server/rdb_extensions.h"
#include "server/search/doc_index.h"
#include "server/serializer_commons.h"
@ -1251,15 +1252,17 @@ error_code RdbSaver::Impl::ConsumeChannel(const Cancellation* cll) {
void RdbSaver::Impl::StartSnapshotting(bool stream_journal, const Cancellation* cll,
EngineShard* shard) {
auto& s = GetSnapshot(shard);
s = std::make_unique<SliceSnapshot>(&shard->db_slice(), &channel_, compression_mode_);
auto& db_slice = namespaces.GetDefaultNamespace().GetDbSlice(shard->shard_id());
s = std::make_unique<SliceSnapshot>(&db_slice, &channel_, compression_mode_);
s->Start(stream_journal, cll);
}
void RdbSaver::Impl::StartIncrementalSnapshotting(Context* cntx, EngineShard* shard,
LSN start_lsn) {
auto& db_slice = namespaces.GetDefaultNamespace().GetDbSlice(shard->shard_id());
auto& s = GetSnapshot(shard);
s = std::make_unique<SliceSnapshot>(&shard->db_slice(), &channel_, compression_mode_);
s = std::make_unique<SliceSnapshot>(&db_slice, &channel_, compression_mode_);
s->StartIncremental(cntx, start_lsn);
}

View file

@ -570,6 +570,7 @@ error_code Replica::ConsumeRedisStream() {
conn_context.is_replicating = true;
conn_context.journal_emulated = true;
conn_context.skip_acl_validation = true;
conn_context.ns = &namespaces.GetDefaultNamespace();
ResetParser(true);
// Master waits for this command in order to start sending replication stream.

View file

@ -447,7 +447,7 @@ void ClientPauseCmd(CmdArgList args, vector<facade::Listener*> listeners, Connec
};
if (auto pause_fb_opt =
Pause(listeners, cntx->conn(), pause_state, std::move(is_pause_in_progress));
Pause(listeners, cntx->ns, cntx->conn(), pause_state, std::move(is_pause_in_progress));
pause_fb_opt) {
pause_fb_opt->Detach();
cntx->SendOk();
@ -673,8 +673,8 @@ optional<ReplicaOfArgs> ReplicaOfArgs::FromCmdArgs(CmdArgList args, ConnectionCo
} // namespace
std::optional<fb2::Fiber> Pause(std::vector<facade::Listener*> listeners, facade::Connection* conn,
ClientPause pause_state,
std::optional<fb2::Fiber> Pause(std::vector<facade::Listener*> listeners, Namespace* ns,
facade::Connection* conn, ClientPause pause_state,
std::function<bool()> is_pause_in_progress) {
// Track connections and set pause state to be able to wait untill all running transactions read
// the new pause state. Exlude already paused commands from the busy count. Exlude tracking
@ -683,7 +683,7 @@ std::optional<fb2::Fiber> Pause(std::vector<facade::Listener*> listeners, facade
// command that did not pause on the new state yet we will pause after waking up.
DispatchTracker tracker{std::move(listeners), conn, true /* ignore paused commands */,
true /*ignore blocking*/};
shard_set->pool()->AwaitBrief([&tracker, pause_state](unsigned, util::ProactorBase*) {
shard_set->pool()->AwaitBrief([&tracker, pause_state, ns](unsigned, util::ProactorBase*) {
// Commands don't suspend before checking the pause state, so
// it's impossible to deadlock on waiting for a command that will be paused.
tracker.TrackOnThread();
@ -703,9 +703,9 @@ std::optional<fb2::Fiber> Pause(std::vector<facade::Listener*> listeners, facade
// We should not expire/evict keys while clients are puased.
shard_set->RunBriefInParallel(
[](EngineShard* shard) { shard->db_slice().SetExpireAllowed(false); });
[ns](EngineShard* shard) { ns->GetDbSlice(shard->shard_id()).SetExpireAllowed(false); });
return fb2::Fiber("client_pause", [is_pause_in_progress, pause_state]() mutable {
return fb2::Fiber("client_pause", [is_pause_in_progress, pause_state, ns]() mutable {
// On server shutdown we sleep 10ms to make sure all running task finish, therefore 10ms steps
// ensure this fiber will not left hanging .
constexpr auto step = 10ms;
@ -719,7 +719,7 @@ std::optional<fb2::Fiber> Pause(std::vector<facade::Listener*> listeners, facade
ServerState::tlocal()->SetPauseState(pause_state, false);
});
shard_set->RunBriefInParallel(
[](EngineShard* shard) { shard->db_slice().SetExpireAllowed(true); });
[ns](EngineShard* shard) { ns->GetDbSlice(shard->shard_id()).SetExpireAllowed(true); });
}
});
}
@ -1345,7 +1345,8 @@ void ServerFamily::ConfigureMetrics(util::HttpListenerBase* http_base) {
auto cb = [this](const util::http::QueryArgs& args, util::HttpContext* send) {
StringResponse resp = util::http::MakeStringResponse(boost::beast::http::status::ok);
PrintPrometheusMetrics(this->GetMetrics(), this->dfly_cmd_.get(), &resp);
PrintPrometheusMetrics(this->GetMetrics(&namespaces.GetDefaultNamespace()),
this->dfly_cmd_.get(), &resp);
return send->Invoke(std::move(resp));
};
@ -1424,7 +1425,7 @@ void ServerFamily::StatsMC(std::string_view section, facade::ConnectionContext*
double utime = dbl_time(ru.ru_utime);
double systime = dbl_time(ru.ru_stime);
Metrics m = GetMetrics();
Metrics m = GetMetrics(&namespaces.GetDefaultNamespace());
ADD_LINE(pid, getpid());
ADD_LINE(uptime, m.uptime);
@ -1454,7 +1455,7 @@ GenericError ServerFamily::DoSave(bool ignore_state) {
const CommandId* cid = service().FindCmd("SAVE");
CHECK_NOTNULL(cid);
boost::intrusive_ptr<Transaction> trans(new Transaction{cid});
trans->InitByArgs(0, {});
trans->InitByArgs(&namespaces.GetDefaultNamespace(), 0, {});
return DoSave(absl::GetFlag(FLAGS_df_snapshot_format), {}, trans.get(), ignore_state);
}
@ -1551,7 +1552,7 @@ void ServerFamily::DbSize(CmdArgList args, ConnectionContext* cntx) {
shard_set->RunBriefInParallel(
[&](EngineShard* shard) {
auto db_size = shard->db_slice().DbSize(cntx->conn_state.db_index);
auto db_size = cntx->ns->GetDbSlice(shard->shard_id()).DbSize(cntx->conn_state.db_index);
num_keys.fetch_add(db_size, memory_order_relaxed);
},
[](ShardId) { return true; });
@ -1647,6 +1648,7 @@ void ServerFamily::Auth(CmdArgList args, ConnectionContext* cntx) {
auto cred = registry->GetCredentials(username);
cntx->acl_commands = cred.acl_commands;
cntx->keys = std::move(cred.keys);
cntx->ns = &namespaces.GetOrInsert(cred.ns);
cntx->authenticated = true;
return cntx->SendOk();
}
@ -1773,7 +1775,7 @@ void ServerFamily::Config(CmdArgList args, ConnectionContext* cntx) {
}
if (sub_cmd == "RESETSTAT") {
ResetStat();
ResetStat(cntx->ns);
return cntx->SendOk();
} else {
return cntx->SendError(UnknownSubCmd(sub_cmd, "CONFIG"), kSyntaxErrType);
@ -1883,17 +1885,16 @@ static void MergeDbSliceStats(const DbSlice::Stats& src, Metrics* dest) {
dest->small_string_bytes += src.small_string_bytes;
}
void ServerFamily::ResetStat() {
void ServerFamily::ResetStat(Namespace* ns) {
shard_set->pool()->AwaitBrief(
[registry = service_.mutable_registry(), this](unsigned index, auto*) {
[registry = service_.mutable_registry(), this, ns](unsigned index, auto*) {
registry->ResetCallStats(index);
SinkReplyBuilder::ResetThreadLocalStats();
auto& stats = tl_facade_stats->conn_stats;
stats.command_cnt = 0;
stats.pipelined_cmd_cnt = 0;
EngineShard* shard = EngineShard::tlocal();
shard->db_slice().ResetEvents();
ns->GetCurrentDbSlice().ResetEvents();
tl_facade_stats->conn_stats.conn_received_cnt = 0;
tl_facade_stats->conn_stats.pipelined_cmd_cnt = 0;
tl_facade_stats->conn_stats.command_cnt = 0;
@ -1910,7 +1911,7 @@ void ServerFamily::ResetStat() {
});
}
Metrics ServerFamily::GetMetrics() const {
Metrics ServerFamily::GetMetrics(Namespace* ns) const {
Metrics result;
util::fb2::Mutex mu;
@ -1942,7 +1943,7 @@ Metrics ServerFamily::GetMetrics() const {
if (shard) {
result.heap_used_bytes += shard->UsedMemory();
MergeDbSliceStats(shard->db_slice().GetStats(), &result);
MergeDbSliceStats(ns->GetDbSlice(shard->shard_id()).GetStats(), &result);
result.shard_stats += shard->stats();
if (shard->tiered_storage()) {
@ -2017,7 +2018,7 @@ void ServerFamily::Info(CmdArgList args, ConnectionContext* cntx) {
absl::StrAppend(&info, a1, ":", a2, "\r\n");
};
Metrics m = GetMetrics();
Metrics m = GetMetrics(cntx->ns);
DbStats total;
for (const auto& db_stats : m.db_stats)
total += db_stats;
@ -2589,8 +2590,9 @@ void ServerFamily::ReplicaOf(CmdArgList args, ConnectionContext* cntx) {
void ServerFamily::Replicate(string_view host, string_view port) {
io::NullSink sink;
ConnectionContext ctxt{&sink, nullptr, {}};
ctxt.skip_acl_validation = true;
ConnectionContext cntx{&sink, nullptr, {}};
cntx.ns = &namespaces.GetDefaultNamespace();
cntx.skip_acl_validation = true;
StringVec replicaof_params{string(host), string(port)};
@ -2599,7 +2601,7 @@ void ServerFamily::Replicate(string_view host, string_view port) {
args_vec.emplace_back(MutableSlice{s.data(), s.size()});
}
CmdArgList args_list = absl::MakeSpan(args_vec);
ReplicaOfInternal(args_list, &ctxt, ActionOnConnectionFail::kContinueReplication);
ReplicaOfInternal(args_list, &cntx, ActionOnConnectionFail::kContinueReplication);
}
// REPLTAKEOVER <seconds> [SAVE]

View file

@ -15,6 +15,7 @@
#include "server/detail/save_stages_controller.h"
#include "server/dflycmd.h"
#include "server/engine_shard_set.h"
#include "server/namespaces.h"
#include "server/replica.h"
#include "server/server_state.h"
#include "util/fibers/fiberqueue_threadpool.h"
@ -158,9 +159,9 @@ class ServerFamily {
return service_;
}
void ResetStat();
void ResetStat(Namespace* ns);
Metrics GetMetrics() const;
Metrics GetMetrics(Namespace* ns) const;
ScriptMgr* script_mgr() {
return script_mgr_.get();
@ -337,7 +338,7 @@ class ServerFamily {
};
// Reusable CLIENT PAUSE implementation that blocks while polling is_pause_in_progress
std::optional<util::fb2::Fiber> Pause(std::vector<facade::Listener*> listeners,
std::optional<util::fb2::Fiber> Pause(std::vector<facade::Listener*> listeners, Namespace* ns,
facade::Connection* conn, ClientPause pause_state,
std::function<bool()> is_pause_in_progress);

View file

@ -648,9 +648,9 @@ OpResult<streamID> OpAdd(const OpArgs& op_args, const AddTrimOpts& opts, CmdArgL
StreamTrim(opts, stream_inst);
EngineShard* es = op_args.shard;
if (es->blocking_controller()) {
es->blocking_controller()->AwakeWatched(op_args.db_cntx.db_index, opts.key);
auto blocking_controller = op_args.db_cntx.ns->GetBlockingController(op_args.shard->shard_id());
if (blocking_controller) {
blocking_controller->AwakeWatched(op_args.db_cntx.db_index, opts.key);
}
return result_id;
@ -2364,7 +2364,7 @@ void StreamFamily::XInfo(CmdArgList args, ConnectionContext* cntx) {
// We do not use transactional xemantics for xinfo since it's informational command.
auto cb = [&]() {
EngineShard* shard = EngineShard::tlocal();
DbContext db_context{cntx->db_index(), GetCurrentTimeMs()};
DbContext db_context{cntx->ns, cntx->db_index(), GetCurrentTimeMs()};
return OpListGroups(db_context, key, shard);
};
@ -2433,7 +2433,8 @@ void StreamFamily::XInfo(CmdArgList args, ConnectionContext* cntx) {
auto cb = [&]() {
EngineShard* shard = EngineShard::tlocal();
return OpStreams(DbContext{cntx->db_index(), GetCurrentTimeMs()}, key, shard, full, count);
return OpStreams(DbContext{cntx->ns, cntx->db_index(), GetCurrentTimeMs()}, key, shard,
full, count);
};
auto* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder());
@ -2561,8 +2562,8 @@ void StreamFamily::XInfo(CmdArgList args, ConnectionContext* cntx) {
string_view stream_name = ArgS(args, 1);
string_view group_name = ArgS(args, 2);
auto cb = [&]() {
return OpConsumers(DbContext{cntx->db_index(), GetCurrentTimeMs()}, EngineShard::tlocal(),
stream_name, group_name);
return OpConsumers(DbContext{cntx->ns, cntx->db_index(), GetCurrentTimeMs()},
EngineShard::tlocal(), stream_name, group_name);
};
OpResult<vector<ConsumerInfo>> result = shard_set->Await(sid, std::move(cb));

View file

@ -1226,9 +1226,10 @@ void StringFamily::MSetNx(CmdArgList args, ConnectionContext* cntx) {
atomic_bool exists{false};
auto cb = [&](Transaction* t, EngineShard* es) {
auto args = t->GetShardArgs(es->shard_id());
auto sid = es->shard_id();
auto args = t->GetShardArgs(sid);
for (auto arg_it = args.begin(); arg_it != args.end(); ++arg_it) {
auto it = es->db_slice().FindReadOnly(t->GetDbContext(), *arg_it).it;
auto it = cntx->ns->GetDbSlice(sid).FindReadOnly(t->GetDbContext(), *arg_it).it;
++arg_it;
if (IsValid(it)) {
exists.store(true, memory_order_relaxed);

View file

@ -804,7 +804,7 @@ TEST_F(StringFamilyTest, SetWithHashtagsNoCluster) {
auto fb = ExpectUsedKeys({"{key}1"});
EXPECT_EQ(Run({"set", "{key}1", "val1"}), "OK");
fb.Join();
EXPECT_FALSE(service_->IsLocked(0, "{key}1"));
EXPECT_FALSE(IsLocked(0, "{key}1"));
fb = ExpectUsedKeys({"{key}2"});
EXPECT_EQ(Run({"set", "{key}2", "val2"}), "OK");

View file

@ -81,7 +81,7 @@ void TransactionSuspension::Start() {
transaction_ = new dfly::Transaction{&cid};
auto st = transaction_->InitByArgs(0, {});
auto st = transaction_->InitByArgs(&namespaces.GetDefaultNamespace(), 0, {});
CHECK_EQ(st, OpStatus::OK);
transaction_->Execute([](Transaction* t, EngineShard* shard) { return OpStatus::OK; }, false);
@ -107,7 +107,9 @@ class BaseFamilyTest::TestConnWrapper {
const facade::Connection::InvalidationMessage& GetInvalidationMessage(size_t index) const;
ConnectionContext* cmd_cntx() {
return static_cast<ConnectionContext*>(dummy_conn_->cntx());
auto cntx = static_cast<ConnectionContext*>(dummy_conn_->cntx());
cntx->ns = &namespaces.GetDefaultNamespace();
return cntx;
}
StringVec SplitLines() const {
@ -210,7 +212,10 @@ void BaseFamilyTest::ResetService() {
used_mem_current = 0;
TEST_current_time_ms = absl::GetCurrentTimeNanos() / 1000000;
auto cb = [&](EngineShard* s) { s->db_slice().UpdateExpireBase(TEST_current_time_ms - 1000, 0); };
auto default_ns = &namespaces.GetDefaultNamespace();
auto cb = [&](EngineShard* s) {
default_ns->GetDbSlice(s->shard_id()).UpdateExpireBase(TEST_current_time_ms - 1000, 0);
};
shard_set->RunBriefInParallel(cb);
const TestInfo* const test_info = UnitTest::GetInstance()->current_test_info();
@ -244,7 +249,10 @@ void BaseFamilyTest::ResetService() {
}
LOG(ERROR) << "TxLocks for shard " << es->shard_id();
for (const auto& k_v : es->db_slice().GetDBTable(0)->trans_locks) {
for (const auto& k_v : namespaces.GetDefaultNamespace()
.GetDbSlice(es->shard_id())
.GetDBTable(0)
->trans_locks) {
LOG(ERROR) << "Key " << k_v.first << " " << k_v.second;
}
}
@ -264,6 +272,7 @@ void BaseFamilyTest::ShutdownService() {
service_->Shutdown();
service_.reset();
delete shard_set;
shard_set = nullptr;
@ -295,8 +304,9 @@ void BaseFamilyTest::CleanupSnapshots() {
unsigned BaseFamilyTest::NumLocked() {
atomic_uint count = 0;
auto default_ns = &namespaces.GetDefaultNamespace();
shard_set->RunBriefInParallel([&](EngineShard* shard) {
for (const auto& db : shard->db_slice().databases()) {
for (const auto& db : default_ns->GetDbSlice(shard->shard_id()).databases()) {
if (db == nullptr) {
continue;
}
@ -375,6 +385,7 @@ RespExpr BaseFamilyTest::Run(std::string_view id, ArgSlice slice) {
CmdArgVec args = conn_wrapper->Args(slice);
auto* context = conn_wrapper->cmd_cntx();
context->ns = &namespaces.GetDefaultNamespace();
DCHECK(context->transaction == nullptr) << id;
@ -551,12 +562,7 @@ BaseFamilyTest::TestConnWrapper::GetInvalidationMessage(size_t index) const {
}
bool BaseFamilyTest::IsLocked(DbIndex db_index, std::string_view key) const {
ShardId sid = Shard(key, shard_set->size());
bool is_open = pp_->at(sid)->AwaitBrief([db_index, key] {
return EngineShard::tlocal()->db_slice().CheckLock(IntentLock::EXCLUSIVE, db_index, key);
});
return !is_open;
return service_->IsLocked(&namespaces.GetDefaultNamespace(), db_index, key);
}
string BaseFamilyTest::GetId() const {
@ -643,7 +649,8 @@ vector<LockFp> BaseFamilyTest::GetLastFps() {
}
lock_guard lk(mu);
for (auto fp : shard->db_slice().TEST_GetLastLockedFps()) {
for (auto fp :
namespaces.GetDefaultNamespace().GetDbSlice(shard->shard_id()).TEST_GetLastLockedFps()) {
result.push_back(fp);
}
};

View file

@ -117,7 +117,7 @@ class BaseFamilyTest : public ::testing::Test {
static std::vector<std::string> StrArray(const RespExpr& expr);
Metrics GetMetrics() const {
return service_->server_family().GetMetrics();
return service_->server_family().GetMetrics(&namespaces.GetDefaultNamespace());
}
void ClearMetrics();

View file

@ -173,8 +173,9 @@ Transaction::~Transaction() {
<< " destroyed";
}
void Transaction::InitBase(DbIndex dbid, CmdArgList args) {
void Transaction::InitBase(Namespace* ns, DbIndex dbid, CmdArgList args) {
global_ = false;
namespace_ = ns;
db_index_ = dbid;
full_args_ = args;
local_result_ = OpStatus::OK;
@ -359,8 +360,8 @@ void Transaction::InitByKeys(const KeyIndex& key_index) {
}
}
OpStatus Transaction::InitByArgs(DbIndex index, CmdArgList args) {
InitBase(index, args);
OpStatus Transaction::InitByArgs(Namespace* ns, DbIndex index, CmdArgList args) {
InitBase(ns, index, args);
if ((cid_->opt_mask() & CO::GLOBAL_TRANS) > 0) {
InitGlobal();
@ -393,7 +394,7 @@ void Transaction::PrepareSquashedMultiHop(const CommandId* cid,
MultiSwitchCmd(cid);
InitBase(db_index_, {});
InitBase(namespace_, db_index_, {});
// Because squashing already determines active shards by partitioning commands,
// we don't have to work with keys manually and can just mark active shards.
@ -414,19 +415,20 @@ void Transaction::PrepareSquashedMultiHop(const CommandId* cid,
MultiBecomeSquasher();
}
void Transaction::StartMultiGlobal(DbIndex dbid) {
void Transaction::StartMultiGlobal(Namespace* ns, DbIndex dbid) {
CHECK(multi_);
CHECK(shard_data_.empty()); // Make sure default InitByArgs didn't run.
multi_->mode = GLOBAL;
InitBase(dbid, {});
InitBase(ns, dbid, {});
InitGlobal();
multi_->lock_mode = IntentLock::EXCLUSIVE;
ScheduleInternal();
}
void Transaction::StartMultiLockedAhead(DbIndex dbid, CmdArgList keys, bool skip_scheduling) {
void Transaction::StartMultiLockedAhead(Namespace* ns, DbIndex dbid, CmdArgList keys,
bool skip_scheduling) {
DVLOG(1) << "StartMultiLockedAhead on " << keys.size() << " keys";
DCHECK(multi_);
@ -437,7 +439,7 @@ void Transaction::StartMultiLockedAhead(DbIndex dbid, CmdArgList keys, bool skip
PrepareMultiFps(keys);
InitBase(dbid, keys);
InitBase(ns, dbid, keys);
InitByKeys(KeyIndex::Range(0, keys.size()));
if (!skip_scheduling)
@ -504,6 +506,7 @@ void Transaction::MultiUpdateWithParent(const Transaction* parent) {
txid_ = parent->txid_;
time_now_ms_ = parent->time_now_ms_;
unique_slot_checker_ = parent->unique_slot_checker_;
namespace_ = parent->namespace_;
}
void Transaction::MultiBecomeSquasher() {
@ -528,9 +531,10 @@ string Transaction::DebugId(std::optional<ShardId> sid) const {
return res;
}
void Transaction::PrepareMultiForScheduleSingleHop(ShardId sid, DbIndex db, CmdArgList args) {
void Transaction::PrepareMultiForScheduleSingleHop(Namespace* ns, ShardId sid, DbIndex db,
CmdArgList args) {
multi_.reset();
InitBase(db, args);
InitBase(ns, db, args);
EnableShard(sid);
OpResult<KeyIndex> key_index = DetermineKeys(cid_, args);
CHECK(key_index);
@ -608,7 +612,8 @@ bool Transaction::RunInShard(EngineShard* shard, bool txq_ooo) {
// 1: to go over potential wakened keys, verify them and activate watch queues.
// 2: if this transaction was notified and finished running - to remove it from the head
// of the queue and notify the next one.
if (auto* bcontroller = shard->blocking_controller(); bcontroller) {
if (auto* bcontroller = namespace_->GetBlockingController(shard->shard_id()); bcontroller) {
if (awaked_prerun || was_suspended) {
bcontroller->FinalizeWatched(GetShardArgs(idx), this);
}
@ -864,6 +869,7 @@ void Transaction::DispatchHop() {
use_count_.fetch_add(run_cnt, memory_order_relaxed); // for each pointer from poll_cb
auto poll_cb = [this] {
CHECK(namespace_ != nullptr);
EngineShard::tlocal()->PollExecution("exec_cb", this);
DVLOG(3) << "ptr_release " << DebugId();
intrusive_ptr_release(this); // against use_count_.fetch_add above.
@ -1149,7 +1155,7 @@ OpStatus Transaction::WaitOnWatch(const time_point& tp, WaitKeysProvider wkeys_p
// Register keys on active shards blocking controllers and mark shard state as suspended.
auto cb = [&](Transaction* t, EngineShard* shard) {
auto keys = wkeys_provider(t, shard);
return t->WatchInShard(keys, shard, krc);
return t->WatchInShard(&t->GetNamespace(), keys, shard, krc);
};
Execute(std::move(cb), true);
@ -1187,7 +1193,7 @@ OpStatus Transaction::WaitOnWatch(const time_point& tp, WaitKeysProvider wkeys_p
return result;
}
OpStatus Transaction::WatchInShard(BlockingController::Keys keys, EngineShard* shard,
OpStatus Transaction::WatchInShard(Namespace* ns, BlockingController::Keys keys, EngineShard* shard,
KeyReadyChecker krc) {
auto& sd = shard_data_[SidToId(shard->shard_id())];
@ -1195,7 +1201,7 @@ OpStatus Transaction::WatchInShard(BlockingController::Keys keys, EngineShard* s
sd.local_mask |= SUSPENDED_Q;
sd.local_mask &= ~OUT_OF_ORDER;
shard->EnsureBlockingController()->AddWatched(keys, std::move(krc), this);
ns->GetOrAddBlockingController(shard)->AddWatched(keys, std::move(krc), this);
DVLOG(2) << "WatchInShard " << DebugId();
return OpStatus::OK;
@ -1209,8 +1215,10 @@ void Transaction::ExpireShardCb(BlockingController::Keys keys, EngineShard* shar
auto& sd = shard_data_[SidToId(shard->shard_id())];
sd.local_mask &= ~KEYLOCK_ACQUIRED;
shard->blocking_controller()->FinalizeWatched(keys, this);
DCHECK(!shard->blocking_controller()->awakened_transactions().contains(this));
namespace_->GetBlockingController(shard->shard_id())->FinalizeWatched(keys, this);
DCHECK(!namespace_->GetBlockingController(shard->shard_id())
->awakened_transactions()
.contains(this));
// Resume processing of transaction queue
shard->PollExecution("unwatchcb", nullptr);
@ -1218,9 +1226,8 @@ void Transaction::ExpireShardCb(BlockingController::Keys keys, EngineShard* shar
}
DbSlice& Transaction::GetDbSlice(ShardId shard_id) const {
auto* shard = EngineShard::tlocal();
DCHECK_EQ(shard->shard_id(), shard_id);
return shard->db_slice();
CHECK(namespace_ != nullptr);
return namespace_->GetDbSlice(shard_id);
}
OpStatus Transaction::RunSquashedMultiCb(RunnableType cb) {
@ -1270,8 +1277,9 @@ void Transaction::UnlockMultiShardCb(absl::Span<const LockFp> fps, EngineShard*
shard->RemoveContTx(this);
// Wake only if no tx queue head is currently running
if (shard->blocking_controller() && shard->GetContTx() == nullptr)
shard->blocking_controller()->NotifyPending();
auto bc = namespace_->GetBlockingController(shard->shard_id());
if (bc && shard->GetContTx() == nullptr)
bc->NotifyPending();
shard->PollExecution("unlockmulti", nullptr);
}

View file

@ -21,6 +21,7 @@
#include "server/cluster/cluster_utility.h"
#include "server/common.h"
#include "server/journal/types.h"
#include "server/namespaces.h"
#include "server/table.h"
#include "server/tx_base.h"
#include "util/fibers/synchronization.h"
@ -185,7 +186,7 @@ class Transaction {
std::optional<cluster::SlotId> slot_id);
// Initialize from command (args) on specific db.
OpStatus InitByArgs(DbIndex index, CmdArgList args);
OpStatus InitByArgs(Namespace* ns, DbIndex index, CmdArgList args);
// Get command arguments for specific shard. Called from shard thread.
ShardArgs GetShardArgs(ShardId sid) const;
@ -230,10 +231,11 @@ class Transaction {
void PrepareSquashedMultiHop(const CommandId* cid, absl::FunctionRef<bool(ShardId)> enabled);
// Start multi in GLOBAL mode.
void StartMultiGlobal(DbIndex dbid);
void StartMultiGlobal(Namespace* ns, DbIndex dbid);
// Start multi in LOCK_AHEAD mode with given keys.
void StartMultiLockedAhead(DbIndex dbid, CmdArgList keys, bool skip_scheduling = false);
void StartMultiLockedAhead(Namespace* ns, DbIndex dbid, CmdArgList keys,
bool skip_scheduling = false);
// Start multi in NON_ATOMIC mode.
void StartMultiNonAtomic();
@ -311,7 +313,11 @@ class Transaction {
bool IsGlobal() const;
DbContext GetDbContext() const {
return DbContext{db_index_, time_now_ms_};
return DbContext{namespace_, db_index_, time_now_ms_};
}
Namespace& GetNamespace() const {
return *namespace_;
}
DbSlice& GetDbSlice(ShardId sid) const;
@ -330,7 +336,7 @@ class Transaction {
// Prepares for running ScheduleSingleHop() for a single-shard multi tx.
// It is safe to call ScheduleSingleHop() after calling this method, but the callback passed
// to it must not block.
void PrepareMultiForScheduleSingleHop(ShardId sid, DbIndex db, CmdArgList args);
void PrepareMultiForScheduleSingleHop(Namespace* ns, ShardId sid, DbIndex db, CmdArgList args);
// Write a journal entry to a shard journal with the given payload.
void LogJournalOnShard(EngineShard* shard, journal::Entry::Payload&& payload, uint32_t shard_cnt,
@ -479,7 +485,7 @@ class Transaction {
};
// Init basic fields and reset re-usable.
void InitBase(DbIndex dbid, CmdArgList args);
void InitBase(Namespace* ns, DbIndex dbid, CmdArgList args);
// Init as a global transaction.
void InitGlobal();
@ -518,7 +524,7 @@ class Transaction {
void RunCallback(EngineShard* shard);
// Adds itself to watched queue in the shard. Must run in that shard thread.
OpStatus WatchInShard(std::variant<ShardArgs, ArgSlice> keys, EngineShard* shard,
OpStatus WatchInShard(Namespace* ns, std::variant<ShardArgs, ArgSlice> keys, EngineShard* shard,
KeyReadyChecker krc);
// Expire blocking transaction, unlock keys and unregister it from the blocking controller
@ -612,6 +618,7 @@ class Transaction {
TxId txid_{0};
bool global_{false};
Namespace* namespace_{nullptr};
DbIndex db_index_{0};
uint64_t time_now_ms_{0};

View file

@ -9,6 +9,7 @@
#include "server/cluster/cluster_defs.h"
#include "server/engine_shard_set.h"
#include "server/journal/journal.h"
#include "server/namespaces.h"
#include "server/transaction.h"
namespace dfly {
@ -17,14 +18,11 @@ using namespace std;
using Payload = journal::Entry::Payload;
DbSlice& DbContext::GetDbSlice(ShardId shard_id) const {
// TODO: Update this when adding namespaces
DCHECK_EQ(shard_id, EngineShard::tlocal()->shard_id());
return EngineShard::tlocal()->db_slice();
return ns->GetDbSlice(shard_id);
}
DbSlice& OpArgs::GetDbSlice() const {
// TODO: Update this when adding namespaces
return shard->db_slice();
return db_cntx.GetDbSlice(shard->shard_id());
}
size_t ShardArgs::Size() const {

View file

@ -14,6 +14,7 @@ namespace dfly {
class EngineShard;
class Transaction;
class Namespace;
class DbSlice;
using DbIndex = uint16_t;
@ -58,6 +59,7 @@ struct KeyIndex {
};
struct DbContext {
Namespace* ns = nullptr;
DbIndex db_index = 0;
uint64_t time_now_ms = 0;

View file

@ -213,10 +213,11 @@ OpResult<DbSlice::ItAndUpdater> FindZEntry(const ZParams& zparams, const OpArgs&
return OpStatus::WRONG_TYPE;
}
if (add_res.is_new && op_args.shard->blocking_controller()) {
auto* blocking_controller = op_args.db_cntx.ns->GetBlockingController(op_args.shard->shard_id());
if (add_res.is_new && blocking_controller) {
string tmp;
string_view key = it->first.GetSlice(&tmp);
op_args.shard->blocking_controller()->AwakeWatched(op_args.db_cntx.db_index, key);
blocking_controller->AwakeWatched(op_args.db_cntx.db_index, key);
}
return DbSlice::ItAndUpdater{add_res.it, add_res.exp_it, std::move(add_res.post_updater)};

View file

@ -2,7 +2,7 @@ import pytest
import redis
from redis import asyncio as aioredis
from .instance import DflyInstanceFactory
from .utility import disconnect_clients
from .utility import *
import tempfile
import asyncio
import os
@ -567,6 +567,43 @@ async def test_acl_keys(async_client):
await async_client.execute_command("ZUNIONSTORE destkey 2 barz1 barz2")
@pytest.mark.asyncio
async def test_namespaces(df_factory):
df = df_factory.create()
df.start()
admin = aioredis.Redis(port=df.port)
assert await admin.execute_command("SET foo admin") == b"OK"
assert await admin.execute_command("GET foo") == b"admin"
# Create ns space named 'ns1'
await admin.execute_command("ACL SETUSER adi NAMESPACE:ns1 ON >adi_pass +@all ~*")
adi = aioredis.Redis(port=df.port)
assert await adi.execute_command("AUTH adi adi_pass") == b"OK"
assert await adi.execute_command("SET foo bar") == b"OK"
assert await adi.execute_command("GET foo") == b"bar"
assert await admin.execute_command("GET foo") == b"admin"
# Adi and Shahar are on the same team
await admin.execute_command("ACL SETUSER shahar NAMESPACE:ns1 ON >shahar_pass +@all ~*")
shahar = aioredis.Redis(port=df.port)
assert await shahar.execute_command("AUTH shahar shahar_pass") == b"OK"
assert await shahar.execute_command("GET foo") == b"bar"
assert await shahar.execute_command("SET foo bar2") == b"OK"
assert await adi.execute_command("GET foo") == b"bar2"
# Roman is a CTO, he has his own private space
await admin.execute_command("ACL SETUSER roman NAMESPACE:ns2 ON >roman_pass +@all ~*")
roman = aioredis.Redis(port=df.port)
assert await roman.execute_command("AUTH roman roman_pass") == b"OK"
assert await roman.execute_command("GET foo") == None
await close_clients(admin, adi, shahar, roman)
@pytest.mark.asyncio
async def default_user_bug(df_factory):
df.start()