1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-15 17:51:06 +00:00

Implement the initial version of PUBSUB.

There are some things left to polish, mainly around
processing a subset of commands while being blocked in a subscribed state.
Also I need to solve the issue of atomic replies when publishing messages in parallel with
processing the whitelisted commands in a subscribed state.
This commit is contained in:
Roman Gershman 2022-03-29 20:07:49 +03:00
parent e46f2b5384
commit e29f76ad4d
13 changed files with 307 additions and 14 deletions

2
helio

@ -1 +1 @@
Subproject commit 408d201f365ce886f1b1e762810e81e0509ea8b9
Subproject commit 2cf77beb5b9ae70f594380e8df3c0e347f39f0af

View file

@ -47,6 +47,8 @@ class ConnectionContext {
bool req_auth: 1;
bool replica_conn: 1;
bool authenticated: 1;
virtual void OnClose() {}
private:
Connection* owner_;
std::unique_ptr<SinkReplyBuilder> rbuilder_;

View file

@ -284,6 +284,7 @@ void Connection::ConnectionFlow(FiberSocketBase* peer) {
cc_->conn_closing = true; // Signal dispatch to close.
evc_.notify();
dispatch_fb.join();
cc_->OnClose();
stats->read_buf_capacity -= io_buf_.Capacity();

View file

@ -1,4 +1,4 @@
// Copyright 2021, Roman Gershman. All rights reserved.
// Copyright 2022, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//

View file

@ -1,11 +1,11 @@
add_executable(dragonfly dfly_main.cc)
cxx_link(dragonfly base dragonfly_lib)
add_library(dragonfly_lib command_registry.cc common.cc config_flags.cc
db_slice.cc debugcmd.cc
add_library(dragonfly_lib channel_slice.cc command_registry.cc common.cc config_flags.cc
conn_context.cc db_slice.cc debugcmd.cc
engine_shard_set.cc generic_family.cc hset_family.cc
list_family.cc main_service.cc rdb_load.cc rdb_save.cc replica.cc
snapshot.cc script_mgr.cc server_family.cc
snapshot.cc script_mgr.cc server_family.cc
set_family.cc
string_family.cc transaction.cc zset_family.cc)
@ -26,5 +26,5 @@ cxx_test(zset_family_test dfly_test_lib LABELS DFLY)
add_custom_target(check_dfly WORKING_DIRECTORY .. COMMAND ctest -L DFLY)
add_dependencies(check_dfly dragonfly_test list_family_test
generic_family_test memcache_parser_test rdb_test
generic_family_test memcache_parser_test rdb_test
redis_parser_test string_family_test)

View file

@ -0,0 +1,48 @@
// Copyright 2022, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
#include "server/channel_slice.h"
namespace dfly {
using namespace std;
ChannelSlice::Subscriber::Subscriber(ConnectionContext* cntx, uint32_t tid)
: conn_cntx(cntx), borrow_token(cntx->conn_state.subscribe_info->borrow_token), thread_id(tid) {
}
void ChannelSlice::RemoveSubscription(string_view channel, ConnectionContext* me) {
auto it = channels_.find(channel);
if (it != channels_.end()) {
it->second->subscribers.erase(me);
if (it->second->subscribers.empty())
channels_.erase(it);
}
}
void ChannelSlice::AddSubscription(string_view channel, ConnectionContext* me, uint32_t thread_id) {
auto [it, added] = channels_.emplace(channel, nullptr);
if (added) {
it->second.reset(new Channel);
}
it->second->subscribers.emplace(me, SubscriberInternal{thread_id});
}
auto ChannelSlice::FetchSubscribers(string_view channel) -> vector<Subscriber> {
vector<Subscriber> res;
auto it = channels_.find(channel);
if (it != channels_.end()) {
res.reserve(it->second->subscribers.size());
for (const auto& k_v : it->second->subscribers) {
Subscriber s(k_v.first, k_v.second.thread_id);
s.borrow_token.Inc();
res.push_back(std::move(s));
}
}
return res;
}
} // namespace dfly

View file

@ -0,0 +1,52 @@
// Copyright 2022, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
#pragma once
#include <absl/container/flat_hash_map.h>
#include <string_view>
#include "server/conn_context.h"
namespace dfly {
// Database holding pubsub subscribers.
class ChannelSlice {
public:
struct Subscriber {
ConnectionContext* conn_cntx;
util::fibers_ext::BlockingCounter borrow_token;
uint32_t thread_id;
Subscriber(ConnectionContext* cntx, uint32_t tid);
// Subscriber() : borrow_token(0) {}
Subscriber(Subscriber&&) noexcept = default;
Subscriber& operator=(Subscriber&&) noexcept = default;
Subscriber(const Subscriber&) = delete;
void operator=(const Subscriber&) = delete;
};
std::vector<Subscriber> FetchSubscribers(std::string_view channel);
void RemoveSubscription(std::string_view channel, ConnectionContext* me);
void AddSubscription(std::string_view channel, ConnectionContext* me, uint32_t thread_id);
private:
struct SubscriberInternal {
uint32_t thread_id; // proactor thread id.
SubscriberInternal(uint32_t tid) : thread_id(tid) {}
};
struct Channel {
absl::flat_hash_map<ConnectionContext*, SubscriberInternal> subscribers;
};
absl::flat_hash_map<std::string, std::unique_ptr<Channel>> channels_;
absl::flat_hash_map<std::string, std::unique_ptr<Channel>> patterns_;
};
} // namespace dfly

115
src/server/conn_context.cc Normal file
View file

@ -0,0 +1,115 @@
// Copyright 2022, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
#include "server/conn_context.h"
#include "base/logging.h"
#include "server/engine_shard_set.h"
#include "util/proactor_base.h"
namespace dfly {
using namespace std;
void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgList args) {
vector<unsigned> result(to_reply ? args.size() : 0, 0);
if (to_add || conn_state.subscribe_info) {
std::vector<pair<ShardId, string_view>> channels;
channels.reserve(args.size());
if (!conn_state.subscribe_info) {
DCHECK(to_add);
conn_state.subscribe_info.reset(new ConnectionState::SubscribeInfo);
}
for (size_t i = 0; i < args.size(); ++i) {
bool res = false;
string_view channel = ArgS(args, i);
if (to_add) {
res = conn_state.subscribe_info->channels.emplace(channel).second;
} else {
res = conn_state.subscribe_info->channels.erase(channel) > 0;
}
if (to_reply)
result[i] = conn_state.subscribe_info->channels.size();
if (res) {
ShardId sid = Shard(channel, shard_set->size());
channels.emplace_back(sid, channel);
}
}
if (!to_add && conn_state.subscribe_info->channels.empty()) {
conn_state.subscribe_info.reset();
}
sort(channels.begin(), channels.end());
vector<unsigned> shard_idx(shard_set->size() + 1, 0);
for (const auto& k_v : channels) {
shard_idx[k_v.first]++;
}
unsigned prev = shard_idx[0];
shard_idx[0] = 0;
// compute cumulitive sum, or in other words a beginning index in channels for each shard.
for (size_t i = 1; i < shard_idx.size(); ++i) {
unsigned cur = shard_idx[i];
shard_idx[i] = shard_idx[i - 1] + prev;
prev = cur;
}
int32_t tid = util::ProactorBase::GetIndex();
DCHECK_GE(tid, 0);
auto cb = [&](EngineShard* shard) {
ChannelSlice& cs = shard->channel_slice();
unsigned start = shard_idx[shard->shard_id()];
unsigned end = shard_idx[shard->shard_id() + 1];
DCHECK_LT(start, end);
for (unsigned i = start; i < end; ++i) {
if (to_add) {
cs.AddSubscription(channels[i].second, this, tid);
} else {
cs.RemoveSubscription(channels[i].second, this);
}
}
};
shard_set->RunBriefInParallel(move(cb),
[&](ShardId sid) { return shard_idx[sid + 1] > shard_idx[sid]; });
}
if (to_reply) {
const char* action[2] = {"unsubscribe", "subscribe"};
for (size_t i = 0; i < result.size(); ++i) {
(*this)->StartArray(3);
(*this)->SendBulkString(action[to_add]);
(*this)->SendBulkString(ArgS(args, i));
(*this)->SendLong(result[i]);
}
}
}
void ConnectionContext::OnClose() {
if (conn_state.subscribe_info) {
StringVec channels(conn_state.subscribe_info->channels.begin(),
conn_state.subscribe_info->channels.end());
CmdArgVec arg_vec(channels.begin(), channels.end());
auto token = conn_state.subscribe_info->borrow_token;
ChangeSubscription(false, false, CmdArgList{arg_vec});
DCHECK(!conn_state.subscribe_info);
// Check that all borrowers finished processing
token.Wait();
}
}
} // namespace dfly

View file

@ -8,6 +8,7 @@
#include "facade/conn_context.h"
#include "server/common_types.h"
#include "util/fibers/fibers_ext.h"
namespace dfly {
@ -45,6 +46,17 @@ struct ConnectionState {
absl::flat_hash_set<std::string_view> keys;
};
std::optional<Script> script_info;
struct SubscribeInfo {
// TODO: to provide unique_strings across service. This will allow us to use string_view here.
absl::flat_hash_set<std::string> channels;
util::fibers_ext::BlockingCounter borrow_token;
SubscribeInfo() : borrow_token(0) {}
};
std::unique_ptr<SubscribeInfo> subscribe_info;
};
class ConnectionContext : public facade::ConnectionContext {
@ -52,6 +64,9 @@ class ConnectionContext : public facade::ConnectionContext {
ConnectionContext(::io::Sink* stream, facade::Connection* owner)
: facade::ConnectionContext(stream, owner) {
}
void OnClose() override;
struct DebugInfo {
uint32_t shards_count = 0;
TxClock clock = 0;
@ -69,6 +84,8 @@ class ConnectionContext : public facade::ConnectionContext {
DbIndex db_index() const {
return conn_state.db_index;
}
void ChangeSubscription(bool to_add, bool to_reply, CmdArgList args);
};
} // namespace dfly

View file

@ -18,7 +18,6 @@ extern "C" {
namespace dfly {
using namespace boost;
using namespace std;
using namespace util;
using facade::OpStatus;

View file

@ -15,6 +15,7 @@ extern "C" {
#include "base/string_view_sso.h"
#include "core/mi_memory_resource.h"
#include "core/tx_queue.h"
#include "server/channel_slice.h"
#include "server/db_slice.h"
#include "util/fibers/fiberqueue_threadpool.h"
#include "util/fibers/fibers_ext.h"
@ -54,6 +55,10 @@ class EngineShard {
return db_slice_;
}
ChannelSlice& channel_slice() {
return channel_slice_;
}
std::pmr::memory_resource* memory_resource() {
return &mi_resource_;
}
@ -163,6 +168,8 @@ class EngineShard {
TxQueue txq_;
MiMemoryResource mi_resource_;
DbSlice db_slice_;
ChannelSlice channel_slice_;
Stats stats_;
// Logical ts used to order distributed transactions.

View file

@ -41,7 +41,6 @@ DEFINE_uint32(memcache_port, 0, "Memcached port");
DECLARE_string(requirepass);
DEFINE_uint64(maxmemory, 0, "Limit on maximum-memory that is used by the database");
namespace dfly {
using namespace std;
@ -856,13 +855,64 @@ void Service::Exec(CmdArgList args, ConnectionContext* cntx) {
}
void Service::Publish(CmdArgList args, ConnectionContext* cntx) {
(*cntx)->SendLong(0);
string_view channel = ArgS(args, 1);
string_view message = ArgS(args, 2);
ShardId sid = Shard(channel, shard_count());
auto cb = [&] { return EngineShard::tlocal()->channel_slice().FetchSubscribers(channel); };
vector<ChannelSlice::Subscriber> res = shard_set_.Await(sid, std::move(cb));
atomic_uint32_t published{0};
if (!res.empty()) {
sort(res.begin(), res.end(),
[](const auto& left, const auto& right) { return left.thread_id < right.thread_id; });
vector<unsigned> slices(shard_set_.pool()->size(), UINT_MAX);
for (size_t i = 0; i < res.size(); ++i) {
if (slices[res[i].thread_id] > i) {
slices[res[i].thread_id] = i;
}
}
auto cb = [&](unsigned idx, util::ProactorBase*) {
unsigned start = slices[idx];
for (unsigned i = start; i < res.size(); ++i) {
if (res[i].thread_id != idx)
break;
if (!res[i].conn_cntx->conn_closing) {
published.fetch_add(1, memory_order_relaxed);
// TODO: this is wrong because ReplyBuilder does not guarantee atomicity if used
// concurrently by multiple fibers.
string_view msg_arr[3] = {"message", channel, message};
(*res[i].conn_cntx)->SendStringArr(msg_arr);
}
}
};
shard_set_.pool()->AwaitFiberOnAll(cb);
}
for (auto& s : res) {
s.borrow_token.Dec();
}
(*cntx)->SendLong(published.load(memory_order_relaxed));
}
void Service::Subscribe(CmdArgList args, ConnectionContext* cntx) {
(*cntx)->SendOk();
args.remove_prefix(1);
cntx->ChangeSubscription(true /*add*/, true /* reply*/, std::move(args));
}
void Service::Unsubscribe(CmdArgList args, ConnectionContext* cntx) {
args.remove_prefix(1);
cntx->ChangeSubscription(false, true, std::move(args));
}
VarzValue::Map Service::GetVarzStats() {
VarzValue::Map res;
@ -893,8 +943,9 @@ void Service::RegisterCommands() {
<< CI{"EVAL", CO::NOSCRIPT, -3, 0, 0, 0}.MFUNC(Eval).SetValidator(&EvalValidator)
<< CI{"EVALSHA", CO::NOSCRIPT, -3, 0, 0, 0}.MFUNC(EvalSha).SetValidator(&EvalValidator)
<< CI{"EXEC", kExecMask, 1, 0, 0, 0}.MFUNC(Exec)
<< CI{"PUBLISH", CO::LOADING| CO::FAST, 3, 0, 0, 0}.HFUNC(Publish)
<< CI{"SUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -2, 0, 0, 0}.HFUNC(Subscribe);
<< CI{"PUBLISH", CO::LOADING | CO::FAST, 3, 0, 0, 0}.MFUNC(Publish)
<< CI{"SUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -2, 0, 0, 0}.MFUNC(Subscribe)
<< CI{"UNSUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -2, 0, 0, 0}.MFUNC(Unsubscribe);
StringFamily::Register(&registry_);
GenericFamily::Register(&registry_);

View file

@ -70,12 +70,13 @@ class Service : public facade::ServiceInterface {
private:
static void Quit(CmdArgList args, ConnectionContext* cntx);
static void Multi(CmdArgList args, ConnectionContext* cntx);
static void Publish(CmdArgList args, ConnectionContext* cntx);
static void Subscribe(CmdArgList args, ConnectionContext* cntx);
void Eval(CmdArgList args, ConnectionContext* cntx);
void EvalSha(CmdArgList args, ConnectionContext* cntx);
void Exec(CmdArgList args, ConnectionContext* cntx);
void Publish(CmdArgList args, ConnectionContext* cntx);
void Subscribe(CmdArgList args, ConnectionContext* cntx);
void Unsubscribe(CmdArgList args, ConnectionContext* cntx);
struct EvalArgs {
std::string_view sha; // only one of them is defined.