1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-14 11:58:02 +00:00

Add tests for memcache commands. Handle flags metadata

This commit is contained in:
Roman Gershman 2022-02-20 22:07:33 +02:00
parent 8d2d49d782
commit b3c9836682
18 changed files with 309 additions and 104 deletions

View file

@ -290,7 +290,7 @@ void CompactObj::ImportRObj(robj* o) {
u_.r_obj.type = o->type;
u_.r_obj.encoding = o->encoding;
u_.r_obj.lru_unneeded = o->lru;
u_.r_obj.unneeded = o->lru;
if (o->type == OBJ_STRING) {
std::string_view src((char*)o->ptr, sdslen((sds)o->ptr));
@ -308,7 +308,7 @@ robj* CompactObj::AsRObj() const {
tmp_robj.encoding = u_.r_obj.encoding;
tmp_robj.type = u_.r_obj.type;
tmp_robj.lru = u_.r_obj.lru_unneeded;
tmp_robj.lru = u_.r_obj.unneeded;
tmp_robj.ptr = u_.r_obj.blob.ptr();
return &tmp_robj;

View file

@ -69,7 +69,7 @@ struct RobjWrapper {
uint32_t type : 4;
uint32_t encoding : 4;
uint32_t lru_unneeded : 24;
uint32_t unneeded : 24;
RobjWrapper() {
}
} __attribute__((packed));
@ -92,6 +92,7 @@ class CompactObj {
enum MaskBit {
REF_BIT = 1,
EXPIRE_BIT = 2,
FLAG_BIT = 4,
};
public:
@ -168,6 +169,18 @@ class CompactObj {
}
}
bool HasFlag() const {
return mask_ & FLAG_BIT;
}
void SetFlag(bool e) {
if (e) {
mask_ |= FLAG_BIT;
} else {
mask_ &= ~FLAG_BIT;
}
}
unsigned Encoding() const;
unsigned ObjType() const;
quicklist* GetQL() const;
@ -236,6 +249,8 @@ class CompactObj {
//
static_assert(sizeof(u_) == 16, "");
// Maybe it's possible to merge those 2 together and gain another byte
// but lets postpone it to 2023.
mutable uint8_t mask_ = 0;
uint8_t taglen_ = 0;
};

View file

@ -337,7 +337,6 @@ template <typename _Key, typename _Value, typename Policy>
DashTable<_Key, _Value, Policy>::DashTable(size_t capacity_log, const Policy& policy,
std::pmr::memory_resource* mr)
: Base(capacity_log), policy_(policy), segment_(mr) {
assert(capacity_log > 0u);
segment_.resize(unique_segments_);
std::pmr::polymorphic_allocator<SegmentType> pa(mr);

View file

@ -544,7 +544,11 @@ class DashTableBase {
protected:
uint32_t SegmentId(size_t hash) const {
return hash >> (64 - global_depth_);
if (global_depth_) {
return hash >> (64 - global_depth_);
}
return 0;
}
uint32_t global_depth_;

View file

@ -25,9 +25,9 @@ Protocol ConnectionContext::protocol() const {
}
RedisReplyBuilder* ConnectionContext::operator->() {
CHECK(Protocol::REDIS == owner_->protocol());
RedisReplyBuilder* b = static_cast<RedisReplyBuilder*>(rbuilder_.get());
return b;
CHECK(Protocol::REDIS == protocol());
return static_cast<RedisReplyBuilder*>(rbuilder_.get());
}
} // namespace dfly

View file

@ -42,6 +42,7 @@ struct ConnectionState {
};
uint32_t mask = 0; // A bitmask of Mask values.
uint32_t memcache_flag = 0; // used for memcache set command.
bool IsClosing() const {
return mask & CONN_CLOSING;

View file

@ -47,6 +47,12 @@ SliceEvents& SliceEvents::operator+=(const SliceEvents& o) {
#undef ADD
DbSlice::DbWrapper::DbWrapper(std::pmr::memory_resource* mr)
: prime_table(4, detail::PrimeTablePolicy{}, mr),
expire_table(0, detail::ExpireTablePolicy{}, mr),
mcflag_table(0, detail::ExpireTablePolicy{}, mr) {
}
DbSlice::DbSlice(uint32_t index, EngineShard* owner) : shard_id_(index), owner_(owner) {
db_arr_.emplace_back();
CreateDb(0);
@ -178,6 +184,10 @@ auto DbSlice::AddOrFind(DbIndex db_index, string_view key) -> pair<MainIterator,
if (expire_it->second <= now_ms_) {
db->expire_table.Erase(expire_it);
if (existing->second.HasFlag()) {
db->mcflag_table.Erase(existing->first);
}
// Keep the entry but reset the object.
db->stats.obj_memory_usage -= existing->second.MallocUsed();
existing->second.Reset();
@ -213,6 +223,10 @@ bool DbSlice::Del(DbIndex db_ind, MainIterator it) {
CHECK_EQ(1u, db->expire_table.Erase(it->first));
}
if (it->second.HasFlag()) {
CHECK_EQ(1u, db->mcflag_table.Erase(it->first));
}
db->stats.inline_keys -= it->first.IsInline();
db->stats.obj_memory_usage -= (it->first.MallocUsed() + it->second.MallocUsed());
db->prime_table.Erase(it);
@ -229,6 +243,7 @@ size_t DbSlice::FlushDb(DbIndex db_ind) {
size_t removed = db->prime_table.size();
db->prime_table.Clear();
db->expire_table.Clear();
db->mcflag_table.Clear();
db->stats.inline_keys = 0;
db->stats.obj_memory_usage = 0;
@ -270,24 +285,47 @@ bool DbSlice::Expire(DbIndex db_ind, MainIterator it, uint64_t at) {
return false;
}
void DbSlice::AddNew(DbIndex db_ind, string_view key, PrimeValue obj, uint64_t expire_at_ms) {
void DbSlice::SetMCFlag(DbIndex db_ind, PrimeKey key, uint32_t flag) {
auto& db = *db_arr_[db_ind];
if (flag == 0) {
db.mcflag_table.Erase(key);
} else {
auto [it, inserted] = db.mcflag_table.Insert(std::move(key), flag);
if (!inserted)
it->second = flag;
}
}
uint32_t DbSlice::GetMCFlag(DbIndex db_ind, const PrimeKey& key) const {
auto& db = *db_arr_[db_ind];
auto it = db.mcflag_table.Find(key);
return it.is_done() ? 0 : it->second;
}
MainIterator DbSlice::AddNew(DbIndex db_ind, string_view key, PrimeValue obj,
uint64_t expire_at_ms) {
for (const auto& ccb : change_cb_) {
ccb.second(db_ind, ChangeReq{key});
}
CHECK(AddIfNotExist(db_ind, key, std::move(obj), expire_at_ms));
auto [res, added] = AddIfNotExist(db_ind, key, std::move(obj), expire_at_ms);
CHECK(added);
return res;
}
bool DbSlice::AddIfNotExist(DbIndex db_ind, string_view key, PrimeValue obj,
uint64_t expire_at_ms) {
pair<MainIterator, bool> DbSlice::AddIfNotExist(DbIndex db_ind, string_view key, PrimeValue obj,
uint64_t expire_at_ms) {
DCHECK(!obj.IsRef());
auto& db = db_arr_[db_ind];
CompactObj co_key{key};
auto [new_entry, success] = db->prime_table.Insert(std::move(co_key), std::move(obj));
// in this case obj won't be moved and will be destroyed during unwinding.
if (!success)
return false; // in this case obj won't be moved and will be destroyed during unwinding.
return make_pair(new_entry, false);
new_entry.SetVersion(NextVersion());
@ -300,7 +338,7 @@ bool DbSlice::AddIfNotExist(DbIndex db_ind, string_view key, PrimeValue obj,
CHECK(db->expire_table.Insert(new_entry->first.AsRef(), expire_at_ms).second);
}
return true;
return make_pair(new_entry, true);
}
size_t DbSlice::DbSize(DbIndex db_ind) const {
@ -358,7 +396,6 @@ void DbSlice::Release(IntentLock::Mode mode, const KeyLockArgs& lock_args) {
}
}
}
}
}

View file

@ -103,12 +103,17 @@ class DbSlice {
// Does not change expiry if at != 0 and expiry already exists.
bool Expire(DbIndex db_ind, MainIterator main_it, uint64_t at);
void SetMCFlag(DbIndex db_ind, PrimeKey key, uint32_t flag);
uint32_t GetMCFlag(DbIndex db_ind, const PrimeKey& key) const;
// Adds a new entry. Requires: key does not exist in this slice.
void AddNew(DbIndex db_ind, std::string_view key, PrimeValue obj, uint64_t expire_at_ms);
// Returns the iterator to the newly added entry.
MainIterator AddNew(DbIndex db_ind, std::string_view key, PrimeValue obj, uint64_t expire_at_ms);
// Adds a new entry if a key does not exists. Returns true if insertion took place,
// false otherwise. expire_at_ms equal to 0 - means no expiry.
bool AddIfNotExist(DbIndex db_ind, std::string_view key, PrimeValue obj, uint64_t expire_at_ms);
std::pair<MainIterator, bool> AddIfNotExist(DbIndex db_ind, std::string_view key, PrimeValue obj,
uint64_t expire_at_ms);
// Creates a database with index `db_ind`. If such database exists does nothing.
void ActivateDb(DbIndex db_ind);
@ -162,7 +167,9 @@ class DbSlice {
// Current version of this slice.
// We maintain a shared versioning scheme for all databases in the slice.
uint64_t version() const { return version_; }
uint64_t version() const {
return version_;
}
// ChangeReq - describes the change to the table. If MainIterator is defined then
// it's an update on the existing entry, otherwise if string_view is defined then
@ -190,8 +197,8 @@ class DbSlice {
EngineShard* owner_;
uint64_t now_ms_ = 0; // Used for expire logic, represents a real clock.
uint64_t version_ = 1; // Used to version entries in the PrimeTable.
uint64_t now_ms_ = 0; // Used for expire logic, represents a real clock.
uint64_t version_ = 1; // Used to version entries in the PrimeTable.
mutable SliceEvents events_; // we may change this even for const operations.
using LockTable = absl::flat_hash_map<std::string, IntentLock>;
@ -199,13 +206,13 @@ class DbSlice {
struct DbWrapper {
PrimeTable prime_table;
ExpireTable expire_table;
DashTable<PrimeKey, uint32_t, detail::ExpireTablePolicy> mcflag_table;
LockTable lock_table;
mutable InternalDbStats stats;
explicit DbWrapper(std::pmr::memory_resource* mr)
: prime_table(4, detail::PrimeTablePolicy{}, mr) {
}
explicit DbWrapper(std::pmr::memory_resource* mr);
};
std::vector<std::unique_ptr<DbWrapper>> db_arr_;

View file

@ -259,6 +259,27 @@ TEST_F(DflyEngineTest, EvalSha) {
EXPECT_THAT(resp[0], ErrArg("No matching"));
}
TEST_F(DflyEngineTest, Memcache) {
using MP = MemcacheParser;
auto resp = RunMC(MP::SET, "foo", "bar", 1);
EXPECT_EQ(resp, "STORED\r\n");
resp = RunMC(MP::GET, "foo");
EXPECT_EQ(resp, "VALUE foo 1 3\r\nbar\r\nEND\r\n");
resp = RunMC(MP::ADD, "foo", "bar", 1);
EXPECT_EQ(resp, "NOT_STORED\r\n");
resp = RunMC(MP::REPLACE, "foo2", "bar", 1);
EXPECT_EQ(resp, "NOT_STORED\r\n");
resp = RunMC(MP::ADD, "foo2", "bar2", 2);
EXPECT_EQ(resp, "STORED\r\n");
resp = GetMC(MP::GET, {"foo2", "foo"});
// EXPECT_EQ(resp, "");
}
// TODO: to test transactions with a single shard since then all transactions become local.
// To consider having a parameter in dragonfly engine controlling number of shards
// unconditionally from number of cpus. TO TEST BLPOP under multi for single/multi argument case.

View file

@ -499,6 +499,9 @@ void Service::DispatchMC(const MemcacheParser::Command& cmd, std::string_view va
break;
case MemcacheParser::GET:
strcpy(cmd_name, "GET");
if (cmd.keys_ext.size() > 0) {
return mc_builder->SendClientError("multiple keys are not suported");
}
break;
default:
mc_builder->SendClientError("bad command line format");
@ -516,10 +519,13 @@ void Service::DispatchMC(const MemcacheParser::Command& cmd, std::string_view va
if (set_opt[0]) {
args.emplace_back(set_opt, strlen(set_opt));
}
cntx->conn_state.memcache_flag = cmd.flags;
}
CmdArgList arg_list{args.data(), args.size()};
DispatchCommand(arg_list, cntx);
DispatchCommand(CmdArgList{args}, cntx);
// Reset back.
cntx->conn_state.memcache_flag = 0;
}
bool Service::IsLocked(DbIndex db_index, std::string_view key) const {

View file

@ -42,9 +42,6 @@ MemcacheParser::Result ParseStore(const std::string_view* tokens, unsigned num_t
!absl::SimpleAtoi(tokens[2], &res->bytes_len))
return MemcacheParser::BAD_INT;
if (flags > 0xFFFF)
return MemcacheParser::BAD_INT;
if (res->type == MemcacheParser::CAS && !absl::SimpleAtoi(tokens[3], &res->cas_unique)) {
return MemcacheParser::BAD_INT;
}

View file

@ -35,15 +35,16 @@ class MemcacheParser {
DECR = 23,
};
// According to https://github.com/memcached/memcached/wiki/Commands#standard-protocol
struct Command {
CmdType type = INVALID;
std::string_view key;
std::vector<std::string_view> keys_ext;
uint64_t cas_unique = 0;
uint32_t expire_ts = 0;
uint32_t expire_ts = 0; // relative time in seconds.
uint32_t bytes_len = 0;
uint16_t flags = 0;
uint32_t flags = 0;
bool no_reply = false;
};

View file

@ -98,11 +98,19 @@ void MCReplyBuilder::SendError(string_view str) {
SendDirect("ERROR\r\n");
}
void MCReplyBuilder::EndMultiLine() {
SendDirect("END\r\n");
}
void MCReplyBuilder::SendClientError(string_view str) {
iovec v[] = {IoVec("CLIENT_ERROR"), IoVec(str), IoVec(kCRLF)};
iovec v[] = {IoVec("CLIENT_ERROR "), IoVec(str), IoVec(kCRLF)};
Send(v, ABSL_ARRAYSIZE(v));
}
void MCReplyBuilder::SendSetSkipped() {
SendDirect("NOT_STORED\r\n");
}
RedisReplyBuilder::RedisReplyBuilder(::io::Sink* sink) : SinkReplyBuilder(sink) {
}
@ -128,6 +136,10 @@ void RedisReplyBuilder::SendStored() {
SendSimpleString("OK");
}
void RedisReplyBuilder::SendSetSkipped() {
SendNull();
}
void RedisReplyBuilder::SendNull() {
constexpr char kNullStr[] = "$-1\r\n";

View file

@ -24,6 +24,9 @@ class ReplyBuilderInterface {
virtual void SendGetNotFound() = 0;
virtual void SendGetReply(std::string_view key, uint32_t flags, std::string_view value) = 0;
virtual void SendSetSkipped() = 0;
virtual void EndMultiLine() {}
};
class SinkReplyBuilder : public ReplyBuilderInterface {
@ -88,6 +91,9 @@ class MCReplyBuilder : public SinkReplyBuilder {
void SendStored() final;
void EndMultiLine() final;
void SendSetSkipped() final;
void SendClientError(std::string_view str);
};
@ -103,6 +109,7 @@ class RedisReplyBuilder : public SinkReplyBuilder {
void SendGetReply(std::string_view key, uint32_t flags, std::string_view value) override;
void SendGetNotFound() override;
void SendStored() override;
void SendSetSkipped() override;
void SendError(OpStatus status);
virtual void SendSimpleString(std::string_view str);

View file

@ -30,8 +30,6 @@ DEFINE_VARZ(VarzQps, get_qps);
} // namespace
SetCmd::SetCmd(DbSlice* db_slice) : db_slice_(db_slice) {
}
@ -60,27 +58,36 @@ OpResult<void> SetCmd::Set(const SetParams& params, std::string_view key, std::s
params.prev_val->emplace(move(val));
}
return SetExisting(params.db_index, value, at_ms, it, expire_it);
if (IsValid(expire_it) && at_ms) {
expire_it->second = at_ms;
} else {
db_slice_->Expire(params.db_index, it, at_ms);
}
db_slice_->PreUpdate(params.db_index, it);
// Check whether we need to update flags table.
bool req_flag_update = (params.memcache_flags != 0) != it->second.HasFlag();
if (req_flag_update) {
it->second.SetFlag(params.memcache_flags != 0);
db_slice_->SetMCFlag(params.db_index, it->first.AsRef(), params.memcache_flags);
}
it->second.SetString(value);
db_slice_->PostUpdate(params.db_index, it);
return OpStatus::OK;
}
// New entry
if (params.how == SET_IF_EXISTS)
return OpStatus::SKIPPED;
db_slice_->AddNew(params.db_index, key, PrimeValue{value}, at_ms);
PrimeValue tvalue{value};
tvalue.SetFlag(params.memcache_flags != 0);
it = db_slice_->AddNew(params.db_index, key, std::move(tvalue), at_ms);
return OpStatus::OK;
}
OpResult<void> SetCmd::SetExisting(DbIndex db_ind, std::string_view value, uint64_t expire_at_ms,
MainIterator dest, ExpireIterator exp_it) {
if (IsValid(exp_it) && expire_at_ms) {
exp_it->second = expire_at_ms;
} else {
db_slice_->Expire(db_ind, dest, expire_at_ms);
}
db_slice_->PreUpdate(db_ind, dest);
dest->second.SetString(value);
db_slice_->PostUpdate(db_ind, dest);
if (params.memcache_flags)
db_slice_->SetMCFlag(params.db_index, it->first.AsRef(), params.memcache_flags);
return OpStatus::OK;
}
@ -93,7 +100,10 @@ void StringFamily::Set(CmdArgList args, ConnectionContext* cntx) {
VLOG(2) << "Set " << key << " " << value;
SetCmd::SetParams sparams{cntx->db_index()};
sparams.memcache_flags = cntx->conn_state.memcache_flag;
int64_t int_arg;
ReplyBuilderInterface* builder = cntx->reply_builder();
for (size_t i = 3; i < args.size(); ++i) {
ToUpper(&args[i]);
@ -104,15 +114,17 @@ void StringFamily::Set(CmdArgList args, ConnectionContext* cntx) {
bool is_ms = (cur_arg == "PX");
++i;
if (i == args.size()) {
(*cntx)->SendError(kSyntaxErr);
builder->SendError(kSyntaxErr);
}
std::string_view ex = ArgS(args, i);
if (!absl::SimpleAtoi(ex, &int_arg)) {
return (*cntx)->SendError(kInvalidIntErr);
return builder->SendError(kInvalidIntErr);
}
if (int_arg <= 0 || (!is_ms && int_arg >= 500000000)) {
return (*cntx)->SendError("invalid expire time in set");
return builder->SendError("invalid expire time in set");
}
if (!is_ms) {
int_arg *= 1000;
}
@ -124,7 +136,7 @@ void StringFamily::Set(CmdArgList args, ConnectionContext* cntx) {
} else if (cur_arg == "KEEPTTL") {
sparams.keep_expire = true;
} else {
return (*cntx)->SendError(kSyntaxErr);
return builder->SendError(kSyntaxErr);
}
}
@ -138,25 +150,30 @@ void StringFamily::Set(CmdArgList args, ConnectionContext* cntx) {
OpResult<void> result = cntx->transaction->ScheduleSingleHop(std::move(cb));
if (result == OpStatus::OK) {
return (*cntx)->SendStored();
return builder->SendStored();
}
CHECK_EQ(result, OpStatus::SKIPPED); // in case of NX option
return (*cntx)->SendNull();
return builder->SendSetSkipped();
}
void StringFamily::Get(CmdArgList args, ConnectionContext* cntx) {
get_qps.Inc();
std::string_view key = ArgS(args, 1);
uint32_t mc_flag = 0;
auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult<string> {
OpResult<MainIterator> it_res = shard->db_slice().Find(cntx->db_index(), key, OBJ_STRING);
OpResult<MainIterator> it_res = shard->db_slice().Find(t->db_index(), key, OBJ_STRING);
if (!it_res.ok())
return it_res.status();
string val;
it_res.value()->second.GetString(&val);
if ((*it_res)->second.HasFlag() && cntx->protocol() == Protocol::MEMCACHE) {
mc_flag = shard->db_slice().GetMCFlag(t->db_index(), (*it_res)->first);
}
return val;
};
@ -165,19 +182,22 @@ void StringFamily::Get(CmdArgList args, ConnectionContext* cntx) {
Transaction* trans = cntx->transaction;
OpResult<string> result = trans->ScheduleSingleHopT(std::move(cb));
// This method is being used by both MC and Redis. We should use common interface.
ReplyBuilderInterface* builder = cntx->reply_builder();
if (result) {
DVLOG(1) << "GET " << trans->DebugId() << ": " << key << " " << result.value();
(*cntx)->SendGetReply(key, 0, result.value());
builder->SendGetReply(key, mc_flag, result.value());
} else {
switch (result.status()) {
case OpStatus::WRONG_TYPE:
(*cntx)->SendError(kWrongTypeErr);
builder->SendError(kWrongTypeErr);
break;
default:
DVLOG(1) << "GET " << key << " nil";
(*cntx)->SendGetNotFound();
builder->SendGetNotFound();
}
}
builder->EndMultiLine();
}
void StringFamily::GetSet(CmdArgList args, ConnectionContext* cntx) {
@ -386,7 +406,6 @@ OpResult<int64_t> StringFamily::OpIncrBy(const OpArgs& op_args, std::string_view
return new_val;
}
void StringFamily::Init(util::ProactorPool* pp) {
set_qps.Init(pp);
get_qps.Init(pp);

View file

@ -26,6 +26,7 @@ class SetCmd {
SetHow how = SET_ALWAYS;
DbIndex db_index;
uint32_t memcache_flags = 0;
// Relative value based on now. 0 means no expiration.
uint64_t expire_after_ms = 0;
mutable std::optional<std::string>* prev_val = nullptr; // GETSET option
@ -36,10 +37,6 @@ class SetCmd {
};
OpResult<void> Set(const SetParams& params, std::string_view key, std::string_view value);
private:
OpResult<void> SetExisting(DbIndex db_ind, std::string_view value, uint64_t expire_at_ms,
MainIterator dest, ExpireIterator exp_it);
};
class StringFamily {

View file

@ -16,6 +16,7 @@ namespace dfly {
using namespace testing;
using namespace util;
using namespace std;
using MP = MemcacheParser;
bool RespMatcher::MatchAndExplain(const RespExpr& e, MatchResultListener* listener) const {
if (e.type != type_) {
@ -113,12 +114,11 @@ vector<int64_t> ToIntArr(const RespVec& vec) {
return res;
}
BaseFamilyTest::TestConn::TestConn()
: dummy_conn(new Connection(Protocol::REDIS, nullptr, nullptr)),
cmd_cntx(&sink, dummy_conn.get()) {
BaseFamilyTest::TestConnWrapper::TestConnWrapper(Protocol proto)
: dummy_conn(new Connection(proto, nullptr, nullptr)), cmd_cntx(&sink, dummy_conn.get()) {
}
BaseFamilyTest::TestConn::~TestConn() {
BaseFamilyTest::TestConnWrapper::~TestConnWrapper() {
}
BaseFamilyTest::BaseFamilyTest() {
@ -161,41 +161,103 @@ RespVec BaseFamilyTest::Run(initializer_list<std::string_view> list) {
return pp_->at(0)->Await([&] { return this->Run(list); });
}
string id = GetId();
return Run(id, list);
return Run(GetId(), list);
}
RespVec BaseFamilyTest::Run(std::string_view id, std::initializer_list<std::string_view> list) {
mu_.lock();
auto [it, inserted] = connections_.emplace(id, nullptr);
if (inserted) {
it->second.reset(new TestConn);
} else {
it->second->sink.Clear();
}
TestConn* conn = it->second.get();
mu_.unlock();
TestConnWrapper* conn = AddFindConn(Protocol::REDIS, id);
CmdArgVec args = conn->Args(list);
CmdArgList cmd_arg_list{args.data(), args.size()};
auto& context = conn->cmd_cntx;
context.shard_set = ess_;
DCHECK(context.transaction == nullptr);
service_->DispatchCommand(cmd_arg_list, &context);
service_->DispatchCommand(CmdArgList{args}, &context);
DCHECK(context.transaction == nullptr);
unique_lock lk(mu_);
last_cmd_dbg_info_ = context.last_command_debug;
RespVec vec = conn->ParseResp();
RespVec vec = conn->ParseResponse();
return vec;
}
string BaseFamilyTest::RunMC(MP::CmdType cmd_type, string_view key, string_view value,
uint32_t flags, chrono::seconds ttl) {
if (!ProactorBase::IsProactorThread()) {
return pp_->at(0)->Await([&] { return this->RunMC(cmd_type, key, value, flags, ttl); });
}
MP::Command cmd;
cmd.type = cmd_type;
cmd.key = key;
cmd.flags = flags;
cmd.bytes_len = value.size();
cmd.expire_ts = ttl.count();
TestConnWrapper* conn = AddFindConn(Protocol::MEMCACHE, GetId());
auto& context = conn->cmd_cntx;
context.shard_set = ess_;
DCHECK(context.transaction == nullptr);
service_->DispatchMC(cmd, value, &context);
DCHECK(context.transaction == nullptr);
return conn->sink.str();
}
string BaseFamilyTest::RunMC(MP::CmdType cmd_type, std::string_view key) {
if (!ProactorBase::IsProactorThread()) {
return pp_->at(0)->Await([&] { return this->RunMC(cmd_type, key); });
}
MP::Command cmd;
cmd.type = cmd_type;
cmd.key = key;
TestConnWrapper* conn = AddFindConn(Protocol::MEMCACHE, GetId());
auto& context = conn->cmd_cntx;
context.shard_set = ess_;
service_->DispatchMC(cmd, string_view{}, &context);
return conn->sink.str();
}
string BaseFamilyTest::GetMC(MP::CmdType cmd_type,
std::initializer_list<std::string_view> list) {
CHECK_GT(list.size(), 0u);
CHECK(base::_in(cmd_type, {MP::GET, MP::GAT, MP::GETS, MP::GATS}));
if (!ProactorBase::IsProactorThread()) {
return pp_->at(0)->Await([&] { return this->GetMC(cmd_type, list); });
}
MP::Command cmd;
cmd.type = cmd_type;
auto src = list.begin();
cmd.key = *src++;
for (; src != list.end(); ++src) {
cmd.keys_ext.push_back(*src);
}
TestConnWrapper* conn = AddFindConn(Protocol::MEMCACHE, GetId());
auto& context = conn->cmd_cntx;
context.shard_set = ess_;
service_->DispatchMC(cmd, string_view{}, &context);
return conn->sink.str();
}
int64_t BaseFamilyTest::CheckedInt(std::initializer_list<std::string_view> list) {
RespVec resp = Run(list);
CHECK_EQ(1u, resp.size());
@ -212,7 +274,7 @@ int64_t BaseFamilyTest::CheckedInt(std::initializer_list<std::string_view> list)
return res;
}
CmdArgVec BaseFamilyTest::TestConn::Args(std::initializer_list<std::string_view> list) {
CmdArgVec BaseFamilyTest::TestConnWrapper::Args(std::initializer_list<std::string_view> list) {
CHECK_NE(0u, list.size());
CmdArgVec res;
@ -226,7 +288,7 @@ CmdArgVec BaseFamilyTest::TestConn::Args(std::initializer_list<std::string_view>
return res;
}
RespVec BaseFamilyTest::TestConn::ParseResp() {
RespVec BaseFamilyTest::TestConnWrapper::ParseResponse() {
tmp_str_vec.emplace_back(new string{sink.str()});
auto& s = *tmp_str_vec.back();
auto buf = RespExpr::buffer(&s);
@ -264,4 +326,17 @@ ConnectionContext::DebugInfo BaseFamilyTest::GetDebugInfo(const std::string& id)
return it->second->cmd_cntx.last_command_debug;
}
auto BaseFamilyTest::AddFindConn(Protocol proto, std::string_view id) -> TestConnWrapper* {
unique_lock lk(mu_);
auto [it, inserted] = connections_.emplace(id, nullptr);
if (inserted) {
it->second.reset(new TestConnWrapper(proto));
} else {
it->second->sink.Clear();
}
return it->second.get();
}
} // namespace dfly

View file

@ -9,6 +9,7 @@
#include "io/io.h"
#include "server/conn_context.h"
#include "server/main_service.h"
#include "server/memcache_parser.h"
#include "server/redis_parser.h"
#include "util/proactor_pool.h"
@ -19,8 +20,7 @@ class RespMatcher {
RespMatcher(std::string_view val, RespExpr::Type t = RespExpr::STRING) : type_(t), exp_str_(val) {
}
RespMatcher(int64_t val, RespExpr::Type t = RespExpr::INT64)
: type_(t), exp_int_(val) {
RespMatcher(int64_t val, RespExpr::Type t = RespExpr::INT64) : type_(t), exp_int_(val) {
}
using is_gtest_matcher = void;
@ -96,10 +96,32 @@ class BaseFamilyTest : public ::testing::Test {
void TearDown() override;
protected:
struct TestConnWrapper {
::io::StringSink sink; // holds the response blob
std::unique_ptr<Connection> dummy_conn;
ConnectionContext cmd_cntx;
std::vector<std::unique_ptr<std::string>> tmp_str_vec;
std::unique_ptr<RedisParser> parser;
TestConnWrapper(Protocol proto);
~TestConnWrapper();
CmdArgVec Args(std::initializer_list<std::string_view> list);
RespVec ParseResponse();
};
RespVec Run(std::initializer_list<std::string_view> list);
RespVec Run(std::string_view id, std::initializer_list<std::string_view> list);
std::string RunMC(MemcacheParser::CmdType cmd_type, std::string_view key, std::string_view value,
uint32_t flags = 0, std::chrono::seconds ttl = std::chrono::seconds{});
std::string RunMC(MemcacheParser::CmdType cmd_type, std::string_view key = std::string_view{});
std::string GetMC(MemcacheParser::CmdType cmd_type, std::initializer_list<std::string_view> list);
int64_t CheckedInt(std::initializer_list<std::string_view> list);
bool IsLocked(DbIndex db_index, std::string_view key) const;
@ -109,6 +131,8 @@ class BaseFamilyTest : public ::testing::Test {
return GetDebugInfo("IO0");
}
TestConnWrapper* AddFindConn(Protocol proto, std::string_view id);
// ts is ms
void UpdateTime(uint64_t ms);
std::string GetId() const;
@ -118,24 +142,7 @@ class BaseFamilyTest : public ::testing::Test {
EngineShardSet* ess_ = nullptr;
unsigned num_threads_ = 3;
struct TestConn {
::io::StringSink sink;
std::unique_ptr<Connection> dummy_conn;
ConnectionContext cmd_cntx;
std::vector<std::unique_ptr<std::string>> tmp_str_vec;
std::unique_ptr<RedisParser> parser;
TestConn();
~TestConn();
CmdArgVec Args(std::initializer_list<std::string_view> list);
RespVec ParseResp();
};
absl::flat_hash_map<std::string, std::unique_ptr<TestConn>> connections_;
absl::flat_hash_map<std::string, std::unique_ptr<TestConnWrapper>> connections_;
::boost::fibers::mutex mu_;
ConnectionContext::DebugInfo last_cmd_dbg_info_;
};