Sliding sync: Add classes for per-connection state (#17574)

This is some prep work ahead of correctly tracking receipts, where we
will also want to track the room status in terms of last receipt we had
sent down.

Essentially, we add two classes `PerConnectionState` and a mutable
version, and then operate on those.

---------

Co-authored-by: Eric Eastwood <eric.eastwood@beta.gouv.fr>
This commit is contained in:
Erik Johnston 2024-08-19 20:09:41 +01:00 committed by GitHub
parent 993644ded0
commit 261e746281
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 196 additions and 100 deletions

1
changelog.d/17574.misc Normal file
View file

@ -0,0 +1 @@
Refactor per-connection state in experimental sliding sync handler.

View file

@ -19,6 +19,8 @@
# #
import enum import enum
import logging import logging
import typing
from collections import ChainMap
from enum import Enum from enum import Enum
from itertools import chain from itertools import chain
from typing import ( from typing import (
@ -30,11 +32,13 @@ from typing import (
List, List,
Literal, Literal,
Mapping, Mapping,
MutableMapping,
Optional, Optional,
Sequence, Sequence,
Set, Set,
Tuple, Tuple,
Union, Union,
cast,
) )
import attr import attr
@ -571,21 +575,21 @@ class SlidingSyncHandler:
# See https://github.com/matrix-org/matrix-doc/issues/1144 # See https://github.com/matrix-org/matrix-doc/issues/1144
raise NotImplementedError() raise NotImplementedError()
if from_token: # Get the per-connection state (if any).
# Check that we recognize the connection position, if not tell the
# clients that they need to start again.
# #
# If we don't do this and the client asks for the full range of # Raises an exception if there is a `connection_position` that we don't
# rooms, we end up sending down all rooms and their state from # recognize. If we don't do this and the client asks for the full range
# scratch (which can be very slow). By expiring the connection we # of rooms, we end up sending down all rooms and their state from
# allow the client a chance to do an initial request with a smaller # scratch (which can be very slow). By expiring the connection we allow
# range of rooms to get them some results sooner but will end up # the client a chance to do an initial request with a smaller range of
# taking the same amount of time (more with round-trips and # rooms to get them some results sooner but will end up taking the same
# re-processing) in the end to get everything again. # amount of time (more with round-trips and re-processing) in the end to
if not await self.connection_store.is_valid_token( # get everything again.
sync_config, from_token.connection_position previous_connection_state = (
): await self.connection_store.get_per_connection_state(
raise SlidingSyncUnknownPosition() sync_config, from_token
)
)
await self.connection_store.mark_token_seen( await self.connection_store.mark_token_seen(
sync_config=sync_config, sync_config=sync_config,
@ -781,11 +785,7 @@ class SlidingSyncHandler:
# we haven't sent the room down, or we have but there are missing # we haven't sent the room down, or we have but there are missing
# updates). # updates).
for room_id in relevant_room_map: for room_id in relevant_room_map:
status = await self.connection_store.have_sent_room( status = previous_connection_state.rooms.have_sent_room(room_id)
sync_config,
from_token.connection_position,
room_id,
)
if ( if (
# The room was never sent down before so the client needs to know # The room was never sent down before so the client needs to know
# about it regardless of any updates. # about it regardless of any updates.
@ -821,6 +821,7 @@ class SlidingSyncHandler:
async def handle_room(room_id: str) -> None: async def handle_room(room_id: str) -> None:
room_sync_result = await self.get_room_sync_data( room_sync_result = await self.get_room_sync_data(
sync_config=sync_config, sync_config=sync_config,
per_connection_state=previous_connection_state,
room_id=room_id, room_id=room_id,
room_sync_config=relevant_rooms_to_send_map[room_id], room_sync_config=relevant_rooms_to_send_map[room_id],
room_membership_for_user_at_to_token=room_membership_for_user_map[ room_membership_for_user_at_to_token=room_membership_for_user_map[
@ -853,6 +854,8 @@ class SlidingSyncHandler:
) )
if has_lists or has_room_subscriptions: if has_lists or has_room_subscriptions:
new_connection_state = previous_connection_state.get_mutable()
# We now calculate if any rooms outside the range have had updates, # We now calculate if any rooms outside the range have had updates,
# which we are not sending down. # which we are not sending down.
# #
@ -882,11 +885,18 @@ class SlidingSyncHandler:
) )
unsent_room_ids = list(missing_event_map_by_room) unsent_room_ids = list(missing_event_map_by_room)
connection_position = await self.connection_store.record_rooms( new_connection_state.rooms.record_unsent_rooms(
unsent_room_ids, from_token.stream_token
)
new_connection_state.rooms.record_sent_rooms(
relevant_rooms_to_send_map.keys()
)
connection_position = await self.connection_store.record_new_state(
sync_config=sync_config, sync_config=sync_config,
from_token=from_token, from_token=from_token,
sent_room_ids=relevant_rooms_to_send_map.keys(), per_connection_state=new_connection_state,
unsent_room_ids=unsent_room_ids,
) )
elif from_token: elif from_token:
connection_position = from_token.connection_position connection_position = from_token.connection_position
@ -1939,6 +1949,7 @@ class SlidingSyncHandler:
async def get_room_sync_data( async def get_room_sync_data(
self, self,
sync_config: SlidingSyncConfig, sync_config: SlidingSyncConfig,
per_connection_state: "PerConnectionState",
room_id: str, room_id: str,
room_sync_config: RoomSyncConfig, room_sync_config: RoomSyncConfig,
room_membership_for_user_at_to_token: _RoomMembershipForUser, room_membership_for_user_at_to_token: _RoomMembershipForUser,
@ -1986,11 +1997,7 @@ class SlidingSyncHandler:
from_bound = None from_bound = None
initial = True initial = True
if from_token and not room_membership_for_user_at_to_token.newly_joined: if from_token and not room_membership_for_user_at_to_token.newly_joined:
room_status = await self.connection_store.have_sent_room( room_status = per_connection_state.rooms.have_sent_room(room_id)
sync_config=sync_config,
connection_token=from_token.connection_position,
room_id=room_id,
)
if room_status.status == HaveSentRoomFlag.LIVE: if room_status.status == HaveSentRoomFlag.LIVE:
from_bound = from_token.stream_token.room_key from_bound = from_token.stream_token.room_key
initial = False initial = False
@ -3034,6 +3041,121 @@ HAVE_SENT_ROOM_NEVER = HaveSentRoom(HaveSentRoomFlag.NEVER, None)
HAVE_SENT_ROOM_LIVE = HaveSentRoom(HaveSentRoomFlag.LIVE, None) HAVE_SENT_ROOM_LIVE = HaveSentRoom(HaveSentRoomFlag.LIVE, None)
@attr.s(auto_attribs=True, slots=True, frozen=True)
class RoomStatusMap:
"""For a given stream, e.g. events, records what we have or have not sent
down for that stream in a given room."""
# `room_id` -> `HaveSentRoom`
_statuses: Mapping[str, HaveSentRoom] = attr.Factory(dict)
def have_sent_room(self, room_id: str) -> HaveSentRoom:
"""Return whether we have previously sent the room down"""
return self._statuses.get(room_id, HAVE_SENT_ROOM_NEVER)
def get_mutable(self) -> "MutableRoomStatusMap":
"""Get a mutable copy of this state."""
return MutableRoomStatusMap(
statuses=self._statuses,
)
def copy(self) -> "RoomStatusMap":
"""Make a copy of the class. Useful for converting from a mutable to
immutable version."""
return RoomStatusMap(statuses=dict(self._statuses))
class MutableRoomStatusMap(RoomStatusMap):
"""A mutable version of `RoomStatusMap`"""
# We use a ChainMap here so that we can easily track what has been updated
# and what hasn't. Note that when we persist the per connection state this
# will get flattened to a normal dict (via calling `.copy()`)
_statuses: typing.ChainMap[str, HaveSentRoom]
def __init__(
self,
statuses: Mapping[str, HaveSentRoom],
) -> None:
# ChainMap requires a mutable mapping, but we're not actually going to
# mutate it.
statuses = cast(MutableMapping, statuses)
super().__init__(
statuses=ChainMap({}, statuses),
)
def get_updates(self) -> Mapping[str, HaveSentRoom]:
"""Return only the changes that were made"""
return self._statuses.maps[0]
def record_sent_rooms(self, room_ids: StrCollection) -> None:
"""Record that we have sent these rooms in the response"""
for room_id in room_ids:
current_status = self._statuses.get(room_id, HAVE_SENT_ROOM_NEVER)
if current_status.status == HaveSentRoomFlag.LIVE:
continue
self._statuses[room_id] = HAVE_SENT_ROOM_LIVE
def record_unsent_rooms(
self, room_ids: StrCollection, from_token: StreamToken
) -> None:
"""Record that we have not sent these rooms in the response, but there
have been updates.
"""
# Whether we add/update the entries for unsent rooms depends on the
# existing entry:
# - LIVE: We have previously sent down everything up to
# `last_room_token, so we update the entry to be `PREVIOUSLY` with
# `last_room_token`.
# - PREVIOUSLY: We have previously sent down everything up to *a*
# given token, so we don't need to update the entry.
# - NEVER: We have never previously sent down the room, and we haven't
# sent anything down this time either so we leave it as NEVER.
for room_id in room_ids:
current_status = self._statuses.get(room_id, HAVE_SENT_ROOM_NEVER)
if current_status.status != HaveSentRoomFlag.LIVE:
continue
self._statuses[room_id] = HaveSentRoom.previously(from_token.room_key)
@attr.s(auto_attribs=True)
class PerConnectionState:
"""The per-connection state. A snapshot of what we've sent down the connection before.
Currently, we track whether we've sent down various aspects of a given room before.
We use the `rooms` field to store the position in the events stream for each room that we've previously sent to the client before. On the next request that includes the room, we can then send only what's changed since that recorded position.
Same goes for the `receipts` field so we only need to send the new receipts since the last time you made a sync request.
Attributes:
rooms: The status of each room for the events stream.
"""
rooms: RoomStatusMap = attr.Factory(RoomStatusMap)
def get_mutable(self) -> "MutablePerConnectionState":
"""Get a mutable copy of this state."""
return MutablePerConnectionState(
rooms=self.rooms.get_mutable(),
)
@attr.s(auto_attribs=True)
class MutablePerConnectionState(PerConnectionState):
"""A mutable version of `PerConnectionState`"""
rooms: MutableRoomStatusMap
def has_updates(self) -> bool:
return bool(self.rooms.get_updates())
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class SlidingSyncConnectionStore: class SlidingSyncConnectionStore:
"""In-memory store of per-connection state, including what rooms we have """In-memory store of per-connection state, including what rooms we have
@ -3063,9 +3185,9 @@ class SlidingSyncConnectionStore:
to mapping of room ID to `HaveSentRoom`. to mapping of room ID to `HaveSentRoom`.
""" """
# `(user_id, conn_id)` -> `token` -> `room_id` -> `HaveSentRoom` # `(user_id, conn_id)` -> `connection_position` -> `PerConnectionState`
_connections: Dict[Tuple[str, str], Dict[int, Dict[str, HaveSentRoom]]] = ( _connections: Dict[Tuple[str, str], Dict[int, PerConnectionState]] = attr.Factory(
attr.Factory(dict) dict
) )
async def is_valid_token( async def is_valid_token(
@ -3078,48 +3200,52 @@ class SlidingSyncConnectionStore:
conn_key = self._get_connection_key(sync_config) conn_key = self._get_connection_key(sync_config)
return connection_token in self._connections.get(conn_key, {}) return connection_token in self._connections.get(conn_key, {})
async def have_sent_room( async def get_per_connection_state(
self, sync_config: SlidingSyncConfig, connection_token: int, room_id: str
) -> HaveSentRoom:
"""For the given user_id/conn_id/token, return whether we have
previously sent the room down
"""
conn_key = self._get_connection_key(sync_config)
sync_statuses = self._connections.setdefault(conn_key, {})
room_status = sync_statuses.get(connection_token, {}).get(
room_id, HAVE_SENT_ROOM_NEVER
)
return room_status
@trace
async def record_rooms(
self, self,
sync_config: SlidingSyncConfig, sync_config: SlidingSyncConfig,
from_token: Optional[SlidingSyncStreamToken], from_token: Optional[SlidingSyncStreamToken],
*, ) -> PerConnectionState:
sent_room_ids: StrCollection, """Fetch the per-connection state for the token.
unsent_room_ids: StrCollection,
) -> int:
"""Record which rooms we have/haven't sent down in a new response
Attributes: Raises:
sync_config SlidingSyncUnknownPosition if the connection_token is unknown
from_token: The since token from the request, if any """
sent_room_ids: The set of room IDs that we have sent down as if from_token is None:
part of this request (only needs to be ones we didn't return PerConnectionState()
previously sent down).
unsent_room_ids: The set of room IDs that have had updates connection_position = from_token.connection_position
since the `from_token`, but which were not included in if connection_position == 0:
this request # Initial sync (request without a `from_token`) starts at `0` so
# there is no existing per-connection state
return PerConnectionState()
conn_key = self._get_connection_key(sync_config)
sync_statuses = self._connections.get(conn_key, {})
connection_state = sync_statuses.get(connection_position)
if connection_state is None:
raise SlidingSyncUnknownPosition()
return connection_state
@trace
async def record_new_state(
self,
sync_config: SlidingSyncConfig,
from_token: Optional[SlidingSyncStreamToken],
per_connection_state: MutablePerConnectionState,
) -> int:
"""Record updated per-connection state, returning the connection
position associated with the new state.
If there are no changes to the state this may return the same token as
the existing per-connection state.
""" """
prev_connection_token = 0 prev_connection_token = 0
if from_token is not None: if from_token is not None:
prev_connection_token = from_token.connection_position prev_connection_token = from_token.connection_position
# If there are no changes then this is a noop. if not per_connection_state.has_updates():
if not sent_room_ids and not unsent_room_ids:
return prev_connection_token return prev_connection_token
conn_key = self._get_connection_key(sync_config) conn_key = self._get_connection_key(sync_config)
@ -3130,42 +3256,11 @@ class SlidingSyncConnectionStore:
new_store_token = prev_connection_token + 1 new_store_token = prev_connection_token + 1
sync_statuses.pop(new_store_token, None) sync_statuses.pop(new_store_token, None)
# Copy over and update the room mappings. # We copy the `MutablePerConnectionState` so that the inner `ChainMap`s
new_room_statuses = dict(sync_statuses.get(prev_connection_token, {})) # don't grow forever.
sync_statuses[new_store_token] = PerConnectionState(
# Whether we have updated the `new_room_statuses`, if we don't by the rooms=per_connection_state.rooms.copy(),
# end we can treat this as a noop. )
have_updated = False
for room_id in sent_room_ids:
new_room_statuses[room_id] = HAVE_SENT_ROOM_LIVE
have_updated = True
# Whether we add/update the entries for unsent rooms depends on the
# existing entry:
# - LIVE: We have previously sent down everything up to
# `last_room_token, so we update the entry to be `PREVIOUSLY` with
# `last_room_token`.
# - PREVIOUSLY: We have previously sent down everything up to *a*
# given token, so we don't need to update the entry.
# - NEVER: We have never previously sent down the room, and we haven't
# sent anything down this time either so we leave it as NEVER.
# Work out the new state for unsent rooms that were `LIVE`.
if from_token:
new_unsent_state = HaveSentRoom.previously(from_token.stream_token.room_key)
else:
new_unsent_state = HAVE_SENT_ROOM_NEVER
for room_id in unsent_room_ids:
prev_state = new_room_statuses.get(room_id)
if prev_state is not None and prev_state.status == HaveSentRoomFlag.LIVE:
new_room_statuses[room_id] = new_unsent_state
have_updated = True
if not have_updated:
return prev_connection_token
sync_statuses[new_store_token] = new_room_statuses
return new_store_token return new_store_token