diff --git a/changelog.d/17271.misc b/changelog.d/17271.misc new file mode 100644 index 0000000000..915d717ad7 --- /dev/null +++ b/changelog.d/17271.misc @@ -0,0 +1 @@ +Handle OTK uploads off master. diff --git a/changelog.d/17273.misc b/changelog.d/17273.misc new file mode 100644 index 0000000000..2c1c6bc0d5 --- /dev/null +++ b/changelog.d/17273.misc @@ -0,0 +1 @@ +Don't try and resync devices for remote users whose servers are marked as down. diff --git a/changelog.d/17275.bugfix b/changelog.d/17275.bugfix new file mode 100644 index 0000000000..eb522bb997 --- /dev/null +++ b/changelog.d/17275.bugfix @@ -0,0 +1 @@ +Fix bug where OTKs were not always included in `/sync` response when using workers. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 560530a7b3..668cec513b 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -35,6 +35,7 @@ from synapse.api.errors import CodeMessageException, Codes, NotFoundError, Synap from synapse.handlers.device import DeviceHandler from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace +from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet from synapse.types import ( JsonDict, JsonMapping, @@ -45,7 +46,10 @@ from synapse.types import ( from synapse.util import json_decoder from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.cancellation import cancellable -from synapse.util.retryutils import NotRetryingDestination +from synapse.util.retryutils import ( + NotRetryingDestination, + filter_destinations_by_retry_limiter, +) if TYPE_CHECKING: from synapse.server import HomeServer @@ -86,6 +90,12 @@ class E2eKeysHandler: edu_updater.incoming_signing_key_update, ) + self.device_key_uploader = self.upload_device_keys_for_user + else: + self.device_key_uploader = ( + ReplicationUploadKeysForUserRestServlet.make_client(hs) + ) + # doesn't really work as part of the generic query API, because the # query request requires an object POST, but we abuse the # "query handler" interface. @@ -268,10 +278,8 @@ class E2eKeysHandler: "%d destinations to query devices for", len(remote_queries_not_in_cache) ) - async def _query( - destination_queries: Tuple[str, Dict[str, Iterable[str]]] - ) -> None: - destination, queries = destination_queries + async def _query(destination: str) -> None: + queries = remote_queries_not_in_cache[destination] return await self._query_devices_for_destination( results, cross_signing_keys, @@ -281,9 +289,20 @@ class E2eKeysHandler: timeout, ) + # Only try and fetch keys for destinations that are not marked as + # down. + filtered_destinations = await filter_destinations_by_retry_limiter( + remote_queries_not_in_cache.keys(), + self.clock, + self.store, + # Let's give an arbitrary grace period for those hosts that are + # only recently down + retry_due_within_ms=60 * 1000, + ) + await concurrently_execute( _query, - remote_queries_not_in_cache.items(), + filtered_destinations, 10, delay_cancellation=True, ) @@ -784,36 +803,17 @@ class E2eKeysHandler: "one_time_keys": A mapping from algorithm to number of keys for that algorithm, including those previously persisted. """ - # This can only be called from the main process. - assert isinstance(self.device_handler, DeviceHandler) - time_now = self.clock.time_msec() # TODO: Validate the JSON to make sure it has the right keys. device_keys = keys.get("device_keys", None) if device_keys: - logger.info( - "Updating device_keys for device %r for user %s at %d", - device_id, - user_id, - time_now, + await self.device_key_uploader( + user_id=user_id, + device_id=device_id, + keys={"device_keys": device_keys}, ) - log_kv( - { - "message": "Updating device_keys for user.", - "user_id": user_id, - "device_id": device_id, - } - ) - # TODO: Sign the JSON with the server key - changed = await self.store.set_e2e_device_keys( - user_id, device_id, time_now, device_keys - ) - if changed: - # Only notify about device updates *if* the keys actually changed - await self.device_handler.notify_device_update(user_id, [device_id]) - else: - log_kv({"message": "Not updating device_keys for user", "user_id": user_id}) + one_time_keys = keys.get("one_time_keys", None) if one_time_keys: log_kv( @@ -849,6 +849,49 @@ class E2eKeysHandler: {"message": "Did not update fallback_keys", "reason": "no keys given"} ) + result = await self.store.count_e2e_one_time_keys(user_id, device_id) + + set_tag("one_time_key_counts", str(result)) + return {"one_time_key_counts": result} + + @tag_args + async def upload_device_keys_for_user( + self, user_id: str, device_id: str, keys: JsonDict + ) -> None: + """ + Args: + user_id: user whose keys are being uploaded. + device_id: device whose keys are being uploaded. + device_keys: the `device_keys` of an /keys/upload request. + + """ + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) + + time_now = self.clock.time_msec() + + device_keys = keys["device_keys"] + logger.info( + "Updating device_keys for device %r for user %s at %d", + device_id, + user_id, + time_now, + ) + log_kv( + { + "message": "Updating device_keys for user.", + "user_id": user_id, + "device_id": device_id, + } + ) + # TODO: Sign the JSON with the server key + changed = await self.store.set_e2e_device_keys( + user_id, device_id, time_now, device_keys + ) + if changed: + # Only notify about device updates *if* the keys actually changed + await self.device_handler.notify_device_update(user_id, [device_id]) + # the device should have been registered already, but it may have been # deleted due to a race with a DELETE request. Or we may be using an # old access_token without an associated device_id. Either way, we @@ -856,11 +899,6 @@ class E2eKeysHandler: # keys without a corresponding device. await self.device_handler.check_device_registered(user_id, device_id) - result = await self.store.count_e2e_one_time_keys(user_id, device_id) - - set_tag("one_time_key_counts", str(result)) - return {"one_time_key_counts": result} - async def _upload_one_time_keys_for_user( self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict ) -> None: diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index 6e4ac23a87..08c6aadff6 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -53,14 +53,13 @@ def filter_membership_for_sync(*, membership: str, user_id: str, sender: str) -> sender: The person who sent the membership event """ - return ( - # Everything except `Membership.LEAVE` because we want everything that's *still* - # relevant to the user. There are few more things to include in the sync response - # (newly_left) but those are handled separately. - membership in (Membership.LIST - {Membership.LEAVE}) - # Include kicks - or (membership == Membership.LEAVE and sender != user_id) - ) + # Everything except `Membership.LEAVE` because we want everything that's *still* + # relevant to the user. There are few more things to include in the sync response + # (newly_left) but those are handled separately. + # + # This logic includes kicks (leave events where the sender is not the same user) and + # can be read as "anything that isn't a leave or a leave with a different sender". + return membership != Membership.LEAVE or sender != user_id class SlidingSyncConfig(SlidingSyncBody): diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 778d68ad3f..39964726c5 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -285,7 +285,11 @@ class SyncResult: ) @staticmethod - def empty(next_batch: StreamToken) -> "SyncResult": + def empty( + next_batch: StreamToken, + device_one_time_keys_count: JsonMapping, + device_unused_fallback_key_types: List[str], + ) -> "SyncResult": "Return a new empty result" return SyncResult( next_batch=next_batch, @@ -297,8 +301,8 @@ class SyncResult: archived=[], to_device=[], device_lists=DeviceListUpdates(), - device_one_time_keys_count={}, - device_unused_fallback_key_types=[], + device_one_time_keys_count=device_one_time_keys_count, + device_unused_fallback_key_types=device_unused_fallback_key_types, ) @@ -523,7 +527,28 @@ class SyncHandler: logger.warning( "Timed out waiting for worker to catch up. Returning empty response" ) - return SyncResult.empty(since_token) + device_id = sync_config.device_id + one_time_keys_count: JsonMapping = {} + unused_fallback_key_types: List[str] = [] + if device_id: + user_id = sync_config.user.to_string() + # TODO: We should have a way to let clients differentiate between the states of: + # * no change in OTK count since the provided since token + # * the server has zero OTKs left for this device + # Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298 + one_time_keys_count = await self.store.count_e2e_one_time_keys( + user_id, device_id + ) + unused_fallback_key_types = list( + await self.store.get_e2e_unused_fallback_key_types( + user_id, device_id + ) + ) + + cache_context.should_cache = False # Don't cache empty responses + return SyncResult.empty( + since_token, one_time_keys_count, unused_fallback_key_types + ) # If we've spent significant time waiting to catch up, take it off # the timeout. diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index a0017257ce..306db07b86 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -36,7 +36,6 @@ from synapse.http.servlet import ( ) from synapse.http.site import SynapseRequest from synapse.logging.opentracing import log_kv, set_tag -from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet from synapse.rest.client._base import client_patterns, interactive_auth_handler from synapse.types import JsonDict, StreamToken from synapse.util.cancellation import cancellable @@ -105,13 +104,8 @@ class KeyUploadServlet(RestServlet): self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() self.device_handler = hs.get_device_handler() - - if hs.config.worker.worker_app is None: - # if main process - self.key_uploader = self.e2e_keys_handler.upload_keys_for_user - else: - # then a worker - self.key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(hs) + self._clock = hs.get_clock() + self._store = hs.get_datastores().main async def on_POST( self, request: SynapseRequest, device_id: Optional[str] @@ -151,9 +145,10 @@ class KeyUploadServlet(RestServlet): 400, "To upload keys, you must pass device_id when authenticating" ) - result = await self.key_uploader( + result = await self.e2e_keys_handler.upload_keys_for_user( user_id=user_id, device_id=device_id, keys=body ) + return 200, result diff --git a/tests/handlers/test_sliding_sync.py b/tests/handlers/test_sliding_sync.py index 498e3be592..220683b9d6 100644 --- a/tests/handlers/test_sliding_sync.py +++ b/tests/handlers/test_sliding_sync.py @@ -547,6 +547,108 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): # Room should still show up because it's newly_left during the from/to range self.assertEqual(room_id_results, {room_id1}) + def test_no_from_token(self) -> None: + """ + Test that if we don't provide a `from_token`, we get all the rooms that we we're + joined to up to the `to_token`. + + Providing `from_token` only really has the effect that it adds `newly_left` + rooms to the response. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + # We create the room with user2 so the room isn't left with no members when we + # leave and can still re-join. + room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) + room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) + + # Join room1 + self.helper.join(room_id1, user1_id, tok=user1_tok) + + # Join and leave the room2 before the `to_token` + self.helper.join(room_id2, user1_id, tok=user1_tok) + self.helper.leave(room_id2, user1_id, tok=user1_tok) + + after_room1_token = self.event_sources.get_current_token() + + # Join the room2 after we already have our tokens + self.helper.join(room_id2, user1_id, tok=user1_tok) + + room_id_results = self.get_success( + self.sliding_sync_handler.get_sync_room_ids_for_user( + UserID.from_string(user1_id), + from_token=None, + to_token=after_room1_token, + ) + ) + + # Only rooms we were joined to before the `to_token` should show up + self.assertEqual(room_id_results, {room_id1}) + + def test_from_token_ahead_of_to_token(self) -> None: + """ + Test when the provided `from_token` comes after the `to_token`. We should + basically expect the same result as having no `from_token`. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + # We create the room with user2 so the room isn't left with no members when we + # leave and can still re-join. + room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) + room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) + room_id3 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) + room_id4 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) + + # Join room1 before `before_room_token` + self.helper.join(room_id1, user1_id, tok=user1_tok) + + # Join and leave the room2 before `before_room_token` + self.helper.join(room_id2, user1_id, tok=user1_tok) + self.helper.leave(room_id2, user1_id, tok=user1_tok) + + # Note: These are purposely swapped. The `from_token` should come after + # the `to_token` in this test + to_token = self.event_sources.get_current_token() + + # Join room2 after `before_room_token` + self.helper.join(room_id2, user1_id, tok=user1_tok) + + # -------- + + # Join room3 after `before_room_token` + self.helper.join(room_id3, user1_id, tok=user1_tok) + + # Join and leave the room4 after `before_room_token` + self.helper.join(room_id4, user1_id, tok=user1_tok) + self.helper.leave(room_id4, user1_id, tok=user1_tok) + + # Note: These are purposely swapped. The `from_token` should come after the + # `to_token` in this test + from_token = self.event_sources.get_current_token() + + # Join the room4 after we already have our tokens + self.helper.join(room_id4, user1_id, tok=user1_tok) + + room_id_results = self.get_success( + self.sliding_sync_handler.get_sync_room_ids_for_user( + UserID.from_string(user1_id), + from_token=from_token, + to_token=to_token, + ) + ) + + # Only rooms we were joined to before the `to_token` should show up + # + # There won't be any newly_left rooms because the `from_token` is ahead of the + # `to_token` and that range will give no membership changes to check. + self.assertEqual(room_id_results, {room_id1}) + def test_leave_before_range_and_join_leave_after_to_token(self) -> None: """ Old left room shouldn't show up. But we're also testing that joining and leaving