diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 3411179a2a..d930c80ef8 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -448,3 +448,6 @@ class ExperimentalConfig(Config): # MSC4222: Adding `state_after` to sync v2 self.msc4222_enabled: bool = experimental.get("msc4222_enabled", False) + + # MSC4229: Pass through `unsigned` data from `/keys/upload` to `/keys/query` + self.msc4229_enabled: bool = experimental.get("msc4229_enabled", False) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 540995e062..82d45cf09d 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -542,7 +542,9 @@ class E2eKeysHandler: result_dict[user_id] = {} results = await self.store.get_e2e_device_keys_for_cs_api( - local_query, include_displaynames + local_query, + include_displaynames, + include_uploaded_unsigned_data=self.config.experimental.msc4229_enabled, ) # Check if the application services have any additional results. diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 3bb8fccb5e..b28f9ef591 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -220,12 +220,15 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker self, query_list: Collection[Tuple[str, Optional[str]]], include_displaynames: bool = True, + include_uploaded_unsigned_data: bool = False, ) -> Dict[str, Dict[str, JsonDict]]: """Fetch a list of device keys, formatted suitably for the C/S API. Args: query_list: List of pairs of user_ids and device_ids. include_displaynames: Whether to include the displayname of returned devices (if one exists). + include_uploaded_unsigned_data: Whether to include uploaded `unsigned` data + in the response Returns: Dict mapping from user-id to dict mapping from device_id to key data. The key data will be a dict in the same format as the @@ -247,7 +250,13 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker if r is None: continue - r["unsigned"] = {} + # If there was already an `unsigned` dict in the uploaded key, keep it. + # Otherwise, create a new one. + if not include_uploaded_unsigned_data or not isinstance( + r.get("unsigned"), dict + ): + r["unsigned"] = {} + if include_displaynames: # Include the device's display name in the "unsigned" dictionary display_name = device_info.display_name diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py index d9a210b616..10d857a08b 100644 --- a/tests/rest/client/test_keys.py +++ b/tests/rest/client/test_keys.py @@ -19,6 +19,7 @@ # # import urllib.parse +from copy import deepcopy from http import HTTPStatus from unittest.mock import patch @@ -205,6 +206,141 @@ class KeyQueryTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, HTTPStatus.OK, channel.result) +class UnsignedKeyDataTestCase(unittest.HomeserverTestCase): + servlets = [ + keys.register_servlets, + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] + + def default_config(self) -> JsonDict: + config = super().default_config() + config["experimental_features"] = {"msc4229_enabled": True} + return config + + def make_key_data(self, user_id: str, device_id: str) -> JsonDict: + return { + "algorithms": ["m.olm.v1.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"], + "device_id": device_id, + "keys": { + f"curve25519:{device_id}": "keykeykey", + f"ed25519:{device_id}": "keykeykey", + }, + "signatures": {user_id: {f"ed25519:{device_id}": "sigsigsig"}}, + "user_id": user_id, + } + + def test_unsigned_uploaded_data_returned_in_keys_query(self) -> None: + password = "wonderland" + device_id = "ABCDEFGHI" + alice_id = self.register_user("alice", password) + alice_token = self.login( + "alice", + password, + device_id=device_id, + additional_request_fields={"initial_device_display_name": "mydevice"}, + ) + + # Alice uploads some keys, with a bit of unsigned data + keys1 = self.make_key_data(alice_id, device_id) + keys1["unsigned"] = {"a": "b"} + + channel = self.make_request( + "POST", + "/_matrix/client/v3/keys/upload", + {"device_keys": keys1}, + alice_token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + # /keys/query should return the unsigned data, with the device display name merged in. + channel = self.make_request( + "POST", + "/_matrix/client/v3/keys/query", + {"device_keys": {alice_id: []}}, + alice_token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + device_response = channel.json_body["device_keys"][alice_id][device_id] + expected_device_response = deepcopy(keys1) + expected_device_response["unsigned"]["device_display_name"] = "mydevice" + self.assertEqual(device_response, expected_device_response) + + # /_matrix/federation/v1/user/devices/{userId} should return the unsigned data too + fed_response = self.get_success( + self.hs.get_device_handler().on_federation_query_user_devices(alice_id) + ) + self.assertEqual( + fed_response["devices"][0], + {"device_id": device_id, "keys": keys1}, + ) + + # so should /_matrix/federation/v1/user/keys/query + fed_response = self.get_success( + self.hs.get_e2e_keys_handler().on_federation_query_client_keys( + {"device_keys": {alice_id: []}} + ) + ) + fed_device_response = fed_response["device_keys"][alice_id][device_id] + self.assertEqual(fed_device_response, keys1) + + def test_non_dict_unsigned_is_ignored(self) -> None: + password = "wonderland" + device_id = "ABCDEFGHI" + alice_id = self.register_user("alice", password) + alice_token = self.login( + "alice", + password, + device_id=device_id, + additional_request_fields={"initial_device_display_name": "mydevice"}, + ) + + # Alice uploads some keys, with a malformed unsigned data + keys1 = self.make_key_data(alice_id, device_id) + keys1["unsigned"] = ["a", "b"] # a list! + + channel = self.make_request( + "POST", + "/_matrix/client/v3/keys/upload", + {"device_keys": keys1}, + alice_token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + # /keys/query should return the unsigned data, with the device display name merged in. + channel = self.make_request( + "POST", + "/_matrix/client/v3/keys/query", + {"device_keys": {alice_id: []}}, + alice_token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + device_response = channel.json_body["device_keys"][alice_id][device_id] + expected_device_response = deepcopy(keys1) + expected_device_response["unsigned"] = {"device_display_name": "mydevice"} + self.assertEqual(device_response, expected_device_response) + + # /_matrix/federation/v1/user/devices/{userId} should return the unsigned data too + fed_response = self.get_success( + self.hs.get_device_handler().on_federation_query_user_devices(alice_id) + ) + self.assertEqual( + fed_response["devices"][0], + {"device_id": device_id, "keys": keys1}, + ) + + # so should /_matrix/federation/v1/user/keys/query + fed_response = self.get_success( + self.hs.get_e2e_keys_handler().on_federation_query_client_keys( + {"device_keys": {alice_id: []}} + ) + ) + fed_device_response = fed_response["device_keys"][alice_id][device_id] + expected_device_response = deepcopy(keys1) + expected_device_response["unsigned"] = {} + self.assertEqual(fed_device_response, expected_device_response) + + class SigningKeyUploadServletTestCase(unittest.HomeserverTestCase): servlets = [ admin.register_servlets,