1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-15 17:51:06 +00:00

Implement PSUBSCRIBE/PUNSUBSCRIBE commands.

Add minimal tests.
This commit is contained in:
Roman Gershman 2022-06-02 07:54:34 +03:00
parent 8570a12d81
commit ec9754150f
13 changed files with 373 additions and 85 deletions

View file

@ -268,15 +268,15 @@ API 2.0
- [X] HSETNX
- [X] HVALS
- [X] HSCAN
- [ ] PubSub family
- [X] PubSub family
- [X] PUBLISH
- [ ] PUBSUB
- [ ] PUBSUB CHANNELS
- [X] SUBSCRIBE
- [X] UNSUBSCRIBE
- [ ] PSUBSCRIBE
- [ ] PUNSUBSCRIBE
- [ ] Server Family
- [X] PSUBSCRIBE
- [X] PUNSUBSCRIBE
- [X] Server Family
- [ ] WATCH
- [ ] UNWATCH
- [X] DISCARD

View file

@ -69,11 +69,11 @@ constexpr size_t kMinReadSize = 256;
constexpr size_t kMaxReadSize = 32_KB;
struct AsyncMsg {
absl::Span<const std::string_view> msg_vec;
Connection::PubMessage pub_msg;
fibers_ext::BlockingCounter bc;
AsyncMsg(absl::Span<const std::string_view> vec, fibers_ext::BlockingCounter b)
: msg_vec(vec), bc(move(b)) {
AsyncMsg(const Connection::PubMessage& pmsg, fibers_ext::BlockingCounter b)
: pub_msg(pmsg), bc(move(b)) {
}
};
@ -245,15 +245,17 @@ void Connection::RegisterOnBreak(BreakerCb breaker_cb) {
breaker_cb_ = breaker_cb;
}
void Connection::SendMsgVecAsync(absl::Span<const std::string_view> msg_vec,
void Connection::SendMsgVecAsync(const PubMessage& pub_msg,
fibers_ext::BlockingCounter bc) {
DCHECK(cc_);
if (cc_->conn_closing) {
bc.Dec();
return;
}
void* ptr = mi_malloc(sizeof(AsyncMsg));
AsyncMsg* amsg = new (ptr) AsyncMsg(msg_vec, move(bc));
AsyncMsg* amsg = new (ptr) AsyncMsg(pub_msg, move(bc));
ptr = mi_malloc(sizeof(Request));
Request* req = new (ptr) Request(0, 0);
@ -571,7 +573,24 @@ void Connection::DispatchFiber(util::FiberSocketBase* peer) {
if (req->async_msg) {
++stats->async_writes_cnt;
builder->SendRawVec(req->async_msg->msg_vec);
RedisReplyBuilder* rbuilder = (RedisReplyBuilder*)builder;
const PubMessage& pub_msg = req->async_msg->pub_msg;
string_view arr[4];
if (pub_msg.pattern.empty()) {
arr[0] = "message";
arr[1] = pub_msg.channel;
arr[2] = pub_msg.message;
rbuilder->SendStringArr(absl::Span<string_view>{arr, 3});
} else {
arr[0] = "pmessage";
arr[1] = pub_msg.pattern;
arr[2] = pub_msg.channel;
arr[3] = pub_msg.message;
rbuilder->SendStringArr(absl::Span<string_view>{arr, 4});
}
req->async_msg->bc.Dec();
req->async_msg->~AsyncMsg();

View file

@ -46,11 +46,20 @@ class Connection : public util::Connection {
using BreakerCb = std::function<void(uint32_t)>;
void RegisterOnBreak(BreakerCb breaker_cb);
// This interface is used to pass a raw message directly to the socket via zero-copy interface.
// This interface is used to pass a published message directly to the socket without
// copying strings.
// Once the msg is sent "bc" will be decreased so that caller could release the underlying
// storage for the message.
void SendMsgVecAsync(absl::Span<const std::string_view> msg_vec,
util::fibers_ext::BlockingCounter bc);
// virtual - to allow the testing code to override it.
struct PubMessage {
// if empty - means its a regular message, otherwise it's pmessage.
std::string_view pattern;
std::string_view channel;
std::string_view message;
};
virtual void SendMsgVecAsync(const PubMessage& pub_msg, util::fibers_ext::BlockingCounter bc);
void SetName(std::string_view name) {
CopyCharBuf(name, sizeof(name_), name_);

View file

@ -4,6 +4,10 @@
#include "server/channel_slice.h"
extern "C" {
#include "redis/util.h"
}
namespace dfly {
using namespace std;
@ -11,6 +15,14 @@ ChannelSlice::Subscriber::Subscriber(ConnectionContext* cntx, uint32_t tid)
: conn_cntx(cntx), borrow_token(cntx->conn_state.subscribe_info->borrow_token), thread_id(tid) {
}
void ChannelSlice::AddSubscription(string_view channel, ConnectionContext* me, uint32_t thread_id) {
auto [it, added] = channels_.emplace(channel, nullptr);
if (added) {
it->second.reset(new Channel);
}
it->second->subscribers.emplace(me, SubscriberInternal{thread_id});
}
void ChannelSlice::RemoveSubscription(string_view channel, ConnectionContext* me) {
auto it = channels_.find(channel);
if (it != channels_.end()) {
@ -20,29 +32,52 @@ void ChannelSlice::RemoveSubscription(string_view channel, ConnectionContext* me
}
}
void ChannelSlice::AddSubscription(string_view channel, ConnectionContext* me, uint32_t thread_id) {
auto [it, added] = channels_.emplace(channel, nullptr);
void ChannelSlice::AddGlobPattern(string_view pattern, ConnectionContext* me, uint32_t thread_id) {
auto [it, added] = patterns_.emplace(pattern, nullptr);
if (added) {
it->second.reset(new Channel);
}
it->second->subscribers.emplace(me, SubscriberInternal{thread_id});
}
void ChannelSlice::RemoveGlobPattern(string_view pattern, ConnectionContext* me) {
auto it = patterns_.find(pattern);
if (it != patterns_.end()) {
it->second->subscribers.erase(me);
if (it->second->subscribers.empty())
patterns_.erase(it);
}
}
auto ChannelSlice::FetchSubscribers(string_view channel) -> vector<Subscriber> {
vector<Subscriber> res;
auto it = channels_.find(channel);
if (it != channels_.end()) {
res.reserve(it->second->subscribers.size());
for (const auto& k_v : it->second->subscribers) {
Subscriber s(k_v.first, k_v.second.thread_id);
s.borrow_token.Inc();
CopySubsribers(it->second->subscribers, string{}, &res);
}
res.push_back(std::move(s));
for (const auto& k_v : patterns_) {
const string& pat = k_v.first;
// 1 - match
if (stringmatchlen(pat.data(), pat.size(), channel.data(), channel.size(), 0) == 1) {
CopySubsribers(k_v.second->subscribers, pat, &res);
}
}
return res;
}
void ChannelSlice::CopySubsribers(const SubsribeMap& src, const std::string& pattern,
vector<Subscriber>* dest) {
for (const auto& sub : src) {
Subscriber s(sub.first, sub.second.thread_id);
s.pattern = pattern;
s.borrow_token.Inc();
dest->push_back(std::move(s));
}
}
} // namespace dfly

View file

@ -19,6 +19,9 @@ class ChannelSlice {
util::fibers_ext::BlockingCounter borrow_token;
uint32_t thread_id;
// non-empty if was registered via psubscribe
std::string pattern;
Subscriber(ConnectionContext* cntx, uint32_t tid);
// Subscriber() : borrow_token(0) {}
@ -31,18 +34,27 @@ class ChannelSlice {
std::vector<Subscriber> FetchSubscribers(std::string_view channel);
void RemoveSubscription(std::string_view channel, ConnectionContext* me);
void AddSubscription(std::string_view channel, ConnectionContext* me, uint32_t thread_id);
void RemoveSubscription(std::string_view channel, ConnectionContext* me);
void AddGlobPattern(std::string_view pattern, ConnectionContext* me, uint32_t thread_id);
void RemoveGlobPattern(std::string_view pattern, ConnectionContext* me);
private:
struct SubscriberInternal {
uint32_t thread_id; // proactor thread id.
SubscriberInternal(uint32_t tid) : thread_id(tid) {}
SubscriberInternal(uint32_t tid) : thread_id(tid) {
}
};
using SubsribeMap = absl::flat_hash_map<ConnectionContext*, SubscriberInternal>;
static void CopySubsribers(const SubsribeMap& src, const std::string& pattern,
std::vector<Subscriber>* dest);
struct Channel {
absl::flat_hash_map<ConnectionContext*, SubscriberInternal> subscribers;
SubsribeMap subscribers;
};
absl::flat_hash_map<std::string, std::unique_ptr<Channel>> channels_;

View file

@ -23,9 +23,11 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis
DCHECK(to_add);
conn_state.subscribe_info.reset(new ConnectionState::SubscribeInfo);
// to be able to read input and still write the output.
this->force_dispatch = true;
}
// Gather all the channels we need to subsribe to / remove.
for (size_t i = 0; i < args.size(); ++i) {
bool res = false;
string_view channel = ArgS(args, i);
@ -44,13 +46,14 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis
}
}
if (!to_add && conn_state.subscribe_info->channels.empty()) {
if (!to_add && conn_state.subscribe_info->IsEmpty()) {
conn_state.subscribe_info.reset();
force_dispatch = false;
}
sort(channels.begin(), channels.end());
// prepare the array in order to distribute the updates to the shards.
vector<unsigned> shard_idx(shard_set->size() + 1, 0);
for (const auto& k_v : channels) {
shard_idx[k_v.first]++;
@ -68,6 +71,7 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis
int32_t tid = util::ProactorBase::GetIndex();
DCHECK_GE(tid, 0);
// Update the subsribers on publisher's side.
auto cb = [&](EngineShard* shard) {
ChannelSlice& cs = shard->channel_slice();
unsigned start = shard_idx[shard->shard_id()];
@ -83,6 +87,7 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis
}
};
// Update subscription
shard_set->RunBriefInParallel(move(cb),
[&](ShardId sid) { return shard_idx[sid + 1] > shard_idx[sid]; });
}
@ -90,6 +95,77 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis
if (to_reply) {
const char* action[2] = {"unsubscribe", "subscribe"};
for (size_t i = 0; i < result.size(); ++i) {
(*this)->StartArray(3);
(*this)->SendBulkString(action[to_add]);
(*this)->SendBulkString(ArgS(args, i)); // channel
// number of subsribed channels for this connection *right after*
// we subsribe.
(*this)->SendLong(result[i]);
}
}
}
void ConnectionContext::ChangePSub(bool to_add, bool to_reply, CmdArgList args) {
vector<unsigned> result(to_reply ? args.size() : 0, 0);
if (to_add || conn_state.subscribe_info) {
std::vector<string_view> patterns;
patterns.reserve(args.size());
if (!conn_state.subscribe_info) {
DCHECK(to_add);
conn_state.subscribe_info.reset(new ConnectionState::SubscribeInfo);
this->force_dispatch = true;
}
// Gather all the patterns we need to subsribe to / remove.
for (size_t i = 0; i < args.size(); ++i) {
bool res = false;
string_view pattern = ArgS(args, i);
if (to_add) {
res = conn_state.subscribe_info->patterns.emplace(pattern).second;
} else {
res = conn_state.subscribe_info->patterns.erase(pattern) > 0;
}
if (to_reply)
result[i] = conn_state.subscribe_info->patterns.size();
if (res) {
patterns.emplace_back(pattern);
}
}
if (!to_add && conn_state.subscribe_info->IsEmpty()) {
conn_state.subscribe_info.reset();
force_dispatch = false;
}
int32_t tid = util::ProactorBase::GetIndex();
DCHECK_GE(tid, 0);
// Update the subsribers on publisher's side.
auto cb = [&](EngineShard* shard) {
ChannelSlice& cs = shard->channel_slice();
for (string_view pattern : patterns) {
if (to_add) {
cs.AddGlobPattern(pattern, this, tid);
} else {
cs.RemoveGlobPattern(pattern, this);
}
}
};
// Update pattern subscription. Run on all shards.
shard_set->RunBriefInParallel(move(cb));
}
if (to_reply) {
const char* action[2] = {"punsubscribe", "psubscribe"};
for (size_t i = 0; i < result.size(); ++i) {
(*this)->StartArray(3);
(*this)->SendBulkString(action[to_add]);
@ -100,18 +176,35 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis
}
void ConnectionContext::OnClose() {
if (conn_state.subscribe_info) {
if (!conn_state.subscribe_info)
return;
if (!conn_state.subscribe_info->channels.empty()) {
StringVec channels(conn_state.subscribe_info->channels.begin(),
conn_state.subscribe_info->channels.end());
CmdArgVec arg_vec(channels.begin(), channels.end());
auto token = conn_state.subscribe_info->borrow_token;
ChangeSubscription(false, false, CmdArgList{arg_vec});
DCHECK(!conn_state.subscribe_info);
// Check that all borrowers finished processing
token.Wait();
}
if (conn_state.subscribe_info) {
DCHECK(!conn_state.subscribe_info->patterns.empty());
StringVec patterns(conn_state.subscribe_info->patterns.begin(),
conn_state.subscribe_info->patterns.end());
CmdArgVec arg_vec(patterns.begin(), patterns.end());
auto token = conn_state.subscribe_info->borrow_token;
ChangePSub(false, false, CmdArgList{arg_vec});
// Check that all borrowers finished processing
token.Wait();
DCHECK(!conn_state.subscribe_info);
}
}
} // namespace dfly

View file

@ -50,10 +50,16 @@ struct ConnectionState {
struct SubscribeInfo {
// TODO: to provide unique_strings across service. This will allow us to use string_view here.
absl::flat_hash_set<std::string> channels;
absl::flat_hash_set<std::string> patterns;
util::fibers_ext::BlockingCounter borrow_token;
SubscribeInfo() : borrow_token(0) {}
bool IsEmpty() const {
return channels.empty() && patterns.empty();
}
SubscribeInfo() : borrow_token(0) {
}
};
std::unique_ptr<SubscribeInfo> subscribe_info;
@ -85,6 +91,7 @@ class ConnectionContext : public facade::ConnectionContext {
}
void ChangeSubscription(bool to_add, bool to_reply, CmdArgList args);
void ChangePSub(bool to_add, bool to_reply, CmdArgList args);
bool is_replicating = false;
};

View file

@ -22,13 +22,13 @@ extern "C" {
namespace dfly {
using namespace absl;
using namespace boost;
using namespace std;
using namespace util;
using ::io::Result;
using testing::ElementsAre;
using testing::HasSubstr;
using absl::StrCat;
namespace this_fiber = boost::this_fiber;
namespace {
@ -411,6 +411,20 @@ TEST_F(DflyEngineTest, OOM) {
}
}
TEST_F(DflyEngineTest, PSubscribe) {
auto resp = pp_->at(1)->Await([&] { return Run({"psubscribe", "a*", "b*"}); });
EXPECT_THAT(resp, ArrLen(3));
resp = pp_->at(0)->Await([&] { return Run({"publish", "ab", "foo"}); });
EXPECT_THAT(resp, IntArg(1));
ASSERT_EQ(1, SubsriberMessagesLen("IO1"));
facade::Connection::PubMessage msg = GetPublishedMessage("IO1", 0);
EXPECT_EQ("foo", msg.message);
EXPECT_EQ("ab", msg.channel);
EXPECT_EQ("a*", msg.pattern);
}
// TODO: to test transactions with a single shard since then all transactions become local.
// To consider having a parameter in dragonfly engine controlling number of shards
// unconditionally from number of cpus. TO TEST BLPOP under multi for single/multi argument case.

View file

@ -912,23 +912,22 @@ void Service::Publish(CmdArgList args, ConnectionContext* cntx) {
}
fibers_ext::BlockingCounter bc(subsriber_arr.size());
char prefix[] = "*3\r\n$7\r\nmessage\r\n$";
char msg_size[32] = {0};
char channel_size[32] = {0};
absl::SNPrintF(msg_size, sizeof(msg_size), "%u\r\n", message.size());
absl::SNPrintF(channel_size, sizeof(channel_size), "%u\r\n", channel.size());
string_view msg_arr[] = {prefix, channel_size, channel, "\r\n$", msg_size, message, "\r\n"};
auto publish_cb = [&, bc](unsigned idx, util::ProactorBase*) mutable {
unsigned start = slices[idx];
for (unsigned i = start; i < subsriber_arr.size(); ++i) {
if (subsriber_arr[i].thread_id != idx)
const ChannelSlice::Subscriber& subscriber = subsriber_arr[i];
if (subscriber.thread_id != idx)
break;
published.fetch_add(1, memory_order_relaxed);
subsriber_arr[i].conn_cntx->owner()->SendMsgVecAsync(msg_arr, bc);
facade::Connection* conn = subsriber_arr[i].conn_cntx->owner();
DCHECK(conn);
facade::Connection::PubMessage pmsg;
pmsg.channel = channel;
pmsg.message = message;
pmsg.pattern = subscriber.pattern;
conn->SendMsgVecAsync(pmsg, bc);
}
};
@ -959,6 +958,17 @@ void Service::Unsubscribe(CmdArgList args, ConnectionContext* cntx) {
cntx->ChangeSubscription(false, true, std::move(args));
}
void Service::PSubscribe(CmdArgList args, ConnectionContext* cntx) {
args.remove_prefix(1);
cntx->ChangePSub(true, true, args);
}
void Service::PUnsubscribe(CmdArgList args, ConnectionContext* cntx) {
args.remove_prefix(1);
cntx->ChangePSub(false, true, args);
}
// Not a real implementation. Serves as a decorator to accept some function commands
// for testing.
void Service::Function(CmdArgList args, ConnectionContext* cntx) {
@ -1024,6 +1034,8 @@ void Service::RegisterCommands() {
<< CI{"PUBLISH", CO::LOADING | CO::FAST, 3, 0, 0, 0}.MFUNC(Publish)
<< CI{"SUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -2, 0, 0, 0}.MFUNC(Subscribe)
<< CI{"UNSUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -2, 0, 0, 0}.MFUNC(Unsubscribe)
<< CI{"PSUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -2, 0, 0, 0}.MFUNC(PSubscribe)
<< CI{"PUNSUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -2, 0, 0, 0}.MFUNC(PUnsubscribe)
<< CI{"FUNCTION", CO::NOSCRIPT, 2, 0, 0, 0}.MFUNC(Function);
StringFamily::Register(&registry_);

View file

@ -94,6 +94,8 @@ class Service : public facade::ServiceInterface {
void Publish(CmdArgList args, ConnectionContext* cntx);
void Subscribe(CmdArgList args, ConnectionContext* cntx);
void Unsubscribe(CmdArgList args, ConnectionContext* cntx);
void PSubscribe(CmdArgList args, ConnectionContext* cntx);
void PUnsubscribe(CmdArgList args, ConnectionContext* cntx);
void Function(CmdArgList args, ConnectionContext* cntx);
struct EvalArgs {
@ -113,6 +115,7 @@ class Service : public facade::ServiceInterface {
ServerFamily server_family_;
CommandRegistry registry_;
absl::flat_hash_map<std::string, unsigned> unknown_cmds_;
mutable ::boost::fibers::mutex mu_;
GlobalState global_state_ = GlobalState::ACTIVE; // protected by mu_;

View file

@ -37,9 +37,68 @@ static vector<string> SplitLines(const std::string& src) {
return res;
}
TestConnection::TestConnection(Protocol protocol)
: facade::Connection(protocol, nullptr, nullptr, nullptr) {
}
void TestConnection::SendMsgVecAsync(const PubMessage& pmsg, util::fibers_ext::BlockingCounter bc) {
backing_str_.emplace_back(new string(pmsg.channel));
PubMessage dest;
dest.channel = *backing_str_.back();
backing_str_.emplace_back(new string(pmsg.message));
dest.message = *backing_str_.back();
if (!pmsg.pattern.empty()) {
backing_str_.emplace_back(new string(pmsg.pattern));
dest.pattern = *backing_str_.back();
}
messages.push_back(dest);
bc.Dec();
}
class BaseFamilyTest::TestConnWrapper {
public:
TestConnWrapper(Protocol proto);
~TestConnWrapper();
CmdArgVec Args(ArgSlice list);
RespVec ParseResponse();
// returns: type(pmessage), pattern, channel, message.
facade::Connection::PubMessage GetPubMessage(size_t index) const;
ConnectionContext* cmd_cntx() {
return &cmd_cntx_;
}
StringVec SplitLines() const {
return dfly::SplitLines(sink_.str());
}
void ClearSink() {
sink_.Clear();
}
TestConnection* conn() {
return dummy_conn_.get();
}
private:
::io::StringSink sink_; // holds the response blob
std::unique_ptr<TestConnection> dummy_conn_;
ConnectionContext cmd_cntx_;
std::vector<std::unique_ptr<std::string>> tmp_str_vec_;
std::unique_ptr<RedisParser> parser_;
};
BaseFamilyTest::TestConnWrapper::TestConnWrapper(Protocol proto)
: dummy_conn(new facade::Connection(proto, nullptr, nullptr, nullptr)),
cmd_cntx(&sink, dummy_conn.get()) {
: dummy_conn_(new TestConnection(proto)), cmd_cntx_(&sink_, dummy_conn_.get()) {
}
BaseFamilyTest::TestConnWrapper::~TestConnWrapper() {
@ -102,22 +161,22 @@ RespExpr BaseFamilyTest::Run(ArgSlice list) {
}
RespExpr BaseFamilyTest::Run(std::string_view id, ArgSlice slice) {
TestConnWrapper* conn = AddFindConn(Protocol::REDIS, id);
TestConnWrapper* conn_wrapper = AddFindConn(Protocol::REDIS, id);
CmdArgVec args = conn->Args(slice);
CmdArgVec args = conn_wrapper->Args(slice);
auto& context = conn->cmd_cntx;
auto* context = conn_wrapper->cmd_cntx();
DCHECK(context.transaction == nullptr);
DCHECK(context->transaction == nullptr);
service_->DispatchCommand(CmdArgList{args}, &context);
service_->DispatchCommand(CmdArgList{args}, context);
DCHECK(context.transaction == nullptr);
DCHECK(context->transaction == nullptr);
unique_lock lk(mu_);
last_cmd_dbg_info_ = context.last_command_debug;
last_cmd_dbg_info_ = context->last_command_debug;
RespVec vec = conn->ParseResponse();
RespVec vec = conn_wrapper->ParseResponse();
if (vec.size() == 1)
return vec.front();
RespVec* new_vec = new RespVec(vec);
@ -144,15 +203,15 @@ auto BaseFamilyTest::RunMC(MP::CmdType cmd_type, string_view key, string_view va
TestConnWrapper* conn = AddFindConn(Protocol::MEMCACHE, GetId());
auto& context = conn->cmd_cntx;
auto* context = conn->cmd_cntx();
DCHECK(context.transaction == nullptr);
DCHECK(context->transaction == nullptr);
service_->DispatchMC(cmd, value, &context);
service_->DispatchMC(cmd, value, context);
DCHECK(context.transaction == nullptr);
DCHECK(context->transaction == nullptr);
return SplitLines(conn->sink.str());
return conn->SplitLines();
}
auto BaseFamilyTest::RunMC(MP::CmdType cmd_type, std::string_view key) -> MCResponse {
@ -165,11 +224,11 @@ auto BaseFamilyTest::RunMC(MP::CmdType cmd_type, std::string_view key) -> MCResp
cmd.key = key;
TestConnWrapper* conn = AddFindConn(Protocol::MEMCACHE, GetId());
auto& context = conn->cmd_cntx;
auto* context = conn->cmd_cntx();
service_->DispatchMC(cmd, string_view{}, &context);
service_->DispatchMC(cmd, string_view{}, context);
return SplitLines(conn->sink.str());
return conn->SplitLines();
}
auto BaseFamilyTest::GetMC(MP::CmdType cmd_type, std::initializer_list<std::string_view> list)
@ -191,11 +250,11 @@ auto BaseFamilyTest::GetMC(MP::CmdType cmd_type, std::initializer_list<std::stri
TestConnWrapper* conn = AddFindConn(Protocol::MEMCACHE, GetId());
auto& context = conn->cmd_cntx;
auto* context = conn->cmd_cntx();
service_->DispatchMC(cmd, string_view{}, &context);
service_->DispatchMC(cmd, string_view{}, context);
return SplitLines(conn->sink.str());
return conn->SplitLines();
}
int64_t BaseFamilyTest::CheckedInt(std::initializer_list<std::string_view> list) {
@ -222,8 +281,8 @@ CmdArgVec BaseFamilyTest::TestConnWrapper::Args(ArgSlice list) {
if (v.empty()) {
res.push_back(MutableSlice{});
} else {
tmp_str_vec.emplace_back(new string{v});
auto& s = *tmp_str_vec.back();
tmp_str_vec_.emplace_back(new string{v});
auto& s = *tmp_str_vec_.back();
res.emplace_back(s.data(), s.size());
}
@ -233,19 +292,24 @@ CmdArgVec BaseFamilyTest::TestConnWrapper::Args(ArgSlice list) {
}
RespVec BaseFamilyTest::TestConnWrapper::ParseResponse() {
tmp_str_vec.emplace_back(new string{sink.str()});
auto& s = *tmp_str_vec.back();
tmp_str_vec_.emplace_back(new string{sink_.str()});
auto& s = *tmp_str_vec_.back();
auto buf = RespExpr::buffer(&s);
uint32_t consumed = 0;
parser.reset(new RedisParser{false}); // Client mode.
parser_.reset(new RedisParser{false}); // Client mode.
RespVec res;
RedisParser::Result st = parser->Parse(buf, &consumed, &res);
RedisParser::Result st = parser_->Parse(buf, &consumed, &res);
CHECK_EQ(RedisParser::OK, st);
return res;
}
facade::Connection::PubMessage BaseFamilyTest::TestConnWrapper::GetPubMessage(size_t index) const {
CHECK_LT(index, dummy_conn_->messages.size());
return dummy_conn_->messages[index];
}
bool BaseFamilyTest::IsLocked(DbIndex db_index, std::string_view key) const {
ShardId sid = Shard(key, shard_set->size());
KeyLockArgs args;
@ -263,11 +327,30 @@ string BaseFamilyTest::GetId() const {
return absl::StrCat("IO", id);
}
size_t BaseFamilyTest::SubsriberMessagesLen(string_view conn_id) const {
auto it = connections_.find(conn_id);
if (it == connections_.end())
return 0;
return it->second->conn()->messages.size();
}
facade::Connection::PubMessage BaseFamilyTest::GetPublishedMessage(string_view conn_id,
size_t index) const {
facade::Connection::PubMessage res;
auto it = connections_.find(conn_id);
if (it == connections_.end())
return res;
return it->second->GetPubMessage(index);
}
ConnectionContext::DebugInfo BaseFamilyTest::GetDebugInfo(const std::string& id) const {
auto it = connections_.find(id);
CHECK(it != connections_.end());
return it->second->cmd_cntx.last_command_debug;
return it->second->cmd_cntx()->last_command_debug;
}
auto BaseFamilyTest::AddFindConn(Protocol proto, std::string_view id) -> TestConnWrapper* {
@ -278,7 +361,7 @@ auto BaseFamilyTest::AddFindConn(Protocol proto, std::string_view id) -> TestCon
if (inserted) {
it->second.reset(new TestConnWrapper(proto));
} else {
it->second->sink.Clear();
it->second->ClearSink();
}
return it->second.get();
}

View file

@ -6,6 +6,7 @@
#include <gmock/gmock.h>
#include "facade/dragonfly_connection.h"
#include "facade/memcache_parser.h"
#include "facade/redis_parser.h"
#include "io/io.h"
@ -16,6 +17,18 @@
namespace dfly {
using namespace facade;
class TestConnection : public facade::Connection {
public:
TestConnection(Protocol protocol);
void SendMsgVecAsync(const PubMessage& pmsg, util::fibers_ext::BlockingCounter bc) final;
std::vector<PubMessage> messages;
private:
std::vector<std::unique_ptr<std::string>> backing_str_;
};
class BaseFamilyTest : public ::testing::Test {
protected:
BaseFamilyTest();
@ -27,23 +40,7 @@ class BaseFamilyTest : public ::testing::Test {
void TearDown() override;
protected:
struct TestConnWrapper {
::io::StringSink sink; // holds the response blob
std::unique_ptr<facade::Connection> dummy_conn;
ConnectionContext cmd_cntx;
std::vector<std::unique_ptr<std::string>> tmp_str_vec;
std::unique_ptr<RedisParser> parser;
TestConnWrapper(Protocol proto);
~TestConnWrapper();
CmdArgVec Args(ArgSlice list);
RespVec ParseResponse();
};
class TestConnWrapper;
RespExpr Run(std::initializer_list<const std::string_view> list) {
return Run(ArgSlice{list.begin(), list.size()});
@ -75,6 +72,11 @@ class BaseFamilyTest : public ::testing::Test {
void UpdateTime(uint64_t ms);
std::string GetId() const;
size_t SubsriberMessagesLen(std::string_view conn_id) const;
// Returns message parts as returned by RESP:
// pmessage, pattern, channel, message
facade::Connection::PubMessage GetPublishedMessage(std::string_view conn_id, size_t index) const;
std::unique_ptr<util::ProactorPool> pp_;
std::unique_ptr<Service> service_;

View file

@ -13,7 +13,6 @@
using namespace testing;
using namespace std;
using namespace util;
using namespace boost;
namespace dfly {