mirror of
https://github.com/dragonflydb/dragonfly.git
synced 2024-12-15 17:51:06 +00:00
Implement the initial version of PUBSUB.
There are some things left to polish, mainly around processing a subset of commands while being blocked in a subscribed state. Also I need to solve the issue of atomic replies when publishing messages in parallel with processing the whitelisted commands in a subscribed state.
This commit is contained in:
parent
e46f2b5384
commit
e29f76ad4d
13 changed files with 307 additions and 14 deletions
2
helio
2
helio
|
@ -1 +1 @@
|
|||
Subproject commit 408d201f365ce886f1b1e762810e81e0509ea8b9
|
||||
Subproject commit 2cf77beb5b9ae70f594380e8df3c0e347f39f0af
|
|
@ -47,6 +47,8 @@ class ConnectionContext {
|
|||
bool req_auth: 1;
|
||||
bool replica_conn: 1;
|
||||
bool authenticated: 1;
|
||||
|
||||
virtual void OnClose() {}
|
||||
private:
|
||||
Connection* owner_;
|
||||
std::unique_ptr<SinkReplyBuilder> rbuilder_;
|
||||
|
|
|
@ -284,6 +284,7 @@ void Connection::ConnectionFlow(FiberSocketBase* peer) {
|
|||
cc_->conn_closing = true; // Signal dispatch to close.
|
||||
evc_.notify();
|
||||
dispatch_fb.join();
|
||||
cc_->OnClose();
|
||||
|
||||
stats->read_buf_capacity -= io_buf_.Capacity();
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright 2021, Roman Gershman. All rights reserved.
|
||||
// Copyright 2022, Roman Gershman. All rights reserved.
|
||||
// See LICENSE for licensing terms.
|
||||
//
|
||||
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
add_executable(dragonfly dfly_main.cc)
|
||||
cxx_link(dragonfly base dragonfly_lib)
|
||||
|
||||
add_library(dragonfly_lib command_registry.cc common.cc config_flags.cc
|
||||
db_slice.cc debugcmd.cc
|
||||
add_library(dragonfly_lib channel_slice.cc command_registry.cc common.cc config_flags.cc
|
||||
conn_context.cc db_slice.cc debugcmd.cc
|
||||
engine_shard_set.cc generic_family.cc hset_family.cc
|
||||
list_family.cc main_service.cc rdb_load.cc rdb_save.cc replica.cc
|
||||
snapshot.cc script_mgr.cc server_family.cc
|
||||
snapshot.cc script_mgr.cc server_family.cc
|
||||
set_family.cc
|
||||
string_family.cc transaction.cc zset_family.cc)
|
||||
|
||||
|
@ -26,5 +26,5 @@ cxx_test(zset_family_test dfly_test_lib LABELS DFLY)
|
|||
|
||||
add_custom_target(check_dfly WORKING_DIRECTORY .. COMMAND ctest -L DFLY)
|
||||
add_dependencies(check_dfly dragonfly_test list_family_test
|
||||
generic_family_test memcache_parser_test rdb_test
|
||||
generic_family_test memcache_parser_test rdb_test
|
||||
redis_parser_test string_family_test)
|
||||
|
|
48
src/server/channel_slice.cc
Normal file
48
src/server/channel_slice.cc
Normal file
|
@ -0,0 +1,48 @@
|
|||
// Copyright 2022, Roman Gershman. All rights reserved.
|
||||
// See LICENSE for licensing terms.
|
||||
//
|
||||
|
||||
#include "server/channel_slice.h"
|
||||
|
||||
namespace dfly {
|
||||
using namespace std;
|
||||
|
||||
ChannelSlice::Subscriber::Subscriber(ConnectionContext* cntx, uint32_t tid)
|
||||
: conn_cntx(cntx), borrow_token(cntx->conn_state.subscribe_info->borrow_token), thread_id(tid) {
|
||||
}
|
||||
|
||||
void ChannelSlice::RemoveSubscription(string_view channel, ConnectionContext* me) {
|
||||
auto it = channels_.find(channel);
|
||||
if (it != channels_.end()) {
|
||||
it->second->subscribers.erase(me);
|
||||
if (it->second->subscribers.empty())
|
||||
channels_.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
void ChannelSlice::AddSubscription(string_view channel, ConnectionContext* me, uint32_t thread_id) {
|
||||
auto [it, added] = channels_.emplace(channel, nullptr);
|
||||
if (added) {
|
||||
it->second.reset(new Channel);
|
||||
}
|
||||
it->second->subscribers.emplace(me, SubscriberInternal{thread_id});
|
||||
}
|
||||
|
||||
auto ChannelSlice::FetchSubscribers(string_view channel) -> vector<Subscriber> {
|
||||
vector<Subscriber> res;
|
||||
|
||||
auto it = channels_.find(channel);
|
||||
if (it != channels_.end()) {
|
||||
res.reserve(it->second->subscribers.size());
|
||||
for (const auto& k_v : it->second->subscribers) {
|
||||
Subscriber s(k_v.first, k_v.second.thread_id);
|
||||
s.borrow_token.Inc();
|
||||
|
||||
res.push_back(std::move(s));
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
} // namespace dfly
|
52
src/server/channel_slice.h
Normal file
52
src/server/channel_slice.h
Normal file
|
@ -0,0 +1,52 @@
|
|||
// Copyright 2022, Roman Gershman. All rights reserved.
|
||||
// See LICENSE for licensing terms.
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include <absl/container/flat_hash_map.h>
|
||||
|
||||
#include <string_view>
|
||||
|
||||
#include "server/conn_context.h"
|
||||
|
||||
namespace dfly {
|
||||
|
||||
// Database holding pubsub subscribers.
|
||||
class ChannelSlice {
|
||||
public:
|
||||
struct Subscriber {
|
||||
ConnectionContext* conn_cntx;
|
||||
util::fibers_ext::BlockingCounter borrow_token;
|
||||
uint32_t thread_id;
|
||||
|
||||
Subscriber(ConnectionContext* cntx, uint32_t tid);
|
||||
// Subscriber() : borrow_token(0) {}
|
||||
|
||||
Subscriber(Subscriber&&) noexcept = default;
|
||||
Subscriber& operator=(Subscriber&&) noexcept = default;
|
||||
|
||||
Subscriber(const Subscriber&) = delete;
|
||||
void operator=(const Subscriber&) = delete;
|
||||
};
|
||||
|
||||
std::vector<Subscriber> FetchSubscribers(std::string_view channel);
|
||||
|
||||
void RemoveSubscription(std::string_view channel, ConnectionContext* me);
|
||||
void AddSubscription(std::string_view channel, ConnectionContext* me, uint32_t thread_id);
|
||||
|
||||
private:
|
||||
struct SubscriberInternal {
|
||||
uint32_t thread_id; // proactor thread id.
|
||||
|
||||
SubscriberInternal(uint32_t tid) : thread_id(tid) {}
|
||||
};
|
||||
|
||||
struct Channel {
|
||||
absl::flat_hash_map<ConnectionContext*, SubscriberInternal> subscribers;
|
||||
};
|
||||
|
||||
absl::flat_hash_map<std::string, std::unique_ptr<Channel>> channels_;
|
||||
absl::flat_hash_map<std::string, std::unique_ptr<Channel>> patterns_;
|
||||
};
|
||||
|
||||
} // namespace dfly
|
115
src/server/conn_context.cc
Normal file
115
src/server/conn_context.cc
Normal file
|
@ -0,0 +1,115 @@
|
|||
// Copyright 2022, Roman Gershman. All rights reserved.
|
||||
// See LICENSE for licensing terms.
|
||||
//
|
||||
|
||||
#include "server/conn_context.h"
|
||||
|
||||
#include "base/logging.h"
|
||||
#include "server/engine_shard_set.h"
|
||||
#include "util/proactor_base.h"
|
||||
|
||||
namespace dfly {
|
||||
|
||||
using namespace std;
|
||||
|
||||
void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgList args) {
|
||||
vector<unsigned> result(to_reply ? args.size() : 0, 0);
|
||||
|
||||
if (to_add || conn_state.subscribe_info) {
|
||||
std::vector<pair<ShardId, string_view>> channels;
|
||||
channels.reserve(args.size());
|
||||
|
||||
if (!conn_state.subscribe_info) {
|
||||
DCHECK(to_add);
|
||||
|
||||
conn_state.subscribe_info.reset(new ConnectionState::SubscribeInfo);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
bool res = false;
|
||||
string_view channel = ArgS(args, i);
|
||||
if (to_add) {
|
||||
res = conn_state.subscribe_info->channels.emplace(channel).second;
|
||||
} else {
|
||||
res = conn_state.subscribe_info->channels.erase(channel) > 0;
|
||||
}
|
||||
|
||||
if (to_reply)
|
||||
result[i] = conn_state.subscribe_info->channels.size();
|
||||
|
||||
if (res) {
|
||||
ShardId sid = Shard(channel, shard_set->size());
|
||||
channels.emplace_back(sid, channel);
|
||||
}
|
||||
}
|
||||
|
||||
if (!to_add && conn_state.subscribe_info->channels.empty()) {
|
||||
conn_state.subscribe_info.reset();
|
||||
}
|
||||
|
||||
sort(channels.begin(), channels.end());
|
||||
|
||||
vector<unsigned> shard_idx(shard_set->size() + 1, 0);
|
||||
for (const auto& k_v : channels) {
|
||||
shard_idx[k_v.first]++;
|
||||
}
|
||||
unsigned prev = shard_idx[0];
|
||||
shard_idx[0] = 0;
|
||||
|
||||
// compute cumulitive sum, or in other words a beginning index in channels for each shard.
|
||||
for (size_t i = 1; i < shard_idx.size(); ++i) {
|
||||
unsigned cur = shard_idx[i];
|
||||
shard_idx[i] = shard_idx[i - 1] + prev;
|
||||
prev = cur;
|
||||
}
|
||||
|
||||
int32_t tid = util::ProactorBase::GetIndex();
|
||||
DCHECK_GE(tid, 0);
|
||||
|
||||
auto cb = [&](EngineShard* shard) {
|
||||
ChannelSlice& cs = shard->channel_slice();
|
||||
unsigned start = shard_idx[shard->shard_id()];
|
||||
unsigned end = shard_idx[shard->shard_id() + 1];
|
||||
|
||||
DCHECK_LT(start, end);
|
||||
for (unsigned i = start; i < end; ++i) {
|
||||
if (to_add) {
|
||||
cs.AddSubscription(channels[i].second, this, tid);
|
||||
} else {
|
||||
cs.RemoveSubscription(channels[i].second, this);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
shard_set->RunBriefInParallel(move(cb),
|
||||
[&](ShardId sid) { return shard_idx[sid + 1] > shard_idx[sid]; });
|
||||
}
|
||||
|
||||
if (to_reply) {
|
||||
const char* action[2] = {"unsubscribe", "subscribe"};
|
||||
|
||||
for (size_t i = 0; i < result.size(); ++i) {
|
||||
(*this)->StartArray(3);
|
||||
(*this)->SendBulkString(action[to_add]);
|
||||
(*this)->SendBulkString(ArgS(args, i));
|
||||
(*this)->SendLong(result[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ConnectionContext::OnClose() {
|
||||
if (conn_state.subscribe_info) {
|
||||
StringVec channels(conn_state.subscribe_info->channels.begin(),
|
||||
conn_state.subscribe_info->channels.end());
|
||||
CmdArgVec arg_vec(channels.begin(), channels.end());
|
||||
|
||||
auto token = conn_state.subscribe_info->borrow_token;
|
||||
ChangeSubscription(false, false, CmdArgList{arg_vec});
|
||||
DCHECK(!conn_state.subscribe_info);
|
||||
|
||||
// Check that all borrowers finished processing
|
||||
token.Wait();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace dfly
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
#include "facade/conn_context.h"
|
||||
#include "server/common_types.h"
|
||||
#include "util/fibers/fibers_ext.h"
|
||||
|
||||
namespace dfly {
|
||||
|
||||
|
@ -45,6 +46,17 @@ struct ConnectionState {
|
|||
absl::flat_hash_set<std::string_view> keys;
|
||||
};
|
||||
std::optional<Script> script_info;
|
||||
|
||||
struct SubscribeInfo {
|
||||
// TODO: to provide unique_strings across service. This will allow us to use string_view here.
|
||||
absl::flat_hash_set<std::string> channels;
|
||||
|
||||
util::fibers_ext::BlockingCounter borrow_token;
|
||||
|
||||
SubscribeInfo() : borrow_token(0) {}
|
||||
};
|
||||
|
||||
std::unique_ptr<SubscribeInfo> subscribe_info;
|
||||
};
|
||||
|
||||
class ConnectionContext : public facade::ConnectionContext {
|
||||
|
@ -52,6 +64,9 @@ class ConnectionContext : public facade::ConnectionContext {
|
|||
ConnectionContext(::io::Sink* stream, facade::Connection* owner)
|
||||
: facade::ConnectionContext(stream, owner) {
|
||||
}
|
||||
|
||||
void OnClose() override;
|
||||
|
||||
struct DebugInfo {
|
||||
uint32_t shards_count = 0;
|
||||
TxClock clock = 0;
|
||||
|
@ -69,6 +84,8 @@ class ConnectionContext : public facade::ConnectionContext {
|
|||
DbIndex db_index() const {
|
||||
return conn_state.db_index;
|
||||
}
|
||||
|
||||
void ChangeSubscription(bool to_add, bool to_reply, CmdArgList args);
|
||||
};
|
||||
|
||||
} // namespace dfly
|
||||
|
|
|
@ -18,7 +18,6 @@ extern "C" {
|
|||
|
||||
namespace dfly {
|
||||
|
||||
using namespace boost;
|
||||
using namespace std;
|
||||
using namespace util;
|
||||
using facade::OpStatus;
|
||||
|
|
|
@ -15,6 +15,7 @@ extern "C" {
|
|||
#include "base/string_view_sso.h"
|
||||
#include "core/mi_memory_resource.h"
|
||||
#include "core/tx_queue.h"
|
||||
#include "server/channel_slice.h"
|
||||
#include "server/db_slice.h"
|
||||
#include "util/fibers/fiberqueue_threadpool.h"
|
||||
#include "util/fibers/fibers_ext.h"
|
||||
|
@ -54,6 +55,10 @@ class EngineShard {
|
|||
return db_slice_;
|
||||
}
|
||||
|
||||
ChannelSlice& channel_slice() {
|
||||
return channel_slice_;
|
||||
}
|
||||
|
||||
std::pmr::memory_resource* memory_resource() {
|
||||
return &mi_resource_;
|
||||
}
|
||||
|
@ -163,6 +168,8 @@ class EngineShard {
|
|||
TxQueue txq_;
|
||||
MiMemoryResource mi_resource_;
|
||||
DbSlice db_slice_;
|
||||
ChannelSlice channel_slice_;
|
||||
|
||||
Stats stats_;
|
||||
|
||||
// Logical ts used to order distributed transactions.
|
||||
|
|
|
@ -41,7 +41,6 @@ DEFINE_uint32(memcache_port, 0, "Memcached port");
|
|||
DECLARE_string(requirepass);
|
||||
DEFINE_uint64(maxmemory, 0, "Limit on maximum-memory that is used by the database");
|
||||
|
||||
|
||||
namespace dfly {
|
||||
|
||||
using namespace std;
|
||||
|
@ -856,13 +855,64 @@ void Service::Exec(CmdArgList args, ConnectionContext* cntx) {
|
|||
}
|
||||
|
||||
void Service::Publish(CmdArgList args, ConnectionContext* cntx) {
|
||||
(*cntx)->SendLong(0);
|
||||
string_view channel = ArgS(args, 1);
|
||||
string_view message = ArgS(args, 2);
|
||||
ShardId sid = Shard(channel, shard_count());
|
||||
|
||||
auto cb = [&] { return EngineShard::tlocal()->channel_slice().FetchSubscribers(channel); };
|
||||
|
||||
vector<ChannelSlice::Subscriber> res = shard_set_.Await(sid, std::move(cb));
|
||||
atomic_uint32_t published{0};
|
||||
|
||||
if (!res.empty()) {
|
||||
sort(res.begin(), res.end(),
|
||||
[](const auto& left, const auto& right) { return left.thread_id < right.thread_id; });
|
||||
|
||||
vector<unsigned> slices(shard_set_.pool()->size(), UINT_MAX);
|
||||
for (size_t i = 0; i < res.size(); ++i) {
|
||||
if (slices[res[i].thread_id] > i) {
|
||||
slices[res[i].thread_id] = i;
|
||||
}
|
||||
}
|
||||
|
||||
auto cb = [&](unsigned idx, util::ProactorBase*) {
|
||||
unsigned start = slices[idx];
|
||||
for (unsigned i = start; i < res.size(); ++i) {
|
||||
if (res[i].thread_id != idx)
|
||||
break;
|
||||
|
||||
if (!res[i].conn_cntx->conn_closing) {
|
||||
published.fetch_add(1, memory_order_relaxed);
|
||||
|
||||
// TODO: this is wrong because ReplyBuilder does not guarantee atomicity if used
|
||||
// concurrently by multiple fibers.
|
||||
string_view msg_arr[3] = {"message", channel, message};
|
||||
(*res[i].conn_cntx)->SendStringArr(msg_arr);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
shard_set_.pool()->AwaitFiberOnAll(cb);
|
||||
}
|
||||
|
||||
for (auto& s : res) {
|
||||
s.borrow_token.Dec();
|
||||
}
|
||||
|
||||
(*cntx)->SendLong(published.load(memory_order_relaxed));
|
||||
}
|
||||
|
||||
void Service::Subscribe(CmdArgList args, ConnectionContext* cntx) {
|
||||
(*cntx)->SendOk();
|
||||
args.remove_prefix(1);
|
||||
|
||||
cntx->ChangeSubscription(true /*add*/, true /* reply*/, std::move(args));
|
||||
}
|
||||
|
||||
void Service::Unsubscribe(CmdArgList args, ConnectionContext* cntx) {
|
||||
args.remove_prefix(1);
|
||||
|
||||
cntx->ChangeSubscription(false, true, std::move(args));
|
||||
}
|
||||
|
||||
VarzValue::Map Service::GetVarzStats() {
|
||||
VarzValue::Map res;
|
||||
|
@ -893,8 +943,9 @@ void Service::RegisterCommands() {
|
|||
<< CI{"EVAL", CO::NOSCRIPT, -3, 0, 0, 0}.MFUNC(Eval).SetValidator(&EvalValidator)
|
||||
<< CI{"EVALSHA", CO::NOSCRIPT, -3, 0, 0, 0}.MFUNC(EvalSha).SetValidator(&EvalValidator)
|
||||
<< CI{"EXEC", kExecMask, 1, 0, 0, 0}.MFUNC(Exec)
|
||||
<< CI{"PUBLISH", CO::LOADING| CO::FAST, 3, 0, 0, 0}.HFUNC(Publish)
|
||||
<< CI{"SUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -2, 0, 0, 0}.HFUNC(Subscribe);
|
||||
<< CI{"PUBLISH", CO::LOADING | CO::FAST, 3, 0, 0, 0}.MFUNC(Publish)
|
||||
<< CI{"SUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -2, 0, 0, 0}.MFUNC(Subscribe)
|
||||
<< CI{"UNSUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -2, 0, 0, 0}.MFUNC(Unsubscribe);
|
||||
|
||||
StringFamily::Register(®istry_);
|
||||
GenericFamily::Register(®istry_);
|
||||
|
|
|
@ -70,12 +70,13 @@ class Service : public facade::ServiceInterface {
|
|||
private:
|
||||
static void Quit(CmdArgList args, ConnectionContext* cntx);
|
||||
static void Multi(CmdArgList args, ConnectionContext* cntx);
|
||||
static void Publish(CmdArgList args, ConnectionContext* cntx);
|
||||
static void Subscribe(CmdArgList args, ConnectionContext* cntx);
|
||||
|
||||
void Eval(CmdArgList args, ConnectionContext* cntx);
|
||||
void EvalSha(CmdArgList args, ConnectionContext* cntx);
|
||||
void Exec(CmdArgList args, ConnectionContext* cntx);
|
||||
void Publish(CmdArgList args, ConnectionContext* cntx);
|
||||
void Subscribe(CmdArgList args, ConnectionContext* cntx);
|
||||
void Unsubscribe(CmdArgList args, ConnectionContext* cntx);
|
||||
|
||||
struct EvalArgs {
|
||||
std::string_view sha; // only one of them is defined.
|
||||
|
|
Loading…
Reference in a new issue