diff --git a/src/server/channel_store.cc b/src/server/channel_store.cc index 06d118352..8eeae4ee1 100644 --- a/src/server/channel_store.cc +++ b/src/server/channel_store.cc @@ -11,10 +11,20 @@ extern "C" { } #include "base/logging.h" +#include "server/engine_shard_set.h" +#include "server/server_state.h" namespace dfly { using namespace std; +namespace { + +bool Matches(string_view pattern, string_view channel) { + return stringmatchlen(pattern.data(), pattern.size(), channel.data(), channel.size(), 0) == 1; +} + +} // namespace + ChannelStore::Subscriber::Subscriber(ConnectionContext* cntx, uint32_t tid) : conn_cntx(cntx), borrow_token(cntx->conn_state.subscribe_info->borrow_token), thread_id(tid) { } @@ -23,10 +33,36 @@ ChannelStore::Subscriber::Subscriber(uint32_t tid) : conn_cntx(nullptr), borrow_token(0), thread_id(tid) { } +bool ChannelStore::Subscriber::ByThread(const Subscriber& lhs, const Subscriber& rhs) { + if (lhs.thread_id == rhs.thread_id) + return (lhs.conn_cntx != nullptr) < (rhs.conn_cntx != nullptr); + return lhs.thread_id < rhs.thread_id; +} + +ChannelStore::UpdatablePointer::UpdatablePointer(const UpdatablePointer& other) { + ptr.store(other.ptr.load(memory_order_relaxed), memory_order_relaxed); +} + +ChannelStore::SubscribeMap* ChannelStore::UpdatablePointer::Get() const { + return ptr.load(memory_order_relaxed); +} + +void ChannelStore::UpdatablePointer::Set(ChannelStore::SubscribeMap* sm) { + ptr.store(sm, memory_order_relaxed); +} + +ChannelStore::SubscribeMap* ChannelStore::UpdatablePointer::operator->() { + return Get(); +} + +const ChannelStore::SubscribeMap& ChannelStore::UpdatablePointer::operator*() const { + return *Get(); +} + void ChannelStore::ChannelMap::Add(string_view key, ConnectionContext* me, uint32_t thread_id) { auto it = find(key); if (it == end()) - it = emplace(key, make_unique()).first; + it = emplace(key, new SubscribeMap{}).first; it->second->emplace(me, thread_id); } @@ -38,38 +74,42 @@ void ChannelStore::ChannelMap::Remove(string_view key, ConnectionContext* me) { } } -void ChannelStore::AddSub(string_view channel, ConnectionContext* me, uint32_t thread_id) { - unique_lock lk{lock_}; - channels_.Add(channel, me, thread_id); +void ChannelStore::ChannelMap::DeleteAll() { + for (auto [k, ptr] : *this) + delete ptr.Get(); } -void ChannelStore::AddPatternSub(string_view pattern, ConnectionContext* me, uint32_t thread_id) { - unique_lock lk{lock_}; - patterns_.Add(pattern, me, thread_id); +ChannelStore::ChannelStore() : channels_{new ChannelMap{}}, patterns_{new ChannelMap{}} { + control_block.most_recent = this; } -void ChannelStore::RemoveSub(string_view channel, ConnectionContext* me) { - unique_lock lk{lock_}; - channels_.Remove(channel, me); +ChannelStore::ChannelStore(ChannelMap* channels, ChannelMap* patterns) + : channels_{channels}, patterns_{patterns} { } -void ChannelStore::RemovePatternSub(string_view pattern, ConnectionContext* me) { - unique_lock lk{lock_}; - patterns_.Remove(pattern, me); +void ChannelStore::Destroy() { + control_block.update_mu.lock(); + control_block.update_mu.unlock(); + + auto* store = control_block.most_recent.load(memory_order_relaxed); + for (auto* chan_map : {store->channels_, store->patterns_}) { + chan_map->DeleteAll(); + delete chan_map; + } + delete control_block.most_recent; } -vector ChannelStore::FetchSubscribers(string_view channel) { - shared_lock lk{lock_}; +ChannelStore::ControlBlock ChannelStore::control_block; + +vector ChannelStore::FetchSubscribers(string_view channel) const { vector res; - if (auto it = channels_.find(channel); it != channels_.end()) { + if (auto it = channels_->find(channel); it != channels_->end()) Fill(*it->second, string{}, &res); - } - for (const auto& [pat, subs] : patterns_) { - if (stringmatchlen(pat.data(), pat.size(), channel.data(), channel.size(), 0) == 1) { + for (const auto& [pat, subs] : *patterns_) { + if (Matches(pat, channel)) Fill(*subs, pat, &res); - } } sort(res.begin(), res.end(), Subscriber::ByThread); @@ -90,20 +130,115 @@ void ChannelStore::Fill(const SubscribeMap& src, const string& pattern, vector ChannelStore::ListChannels(const string_view pattern) const { - shared_lock lk{lock_}; vector res; - for (const auto& [channel, _] : channels_) { - if (pattern.empty() || - stringmatchlen(pattern.data(), pattern.size(), channel.data(), channel.size(), 0) == 1) { + for (const auto& [channel, _] : *channels_) { + if (pattern.empty() || Matches(pattern, channel)) res.push_back(channel); - } } return res; } size_t ChannelStore::PatternCount() const { - shared_lock lk{lock_}; - return patterns_.size(); + return patterns_->size(); +} + +ChannelStoreUpdater::ChannelStoreUpdater(bool pattern, bool to_add, ConnectionContext* cntx, + uint32_t thread_id) + : pattern_{pattern}, to_add_{to_add}, cntx_{cntx}, thread_id_{thread_id} { +} + +void ChannelStoreUpdater::Record(string_view key) { + ops_.emplace_back(key); +} + +pair ChannelStoreUpdater::GetTargetMap(ChannelStore* store) { + auto* target = pattern_ ? store->patterns_ : store->channels_; + + for (auto key : ops_) { + auto it = target->find(key); + DCHECK(it != target->end() || to_add_); + // We need to make a copy, if we are going to add or delete new map slot. + if ((to_add_ && it == target->end()) || (!to_add_ && it->second->size() == 1)) + return {new ChannelStore::ChannelMap{*target}, true}; + } + + return {target, false}; +} + +void ChannelStoreUpdater::Modify(ChannelMap* target, string_view key) { + using SubscribeMap = ChannelStore::SubscribeMap; + + auto it = target->find(key); + + // New key, add new slot. + if (to_add_ && it == target->end()) { + target->emplace(key, new SubscribeMap{{cntx_, thread_id_}}); + return; + } + + // Last entry for key, remove slot. + if (!to_add_ && it->second->size() == 1) { + DCHECK(it->second->begin()->first == cntx_); + freelist_.push_back(it->second.Get()); + target->erase(it); + return; + } + + // RCU update existing SubscribeMap entry. + DCHECK(it->second->size() > 0); + auto* replacement = new SubscribeMap{*it->second}; + if (to_add_) + replacement->emplace(cntx_, thread_id_); + else + replacement->erase(cntx_); + + // The pointer can still be in use, so delay freeing it + // until the dispatch and update the slot atomically. + freelist_.push_back(it->second.Get()); + it->second.Set(replacement); +} + +void ChannelStoreUpdater::Apply() { + // Wait for other updates to finish, lock the control block and update store pointer. + auto& cb = ChannelStore::control_block; + cb.update_mu.lock(); + auto* store = cb.most_recent.load(memory_order_relaxed); + + // Get target map (copied if needed) and apply operations. + auto [target, copied] = GetTargetMap(store); + for (auto key : ops_) + Modify(target, key); + + // Prepare replacement. + auto* replacement = store; + if (copied) { + auto* new_chans = pattern_ ? store->channels_ : target; + auto* new_patterns = pattern_ ? target : store->patterns_; + replacement = new ChannelStore{new_chans, new_patterns}; + } + + // Update control block and unlock it. + cb.most_recent.store(replacement, memory_order_relaxed); + cb.update_mu.unlock(); + + // Update thread local references. Readers fetch subscribers via FetchSubscribers, + // which runs without preemption, and store references to them in self container Subscriber + // structs. This means that any point on the other thread is safe to update the channel store. + // Regardless of whether we need to replace, we dispatch to make sure all + // queued SubscribeMaps in the freelist are no longer in use. + shard_set->pool()->Await([](unsigned idx, util::ProactorBase*) { + ServerState::tlocal()->UpdateChannelStore( + ChannelStore::control_block.most_recent.load(memory_order_relaxed)); + }); + + // Delete previous map and channel store. + if (copied) { + delete (pattern_ ? store->patterns_ : store->channels_); + delete store; + } + + for (auto ptr : freelist_) + delete ptr; } } // namespace dfly diff --git a/src/server/channel_store.h b/src/server/channel_store.h index 2b23999ce..767956cea 100644 --- a/src/server/channel_store.h +++ b/src/server/channel_store.h @@ -4,28 +4,45 @@ #pragma once #include -#include +#include #include #include "server/conn_context.h" namespace dfly { -// Centralized store holding pubsub subscribers. All public functions are thread safe. +class ChannelStoreUpdater; + +// ChannelStore manages PUB/SUB subscriptions. +// +// Updates are carried out via RCU (read-copy-update). Each thread stores a pointer to ChannelStore +// in its local ServerState and uses it for reads. Whenever an update needs to be performed, +// a new ChannelStore is constructed with the requested modifications and broadcasted to all +// threads. +// +// ServerState ChannelStore* -> ChannelMap* -> atomic (cntx -> thread) +// +// Specifically, whenever a new channel is registered or a channel is removed fully, +// a new ChannelMap for the specified type (channel/pattern) needs to be constructed. However, if +// only a single SubscribeMap is modified (no map ChannelMap slots are added or removed), +// we can update only it with a simpler version of RCU, as SubscribeMap is stored as an atomic +// pointer inside ChannelMap. +// +// To prevent parallel (and thus overlapping) updates, a centralized ControlBlock is used. +// Update operations are carried out by the ChannelStoreUpdater. +// +// A centralized ChannelStore, contrary to sharded storage, avoids contention on a single shard +// thread for heavy throughput on a single channel and thus seamlessly scales on multiple threads +// even with a small number of channels. In general, it has a slightly lower latency, due to the +// fact that no hop is required to fetch the subscribers. class ChannelStore { + friend class ChannelStoreUpdater; + public: struct Subscriber { - ConnectionContext* conn_cntx; - util::fibers_ext::BlockingCounter borrow_token; - uint32_t thread_id; - - // non-empty if was registered via psubscribe - std::string pattern; - Subscriber(ConnectionContext* cntx, uint32_t tid); Subscriber(uint32_t tid); - // Subscriber() : borrow_token(0) {} Subscriber(Subscriber&&) noexcept = default; Subscriber& operator=(Subscriber&&) noexcept = default; @@ -34,39 +51,108 @@ class ChannelStore { void operator=(const Subscriber&) = delete; // Sort by thread-id. Subscriber without owner comes first. - static bool ByThread(const Subscriber& lhs, const Subscriber& rhs) { - if (lhs.thread_id == rhs.thread_id) - return (lhs.conn_cntx != nullptr) < (rhs.conn_cntx != nullptr); - return lhs.thread_id < rhs.thread_id; - } + static bool ByThread(const Subscriber& lhs, const Subscriber& rhs); + + ConnectionContext* conn_cntx; + util::fibers_ext::BlockingCounter borrow_token; // to keep connection alive + uint32_t thread_id; + std::string pattern; // non-empty if registered via psubscribe }; - void AddSub(std::string_view channel, ConnectionContext* me, uint32_t thread_id); - void RemoveSub(std::string_view channel, ConnectionContext* me); + ChannelStore(); - void AddPatternSub(std::string_view pattern, ConnectionContext* me, uint32_t thread_id); - void RemovePatternSub(std::string_view pattern, ConnectionContext* me); - - std::vector FetchSubscribers(std::string_view channel); + // Fetch all subscribers for channel, including matching patterns. + std::vector FetchSubscribers(std::string_view channel) const; std::vector ListChannels(const std::string_view pattern) const; size_t PatternCount() const; + // Destroy current instance and delete it. + static void Destroy(); + private: using ThreadId = unsigned; + + // Subscribers for a single channel/pattern. using SubscribeMap = absl::flat_hash_map; - struct ChannelMap : absl::flat_hash_map> { + // Wrapper around atomic pointer that allows copying and moving. + // Made to overcome restrictions of absl::flat_hash_map. + // Copy/Move don't need to be atomic with RCU. + struct UpdatablePointer { + UpdatablePointer(SubscribeMap* sm) : ptr{sm} { + } + + UpdatablePointer(const UpdatablePointer& other); + + SubscribeMap* Get() const; + void Set(SubscribeMap* sm); + + SubscribeMap* operator->(); + const SubscribeMap& operator*() const; + + private: + std::atomic ptr; + }; + + // SubscriberMaps for channels/patterns. + struct ChannelMap : absl::flat_hash_map { void Add(std::string_view key, ConnectionContext* me, uint32_t thread_id); void Remove(std::string_view key, ConnectionContext* me); + + // Delete all stored SubscribeMap pointers. + void DeleteAll(); }; + // Centralized controller to prevent overlaping updates. + struct ControlBlock { + std::atomic most_recent; + ::boost::fibers::mutex update_mu; // locked during updates. + }; + + private: + static ControlBlock control_block; + + ChannelStore(ChannelMap* channels, ChannelMap* patterns); + static void Fill(const SubscribeMap& src, const std::string& pattern, std::vector* out); - mutable folly::RWSpinLock lock_; - ChannelMap channels_; - ChannelMap patterns_; + ChannelMap* channels_; + ChannelMap* patterns_; +}; + +// Performs RCU (read-copy-update) updates to the channel store. +// See ChannelStore header top for design details. +// Queues operations and performs them with Apply(). +class ChannelStoreUpdater { + public: + ChannelStoreUpdater(bool pattern, bool to_add, ConnectionContext* cntx, uint32_t thread_id); + + void Record(std::string_view key); + void Apply(); + + private: + using ChannelMap = ChannelStore::ChannelMap; + + // Get target map and flag whether it was copied. + // Must be called with locked control block. + std::pair GetTargetMap(ChannelStore* store); + + // Apply modify operation to target map. + void Modify(ChannelMap* target, std::string_view key); + + private: + bool pattern_; + bool to_add_; + ConnectionContext* cntx_; + uint32_t thread_id_; + + // Pending operations. + std::vector ops_; + + // Replaced SubscribeMaps that need to be deleted safely. + std::vector freelist_; }; } // namespace dfly diff --git a/src/server/conn_context.cc b/src/server/conn_context.cc index 55c3d67a8..d8d526c14 100644 --- a/src/server/conn_context.cc +++ b/src/server/conn_context.cc @@ -53,7 +53,7 @@ void ConnectionContext::ChangeMonitor(bool start) { } vector ChangeSubscriptions(bool pattern, CmdArgList args, bool to_add, bool to_reply, - ConnectionContext* conn, ChannelStore* store) { + ConnectionContext* conn) { vector result(to_reply ? args.size() : 0, 0); auto& conn_state = conn->conn_state; @@ -70,27 +70,25 @@ vector ChangeSubscriptions(bool pattern, CmdArgList args, bool to_add, auto& sinfo = *conn->conn_state.subscribe_info.get(); auto& local_store = pattern ? sinfo.patterns : sinfo.channels; - auto sadd = pattern ? &ChannelStore::AddPatternSub : &ChannelStore::AddSub; - auto sremove = pattern ? &ChannelStore::RemovePatternSub : &ChannelStore::RemoveSub; - int32_t tid = util::ProactorBase::GetIndex(); DCHECK_GE(tid, 0); + ChannelStoreUpdater csu{pattern, to_add, conn, uint32_t(tid)}; + // Gather all the channels we need to subscribe to / remove. for (size_t i = 0; i < args.size(); ++i) { string_view channel = ArgS(args, i); - if (to_add) { - if (local_store.emplace(channel).second) - (store->*sadd)(channel, conn, tid); - } else { - if (local_store.erase(channel) > 0) - (store->*sremove)(channel, conn); - } + if (to_add && local_store.emplace(channel).second) + csu.Record(channel); + else if (!to_add && local_store.erase(channel) > 0) + csu.Record(channel); if (to_reply) result[i] = sinfo.SubscriptionCount(); } + csu.Apply(); + // Important to reset conn_state.subscribe_info only after all references to it were // removed. if (!to_add && conn_state.subscribe_info->IsEmpty()) { @@ -101,9 +99,8 @@ vector ChangeSubscriptions(bool pattern, CmdArgList args, bool to_add, return result; } -void ConnectionContext::ChangeSubscription(ChannelStore* store, bool to_add, bool to_reply, - CmdArgList args) { - vector result = ChangeSubscriptions(false, args, to_add, to_reply, this, store); +void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgList args) { + vector result = ChangeSubscriptions(false, args, to_add, to_reply, this); if (to_reply) { for (size_t i = 0; i < result.size(); ++i) { @@ -112,9 +109,8 @@ void ConnectionContext::ChangeSubscription(ChannelStore* store, bool to_add, boo } } -void ConnectionContext::ChangePSubscription(ChannelStore* store, bool to_add, bool to_reply, - CmdArgList args) { - vector result = ChangeSubscriptions(true, args, to_add, to_reply, this, store); +void ConnectionContext::ChangePSubscription(bool to_add, bool to_reply, CmdArgList args) { + vector result = ChangeSubscriptions(true, args, to_add, to_reply, this); if (to_reply) { const char* action[2] = {"punsubscribe", "psubscribe"}; @@ -128,17 +124,17 @@ void ConnectionContext::ChangePSubscription(ChannelStore* store, bool to_add, bo } } -void ConnectionContext::UnsubscribeAll(ChannelStore* store, bool to_reply) { +void ConnectionContext::UnsubscribeAll(bool to_reply) { if (to_reply && (!conn_state.subscribe_info || conn_state.subscribe_info->channels.empty())) { return SendSubscriptionChangedResponse("unsubscribe", std::nullopt, 0); } StringVec channels(conn_state.subscribe_info->channels.begin(), conn_state.subscribe_info->channels.end()); CmdArgVec arg_vec(channels.begin(), channels.end()); - ChangeSubscription(store, false, to_reply, CmdArgList{arg_vec}); + ChangeSubscription(false, to_reply, CmdArgList{arg_vec}); } -void ConnectionContext::PUnsubscribeAll(ChannelStore* store, bool to_reply) { +void ConnectionContext::PUnsubscribeAll(bool to_reply) { if (to_reply && (!conn_state.subscribe_info || conn_state.subscribe_info->patterns.empty())) { return SendSubscriptionChangedResponse("punsubscribe", std::nullopt, 0); } @@ -146,7 +142,7 @@ void ConnectionContext::PUnsubscribeAll(ChannelStore* store, bool to_reply) { StringVec patterns(conn_state.subscribe_info->patterns.begin(), conn_state.subscribe_info->patterns.end()); CmdArgVec arg_vec(patterns.begin(), patterns.end()); - ChangePSubscription(store, false, to_reply, CmdArgList{arg_vec}); + ChangePSubscription(false, to_reply, CmdArgList{arg_vec}); } void ConnectionContext::SendSubscriptionChangedResponse(string_view action, diff --git a/src/server/conn_context.h b/src/server/conn_context.h index 57ab5dbb8..d655c2006 100644 --- a/src/server/conn_context.h +++ b/src/server/conn_context.h @@ -134,10 +134,10 @@ class ConnectionContext : public facade::ConnectionContext { return conn_state.db_index; } - void ChangeSubscription(ChannelStore* store, bool to_add, bool to_reply, CmdArgList args); - void ChangePSubscription(ChannelStore* store, bool to_add, bool to_reply, CmdArgList args); - void UnsubscribeAll(ChannelStore* store, bool to_reply); - void PUnsubscribeAll(ChannelStore* store, bool to_reply); + void ChangeSubscription(bool to_add, bool to_reply, CmdArgList args); + void ChangePSubscription(bool to_add, bool to_reply, CmdArgList args); + void UnsubscribeAll(bool to_reply); + void PUnsubscribeAll(bool to_reply); void ChangeMonitor(bool start); // either start or stop monitor on a given connection bool is_replicating = false; diff --git a/src/server/dragonfly_test.cc b/src/server/dragonfly_test.cc index 32a115ff3..45eaa49d0 100644 --- a/src/server/dragonfly_test.cc +++ b/src/server/dragonfly_test.cc @@ -399,6 +399,8 @@ TEST_F(DflyEngineTest, PSubscribe) { resp = pp_->at(0)->Await([&] { return Run({"publish", "ab", "foo"}); }); EXPECT_THAT(resp, IntArg(1)); + pp_->AwaitFiberOnAll([](ProactorBase* pb) {}); + ASSERT_EQ(1, SubscriberMessagesLen("IO1")); const facade::Connection::PubMessage& msg = GetPublishedMessage("IO1", 0); diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 5911f62eb..6919abffd 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -525,6 +525,10 @@ void Service::Init(util::AcceptServer* acceptor, util::ListenerInterface* main_i StringFamily::Init(&pp_); GenericFamily::Init(&pp_); server_family_.Init(acceptor, main_interface); + + ChannelStore* cs = new ChannelStore{}; + pp_.Await( + [cs](uint32_t index, ProactorBase* pb) { ServerState::tlocal()->UpdateChannelStore(cs); }); } void Service::Shutdown() { @@ -545,6 +549,8 @@ void Service::Shutdown() { engine_varz.reset(); request_latency_usec.Shutdown(); + ChannelStore::Destroy(); + shard_set->Shutdown(); pp_.Await([](ProactorBase* pb) { ServerState::tlocal()->Destroy(); }); @@ -1381,64 +1387,63 @@ void Service::Exec(CmdArgList args, ConnectionContext* cntx) { } void Service::Publish(CmdArgList args, ConnectionContext* cntx) { - auto* store = server_family_.channel_store(); string_view channel = ArgS(args, 1); - shared_ptr msg_ptr = make_shared(ArgS(args, 2)); - shared_ptr channel_ptr = make_shared(channel); + auto* cs = ServerState::tlocal()->channel_store(); + vector subscribers = cs->FetchSubscribers(channel); + int num_published = subscribers.size(); - auto clients = store->FetchSubscribers(channel); + if (!subscribers.empty()) { + auto subscribers_ptr = make_shared(move(subscribers)); + auto msg_ptr = make_shared(ArgS(args, 2)); + auto channel_ptr = make_shared(channel); - atomic_uint32_t published{0}; - auto cb = [&published, &clients, msg_ptr, channel_ptr](unsigned idx, util::ProactorBase*) { - auto it = lower_bound(clients.begin(), clients.end(), idx, ChannelStore::Subscriber::ByThread); - while (it != clients.end() && it->thread_id == idx) { - facade::Connection* conn = it->conn_cntx->owner(); - DCHECK(conn); + auto cb = [subscribers_ptr, msg_ptr, channel_ptr](unsigned idx, util::ProactorBase*) { + auto it = lower_bound(subscribers_ptr->begin(), subscribers_ptr->end(), idx, + ChannelStore::Subscriber::ByThread); - conn->SendMsgVecAsync({move(it->pattern), move(channel_ptr), move(msg_ptr)}); - published.fetch_add(1, memory_order_relaxed); - it++; - } - }; - shard_set->pool()->Await(std::move(cb)); - - for (auto& c : clients) { - c.borrow_token.Dec(); + while (it != subscribers_ptr->end() && it->thread_id == idx) { + facade::Connection* conn = it->conn_cntx->owner(); + DCHECK(conn); + conn->SendMsgVecAsync({move(it->pattern), move(channel_ptr), move(msg_ptr)}); + it->borrow_token.Dec(); + it++; + } + }; + shard_set->pool()->DispatchBrief(std::move(cb)); } - (*cntx)->SendLong(published.load(memory_order_relaxed)); + (*cntx)->SendLong(num_published); } void Service::Subscribe(CmdArgList args, ConnectionContext* cntx) { args.remove_prefix(1); - cntx->ChangeSubscription(server_family_.channel_store(), true /*add*/, true /* reply*/, - std::move(args)); + cntx->ChangeSubscription(true /*add*/, true /* reply*/, std::move(args)); } void Service::Unsubscribe(CmdArgList args, ConnectionContext* cntx) { args.remove_prefix(1); if (args.size() == 0) { - cntx->UnsubscribeAll(server_family_.channel_store(), true); + cntx->UnsubscribeAll(true); } else { - cntx->ChangeSubscription(server_family_.channel_store(), false, true, args); + cntx->ChangeSubscription(false, true, args); } } void Service::PSubscribe(CmdArgList args, ConnectionContext* cntx) { args.remove_prefix(1); - cntx->ChangePSubscription(server_family_.channel_store(), true, true, args); + cntx->ChangePSubscription(true, true, args); } void Service::PUnsubscribe(CmdArgList args, ConnectionContext* cntx) { args.remove_prefix(1); if (args.size() == 0) { - cntx->PUnsubscribeAll(server_family_.channel_store(), true); + cntx->PUnsubscribeAll(true); } else { - cntx->ChangePSubscription(server_family_.channel_store(), false, true, args); + cntx->ChangePSubscription(false, true, args); } } @@ -1457,11 +1462,11 @@ void Service::Function(CmdArgList args, ConnectionContext* cntx) { } void Service::PubsubChannels(string_view pattern, ConnectionContext* cntx) { - (*cntx)->SendStringArr(server_family_.channel_store()->ListChannels(pattern)); + (*cntx)->SendStringArr(ServerState::tlocal()->channel_store()->ListChannels(pattern)); } void Service::PubsubPatterns(ConnectionContext* cntx) { - size_t pattern_count = server_family_.channel_store()->PatternCount(); + size_t pattern_count = ServerState::tlocal()->channel_store()->PatternCount(); (*cntx)->SendLong(pattern_count); } @@ -1551,7 +1556,7 @@ void Service::OnClose(facade::ConnectionContext* cntx) { if (conn_state.subscribe_info) { // Clean-ups related to PUBSUB if (!conn_state.subscribe_info->channels.empty()) { auto token = conn_state.subscribe_info->borrow_token; - server_cntx->UnsubscribeAll(server_family_.channel_store(), false); + server_cntx->UnsubscribeAll(false); // Check that all borrowers finished processing. // token is increased in channel_slice (the publisher side). @@ -1561,7 +1566,7 @@ void Service::OnClose(facade::ConnectionContext* cntx) { if (conn_state.subscribe_info) { DCHECK(!conn_state.subscribe_info->patterns.empty()); auto token = conn_state.subscribe_info->borrow_token; - server_cntx->PUnsubscribeAll(server_family_.channel_store(), false); + server_cntx->PUnsubscribeAll(false); // Check that all borrowers finished processing token.Wait(); DCHECK(!conn_state.subscribe_info); diff --git a/src/server/server_family.cc b/src/server/server_family.cc index f18109104..4139c86e1 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -357,7 +357,6 @@ ServerFamily::ServerFamily(Service* service) : service_(*service) { last_save_info_ = make_shared(); last_save_info_->save_time = start_time_; script_mgr_.reset(new ScriptMgr()); - channel_store_.reset(new ChannelStore()); journal_.reset(new journal::Journal()); { diff --git a/src/server/server_family.h b/src/server/server_family.h index 3163c8a11..13d435ce4 100644 --- a/src/server/server_family.h +++ b/src/server/server_family.h @@ -121,10 +121,6 @@ class ServerFamily { return journal_.get(); } - ChannelStore* channel_store() { - return channel_store_.get(); - } - void OnClose(ConnectionContext* cntx); void BreakOnShutdown(); @@ -182,7 +178,6 @@ class ServerFamily { std::unique_ptr script_mgr_; std::unique_ptr journal_; std::unique_ptr dfly_cmd_; - std::unique_ptr channel_store_; std::string master_id_; diff --git a/src/server/server_state.h b/src/server/server_state.h index d021fea1d..65af93cd9 100644 --- a/src/server/server_state.h +++ b/src/server/server_state.h @@ -110,11 +110,6 @@ class ServerState { // public struct - to allow initialization. state_->gstate_ = GlobalState::SHUTTING_DOWN; } - bool is_master = true; - std::string remote_client_id_; // for cluster support - - facade::ConnectionStats connection_stats; - void TxCountInc() { ++live_transactions_; } @@ -190,8 +185,22 @@ class ServerState { // public struct - to allow initialization. return thread_index_; } + ChannelStore* channel_store() const { + return channel_store_; + } + + void UpdateChannelStore(ChannelStore* replacement) { + channel_store_ = replacement; + } + + public: Stats stats; + bool is_master = true; + std::string remote_client_id_; // for cluster support + + facade::ConnectionStats connection_stats; + private: int64_t live_transactions_ = 0; mi_heap_t* data_heap_; @@ -200,6 +209,8 @@ class ServerState { // public struct - to allow initialization. InterpreterManager interpreter_mgr_; absl::flat_hash_map cached_script_params_; + ChannelStore* channel_store_; + GlobalState gstate_ = GlobalState::ACTIVE; using Counter = util::SlidingCounter<7>;