mirror of
https://github.com/dragonflydb/dragonfly.git
synced 2024-12-15 17:51:06 +00:00
Implement PSUBSCRIBE/PUNSUBSCRIBE commands.
Add minimal tests.
This commit is contained in:
parent
8570a12d81
commit
ec9754150f
13 changed files with 373 additions and 85 deletions
|
@ -268,15 +268,15 @@ API 2.0
|
|||
- [X] HSETNX
|
||||
- [X] HVALS
|
||||
- [X] HSCAN
|
||||
- [ ] PubSub family
|
||||
- [X] PubSub family
|
||||
- [X] PUBLISH
|
||||
- [ ] PUBSUB
|
||||
- [ ] PUBSUB CHANNELS
|
||||
- [X] SUBSCRIBE
|
||||
- [X] UNSUBSCRIBE
|
||||
- [ ] PSUBSCRIBE
|
||||
- [ ] PUNSUBSCRIBE
|
||||
- [ ] Server Family
|
||||
- [X] PSUBSCRIBE
|
||||
- [X] PUNSUBSCRIBE
|
||||
- [X] Server Family
|
||||
- [ ] WATCH
|
||||
- [ ] UNWATCH
|
||||
- [X] DISCARD
|
||||
|
|
|
@ -69,11 +69,11 @@ constexpr size_t kMinReadSize = 256;
|
|||
constexpr size_t kMaxReadSize = 32_KB;
|
||||
|
||||
struct AsyncMsg {
|
||||
absl::Span<const std::string_view> msg_vec;
|
||||
Connection::PubMessage pub_msg;
|
||||
fibers_ext::BlockingCounter bc;
|
||||
|
||||
AsyncMsg(absl::Span<const std::string_view> vec, fibers_ext::BlockingCounter b)
|
||||
: msg_vec(vec), bc(move(b)) {
|
||||
AsyncMsg(const Connection::PubMessage& pmsg, fibers_ext::BlockingCounter b)
|
||||
: pub_msg(pmsg), bc(move(b)) {
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -245,15 +245,17 @@ void Connection::RegisterOnBreak(BreakerCb breaker_cb) {
|
|||
breaker_cb_ = breaker_cb;
|
||||
}
|
||||
|
||||
void Connection::SendMsgVecAsync(absl::Span<const std::string_view> msg_vec,
|
||||
void Connection::SendMsgVecAsync(const PubMessage& pub_msg,
|
||||
fibers_ext::BlockingCounter bc) {
|
||||
DCHECK(cc_);
|
||||
|
||||
if (cc_->conn_closing) {
|
||||
bc.Dec();
|
||||
return;
|
||||
}
|
||||
|
||||
void* ptr = mi_malloc(sizeof(AsyncMsg));
|
||||
AsyncMsg* amsg = new (ptr) AsyncMsg(msg_vec, move(bc));
|
||||
AsyncMsg* amsg = new (ptr) AsyncMsg(pub_msg, move(bc));
|
||||
|
||||
ptr = mi_malloc(sizeof(Request));
|
||||
Request* req = new (ptr) Request(0, 0);
|
||||
|
@ -571,7 +573,24 @@ void Connection::DispatchFiber(util::FiberSocketBase* peer) {
|
|||
|
||||
if (req->async_msg) {
|
||||
++stats->async_writes_cnt;
|
||||
builder->SendRawVec(req->async_msg->msg_vec);
|
||||
|
||||
RedisReplyBuilder* rbuilder = (RedisReplyBuilder*)builder;
|
||||
const PubMessage& pub_msg = req->async_msg->pub_msg;
|
||||
string_view arr[4];
|
||||
|
||||
if (pub_msg.pattern.empty()) {
|
||||
arr[0] = "message";
|
||||
arr[1] = pub_msg.channel;
|
||||
arr[2] = pub_msg.message;
|
||||
rbuilder->SendStringArr(absl::Span<string_view>{arr, 3});
|
||||
} else {
|
||||
arr[0] = "pmessage";
|
||||
arr[1] = pub_msg.pattern;
|
||||
arr[2] = pub_msg.channel;
|
||||
arr[3] = pub_msg.message;
|
||||
rbuilder->SendStringArr(absl::Span<string_view>{arr, 4});
|
||||
}
|
||||
|
||||
req->async_msg->bc.Dec();
|
||||
|
||||
req->async_msg->~AsyncMsg();
|
||||
|
|
|
@ -46,11 +46,20 @@ class Connection : public util::Connection {
|
|||
using BreakerCb = std::function<void(uint32_t)>;
|
||||
void RegisterOnBreak(BreakerCb breaker_cb);
|
||||
|
||||
// This interface is used to pass a raw message directly to the socket via zero-copy interface.
|
||||
// This interface is used to pass a published message directly to the socket without
|
||||
// copying strings.
|
||||
// Once the msg is sent "bc" will be decreased so that caller could release the underlying
|
||||
// storage for the message.
|
||||
void SendMsgVecAsync(absl::Span<const std::string_view> msg_vec,
|
||||
util::fibers_ext::BlockingCounter bc);
|
||||
// virtual - to allow the testing code to override it.
|
||||
|
||||
struct PubMessage {
|
||||
// if empty - means its a regular message, otherwise it's pmessage.
|
||||
std::string_view pattern;
|
||||
std::string_view channel;
|
||||
std::string_view message;
|
||||
};
|
||||
|
||||
virtual void SendMsgVecAsync(const PubMessage& pub_msg, util::fibers_ext::BlockingCounter bc);
|
||||
|
||||
void SetName(std::string_view name) {
|
||||
CopyCharBuf(name, sizeof(name_), name_);
|
||||
|
|
|
@ -4,6 +4,10 @@
|
|||
|
||||
#include "server/channel_slice.h"
|
||||
|
||||
extern "C" {
|
||||
#include "redis/util.h"
|
||||
}
|
||||
|
||||
namespace dfly {
|
||||
using namespace std;
|
||||
|
||||
|
@ -11,6 +15,14 @@ ChannelSlice::Subscriber::Subscriber(ConnectionContext* cntx, uint32_t tid)
|
|||
: conn_cntx(cntx), borrow_token(cntx->conn_state.subscribe_info->borrow_token), thread_id(tid) {
|
||||
}
|
||||
|
||||
void ChannelSlice::AddSubscription(string_view channel, ConnectionContext* me, uint32_t thread_id) {
|
||||
auto [it, added] = channels_.emplace(channel, nullptr);
|
||||
if (added) {
|
||||
it->second.reset(new Channel);
|
||||
}
|
||||
it->second->subscribers.emplace(me, SubscriberInternal{thread_id});
|
||||
}
|
||||
|
||||
void ChannelSlice::RemoveSubscription(string_view channel, ConnectionContext* me) {
|
||||
auto it = channels_.find(channel);
|
||||
if (it != channels_.end()) {
|
||||
|
@ -20,29 +32,52 @@ void ChannelSlice::RemoveSubscription(string_view channel, ConnectionContext* me
|
|||
}
|
||||
}
|
||||
|
||||
void ChannelSlice::AddSubscription(string_view channel, ConnectionContext* me, uint32_t thread_id) {
|
||||
auto [it, added] = channels_.emplace(channel, nullptr);
|
||||
void ChannelSlice::AddGlobPattern(string_view pattern, ConnectionContext* me, uint32_t thread_id) {
|
||||
auto [it, added] = patterns_.emplace(pattern, nullptr);
|
||||
if (added) {
|
||||
it->second.reset(new Channel);
|
||||
}
|
||||
it->second->subscribers.emplace(me, SubscriberInternal{thread_id});
|
||||
}
|
||||
|
||||
void ChannelSlice::RemoveGlobPattern(string_view pattern, ConnectionContext* me) {
|
||||
auto it = patterns_.find(pattern);
|
||||
if (it != patterns_.end()) {
|
||||
it->second->subscribers.erase(me);
|
||||
if (it->second->subscribers.empty())
|
||||
patterns_.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
auto ChannelSlice::FetchSubscribers(string_view channel) -> vector<Subscriber> {
|
||||
vector<Subscriber> res;
|
||||
|
||||
auto it = channels_.find(channel);
|
||||
if (it != channels_.end()) {
|
||||
res.reserve(it->second->subscribers.size());
|
||||
for (const auto& k_v : it->second->subscribers) {
|
||||
Subscriber s(k_v.first, k_v.second.thread_id);
|
||||
s.borrow_token.Inc();
|
||||
CopySubsribers(it->second->subscribers, string{}, &res);
|
||||
}
|
||||
|
||||
res.push_back(std::move(s));
|
||||
for (const auto& k_v : patterns_) {
|
||||
const string& pat = k_v.first;
|
||||
// 1 - match
|
||||
if (stringmatchlen(pat.data(), pat.size(), channel.data(), channel.size(), 0) == 1) {
|
||||
CopySubsribers(k_v.second->subscribers, pat, &res);
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void ChannelSlice::CopySubsribers(const SubsribeMap& src, const std::string& pattern,
|
||||
vector<Subscriber>* dest) {
|
||||
for (const auto& sub : src) {
|
||||
Subscriber s(sub.first, sub.second.thread_id);
|
||||
s.pattern = pattern;
|
||||
s.borrow_token.Inc();
|
||||
|
||||
dest->push_back(std::move(s));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace dfly
|
||||
|
|
|
@ -19,6 +19,9 @@ class ChannelSlice {
|
|||
util::fibers_ext::BlockingCounter borrow_token;
|
||||
uint32_t thread_id;
|
||||
|
||||
// non-empty if was registered via psubscribe
|
||||
std::string pattern;
|
||||
|
||||
Subscriber(ConnectionContext* cntx, uint32_t tid);
|
||||
// Subscriber() : borrow_token(0) {}
|
||||
|
||||
|
@ -31,18 +34,27 @@ class ChannelSlice {
|
|||
|
||||
std::vector<Subscriber> FetchSubscribers(std::string_view channel);
|
||||
|
||||
void RemoveSubscription(std::string_view channel, ConnectionContext* me);
|
||||
void AddSubscription(std::string_view channel, ConnectionContext* me, uint32_t thread_id);
|
||||
void RemoveSubscription(std::string_view channel, ConnectionContext* me);
|
||||
|
||||
void AddGlobPattern(std::string_view pattern, ConnectionContext* me, uint32_t thread_id);
|
||||
void RemoveGlobPattern(std::string_view pattern, ConnectionContext* me);
|
||||
|
||||
private:
|
||||
struct SubscriberInternal {
|
||||
uint32_t thread_id; // proactor thread id.
|
||||
|
||||
SubscriberInternal(uint32_t tid) : thread_id(tid) {}
|
||||
SubscriberInternal(uint32_t tid) : thread_id(tid) {
|
||||
}
|
||||
};
|
||||
|
||||
using SubsribeMap = absl::flat_hash_map<ConnectionContext*, SubscriberInternal>;
|
||||
|
||||
static void CopySubsribers(const SubsribeMap& src, const std::string& pattern,
|
||||
std::vector<Subscriber>* dest);
|
||||
|
||||
struct Channel {
|
||||
absl::flat_hash_map<ConnectionContext*, SubscriberInternal> subscribers;
|
||||
SubsribeMap subscribers;
|
||||
};
|
||||
|
||||
absl::flat_hash_map<std::string, std::unique_ptr<Channel>> channels_;
|
||||
|
|
|
@ -23,9 +23,11 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis
|
|||
DCHECK(to_add);
|
||||
|
||||
conn_state.subscribe_info.reset(new ConnectionState::SubscribeInfo);
|
||||
// to be able to read input and still write the output.
|
||||
this->force_dispatch = true;
|
||||
}
|
||||
|
||||
// Gather all the channels we need to subsribe to / remove.
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
bool res = false;
|
||||
string_view channel = ArgS(args, i);
|
||||
|
@ -44,13 +46,14 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis
|
|||
}
|
||||
}
|
||||
|
||||
if (!to_add && conn_state.subscribe_info->channels.empty()) {
|
||||
if (!to_add && conn_state.subscribe_info->IsEmpty()) {
|
||||
conn_state.subscribe_info.reset();
|
||||
force_dispatch = false;
|
||||
}
|
||||
|
||||
sort(channels.begin(), channels.end());
|
||||
|
||||
// prepare the array in order to distribute the updates to the shards.
|
||||
vector<unsigned> shard_idx(shard_set->size() + 1, 0);
|
||||
for (const auto& k_v : channels) {
|
||||
shard_idx[k_v.first]++;
|
||||
|
@ -68,6 +71,7 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis
|
|||
int32_t tid = util::ProactorBase::GetIndex();
|
||||
DCHECK_GE(tid, 0);
|
||||
|
||||
// Update the subsribers on publisher's side.
|
||||
auto cb = [&](EngineShard* shard) {
|
||||
ChannelSlice& cs = shard->channel_slice();
|
||||
unsigned start = shard_idx[shard->shard_id()];
|
||||
|
@ -83,6 +87,7 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis
|
|||
}
|
||||
};
|
||||
|
||||
// Update subscription
|
||||
shard_set->RunBriefInParallel(move(cb),
|
||||
[&](ShardId sid) { return shard_idx[sid + 1] > shard_idx[sid]; });
|
||||
}
|
||||
|
@ -90,6 +95,77 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis
|
|||
if (to_reply) {
|
||||
const char* action[2] = {"unsubscribe", "subscribe"};
|
||||
|
||||
for (size_t i = 0; i < result.size(); ++i) {
|
||||
(*this)->StartArray(3);
|
||||
(*this)->SendBulkString(action[to_add]);
|
||||
(*this)->SendBulkString(ArgS(args, i)); // channel
|
||||
|
||||
// number of subsribed channels for this connection *right after*
|
||||
// we subsribe.
|
||||
(*this)->SendLong(result[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ConnectionContext::ChangePSub(bool to_add, bool to_reply, CmdArgList args) {
|
||||
vector<unsigned> result(to_reply ? args.size() : 0, 0);
|
||||
|
||||
if (to_add || conn_state.subscribe_info) {
|
||||
std::vector<string_view> patterns;
|
||||
patterns.reserve(args.size());
|
||||
|
||||
if (!conn_state.subscribe_info) {
|
||||
DCHECK(to_add);
|
||||
|
||||
conn_state.subscribe_info.reset(new ConnectionState::SubscribeInfo);
|
||||
this->force_dispatch = true;
|
||||
}
|
||||
|
||||
// Gather all the patterns we need to subsribe to / remove.
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
bool res = false;
|
||||
string_view pattern = ArgS(args, i);
|
||||
if (to_add) {
|
||||
res = conn_state.subscribe_info->patterns.emplace(pattern).second;
|
||||
} else {
|
||||
res = conn_state.subscribe_info->patterns.erase(pattern) > 0;
|
||||
}
|
||||
|
||||
if (to_reply)
|
||||
result[i] = conn_state.subscribe_info->patterns.size();
|
||||
|
||||
if (res) {
|
||||
patterns.emplace_back(pattern);
|
||||
}
|
||||
}
|
||||
|
||||
if (!to_add && conn_state.subscribe_info->IsEmpty()) {
|
||||
conn_state.subscribe_info.reset();
|
||||
force_dispatch = false;
|
||||
}
|
||||
|
||||
int32_t tid = util::ProactorBase::GetIndex();
|
||||
DCHECK_GE(tid, 0);
|
||||
|
||||
// Update the subsribers on publisher's side.
|
||||
auto cb = [&](EngineShard* shard) {
|
||||
ChannelSlice& cs = shard->channel_slice();
|
||||
for (string_view pattern : patterns) {
|
||||
if (to_add) {
|
||||
cs.AddGlobPattern(pattern, this, tid);
|
||||
} else {
|
||||
cs.RemoveGlobPattern(pattern, this);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Update pattern subscription. Run on all shards.
|
||||
shard_set->RunBriefInParallel(move(cb));
|
||||
}
|
||||
|
||||
if (to_reply) {
|
||||
const char* action[2] = {"punsubscribe", "psubscribe"};
|
||||
|
||||
for (size_t i = 0; i < result.size(); ++i) {
|
||||
(*this)->StartArray(3);
|
||||
(*this)->SendBulkString(action[to_add]);
|
||||
|
@ -100,18 +176,35 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis
|
|||
}
|
||||
|
||||
void ConnectionContext::OnClose() {
|
||||
if (conn_state.subscribe_info) {
|
||||
if (!conn_state.subscribe_info)
|
||||
return;
|
||||
|
||||
if (!conn_state.subscribe_info->channels.empty()) {
|
||||
StringVec channels(conn_state.subscribe_info->channels.begin(),
|
||||
conn_state.subscribe_info->channels.end());
|
||||
CmdArgVec arg_vec(channels.begin(), channels.end());
|
||||
|
||||
auto token = conn_state.subscribe_info->borrow_token;
|
||||
ChangeSubscription(false, false, CmdArgList{arg_vec});
|
||||
DCHECK(!conn_state.subscribe_info);
|
||||
|
||||
// Check that all borrowers finished processing
|
||||
token.Wait();
|
||||
}
|
||||
|
||||
if (conn_state.subscribe_info) {
|
||||
DCHECK(!conn_state.subscribe_info->patterns.empty());
|
||||
|
||||
StringVec patterns(conn_state.subscribe_info->patterns.begin(),
|
||||
conn_state.subscribe_info->patterns.end());
|
||||
CmdArgVec arg_vec(patterns.begin(), patterns.end());
|
||||
|
||||
auto token = conn_state.subscribe_info->borrow_token;
|
||||
ChangePSub(false, false, CmdArgList{arg_vec});
|
||||
|
||||
// Check that all borrowers finished processing
|
||||
token.Wait();
|
||||
DCHECK(!conn_state.subscribe_info);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace dfly
|
||||
|
|
|
@ -50,10 +50,16 @@ struct ConnectionState {
|
|||
struct SubscribeInfo {
|
||||
// TODO: to provide unique_strings across service. This will allow us to use string_view here.
|
||||
absl::flat_hash_set<std::string> channels;
|
||||
absl::flat_hash_set<std::string> patterns;
|
||||
|
||||
util::fibers_ext::BlockingCounter borrow_token;
|
||||
|
||||
SubscribeInfo() : borrow_token(0) {}
|
||||
bool IsEmpty() const {
|
||||
return channels.empty() && patterns.empty();
|
||||
}
|
||||
|
||||
SubscribeInfo() : borrow_token(0) {
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<SubscribeInfo> subscribe_info;
|
||||
|
@ -85,6 +91,7 @@ class ConnectionContext : public facade::ConnectionContext {
|
|||
}
|
||||
|
||||
void ChangeSubscription(bool to_add, bool to_reply, CmdArgList args);
|
||||
void ChangePSub(bool to_add, bool to_reply, CmdArgList args);
|
||||
|
||||
bool is_replicating = false;
|
||||
};
|
||||
|
|
|
@ -22,13 +22,13 @@ extern "C" {
|
|||
|
||||
namespace dfly {
|
||||
|
||||
using namespace absl;
|
||||
using namespace boost;
|
||||
using namespace std;
|
||||
using namespace util;
|
||||
using ::io::Result;
|
||||
using testing::ElementsAre;
|
||||
using testing::HasSubstr;
|
||||
using absl::StrCat;
|
||||
namespace this_fiber = boost::this_fiber;
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -411,6 +411,20 @@ TEST_F(DflyEngineTest, OOM) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(DflyEngineTest, PSubscribe) {
|
||||
auto resp = pp_->at(1)->Await([&] { return Run({"psubscribe", "a*", "b*"}); });
|
||||
EXPECT_THAT(resp, ArrLen(3));
|
||||
resp = pp_->at(0)->Await([&] { return Run({"publish", "ab", "foo"}); });
|
||||
EXPECT_THAT(resp, IntArg(1));
|
||||
|
||||
ASSERT_EQ(1, SubsriberMessagesLen("IO1"));
|
||||
|
||||
facade::Connection::PubMessage msg = GetPublishedMessage("IO1", 0);
|
||||
EXPECT_EQ("foo", msg.message);
|
||||
EXPECT_EQ("ab", msg.channel);
|
||||
EXPECT_EQ("a*", msg.pattern);
|
||||
}
|
||||
|
||||
// TODO: to test transactions with a single shard since then all transactions become local.
|
||||
// To consider having a parameter in dragonfly engine controlling number of shards
|
||||
// unconditionally from number of cpus. TO TEST BLPOP under multi for single/multi argument case.
|
||||
|
|
|
@ -912,23 +912,22 @@ void Service::Publish(CmdArgList args, ConnectionContext* cntx) {
|
|||
}
|
||||
|
||||
fibers_ext::BlockingCounter bc(subsriber_arr.size());
|
||||
char prefix[] = "*3\r\n$7\r\nmessage\r\n$";
|
||||
char msg_size[32] = {0};
|
||||
char channel_size[32] = {0};
|
||||
absl::SNPrintF(msg_size, sizeof(msg_size), "%u\r\n", message.size());
|
||||
absl::SNPrintF(channel_size, sizeof(channel_size), "%u\r\n", channel.size());
|
||||
|
||||
string_view msg_arr[] = {prefix, channel_size, channel, "\r\n$", msg_size, message, "\r\n"};
|
||||
|
||||
auto publish_cb = [&, bc](unsigned idx, util::ProactorBase*) mutable {
|
||||
unsigned start = slices[idx];
|
||||
|
||||
for (unsigned i = start; i < subsriber_arr.size(); ++i) {
|
||||
if (subsriber_arr[i].thread_id != idx)
|
||||
const ChannelSlice::Subscriber& subscriber = subsriber_arr[i];
|
||||
if (subscriber.thread_id != idx)
|
||||
break;
|
||||
|
||||
published.fetch_add(1, memory_order_relaxed);
|
||||
subsriber_arr[i].conn_cntx->owner()->SendMsgVecAsync(msg_arr, bc);
|
||||
facade::Connection* conn = subsriber_arr[i].conn_cntx->owner();
|
||||
DCHECK(conn);
|
||||
facade::Connection::PubMessage pmsg;
|
||||
pmsg.channel = channel;
|
||||
pmsg.message = message;
|
||||
pmsg.pattern = subscriber.pattern;
|
||||
conn->SendMsgVecAsync(pmsg, bc);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -959,6 +958,17 @@ void Service::Unsubscribe(CmdArgList args, ConnectionContext* cntx) {
|
|||
cntx->ChangeSubscription(false, true, std::move(args));
|
||||
}
|
||||
|
||||
void Service::PSubscribe(CmdArgList args, ConnectionContext* cntx) {
|
||||
args.remove_prefix(1);
|
||||
cntx->ChangePSub(true, true, args);
|
||||
}
|
||||
|
||||
void Service::PUnsubscribe(CmdArgList args, ConnectionContext* cntx) {
|
||||
args.remove_prefix(1);
|
||||
|
||||
cntx->ChangePSub(false, true, args);
|
||||
}
|
||||
|
||||
// Not a real implementation. Serves as a decorator to accept some function commands
|
||||
// for testing.
|
||||
void Service::Function(CmdArgList args, ConnectionContext* cntx) {
|
||||
|
@ -1024,6 +1034,8 @@ void Service::RegisterCommands() {
|
|||
<< CI{"PUBLISH", CO::LOADING | CO::FAST, 3, 0, 0, 0}.MFUNC(Publish)
|
||||
<< CI{"SUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -2, 0, 0, 0}.MFUNC(Subscribe)
|
||||
<< CI{"UNSUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -2, 0, 0, 0}.MFUNC(Unsubscribe)
|
||||
<< CI{"PSUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -2, 0, 0, 0}.MFUNC(PSubscribe)
|
||||
<< CI{"PUNSUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -2, 0, 0, 0}.MFUNC(PUnsubscribe)
|
||||
<< CI{"FUNCTION", CO::NOSCRIPT, 2, 0, 0, 0}.MFUNC(Function);
|
||||
|
||||
StringFamily::Register(®istry_);
|
||||
|
|
|
@ -94,6 +94,8 @@ class Service : public facade::ServiceInterface {
|
|||
void Publish(CmdArgList args, ConnectionContext* cntx);
|
||||
void Subscribe(CmdArgList args, ConnectionContext* cntx);
|
||||
void Unsubscribe(CmdArgList args, ConnectionContext* cntx);
|
||||
void PSubscribe(CmdArgList args, ConnectionContext* cntx);
|
||||
void PUnsubscribe(CmdArgList args, ConnectionContext* cntx);
|
||||
void Function(CmdArgList args, ConnectionContext* cntx);
|
||||
|
||||
struct EvalArgs {
|
||||
|
@ -113,6 +115,7 @@ class Service : public facade::ServiceInterface {
|
|||
ServerFamily server_family_;
|
||||
CommandRegistry registry_;
|
||||
absl::flat_hash_map<std::string, unsigned> unknown_cmds_;
|
||||
|
||||
mutable ::boost::fibers::mutex mu_;
|
||||
|
||||
GlobalState global_state_ = GlobalState::ACTIVE; // protected by mu_;
|
||||
|
|
|
@ -37,9 +37,68 @@ static vector<string> SplitLines(const std::string& src) {
|
|||
return res;
|
||||
}
|
||||
|
||||
TestConnection::TestConnection(Protocol protocol)
|
||||
: facade::Connection(protocol, nullptr, nullptr, nullptr) {
|
||||
}
|
||||
|
||||
void TestConnection::SendMsgVecAsync(const PubMessage& pmsg, util::fibers_ext::BlockingCounter bc) {
|
||||
backing_str_.emplace_back(new string(pmsg.channel));
|
||||
PubMessage dest;
|
||||
dest.channel = *backing_str_.back();
|
||||
|
||||
backing_str_.emplace_back(new string(pmsg.message));
|
||||
dest.message = *backing_str_.back();
|
||||
|
||||
if (!pmsg.pattern.empty()) {
|
||||
backing_str_.emplace_back(new string(pmsg.pattern));
|
||||
dest.pattern = *backing_str_.back();
|
||||
}
|
||||
messages.push_back(dest);
|
||||
|
||||
bc.Dec();
|
||||
}
|
||||
|
||||
class BaseFamilyTest::TestConnWrapper {
|
||||
public:
|
||||
TestConnWrapper(Protocol proto);
|
||||
~TestConnWrapper();
|
||||
|
||||
CmdArgVec Args(ArgSlice list);
|
||||
|
||||
RespVec ParseResponse();
|
||||
|
||||
// returns: type(pmessage), pattern, channel, message.
|
||||
facade::Connection::PubMessage GetPubMessage(size_t index) const;
|
||||
|
||||
ConnectionContext* cmd_cntx() {
|
||||
return &cmd_cntx_;
|
||||
}
|
||||
|
||||
StringVec SplitLines() const {
|
||||
return dfly::SplitLines(sink_.str());
|
||||
}
|
||||
|
||||
void ClearSink() {
|
||||
sink_.Clear();
|
||||
}
|
||||
|
||||
TestConnection* conn() {
|
||||
return dummy_conn_.get();
|
||||
}
|
||||
|
||||
private:
|
||||
::io::StringSink sink_; // holds the response blob
|
||||
|
||||
std::unique_ptr<TestConnection> dummy_conn_;
|
||||
|
||||
ConnectionContext cmd_cntx_;
|
||||
std::vector<std::unique_ptr<std::string>> tmp_str_vec_;
|
||||
|
||||
std::unique_ptr<RedisParser> parser_;
|
||||
};
|
||||
|
||||
BaseFamilyTest::TestConnWrapper::TestConnWrapper(Protocol proto)
|
||||
: dummy_conn(new facade::Connection(proto, nullptr, nullptr, nullptr)),
|
||||
cmd_cntx(&sink, dummy_conn.get()) {
|
||||
: dummy_conn_(new TestConnection(proto)), cmd_cntx_(&sink_, dummy_conn_.get()) {
|
||||
}
|
||||
|
||||
BaseFamilyTest::TestConnWrapper::~TestConnWrapper() {
|
||||
|
@ -102,22 +161,22 @@ RespExpr BaseFamilyTest::Run(ArgSlice list) {
|
|||
}
|
||||
|
||||
RespExpr BaseFamilyTest::Run(std::string_view id, ArgSlice slice) {
|
||||
TestConnWrapper* conn = AddFindConn(Protocol::REDIS, id);
|
||||
TestConnWrapper* conn_wrapper = AddFindConn(Protocol::REDIS, id);
|
||||
|
||||
CmdArgVec args = conn->Args(slice);
|
||||
CmdArgVec args = conn_wrapper->Args(slice);
|
||||
|
||||
auto& context = conn->cmd_cntx;
|
||||
auto* context = conn_wrapper->cmd_cntx();
|
||||
|
||||
DCHECK(context.transaction == nullptr);
|
||||
DCHECK(context->transaction == nullptr);
|
||||
|
||||
service_->DispatchCommand(CmdArgList{args}, &context);
|
||||
service_->DispatchCommand(CmdArgList{args}, context);
|
||||
|
||||
DCHECK(context.transaction == nullptr);
|
||||
DCHECK(context->transaction == nullptr);
|
||||
|
||||
unique_lock lk(mu_);
|
||||
last_cmd_dbg_info_ = context.last_command_debug;
|
||||
last_cmd_dbg_info_ = context->last_command_debug;
|
||||
|
||||
RespVec vec = conn->ParseResponse();
|
||||
RespVec vec = conn_wrapper->ParseResponse();
|
||||
if (vec.size() == 1)
|
||||
return vec.front();
|
||||
RespVec* new_vec = new RespVec(vec);
|
||||
|
@ -144,15 +203,15 @@ auto BaseFamilyTest::RunMC(MP::CmdType cmd_type, string_view key, string_view va
|
|||
|
||||
TestConnWrapper* conn = AddFindConn(Protocol::MEMCACHE, GetId());
|
||||
|
||||
auto& context = conn->cmd_cntx;
|
||||
auto* context = conn->cmd_cntx();
|
||||
|
||||
DCHECK(context.transaction == nullptr);
|
||||
DCHECK(context->transaction == nullptr);
|
||||
|
||||
service_->DispatchMC(cmd, value, &context);
|
||||
service_->DispatchMC(cmd, value, context);
|
||||
|
||||
DCHECK(context.transaction == nullptr);
|
||||
DCHECK(context->transaction == nullptr);
|
||||
|
||||
return SplitLines(conn->sink.str());
|
||||
return conn->SplitLines();
|
||||
}
|
||||
|
||||
auto BaseFamilyTest::RunMC(MP::CmdType cmd_type, std::string_view key) -> MCResponse {
|
||||
|
@ -165,11 +224,11 @@ auto BaseFamilyTest::RunMC(MP::CmdType cmd_type, std::string_view key) -> MCResp
|
|||
cmd.key = key;
|
||||
TestConnWrapper* conn = AddFindConn(Protocol::MEMCACHE, GetId());
|
||||
|
||||
auto& context = conn->cmd_cntx;
|
||||
auto* context = conn->cmd_cntx();
|
||||
|
||||
service_->DispatchMC(cmd, string_view{}, &context);
|
||||
service_->DispatchMC(cmd, string_view{}, context);
|
||||
|
||||
return SplitLines(conn->sink.str());
|
||||
return conn->SplitLines();
|
||||
}
|
||||
|
||||
auto BaseFamilyTest::GetMC(MP::CmdType cmd_type, std::initializer_list<std::string_view> list)
|
||||
|
@ -191,11 +250,11 @@ auto BaseFamilyTest::GetMC(MP::CmdType cmd_type, std::initializer_list<std::stri
|
|||
|
||||
TestConnWrapper* conn = AddFindConn(Protocol::MEMCACHE, GetId());
|
||||
|
||||
auto& context = conn->cmd_cntx;
|
||||
auto* context = conn->cmd_cntx();
|
||||
|
||||
service_->DispatchMC(cmd, string_view{}, &context);
|
||||
service_->DispatchMC(cmd, string_view{}, context);
|
||||
|
||||
return SplitLines(conn->sink.str());
|
||||
return conn->SplitLines();
|
||||
}
|
||||
|
||||
int64_t BaseFamilyTest::CheckedInt(std::initializer_list<std::string_view> list) {
|
||||
|
@ -222,8 +281,8 @@ CmdArgVec BaseFamilyTest::TestConnWrapper::Args(ArgSlice list) {
|
|||
if (v.empty()) {
|
||||
res.push_back(MutableSlice{});
|
||||
} else {
|
||||
tmp_str_vec.emplace_back(new string{v});
|
||||
auto& s = *tmp_str_vec.back();
|
||||
tmp_str_vec_.emplace_back(new string{v});
|
||||
auto& s = *tmp_str_vec_.back();
|
||||
|
||||
res.emplace_back(s.data(), s.size());
|
||||
}
|
||||
|
@ -233,19 +292,24 @@ CmdArgVec BaseFamilyTest::TestConnWrapper::Args(ArgSlice list) {
|
|||
}
|
||||
|
||||
RespVec BaseFamilyTest::TestConnWrapper::ParseResponse() {
|
||||
tmp_str_vec.emplace_back(new string{sink.str()});
|
||||
auto& s = *tmp_str_vec.back();
|
||||
tmp_str_vec_.emplace_back(new string{sink_.str()});
|
||||
auto& s = *tmp_str_vec_.back();
|
||||
auto buf = RespExpr::buffer(&s);
|
||||
uint32_t consumed = 0;
|
||||
|
||||
parser.reset(new RedisParser{false}); // Client mode.
|
||||
parser_.reset(new RedisParser{false}); // Client mode.
|
||||
RespVec res;
|
||||
RedisParser::Result st = parser->Parse(buf, &consumed, &res);
|
||||
RedisParser::Result st = parser_->Parse(buf, &consumed, &res);
|
||||
CHECK_EQ(RedisParser::OK, st);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
facade::Connection::PubMessage BaseFamilyTest::TestConnWrapper::GetPubMessage(size_t index) const {
|
||||
CHECK_LT(index, dummy_conn_->messages.size());
|
||||
return dummy_conn_->messages[index];
|
||||
}
|
||||
|
||||
bool BaseFamilyTest::IsLocked(DbIndex db_index, std::string_view key) const {
|
||||
ShardId sid = Shard(key, shard_set->size());
|
||||
KeyLockArgs args;
|
||||
|
@ -263,11 +327,30 @@ string BaseFamilyTest::GetId() const {
|
|||
return absl::StrCat("IO", id);
|
||||
}
|
||||
|
||||
size_t BaseFamilyTest::SubsriberMessagesLen(string_view conn_id) const {
|
||||
auto it = connections_.find(conn_id);
|
||||
if (it == connections_.end())
|
||||
return 0;
|
||||
|
||||
return it->second->conn()->messages.size();
|
||||
}
|
||||
|
||||
facade::Connection::PubMessage BaseFamilyTest::GetPublishedMessage(string_view conn_id,
|
||||
size_t index) const {
|
||||
facade::Connection::PubMessage res;
|
||||
|
||||
auto it = connections_.find(conn_id);
|
||||
if (it == connections_.end())
|
||||
return res;
|
||||
|
||||
return it->second->GetPubMessage(index);
|
||||
}
|
||||
|
||||
ConnectionContext::DebugInfo BaseFamilyTest::GetDebugInfo(const std::string& id) const {
|
||||
auto it = connections_.find(id);
|
||||
CHECK(it != connections_.end());
|
||||
|
||||
return it->second->cmd_cntx.last_command_debug;
|
||||
return it->second->cmd_cntx()->last_command_debug;
|
||||
}
|
||||
|
||||
auto BaseFamilyTest::AddFindConn(Protocol proto, std::string_view id) -> TestConnWrapper* {
|
||||
|
@ -278,7 +361,7 @@ auto BaseFamilyTest::AddFindConn(Protocol proto, std::string_view id) -> TestCon
|
|||
if (inserted) {
|
||||
it->second.reset(new TestConnWrapper(proto));
|
||||
} else {
|
||||
it->second->sink.Clear();
|
||||
it->second->ClearSink();
|
||||
}
|
||||
return it->second.get();
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
#include <gmock/gmock.h>
|
||||
|
||||
#include "facade/dragonfly_connection.h"
|
||||
#include "facade/memcache_parser.h"
|
||||
#include "facade/redis_parser.h"
|
||||
#include "io/io.h"
|
||||
|
@ -16,6 +17,18 @@
|
|||
namespace dfly {
|
||||
using namespace facade;
|
||||
|
||||
class TestConnection : public facade::Connection {
|
||||
public:
|
||||
TestConnection(Protocol protocol);
|
||||
|
||||
void SendMsgVecAsync(const PubMessage& pmsg, util::fibers_ext::BlockingCounter bc) final;
|
||||
|
||||
std::vector<PubMessage> messages;
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<std::string>> backing_str_;
|
||||
};
|
||||
|
||||
class BaseFamilyTest : public ::testing::Test {
|
||||
protected:
|
||||
BaseFamilyTest();
|
||||
|
@ -27,23 +40,7 @@ class BaseFamilyTest : public ::testing::Test {
|
|||
void TearDown() override;
|
||||
|
||||
protected:
|
||||
struct TestConnWrapper {
|
||||
::io::StringSink sink; // holds the response blob
|
||||
|
||||
std::unique_ptr<facade::Connection> dummy_conn;
|
||||
|
||||
ConnectionContext cmd_cntx;
|
||||
std::vector<std::unique_ptr<std::string>> tmp_str_vec;
|
||||
|
||||
std::unique_ptr<RedisParser> parser;
|
||||
|
||||
TestConnWrapper(Protocol proto);
|
||||
~TestConnWrapper();
|
||||
|
||||
CmdArgVec Args(ArgSlice list);
|
||||
|
||||
RespVec ParseResponse();
|
||||
};
|
||||
class TestConnWrapper;
|
||||
|
||||
RespExpr Run(std::initializer_list<const std::string_view> list) {
|
||||
return Run(ArgSlice{list.begin(), list.size()});
|
||||
|
@ -75,6 +72,11 @@ class BaseFamilyTest : public ::testing::Test {
|
|||
void UpdateTime(uint64_t ms);
|
||||
|
||||
std::string GetId() const;
|
||||
size_t SubsriberMessagesLen(std::string_view conn_id) const;
|
||||
|
||||
// Returns message parts as returned by RESP:
|
||||
// pmessage, pattern, channel, message
|
||||
facade::Connection::PubMessage GetPublishedMessage(std::string_view conn_id, size_t index) const;
|
||||
|
||||
std::unique_ptr<util::ProactorPool> pp_;
|
||||
std::unique_ptr<Service> service_;
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
using namespace testing;
|
||||
using namespace std;
|
||||
using namespace util;
|
||||
using namespace boost;
|
||||
|
||||
namespace dfly {
|
||||
|
||||
|
|
Loading…
Reference in a new issue