1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-14 11:58:02 +00:00

Implement memcache get for multiple keys. Fetch flag values as well

This commit is contained in:
Roman Gershman 2022-02-21 20:27:18 +02:00
parent b3c9836682
commit 3d8af8b413
7 changed files with 115 additions and 69 deletions

View file

@ -42,7 +42,15 @@ struct ConnectionState {
};
uint32_t mask = 0; // A bitmask of Mask values.
uint32_t memcache_flag = 0; // used for memcache set command.
enum MCGetMask {
FETCH_CAS_VER = 1,
};
// used for memcache set/get commands.
// For set op - it's the flag value we are storing along with the value.
// For get op - we use it as a mask of MCGetMask values.
uint32_t memcache_flag = 0;
bool IsClosing() const {
return mask & CONN_CLOSING;

View file

@ -261,23 +261,23 @@ TEST_F(DflyEngineTest, EvalSha) {
TEST_F(DflyEngineTest, Memcache) {
using MP = MemcacheParser;
auto resp = RunMC(MP::SET, "foo", "bar", 1);
auto resp = RunMC(MP::SET, "key", "bar", 1);
EXPECT_EQ(resp, "STORED\r\n");
resp = RunMC(MP::GET, "foo");
EXPECT_EQ(resp, "VALUE foo 1 3\r\nbar\r\nEND\r\n");
resp = RunMC(MP::GET, "key");
EXPECT_EQ(resp, "VALUE key 1 3\r\nbar\r\nEND\r\n");
resp = RunMC(MP::ADD, "foo", "bar", 1);
resp = RunMC(MP::ADD, "key", "bar", 1);
EXPECT_EQ(resp, "NOT_STORED\r\n");
resp = RunMC(MP::REPLACE, "foo2", "bar", 1);
resp = RunMC(MP::REPLACE, "key2", "bar", 1);
EXPECT_EQ(resp, "NOT_STORED\r\n");
resp = RunMC(MP::ADD, "foo2", "bar2", 2);
resp = RunMC(MP::ADD, "key2", "bar2", 2);
EXPECT_EQ(resp, "STORED\r\n");
resp = GetMC(MP::GET, {"foo2", "foo"});
// EXPECT_EQ(resp, "");
resp = GetMC(MP::GET, {"key2", "key"});
EXPECT_EQ(resp, "VALUE key2 2 4\r\nbar2\r\nVALUE key 1 3\r\nbar\r\nEND\r\n");
}
// TODO: to test transactions with a single shard since then all transactions become local.

View file

@ -57,15 +57,14 @@ class InterpreterReplier : public RedisReplyBuilder {
void SendError(std::string_view str) override;
void SendGetReply(std::string_view key, uint32_t flags, std::string_view value) override;
void SendGetNotFound() override;
void SendStored() override;
void SendSimpleString(std::string_view str) final;
void SendMGetResponse(const StrOrNil* arr, uint32_t count) final;
void SendSimpleStrArr(const std::string_view* arr, uint32_t count) final;
void SendMGetResponse(const OptResp* resp, uint32_t count) final;
void SendSimpleStrArr(const string_view* arr, uint32_t count) final;
void SendNullArray() final;
void SendStringArr(absl::Span<const std::string_view> arr) final;
void SendStringArr(absl::Span<const string_view> arr) final;
void SendNull() final;
void SendLong(long val) final;
@ -160,11 +159,6 @@ void InterpreterReplier::SendGetReply(string_view key, uint32_t flags, string_vi
explr_->OnString(value);
}
void InterpreterReplier::SendGetNotFound() {
DCHECK(array_len_.empty());
explr_->OnNil();
}
void InterpreterReplier::SendStored() {
DCHECK(array_len_.empty());
SendSimpleString("OK");
@ -178,13 +172,13 @@ void InterpreterReplier::SendSimpleString(string_view str) {
PostItem();
}
void InterpreterReplier::SendMGetResponse(const StrOrNil* arr, uint32_t count) {
void InterpreterReplier::SendMGetResponse(const OptResp* resp, uint32_t count) {
DCHECK(array_len_.empty());
explr_->OnArrayStart(count);
for (uint32_t i = 0; i < count; ++i) {
if (arr[i].has_value()) {
explr_->OnString(*arr[i]);
if (resp[i].has_value()) {
explr_->OnString(resp[i]->value);
} else {
explr_->OnNil();
}
@ -498,10 +492,7 @@ void Service::DispatchMC(const MemcacheParser::Command& cmd, std::string_view va
strcpy(set_opt, "NX");
break;
case MemcacheParser::GET:
strcpy(cmd_name, "GET");
if (cmd.keys_ext.size() > 0) {
return mc_builder->SendClientError("multiple keys are not suported");
}
strcpy(cmd_name, "MGET");
break;
default:
mc_builder->SendClientError("bad command line format");
@ -520,6 +511,11 @@ void Service::DispatchMC(const MemcacheParser::Command& cmd, std::string_view va
args.emplace_back(set_opt, strlen(set_opt));
}
cntx->conn_state.memcache_flag = cmd.flags;
} else {
for (auto s : cmd.keys_ext) {
char* key = const_cast<char*>(s.data());
args.emplace_back(key, s.size());
}
}
DispatchCommand(CmdArgList{args}, cntx);

View file

@ -94,12 +94,28 @@ void MCReplyBuilder::SendGetReply(std::string_view key, uint32_t flags, std::str
Send(v, ABSL_ARRAYSIZE(v));
}
void MCReplyBuilder::SendError(string_view str) {
SendDirect("ERROR\r\n");
void MCReplyBuilder::SendMGetResponse(const OptResp* resp, uint32_t count) {
string header;
for (unsigned i = 0; i < count; ++i) {
if (resp[i]) {
const auto& src = *resp[i];
absl::StrAppend(&header, "VALUE ", src.key, " ", src.mc_flag, " ",
src.value.size());
if (src.mc_ver) {
absl::StrAppend(&header, " ", src.mc_ver);
}
absl::StrAppend(&header, "\r\n");
iovec v[] = {IoVec(header), IoVec(src.value), IoVec(kCRLF)};
Send(v, ABSL_ARRAYSIZE(v));
header.clear();
}
}
SendDirect("END\r\n");
}
void MCReplyBuilder::EndMultiLine() {
SendDirect("END\r\n");
void MCReplyBuilder::SendError(string_view str) {
SendDirect("ERROR\r\n");
}
void MCReplyBuilder::SendClientError(string_view str) {
@ -128,10 +144,6 @@ void RedisReplyBuilder::SendGetReply(std::string_view key, uint32_t flags, std::
SendBulkString(value);
}
void RedisReplyBuilder::SendGetNotFound() {
SendNull();
}
void RedisReplyBuilder::SendStored() {
SendSimpleString("OK");
}
@ -193,12 +205,12 @@ void RedisReplyBuilder::SendDouble(double val) {
SendBulkString(absl::StrCat(val));
}
void RedisReplyBuilder::SendMGetResponse(const StrOrNil* arr, uint32_t count) {
void RedisReplyBuilder::SendMGetResponse(const OptResp* resp, uint32_t count) {
string res = absl::StrCat("*", count, kCRLF);
for (size_t i = 0; i < count; ++i) {
if (arr[i]) {
StrAppend(&res, "$", arr[i]->size(), kCRLF);
res.append(*arr[i]).append(kCRLF);
if (resp[i]) {
StrAppend(&res, "$", resp[i]->value.size(), kCRLF);
res.append(resp[i]->value).append(kCRLF);
} else {
res.append("$-1\r\n");
}

View file

@ -22,11 +22,20 @@ class ReplyBuilderInterface {
virtual std::error_code GetError() const = 0;
virtual void SendGetNotFound() = 0;
virtual void SendGetReply(std::string_view key, uint32_t flags, std::string_view value) = 0;
virtual void SendSetSkipped() = 0;
virtual void EndMultiLine() {}
struct ResponseValue {
std::string_view key;
std::string value;
uint64_t mc_ver = 0; // 0 means we do not output it (i.e has not been requested).
uint32_t mc_flag = 0;
};
using OptResp = std::optional<ResponseValue>;
virtual void SendMGetResponse(const OptResp* resp, uint32_t count) = 0;
virtual void SendSetSkipped() = 0;
};
class SinkReplyBuilder : public ReplyBuilderInterface {
@ -84,14 +93,10 @@ class MCReplyBuilder : public SinkReplyBuilder {
void SendError(std::string_view str) final;
void SendGetReply(std::string_view key, uint32_t flags, std::string_view value) final;
// memcache does not print keys that are not found.
void SendGetNotFound() final {
}
void SendMGetResponse(const OptResp* resp, uint32_t count) final;
void SendStored() final;
void EndMultiLine() final;
void SendSetSkipped() final;
void SendClientError(std::string_view str);
@ -107,15 +112,14 @@ class RedisReplyBuilder : public SinkReplyBuilder {
void SendError(std::string_view str) override;
void SendGetReply(std::string_view key, uint32_t flags, std::string_view value) override;
void SendGetNotFound() override;
void SendMGetResponse(const OptResp* resp, uint32_t count) override;
void SendStored() override;
void SendSetSkipped() override;
void SendError(OpStatus status);
virtual void SendSimpleString(std::string_view str);
using StrOrNil = std::optional<std::string_view>;
virtual void SendMGetResponse(const StrOrNil* arr, uint32_t count);
virtual void SendSimpleStrArr(const std::string_view* arr, uint32_t count);
virtual void SendNullArray();

View file

@ -171,10 +171,6 @@ void StringFamily::Get(CmdArgList args, ConnectionContext* cntx) {
string val;
it_res.value()->second.GetString(&val);
if ((*it_res)->second.HasFlag() && cntx->protocol() == Protocol::MEMCACHE) {
mc_flag = shard->db_slice().GetMCFlag(t->db_index(), (*it_res)->first);
}
return val;
};
@ -182,22 +178,19 @@ void StringFamily::Get(CmdArgList args, ConnectionContext* cntx) {
Transaction* trans = cntx->transaction;
OpResult<string> result = trans->ScheduleSingleHopT(std::move(cb));
// This method is being used by both MC and Redis. We should use common interface.
ReplyBuilderInterface* builder = cntx->reply_builder();
if (result) {
DVLOG(1) << "GET " << trans->DebugId() << ": " << key << " " << result.value();
builder->SendGetReply(key, mc_flag, result.value());
(*cntx)->SendGetReply(key, mc_flag, result.value());
} else {
switch (result.status()) {
case OpStatus::WRONG_TYPE:
builder->SendError(kWrongTypeErr);
(*cntx)->SendError(kWrongTypeErr);
break;
default:
DVLOG(1) << "GET " << key << " nil";
builder->SendGetNotFound();
(*cntx)->SendNull();
}
}
builder->EndMultiLine();
}
void StringFamily::GetSet(CmdArgList args, ConnectionContext* cntx) {
@ -290,9 +283,12 @@ void StringFamily::MGet(CmdArgList args, ConnectionContext* cntx) {
unsigned shard_count = transaction->shard_set()->size();
std::vector<MGetResponse> mget_resp(shard_count);
bool fetch_mcflag = cntx->protocol() == Protocol::MEMCACHE;
bool fetch_mcver = fetch_mcflag && (cntx->conn_state.mask & ConnectionState::FETCH_CAS_VER);
auto cb = [&](Transaction* t, EngineShard* shard) {
ShardId sid = shard->shard_id();
mget_resp[sid] = OpMGet(t, shard);
mget_resp[sid] = OpMGet(fetch_mcflag, fetch_mcver, t, shard);
return OpStatus::OK;
};
@ -303,21 +299,34 @@ void StringFamily::MGet(CmdArgList args, ConnectionContext* cntx) {
CHECK_EQ(OpStatus::OK, result);
// reorder the responses back according to the order of their corresponding keys.
vector<std::optional<std::string_view>> res(args.size() - 1);
vector<ReplyBuilderInterface::OptResp> res(args.size() - 1);
for (ShardId sid = 0; sid < shard_count; ++sid) {
if (!transaction->IsActive(sid))
continue;
auto& values = mget_resp[sid];
MGetResponse& results = mget_resp[sid];
ArgSlice slice = transaction->ShardArgsInShard(sid);
DCHECK(!slice.empty());
DCHECK_EQ(slice.size(), values.size());
DCHECK_EQ(slice.size(), results.size());
for (size_t j = 0; j < slice.size(); ++j) {
if (!results[j])
continue;
uint32_t indx = transaction->ReverseArgIndex(sid, j);
res[indx] = values[j];
auto& dest = res[indx].emplace();
auto& src = *results[j];
dest.key = ArgS(args, indx + 1);
dest.value = std::move(src.value);
dest.mc_flag = src.mc_flag;
dest.mc_ver = src.mc_ver;
}
}
return (*cntx)->SendMGetResponse(res.data(), res.size());
return cntx->reply_builder()->SendMGetResponse(res.data(), res.size());
}
void StringFamily::MSet(CmdArgList args, ConnectionContext* cntx) {
@ -339,7 +348,8 @@ void StringFamily::MSet(CmdArgList args, ConnectionContext* cntx) {
return (*cntx)->SendOk();
}
auto StringFamily::OpMGet(const Transaction* t, EngineShard* shard) -> MGetResponse {
auto StringFamily::OpMGet(bool fetch_mcflag, bool fetch_mcver, const Transaction* t,
EngineShard* shard) -> MGetResponse {
auto args = t->ShardArgsInShard(shard->shard_id());
DCHECK(!args.empty());
@ -348,8 +358,18 @@ auto StringFamily::OpMGet(const Transaction* t, EngineShard* shard) -> MGetRespo
auto& db_slice = shard->db_slice();
for (size_t i = 0; i < args.size(); ++i) {
OpResult<MainIterator> it_res = db_slice.Find(t->db_index(), args[i], OBJ_STRING);
if (it_res.ok()) {
it_res.value()->second.GetString(&response[i].emplace());
if (!it_res)
continue;
const MainIterator& it = *it_res;
auto& dest = response[i].emplace();
it->second.GetString(&dest.value);
if (fetch_mcflag) {
dest.mc_flag = db_slice.GetMCFlag(t->db_index(), it->first);
if (fetch_mcver) {
dest.mc_ver = it.GetVersion();
}
}
}

View file

@ -59,9 +59,15 @@ class StringFamily {
static void IncrByGeneric(std::string_view key, int64_t val, ConnectionContext* cntx);
using MGetResponse = std::vector<std::optional<std::string>>;
struct GetResp {
std::string value;
uint64_t mc_ver = 0; // 0 means we do not output it (i.e has not been requested).
uint32_t mc_flag = 0;
};
static MGetResponse OpMGet(const Transaction* t, EngineShard* shard);
using MGetResponse = std::vector<std::optional<GetResp>>;
static MGetResponse OpMGet(bool fetch_mcflag, bool fetch_mcver,
const Transaction* t, EngineShard* shard);
static OpStatus OpMSet(const Transaction* t, EngineShard* es);
static OpResult<int64_t> OpIncrBy(const OpArgs& op_args, std::string_view key, int64_t val);