From 51df675c054369576ed9bde8d1865c904fa6056c Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Mon, 17 Mar 2025 12:21:45 -0400 Subject: [PATCH] MSC4140: don't cancel delayed state on own state (#17810) When a user sends a state event, do not cancel their own delayed events for the same piece of state. For context, see [the relevant section in the MSC](https://github.com/matrix-org/matrix-spec-proposals/blob/a09a883d9a013ac4b6ffddebd7ea87a827d211b9/proposals/4140-delayed-events-futures.md#delayed-state-events-are-cancelled-by-a-more-recent-state-event). --- changelog.d/17810.feature | 1 + synapse/handlers/delayed_events.py | 20 ++- .../storage/databases/main/delayed_events.py | 30 ++-- tests/rest/client/test_delayed_events.py | 143 ++++++++++++++---- 4 files changed, 158 insertions(+), 36 deletions(-) create mode 100644 changelog.d/17810.feature diff --git a/changelog.d/17810.feature b/changelog.d/17810.feature new file mode 100644 index 0000000000..5c65e54ceb --- /dev/null +++ b/changelog.d/17810.feature @@ -0,0 +1 @@ +Update MSC4140 implementation to no longer cancel a user's own delayed state events with an event type & state key that match a more recent state event sent by that user. diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index b3f40809a1..80cb1cec9b 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -191,18 +191,36 @@ class DelayedEventsHandler: async def _handle_state_deltas(self, deltas: List[StateDelta]) -> None: """ - Process current state deltas to cancel pending delayed events + Process current state deltas to cancel other users' pending delayed events that target the same state. """ for delta in deltas: + if delta.event_id is None: + logger.debug( + "Not handling delta for deleted state: %r %r", + delta.event_type, + delta.state_key, + ) + continue + logger.debug( "Handling: %r %r, %s", delta.event_type, delta.state_key, delta.event_id ) + event = await self._store.get_event( + delta.event_id, check_room_id=delta.room_id + ) + sender = UserID.from_string(event.sender) + next_send_ts = await self._store.cancel_delayed_state_events( room_id=delta.room_id, event_type=delta.event_type, state_key=delta.state_key, + not_from_localpart=( + sender.localpart + if sender.domain == self._config.server.server_name + else "" + ), ) if self._next_send_ts_changed(next_send_ts): diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index 1616e30e22..c88682d55c 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -424,25 +424,37 @@ class DelayedEventsStore(SQLBaseStore): room_id: str, event_type: str, state_key: str, + not_from_localpart: str, ) -> Optional[Timestamp]: """ Cancels all matching delayed state events, i.e. remove them as long as they haven't been processed. + Args: + room_id: The room ID to match against. + event_type: The event type to match against. + state_key: The state key to match against. + not_from_localpart: The localpart of a user whose delayed events to not cancel. + If set to the empty string, any users' delayed events may be cancelled. + Returns: The send time of the next delayed event to be sent, if any. """ def cancel_delayed_state_events_txn( txn: LoggingTransaction, ) -> Optional[Timestamp]: - self.db_pool.simple_delete_txn( - txn, - table="delayed_events", - keyvalues={ - "room_id": room_id, - "event_type": event_type, - "state_key": state_key, - "is_processed": False, - }, + txn.execute( + """ + DELETE FROM delayed_events + WHERE room_id = ? AND event_type = ? AND state_key = ? + AND user_localpart <> ? + AND NOT is_processed + """, + ( + room_id, + event_type, + state_key, + not_from_localpart, + ), ) return self._get_next_delayed_event_send_ts_txn(txn) diff --git a/tests/rest/client/test_delayed_events.py b/tests/rest/client/test_delayed_events.py index 2c938390c8..9f9d241f12 100644 --- a/tests/rest/client/test_delayed_events.py +++ b/tests/rest/client/test_delayed_events.py @@ -22,7 +22,8 @@ from parameterized import parameterized from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import Codes -from synapse.rest.client import delayed_events, room, versions +from synapse.rest import admin +from synapse.rest.client import delayed_events, login, room, versions from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock @@ -32,7 +33,6 @@ from tests.unittest import HomeserverTestCase PATH_PREFIX = "/_matrix/client/unstable/org.matrix.msc4140/delayed_events" -_HS_NAME = "red" _EVENT_TYPE = "com.example.test" @@ -54,23 +54,41 @@ class DelayedEventsUnstableSupportTestCase(HomeserverTestCase): class DelayedEventsTestCase(HomeserverTestCase): """Tests getting and managing delayed events.""" - servlets = [delayed_events.register_servlets, room.register_servlets] - user_id = f"@sid1:{_HS_NAME}" + servlets = [ + admin.register_servlets, + delayed_events.register_servlets, + login.register_servlets, + room.register_servlets, + ] def default_config(self) -> JsonDict: config = super().default_config() - config["server_name"] = _HS_NAME config["max_event_delay_duration"] = "24h" return config def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.user1_user_id = self.register_user("user1", "pass") + self.user1_access_token = self.login("user1", "pass") + self.user2_user_id = self.register_user("user2", "pass") + self.user2_access_token = self.login("user2", "pass") + self.room_id = self.helper.create_room_as( - self.user_id, + self.user1_user_id, + tok=self.user1_access_token, extra_content={ - "preset": "trusted_private_chat", + "preset": "public_chat", + "power_level_content_override": { + "events": { + _EVENT_TYPE: 0, + } + }, }, ) + self.helper.join( + room=self.room_id, user=self.user2_user_id, tok=self.user2_access_token + ) + def test_delayed_events_empty_on_startup(self) -> None: self.assertListEqual([], self._get_delayed_events()) @@ -85,6 +103,7 @@ class DelayedEventsTestCase(HomeserverTestCase): { setter_key: setter_expected, }, + self.user1_access_token, ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) events = self._get_delayed_events() @@ -94,7 +113,7 @@ class DelayedEventsTestCase(HomeserverTestCase): self.helper.get_state( self.room_id, _EVENT_TYPE, - "", + self.user1_access_token, state_key=state_key, expect_code=HTTPStatus.NOT_FOUND, ) @@ -104,7 +123,7 @@ class DelayedEventsTestCase(HomeserverTestCase): content = self.helper.get_state( self.room_id, _EVENT_TYPE, - "", + self.user1_access_token, state_key=state_key, ) self.assertEqual(setter_expected, content.get(setter_key), content) @@ -113,7 +132,7 @@ class DelayedEventsTestCase(HomeserverTestCase): {"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}} ) def test_get_delayed_events_ratelimit(self) -> None: - args = ("GET", PATH_PREFIX) + args = ("GET", PATH_PREFIX, b"", self.user1_access_token) channel = self.make_request(*args) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) @@ -123,7 +142,9 @@ class DelayedEventsTestCase(HomeserverTestCase): # Add the current user to the ratelimit overrides, allowing them no ratelimiting. self.get_success( - self.hs.get_datastores().main.set_ratelimit_for_user(self.user_id, 0, 0) + self.hs.get_datastores().main.set_ratelimit_for_user( + self.user1_user_id, 0, 0 + ) ) # Test that the request isn't ratelimited anymore. @@ -134,6 +155,7 @@ class DelayedEventsTestCase(HomeserverTestCase): channel = self.make_request( "POST", f"{PATH_PREFIX}/", + access_token=self.user1_access_token, ) self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result) @@ -141,6 +163,7 @@ class DelayedEventsTestCase(HomeserverTestCase): channel = self.make_request( "POST", f"{PATH_PREFIX}/abc", + access_token=self.user1_access_token, ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) self.assertEqual( @@ -153,6 +176,7 @@ class DelayedEventsTestCase(HomeserverTestCase): "POST", f"{PATH_PREFIX}/abc", {}, + self.user1_access_token, ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) self.assertEqual( @@ -165,6 +189,7 @@ class DelayedEventsTestCase(HomeserverTestCase): "POST", f"{PATH_PREFIX}/abc", {"action": "oops"}, + self.user1_access_token, ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) self.assertEqual( @@ -178,6 +203,7 @@ class DelayedEventsTestCase(HomeserverTestCase): "POST", f"{PATH_PREFIX}/abc", {"action": action}, + self.user1_access_token, ) self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result) @@ -192,6 +218,7 @@ class DelayedEventsTestCase(HomeserverTestCase): { setter_key: setter_expected, }, + self.user1_access_token, ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) delay_id = channel.json_body.get("delay_id") @@ -205,7 +232,7 @@ class DelayedEventsTestCase(HomeserverTestCase): self.helper.get_state( self.room_id, _EVENT_TYPE, - "", + self.user1_access_token, state_key=state_key, expect_code=HTTPStatus.NOT_FOUND, ) @@ -214,6 +241,7 @@ class DelayedEventsTestCase(HomeserverTestCase): "POST", f"{PATH_PREFIX}/{delay_id}", {"action": "cancel"}, + self.user1_access_token, ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertListEqual([], self._get_delayed_events()) @@ -222,7 +250,7 @@ class DelayedEventsTestCase(HomeserverTestCase): content = self.helper.get_state( self.room_id, _EVENT_TYPE, - "", + self.user1_access_token, state_key=state_key, expect_code=HTTPStatus.NOT_FOUND, ) @@ -237,6 +265,7 @@ class DelayedEventsTestCase(HomeserverTestCase): "POST", _get_path_for_delayed_send(self.room_id, _EVENT_TYPE, 100000), {}, + self.user1_access_token, ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) delay_id = channel.json_body.get("delay_id") @@ -247,6 +276,7 @@ class DelayedEventsTestCase(HomeserverTestCase): "POST", f"{PATH_PREFIX}/{delay_ids.pop(0)}", {"action": "cancel"}, + self.user1_access_token, ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) @@ -254,13 +284,16 @@ class DelayedEventsTestCase(HomeserverTestCase): "POST", f"{PATH_PREFIX}/{delay_ids.pop(0)}", {"action": "cancel"}, + self.user1_access_token, ) channel = self.make_request(*args) self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result) # Add the current user to the ratelimit overrides, allowing them no ratelimiting. self.get_success( - self.hs.get_datastores().main.set_ratelimit_for_user(self.user_id, 0, 0) + self.hs.get_datastores().main.set_ratelimit_for_user( + self.user1_user_id, 0, 0 + ) ) # Test that the request isn't ratelimited anymore. @@ -278,6 +311,7 @@ class DelayedEventsTestCase(HomeserverTestCase): { setter_key: setter_expected, }, + self.user1_access_token, ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) delay_id = channel.json_body.get("delay_id") @@ -291,7 +325,7 @@ class DelayedEventsTestCase(HomeserverTestCase): self.helper.get_state( self.room_id, _EVENT_TYPE, - "", + self.user1_access_token, state_key=state_key, expect_code=HTTPStatus.NOT_FOUND, ) @@ -300,13 +334,14 @@ class DelayedEventsTestCase(HomeserverTestCase): "POST", f"{PATH_PREFIX}/{delay_id}", {"action": "send"}, + self.user1_access_token, ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertListEqual([], self._get_delayed_events()) content = self.helper.get_state( self.room_id, _EVENT_TYPE, - "", + self.user1_access_token, state_key=state_key, ) self.assertEqual(setter_expected, content.get(setter_key), content) @@ -319,6 +354,7 @@ class DelayedEventsTestCase(HomeserverTestCase): "POST", _get_path_for_delayed_send(self.room_id, _EVENT_TYPE, 100000), {}, + self.user1_access_token, ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) delay_id = channel.json_body.get("delay_id") @@ -329,6 +365,7 @@ class DelayedEventsTestCase(HomeserverTestCase): "POST", f"{PATH_PREFIX}/{delay_ids.pop(0)}", {"action": "send"}, + self.user1_access_token, ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) @@ -336,13 +373,16 @@ class DelayedEventsTestCase(HomeserverTestCase): "POST", f"{PATH_PREFIX}/{delay_ids.pop(0)}", {"action": "send"}, + self.user1_access_token, ) channel = self.make_request(*args) self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result) # Add the current user to the ratelimit overrides, allowing them no ratelimiting. self.get_success( - self.hs.get_datastores().main.set_ratelimit_for_user(self.user_id, 0, 0) + self.hs.get_datastores().main.set_ratelimit_for_user( + self.user1_user_id, 0, 0 + ) ) # Test that the request isn't ratelimited anymore. @@ -360,6 +400,7 @@ class DelayedEventsTestCase(HomeserverTestCase): { setter_key: setter_expected, }, + self.user1_access_token, ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) delay_id = channel.json_body.get("delay_id") @@ -373,7 +414,7 @@ class DelayedEventsTestCase(HomeserverTestCase): self.helper.get_state( self.room_id, _EVENT_TYPE, - "", + self.user1_access_token, state_key=state_key, expect_code=HTTPStatus.NOT_FOUND, ) @@ -382,6 +423,7 @@ class DelayedEventsTestCase(HomeserverTestCase): "POST", f"{PATH_PREFIX}/{delay_id}", {"action": "restart"}, + self.user1_access_token, ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) @@ -393,7 +435,7 @@ class DelayedEventsTestCase(HomeserverTestCase): self.helper.get_state( self.room_id, _EVENT_TYPE, - "", + self.user1_access_token, state_key=state_key, expect_code=HTTPStatus.NOT_FOUND, ) @@ -403,7 +445,7 @@ class DelayedEventsTestCase(HomeserverTestCase): content = self.helper.get_state( self.room_id, _EVENT_TYPE, - "", + self.user1_access_token, state_key=state_key, ) self.assertEqual(setter_expected, content.get(setter_key), content) @@ -418,6 +460,7 @@ class DelayedEventsTestCase(HomeserverTestCase): "POST", _get_path_for_delayed_send(self.room_id, _EVENT_TYPE, 100000), {}, + self.user1_access_token, ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) delay_id = channel.json_body.get("delay_id") @@ -428,6 +471,7 @@ class DelayedEventsTestCase(HomeserverTestCase): "POST", f"{PATH_PREFIX}/{delay_ids.pop(0)}", {"action": "restart"}, + self.user1_access_token, ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) @@ -435,21 +479,66 @@ class DelayedEventsTestCase(HomeserverTestCase): "POST", f"{PATH_PREFIX}/{delay_ids.pop(0)}", {"action": "restart"}, + self.user1_access_token, ) channel = self.make_request(*args) self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result) # Add the current user to the ratelimit overrides, allowing them no ratelimiting. self.get_success( - self.hs.get_datastores().main.set_ratelimit_for_user(self.user_id, 0, 0) + self.hs.get_datastores().main.set_ratelimit_for_user( + self.user1_user_id, 0, 0 + ) ) # Test that the request isn't ratelimited anymore. channel = self.make_request(*args) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) - def test_delayed_state_events_are_cancelled_by_more_recent_state(self) -> None: - state_key = "to_be_cancelled" + def test_delayed_state_is_not_cancelled_by_new_state_from_same_user( + self, + ) -> None: + state_key = "to_not_be_cancelled_by_same_user" + + setter_key = "setter" + setter_expected = "on_timeout" + channel = self.make_request( + "PUT", + _get_path_for_delayed_state(self.room_id, _EVENT_TYPE, state_key, 900), + { + setter_key: setter_expected, + }, + self.user1_access_token, + ) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + events = self._get_delayed_events() + self.assertEqual(1, len(events), events) + + self.helper.send_state( + self.room_id, + _EVENT_TYPE, + { + setter_key: "manual", + }, + self.user1_access_token, + state_key=state_key, + ) + events = self._get_delayed_events() + self.assertEqual(1, len(events), events) + + self.reactor.advance(1) + content = self.helper.get_state( + self.room_id, + _EVENT_TYPE, + self.user1_access_token, + state_key=state_key, + ) + self.assertEqual(setter_expected, content.get(setter_key), content) + + def test_delayed_state_is_cancelled_by_new_state_from_other_user( + self, + ) -> None: + state_key = "to_be_cancelled_by_other_user" setter_key = "setter" channel = self.make_request( @@ -458,19 +547,20 @@ class DelayedEventsTestCase(HomeserverTestCase): { setter_key: "on_timeout", }, + self.user1_access_token, ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) events = self._get_delayed_events() self.assertEqual(1, len(events), events) - setter_expected = "manual" + setter_expected = "other_user" self.helper.send_state( self.room_id, _EVENT_TYPE, { setter_key: setter_expected, }, - None, + self.user2_access_token, state_key=state_key, ) self.assertListEqual([], self._get_delayed_events()) @@ -479,7 +569,7 @@ class DelayedEventsTestCase(HomeserverTestCase): content = self.helper.get_state( self.room_id, _EVENT_TYPE, - "", + self.user1_access_token, state_key=state_key, ) self.assertEqual(setter_expected, content.get(setter_key), content) @@ -488,6 +578,7 @@ class DelayedEventsTestCase(HomeserverTestCase): channel = self.make_request( "GET", PATH_PREFIX, + access_token=self.user1_access_token, ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)