1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-15 17:51:06 +00:00

chore: Introduce ShardArgs as a distinct type (#2952)

Done in preparation to make ShardArgs a smart iterable type,
but currently it's just a wrapper aroung ArgSlice.
Also refactored common.{h,cc} into tx_base.{h,cc}

In addition, fixed a bug in key tracking, where we wrongly created weak_ref
in a shard thread instead of doing this in the coordinator thread.
Finally, identified another bug (not fixed yet) where we track all the arguments
instead of tracking keys only.

Besides this, no functional changes around the moved code.

Signed-off-by: Roman Gershman <roman@dragonflydb.io>
This commit is contained in:
Roman Gershman 2024-04-24 13:36:34 +03:00 committed by GitHub
parent 2230397a12
commit 89b1d7d52a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 568 additions and 476 deletions

View file

@ -31,7 +31,7 @@ add_library(dfly_transaction db_slice.cc malloc_stats.cc blocking_controller.cc
command_registry.cc cluster/unique_slot_checker.cc
journal/tx_executor.cc
common.cc journal/journal.cc journal/types.cc journal/journal_slice.cc
server_state.cc table.cc top_keys.cc transaction.cc
server_state.cc table.cc top_keys.cc transaction.cc tx_base.cc
serializer_commons.cc journal/serializer.cc journal/executor.cc journal/streamer.cc
${TX_LINUX_SRCS} acl/acl_log.cc slowlog.cc
)

View file

@ -24,6 +24,7 @@
namespace dfly {
using namespace facade;
using namespace std;
namespace {
@ -57,7 +58,6 @@ bool SetBitValue(uint32_t offset, bool bit_value, std::string* entry);
std::size_t CountBitSetByByteIndices(std::string_view at, std::size_t start, std::size_t end);
std::size_t CountBitSet(std::string_view str, int64_t start, int64_t end, bool bits);
std::size_t CountBitSetByBitIndices(std::string_view at, std::size_t start, std::size_t end);
OpResult<std::string> RunBitOpOnShard(std::string_view op, const OpArgs& op_args, ArgSlice keys);
std::string RunBitOperationOnValues(std::string_view op, const BitsStrVec& values);
// ------------------------------------------------------------------------- //
@ -444,12 +444,9 @@ OpResult<std::string> CombineResultOp(ShardStringResults result, std::string_vie
}
// For bitop not - we cannot accumulate
OpResult<std::string> RunBitOpNot(const OpArgs& op_args, ArgSlice keys) {
DCHECK(keys.size() == 1);
OpResult<std::string> RunBitOpNot(const OpArgs& op_args, string_view key) {
EngineShard* es = op_args.shard;
// if we found the value, just return, if not found then skip, otherwise report an error
auto key = keys.front();
auto find_res = es->db_slice().FindAndFetchReadOnly(op_args.db_cntx, key, OBJ_STRING);
if (find_res) {
return GetString(find_res.value()->second);
@ -460,18 +457,18 @@ OpResult<std::string> RunBitOpNot(const OpArgs& op_args, ArgSlice keys) {
// Read only operation where we are running the bit operation on all the
// values that belong to same shard.
OpResult<std::string> RunBitOpOnShard(std::string_view op, const OpArgs& op_args, ArgSlice keys) {
DCHECK(!keys.empty());
OpResult<std::string> RunBitOpOnShard(std::string_view op, const OpArgs& op_args,
ShardArgs::Iterator start, ShardArgs::Iterator end) {
DCHECK(start != end);
if (op == NOT_OP_NAME) {
return RunBitOpNot(op_args, keys);
return RunBitOpNot(op_args, *start);
}
EngineShard* es = op_args.shard;
BitsStrVec values;
values.reserve(keys.size());
// collect all the value for this shard
for (auto& key : keys) {
auto find_res = es->db_slice().FindAndFetchReadOnly(op_args.db_cntx, key, OBJ_STRING);
for (; start != end; ++start) {
auto find_res = es->db_slice().FindAndFetchReadOnly(op_args.db_cntx, *start, OBJ_STRING);
if (find_res) {
values.emplace_back(GetString(find_res.value()->second));
} else {
@ -1143,18 +1140,18 @@ void BitOp(CmdArgList args, ConnectionContext* cntx) {
ShardId dest_shard = Shard(dest_key, result_set.size());
auto shard_bitop = [&](Transaction* t, EngineShard* shard) {
ArgSlice largs = t->GetShardArgs(shard->shard_id());
DCHECK(!largs.empty());
ShardArgs largs = t->GetShardArgs(shard->shard_id());
DCHECK(!largs.Empty());
ShardArgs::Iterator start = largs.begin(), end = largs.end();
if (shard->shard_id() == dest_shard) {
CHECK_EQ(largs.front(), dest_key);
largs.remove_prefix(1);
if (largs.empty()) { // no more keys to check
CHECK_EQ(*start, dest_key);
++start;
if (start == end) { // no more keys to check
return OpStatus::OK;
}
}
OpArgs op_args = t->GetOpArgs(shard);
result_set[shard->shard_id()] = RunBitOpOnShard(op, op_args, largs);
result_set[shard->shard_id()] = RunBitOpOnShard(op, op_args, start, end);
return OpStatus::OK;
};

View file

@ -118,7 +118,7 @@ bool BlockingController::DbWatchTable::AddAwakeEvent(string_view key) {
}
// Removes tx from its watch queues if tx appears there.
void BlockingController::FinalizeWatched(ArgSlice args, Transaction* tx) {
void BlockingController::FinalizeWatched(const ShardArgs& args, Transaction* tx) {
DCHECK(tx);
VLOG(1) << "FinalizeBlocking [" << owner_->shard_id() << "]" << tx->DebugId();
@ -197,7 +197,8 @@ void BlockingController::NotifyPending() {
awakened_indices_.clear();
}
void BlockingController::AddWatched(ArgSlice keys, KeyReadyChecker krc, Transaction* trans) {
void BlockingController::AddWatched(const ShardArgs& watch_keys, KeyReadyChecker krc,
Transaction* trans) {
auto [dbit, added] = watched_dbs_.emplace(trans->GetDbIndex(), nullptr);
if (added) {
dbit->second.reset(new DbWatchTable);
@ -205,7 +206,7 @@ void BlockingController::AddWatched(ArgSlice keys, KeyReadyChecker krc, Transact
DbWatchTable& wt = *dbit->second;
for (auto key : keys) {
for (auto key : watch_keys) {
auto [res, inserted] = wt.queue_map.emplace(key, nullptr);
if (inserted) {
res->second.reset(new WatchQueue);

View file

@ -10,6 +10,7 @@
#include "base/string_view_sso.h"
#include "server/common.h"
#include "server/tx_base.h"
namespace dfly {
@ -28,7 +29,7 @@ class BlockingController {
return awakened_transactions_;
}
void FinalizeWatched(ArgSlice args, Transaction* tx);
void FinalizeWatched(const ShardArgs& args, Transaction* tx);
// go over potential wakened keys, verify them and activate watch queues.
void NotifyPending();
@ -37,7 +38,7 @@ class BlockingController {
// TODO: consider moving all watched functions to
// EngineShard with separate per db map.
//! AddWatched adds a transaction to the blocking queue.
void AddWatched(ArgSlice watch_keys, KeyReadyChecker krc, Transaction* me);
void AddWatched(const ShardArgs& watch_keys, KeyReadyChecker krc, Transaction* me);
// Called from operations that create keys like lpush, rename etc.
void AwakeWatched(DbIndex db_index, std::string_view db_key);

View file

@ -255,30 +255,6 @@ bool ParseDouble(string_view src, double* value) {
return true;
}
void RecordJournal(const OpArgs& op_args, string_view cmd, ArgSlice args, uint32_t shard_cnt,
bool multi_commands) {
VLOG(2) << "Logging command " << cmd << " from txn " << op_args.tx->txid();
op_args.tx->LogJournalOnShard(op_args.shard, make_pair(cmd, args), shard_cnt, multi_commands,
false);
}
void RecordJournalFinish(const OpArgs& op_args, uint32_t shard_cnt) {
op_args.tx->FinishLogJournalOnShard(op_args.shard, shard_cnt);
}
void RecordExpiry(DbIndex dbid, string_view key) {
auto journal = EngineShard::tlocal()->journal();
CHECK(journal);
journal->RecordEntry(0, journal::Op::EXPIRED, dbid, 1, cluster::KeySlot(key),
make_pair("DEL", ArgSlice{key}), false);
}
void TriggerJournalWriteToSink() {
auto journal = EngineShard::tlocal()->journal();
CHECK(journal);
journal->RecordEntry(0, journal::Op::NOOP, 0, 0, nullopt, {}, true);
}
#define ADD(x) (x) += o.x
IoMgrStats& IoMgrStats::operator+=(const IoMgrStats& rhs) {
@ -462,24 +438,4 @@ std::ostream& operator<<(std::ostream& os, const GlobalState& state) {
return os << GlobalStateName(state);
}
std::ostream& operator<<(std::ostream& os, ArgSlice list) {
os << "[";
if (!list.empty()) {
std::for_each(list.begin(), list.end() - 1, [&os](const auto& val) { os << val << ", "; });
os << (*(list.end() - 1));
}
return os << "]";
}
LockTag::LockTag(std::string_view key) {
if (LockTagOptions::instance().enabled)
str_ = LockTagOptions::instance().Tag(key);
else
str_ = key;
}
LockFp LockTag::Fingerprint() const {
return XXH64(str_.data(), str_.size(), 0x1C69B3F74AC4AE35UL);
}
} // namespace dfly

View file

@ -1,4 +1,4 @@
// Copyright 2022, DragonflyDB authors. All rights reserved.
// Copyright 2024, DragonflyDB authors. All rights reserved.
// See LICENSE for licensing terms.
//
@ -27,8 +27,6 @@ enum class ListDir : uint8_t { LEFT, RIGHT };
constexpr int64_t kMaxExpireDeadlineSec = (1u << 28) - 1; // 8.5 years
constexpr int64_t kMaxExpireDeadlineMs = kMaxExpireDeadlineSec * 1000;
using DbIndex = uint16_t;
using ShardId = uint16_t;
using LSN = uint64_t;
using TxId = uint64_t;
using TxClock = uint64_t;
@ -39,17 +37,11 @@ using facade::CmdArgVec;
using facade::MutableSlice;
using facade::OpResult;
using ArgSlice = absl::Span<const std::string_view>;
using StringVec = std::vector<std::string>;
// keys are RDB_TYPE_xxx constants.
using RdbTypeFreqMap = absl::flat_hash_map<unsigned, size_t>;
constexpr DbIndex kInvalidDbId = DbIndex(-1);
constexpr ShardId kInvalidSid = ShardId(-1);
constexpr DbIndex kMaxDbId = 1024; // Reasonable starting point.
using LockFp = uint64_t; // a key fingerprint used by the LockTable.
class CommandId;
class Transaction;
class EngineShard;
@ -67,98 +59,6 @@ struct LockTagOptions {
static const LockTagOptions& instance();
};
struct KeyLockArgs {
DbIndex db_index = 0;
absl::Span<const LockFp> fps;
};
// Describes key indices.
struct KeyIndex {
unsigned start;
unsigned end; // does not include this index (open limit).
unsigned step; // 1 for commands like mget. 2 for commands like mset.
// if index is non-zero then adds another key index (usually 0).
// relevant for for commands like ZUNIONSTORE/ZINTERSTORE for destination key.
std::optional<uint16_t> bonus{};
bool has_reverse_mapping = false;
KeyIndex(unsigned s = 0, unsigned e = 0, unsigned step = 0) : start(s), end(e), step(step) {
}
static KeyIndex Range(unsigned start, unsigned end, unsigned step = 1) {
return KeyIndex{start, end, step};
}
bool HasSingleKey() const {
return !bonus && (start + step >= end);
}
unsigned num_args() const {
return end - start + bool(bonus);
}
};
struct DbContext {
DbIndex db_index = 0;
uint64_t time_now_ms = 0;
};
struct OpArgs {
EngineShard* shard;
const Transaction* tx;
DbContext db_cntx;
OpArgs() : shard(nullptr), tx(nullptr) {
}
OpArgs(EngineShard* s, const Transaction* tx, const DbContext& cntx)
: shard(s), tx(tx), db_cntx(cntx) {
}
};
// A strong type for a lock tag. Helps to disambiguate between keys and the parts of the
// keys that are used for locking.
class LockTag {
std::string_view str_;
public:
using is_stackonly = void; // marks that this object does not use heap.
LockTag() = default;
explicit LockTag(std::string_view key);
explicit operator std::string_view() const {
return str_;
}
LockFp Fingerprint() const;
// To make it hashable.
template <typename H> friend H AbslHashValue(H h, const LockTag& tag) {
return H::combine(std::move(h), tag.str_);
}
bool operator==(const LockTag& o) const {
return str_ == o.str_;
}
};
// Record non auto journal command with own txid and dbid.
void RecordJournal(const OpArgs& op_args, std::string_view cmd, ArgSlice args,
uint32_t shard_cnt = 1, bool multi_commands = false);
// Record non auto journal command finish. Call only when command translates to multi commands.
void RecordJournalFinish(const OpArgs& op_args, uint32_t shard_cnt);
// Record expiry in journal with independent transaction. Must be called from shard thread holding
// key.
void RecordExpiry(DbIndex dbid, std::string_view key);
// Trigger journal write to sink, no journal record will be added to journal.
// Must be called from shard thread of journal to sink.
void TriggerJournalWriteToSink();
struct IoMgrStats {
uint64_t read_total = 0;
uint64_t read_delay_usec = 0;
@ -205,8 +105,6 @@ enum class GlobalState : uint8_t {
std::ostream& operator<<(std::ostream& os, const GlobalState& state);
std::ostream& operator<<(std::ostream& os, ArgSlice list);
enum class TimeUnit : uint8_t { SEC, MSEC };
inline void ToUpper(const MutableSlice* val) {
@ -414,10 +312,6 @@ inline uint32_t MemberTimeSeconds(uint64_t now_ms) {
return (now_ms / 1000) - kMemberExpiryBase;
}
// Checks whether the touched key is valid for a blocking transaction watching it
using KeyReadyChecker =
std::function<bool(EngineShard*, const DbContext& context, Transaction* tx, std::string_view)>;
struct MemoryBytesFlag {
uint64_t value = 0;
};

View file

@ -11,6 +11,7 @@
#include "facade/conn_context.h"
#include "facade/reply_capture.h"
#include "server/common.h"
#include "server/tx_base.h"
#include "server/version.h"
namespace dfly {

View file

@ -24,7 +24,7 @@ extern "C" {
ABSL_FLAG(bool, singlehop_blocking, true, "Use single hop optimization for blocking commands");
namespace dfly::container_utils {
using namespace std;
namespace {
struct ShardFFResult {
@ -32,16 +32,38 @@ struct ShardFFResult {
ShardId sid = kInvalidSid;
};
// Returns (iterator, args-index) if found, KEY_NOTFOUND otherwise.
// If multiple keys are found, returns the first index in the ArgSlice.
OpResult<std::pair<DbSlice::ConstIterator, unsigned>> FindFirstReadOnly(const DbSlice& db_slice,
const DbContext& cntx,
const ShardArgs& args,
int req_obj_type) {
DCHECK(!args.Empty());
unsigned i = 0;
for (string_view key : args) {
OpResult<DbSlice::ConstIterator> res = db_slice.FindReadOnly(cntx, key, req_obj_type);
if (res)
return make_pair(res.value(), i);
if (res.status() != OpStatus::KEY_NOTFOUND)
return res.status();
++i;
}
VLOG(2) << "FindFirst not found";
return OpStatus::KEY_NOTFOUND;
}
// Find first non-empty key of a single shard transaction, pass it to `func` and return the key.
// If no such key exists or a wrong type is found, the apropriate status is returned.
// Optimized version of `FindFirstNonEmpty` below.
OpResult<std::string> FindFirstNonEmptySingleShard(Transaction* trans, int req_obj_type,
BlockingResultCb func) {
OpResult<string> FindFirstNonEmptySingleShard(Transaction* trans, int req_obj_type,
BlockingResultCb func) {
DCHECK_EQ(trans->GetUniqueShardCnt(), 1u);
std::string key;
string key;
auto cb = [&](Transaction* t, EngineShard* shard) -> Transaction::RunnableResult {
auto args = t->GetShardArgs(shard->shard_id());
auto ff_res = shard->db_slice().FindFirstReadOnly(t->GetDbContext(), args, req_obj_type);
auto ff_res = FindFirstReadOnly(shard->db_slice(), t->GetDbContext(), args, req_obj_type);
if (ff_res == OpStatus::WRONG_TYPE)
return OpStatus::WRONG_TYPE;
@ -77,7 +99,7 @@ OpResult<ShardFFResult> FindFirstNonEmpty(Transaction* trans, int req_obj_type)
auto cb = [&](Transaction* t, EngineShard* shard) {
auto args = t->GetShardArgs(shard->shard_id());
auto ff_res = shard->db_slice().FindFirstReadOnly(t->GetDbContext(), args, req_obj_type);
auto ff_res = FindFirstReadOnly(shard->db_slice(), t->GetDbContext(), args, req_obj_type);
if (ff_res) {
find_res[shard->shard_id()] =
FFResult{ff_res->first->first.AsRef(), ff_res->second, shard->shard_id()};

View file

@ -411,7 +411,7 @@ OpResult<DbSlice::ItAndUpdater> DbSlice::FindMutableInternal(const Context& cntx
}
}
DbSlice::ItAndExpConst DbSlice::FindReadOnly(const Context& cntx, std::string_view key) {
DbSlice::ItAndExpConst DbSlice::FindReadOnly(const Context& cntx, std::string_view key) const {
auto res = FindInternal(cntx, key, std::nullopt, UpdateStatsMode::kReadStats,
LoadExternalMode::kDontLoad);
return {ConstIterator(res->it, StringOrView::FromView(key)),
@ -419,7 +419,7 @@ DbSlice::ItAndExpConst DbSlice::FindReadOnly(const Context& cntx, std::string_vi
}
OpResult<DbSlice::ConstIterator> DbSlice::FindReadOnly(const Context& cntx, string_view key,
unsigned req_obj_type) {
unsigned req_obj_type) const {
auto res = FindInternal(cntx, key, req_obj_type, UpdateStatsMode::kReadStats,
LoadExternalMode::kDontLoad);
if (res.ok()) {
@ -442,7 +442,7 @@ OpResult<DbSlice::ConstIterator> DbSlice::FindAndFetchReadOnly(const Context& cn
OpResult<DbSlice::PrimeItAndExp> DbSlice::FindInternal(const Context& cntx, std::string_view key,
std::optional<unsigned> req_obj_type,
UpdateStatsMode stats_mode,
LoadExternalMode load_mode) {
LoadExternalMode load_mode) const {
if (!IsDbValid(cntx.db_index)) {
return OpStatus::KEY_NOTFOUND;
}
@ -536,24 +536,6 @@ OpResult<DbSlice::PrimeItAndExp> DbSlice::FindInternal(const Context& cntx, std:
return res;
}
OpResult<pair<DbSlice::ConstIterator, unsigned>> DbSlice::FindFirstReadOnly(const Context& cntx,
ArgSlice args,
int req_obj_type) {
DCHECK(!args.empty());
for (unsigned i = 0; i < args.size(); ++i) {
string_view s = args[i];
OpResult<ConstIterator> res = FindReadOnly(cntx, s, req_obj_type);
if (res)
return make_pair(res.value(), i);
if (res.status() != OpStatus::KEY_NOTFOUND)
return res.status();
}
VLOG(2) << "FindFirst " << args.front() << " not found";
return OpStatus::KEY_NOTFOUND;
}
OpResult<DbSlice::AddOrFindResult> DbSlice::AddOrFind(const Context& cntx, string_view key) {
return AddOrFindInternal(cntx, key, LoadExternalMode::kDontLoad);
}
@ -1082,12 +1064,12 @@ void DbSlice::PostUpdate(DbIndex db_ind, Iterator it, std::string_view key, size
SendInvalidationTrackingMessage(key);
}
DbSlice::ItAndExp DbSlice::ExpireIfNeeded(const Context& cntx, Iterator it) {
DbSlice::ItAndExp DbSlice::ExpireIfNeeded(const Context& cntx, Iterator it) const {
auto res = ExpireIfNeeded(cntx, it.GetInnerIt());
return {.it = Iterator::FromPrime(res.it), .exp_it = ExpIterator::FromPrime(res.exp_it)};
}
DbSlice::PrimeItAndExp DbSlice::ExpireIfNeeded(const Context& cntx, PrimeIterator it) {
DbSlice::PrimeItAndExp DbSlice::ExpireIfNeeded(const Context& cntx, PrimeIterator it) const {
if (!it->second.HasExpire()) {
LOG(ERROR) << "Invalid call to ExpireIfNeeded";
return {it, ExpireIterator{}};
@ -1124,8 +1106,9 @@ DbSlice::PrimeItAndExp DbSlice::ExpireIfNeeded(const Context& cntx, PrimeIterato
doc_del_cb_(key, cntx, it->second);
}
PerformDeletion(Iterator(it, StringOrView::FromView(key)),
ExpIterator(expire_it, StringOrView::FromView(key)), db.get());
const_cast<DbSlice*>(this)->PerformDeletion(Iterator(it, StringOrView::FromView(key)),
ExpIterator(expire_it, StringOrView::FromView(key)),
db.get());
++events_.expired_keys;
return {PrimeIterator{}, ExpireIterator{}};
@ -1490,21 +1473,6 @@ void DbSlice::ResetEvents() {
events_ = {};
}
void DbSlice::TrackKeys(const facade::Connection::WeakRef& conn, const ArgSlice& keys) {
if (conn.IsExpired()) {
DVLOG(2) << "Connection expired, exiting TrackKey function.";
return;
}
DVLOG(2) << "Start tracking keys for client ID: " << conn.GetClientId()
<< " with thread ID: " << conn.Thread();
for (auto key : keys) {
DVLOG(2) << "Inserting client ID " << conn.GetClientId()
<< " into the tracking client set of key " << key;
client_tracking_map_[key].insert(conn);
}
}
void DbSlice::SendInvalidationTrackingMessage(std::string_view key) {
if (client_tracking_map_.empty())
return;

View file

@ -281,17 +281,12 @@ class DbSlice {
ConstIterator it;
ExpConstIterator exp_it;
};
ItAndExpConst FindReadOnly(const Context& cntx, std::string_view key);
ItAndExpConst FindReadOnly(const Context& cntx, std::string_view key) const;
OpResult<ConstIterator> FindReadOnly(const Context& cntx, std::string_view key,
unsigned req_obj_type);
unsigned req_obj_type) const;
OpResult<ConstIterator> FindAndFetchReadOnly(const Context& cntx, std::string_view key,
unsigned req_obj_type);
// Returns (iterator, args-index) if found, KEY_NOTFOUND otherwise.
// If multiple keys are found, returns the first index in the ArgSlice.
OpResult<std::pair<ConstIterator, unsigned>> FindFirstReadOnly(const Context& cntx, ArgSlice args,
int req_obj_type);
struct AddOrFindResult {
Iterator it;
ExpIterator exp_it;
@ -404,7 +399,7 @@ class DbSlice {
Iterator it;
ExpIterator exp_it;
};
ItAndExp ExpireIfNeeded(const Context& cntx, Iterator it);
ItAndExp ExpireIfNeeded(const Context& cntx, Iterator it) const;
// Iterate over all expire table entries and delete expired.
void ExpireAllIfNeeded();
@ -473,7 +468,9 @@ class DbSlice {
}
// Track keys for the client represented by the the weak reference to its connection.
void TrackKeys(const facade::Connection::WeakRef&, const ArgSlice&);
void TrackKey(const facade::Connection::WeakRef& conn_ref, std::string_view key) {
client_tracking_map_[key].insert(conn_ref);
}
// Delete a key referred by its iterator.
void PerformDeletion(Iterator del_it, DbTable* table);
@ -517,10 +514,11 @@ class DbSlice {
PrimeIterator it;
ExpireIterator exp_it;
};
PrimeItAndExp ExpireIfNeeded(const Context& cntx, PrimeIterator it);
PrimeItAndExp ExpireIfNeeded(const Context& cntx, PrimeIterator it) const;
OpResult<PrimeItAndExp> FindInternal(const Context& cntx, std::string_view key,
std::optional<unsigned> req_obj_type,
UpdateStatsMode stats_mode, LoadExternalMode load_mode);
UpdateStatsMode stats_mode,
LoadExternalMode load_mode) const;
OpResult<AddOrFindResult> AddOrFindInternal(const Context& cntx, std::string_view key,
LoadExternalMode load_mode);
OpResult<ItAndUpdater> FindMutableInternal(const Context& cntx, std::string_view key,

View file

@ -281,11 +281,11 @@ class Renamer {
void Renamer::Find(Transaction* t) {
auto cb = [this](Transaction* t, EngineShard* shard) {
auto args = t->GetShardArgs(shard->shard_id());
CHECK_EQ(1u, args.size());
DCHECK_EQ(1u, args.Size());
FindResult* res = (shard->shard_id() == src_sid_) ? &src_res_ : &dest_res_;
res->key = args.front();
res->key = args.Front();
auto& db_slice = EngineShard::tlocal()->db_slice();
auto [it, exp_it] = db_slice.FindReadOnly(t->GetDbContext(), res->key);
@ -615,6 +615,40 @@ OpResult<long> OpFieldTtl(Transaction* t, EngineShard* shard, string_view key, s
return res <= 0 ? res : int32_t(res - MemberTimeSeconds(db_cntx.time_now_ms));
}
OpResult<uint32_t> OpDel(const OpArgs& op_args, const ShardArgs& keys) {
DVLOG(1) << "Del: " << keys.Front();
auto& db_slice = op_args.shard->db_slice();
uint32_t res = 0;
for (string_view key : keys) {
auto fres = db_slice.FindMutable(op_args.db_cntx, key);
if (!IsValid(fres.it))
continue;
fres.post_updater.Run();
res += int(db_slice.Del(op_args.db_cntx.db_index, fres.it));
}
return res;
}
OpResult<uint32_t> OpStick(const OpArgs& op_args, const ShardArgs& keys) {
DVLOG(1) << "Stick: " << keys.Front();
auto& db_slice = op_args.shard->db_slice();
uint32_t res = 0;
for (string_view key : keys) {
auto find_res = db_slice.FindMutable(op_args.db_cntx, key);
if (IsValid(find_res.it) && !find_res.it->first.IsSticky()) {
find_res.it->first.SetSticky(true);
++res;
}
}
return res;
}
} // namespace
void GenericFamily::Init(util::ProactorPool* pp) {
@ -631,7 +665,7 @@ void GenericFamily::Del(CmdArgList args, ConnectionContext* cntx) {
bool is_mc = cntx->protocol() == Protocol::MEMCACHE;
auto cb = [&result](const Transaction* t, EngineShard* shard) {
ArgSlice args = t->GetShardArgs(shard->shard_id());
ShardArgs args = t->GetShardArgs(shard->shard_id());
auto res = OpDel(t->GetOpArgs(shard), args);
result.fetch_add(res.value_or(0), memory_order_relaxed);
@ -683,7 +717,7 @@ void GenericFamily::Exists(CmdArgList args, ConnectionContext* cntx) {
atomic_uint32_t result{0};
auto cb = [&result](Transaction* t, EngineShard* shard) {
ArgSlice args = t->GetShardArgs(shard->shard_id());
ShardArgs args = t->GetShardArgs(shard->shard_id());
auto res = OpExists(t->GetOpArgs(shard), args);
result.fetch_add(res.value_or(0), memory_order_relaxed);
@ -889,7 +923,7 @@ void GenericFamily::Stick(CmdArgList args, ConnectionContext* cntx) {
atomic_uint32_t result{0};
auto cb = [&result](const Transaction* t, EngineShard* shard) {
ArgSlice args = t->GetShardArgs(shard->shard_id());
ShardArgs args = t->GetShardArgs(shard->shard_id());
auto res = OpStick(t->GetOpArgs(shard), args);
result.fetch_add(res.value_or(0), memory_order_relaxed);
@ -1373,30 +1407,13 @@ OpResult<uint64_t> GenericFamily::OpTtl(Transaction* t, EngineShard* shard, stri
return ttl_ms;
}
OpResult<uint32_t> GenericFamily::OpDel(const OpArgs& op_args, ArgSlice keys) {
DVLOG(1) << "Del: " << keys[0];
auto& db_slice = op_args.shard->db_slice();
uint32_t res = 0;
for (uint32_t i = 0; i < keys.size(); ++i) {
auto fres = db_slice.FindMutable(op_args.db_cntx, keys[i]);
if (!IsValid(fres.it))
continue;
fres.post_updater.Run();
res += int(db_slice.Del(op_args.db_cntx.db_index, fres.it));
}
return res;
}
OpResult<uint32_t> GenericFamily::OpExists(const OpArgs& op_args, ArgSlice keys) {
DVLOG(1) << "Exists: " << keys[0];
OpResult<uint32_t> GenericFamily::OpExists(const OpArgs& op_args, const ShardArgs& keys) {
DVLOG(1) << "Exists: " << keys.Front();
auto& db_slice = op_args.shard->db_slice();
uint32_t res = 0;
for (uint32_t i = 0; i < keys.size(); ++i) {
auto find_res = db_slice.FindReadOnly(op_args.db_cntx, keys[i]);
for (string_view key : keys) {
auto find_res = db_slice.FindReadOnly(op_args.db_cntx, key);
res += IsValid(find_res.it);
}
return res;
@ -1462,23 +1479,6 @@ OpResult<void> GenericFamily::OpRen(const OpArgs& op_args, string_view from_key,
return OpStatus::OK;
}
OpResult<uint32_t> GenericFamily::OpStick(const OpArgs& op_args, ArgSlice keys) {
DVLOG(1) << "Stick: " << keys[0];
auto& db_slice = op_args.shard->db_slice();
uint32_t res = 0;
for (uint32_t i = 0; i < keys.size(); ++i) {
auto find_res = db_slice.FindMutable(op_args.db_cntx, keys[i]);
if (IsValid(find_res.it) && !find_res.it->first.IsSticky()) {
find_res.it->first.SetSticky(true);
++res;
}
}
return res;
}
// OpMove touches multiple databases (op_args.db_idx, target_db), so it assumes it runs
// as a global transaction.
// TODO: Allow running OpMove without a global transaction.

View file

@ -40,7 +40,7 @@ class GenericFamily {
static void Register(CommandRegistry* registry);
// Accessed by Service::Exec and Service::Watch as an utility.
static OpResult<uint32_t> OpExists(const OpArgs& op_args, ArgSlice keys);
static OpResult<uint32_t> OpExists(const OpArgs& op_args, const ShardArgs& keys);
private:
static void Del(CmdArgList args, ConnectionContext* cntx);
@ -76,10 +76,8 @@ class GenericFamily {
static void TtlGeneric(CmdArgList args, ConnectionContext* cntx, TimeUnit unit);
static OpResult<uint64_t> OpTtl(Transaction* t, EngineShard* shard, std::string_view key);
static OpResult<uint32_t> OpDel(const OpArgs& op_args, ArgSlice keys);
static OpResult<void> OpRen(const OpArgs& op_args, std::string_view from, std::string_view to,
bool skip_exists);
static OpResult<uint32_t> OpStick(const OpArgs& op_args, ArgSlice keys);
static OpStatus OpMove(const OpArgs& op_args, std::string_view key, DbIndex target_db);
};

View file

@ -169,11 +169,11 @@ OpResult<int64_t> CountHllsSingle(const OpArgs& op_args, string_view key) {
}
}
OpResult<vector<string>> ReadValues(const OpArgs& op_args, ArgSlice keys) {
OpResult<vector<string>> ReadValues(const OpArgs& op_args, const ShardArgs& keys) {
try {
vector<string> values;
for (size_t i = 0; i < keys.size(); ++i) {
auto it = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, keys[i], OBJ_STRING);
for (string_view key : keys) {
auto it = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_STRING);
if (it.ok()) {
string hll;
it.value()->second.GetString(&hll);
@ -210,7 +210,7 @@ OpResult<int64_t> PFCountMulti(CmdArgList args, ConnectionContext* cntx) {
auto cb = [&](Transaction* t, EngineShard* shard) {
ShardId sid = shard->shard_id();
ArgSlice shard_args = t->GetShardArgs(shard->shard_id());
ShardArgs shard_args = t->GetShardArgs(shard->shard_id());
auto result = ReadValues(t->GetOpArgs(shard), shard_args);
if (result.ok()) {
hlls[sid] = std::move(result.value());
@ -252,7 +252,7 @@ OpResult<int> PFMergeInternal(CmdArgList args, ConnectionContext* cntx) {
atomic_bool success = true;
auto cb = [&](Transaction* t, EngineShard* shard) {
ShardId sid = shard->shard_id();
ArgSlice shard_args = t->GetShardArgs(shard->shard_id());
ShardArgs shard_args = t->GetShardArgs(shard->shard_id());
auto result = ReadValues(t->GetOpArgs(shard), shard_args);
if (result.ok()) {
hlls[sid] = std::move(result.value());

View file

@ -1130,19 +1130,21 @@ OpResult<vector<OptLong>> OpArrIndex(const OpArgs& op_args, string_view key, Jso
// Returns string vector that represents the query result of each supplied key.
vector<OptString> OpJsonMGet(JsonPathV2 expression, const Transaction* t, EngineShard* shard) {
auto args = t->GetShardArgs(shard->shard_id());
DCHECK(!args.empty());
vector<OptString> response(args.size());
ShardArgs args = t->GetShardArgs(shard->shard_id());
DCHECK(!args.Empty());
vector<OptString> response(args.Size());
auto& db_slice = shard->db_slice();
for (size_t i = 0; i < args.size(); ++i) {
auto it_res = db_slice.FindReadOnly(t->GetDbContext(), args[i], OBJ_JSON);
unsigned index = 0;
for (string_view key : args) {
auto it_res = db_slice.FindReadOnly(t->GetDbContext(), key, OBJ_JSON);
auto& dest = response[index++];
if (!it_res.ok())
continue;
auto& dest = response[i].emplace();
dest.emplace();
JsonType* json_val = it_res.value()->second.GetJson();
DCHECK(json_val) << "should have a valid JSON object for key " << args[i];
DCHECK(json_val) << "should have a valid JSON object for key " << key;
vector<JsonType> query_result;
auto cb = [&query_result](const string_view& path, const JsonType& val) {
@ -1364,8 +1366,8 @@ void JsonFamily::MSet(CmdArgList args, ConnectionContext* cntx) {
}
auto cb = [&](Transaction* t, EngineShard* shard) {
ArgSlice args = t->GetShardArgs(shard->shard_id());
LOG(INFO) << shard->shard_id() << " " << args;
ShardArgs args = t->GetShardArgs(shard->shard_id());
(void)args; // TBD
return OpStatus::OK;
};
@ -1469,12 +1471,7 @@ void JsonFamily::MGet(CmdArgList args, ConnectionContext* cntx) {
continue;
vector<OptString>& res = mget_resp[sid];
ArgSlice slice = transaction->GetShardArgs(sid);
DCHECK(!slice.empty());
DCHECK_EQ(slice.size(), res.size());
for (size_t j = 0; j < slice.size(); ++j) {
for (size_t j = 0; j < res.size(); ++j) {
if (!res[j])
continue;

View file

@ -416,9 +416,9 @@ OpResult<string> MoveTwoShards(Transaction* trans, string_view src, string_view
//
auto cb = [&](Transaction* t, EngineShard* shard) {
auto args = t->GetShardArgs(shard->shard_id());
DCHECK_EQ(1u, args.size());
bool is_dest = args.front() == dest;
find_res[is_dest] = Peek(t->GetOpArgs(shard), args.front(), src_dir, !is_dest);
DCHECK_EQ(1u, args.Size());
bool is_dest = args.Front() == dest;
find_res[is_dest] = Peek(t->GetOpArgs(shard), args.Front(), src_dir, !is_dest);
return OpStatus::OK;
};
@ -432,7 +432,7 @@ OpResult<string> MoveTwoShards(Transaction* trans, string_view src, string_view
// Everything is ok, lets proceed with the mutations.
auto cb = [&](Transaction* t, EngineShard* shard) {
auto args = t->GetShardArgs(shard->shard_id());
auto key = args.front();
auto key = args.Front();
bool is_dest = (key == dest);
OpArgs op_args = t->GetOpArgs(shard);
@ -873,7 +873,7 @@ OpResult<string> BPopPusher::RunSingle(ConnectionContext* cntx, time_point tp) {
return op_res;
}
auto wcb = [&](Transaction* t, EngineShard* shard) { return ArgSlice{&this->pop_key_, 1}; };
auto wcb = [&](Transaction* t, EngineShard* shard) { return ShardArgs{&this->pop_key_, 1}; };
const auto key_checker = [](EngineShard* owner, const DbContext& context, Transaction*,
std::string_view key) -> bool {

View file

@ -1127,9 +1127,25 @@ std::optional<ErrorReply> Service::VerifyCommandState(const CommandId* cid, CmdA
return VerifyConnectionAclStatus(cid, &dfly_cntx, "has no ACL permissions", tail_args);
}
OpResult<void> OpTrackKeys(const OpArgs& op_args, ConnectionContext* cntx, const ArgSlice& keys) {
auto& db_slice = op_args.shard->db_slice();
db_slice.TrackKeys(cntx->conn()->Borrow(), keys);
OpResult<void> OpTrackKeys(const OpArgs& op_args, const facade::Connection::WeakRef& conn_ref,
const ShardArgs& args) {
if (conn_ref.IsExpired()) {
DVLOG(2) << "Connection expired, exiting TrackKey function.";
return OpStatus::OK;
}
DVLOG(2) << "Start tracking keys for client ID: " << conn_ref.GetClientId()
<< " with thread ID: " << conn_ref.Thread();
DbSlice& db_slice = op_args.shard->db_slice();
// TODO: There is a bug here that we track all arguments instead of tracking only keys.
for (auto key : args) {
DVLOG(2) << "Inserting client ID " << conn_ref.GetClientId()
<< " into the tracking client set of key " << key;
db_slice.TrackKey(conn_ref, key);
}
return OpStatus::OK;
}
@ -1236,9 +1252,9 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx)
// start tracking all the updates to the keys in this read command
if ((cid->opt_mask() & CO::READONLY) && dfly_cntx->conn()->IsTrackingOn() &&
cid->IsTransactional()) {
auto cb = [&](Transaction* t, EngineShard* shard) {
auto keys = t->GetShardArgs(shard->shard_id());
return OpTrackKeys(t->GetOpArgs(shard), dfly_cntx, keys);
facade::Connection::WeakRef conn_ref = dfly_cntx->conn()->Borrow();
auto cb = [&, conn_ref](Transaction* t, EngineShard* shard) {
return OpTrackKeys(t->GetOpArgs(shard), conn_ref, t->GetShardArgs(shard->shard_id()));
};
dfly_cntx->transaction->Refurbish();
dfly_cntx->transaction->ScheduleSingleHopT(cb);
@ -1610,7 +1626,7 @@ void Service::Watch(CmdArgList args, ConnectionContext* cntx) {
atomic_uint32_t keys_existed = 0;
auto cb = [&](Transaction* t, EngineShard* shard) {
ArgSlice largs = t->GetShardArgs(shard->shard_id());
ShardArgs largs = t->GetShardArgs(shard->shard_id());
for (auto k : largs) {
shard->db_slice().RegisterWatchedKey(cntx->db_index(), k, &exec_info);
}
@ -2018,7 +2034,7 @@ bool CheckWatchedKeyExpiry(ConnectionContext* cntx, const CommandRegistry& regis
atomic_uint32_t watch_exist_count{0};
auto cb = [&watch_exist_count](Transaction* t, EngineShard* shard) {
ArgSlice args = t->GetShardArgs(shard->shard_id());
ShardArgs args = t->GetShardArgs(shard->shard_id());
auto res = GenericFamily::OpExists(t->GetOpArgs(shard), args);
watch_exist_count.fetch_add(res.value_or(0), memory_order_relaxed);

View file

@ -587,10 +587,10 @@ class Mover {
};
OpStatus Mover::OpFind(Transaction* t, EngineShard* es) {
ArgSlice largs = t->GetShardArgs(es->shard_id());
ShardArgs largs = t->GetShardArgs(es->shard_id());
// In case both src and dest are in the same shard, largs size will be 2.
DCHECK_LE(largs.size(), 2u);
DCHECK_LE(largs.Size(), 2u);
for (auto k : largs) {
unsigned index = (k == src_) ? 0 : 1;
@ -609,8 +609,8 @@ OpStatus Mover::OpFind(Transaction* t, EngineShard* es) {
}
OpStatus Mover::OpMutate(Transaction* t, EngineShard* es) {
ArgSlice largs = t->GetShardArgs(es->shard_id());
DCHECK_LE(largs.size(), 2u);
ShardArgs largs = t->GetShardArgs(es->shard_id());
DCHECK_LE(largs.Size(), 2u);
OpArgs op_args = t->GetOpArgs(es);
for (auto k : largs) {
@ -655,12 +655,13 @@ OpResult<unsigned> Mover::Commit(Transaction* t) {
}
// Read-only OpUnion op on sets.
OpResult<StringVec> OpUnion(const OpArgs& op_args, ArgSlice keys) {
DCHECK(!keys.empty());
OpResult<StringVec> OpUnion(const OpArgs& op_args, ShardArgs::Iterator start,
ShardArgs::Iterator end) {
DCHECK(start != end);
absl::flat_hash_set<string> uniques;
for (string_view key : keys) {
auto find_res = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_SET);
for (; start != end; ++start) {
auto find_res = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, *start, OBJ_SET);
if (find_res) {
const PrimeValue& pv = find_res.value()->second;
if (IsDenseEncoding(pv)) {
@ -683,11 +684,12 @@ OpResult<StringVec> OpUnion(const OpArgs& op_args, ArgSlice keys) {
}
// Read-only OpDiff op on sets.
OpResult<StringVec> OpDiff(const OpArgs& op_args, ArgSlice keys) {
DCHECK(!keys.empty());
DVLOG(1) << "OpDiff from " << keys.front();
OpResult<StringVec> OpDiff(const OpArgs& op_args, ShardArgs::Iterator start,
ShardArgs::Iterator end) {
DCHECK(start != end);
DVLOG(1) << "OpDiff from " << *start;
EngineShard* es = op_args.shard;
auto find_res = es->db_slice().FindReadOnly(op_args.db_cntx, keys.front(), OBJ_SET);
auto find_res = es->db_slice().FindReadOnly(op_args.db_cntx, *start, OBJ_SET);
if (!find_res) {
return find_res.status();
@ -707,8 +709,8 @@ OpResult<StringVec> OpDiff(const OpArgs& op_args, ArgSlice keys) {
DCHECK(!uniques.empty()); // otherwise the key would not exist.
for (size_t i = 1; i < keys.size(); ++i) {
auto diff_res = es->db_slice().FindReadOnly(op_args.db_cntx, keys[i], OBJ_SET);
for (++start; start != end; ++start) {
auto diff_res = es->db_slice().FindReadOnly(op_args.db_cntx, *start, OBJ_SET);
if (!diff_res) {
if (diff_res.status() == OpStatus::WRONG_TYPE) {
return OpStatus::WRONG_TYPE;
@ -737,15 +739,16 @@ OpResult<StringVec> OpDiff(const OpArgs& op_args, ArgSlice keys) {
// Read-only OpInter op on sets.
OpResult<StringVec> OpInter(const Transaction* t, EngineShard* es, bool remove_first) {
ArgSlice keys = t->GetShardArgs(es->shard_id());
ShardArgs args = t->GetShardArgs(es->shard_id());
auto it = args.begin();
if (remove_first) {
keys.remove_prefix(1);
++it;
}
DCHECK(!keys.empty());
DCHECK(it != args.end());
StringVec result;
if (keys.size() == 1) {
auto find_res = es->db_slice().FindReadOnly(t->GetDbContext(), keys.front(), OBJ_SET);
if (args.Size() == 1 + unsigned(remove_first)) {
auto find_res = es->db_slice().FindReadOnly(t->GetDbContext(), *it, OBJ_SET);
if (!find_res)
return find_res.status();
@ -763,12 +766,13 @@ OpResult<StringVec> OpInter(const Transaction* t, EngineShard* es, bool remove_f
return result;
}
vector<SetType> sets(keys.size());
vector<SetType> sets(args.Size() - int(remove_first));
OpStatus status = OpStatus::OK;
for (size_t i = 0; i < keys.size(); ++i) {
auto find_res = es->db_slice().FindReadOnly(t->GetDbContext(), keys[i], OBJ_SET);
unsigned index = 0;
for (; it != args.end(); ++it) {
auto& dest = sets[index++];
auto find_res = es->db_slice().FindReadOnly(t->GetDbContext(), *it, OBJ_SET);
if (!find_res) {
if (status == OpStatus::OK || status == OpStatus::KEY_NOTFOUND ||
find_res.status() != OpStatus::KEY_NOTFOUND) {
@ -778,7 +782,7 @@ OpResult<StringVec> OpInter(const Transaction* t, EngineShard* es, bool remove_f
}
const PrimeValue& pv = find_res.value()->second;
void* ptr = pv.RObjPtr();
sets[i] = make_pair(ptr, pv.Encoding());
dest = make_pair(ptr, pv.Encoding());
}
if (status != OpStatus::OK)
@ -1089,12 +1093,12 @@ void SDiff(CmdArgList args, ConnectionContext* cntx) {
ShardId src_shard = Shard(src_key, result_set.size());
auto cb = [&](Transaction* t, EngineShard* shard) {
ArgSlice largs = t->GetShardArgs(shard->shard_id());
ShardArgs largs = t->GetShardArgs(shard->shard_id());
if (shard->shard_id() == src_shard) {
CHECK_EQ(src_key, largs.front());
result_set[shard->shard_id()] = OpDiff(t->GetOpArgs(shard), largs);
CHECK_EQ(src_key, largs.Front());
result_set[shard->shard_id()] = OpDiff(t->GetOpArgs(shard), largs.begin(), largs.end());
} else {
result_set[shard->shard_id()] = OpUnion(t->GetOpArgs(shard), largs);
result_set[shard->shard_id()] = OpUnion(t->GetOpArgs(shard), largs.begin(), largs.end());
}
return OpStatus::OK;
@ -1126,22 +1130,23 @@ void SDiffStore(CmdArgList args, ConnectionContext* cntx) {
// read-only op
auto diff_cb = [&](Transaction* t, EngineShard* shard) {
ArgSlice largs = t->GetShardArgs(shard->shard_id());
DCHECK(!largs.empty());
ShardArgs largs = t->GetShardArgs(shard->shard_id());
OpArgs op_args = t->GetOpArgs(shard);
DCHECK(!largs.Empty());
ShardArgs::Iterator start = largs.begin();
ShardArgs::Iterator end = largs.end();
if (shard->shard_id() == dest_shard) {
CHECK_EQ(largs.front(), dest_key);
largs.remove_prefix(1);
if (largs.empty())
CHECK_EQ(*start, dest_key);
++start;
if (start == end)
return OpStatus::OK;
}
OpArgs op_args = t->GetOpArgs(shard);
if (shard->shard_id() == src_shard) {
CHECK_EQ(src_key, largs.front());
result_set[shard->shard_id()] = OpDiff(op_args, largs); // Diff
CHECK_EQ(src_key, *start);
result_set[shard->shard_id()] = OpDiff(op_args, start, end); // Diff
} else {
result_set[shard->shard_id()] = OpUnion(op_args, largs); // Union
result_set[shard->shard_id()] = OpUnion(op_args, start, end); // Union
}
return OpStatus::OK;
@ -1276,10 +1281,10 @@ void SInterStore(CmdArgList args, ConnectionContext* cntx) {
atomic_uint32_t inter_shard_cnt{0};
auto inter_cb = [&](Transaction* t, EngineShard* shard) {
ArgSlice largs = t->GetShardArgs(shard->shard_id());
ShardArgs largs = t->GetShardArgs(shard->shard_id());
if (shard->shard_id() == dest_shard) {
CHECK_EQ(largs.front(), dest_key);
if (largs.size() == 1)
CHECK_EQ(largs.Front(), dest_key);
if (largs.Size() == 1)
return OpStatus::OK;
}
inter_shard_cnt.fetch_add(1, memory_order_relaxed);
@ -1337,8 +1342,8 @@ void SUnion(CmdArgList args, ConnectionContext* cntx) {
ResultStringVec result_set(shard_set->size());
auto cb = [&](Transaction* t, EngineShard* shard) {
ArgSlice largs = t->GetShardArgs(shard->shard_id());
result_set[shard->shard_id()] = OpUnion(t->GetOpArgs(shard), largs);
ShardArgs largs = t->GetShardArgs(shard->shard_id());
result_set[shard->shard_id()] = OpUnion(t->GetOpArgs(shard), largs.begin(), largs.end());
return OpStatus::OK;
};
@ -1363,14 +1368,15 @@ void SUnionStore(CmdArgList args, ConnectionContext* cntx) {
ShardId dest_shard = Shard(dest_key, result_set.size());
auto union_cb = [&](Transaction* t, EngineShard* shard) {
ArgSlice largs = t->GetShardArgs(shard->shard_id());
ShardArgs largs = t->GetShardArgs(shard->shard_id());
ShardArgs::Iterator start = largs.begin(), end = largs.end();
if (shard->shard_id() == dest_shard) {
CHECK_EQ(largs.front(), dest_key);
largs.remove_prefix(1);
if (largs.empty())
CHECK_EQ(*start, dest_key);
++start;
if (start == end)
return OpStatus::OK;
}
result_set[shard->shard_id()] = OpUnion(t->GetOpArgs(shard), largs);
result_set[shard->shard_id()] = OpUnion(t->GetOpArgs(shard), start, end);
return OpStatus::OK;
};

View file

@ -798,8 +798,8 @@ stream* GetReadOnlyStream(const CompactObj& cobj) {
// Returns a map of stream to the ID of the last entry in the stream. Any
// streams not found are omitted from the result.
OpResult<vector<pair<string_view, streamID>>> OpLastIDs(const OpArgs& op_args,
const ArgSlice& args) {
DCHECK(!args.empty());
const ShardArgs& args) {
DCHECK(!args.Empty());
auto& db_slice = op_args.shard->db_slice();
@ -828,8 +828,8 @@ OpResult<vector<pair<string_view, streamID>>> OpLastIDs(const OpArgs& op_args,
// Returns the range response for each stream on this shard in order of
// GetShardArgs.
vector<RecordVec> OpRead(const OpArgs& op_args, const ArgSlice& args, const ReadOpts& opts) {
DCHECK(!args.empty());
vector<RecordVec> OpRead(const OpArgs& op_args, const ShardArgs& shard_args, const ReadOpts& opts) {
DCHECK(!shard_args.Empty());
RangeOpts range_opts;
range_opts.count = opts.count;
@ -838,11 +838,11 @@ vector<RecordVec> OpRead(const OpArgs& op_args, const ArgSlice& args, const Read
.seq = UINT64_MAX,
}};
vector<RecordVec> response(args.size());
for (size_t i = 0; i < args.size(); ++i) {
string_view key = args[i];
vector<RecordVec> response(shard_args.Size());
unsigned index = 0;
for (string_view key : shard_args) {
auto sitem = opts.stream_ids.at(key);
auto& dest = response[index++];
if (!sitem.group && opts.read_group) {
continue;
}
@ -858,7 +858,7 @@ vector<RecordVec> OpRead(const OpArgs& op_args, const ArgSlice& args, const Read
else
range_res = OpRange(op_args, key, range_opts);
if (range_res) {
response[i] = std::move(range_res.value());
dest = std::move(range_res.value());
}
}
@ -1352,15 +1352,17 @@ struct GroupConsumerPairOpts {
string_view consumer;
};
vector<GroupConsumerPair> OpGetGroupConsumerPairs(ArgSlice slice_args, const OpArgs& op_args,
vector<GroupConsumerPair> OpGetGroupConsumerPairs(const ShardArgs& shard_args,
const OpArgs& op_args,
const GroupConsumerPairOpts& opts) {
vector<GroupConsumerPair> sid_items(slice_args.size());
vector<GroupConsumerPair> sid_items(shard_args.Size());
unsigned index = 0;
// get group and consumer
for (size_t i = 0; i < slice_args.size(); i++) {
string_view key = slice_args[i];
for (string_view key : shard_args) {
streamCG* group = nullptr;
streamConsumer* consumer = nullptr;
auto& dest = sid_items[index++];
auto group_res = FindGroup(op_args, key, opts.group);
if (!group_res) {
continue;
@ -1376,7 +1378,7 @@ vector<GroupConsumerPair> OpGetGroupConsumerPairs(ArgSlice slice_args, const OpA
consumer = streamCreateConsumer(group, op_args.shard->tmp_str1, NULL, 0,
SCC_NO_NOTIFY | SCC_NO_DIRTIFY);
}
sid_items[i] = {group, consumer};
dest = {group, consumer};
}
return sid_items;
}
@ -2988,12 +2990,7 @@ void XReadImpl(CmdArgList args, std::optional<ReadOpts> opts, ConnectionContext*
vector<RecordVec>& results = xread_resp[sid];
ArgSlice slice = cntx->transaction->GetShardArgs(sid);
DCHECK(!slice.empty());
DCHECK_EQ(slice.size(), results.size());
for (size_t i = 0; i < slice.size(); ++i) {
for (size_t i = 0; i < results.size(); ++i) {
if (results[i].size() == 0) {
continue;
}
@ -3039,7 +3036,7 @@ void XReadGeneric(CmdArgList args, bool read_group, ConnectionContext* cntx) {
vector<vector<GroupConsumerPair>> res_pairs(shard_set->size());
auto cb = [&](Transaction* t, EngineShard* shard) {
auto sid = shard->shard_id();
auto s_args = t->GetShardArgs(sid);
ShardArgs s_args = t->GetShardArgs(sid);
GroupConsumerPairOpts gc_opts = {opts->group_name, opts->consumer_name};
res_pairs[sid] = OpGetGroupConsumerPairs(s_args, t->GetOpArgs(shard), gc_opts);
@ -3057,11 +3054,12 @@ void XReadGeneric(CmdArgList args, bool read_group, ConnectionContext* cntx) {
if (s_item.size() == 0) {
continue;
}
for (size_t j = 0; j < s_args.size(); j++) {
string_view key = s_args[j];
unsigned index = 0;
for (string_view key : s_args) {
StreamIDsItem& item = opts->stream_ids.at(key);
item.consumer = s_item[j].consumer;
item.group = s_item[j].group;
item.consumer = s_item[index].consumer;
item.group = s_item[index].group;
++index;
}
}
}

View file

@ -109,8 +109,9 @@ TEST_F(StreamFamilyTest, Range) {
}
TEST_F(StreamFamilyTest, GroupCreate) {
Run({"xadd", "key", "1-*", "f1", "v1"});
auto resp = Run({"xgroup", "create", "key", "grname", "1"});
auto resp = Run({"xadd", "key", "1-*", "f1", "v1"});
EXPECT_EQ(resp, "1-0");
resp = Run({"xgroup", "create", "key", "grname", "1"});
EXPECT_EQ(resp, "OK");
resp = Run({"xgroup", "create", "test", "test", "0"});
EXPECT_THAT(resp, ErrArg("requires the key to exist"));

View file

@ -278,19 +278,23 @@ int64_t AbsExpiryToTtl(int64_t abs_expiry_time, bool as_milli) {
}
// Returns true if keys were set, false otherwise.
void OpMSet(const OpArgs& op_args, ArgSlice args, atomic_bool* success) {
DCHECK(!args.empty() && args.size() % 2 == 0);
void OpMSet(const OpArgs& op_args, const ShardArgs& args, atomic_bool* success) {
DCHECK(!args.Empty() && args.Size() % 2 == 0);
SetCmd::SetParams params;
SetCmd sg(op_args, false);
size_t i = 0;
for (; i < args.size(); i += 2) {
DVLOG(1) << "MSet " << args[i] << ":" << args[i + 1];
if (sg.Set(params, args[i], args[i + 1]) != OpStatus::OK) { // OOM for example.
size_t index = 0;
for (auto it = args.begin(); it != args.end(); ++it) {
string_view key = *it;
++it;
string_view value = *it;
DVLOG(1) << "MSet " << key << ":" << value;
if (sg.Set(params, key, value) != OpStatus::OK) { // OOM for example.
success->store(false);
break;
}
index += 2;
}
if (auto journal = op_args.shard->journal(); journal) {
@ -298,14 +302,14 @@ void OpMSet(const OpArgs& op_args, ArgSlice args, atomic_bool* success) {
// we replicate only what was changed.
string_view cmd;
ArgSlice cmd_args;
if (i == 0) {
if (index == 0) {
// All shards must record the tx was executed for the replica to execute it, so we send a PING
// in case nothing was changed
cmd = "PING";
} else {
// journal [0, i)
cmd = "MSET";
cmd_args = ArgSlice(&args[0], i);
cmd_args = ArgSlice(args.begin(), index);
}
RecordJournal(op_args, cmd, cmd_args, op_args.tx->GetUniqueShardCnt());
}
@ -419,27 +423,29 @@ OpResult<array<int64_t, 5>> OpThrottle(const OpArgs& op_args, const string_view
SinkReplyBuilder::MGetResponse OpMGet(bool fetch_mcflag, bool fetch_mcver, const Transaction* t,
EngineShard* shard) {
auto keys = t->GetShardArgs(shard->shard_id());
DCHECK(!keys.empty());
ShardArgs keys = t->GetShardArgs(shard->shard_id());
DCHECK(!keys.Empty());
auto& db_slice = shard->db_slice();
SinkReplyBuilder::MGetResponse response(keys.size());
absl::InlinedVector<DbSlice::ConstIterator, 32> iters(keys.size());
SinkReplyBuilder::MGetResponse response(keys.Size());
absl::InlinedVector<DbSlice::ConstIterator, 32> iters(keys.Size());
size_t total_size = 0;
for (size_t i = 0; i < keys.size(); ++i) {
auto it_res = db_slice.FindAndFetchReadOnly(t->GetDbContext(), keys[i], OBJ_STRING);
unsigned index = 0;
for (string_view key : keys) {
auto it_res = db_slice.FindAndFetchReadOnly(t->GetDbContext(), key, OBJ_STRING);
auto& dest = iters[index++];
if (!it_res)
continue;
iters[i] = *it_res;
dest = *it_res;
total_size += (*it_res)->second.Size();
}
response.storage_list = SinkReplyBuilder::AllocMGetStorage(total_size);
char* next = response.storage_list->data;
for (size_t i = 0; i < keys.size(); ++i) {
for (size_t i = 0; i < iters.size(); ++i) {
auto it = iters[i];
if (it.is_done())
continue;
@ -1139,12 +1145,7 @@ void StringFamily::MGet(CmdArgList args, ConnectionContext* cntx) {
res.storage_list = src.storage_list;
src.storage_list = nullptr;
ArgSlice slice = transaction->GetShardArgs(sid);
DCHECK(!slice.empty());
DCHECK_EQ(slice.size(), src.resp_arr.size());
for (size_t j = 0; j < slice.size(); ++j) {
for (size_t j = 0; j < src.resp_arr.size(); ++j) {
if (!src.resp_arr[j])
continue;
@ -1173,7 +1174,7 @@ void StringFamily::MSet(CmdArgList args, ConnectionContext* cntx) {
atomic_bool success = true;
auto cb = [&](Transaction* t, EngineShard* shard) {
auto args = t->GetShardArgs(shard->shard_id());
ShardArgs args = t->GetShardArgs(shard->shard_id());
OpMSet(t->GetOpArgs(shard), args, &success);
return OpStatus::OK;
};
@ -1193,8 +1194,9 @@ void StringFamily::MSetNx(CmdArgList args, ConnectionContext* cntx) {
auto cb = [&](Transaction* t, EngineShard* es) {
auto args = t->GetShardArgs(es->shard_id());
for (size_t i = 0; i < args.size(); i += 2) {
auto it = es->db_slice().FindReadOnly(t->GetDbContext(), args[i]).it;
for (auto arg_it = args.begin(); arg_it != args.end(); ++arg_it) {
auto it = es->db_slice().FindReadOnly(t->GetDbContext(), *arg_it).it;
++arg_it;
if (IsValid(it)) {
exists.store(true, memory_order_relaxed);
break;

View file

@ -182,8 +182,6 @@ void Transaction::InitGlobal() {
}
void Transaction::BuildShardIndex(const KeyIndex& key_index, std::vector<PerShardCache>* out) {
auto args = full_args_;
auto& shard_index = *out;
auto add = [this, rev_mapping = key_index.has_reverse_mapping, &shard_index](uint32_t sid,
@ -196,14 +194,14 @@ void Transaction::BuildShardIndex(const KeyIndex& key_index, std::vector<PerShar
if (key_index.bonus) {
DCHECK(key_index.step == 1);
string_view key = ArgS(args, *key_index.bonus);
string_view key = ArgS(full_args_, *key_index.bonus);
unique_slot_checker_.Add(key);
uint32_t sid = Shard(key, shard_data_.size());
add(sid, *key_index.bonus);
}
for (unsigned i = key_index.start; i < key_index.end; ++i) {
string_view key = ArgS(args, i);
string_view key = ArgS(full_args_, i);
unique_slot_checker_.Add(key);
uint32_t sid = Shard(key, shard_data_.size());
shard_index[sid].key_step = key_index.step;
@ -278,18 +276,16 @@ void Transaction::PrepareMultiFps(CmdArgList keys) {
void Transaction::StoreKeysInArgs(const KeyIndex& key_index) {
DCHECK(!key_index.bonus);
DCHECK(key_index.step == 1u || key_index.step == 2u);
DCHECK(kv_fp_.empty());
// even for a single key we may have multiple arguments per key (MSET).
for (unsigned j = key_index.start; j < key_index.end; j++) {
for (unsigned j = key_index.start; j < key_index.end; j += key_index.step) {
string_view arg = ArgS(full_args_, j);
kv_args_.push_back(arg);
kv_fp_.push_back(LockTag(arg).Fingerprint());
if (key_index.step == 2) {
kv_args_.push_back(ArgS(full_args_, ++j));
}
for (unsigned k = j + 1; k < j + key_index.step; ++k)
kv_args_.push_back(ArgS(full_args_, k));
}
if (key_index.has_reverse_mapping) {
@ -318,11 +314,12 @@ void Transaction::InitByKeys(const KeyIndex& key_index) {
StoreKeysInArgs(key_index);
unique_shard_cnt_ = 1;
string_view akey = kv_args_.front();
if (is_stub) // stub transactions don't migrate
DCHECK_EQ(unique_shard_id_, Shard(kv_args_.front(), shard_set->size()));
DCHECK_EQ(unique_shard_id_, Shard(akey, shard_set->size()));
else {
unique_slot_checker_.Add(kv_args_.front());
unique_shard_id_ = Shard(kv_args_.front(), shard_set->size());
unique_slot_checker_.Add(akey);
unique_shard_id_ = Shard(akey, shard_set->size());
}
// Multi transactions that execute commands on their own (not stubs) can't shrink the backing
@ -1178,7 +1175,7 @@ bool Transaction::CancelShardCb(EngineShard* shard) {
}
// runs in engine-shard thread.
ArgSlice Transaction::GetShardArgs(ShardId sid) const {
ShardArgs Transaction::GetShardArgs(ShardId sid) const {
DCHECK(!multi_ || multi_->role != SQUASHER);
// We can read unique_shard_cnt_ only because ShardArgsInShard is called after IsArmedInShard
@ -1188,7 +1185,7 @@ ArgSlice Transaction::GetShardArgs(ShardId sid) const {
}
const auto& sd = shard_data_[sid];
return ArgSlice{kv_args_.data() + sd.arg_start, sd.arg_count};
return ShardArgs{kv_args_.data() + sd.arg_start, sd.arg_count};
}
// from local index back to original arg index skipping the command.
@ -1253,7 +1250,7 @@ OpStatus Transaction::WaitOnWatch(const time_point& tp, WaitKeysProvider wkeys_p
return result;
}
OpStatus Transaction::WatchInShard(ArgSlice keys, EngineShard* shard, KeyReadyChecker krc) {
OpStatus Transaction::WatchInShard(const ShardArgs& keys, EngineShard* shard, KeyReadyChecker krc) {
auto& sd = shard_data_[SidToId(shard->shard_id())];
CHECK_EQ(0, sd.local_mask & SUSPENDED_Q);
@ -1261,12 +1258,12 @@ OpStatus Transaction::WatchInShard(ArgSlice keys, EngineShard* shard, KeyReadyCh
sd.local_mask &= ~OUT_OF_ORDER;
shard->EnsureBlockingController()->AddWatched(keys, std::move(krc), this);
DVLOG(2) << "WatchInShard " << DebugId() << ", first_key:" << keys.front();
DVLOG(2) << "WatchInShard " << DebugId() << ", first_key:" << keys.Front();
return OpStatus::OK;
}
void Transaction::ExpireShardCb(ArgSlice wkeys, EngineShard* shard) {
void Transaction::ExpireShardCb(const ShardArgs& wkeys, EngineShard* shard) {
// Blocking transactions don't release keys when suspending, release them now.
auto lock_args = GetLockArgs(shard->shard_id());
shard->db_slice().Release(LockMode(), lock_args);
@ -1369,9 +1366,9 @@ bool Transaction::NotifySuspended(TxId committed_txid, ShardId sid, string_view
CHECK_EQ(sd.local_mask & AWAKED_Q, 0);
// Find index of awakened key
auto args = GetShardArgs(sid);
auto it = find_if(args.begin(), args.end(), [key](auto arg) { return facade::ToSV(arg) == key; });
CHECK(it != args.end());
ShardArgs args = GetShardArgs(sid);
auto it = find_if(args.cbegin(), args.cend(), [key](string_view arg) { return arg == key; });
CHECK(it != args.cend());
// Change state to awaked and store index of awakened key
sd.local_mask &= ~SUSPENDED_Q;
@ -1427,7 +1424,7 @@ void Transaction::LogAutoJournalOnShard(EngineShard* shard, RunnableResult resul
if (unique_shard_cnt_ == 1 || kv_args_.empty()) {
entry_payload = make_pair(cmd, full_args_);
} else {
entry_payload = make_pair(cmd, GetShardArgs(shard->shard_id()));
entry_payload = make_pair(cmd, GetShardArgs(shard->shard_id()).AsSlice());
}
LogJournalOnShard(shard, std::move(entry_payload), unique_shard_cnt_, false, true);
}

View file

@ -22,6 +22,7 @@
#include "server/common.h"
#include "server/journal/types.h"
#include "server/table.h"
#include "server/tx_base.h"
#include "util/fibers/synchronization.h"
namespace dfly {
@ -129,8 +130,9 @@ class Transaction {
// Runnable that is run on shards during hop executions (often named callback).
// Callacks should return `OpStatus` which is implicitly converitble to `RunnableResult`!
using RunnableType = absl::FunctionRef<RunnableResult(Transaction* t, EngineShard*)>;
// Provides keys to block on for specific shard.
using WaitKeysProvider = std::function<ArgSlice(Transaction*, EngineShard* shard)>;
using WaitKeysProvider = std::function<ShardArgs(Transaction*, EngineShard* shard)>;
// Modes in which a multi transaction can run.
enum MultiMode {
@ -176,7 +178,7 @@ class Transaction {
OpStatus InitByArgs(DbIndex index, CmdArgList args);
// Get command arguments for specific shard. Called from shard thread.
ArgSlice GetShardArgs(ShardId sid) const;
ShardArgs GetShardArgs(ShardId sid) const;
// Map arg_index from GetShardArgs slice to index in original command slice from InitByArgs.
size_t ReverseArgIndex(ShardId shard_id, size_t arg_index) const;
@ -511,12 +513,12 @@ class Transaction {
void RunCallback(EngineShard* shard);
// Adds itself to watched queue in the shard. Must run in that shard thread.
OpStatus WatchInShard(ArgSlice keys, EngineShard* shard, KeyReadyChecker krc);
OpStatus WatchInShard(const ShardArgs& keys, EngineShard* shard, KeyReadyChecker krc);
// Expire blocking transaction, unlock keys and unregister it from the blocking controller
void ExpireBlocking(WaitKeysProvider wcb);
void ExpireShardCb(ArgSlice wkeys, EngineShard* shard);
void ExpireShardCb(const ShardArgs& wkeys, EngineShard* shard);
// Returns true if we need to follow up with PollExecution on this shard.
bool CancelShardCb(EngineShard* shard);
@ -577,7 +579,6 @@ class Transaction {
});
}
private:
// Used for waiting for all hop callbacks to run.
util::fb2::EmbeddedBlockingCounter run_barrier_{0};

61
src/server/tx_base.cc Normal file
View file

@ -0,0 +1,61 @@
// Copyright 2024, DragonflyDB authors. All rights reserved.
// See LICENSE for licensing terms.
//
#include "server/tx_base.h"
#include "base/logging.h"
#include "server/cluster/cluster_defs.h"
#include "server/engine_shard_set.h"
#include "server/journal/journal.h"
#include "server/transaction.h"
namespace dfly {
using namespace std;
void RecordJournal(const OpArgs& op_args, string_view cmd, ArgSlice args, uint32_t shard_cnt,
bool multi_commands) {
VLOG(2) << "Logging command " << cmd << " from txn " << op_args.tx->txid();
op_args.tx->LogJournalOnShard(op_args.shard, make_pair(cmd, args), shard_cnt, multi_commands,
false);
}
void RecordJournalFinish(const OpArgs& op_args, uint32_t shard_cnt) {
op_args.tx->FinishLogJournalOnShard(op_args.shard, shard_cnt);
}
void RecordExpiry(DbIndex dbid, string_view key) {
auto journal = EngineShard::tlocal()->journal();
CHECK(journal);
journal->RecordEntry(0, journal::Op::EXPIRED, dbid, 1, cluster::KeySlot(key),
make_pair("DEL", ArgSlice{key}), false);
}
void TriggerJournalWriteToSink() {
auto journal = EngineShard::tlocal()->journal();
CHECK(journal);
journal->RecordEntry(0, journal::Op::NOOP, 0, 0, nullopt, {}, true);
}
std::ostream& operator<<(std::ostream& os, ArgSlice list) {
os << "[";
if (!list.empty()) {
std::for_each(list.begin(), list.end() - 1, [&os](const auto& val) { os << val << ", "; });
os << (*(list.end() - 1));
}
return os << "]";
}
LockTag::LockTag(std::string_view key) {
if (LockTagOptions::instance().enabled)
str_ = LockTagOptions::instance().Tag(key);
else
str_ = key;
}
LockFp LockTag::Fingerprint() const {
return XXH64(str_.data(), str_.size(), 0x1C69B3F74AC4AE35UL);
}
} // namespace dfly

172
src/server/tx_base.h Normal file
View file

@ -0,0 +1,172 @@
// Copyright 2024, DragonflyDB authors. All rights reserved.
// See LICENSE for licensing terms.
//
#pragma once
#include <absl/types/span.h>
#include <optional>
#include "src/facade/facade_types.h"
namespace dfly {
class EngineShard;
class Transaction;
using DbIndex = uint16_t;
using ShardId = uint16_t;
using LockFp = uint64_t; // a key fingerprint used by the LockTable.
using ArgSlice = absl::Span<const std::string_view>;
constexpr DbIndex kInvalidDbId = DbIndex(-1);
constexpr ShardId kInvalidSid = ShardId(-1);
constexpr DbIndex kMaxDbId = 1024; // Reasonable starting point.
struct KeyLockArgs {
DbIndex db_index = 0;
absl::Span<const LockFp> fps;
};
// Describes key indices.
struct KeyIndex {
unsigned start;
unsigned end; // does not include this index (open limit).
unsigned step; // 1 for commands like mget. 2 for commands like mset.
// if index is non-zero then adds another key index (usually 0).
// relevant for for commands like ZUNIONSTORE/ZINTERSTORE for destination key.
std::optional<uint16_t> bonus{};
bool has_reverse_mapping = false;
KeyIndex(unsigned s = 0, unsigned e = 0, unsigned step = 0) : start(s), end(e), step(step) {
}
static KeyIndex Range(unsigned start, unsigned end, unsigned step = 1) {
return KeyIndex{start, end, step};
}
bool HasSingleKey() const {
return !bonus && (start + step >= end);
}
unsigned num_args() const {
return end - start + bool(bonus);
}
};
struct DbContext {
DbIndex db_index = 0;
uint64_t time_now_ms = 0;
};
struct OpArgs {
EngineShard* shard;
const Transaction* tx;
DbContext db_cntx;
OpArgs() : shard(nullptr), tx(nullptr) {
}
OpArgs(EngineShard* s, const Transaction* tx, const DbContext& cntx)
: shard(s), tx(tx), db_cntx(cntx) {
}
};
// A strong type for a lock tag. Helps to disambiguate between keys and the parts of the
// keys that are used for locking.
class LockTag {
std::string_view str_;
public:
using is_stackonly = void; // marks that this object does not use heap.
LockTag() = default;
explicit LockTag(std::string_view key);
explicit operator std::string_view() const {
return str_;
}
LockFp Fingerprint() const;
// To make it hashable.
template <typename H> friend H AbslHashValue(H h, const LockTag& tag) {
return H::combine(std::move(h), tag.str_);
}
bool operator==(const LockTag& o) const {
return str_ == o.str_;
}
};
// Checks whether the touched key is valid for a blocking transaction watching it.
using KeyReadyChecker =
std::function<bool(EngineShard*, const DbContext& context, Transaction* tx, std::string_view)>;
// References arguments in another array.
using IndexSlice = std::pair<uint32_t, uint32_t>; // (begin, end)
class ShardArgs : protected ArgSlice {
public:
using ArgSlice::ArgSlice;
using ArgSlice::at;
using ArgSlice::operator=;
using Iterator = ArgSlice::iterator;
ShardArgs(const ArgSlice& o) : ArgSlice(o) {
}
size_t Size() const {
return ArgSlice::size();
}
auto cbegin() const {
return ArgSlice::cbegin();
}
auto cend() const {
return ArgSlice::cend();
}
auto begin() const {
return cbegin();
}
auto end() const {
return cend();
}
bool Empty() const {
return ArgSlice::empty();
}
std::string_view Front() const {
return *cbegin();
}
ArgSlice AsSlice() const {
return ArgSlice(*this);
}
};
// Record non auto journal command with own txid and dbid.
void RecordJournal(const OpArgs& op_args, std::string_view cmd, ArgSlice args,
uint32_t shard_cnt = 1, bool multi_commands = false);
// Record non auto journal command finish. Call only when command translates to multi commands.
void RecordJournalFinish(const OpArgs& op_args, uint32_t shard_cnt);
// Record expiry in journal with independent transaction. Must be called from shard thread holding
// key.
void RecordExpiry(DbIndex dbid, std::string_view key);
// Trigger journal write to sink, no journal record will be added to journal.
// Must be called from shard thread of journal to sink.
void TriggerJournalWriteToSink();
std::ostream& operator<<(std::ostream& os, ArgSlice list);
} // namespace dfly

View file

@ -822,39 +822,43 @@ double GetKeyWeight(Transaction* t, ShardId shard_id, const vector<double>& weig
OpResult<ScoredMap> OpUnion(EngineShard* shard, Transaction* t, string_view dest, AggType agg_type,
const vector<double>& weights, bool store) {
ArgSlice keys = t->GetShardArgs(shard->shard_id());
DVLOG(1) << "shard:" << shard->shard_id() << ", keys " << keys;
DCHECK(!keys.empty());
ShardArgs keys = t->GetShardArgs(shard->shard_id());
DCHECK(!keys.Empty());
unsigned cmdargs_keys_offset = 1; // after {numkeys} for ZUNION
unsigned removed_keys = 0;
ShardArgs::Iterator start = keys.begin(), end = keys.end();
if (store) {
// first global index is 2 after {destkey, numkeys}.
++cmdargs_keys_offset;
if (keys.front() == dest) {
keys.remove_prefix(1);
if (*start == dest) {
++start;
++removed_keys;
}
// In case ONLY the destination key is hosted in this shard no work on this shard should be
// done in this step
if (keys.empty()) {
if (start == end) {
return OpStatus::OK;
}
}
auto& db_slice = shard->db_slice();
KeyIterWeightVec key_weight_vec(keys.size());
for (unsigned j = 0; j < keys.size(); ++j) {
auto it_res = db_slice.FindReadOnly(t->GetDbContext(), keys[j], OBJ_ZSET);
if (it_res == OpStatus::WRONG_TYPE) // TODO: support sets with default score 1.
KeyIterWeightVec key_weight_vec(keys.Size() - removed_keys);
unsigned index = 0;
for (; start != end; ++start) {
auto it_res = db_slice.FindReadOnly(t->GetDbContext(), *start, OBJ_ZSET);
if (it_res == OpStatus::WRONG_TYPE) // TODO: support SET type with default score 1.
return it_res.status();
if (!it_res)
if (!it_res) {
++index;
continue;
key_weight_vec[j] = {*it_res, GetKeyWeight(t, shard->shard_id(), weights, j + removed_keys,
cmdargs_keys_offset)};
}
key_weight_vec[index] = {*it_res, GetKeyWeight(t, shard->shard_id(), weights,
index + removed_keys, cmdargs_keys_offset)};
++index;
}
return UnionShardKeysWithScore(key_weight_vec, agg_type);
@ -871,46 +875,48 @@ ScoredMap ZSetFromSet(const PrimeValue& pv, double weight) {
OpResult<ScoredMap> OpInter(EngineShard* shard, Transaction* t, string_view dest, AggType agg_type,
const vector<double>& weights, bool store) {
ArgSlice keys = t->GetShardArgs(shard->shard_id());
DVLOG(1) << "shard:" << shard->shard_id() << ", keys " << keys;
DCHECK(!keys.empty());
ShardArgs keys = t->GetShardArgs(shard->shard_id());
DCHECK(!keys.Empty());
unsigned removed_keys = 0;
unsigned cmdargs_keys_offset = 1;
ShardArgs::Iterator start = keys.begin(), end = keys.end();
if (store) {
// first global index is 2 after {destkey, numkeys}.
++cmdargs_keys_offset;
if (keys.front() == dest) {
keys.remove_prefix(1);
if (*start == dest) {
++start;
++removed_keys;
}
// In case ONLY the destination key is hosted in this shard no work on this shard should be
// done in this step
if (keys.empty()) {
return OpStatus::SKIPPED;
// In case ONLY the destination key is hosted in this shard no work on this shard should be
// done in this step
if (start == end) {
return OpStatus::SKIPPED;
}
}
}
auto& db_slice = shard->db_slice();
vector<pair<DbSlice::ItAndUpdater, double>> it_arr(keys.size());
if (it_arr.empty()) // could be when only the dest key is hosted in this shard
return OpStatus::SKIPPED; // return noop
vector<pair<DbSlice::ItAndUpdater, double>> it_arr(keys.Size() - removed_keys);
for (unsigned j = 0; j < keys.size(); ++j) {
auto it_res = db_slice.FindMutable(t->GetDbContext(), keys[j]);
if (!IsValid(it_res.it))
unsigned index = 0;
for (; start != end; ++start) {
auto it_res = db_slice.FindMutable(t->GetDbContext(), *start);
if (!IsValid(it_res.it)) {
++index;
continue; // we exit in the next loop
}
// sets are supported for ZINTER* commands:
auto obj_type = it_res.it->second.ObjType();
if (obj_type != OBJ_ZSET && obj_type != OBJ_SET)
return OpStatus::WRONG_TYPE;
it_arr[j] = {std::move(it_res), GetKeyWeight(t, shard->shard_id(), weights, j + removed_keys,
cmdargs_keys_offset)};
it_arr[index] = {std::move(it_res), GetKeyWeight(t, shard->shard_id(), weights,
index + removed_keys, cmdargs_keys_offset)};
++index;
}
ScoredMap result;
@ -1343,16 +1349,15 @@ void BZPopMinMax(CmdArgList args, ConnectionContext* cntx, bool is_max) {
}
vector<ScoredMap> OpFetch(EngineShard* shard, Transaction* t) {
ArgSlice keys = t->GetShardArgs(shard->shard_id());
DVLOG(1) << "shard:" << shard->shard_id() << ", keys " << keys;
DCHECK(!keys.empty());
ShardArgs keys = t->GetShardArgs(shard->shard_id());
DCHECK(!keys.Empty());
vector<ScoredMap> results;
results.reserve(keys.size());
results.reserve(keys.Size());
auto& db_slice = shard->db_slice();
for (size_t i = 0; i < keys.size(); ++i) {
auto it = db_slice.FindReadOnly(t->GetDbContext(), keys[i], OBJ_ZSET);
for (string_view key : keys) {
auto it = db_slice.FindReadOnly(t->GetDbContext(), key, OBJ_ZSET);
if (!it) {
results.push_back({});
continue;