From 340ae9b459cb324dc7ddc099e279a2ce1efc8388 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 28 Mar 2025 16:20:16 +0000 Subject: [PATCH] WIP This handles membership changes for users that are not in current state event delta stream table (e.g. rejecting remote invites). --- synapse/handlers/sliding_sync/__init__.py | 10 ++ synapse/handlers/sliding_sync/room_lists.py | 20 +++ synapse/storage/databases/main/events.py | 1 + synapse/storage/databases/main/stream.py | 165 ++++++++++++++++++ .../client/sliding_sync/test_rooms_invites.py | 72 +++++++- 5 files changed, 267 insertions(+), 1 deletion(-) diff --git a/synapse/handlers/sliding_sync/__init__.py b/synapse/handlers/sliding_sync/__init__.py index 459d3c3e24..7297370393 100644 --- a/synapse/handlers/sliding_sync/__init__.py +++ b/synapse/handlers/sliding_sync/__init__.py @@ -245,12 +245,16 @@ class SlidingSyncHandler: to_token=to_token, ) + print("interested_rooms:", interested_rooms) + lists = interested_rooms.lists relevant_room_map = interested_rooms.relevant_room_map all_rooms = interested_rooms.all_rooms room_membership_for_user_map = interested_rooms.room_membership_for_user_map relevant_rooms_to_send_map = interested_rooms.relevant_rooms_to_send_map + print("relevant_rooms_to_send_map:", relevant_rooms_to_send_map) + # Fetch room data rooms: Dict[str, SlidingSyncResult.RoomResult] = {} @@ -274,6 +278,8 @@ class SlidingSyncHandler: is_dm=room_id in interested_rooms.dm_room_ids, ) + print("room_sync_result:", room_id, room_sync_result) + # Filter out empty room results during incremental sync if room_sync_result or not from_token: rooms[room_id] = room_sync_result @@ -856,6 +862,10 @@ class SlidingSyncHandler: # TODO: Limit the number of state events we're about to send down # the room, if its too many we should change this to an # `initial=True`? + + # TODO: We need to pull out membership changes if the user isn't in + # the room, i.e. to deal with rejecting remote invites. + deltas = await self.get_current_state_deltas_for_room( room_id=room_id, room_membership_for_user_at_to_token=room_membership_for_user_at_to_token, diff --git a/synapse/handlers/sliding_sync/room_lists.py b/synapse/handlers/sliding_sync/room_lists.py index a1730b7e05..7d88d01144 100644 --- a/synapse/handlers/sliding_sync/room_lists.py +++ b/synapse/handlers/sliding_sync/room_lists.py @@ -299,6 +299,8 @@ class SlidingSyncRoomLists: ) dm_room_ids = await self._get_dm_rooms_for_user(user_id) + print("newly left:", newly_left_room_map) + # Add back `newly_left` rooms (rooms left in the from -> to token range). # # We do this because `get_sliding_sync_rooms_for_user(...)` doesn't include @@ -1110,6 +1112,24 @@ class SlidingSyncRoomLists: newly_joined_room_ids: Set[str] = set() newly_left_room_map: Dict[str, RoomsForUserStateReset] = {} + if not from_token: + return (), {} + + changes = await self.store.get_sliding_sync_membership_changes( + user_id, + from_key=from_token.room_key, + to_key=to_token.room_key, + excluded_room_ids=self.rooms_to_exclude_globally, + ) + + for room_id, entry in changes.items(): + if entry.membership == Membership.JOIN: + newly_joined_room_ids.add(room_id) + elif entry.membership == Membership.LEAVE: + newly_left_room_map[room_id] = entry + + return newly_joined_room_ids, newly_left_room_map + # We need to figure out the # # - 1) Figure out which rooms are `newly_left` rooms (> `from_token` and <= `to_token`) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 26fbc1a483..852e5e2300 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -410,6 +410,7 @@ class PersistEventsStore: Returns: SlidingSyncTableChanges """ + print("Changes:", delta_state) to_insert = delta_state.to_insert to_delete = delta_state.to_delete diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 00e5208674..8d6205c876 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -80,6 +80,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine +from synapse.storage.roommember import RoomsForUserStateReset from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import PersistedEventPosition, RoomStreamToken, StrCollection from synapse.util.caches.descriptors import cached, cachedList @@ -1136,6 +1137,170 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if membership_change.room_id not in room_ids_to_exclude ] + @trace + async def get_sliding_sync_membership_changes( + self, + user_id: str, + from_key: RoomStreamToken, + to_key: RoomStreamToken, + excluded_room_ids: Optional[List[str]] = None, + ) -> Dict[str, RoomsForUserStateReset]: + # Start by ruling out cases where a DB query is not necessary. + if from_key == to_key: + return [] + + if from_key: + has_changed = self._membership_stream_cache.has_entity_changed( + user_id, int(from_key.stream) + ) + if not has_changed: + return [] + + room_ids_to_exclude: AbstractSet[str] = set() + if excluded_room_ids is not None: + room_ids_to_exclude = set(excluded_room_ids) + + def f(txn: LoggingTransaction) -> Dict[str, RoomsForUserStateReset]: + # To handle tokens with a non-empty instance_map we fetch more + # results than necessary and then filter down + min_from_id = from_key.stream + max_to_id = to_key.get_max_stream_pos() + + # This query looks at membership changes in + # `sliding_sync_membership_snapshots`. These will not include where + # users get state reset out of rooms, so we need to look for that + # case in `current_state_delta_stream`. + # + # TODO: Add an index a better index on sliding_sync_membership_snapshots + sql = """ + SELECT + room_id, + membership_event_id, + event_instance_name, + event_stream_ordering, + membership, + sender, + prev_membership, + room_version + FROM + ( + SELECT + room_id, + membership_event_id, + event_instance_name, + event_stream_ordering, + membership, + sender, + null AS prev_membership + FROM sliding_sync_membership_snapshots + + UNION + + SELECT + s.room_id, + e.event_id, + s.instance_name, + s.stream_id, + m.membership, + e.sender, + m_prev.membership AS prev_membership + FROM current_state_delta_stream AS s + LEFT JOIN events AS e ON e.event_id = s.event_id + LEFT JOIN room_memberships AS m ON m.event_id = s.event_id + LEFT JOIN room_memberships AS m_prev ON m_prev.event_id = s.prev_event_id + WHERE + s.type = ? + AND s.state_key = ? + AND s.event_id IS NULL + ) AS c + INNER JOIN rooms USING (room_id) + WHERE event_stream_ordering > ? AND event_stream_ordering <= ? + ORDER BY event_stream_ordering ASC + """ + + txn.execute( + sql, + (EventTypes.Member, user_id, min_from_id, max_to_id), + ) + + membership_changes: Dict[str, RoomsForUserStateReset] = {} + for ( + room_id, + membership_event_id, + event_instance_name, + event_stream_ordering, + membership, + sender, + prev_membership, + room_version_id, + ) in txn: + assert room_id is not None + assert event_stream_ordering is not None + + if room_id in room_ids_to_exclude: + continue + + print( + room_id, + membership_event_id, + event_instance_name, + event_stream_ordering, + membership, + sender, + prev_membership, + room_version_id, + ) + + if _filter_results_by_stream( + from_key, + to_key, + event_instance_name, + event_stream_ordering, + ): + # When the server leaves a room, it will insert new rows into the + # `current_state_delta_stream` table with `event_id = null` for all + # current state. This means we might already have a row for the + # leave event and then another for the same leave where the + # `event_id=null` but the `prev_event_id` is pointing back at the + # earlier leave event. We don't want to report the leave, if we + # already have a leave event. + if ( + membership_event_id is None + and prev_membership == Membership.LEAVE + ): + continue + + if membership_event_id is None and room_id in membership_changes: + continue + + if membership is None: + membership = Membership.LEAVE + + # TODO: If we see a JOIN we need to check if the user newly + # joined the room (instead of just changing their display + # name) + + membership_change = RoomsForUserStateReset( + room_id=room_id, + sender=sender, + membership=membership, + event_id=membership_event_id, + event_pos=PersistedEventPosition( + event_instance_name, event_stream_ordering + ), + room_version_id=room_version_id, + ) + + membership_changes[room_id] = membership_change + + return membership_changes + + membership_changes = await self.db_pool.runInteraction( + "get_sliding_sync_membership_changes", f + ) + + return membership_changes + @cancellable async def get_membership_changes_for_user( self, diff --git a/tests/rest/client/sliding_sync/test_rooms_invites.py b/tests/rest/client/sliding_sync/test_rooms_invites.py index 882762ca29..a3d0765771 100644 --- a/tests/rest/client/sliding_sync/test_rooms_invites.py +++ b/tests/rest/client/sliding_sync/test_rooms_invites.py @@ -18,13 +18,18 @@ from parameterized import parameterized_class from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin -from synapse.api.constants import EventTypes, HistoryVisibility +from synapse.api.constants import EventTypes, HistoryVisibility, Membership +from synapse.api.room_versions import RoomVersions +from synapse.events import make_event_from_dict +from synapse.handlers.room import EventContext from synapse.rest.client import login, room, sync from synapse.server import HomeServer from synapse.types import UserID +from synapse.types.handlers.sliding_sync import StateValues from synapse.util import Clock from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase +from tests.unittest import override_config logger = logging.getLogger(__name__) @@ -64,6 +69,7 @@ class SlidingSyncRoomsInvitesTestCase(SlidingSyncBase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.storage_controllers = hs.get_storage_controllers() + self.federation_handler = hs.get_federation_handler() super().prepare(reactor, clock, hs) @@ -526,3 +532,67 @@ class SlidingSyncRoomsInvitesTestCase(SlidingSyncBase): ], response_body["rooms"][room_id1]["invite_state"], ) + + @override_config({"federation_domain_whitelist": []}) + def test_reject_invite(self) -> None: + """Test that rejecting an invite gets sent down sliding sync""" + + user_id = self.register_user("user1", "pass") + user_tok = self.login(user_id, "pass") + + room_id = "!room:remote.server" + self._create_remote_invite_room_for_user(room_id, user_id) + + # Make the Sliding Sync request + sync_body = { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [(EventTypes.Member, StateValues.ME)], + "timeline_limit": 3, + } + } + } + response_body, from_token = self.do_sync(sync_body, tok=user_tok) + + self.assertIn(room_id, response_body["rooms"]) + print(response_body["rooms"][room_id]) + + self.helper.leave(room_id, user_id, tok=user_tok) + + response_body, _ = self.do_sync(sync_body, since=from_token, tok=user_tok) + + print(response_body["rooms"][room_id]) + self.assertIn(room_id, response_body["rooms"]) + + raise NotImplementedError() + + def _create_remote_invite_room_for_user( + self, + room_id: str, + user_id: str, + ) -> None: + invite_event_dict = { + "room_id": room_id, + "sender": "@inviter:remote.server", + "state_key": user_id, + "depth": 1, + "origin_server_ts": 1, + "type": EventTypes.Member, + "content": {"membership": Membership.INVITE}, + "auth_events": [], + "prev_events": [], + } + + invite_event = make_event_from_dict( + invite_event_dict, + room_version=RoomVersions.V10, + ) + + self.get_success( + self.federation_handler.on_invite_request( + "remote.server", + invite_event, + room_version=invite_event.room_version, + ) + )