From 7975848c36ba2b9c7b3cf96827b75f85d814c7a8 Mon Sep 17 00:00:00 2001 From: Roman Gershman Date: Sat, 11 Mar 2023 18:34:42 +0200 Subject: [PATCH] chore: Expose a corruption when connection writes interleaving messages The problem happens when a publisher sends a message and a new subscriber registers. In that case it sends "subscribe" response and the publish messages and those interleave sometimes. Signed-off-by: Roman Gershman --- src/facade/dragonfly_connection.cc | 4 ++++ src/facade/reply_builder.h | 2 ++ src/server/conn_context.cc | 13 ++++++---- tests/README.md | 4 ++++ tests/dragonfly/connection_test.py | 38 ++++++++++++++++++++++++++++++ 5 files changed, 56 insertions(+), 5 deletions(-) diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index 969862b42..601faefbf 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -275,7 +275,10 @@ void Connection::DispatchOperations::operator()(const PubMsgRecord& msg) { ++stats->async_writes_cnt; const PubMessage& pub_msg = msg.pub_msg; string_view arr[4]; + DCHECK(!rbuilder->is_sending); + rbuilder->is_sending = true; if (pub_msg.pattern.empty()) { + DVLOG(1) << "Sending message, from channel: " << pub_msg.channel << " " << *pub_msg.message; arr[0] = "message"; arr[1] = pub_msg.channel; arr[2] = *pub_msg.message; @@ -287,6 +290,7 @@ void Connection::DispatchOperations::operator()(const PubMsgRecord& msg) { arr[3] = *pub_msg.message; rbuilder->SendStringArr(absl::Span{arr, 4}); } + rbuilder->is_sending = false; } void Connection::DispatchOperations::operator()(Request::PipelineMsg& msg) { diff --git a/src/facade/reply_builder.h b/src/facade/reply_builder.h index e49541ff2..40c008bb7 100644 --- a/src/facade/reply_builder.h +++ b/src/facade/reply_builder.h @@ -82,6 +82,8 @@ class SinkReplyBuilder { virtual void SendStored() = 0; virtual void SendSetSkipped() = 0; + bool is_sending = false; + protected: void Send(const iovec* v, uint32_t len); diff --git a/src/server/conn_context.cc b/src/server/conn_context.cc index 6d54f80e1..2d400885e 100644 --- a/src/server/conn_context.cc +++ b/src/server/conn_context.cc @@ -136,16 +136,19 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis if (to_reply) { const char* action[2] = {"unsubscribe", "subscribe"}; - + facade::RedisReplyBuilder* rbuilder = this->operator->(); + DCHECK(!rbuilder->is_sending); + rbuilder->is_sending = true; for (size_t i = 0; i < result.size(); ++i) { - (*this)->StartArray(3); - (*this)->SendBulkString(action[to_add]); - (*this)->SendBulkString(ArgS(args, i)); // channel + rbuilder->StartArray(3); + rbuilder->SendBulkString(action[to_add]); + rbuilder->SendBulkString(ArgS(args, i)); // channel // number of subscribed channels for this connection *right after* // we subscribe. - (*this)->SendLong(result[i]); + rbuilder->SendLong(result[i]); } + rbuilder->is_sending = false; } } diff --git a/tests/README.md b/tests/README.md index e3db30f6e..dbb8af28f 100644 --- a/tests/README.md +++ b/tests/README.md @@ -38,6 +38,10 @@ pip install -r dragonfly/requirements.txt to run pytest, run: `pytest -xv dragonfly` +to run selectively, use: +`pytest -xv dragonfly -k ` +For more pytest flags [check here](https://fig.io/manual/pytest). + ## Writing tests The [Getting Started](https://docs.pytest.org/en/7.1.x/getting-started.html) guide is a great resource to become familiar with writing pytest test cases. diff --git a/tests/dragonfly/connection_test.py b/tests/dragonfly/connection_test.py index 58008012a..ffc014e4a 100644 --- a/tests/dragonfly/connection_test.py +++ b/tests/dragonfly/connection_test.py @@ -4,6 +4,8 @@ import asyncio import aioredis import async_timeout +from . import DflyInstance + async def run_monitor_eval(monitor, expected): async with monitor as mon: @@ -278,6 +280,42 @@ async def test_multi_pubsub(async_client): assert state, message +@pytest.mark.asyncio +async def test_subsribers_with_active_publisher(df_server: DflyInstance, max_connections=100): + # TODO: I am not how to customize the max connections for the pool. + async_pool = aioredis.ConnectionPool(host="localhost", port=df_server.port, + db=0, decode_responses=True, max_connections=max_connections) + + async def publish_worker(): + client = aioredis.Redis(connection_pool=async_pool) + for i in range(0, 2000): + await client.publish("channel", f"message-{i}") + await client.close() + + async def channel_reader(channel: aioredis.client.PubSub): + for i in range(0, 150): + try: + async with async_timeout.timeout(1): + message = await channel.get_message(ignore_subscribe_messages=True) + except asyncio.TimeoutError: + break + + async def subscribe_worker(): + client = aioredis.Redis(connection_pool=async_pool) + pubsub = client.pubsub() + async with pubsub as p: + await pubsub.subscribe("channel") + await channel_reader(pubsub) + await pubsub.unsubscribe("channel") + + # Create a publisher that sends constantly messages to the channel + # Then create subscribers that will subscribe to already active channel + pub_task = asyncio.create_task(publish_worker()) + await asyncio.gather(*(subscribe_worker() for _ in range(max_connections - 10))) + await pub_task + await async_pool.disconnect() + + @pytest.mark.asyncio async def test_big_command(df_server, size=8 * 1024): reader, writer = await asyncio.open_connection('127.0.0.1', df_server.port)