Reintroduce "Reduce device lists replication traffic."" (#17361)

Reintroduces https://github.com/element-hq/synapse/pull/17333


Turns out the reason for revert was down two master instances running
This commit is contained in:
Erik Johnston 2024-06-25 10:34:34 +01:00 committed by GitHub
parent a98cb87bee
commit 554a92601a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 89 additions and 48 deletions

1
changelog.d/17333.misc Normal file
View file

@ -0,0 +1 @@
Handle device lists notifications for large accounts more efficiently in worker mode.

View file

@ -114,13 +114,19 @@ class ReplicationDataHandler:
""" """
all_room_ids: Set[str] = set() all_room_ids: Set[str] = set()
if stream_name == DeviceListsStream.NAME: if stream_name == DeviceListsStream.NAME:
if any(row.entity.startswith("@") and not row.is_signature for row in rows): if any(not row.is_signature and not row.hosts_calculated for row in rows):
prev_token = self.store.get_device_stream_token() prev_token = self.store.get_device_stream_token()
all_room_ids = await self.store.get_all_device_list_changes( all_room_ids = await self.store.get_all_device_list_changes(
prev_token, token prev_token, token
) )
self.store.device_lists_in_rooms_have_changed(all_room_ids, token) self.store.device_lists_in_rooms_have_changed(all_room_ids, token)
# If we're sending federation we need to update the device lists
# outbound pokes stream change cache with updated hosts.
if self.send_handler and any(row.hosts_calculated for row in rows):
hosts = await self.store.get_destinations_for_device(token)
self.store.device_lists_outbound_pokes_have_changed(hosts, token)
self.store.process_replication_rows(stream_name, instance_name, token, rows) self.store.process_replication_rows(stream_name, instance_name, token, rows)
# NOTE: this must be called after process_replication_rows to ensure any # NOTE: this must be called after process_replication_rows to ensure any
# cache invalidations are first handled before any stream ID advances. # cache invalidations are first handled before any stream ID advances.
@ -433,12 +439,11 @@ class FederationSenderHandler:
# The entities are either user IDs (starting with '@') whose devices # The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about # have changed, or remote servers that we need to tell about
# changes. # changes.
hosts = { if any(row.hosts_calculated for row in rows):
row.entity hosts = await self.store.get_destinations_for_device(token)
for row in rows await self.federation_sender.send_device_messages(
if not row.entity.startswith("@") and not row.is_signature hosts, immediate=False
} )
await self.federation_sender.send_device_messages(hosts, immediate=False)
elif stream_name == ToDeviceStream.NAME: elif stream_name == ToDeviceStream.NAME:
# The to_device stream includes stuff to be pushed to both local # The to_device stream includes stuff to be pushed to both local

View file

@ -549,10 +549,14 @@ class DeviceListsStream(_StreamFromIdGen):
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceListsStreamRow: class DeviceListsStreamRow:
entity: str user_id: str
# Indicates that a user has signed their own device with their user-signing key # Indicates that a user has signed their own device with their user-signing key
is_signature: bool is_signature: bool
# Indicates if this is a notification that we've calculated the hosts we
# need to send the update to.
hosts_calculated: bool
NAME = "device_lists" NAME = "device_lists"
ROW_TYPE = DeviceListsStreamRow ROW_TYPE = DeviceListsStreamRow
@ -594,13 +598,13 @@ class DeviceListsStream(_StreamFromIdGen):
upper_limit_token = min(upper_limit_token, signatures_to_token) upper_limit_token = min(upper_limit_token, signatures_to_token)
device_updates = [ device_updates = [
(stream_id, (entity, False)) (stream_id, (entity, False, hosts))
for stream_id, (entity,) in device_updates for stream_id, (entity, hosts) in device_updates
if stream_id <= upper_limit_token if stream_id <= upper_limit_token
] ]
signatures_updates = [ signatures_updates = [
(stream_id, (entity, True)) (stream_id, (entity, True, False))
for stream_id, (entity,) in signatures_updates for stream_id, (entity,) in signatures_updates
if stream_id <= upper_limit_token if stream_id <= upper_limit_token
] ]

View file

@ -164,22 +164,24 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
prefilled_cache=user_signature_stream_prefill, prefilled_cache=user_signature_stream_prefill,
) )
( self._device_list_federation_stream_cache = None
device_list_federation_prefill, if hs.should_send_federation():
device_list_federation_list_id, (
) = self.db_pool.get_cache_dict( device_list_federation_prefill,
db_conn, device_list_federation_list_id,
"device_lists_outbound_pokes", ) = self.db_pool.get_cache_dict(
entity_column="destination", db_conn,
stream_column="stream_id", "device_lists_outbound_pokes",
max_value=device_list_max, entity_column="destination",
limit=10000, stream_column="stream_id",
) max_value=device_list_max,
self._device_list_federation_stream_cache = StreamChangeCache( limit=10000,
"DeviceListFederationStreamChangeCache", )
device_list_federation_list_id, self._device_list_federation_stream_cache = StreamChangeCache(
prefilled_cache=device_list_federation_prefill, "DeviceListFederationStreamChangeCache",
) device_list_federation_list_id,
prefilled_cache=device_list_federation_prefill,
)
if hs.config.worker.run_background_tasks: if hs.config.worker.run_background_tasks:
self._clock.looping_call( self._clock.looping_call(
@ -207,23 +209,30 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
) -> None: ) -> None:
for row in rows: for row in rows:
if row.is_signature: if row.is_signature:
self._user_signature_stream_cache.entity_has_changed(row.entity, token) self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
continue continue
# The entities are either user IDs (starting with '@') whose devices # The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about # have changed, or remote servers that we need to tell about
# changes. # changes.
if row.entity.startswith("@"): if not row.hosts_calculated:
self._device_list_stream_cache.entity_has_changed(row.entity, token) self._device_list_stream_cache.entity_has_changed(row.user_id, token)
self.get_cached_devices_for_user.invalidate((row.entity,)) self.get_cached_devices_for_user.invalidate((row.user_id,))
self._get_cached_user_device.invalidate((row.entity,)) self._get_cached_user_device.invalidate((row.user_id,))
self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,)) self.get_device_list_last_stream_id_for_remote.invalidate(
(row.user_id,)
else:
self._device_list_federation_stream_cache.entity_has_changed(
row.entity, token
) )
def device_lists_outbound_pokes_have_changed(
self, destinations: StrCollection, token: int
) -> None:
assert self._device_list_federation_stream_cache is not None
for destination in destinations:
self._device_list_federation_stream_cache.entity_has_changed(
destination, token
)
def device_lists_in_rooms_have_changed( def device_lists_in_rooms_have_changed(
self, room_ids: StrCollection, token: int self, room_ids: StrCollection, token: int
) -> None: ) -> None:
@ -363,6 +372,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
EDU contents. EDU contents.
""" """
now_stream_id = self.get_device_stream_token() now_stream_id = self.get_device_stream_token()
if from_stream_id == now_stream_id:
return now_stream_id, []
if self._device_list_federation_stream_cache is None:
raise Exception("Func can only be used on federation senders")
has_changed = self._device_list_federation_stream_cache.has_entity_changed( has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id) destination, int(from_stream_id)
@ -1018,10 +1032,10 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
# This query Does The Right Thing where it'll correctly apply the # This query Does The Right Thing where it'll correctly apply the
# bounds to the inner queries. # bounds to the inner queries.
sql = """ sql = """
SELECT stream_id, entity FROM ( SELECT stream_id, user_id, hosts FROM (
SELECT stream_id, user_id AS entity FROM device_lists_stream SELECT stream_id, user_id, false AS hosts FROM device_lists_stream
UNION ALL UNION ALL
SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes SELECT DISTINCT stream_id, user_id, true AS hosts FROM device_lists_outbound_pokes
) AS e ) AS e
WHERE ? < stream_id AND stream_id <= ? WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC ORDER BY stream_id ASC
@ -1577,6 +1591,14 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
get_device_list_changes_in_room_txn, get_device_list_changes_in_room_txn,
) )
async def get_destinations_for_device(self, stream_id: int) -> StrCollection:
return await self.db_pool.simple_select_onecol(
table="device_lists_outbound_pokes",
keyvalues={"stream_id": stream_id},
retcol="destination",
desc="get_destinations_for_device",
)
class DeviceBackgroundUpdateStore(SQLBaseStore): class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__( def __init__(
@ -2112,12 +2134,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
stream_ids: List[int], stream_ids: List[int],
context: Optional[Dict[str, str]], context: Optional[Dict[str, str]],
) -> None: ) -> None:
for host in hosts: if self._device_list_federation_stream_cache:
txn.call_after( for host in hosts:
self._device_list_federation_stream_cache.entity_has_changed, txn.call_after(
host, self._device_list_federation_stream_cache.entity_has_changed,
stream_ids[-1], host,
) stream_ids[-1],
)
now = self._clock.time_msec() now = self._clock.time_msec()
stream_id_iterator = iter(stream_ids) stream_id_iterator = iter(stream_ids)

View file

@ -123,9 +123,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
if stream_name == DeviceListsStream.NAME: if stream_name == DeviceListsStream.NAME:
for row in rows: for row in rows:
assert isinstance(row, DeviceListsStream.DeviceListsStreamRow) assert isinstance(row, DeviceListsStream.DeviceListsStreamRow)
if row.entity.startswith("@"): if not row.hosts_calculated:
self._get_e2e_device_keys_for_federation_query_inner.invalidate( self._get_e2e_device_keys_for_federation_query_inner.invalidate(
(row.entity,) (row.user_id,)
) )
super().process_replication_rows(stream_name, instance_name, token, rows) super().process_replication_rows(stream_name, instance_name, token, rows)

View file

@ -36,6 +36,14 @@ class DeviceStoreTestCase(HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
def default_config(self) -> JsonDict:
config = super().default_config()
# We 'enable' federation otherwise `get_device_updates_by_remote` will
# throw an exception.
config["federation_sender_instances"] = ["master"]
return config
def add_device_change(self, user_id: str, device_ids: List[str], host: str) -> None: def add_device_change(self, user_id: str, device_ids: List[str], host: str) -> None:
"""Add a device list change for the given device to """Add a device list change for the given device to
`device_lists_outbound_pokes` table. `device_lists_outbound_pokes` table.