1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-14 11:58:02 +00:00

Dispatch queue memory optimizations (#1103)

Dispatch queue entry optimizations
This commit is contained in:
Vladislav 2023-04-22 09:02:07 +03:00 committed by GitHub
parent 2d73b2bfb0
commit 71147c20a9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 206 additions and 247 deletions

View file

@ -78,17 +78,11 @@ bool MatchHttp11Line(string_view line) {
constexpr size_t kMinReadSize = 256;
constexpr size_t kMaxReadSize = 32_KB;
#ifdef ABSL_HAVE_ADDRESS_SANITIZER
constexpr size_t kReqStorageSize = 88;
#else
constexpr size_t kReqStorageSize = 120;
#endif
thread_local uint32_t free_req_release_weight = 0;
} // namespace
thread_local vector<Connection::RequestPtr> Connection::pipeline_req_pool_;
thread_local vector<Connection::PipelineMessagePtr> Connection::pipeline_req_pool_;
struct Connection::Shutdown {
absl::flat_hash_map<ShutdownHandle, ShutdownCb> map;
@ -104,77 +98,21 @@ struct Connection::Shutdown {
}
};
// Used as custom deleter for Request object
struct Connection::RequestDeleter {
void operator()(Request* req) const;
};
using PubMessage = Connection::PubMessage;
using MonitorMessage = std::string;
// Please note: The call to the Dtor is mandatory for this!!
// This class contain types that don't have trivial destructed objects
class Connection::Request {
public:
struct PipelineMessage {
// mi_stl_allocator uses mi heap internally.
// The capacity is chosen so that we allocate a fully utilized (256 bytes) block.
using StorageType = absl::InlinedVector<char, kReqStorageSize, mi_stl_allocator<char>>;
PipelineMessage(size_t nargs, size_t capacity) : args(nargs), storage(capacity) {
}
void Reset(size_t nargs, size_t capacity);
absl::InlinedVector<MutableSlice, 6> args;
StorageType storage;
};
using MessagePayload = std::variant<PipelineMessage, PubMessage, MonitorMessage>;
// Overload to create the a new pipeline message
static RequestPtr New(mi_heap_t* heap, const RespVec& args, size_t capacity);
// Overload to create a new pubsub message
static RequestPtr New(PubMessage pub_msg);
// Overload to create a new the monitor message
static RequestPtr New(MonitorMessage msg);
void Emplace(const RespVec& args, size_t capacity);
size_t StorageCapacity() const;
bool IsPipelineMsg() const;
private:
static constexpr size_t kSizeOfPipelineMsg = sizeof(PipelineMessage);
Request(size_t nargs, size_t capacity) : payload(PipelineMessage{nargs, capacity}) {
}
Request(PubMessage msg) : payload(move(msg)) {
}
Request(MonitorMessage msg) : payload(move(msg)) {
}
Request(const Request&) = delete;
// Store arguments for pipeline message.
void SetArgs(const RespVec& args);
public:
MessagePayload payload;
};
Connection::PubMessage::PubMessage(string pattern, shared_ptr<string> channel,
shared_ptr<string> message)
: type{kPublish}, pattern{move(pattern)}, channel{move(channel)}, message{move(message)} {
Connection::PubMessage::PubMessage(string pattern, shared_ptr<char[]> buf, size_t channel_len,
size_t message_len)
: data{MessageData{pattern, move(buf), channel_len, message_len}} {
}
Connection::PubMessage::PubMessage(bool add, shared_ptr<string> channel, uint32_t channel_cnt)
: type{add ? kSubscribe : kUnsubscribe}, channel{move(channel)}, channel_cnt{channel_cnt} {
Connection::PubMessage::PubMessage(bool add, string_view channel, uint32_t channel_cnt)
: data{SubscribeData{add, string{channel}, channel_cnt}} {
}
string_view Connection::PubMessage::MessageData::Channel() const {
return {buf.get(), channel_len};
}
string_view Connection::PubMessage::MessageData::Message() const {
return {buf.get() + channel_len, message_len};
}
struct Connection::DispatchOperations {
@ -183,90 +121,67 @@ struct Connection::DispatchOperations {
}
void operator()(const PubMessage& msg);
void operator()(Request::PipelineMessage& msg);
void operator()(Connection::PipelineMessage& msg);
void operator()(const MonitorMessage& msg);
template <typename T, typename D> void operator()(unique_ptr<T, D>& ptr) {
operator()(*ptr.get());
}
ConnectionStats* stats = nullptr;
SinkReplyBuilder* builder = nullptr;
Connection* self = nullptr;
};
Connection::RequestPtr Connection::Request::New(MonitorMessage msg) {
void* ptr = mi_malloc(sizeof(Request));
Request* req = new (ptr) Request(move(msg));
return Connection::RequestPtr{req, Connection::RequestDeleter{}};
}
Connection::RequestPtr Connection::Request::New(mi_heap_t* heap, const RespVec& args,
size_t capacity) {
constexpr auto kReqSz = sizeof(Request);
void* ptr = mi_heap_malloc_small(heap, kReqSz);
// We must construct in place here, since there is a slice that uses memory locations
Request* req = new (ptr) Request(args.size(), capacity);
req->SetArgs(args);
return Connection::RequestPtr{req, Connection::RequestDeleter{}};
}
Connection::RequestPtr Connection::Request::New(PubMessage pub_msg) {
// This will generate a new request for pubsub message
// Please note that unlike the above case, we don't need to "protect", the internals here
// since we are currently using a borrow token for it - i.e. the BlockingCounter will
// ensure that the message is not deleted until we are finish sending it at the other
// side of the queue
void* ptr = mi_malloc(sizeof(Request));
Request* req = new (ptr) Request(move(pub_msg));
return Connection::RequestPtr{req, Connection::RequestDeleter{}};
}
void Connection::Request::SetArgs(const RespVec& args) {
// At this point we know that we have PipelineMessage in Request so next op is safe.
PipelineMessage& pipeline_msg = std::get<PipelineMessage>(payload);
auto* next = pipeline_msg.storage.data();
void Connection::PipelineMessage::SetArgs(const RespVec& args) {
auto* next = storage.data();
for (size_t i = 0; i < args.size(); ++i) {
auto buf = args[i].GetBuf();
size_t s = buf.size();
memcpy(next, buf.data(), s);
pipeline_msg.args[i] = MutableSlice(next, s);
this->args[i] = MutableSlice(next, s);
next += s;
}
}
void Connection::RequestDeleter::operator()(Request* req) const {
req->~Request();
mi_free(req);
void Connection::MessageDeleter::operator()(PipelineMessage* msg) const {
msg->~PipelineMessage();
mi_free(msg);
}
void Connection::Request::Emplace(const RespVec& args, size_t capacity) {
PipelineMessage* msg = get_if<PipelineMessage>(&payload);
if (msg) {
msg->Reset(args.size(), capacity);
} else {
payload = PipelineMessage{args.size(), capacity};
}
SetArgs(args);
void Connection::MessageDeleter::operator()(PubMessage* msg) const {
msg->~PubMessage();
mi_free(msg);
}
void Connection::Request::PipelineMessage::Reset(size_t nargs, size_t capacity) {
void Connection::PipelineMessage::Reset(size_t nargs, size_t capacity) {
storage.resize(capacity);
args.resize(nargs);
}
template <class... Ts> struct Overloaded : Ts... { using Ts::operator()...; };
template <class... Ts> Overloaded(Ts...) -> Overloaded<Ts...>;
size_t Connection::Request::StorageCapacity() const {
return std::visit(Overloaded{[](const PubMessage& msg) -> size_t { return 0; },
[](const PipelineMessage& arg) -> size_t {
return arg.storage.capacity() + arg.args.capacity();
},
[](const MonitorMessage& arg) -> size_t { return arg.capacity(); }},
payload);
size_t Connection::PipelineMessage::StorageCapacity() const {
return storage.capacity() + args.capacity();
}
bool Connection::Request::IsPipelineMsg() const {
return std::get_if<PipelineMessage>(&payload) != nullptr;
template <class... Ts> struct Overloaded : Ts... {
using Ts::operator()...;
template <typename T, typename D> size_t operator()(const unique_ptr<T, D>& ptr) {
return operator()(*ptr.get());
}
};
template <class... Ts> Overloaded(Ts...) -> Overloaded<Ts...>;
size_t Connection::MessageHandle::StorageCapacity() const {
auto pub_size = [](const PubMessage& msg) -> size_t { return 0; };
auto msg_size = [](const PipelineMessage& arg) -> size_t { return arg.StorageCapacity(); };
auto monitor_size = [](const MonitorMessage& arg) -> size_t { return 0; };
return visit(Overloaded{pub_size, msg_size, monitor_size}, this->handle);
}
bool Connection::MessageHandle::IsPipelineMsg() const {
return get_if<PipelineMessagePtr>(&this->handle) != nullptr;
}
void Connection::DispatchOperations::operator()(const MonitorMessage& msg) {
@ -277,33 +192,31 @@ void Connection::DispatchOperations::operator()(const MonitorMessage& msg) {
void Connection::DispatchOperations::operator()(const PubMessage& pub_msg) {
RedisReplyBuilder* rbuilder = (RedisReplyBuilder*)builder;
++stats->async_writes_cnt;
string_view arr[4];
if (pub_msg.type == PubMessage::kPublish) {
if (pub_msg.pattern.empty()) {
DVLOG(1) << "Sending message, from channel: " << *pub_msg.channel << " " << *pub_msg.message;
arr[0] = "message";
arr[1] = *pub_msg.channel;
arr[2] = *pub_msg.message;
rbuilder->SendStringArr(absl::Span<string_view>{arr, 3},
RedisReplyBuilder::CollectionType::PUSH);
auto send_msg = [rbuilder](const PubMessage::MessageData& data) {
unsigned i = 0;
string_view arr[4];
if (data.pattern.empty()) {
arr[i++] = "message";
} else {
arr[0] = "pmessage";
arr[1] = pub_msg.pattern;
arr[2] = *pub_msg.channel;
arr[3] = *pub_msg.message;
rbuilder->SendStringArr(absl::Span<string_view>{arr, 4},
RedisReplyBuilder::CollectionType::PUSH);
arr[i++] = "pmessage";
arr[i++] = data.pattern;
}
} else {
arr[i++] = data.Channel();
arr[i++] = data.Message();
rbuilder->SendStringArr(absl::Span<string_view>{arr, i},
RedisReplyBuilder::CollectionType::PUSH);
};
auto send_sub = [rbuilder](const PubMessage::SubscribeData& data) {
const char* action[2] = {"unsubscribe", "subscribe"};
rbuilder->StartCollection(3, RedisReplyBuilder::CollectionType::PUSH);
rbuilder->SendBulkString(action[pub_msg.type == PubMessage::kSubscribe]);
rbuilder->SendBulkString(*pub_msg.channel);
rbuilder->SendLong(pub_msg.channel_cnt);
}
rbuilder->SendBulkString(action[data.add]);
rbuilder->SendBulkString(data.channel);
rbuilder->SendLong(data.channel_cnt);
};
visit(Overloaded{send_msg, send_sub}, pub_msg.data);
}
void Connection::DispatchOperations::operator()(Request::PipelineMessage& msg) {
void Connection::DispatchOperations::operator()(Connection::PipelineMessage& msg) {
++stats->pipelined_cmd_cnt;
self->pipeline_msg_cnt_--;
@ -324,7 +237,7 @@ Connection::Connection(Protocol protocol, util::HttpListenerBase* http_listener,
protocol_ = protocol;
constexpr size_t kReqSz = sizeof(Connection::Request);
constexpr size_t kReqSz = sizeof(Connection::PipelineMessage);
static_assert(kReqSz <= 256 && kReqSz >= 232);
switch (protocol) {
@ -461,20 +374,6 @@ void Connection::RegisterBreakHook(BreakerCb breaker_cb) {
breaker_cb_ = breaker_cb;
}
void Connection::SendPubMessageAsync(PubMessage pub_msg) {
DCHECK(cc_);
if (cc_->conn_closing) {
return;
}
RequestPtr req = Request::New(move(pub_msg));
dispatch_q_.push_back(std::move(req));
if (dispatch_q_.size() == 1) {
evc_.notify();
}
}
std::string Connection::LocalBindAddress() const {
LinuxSocketBase* lsb = static_cast<LinuxSocketBase*>(socket_.get());
auto le = lsb->LocalEndpoint();
@ -668,14 +567,10 @@ auto Connection::ParseRedis() -> ParserStatus {
last_interaction_ = time(nullptr);
} else {
// Dispatch via queue to speedup input reading.
RequestPtr req = FromArgs(std::move(tmp_parse_args_), tlh);
SendAsync(MessageHandle{FromArgs(move(tmp_parse_args_), tlh)});
++pipeline_msg_cnt_;
dispatch_q_.push_back(std::move(req));
if (dispatch_q_.size() == 1) {
evc_.notify();
} else if (dispatch_q_.size() > 10) {
if (dispatch_q_.size() > 10)
ThisFiber::Yield();
}
}
}
io_buf_.ConsumeInput(consumed);
@ -843,13 +738,15 @@ void Connection::DispatchFiber(util::FiberSocketBase* peer) {
if (cc_->conn_closing)
break;
RequestPtr req{std::move(dispatch_q_.front())};
MessageHandle msg = move(dispatch_q_.front());
dispatch_q_.pop_front();
std::visit(dispatch_op, req->payload);
std::visit(dispatch_op, msg.handle);
if (req->IsPipelineMsg() && stats_->pipeline_cache_capacity < request_cache_limit) {
stats_->pipeline_cache_capacity += req->StorageCapacity();
pipeline_req_pool_.push_back(std::move(req));
if (auto* pipe = get_if<PipelineMessagePtr>(&msg.handle); pipe) {
if (stats_->pipeline_cache_capacity < request_cache_limit) {
stats_->pipeline_cache_capacity += (*pipe)->StorageCapacity();
pipeline_req_pool_.push_back(move(*pipe));
}
}
}
@ -859,7 +756,7 @@ void Connection::DispatchFiber(util::FiberSocketBase* peer) {
dispatch_q_.clear();
}
auto Connection::FromArgs(RespVec args, mi_heap_t* heap) -> RequestPtr {
Connection::PipelineMessagePtr Connection::FromArgs(RespVec args, mi_heap_t* heap) {
DCHECK(!args.empty());
size_t backed_sz = 0;
for (const auto& arg : args) {
@ -868,18 +765,21 @@ auto Connection::FromArgs(RespVec args, mi_heap_t* heap) -> RequestPtr {
}
DCHECK(backed_sz);
constexpr auto kReqSz = sizeof(Request);
constexpr auto kReqSz = sizeof(PipelineMessage);
static_assert(kReqSz < MI_SMALL_SIZE_MAX);
static_assert(alignof(Request) == 8);
static_assert(alignof(PipelineMessage) == 8);
RequestPtr req;
if (req = GetFromPipelinePool(); req) {
req->Emplace(move(args), backed_sz);
PipelineMessagePtr ptr;
if (ptr = GetFromPipelinePool(); ptr) {
ptr->Reset(args.size(), backed_sz);
} else {
req = Request::New(heap, args, backed_sz);
void* heap_ptr = mi_heap_malloc_small(heap, sizeof(PipelineMessage));
// We must construct in place here, since there is a slice that uses memory locations
ptr.reset(new (heap_ptr) PipelineMessage(args.size(), backed_sz));
}
return req;
ptr->SetArgs(args);
return ptr;
}
void Connection::ShrinkPipelinePool() {
@ -899,15 +799,15 @@ void Connection::ShrinkPipelinePool() {
}
}
Connection::RequestPtr Connection::GetFromPipelinePool() {
Connection::PipelineMessagePtr Connection::GetFromPipelinePool() {
if (pipeline_req_pool_.empty())
return {};
return nullptr;
free_req_release_weight = 0; // Reset the release weight.
RequestPtr req = move(pipeline_req_pool_.back());
stats_->pipeline_cache_capacity -= req->StorageCapacity();
auto ptr = move(pipeline_req_pool_.back());
stats_->pipeline_cache_capacity -= ptr->StorageCapacity();
pipeline_req_pool_.pop_back();
return req;
return ptr;
}
void Connection::ShutdownSelf() {
@ -927,15 +827,24 @@ void RespToArgList(const RespVec& src, CmdArgVec* dest) {
}
}
void Connection::SendMonitorMessageAsync(std::string monitor_msg) {
void Connection::SendPubMessageAsync(PubMessage msg) {
void* ptr = mi_malloc(sizeof(PubMessage));
SendAsync({PubMessagePtr{new (ptr) PubMessage{move(msg)}, MessageDeleter{}}});
}
void Connection::SendMonitorMessageAsync(string msg) {
SendAsync({MonitorMessage{move(msg)}});
}
void Connection::SendAsync(MessageHandle msg) {
DCHECK(cc_);
if (!cc_->conn_closing) {
RequestPtr req = Request::New(std::move(monitor_msg));
dispatch_q_.push_back(std::move(req));
if (dispatch_q_.size() == 1) {
evc_.notify();
}
if (cc_->conn_closing)
return;
dispatch_q_.push_back(move(msg));
if (dispatch_q_.size() == 1) {
evc_.notify();
}
}

View file

@ -5,6 +5,7 @@
#pragma once
#include <absl/container/fixed_array.h>
#include <mimalloc.h>
#include <sys/socket.h>
#include <deque>
@ -31,6 +32,12 @@ typedef struct mi_heap_s mi_heap_t;
#define SO_INCOMING_NAPI_ID 56
#endif
#ifdef ABSL_HAVE_ADDRESS_SANITIZER
constexpr size_t kReqStorageSize = 88;
#else
constexpr size_t kReqStorageSize = 120;
#endif
namespace facade {
class ConnectionContext;
@ -55,21 +62,65 @@ class Connection : public util::Connection {
// PubSub message, either incoming message for active subscription or reply for new subscription.
struct PubMessage {
enum Type { kSubscribe, kUnsubscribe, kPublish } type;
// Represents incoming message.
struct MessageData {
std::string pattern{}; // non-empty for pattern subscriber
std::shared_ptr<char[]> buf; // stores channel name and message
size_t channel_len, message_len; // lengths in buf
std::string pattern{}; // non-empty for pattern subscriber
std::shared_ptr<std::string> channel{};
std::shared_ptr<std::string> message{};
std::string_view Channel() const;
std::string_view Message() const;
};
uint32_t channel_cnt = 0;
// Represents reply for subscribe/unsubscribe.
struct SubscribeData {
bool add;
std::string channel;
uint32_t channel_cnt;
};
PubMessage(bool add, std::shared_ptr<std::string> channel, uint32_t channel_cnt);
PubMessage(std::string pattern, std::shared_ptr<std::string> channel,
std::shared_ptr<std::string> message);
std::variant<MessageData, SubscribeData> data;
PubMessage(const PubMessage&) = delete;
PubMessage& operator=(const PubMessage&) = delete;
PubMessage(PubMessage&&) = default;
PubMessage(bool add, std::string_view channel, uint32_t channel_cnt);
PubMessage(std::string pattern, std::shared_ptr<char[]> buf, size_t channel_len,
size_t message_len);
};
struct MonitorMessage : public std::string {};
struct PipelineMessage {
PipelineMessage(size_t nargs, size_t capacity) : args(nargs), storage(capacity) {
}
void Reset(size_t nargs, size_t capacity);
void SetArgs(const RespVec& args);
size_t StorageCapacity() const;
// mi_stl_allocator uses mi heap internally.
// The capacity is chosen so that we allocate a fully utilized (256 bytes) block.
using StorageType = absl::InlinedVector<char, kReqStorageSize, mi_stl_allocator<char>>;
absl::InlinedVector<MutableSlice, 6> args;
StorageType storage;
};
struct MessageDeleter {
void operator()(PipelineMessage* msg) const;
void operator()(PubMessage* msg) const;
};
// Requests are allocated on the mimalloc heap and thus require a custom deleter.
using PipelineMessagePtr = std::unique_ptr<PipelineMessage, MessageDeleter>;
using PubMessagePtr = std::unique_ptr<PubMessage, MessageDeleter>;
struct MessageHandle {
size_t StorageCapacity() const;
bool IsPipelineMsg() const;
std::variant<MonitorMessage, PubMessagePtr, PipelineMessagePtr> handle;
};
enum Phase { READ_SOCKET, PROCESS };
@ -77,10 +128,10 @@ class Connection : public util::Connection {
public:
// Add PubMessage to dispatch queue.
// Virtual because behaviour is overwritten in test_utils.
virtual void SendPubMessageAsync(PubMessage pub_msg);
virtual void SendPubMessageAsync(PubMessage);
// Add monitor message to dispatch queue.
void SendMonitorMessageAsync(std::string monitor_msg);
void SendMonitorMessageAsync(std::string);
// Register hook that is executed on connection shutdown.
ShutdownHandle RegisterShutdownHook(ShutdownCb cb);
@ -121,15 +172,10 @@ class Connection : public util::Connection {
private:
enum ParserStatus { OK, NEED_MORE, ERROR };
class Request;
struct DispatchOperations;
struct DispatchCleanup;
struct RequestDeleter;
struct Shutdown;
// Requests are allocated on the mimalloc heap and thus require a custom deleter.
using RequestPtr = std::unique_ptr<Request, RequestDeleter>;
private:
// Check protocol and handle connection.
void HandleRequests() final;
@ -146,8 +192,10 @@ class Connection : public util::Connection {
// Handles events from dispatch queue.
void DispatchFiber(util::FiberSocketBase* peer);
void SendAsync(MessageHandle msg);
// Create new pipeline request, re-use from pool when possible.
RequestPtr FromArgs(RespVec args, mi_heap_t* heap);
PipelineMessagePtr FromArgs(RespVec args, mi_heap_t* heap);
ParserStatus ParseRedis();
ParserStatus ParseMemcache();
@ -158,11 +206,11 @@ class Connection : public util::Connection {
void ShrinkPipelinePool();
// Returns non-null request ptr if pool has vacant entries.
RequestPtr GetFromPipelinePool();
PipelineMessagePtr GetFromPipelinePool();
private:
std::deque<RequestPtr> dispatch_q_; // dispatch queue
dfly::EventCount evc_; // dispatch queue waker
std::deque<MessageHandle> dispatch_q_; // dispatch queue
dfly::EventCount evc_; // dispatch queue waker
base::IoBuf io_buf_; // used in io loop and parsers
std::unique_ptr<RedisParser> redis_parser_;
@ -198,7 +246,7 @@ class Connection : public util::Connection {
// Pooled pipieline messages per-thread.
// Aggregated while handling pipelines,
// graudally released while handling regular commands.
static thread_local std::vector<RequestPtr> pipeline_req_pool_;
static thread_local std::vector<PipelineMessagePtr> pipeline_req_pool_;
};
void RespToArgList(const RespVec& src, CmdArgVec* dest);

View file

@ -127,7 +127,7 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis
if (to_reply) {
for (size_t i = 0; i < result.size(); ++i) {
owner()->SendPubMessageAsync({to_add, make_shared<string>(ArgS(args, i)), result[i]});
owner()->SendPubMessageAsync({to_add, ArgS(args, i), result[i]});
}
}
}

View file

@ -403,9 +403,9 @@ TEST_F(DflyEngineTest, PSubscribe) {
ASSERT_EQ(1, SubscriberMessagesLen("IO1"));
const facade::Connection::PubMessage& msg = GetPublishedMessage("IO1", 0);
EXPECT_EQ("foo", *msg.message);
EXPECT_EQ("ab", *msg.channel);
const auto& msg = GetPublishedMessage("IO1", 0);
EXPECT_EQ("foo", msg.Message());
EXPECT_EQ("ab", msg.Channel());
EXPECT_EQ("a*", msg.pattern);
}

View file

@ -1446,6 +1446,7 @@ void Service::Exec(CmdArgList args, ConnectionContext* cntx) {
void Service::Publish(CmdArgList args, ConnectionContext* cntx) {
string_view channel = ArgS(args, 0);
string_view msg = ArgS(args, 1);
auto* cs = ServerState::tlocal()->channel_store();
vector<ChannelStore::Subscriber> subscribers = cs->FetchSubscribers(channel);
@ -1453,17 +1454,18 @@ void Service::Publish(CmdArgList args, ConnectionContext* cntx) {
if (!subscribers.empty()) {
auto subscribers_ptr = make_shared<decltype(subscribers)>(move(subscribers));
auto msg_ptr = make_shared<string>(ArgS(args, 1));
auto channel_ptr = make_shared<string>(channel);
auto buf = shared_ptr<char[]>{new char[channel.size() + msg.size()]};
memcpy(buf.get(), channel.data(), channel.size());
memcpy(buf.get() + channel.size(), msg.data(), msg.size());
auto cb = [subscribers_ptr, msg_ptr, channel_ptr](unsigned idx, util::ProactorBase*) {
auto cb = [subscribers_ptr, buf, channel, msg](unsigned idx, util::ProactorBase*) {
auto it = lower_bound(subscribers_ptr->begin(), subscribers_ptr->end(), idx,
ChannelStore::Subscriber::ByThread);
while (it != subscribers_ptr->end() && it->thread_id == idx) {
facade::Connection* conn = it->conn_cntx->owner();
DCHECK(conn);
conn->SendPubMessageAsync({move(it->pattern), move(channel_ptr), move(msg_ptr)});
conn->SendPubMessageAsync({move(it->pattern), move(buf), channel.size(), msg.size()});
it->borrow_token.Dec();
it++;
}

View file

@ -62,15 +62,15 @@ TestConnection::TestConnection(Protocol protocol, io::StringSink* sink)
}
void TestConnection::SendPubMessageAsync(PubMessage pmsg) {
if (pmsg.type == PubMessage::kPublish) {
messages.push_back(move(pmsg));
} else {
if (auto* ptr = std::get_if<PubMessage::MessageData>(&pmsg.data); ptr != nullptr) {
messages.push_back(move(*ptr));
} else if (auto* ptr = std::get_if<PubMessage::SubscribeData>(&pmsg.data); ptr != nullptr) {
RedisReplyBuilder builder(sink_);
const char* action[2] = {"unsubscribe", "subscribe"};
builder.StartArray(3);
builder.SendBulkString(action[pmsg.type == PubMessage::kSubscribe]);
builder.SendBulkString(*pmsg.channel);
builder.SendLong(pmsg.channel_cnt);
builder.SendBulkString(action[ptr->add]);
builder.SendBulkString(ptr->channel);
builder.SendLong(ptr->channel_cnt);
}
}
@ -84,7 +84,7 @@ class BaseFamilyTest::TestConnWrapper {
RespVec ParseResponse(bool fully_consumed);
// returns: type(pmessage), pattern, channel, message.
const facade::Connection::PubMessage& GetPubMessage(size_t index) const;
const facade::Connection::PubMessage::MessageData& GetPubMessage(size_t index) const;
ConnectionContext* cmd_cntx() {
return &cmd_cntx_;
@ -360,7 +360,7 @@ RespVec BaseFamilyTest::TestConnWrapper::ParseResponse(bool fully_consumed) {
return res;
}
const facade::Connection::PubMessage& BaseFamilyTest::TestConnWrapper::GetPubMessage(
const facade::Connection::PubMessage::MessageData& BaseFamilyTest::TestConnWrapper::GetPubMessage(
size_t index) const {
CHECK_LT(index, dummy_conn_->messages.size());
return dummy_conn_->messages[index];
@ -391,8 +391,8 @@ size_t BaseFamilyTest::SubscriberMessagesLen(string_view conn_id) const {
return it->second->conn()->messages.size();
}
const facade::Connection::PubMessage& BaseFamilyTest::GetPublishedMessage(string_view conn_id,
size_t index) const {
const facade::Connection::PubMessage::MessageData& BaseFamilyTest::GetPublishedMessage(
string_view conn_id, size_t index) const {
auto it = connections_.find(conn_id);
CHECK(it != connections_.end());

View file

@ -23,7 +23,7 @@ class TestConnection : public facade::Connection {
void SendPubMessageAsync(PubMessage pmsg) final;
std::vector<PubMessage> messages;
std::vector<PubMessage::MessageData> messages;
private:
io::StringSink* sink_;
@ -87,8 +87,8 @@ class BaseFamilyTest : public ::testing::Test {
std::string GetId() const;
size_t SubscriberMessagesLen(std::string_view conn_id) const;
const facade::Connection::PubMessage& GetPublishedMessage(std::string_view conn_id,
size_t index) const;
const facade::Connection::PubMessage::MessageData& GetPublishedMessage(std::string_view conn_id,
size_t index) const;
std::unique_ptr<util::ProactorPool> pp_;
std::unique_ptr<Service> service_;