From 652d1669c5a103b1c20478770c4aaf18849c09a3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 16 Dec 2022 06:53:01 -0500 Subject: [PATCH] Add missing type hints to tests.handlers. (#14680) And do not allow untyped defs in tests.handlers. --- changelog.d/14680.misc | 1 + mypy.ini | 5 +- synapse/handlers/auth.py | 2 +- tests/handlers/test_appservice.py | 54 +++---- tests/handlers/test_cas.py | 2 +- tests/handlers/test_directory.py | 27 ++-- tests/handlers/test_e2e_room_keys.py | 76 ++++++---- tests/handlers/test_federation.py | 2 +- tests/handlers/test_federation_event.py | 10 +- tests/handlers/test_message.py | 26 ++-- tests/handlers/test_oidc.py | 48 +++--- tests/handlers/test_password_providers.py | 144 +++++++++--------- tests/handlers/test_presence.py | 100 +++++++------ tests/handlers/test_profile.py | 4 +- tests/handlers/test_receipts.py | 6 +- tests/handlers/test_register.py | 169 +++++++++++++--------- tests/handlers/test_room.py | 6 +- tests/handlers/test_room_summary.py | 76 ++++++---- tests/handlers/test_saml.py | 33 +++-- tests/handlers/test_send_email.py | 29 ++-- tests/handlers/test_stats.py | 74 +++++++--- tests/handlers/test_sync.py | 11 +- 22 files changed, 527 insertions(+), 378 deletions(-) create mode 100644 changelog.d/14680.misc diff --git a/changelog.d/14680.misc b/changelog.d/14680.misc new file mode 100644 index 0000000000..d44571b731 --- /dev/null +++ b/changelog.d/14680.misc @@ -0,0 +1 @@ +Add missing type hints. diff --git a/mypy.ini b/mypy.ini index 37acf589c9..1a37414e58 100644 --- a/mypy.ini +++ b/mypy.ini @@ -95,10 +95,7 @@ disallow_untyped_defs = True [mypy-tests.federation.transport.test_client] disallow_untyped_defs = True -[mypy-tests.handlers.test_sso] -disallow_untyped_defs = True - -[mypy-tests.handlers.test_user_directory] +[mypy-tests.handlers.*] disallow_untyped_defs = True [mypy-tests.metrics.test_background_process_metrics] diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 8b9ef25d29..30f2d46c3c 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -2031,7 +2031,7 @@ class PasswordAuthProvider: self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = [] # Mapping from login type to login parameters - self._supported_login_types: Dict[str, Iterable[str]] = {} + self._supported_login_types: Dict[str, Tuple[str, ...]] = {} # Mapping from login type to auth checker callbacks self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {} diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 57bfbd7734..a7495ab21a 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -31,7 +31,7 @@ from synapse.appservice import ( from synapse.handlers.appservice import ApplicationServicesHandler from synapse.rest.client import login, receipts, register, room, sendtodevice from synapse.server import HomeServer -from synapse.types import RoomStreamToken +from synapse.types import JsonDict, RoomStreamToken from synapse.util import Clock from synapse.util.stringutils import random_string @@ -44,7 +44,7 @@ from tests.utils import MockClock class AppServiceHandlerTestCase(unittest.TestCase): """Tests the ApplicationServicesHandler.""" - def setUp(self): + def setUp(self) -> None: self.mock_store = Mock() self.mock_as_api = Mock() self.mock_scheduler = Mock() @@ -61,7 +61,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.handler = ApplicationServicesHandler(hs) self.event_source = hs.get_event_sources() - def test_notify_interested_services(self): + def test_notify_interested_services(self) -> None: interested_service = self._mkservice(is_interested_in_event=True) services = [ self._mkservice(is_interested_in_event=False), @@ -90,7 +90,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): interested_service, events=[event] ) - def test_query_user_exists_unknown_user(self): + def test_query_user_exists_unknown_user(self) -> None: user_id = "@someone:anywhere" services = [self._mkservice(is_interested_in_event=True)] services[0].is_interested_in_user.return_value = True @@ -107,7 +107,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_as_api.query_user.assert_called_once_with(services[0], user_id) - def test_query_user_exists_known_user(self): + def test_query_user_exists_known_user(self) -> None: user_id = "@someone:anywhere" services = [self._mkservice(is_interested_in_event=True)] services[0].is_interested_in_user.return_value = True @@ -127,7 +127,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): "query_user called when it shouldn't have been.", ) - def test_query_room_alias_exists(self): + def test_query_room_alias_exists(self) -> None: room_alias_str = "#foo:bar" room_alias = Mock() room_alias.to_string.return_value = room_alias_str @@ -157,7 +157,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.assertEqual(result.room_id, room_id) self.assertEqual(result.servers, servers) - def test_get_3pe_protocols_no_appservices(self): + def test_get_3pe_protocols_no_appservices(self) -> None: self.mock_store.get_app_services.return_value = [] response = self.successResultOf( defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol")) @@ -165,7 +165,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_as_api.get_3pe_protocol.assert_not_called() self.assertEqual(response, {}) - def test_get_3pe_protocols_no_protocols(self): + def test_get_3pe_protocols_no_protocols(self) -> None: service = self._mkservice(False, []) self.mock_store.get_app_services.return_value = [service] response = self.successResultOf( @@ -174,7 +174,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_as_api.get_3pe_protocol.assert_not_called() self.assertEqual(response, {}) - def test_get_3pe_protocols_protocol_no_response(self): + def test_get_3pe_protocols_protocol_no_response(self) -> None: service = self._mkservice(False, ["my-protocol"]) self.mock_store.get_app_services.return_value = [service] self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None) @@ -186,7 +186,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): ) self.assertEqual(response, {}) - def test_get_3pe_protocols_select_one_protocol(self): + def test_get_3pe_protocols_select_one_protocol(self) -> None: service = self._mkservice(False, ["my-protocol"]) self.mock_store.get_app_services.return_value = [service] self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( @@ -202,7 +202,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} ) - def test_get_3pe_protocols_one_protocol(self): + def test_get_3pe_protocols_one_protocol(self) -> None: service = self._mkservice(False, ["my-protocol"]) self.mock_store.get_app_services.return_value = [service] self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( @@ -218,7 +218,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} ) - def test_get_3pe_protocols_multiple_protocol(self): + def test_get_3pe_protocols_multiple_protocol(self) -> None: service_one = self._mkservice(False, ["my-protocol"]) service_two = self._mkservice(False, ["other-protocol"]) self.mock_store.get_app_services.return_value = [service_one, service_two] @@ -237,11 +237,13 @@ class AppServiceHandlerTestCase(unittest.TestCase): }, ) - def test_get_3pe_protocols_multiple_info(self): + def test_get_3pe_protocols_multiple_info(self) -> None: service_one = self._mkservice(False, ["my-protocol"]) service_two = self._mkservice(False, ["my-protocol"]) - async def get_3pe_protocol(service, unusedProtocol): + async def get_3pe_protocol( + service: ApplicationService, protocol: str + ) -> Optional[JsonDict]: if service == service_one: return { "x-protocol-data": 42, @@ -276,7 +278,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): }, ) - def test_notify_interested_services_ephemeral(self): + def test_notify_interested_services_ephemeral(self) -> None: """ Test sending ephemeral events to the appservice handler are scheduled to be pushed out to interested appservices, and that the stream ID is @@ -306,7 +308,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): 580, ) - def test_notify_interested_services_ephemeral_out_of_order(self): + def test_notify_interested_services_ephemeral_out_of_order(self) -> None: """ Test sending out of order ephemeral events to the appservice handler are ignored. @@ -390,7 +392,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): receipts.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.hs = hs # Mock the ApplicationServiceScheduler's _TransactionController's send method so that # we can track any outgoing ephemeral events @@ -417,7 +419,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): "exclusive_as_user", "password", self.exclusive_as_user_device_id ) - def _notify_interested_services(self): + def _notify_interested_services(self) -> None: # This is normally set in `notify_interested_services` but we need to call the # internal async version so the reactor gets pushed to completion. self.hs.get_application_service_handler().current_max += 1 @@ -443,7 +445,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): ) def test_match_interesting_room_members( self, interesting_user: str, should_notify: bool - ): + ) -> None: """ Test to make sure that a interesting user (local or remote) in the room is notified as expected when someone else in the room sends a message. @@ -512,7 +514,9 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): else: self.send_mock.assert_not_called() - def test_application_services_receive_events_sent_by_interesting_local_user(self): + def test_application_services_receive_events_sent_by_interesting_local_user( + self, + ) -> None: """ Test to make sure that a messages sent from a local user can be interesting and picked up by the appservice. @@ -568,7 +572,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): self.assertEqual(events[0]["type"], "m.room.message") self.assertEqual(events[0]["sender"], alice) - def test_sending_read_receipt_batches_to_application_services(self): + def test_sending_read_receipt_batches_to_application_services(self) -> None: """Tests that a large batch of read receipts are sent correctly to interested application services. """ @@ -644,7 +648,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): @unittest.override_config( {"experimental_features": {"msc2409_to_device_messages_enabled": True}} ) - def test_application_services_receive_local_to_device(self): + def test_application_services_receive_local_to_device(self) -> None: """ Test that when a user sends a to-device message to another user that is an application service's user namespace, the @@ -722,7 +726,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): @unittest.override_config( {"experimental_features": {"msc2409_to_device_messages_enabled": True}} ) - def test_application_services_receive_bursts_of_to_device(self): + def test_application_services_receive_bursts_of_to_device(self) -> None: """ Test that when a user sends >100 to-device messages at once, any interested AS's will receive them in separate transactions. @@ -913,7 +917,7 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase) experimental_feature_enabled: bool, as_supports_txn_extensions: bool, as_should_receive_device_list_updates: bool, - ): + ) -> None: """ Tests that an application service receives notice of changed device lists for a user, when a user changes their device lists. @@ -1070,7 +1074,7 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): and a room for the users to talk in. """ - async def preparation(): + async def preparation() -> None: await self._add_otks_for_device(self._sender_user, self._sender_device, 42) await self._add_fallback_key_for_device( self._sender_user, self._sender_device, used=True diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index 2b21547d0f..2733719d82 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -199,7 +199,7 @@ class CasHandlerTestCase(HomeserverTestCase): ) -def _mock_request(): +def _mock_request() -> Mock: """Returns a mock which will stand in as a SynapseRequest""" mock = Mock( spec=[ diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 3b72c4c9d0..90aec484c4 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -20,6 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.api.errors import synapse.rest.admin from synapse.api.constants import EventTypes +from synapse.events import EventBase from synapse.rest.client import directory, login, room from synapse.server import HomeServer from synapse.types import JsonDict, RoomAlias, create_requester @@ -201,7 +202,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): self.test_user_tok = self.login("user", "pass") self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok) - def _create_alias(self, user) -> None: + def _create_alias(self, user: str) -> None: # Create a new alias to this room. self.get_success( self.store.create_room_alias_association( @@ -324,7 +325,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) return room_alias - def _set_canonical_alias(self, content) -> None: + def _set_canonical_alias(self, content: JsonDict) -> None: """Configure the canonical alias state on the room.""" self.helper.send_state( self.room_id, @@ -333,13 +334,15 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): tok=self.admin_user_tok, ) - def _get_canonical_alias(self): + def _get_canonical_alias(self) -> EventBase: """Get the canonical alias state of the room.""" - return self.get_success( + result = self.get_success( self._storage_controllers.state.get_current_state_event( self.room_id, EventTypes.CanonicalAlias, "" ) ) + assert result is not None + return result def test_remove_alias(self) -> None: """Removing an alias that is the canonical alias should remove it there too.""" @@ -349,8 +352,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) data = self._get_canonical_alias() - self.assertEqual(data["content"]["alias"], self.test_alias) - self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) + self.assertEqual(data.content["alias"], self.test_alias) + self.assertEqual(data.content["alt_aliases"], [self.test_alias]) # Finally, delete the alias. self.get_success( @@ -360,8 +363,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) data = self._get_canonical_alias() - self.assertNotIn("alias", data["content"]) - self.assertNotIn("alt_aliases", data["content"]) + self.assertNotIn("alias", data.content) + self.assertNotIn("alt_aliases", data.content) def test_remove_other_alias(self) -> None: """Removing an alias listed as in alt_aliases should remove it there too.""" @@ -378,9 +381,9 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) data = self._get_canonical_alias() - self.assertEqual(data["content"]["alias"], self.test_alias) + self.assertEqual(data.content["alias"], self.test_alias) self.assertEqual( - data["content"]["alt_aliases"], [self.test_alias, other_test_alias] + data.content["alt_aliases"], [self.test_alias, other_test_alias] ) # Delete the second alias. @@ -391,8 +394,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) data = self._get_canonical_alias() - self.assertEqual(data["content"]["alias"], self.test_alias) - self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) + self.assertEqual(data.content["alias"], self.test_alias) + self.assertEqual(data.content["alt_aliases"], [self.test_alias]) class TestCreateAliasACL(unittest.HomeserverTestCase): diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 9b7e7a8e9a..6c0b30de9e 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -17,7 +17,11 @@ import copy from unittest import mock +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.errors import SynapseError +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -39,14 +43,14 @@ room_keys = { class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: return self.setup_test_homeserver(replication_layer=mock.Mock()) - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = hs.get_e2e_room_keys_handler() self.local_user = "@boris:" + hs.hostname - def test_get_missing_current_version_info(self): + def test_get_missing_current_version_info(self) -> None: """Check that we get a 404 if we ask for info about the current version if there is no version. """ @@ -56,7 +60,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): res = e.value.code self.assertEqual(res, 404) - def test_get_missing_version_info(self): + def test_get_missing_version_info(self) -> None: """Check that we get a 404 if we ask for info about a specific version if it doesn't exist. """ @@ -67,9 +71,9 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): res = e.value.code self.assertEqual(res, 404) - def test_create_version(self): + def test_create_version(self) -> None: """Check that we can create and then retrieve versions.""" - res = self.get_success( + version = self.get_success( self.handler.create_version( self.local_user, { @@ -78,7 +82,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): }, ) ) - self.assertEqual(res, "1") + self.assertEqual(version, "1") # check we can retrieve it as the current version res = self.get_success(self.handler.get_version_info(self.local_user)) @@ -110,7 +114,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): ) # upload a new one... - res = self.get_success( + version = self.get_success( self.handler.create_version( self.local_user, { @@ -119,7 +123,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): }, ) ) - self.assertEqual(res, "2") + self.assertEqual(version, "2") # check we can retrieve it as the current version res = self.get_success(self.handler.get_version_info(self.local_user)) @@ -134,7 +138,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): }, ) - def test_update_version(self): + def test_update_version(self) -> None: """Check that we can update versions.""" version = self.get_success( self.handler.create_version( @@ -173,7 +177,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): }, ) - def test_update_missing_version(self): + def test_update_missing_version(self) -> None: """Check that we get a 404 on updating nonexistent versions""" e = self.get_failure( self.handler.update_version( @@ -190,7 +194,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): res = e.value.code self.assertEqual(res, 404) - def test_update_omitted_version(self): + def test_update_omitted_version(self) -> None: """Check that the update succeeds if the version is missing from the body""" version = self.get_success( self.handler.create_version( @@ -227,7 +231,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): }, ) - def test_update_bad_version(self): + def test_update_bad_version(self) -> None: """Check that we get a 400 if the version in the body doesn't match""" version = self.get_success( self.handler.create_version( @@ -255,7 +259,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): res = e.value.code self.assertEqual(res, 400) - def test_delete_missing_version(self): + def test_delete_missing_version(self) -> None: """Check that we get a 404 on deleting nonexistent versions""" e = self.get_failure( self.handler.delete_version(self.local_user, "1"), SynapseError @@ -263,15 +267,15 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): res = e.value.code self.assertEqual(res, 404) - def test_delete_missing_current_version(self): + def test_delete_missing_current_version(self) -> None: """Check that we get a 404 on deleting nonexistent current version""" e = self.get_failure(self.handler.delete_version(self.local_user), SynapseError) res = e.value.code self.assertEqual(res, 404) - def test_delete_version(self): + def test_delete_version(self) -> None: """Check that we can create and then delete versions.""" - res = self.get_success( + version = self.get_success( self.handler.create_version( self.local_user, { @@ -280,7 +284,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): }, ) ) - self.assertEqual(res, "1") + self.assertEqual(version, "1") # check we can delete it self.get_success(self.handler.delete_version(self.local_user, "1")) @@ -292,7 +296,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): res = e.value.code self.assertEqual(res, 404) - def test_get_missing_backup(self): + def test_get_missing_backup(self) -> None: """Check that we get a 404 on querying missing backup""" e = self.get_failure( self.handler.get_room_keys(self.local_user, "bogus_version"), SynapseError @@ -300,7 +304,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): res = e.value.code self.assertEqual(res, 404) - def test_get_missing_room_keys(self): + def test_get_missing_room_keys(self) -> None: """Check we get an empty response from an empty backup""" version = self.get_success( self.handler.create_version( @@ -319,7 +323,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): # TODO: test the locking semantics when uploading room_keys, # although this is probably best done in sytest - def test_upload_room_keys_no_versions(self): + def test_upload_room_keys_no_versions(self) -> None: """Check that we get a 404 on uploading keys when no versions are defined""" e = self.get_failure( self.handler.upload_room_keys(self.local_user, "no_version", room_keys), @@ -328,7 +332,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): res = e.value.code self.assertEqual(res, 404) - def test_upload_room_keys_bogus_version(self): + def test_upload_room_keys_bogus_version(self) -> None: """Check that we get a 404 on uploading keys when an nonexistent version is specified """ @@ -350,7 +354,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): res = e.value.code self.assertEqual(res, 404) - def test_upload_room_keys_wrong_version(self): + def test_upload_room_keys_wrong_version(self) -> None: """Check that we get a 403 on uploading keys for an old version""" version = self.get_success( self.handler.create_version( @@ -380,7 +384,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): res = e.value.code self.assertEqual(res, 403) - def test_upload_room_keys_insert(self): + def test_upload_room_keys_insert(self) -> None: """Check that we can insert and retrieve keys for a session""" version = self.get_success( self.handler.create_version( @@ -416,7 +420,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): ) self.assertDictEqual(res, room_keys) - def test_upload_room_keys_merge(self): + def test_upload_room_keys_merge(self) -> None: """Check that we can upload a new room_key for an existing session and have it correctly merged""" version = self.get_success( @@ -449,9 +453,11 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): self.handler.upload_room_keys(self.local_user, version, new_room_keys) ) - res = self.get_success(self.handler.get_room_keys(self.local_user, version)) + res_keys = self.get_success( + self.handler.get_room_keys(self.local_user, version) + ) self.assertEqual( - res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], + res_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "SSBBTSBBIEZJU0gK", ) @@ -465,9 +471,12 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): self.handler.upload_room_keys(self.local_user, version, new_room_keys) ) - res = self.get_success(self.handler.get_room_keys(self.local_user, version)) + res_keys = self.get_success( + self.handler.get_room_keys(self.local_user, version) + ) self.assertEqual( - res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" + res_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], + "new", ) # the etag should NOT be equal now, since the key changed @@ -483,9 +492,12 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): self.handler.upload_room_keys(self.local_user, version, new_room_keys) ) - res = self.get_success(self.handler.get_room_keys(self.local_user, version)) + res_keys = self.get_success( + self.handler.get_room_keys(self.local_user, version) + ) self.assertEqual( - res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" + res_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], + "new", ) # the etag should be the same since the session did not change @@ -494,7 +506,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): # TODO: check edge cases as well as the common variations here - def test_delete_room_keys(self): + def test_delete_room_keys(self) -> None: """Check that we can insert and delete keys for a session""" version = self.get_success( self.handler.create_version( diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index d00c69c229..cedbb9fafc 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -439,7 +439,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") - def create_invite(): + def create_invite() -> EventBase: room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) room_version = self.get_success(self.store.get_room_version(room_id)) return event_from_pdu_json( diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index e448cb1901..70ea4d15d4 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -14,6 +14,8 @@ from typing import Optional from unittest import mock +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.errors import AuthError, StoreError from synapse.api.room_versions import RoomVersion from synapse.event_auth import ( @@ -26,8 +28,10 @@ from synapse.federation.transport.client import StateRequestResponse from synapse.logging.context import LoggingContext from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest from tests.test_utils import event_injection, make_awaitable @@ -40,7 +44,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # mock out the federation transport client self.mock_federation_transport_client = mock.Mock( spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"] @@ -165,7 +169,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): ) else: - async def get_event(destination: str, event_id: str, timeout=None): + async def get_event( + destination: str, event_id: str, timeout: Optional[int] = None + ) -> JsonDict: self.assertEqual(destination, self.OTHER_SERVER_NAME) self.assertEqual(event_id, prev_event.event_id) return {"pdus": [prev_event.get_pdu_json()]} diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 99384837d0..c4727ab917 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -14,12 +14,16 @@ import logging from typing import Tuple +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventTypes from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer from synapse.types import create_requester +from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest @@ -35,7 +39,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = self.hs.get_event_creation_handler() self._persist_event_storage_controller = ( self.hs.get_storage_controllers().persistence @@ -94,7 +98,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): ) ) - def test_duplicated_txn_id(self): + def test_duplicated_txn_id(self) -> None: """Test that attempting to handle/persist an event with a transaction ID that has already been persisted correctly returns the old event and does *not* produce duplicate messages. @@ -161,7 +165,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): # rather than the new one. self.assertEqual(ret_event1.event_id, ret_event4.event_id) - def test_duplicated_txn_id_one_call(self): + def test_duplicated_txn_id_one_call(self) -> None: """Test that we correctly handle duplicates that we try and persist at the same time. """ @@ -185,7 +189,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.assertEqual(len(events), 2) self.assertEqual(events[0].event_id, events[1].event_id) - def test_when_empty_prev_events_allowed_create_event_with_empty_prev_events(self): + def test_when_empty_prev_events_allowed_create_event_with_empty_prev_events( + self, + ) -> None: """When we set allow_no_prev_events=True, should be able to create a event without any prev_events (only auth_events). """ @@ -214,7 +220,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): def test_when_empty_prev_events_not_allowed_reject_event_with_empty_prev_events( self, - ): + ) -> None: """When we set allow_no_prev_events=False, shouldn't be able to create a event without any prev_events even if it has auth_events. Expect an exception to be raised. @@ -245,7 +251,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): def test_when_empty_prev_events_allowed_reject_event_with_empty_prev_events_and_auth_events( self, - ): + ) -> None: """When we set allow_no_prev_events=True, should be able to create a event without any prev_events or auth_events. Expect an exception to be raised. @@ -277,12 +283,12 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("tester", "foobar") self.access_token = self.login("tester", "foobar") self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token) - def test_allow_server_acl(self): + def test_allow_server_acl(self) -> None: """Test that sending an ACL that blocks everyone but ourselves works.""" self.helper.send_state( @@ -293,7 +299,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase): expect_code=200, ) - def test_deny_server_acl_block_outselves(self): + def test_deny_server_acl_block_outselves(self) -> None: """Test that sending an ACL that blocks ourselves does not work.""" self.helper.send_state( self.room_id, @@ -303,7 +309,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase): expect_code=400, ) - def test_deny_redact_server_acl(self): + def test_deny_redact_server_acl(self) -> None: """Test that attempting to redact an ACL is blocked.""" body = self.helper.send_state( diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 5955410524..49a1842b5c 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Dict, Tuple +from typing import Any, Awaitable, ContextManager, Dict, Optional, Tuple from unittest.mock import ANY, Mock, patch from urllib.parse import parse_qs, urlparse @@ -23,7 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.handlers.sso import MappingException from synapse.http.site import SynapseRequest from synapse.server import HomeServer -from synapse.types import UserID +from synapse.types import JsonDict, UserID from synapse.util import Clock from synapse.util.macaroons import get_value_from_macaroon from synapse.util.stringutils import random_string @@ -34,6 +34,10 @@ from tests.unittest import HomeserverTestCase, override_config try: import authlib # noqa: F401 + from authlib.oidc.core import UserInfo + from authlib.oidc.discovery import OpenIDProviderMetadata + + from synapse.handlers.oidc import Token, UserAttributeDict HAS_OIDC = True except ImportError: @@ -70,29 +74,37 @@ EXPLICIT_ENDPOINT_CONFIG = { class TestMappingProvider: @staticmethod - def parse_config(config): - return + def parse_config(config: JsonDict) -> None: + return None - def __init__(self, config): + def __init__(self, config: None): pass - def get_remote_user_id(self, userinfo): + def get_remote_user_id(self, userinfo: "UserInfo") -> str: return userinfo["sub"] - async def map_user_attributes(self, userinfo, token): - return {"localpart": userinfo["username"], "display_name": None} + async def map_user_attributes( + self, userinfo: "UserInfo", token: "Token" + ) -> "UserAttributeDict": + # This is testing not providing the full map. + return {"localpart": userinfo["username"], "display_name": None} # type: ignore[typeddict-item] # Do not include get_extra_attributes to test backwards compatibility paths. class TestMappingProviderExtra(TestMappingProvider): - async def get_extra_attributes(self, userinfo, token): + async def get_extra_attributes( + self, userinfo: "UserInfo", token: "Token" + ) -> JsonDict: return {"phone": userinfo["phone"]} class TestMappingProviderFailures(TestMappingProvider): - async def map_user_attributes(self, userinfo, token, failures): - return { + # Superclass is testing the legacy interface for map_user_attributes. + async def map_user_attributes( # type: ignore[override] + self, userinfo: "UserInfo", token: "Token", failures: int + ) -> "UserAttributeDict": + return { # type: ignore[typeddict-item] "localpart": userinfo["username"] + (str(failures) if failures else ""), "display_name": None, } @@ -161,13 +173,13 @@ class OidcHandlerTestCase(HomeserverTestCase): self.hs_patcher.stop() return super().tearDown() - def reset_mocks(self): + def reset_mocks(self) -> None: """Reset all the Mocks.""" self.fake_server.reset_mocks() self.render_error.reset_mock() self.complete_sso_login.reset_mock() - def metadata_edit(self, values): + def metadata_edit(self, values: dict) -> ContextManager[Mock]: """Modify the result that will be returned by the well-known query""" metadata = self.fake_server.get_metadata() @@ -196,7 +208,9 @@ class OidcHandlerTestCase(HomeserverTestCase): session = self._generate_oidc_session_token(state, nonce, client_redirect_url) return _build_callback_request(code, state, session), grant - def assertRenderedError(self, error, error_description=None): + def assertRenderedError( + self, error: str, error_description: Optional[str] = None + ) -> Tuple[Any, ...]: self.render_error.assert_called_once() args = self.render_error.call_args[0] self.assertEqual(args[1], error) @@ -273,8 +287,8 @@ class OidcHandlerTestCase(HomeserverTestCase): """Provider metadatas are extensively validated.""" h = self.provider - def force_load_metadata(): - async def force_load(): + def force_load_metadata() -> Awaitable[None]: + async def force_load() -> "OpenIDProviderMetadata": return await h.load_metadata(force=True) return get_awaitable_result(force_load()) @@ -1198,7 +1212,7 @@ def _build_callback_request( state: str, session: str, ip_address: str = "10.0.0.1", -): +) -> Mock: """Builds a fake SynapseRequest to mock the browser callback Returns a Mock object which looks like the SynapseRequest we get from a browser diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 75934b1707..0916de64f5 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -15,12 +15,13 @@ """Tests for the password_auth_provider interface""" from http import HTTPStatus -from typing import Any, Type, Union +from typing import Any, Dict, List, Optional, Type, Union from unittest.mock import Mock import synapse from synapse.api.constants import LoginType from synapse.api.errors import Codes +from synapse.handlers.account import AccountHandler from synapse.module_api import ModuleApi from synapse.rest.client import account, devices, login, logout, register from synapse.types import JsonDict, UserID @@ -44,13 +45,13 @@ class LegacyPasswordOnlyAuthProvider: """A legacy password_provider which only implements `check_password`.""" @staticmethod - def parse_config(self): + def parse_config(config: JsonDict) -> None: pass - def __init__(self, config, account_handler): + def __init__(self, config: None, account_handler: AccountHandler): pass - def check_password(self, *args): + def check_password(self, *args: str) -> Mock: return mock_password_provider.check_password(*args) @@ -58,16 +59,16 @@ class LegacyCustomAuthProvider: """A legacy password_provider which implements a custom login type.""" @staticmethod - def parse_config(self): + def parse_config(config: JsonDict) -> None: pass - def __init__(self, config, account_handler): + def __init__(self, config: None, account_handler: AccountHandler): pass - def get_supported_login_types(self): + def get_supported_login_types(self) -> Dict[str, List[str]]: return {"test.login_type": ["test_field"]} - def check_auth(self, *args): + def check_auth(self, *args: str) -> Mock: return mock_password_provider.check_auth(*args) @@ -75,15 +76,15 @@ class CustomAuthProvider: """A module which registers password_auth_provider callbacks for a custom login type.""" @staticmethod - def parse_config(self): + def parse_config(config: JsonDict) -> None: pass - def __init__(self, config, api: ModuleApi): + def __init__(self, config: None, api: ModuleApi): api.register_password_auth_provider_callbacks( auth_checkers={("test.login_type", ("test_field",)): self.check_auth} ) - def check_auth(self, *args): + def check_auth(self, *args: Any) -> Mock: return mock_password_provider.check_auth(*args) @@ -92,16 +93,16 @@ class LegacyPasswordCustomAuthProvider: as a custom type.""" @staticmethod - def parse_config(self): + def parse_config(config: JsonDict) -> None: pass - def __init__(self, config, account_handler): + def __init__(self, config: None, account_handler: AccountHandler): pass - def get_supported_login_types(self): + def get_supported_login_types(self) -> Dict[str, List[str]]: return {"m.login.password": ["password"], "test.login_type": ["test_field"]} - def check_auth(self, *args): + def check_auth(self, *args: str) -> Mock: return mock_password_provider.check_auth(*args) @@ -110,10 +111,10 @@ class PasswordCustomAuthProvider: as well as a password login""" @staticmethod - def parse_config(self): + def parse_config(config: JsonDict) -> None: pass - def __init__(self, config, api: ModuleApi): + def __init__(self, config: None, api: ModuleApi): api.register_password_auth_provider_callbacks( auth_checkers={ ("test.login_type", ("test_field",)): self.check_auth, @@ -121,10 +122,10 @@ class PasswordCustomAuthProvider: } ) - def check_auth(self, *args): + def check_auth(self, *args: Any) -> Mock: return mock_password_provider.check_auth(*args) - def check_pass(self, *args): + def check_pass(self, *args: str) -> Mock: return mock_password_provider.check_password(*args) @@ -161,16 +162,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): CALLBACK_USERNAME = "get_username_for_registration" CALLBACK_DISPLAYNAME = "get_displayname_for_registration" - def setUp(self): + def setUp(self) -> None: # we use a global mock device, so make sure we are starting with a clean slate mock_password_provider.reset_mock() super().setUp() @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) - def test_password_only_auth_progiver_login_legacy(self): + def test_password_only_auth_progiver_login_legacy(self) -> None: self.password_only_auth_provider_login_test_body() - def password_only_auth_provider_login_test_body(self): + def password_only_auth_provider_login_test_body(self) -> None: # login flows should only have m.login.password flows = self._get_login_flows() self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS) @@ -201,10 +202,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): ) @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) - def test_password_only_auth_provider_ui_auth_legacy(self): + def test_password_only_auth_provider_ui_auth_legacy(self) -> None: self.password_only_auth_provider_ui_auth_test_body() - def password_only_auth_provider_ui_auth_test_body(self): + def password_only_auth_provider_ui_auth_test_body(self) -> None: """UI Auth should delegate correctly to the password provider""" # create the user, otherwise access doesn't work @@ -238,10 +239,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.check_password.assert_called_once_with("@u:test", "p") @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) - def test_local_user_fallback_login_legacy(self): + def test_local_user_fallback_login_legacy(self) -> None: self.local_user_fallback_login_test_body() - def local_user_fallback_login_test_body(self): + def local_user_fallback_login_test_body(self) -> None: """rejected login should fall back to local db""" self.register_user("localuser", "localpass") @@ -255,10 +256,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual("@localuser:test", channel.json_body["user_id"]) @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) - def test_local_user_fallback_ui_auth_legacy(self): + def test_local_user_fallback_ui_auth_legacy(self) -> None: self.local_user_fallback_ui_auth_test_body() - def local_user_fallback_ui_auth_test_body(self): + def local_user_fallback_ui_auth_test_body(self) -> None: """rejected login should fall back to local db""" self.register_user("localuser", "localpass") @@ -298,10 +299,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"localdb_enabled": False}, } ) - def test_no_local_user_fallback_login_legacy(self): + def test_no_local_user_fallback_login_legacy(self) -> None: self.no_local_user_fallback_login_test_body() - def no_local_user_fallback_login_test_body(self): + def no_local_user_fallback_login_test_body(self) -> None: """localdb_enabled can block login with the local password""" self.register_user("localuser", "localpass") @@ -320,10 +321,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"localdb_enabled": False}, } ) - def test_no_local_user_fallback_ui_auth_legacy(self): + def test_no_local_user_fallback_ui_auth_legacy(self) -> None: self.no_local_user_fallback_ui_auth_test_body() - def no_local_user_fallback_ui_auth_test_body(self): + def no_local_user_fallback_ui_auth_test_body(self) -> None: """localdb_enabled can block ui auth with the local password""" self.register_user("localuser", "localpass") @@ -361,10 +362,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"enabled": False}, } ) - def test_password_auth_disabled_legacy(self): + def test_password_auth_disabled_legacy(self) -> None: self.password_auth_disabled_test_body() - def password_auth_disabled_test_body(self): + def password_auth_disabled_test_body(self) -> None: """password auth doesn't work if it's disabled across the board""" # login flows should be empty flows = self._get_login_flows() @@ -376,14 +377,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.check_password.assert_not_called() @override_config(legacy_providers_config(LegacyCustomAuthProvider)) - def test_custom_auth_provider_login_legacy(self): + def test_custom_auth_provider_login_legacy(self) -> None: self.custom_auth_provider_login_test_body() @override_config(providers_config(CustomAuthProvider)) - def test_custom_auth_provider_login(self): + def test_custom_auth_provider_login(self) -> None: self.custom_auth_provider_login_test_body() - def custom_auth_provider_login_test_body(self): + def custom_auth_provider_login_test_body(self) -> None: # login flows should have the custom flow and m.login.password, since we # haven't disabled local password lookup. # (password must come first, because reasons) @@ -424,14 +425,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): ) @override_config(legacy_providers_config(LegacyCustomAuthProvider)) - def test_custom_auth_provider_ui_auth_legacy(self): + def test_custom_auth_provider_ui_auth_legacy(self) -> None: self.custom_auth_provider_ui_auth_test_body() @override_config(providers_config(CustomAuthProvider)) - def test_custom_auth_provider_ui_auth(self): + def test_custom_auth_provider_ui_auth(self) -> None: self.custom_auth_provider_ui_auth_test_body() - def custom_auth_provider_ui_auth_test_body(self): + def custom_auth_provider_ui_auth_test_body(self) -> None: # register the user and log in twice, to get two devices self.register_user("localuser", "localpass") tok1 = self.login("localuser", "localpass") @@ -486,14 +487,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): ) @override_config(legacy_providers_config(LegacyCustomAuthProvider)) - def test_custom_auth_provider_callback_legacy(self): + def test_custom_auth_provider_callback_legacy(self) -> None: self.custom_auth_provider_callback_test_body() @override_config(providers_config(CustomAuthProvider)) - def test_custom_auth_provider_callback(self): + def test_custom_auth_provider_callback(self) -> None: self.custom_auth_provider_callback_test_body() - def custom_auth_provider_callback_test_body(self): + def custom_auth_provider_callback_test_body(self) -> None: callback = Mock(return_value=make_awaitable(None)) mock_password_provider.check_auth.return_value = make_awaitable( @@ -521,16 +522,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"enabled": False}, } ) - def test_custom_auth_password_disabled_legacy(self): + def test_custom_auth_password_disabled_legacy(self) -> None: self.custom_auth_password_disabled_test_body() @override_config( {**providers_config(CustomAuthProvider), "password_config": {"enabled": False}} ) - def test_custom_auth_password_disabled(self): + def test_custom_auth_password_disabled(self) -> None: self.custom_auth_password_disabled_test_body() - def custom_auth_password_disabled_test_body(self): + def custom_auth_password_disabled_test_body(self) -> None: """Test login with a custom auth provider where password login is disabled""" self.register_user("localuser", "localpass") @@ -548,7 +549,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"enabled": False, "localdb_enabled": False}, } ) - def test_custom_auth_password_disabled_localdb_enabled_legacy(self): + def test_custom_auth_password_disabled_localdb_enabled_legacy(self) -> None: self.custom_auth_password_disabled_localdb_enabled_test_body() @override_config( @@ -557,10 +558,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"enabled": False, "localdb_enabled": False}, } ) - def test_custom_auth_password_disabled_localdb_enabled(self): + def test_custom_auth_password_disabled_localdb_enabled(self) -> None: self.custom_auth_password_disabled_localdb_enabled_test_body() - def custom_auth_password_disabled_localdb_enabled_test_body(self): + def custom_auth_password_disabled_localdb_enabled_test_body(self) -> None: """Check the localdb_enabled == enabled == False Regression test for https://github.com/matrix-org/synapse/issues/8914: check @@ -583,7 +584,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"enabled": False}, } ) - def test_password_custom_auth_password_disabled_login_legacy(self): + def test_password_custom_auth_password_disabled_login_legacy(self) -> None: self.password_custom_auth_password_disabled_login_test_body() @override_config( @@ -592,10 +593,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"enabled": False}, } ) - def test_password_custom_auth_password_disabled_login(self): + def test_password_custom_auth_password_disabled_login(self) -> None: self.password_custom_auth_password_disabled_login_test_body() - def password_custom_auth_password_disabled_login_test_body(self): + def password_custom_auth_password_disabled_login_test_body(self) -> None: """log in with a custom auth provider which implements password, but password login is disabled""" self.register_user("localuser", "localpass") @@ -615,7 +616,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"enabled": False}, } ) - def test_password_custom_auth_password_disabled_ui_auth_legacy(self): + def test_password_custom_auth_password_disabled_ui_auth_legacy(self) -> None: self.password_custom_auth_password_disabled_ui_auth_test_body() @override_config( @@ -624,10 +625,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"enabled": False}, } ) - def test_password_custom_auth_password_disabled_ui_auth(self): + def test_password_custom_auth_password_disabled_ui_auth(self) -> None: self.password_custom_auth_password_disabled_ui_auth_test_body() - def password_custom_auth_password_disabled_ui_auth_test_body(self): + def password_custom_auth_password_disabled_ui_auth_test_body(self) -> None: """UI Auth with a custom auth provider which implements password, but password login is disabled""" # register the user and log in twice via the test login type to get two devices, @@ -689,7 +690,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"localdb_enabled": False}, } ) - def test_custom_auth_no_local_user_fallback_legacy(self): + def test_custom_auth_no_local_user_fallback_legacy(self) -> None: self.custom_auth_no_local_user_fallback_test_body() @override_config( @@ -698,10 +699,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"localdb_enabled": False}, } ) - def test_custom_auth_no_local_user_fallback(self): + def test_custom_auth_no_local_user_fallback(self) -> None: self.custom_auth_no_local_user_fallback_test_body() - def custom_auth_no_local_user_fallback_test_body(self): + def custom_auth_no_local_user_fallback_test_body(self) -> None: """Test login with a custom auth provider where the local db is disabled""" self.register_user("localuser", "localpass") @@ -713,14 +714,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): channel = self._send_password_login("localuser", "localpass") self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) - def test_on_logged_out(self): + def test_on_logged_out(self) -> None: """Tests that the on_logged_out callback is called when the user logs out.""" self.register_user("rin", "password") tok = self.login("rin", "password") self.called = False - async def on_logged_out(user_id, device_id, access_token): + async def on_logged_out( + user_id: str, device_id: Optional[str], access_token: str + ) -> None: self.called = True on_logged_out = Mock(side_effect=on_logged_out) @@ -738,7 +741,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): on_logged_out.assert_called_once() self.assertTrue(self.called) - def test_username(self): + def test_username(self) -> None: """Tests that the get_username_for_registration callback can define the username of a user when registering. """ @@ -763,7 +766,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mxid = channel.json_body["user_id"] self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo") - def test_username_uia(self): + def test_username_uia(self) -> None: """Tests that the get_username_for_registration callback is only called at the end of the UIA flow. """ @@ -782,7 +785,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # Set some email configuration so the test doesn't fail because of its absence. @override_config({"email": {"notif_from": "noreply@test"}}) - def test_3pid_allowed(self): + def test_3pid_allowed(self) -> None: """Tests that an is_3pid_allowed_callbacks forbidding a 3PID makes Synapse refuse to bind the new 3PID, and that one allowing a 3PID makes Synapse accept to bind the 3PID. Also checks that the module is passed a boolean indicating whether the @@ -791,7 +794,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self._test_3pid_allowed("rin", False) self._test_3pid_allowed("kitay", True) - def test_displayname(self): + def test_displayname(self) -> None: """Tests that the get_displayname_for_registration callback can define the display name of a user when registering. """ @@ -820,7 +823,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(display_name, username + "-foo") - def test_displayname_uia(self): + def test_displayname_uia(self) -> None: """Tests that the get_displayname_for_registration callback is only called at the end of the UIA flow. """ @@ -841,7 +844,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # Check that the callback has been called. m.assert_called_once() - def _test_3pid_allowed(self, username: str, registration: bool): + def _test_3pid_allowed(self, username: str, registration: bool) -> None: """Tests that the "is_3pid_allowed" module callback is called correctly, using either /register or /account URLs depending on the arguments. @@ -907,7 +910,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): client is trying to register. """ - async def callback(uia_results, params): + async def callback(uia_results: JsonDict, params: JsonDict) -> str: self.assertIn(LoginType.DUMMY, uia_results) username = params["username"] return username + "-foo" @@ -950,12 +953,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): def _send_password_login(self, user: str, password: str) -> FakeChannel: return self._send_login(type="m.login.password", user=user, password=password) - def _send_login(self, type, user, **params) -> FakeChannel: - params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type}) + def _send_login(self, type: str, user: str, **extra_params: str) -> FakeChannel: + params = {"identifier": {"type": "m.id.user", "user": user}, "type": type} + params.update(extra_params) channel = self.make_request("POST", "/_matrix/client/r0/login", params) return channel - def _start_delete_device_session(self, access_token, device_id) -> str: + def _start_delete_device_session(self, access_token: str, device_id: str) -> str: """Make an initial delete device request, and return the UI Auth session ID""" channel = self._delete_device(access_token, device_id) self.assertEqual(channel.code, 401) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 584e7b8971..19f5322317 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, cast from unittest.mock import Mock, call from parameterized import parameterized from signedjson.key import generate_signing_key +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS @@ -35,7 +37,9 @@ from synapse.handlers.presence import ( ) from synapse.rest import admin from synapse.rest.client import room -from synapse.types import UserID, get_domain_from_id +from synapse.server import HomeServer +from synapse.types import JsonDict, UserID, get_domain_from_id +from synapse.util import Clock from tests import unittest from tests.replication._base import BaseMultiWorkerStreamTestCase @@ -44,10 +48,12 @@ from tests.replication._base import BaseMultiWorkerStreamTestCase class PresenceUpdateTestCase(unittest.HomeserverTestCase): servlets = [admin.register_servlets] - def prepare(self, reactor, clock, homeserver): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: self.store = homeserver.get_datastores().main - def test_offline_to_online(self): + def test_offline_to_online(self) -> None: wheel_timer = Mock() user_id = "@foo:bar" now = 5000000 @@ -85,7 +91,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): any_order=True, ) - def test_online_to_online(self): + def test_online_to_online(self) -> None: wheel_timer = Mock() user_id = "@foo:bar" now = 5000000 @@ -128,7 +134,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): any_order=True, ) - def test_online_to_online_last_active_noop(self): + def test_online_to_online_last_active_noop(self) -> None: wheel_timer = Mock() user_id = "@foo:bar" now = 5000000 @@ -173,7 +179,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): any_order=True, ) - def test_online_to_online_last_active(self): + def test_online_to_online_last_active(self) -> None: wheel_timer = Mock() user_id = "@foo:bar" now = 5000000 @@ -210,7 +216,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): any_order=True, ) - def test_remote_ping_timer(self): + def test_remote_ping_timer(self) -> None: wheel_timer = Mock() user_id = "@foo:bar" now = 5000000 @@ -244,7 +250,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): any_order=True, ) - def test_online_to_offline(self): + def test_online_to_offline(self) -> None: wheel_timer = Mock() user_id = "@foo:bar" now = 5000000 @@ -266,7 +272,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): self.assertEqual(wheel_timer.insert.call_count, 0) - def test_online_to_idle(self): + def test_online_to_idle(self) -> None: wheel_timer = Mock() user_id = "@foo:bar" now = 5000000 @@ -300,7 +306,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): any_order=True, ) - def test_persisting_presence_updates(self): + def test_persisting_presence_updates(self) -> None: """Tests that the latest presence state for each user is persisted correctly""" # Create some test users and presence states for them presence_states = [] @@ -322,7 +328,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): self.get_success(self.store.update_presence(presence_states)) # Check that each update is present in the database - db_presence_states = self.get_success( + db_presence_states_raw = self.get_success( self.store.get_all_presence_updates( instance_name="master", last_id=0, @@ -332,7 +338,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): ) # Extract presence update user ID and state information into lists of tuples - db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]] + db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states_raw[0]] presence_states_compare = [(ps.user_id, ps.state) for ps in presence_states] # Compare what we put into the storage with what we got out. @@ -343,7 +349,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): class PresenceTimeoutTestCase(unittest.TestCase): """Tests different timers and that the timer does not change `status_msg` of user.""" - def test_idle_timer(self): + def test_idle_timer(self) -> None: user_id = "@foo:bar" status_msg = "I'm here!" now = 5000000 @@ -363,7 +369,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertEqual(new_state.state, PresenceState.UNAVAILABLE) self.assertEqual(new_state.status_msg, status_msg) - def test_busy_no_idle(self): + def test_busy_no_idle(self) -> None: """ Tests that a user setting their presence to busy but idling doesn't turn their presence state into unavailable. @@ -387,7 +393,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertEqual(new_state.state, PresenceState.BUSY) self.assertEqual(new_state.status_msg, status_msg) - def test_sync_timeout(self): + def test_sync_timeout(self) -> None: user_id = "@foo:bar" status_msg = "I'm here!" now = 5000000 @@ -407,7 +413,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertEqual(new_state.state, PresenceState.OFFLINE) self.assertEqual(new_state.status_msg, status_msg) - def test_sync_online(self): + def test_sync_online(self) -> None: user_id = "@foo:bar" status_msg = "I'm here!" now = 5000000 @@ -429,7 +435,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertEqual(new_state.state, PresenceState.ONLINE) self.assertEqual(new_state.status_msg, status_msg) - def test_federation_ping(self): + def test_federation_ping(self) -> None: user_id = "@foo:bar" status_msg = "I'm here!" now = 5000000 @@ -448,7 +454,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertIsNotNone(new_state) self.assertEqual(state, new_state) - def test_no_timeout(self): + def test_no_timeout(self) -> None: user_id = "@foo:bar" now = 5000000 @@ -464,7 +470,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertIsNone(new_state) - def test_federation_timeout(self): + def test_federation_timeout(self) -> None: user_id = "@foo:bar" status_msg = "I'm here!" now = 5000000 @@ -487,7 +493,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertEqual(new_state.state, PresenceState.OFFLINE) self.assertEqual(new_state.status_msg, status_msg) - def test_last_active(self): + def test_last_active(self) -> None: user_id = "@foo:bar" status_msg = "I'm here!" now = 5000000 @@ -508,15 +514,15 @@ class PresenceTimeoutTestCase(unittest.TestCase): class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.presence_handler = hs.get_presence_handler() self.clock = hs.get_clock() - def test_external_process_timeout(self): + def test_external_process_timeout(self) -> None: """Test that if an external process doesn't update the records for a while we time out their syncing users presence. """ - process_id = 1 + process_id = "1" user_id = "@test:server" # Notify handler that a user is now syncing. @@ -544,7 +550,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): ) self.assertEqual(state.state, PresenceState.OFFLINE) - def test_user_goes_offline_by_timeout_status_msg_remain(self): + def test_user_goes_offline_by_timeout_status_msg_remain(self) -> None: """Test that if a user doesn't update the records for a while users presence goes `OFFLINE` because of timeout and `status_msg` remains. """ @@ -576,7 +582,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): self.assertEqual(state.state, PresenceState.OFFLINE) self.assertEqual(state.status_msg, status_msg) - def test_user_goes_offline_manually_with_no_status_msg(self): + def test_user_goes_offline_manually_with_no_status_msg(self) -> None: """Test that if a user change presence manually to `OFFLINE` and no status is set, that `status_msg` is `None`. """ @@ -601,7 +607,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): self.assertEqual(state.state, PresenceState.OFFLINE) self.assertEqual(state.status_msg, None) - def test_user_goes_offline_manually_with_status_msg(self): + def test_user_goes_offline_manually_with_status_msg(self) -> None: """Test that if a user change presence manually to `OFFLINE` and a status is set, that `status_msg` appears. """ @@ -618,7 +624,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): user_id, PresenceState.OFFLINE, "And now here." ) - def test_user_reset_online_with_no_status(self): + def test_user_reset_online_with_no_status(self) -> None: """Test that if a user set again the presence manually and no status is set, that `status_msg` is `None`. """ @@ -644,7 +650,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): self.assertEqual(state.state, PresenceState.ONLINE) self.assertEqual(state.status_msg, None) - def test_set_presence_with_status_msg_none(self): + def test_set_presence_with_status_msg_none(self) -> None: """Test that if a user set again the presence manually and status is `None`, that `status_msg` is `None`. """ @@ -659,7 +665,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): # Mark user as online and `status_msg = None` self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None) - def test_set_presence_from_syncing_not_set(self): + def test_set_presence_from_syncing_not_set(self) -> None: """Test that presence is not set by syncing if affect_presence is false""" user_id = "@test:server" status_msg = "I'm here!" @@ -680,7 +686,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): # and status message should still be the same self.assertEqual(state.status_msg, status_msg) - def test_set_presence_from_syncing_is_set(self): + def test_set_presence_from_syncing_is_set(self) -> None: """Test that presence is set by syncing if affect_presence is true""" user_id = "@test:server" status_msg = "I'm here!" @@ -699,7 +705,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): # we should now be online self.assertEqual(state.state, PresenceState.ONLINE) - def test_set_presence_from_syncing_keeps_status(self): + def test_set_presence_from_syncing_keeps_status(self) -> None: """Test that presence set by syncing retains status message""" user_id = "@test:server" status_msg = "I'm here!" @@ -726,7 +732,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): }, } ) - def test_set_presence_from_syncing_keeps_busy(self, test_with_workers: bool): + def test_set_presence_from_syncing_keeps_busy( + self, test_with_workers: bool + ) -> None: """Test that presence set by syncing doesn't affect busy status Args: @@ -767,7 +775,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): def _set_presencestate_with_status_msg( self, user_id: str, state: str, status_msg: Optional[str] - ): + ) -> None: """Set a PresenceState and status_msg and check the result. Args: @@ -790,14 +798,14 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.presence_handler = hs.get_presence_handler() self.clock = hs.get_clock() self.instance_name = hs.get_instance_name() self.queue = self.presence_handler.get_federation_queue() - def test_send_and_get(self): + def test_send_and_get(self) -> None: state1 = UserPresenceState.default("@user1:test") state2 = UserPresenceState.default("@user2:test") state3 = UserPresenceState.default("@user3:test") @@ -834,7 +842,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): self.assertFalse(limited) self.assertCountEqual(rows, []) - def test_send_and_get_split(self): + def test_send_and_get_split(self) -> None: state1 = UserPresenceState.default("@user1:test") state2 = UserPresenceState.default("@user2:test") state3 = UserPresenceState.default("@user3:test") @@ -877,7 +885,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): self.assertCountEqual(rows, expected_rows) - def test_clear_queue_all(self): + def test_clear_queue_all(self) -> None: state1 = UserPresenceState.default("@user1:test") state2 = UserPresenceState.default("@user2:test") state3 = UserPresenceState.default("@user3:test") @@ -921,7 +929,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): self.assertCountEqual(rows, expected_rows) - def test_partially_clear_queue(self): + def test_partially_clear_queue(self) -> None: state1 = UserPresenceState.default("@user1:test") state2 = UserPresenceState.default("@user2:test") state3 = UserPresenceState.default("@user3:test") @@ -982,7 +990,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): servlets = [room.register_servlets] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver( "server", federation_http_client=None, @@ -990,14 +998,14 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): ) return hs - def default_config(self): + def default_config(self) -> JsonDict: config = super().default_config() # Enable federation sending on the main process. config["federation_sender_instances"] = None return config - def prepare(self, reactor, clock, hs): - self.federation_sender = hs.get_federation_sender() + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.federation_sender = cast(Mock, hs.get_federation_sender()) self.event_builder_factory = hs.get_event_builder_factory() self.federation_event_handler = hs.get_federation_event_handler() self.presence_handler = hs.get_presence_handler() @@ -1013,7 +1021,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): # random key to use. self.random_signing_key = generate_signing_key("ver") - def test_remote_joins(self): + def test_remote_joins(self) -> None: # We advance time to something that isn't 0, as we use 0 as a special # value. self.reactor.advance(1000000000000) @@ -1061,7 +1069,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): destinations={"server3"}, states=[expected_state] ) - def test_remote_gets_presence_when_local_user_joins(self): + def test_remote_gets_presence_when_local_user_joins(self) -> None: # We advance time to something that isn't 0, as we use 0 as a special # value. self.reactor.advance(1000000000000) @@ -1110,7 +1118,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): destinations={"server2", "server3"}, states=[expected_state] ) - def _add_new_user(self, room_id, user_id): + def _add_new_user(self, room_id: str, user_id: str) -> None: """Add new user to the room by creating an event and poking the federation API.""" hostname = get_domain_from_id(user_id) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 675aa023ac..7c174782da 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -332,7 +332,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): @unittest.override_config( {"server_name": "test:8888", "allowed_avatar_mimetypes": ["image/png"]} ) - def test_avatar_constraint_on_local_server_with_port(self): + def test_avatar_constraint_on_local_server_with_port(self) -> None: """Test that avatar metadata is correctly fetched when the media is on a local server and the server has an explicit port. @@ -376,7 +376,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.get_success(self.handler.check_avatar_size_and_mime_type(remote_mxc)) ) - def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]): + def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]) -> None: """Stores metadata about files in the database. Args: diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py index b55238650c..f60400ff8d 100644 --- a/tests/handlers/test_receipts.py +++ b/tests/handlers/test_receipts.py @@ -15,14 +15,18 @@ from copy import deepcopy from typing import List +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EduTypes, ReceiptTypes +from synapse.server import HomeServer from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest class ReceiptsTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.event_source = hs.get_event_sources().sources.receipt def test_filters_out_private_receipt(self) -> None: diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 765df75d91..b9332d97dc 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Collection, List, Optional, Tuple from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.auth import Auth from synapse.api.constants import UserTypes from synapse.api.errors import ( @@ -22,8 +25,18 @@ from synapse.api.errors import ( ResourceLimitError, SynapseError, ) +from synapse.module_api import ModuleApi +from synapse.server import HomeServer from synapse.spam_checker_api import RegistrationBehaviour -from synapse.types import RoomAlias, RoomID, UserID, create_requester +from synapse.types import ( + JsonDict, + Requester, + RoomAlias, + RoomID, + UserID, + create_requester, +) +from synapse.util import Clock from tests.test_utils import make_awaitable from tests.unittest import override_config @@ -33,94 +46,98 @@ from .. import unittest class TestSpamChecker: - def __init__(self, config, api): + def __init__(self, config: None, api: ModuleApi): api.register_spam_checker_callbacks( check_registration_for_spam=self.check_registration_for_spam, ) @staticmethod - def parse_config(config): - return config + def parse_config(config: JsonDict) -> None: + return None async def check_registration_for_spam( self, - email_threepid, - username, - request_info, - auth_provider_id, - ): + email_threepid: Optional[dict], + username: Optional[str], + request_info: Collection[Tuple[str, str]], + auth_provider_id: Optional[str], + ) -> RegistrationBehaviour: pass class DenyAll(TestSpamChecker): async def check_registration_for_spam( self, - email_threepid, - username, - request_info, - auth_provider_id, - ): + email_threepid: Optional[dict], + username: Optional[str], + request_info: Collection[Tuple[str, str]], + auth_provider_id: Optional[str], + ) -> RegistrationBehaviour: return RegistrationBehaviour.DENY class BanAll(TestSpamChecker): async def check_registration_for_spam( self, - email_threepid, - username, - request_info, - auth_provider_id, - ): + email_threepid: Optional[dict], + username: Optional[str], + request_info: Collection[Tuple[str, str]], + auth_provider_id: Optional[str], + ) -> RegistrationBehaviour: return RegistrationBehaviour.SHADOW_BAN class BanBadIdPUser(TestSpamChecker): async def check_registration_for_spam( - self, email_threepid, username, request_info, auth_provider_id=None - ): + self, + email_threepid: Optional[dict], + username: Optional[str], + request_info: Collection[Tuple[str, str]], + auth_provider_id: Optional[str] = None, + ) -> RegistrationBehaviour: # Reject any user coming from CAS and whose username contains profanity - if auth_provider_id == "cas" and "flimflob" in username: + if auth_provider_id == "cas" and username and "flimflob" in username: return RegistrationBehaviour.DENY return RegistrationBehaviour.ALLOW class TestLegacyRegistrationSpamChecker: - def __init__(self, config, api): + def __init__(self, config: None, api: ModuleApi): pass async def check_registration_for_spam( self, - email_threepid, - username, - request_info, - ): + email_threepid: Optional[dict], + username: Optional[str], + request_info: Collection[Tuple[str, str]], + ) -> RegistrationBehaviour: pass class LegacyAllowAll(TestLegacyRegistrationSpamChecker): async def check_registration_for_spam( self, - email_threepid, - username, - request_info, - ): + email_threepid: Optional[dict], + username: Optional[str], + request_info: Collection[Tuple[str, str]], + ) -> RegistrationBehaviour: return RegistrationBehaviour.ALLOW class LegacyDenyAll(TestLegacyRegistrationSpamChecker): async def check_registration_for_spam( self, - email_threepid, - username, - request_info, - ): + email_threepid: Optional[dict], + username: Optional[str], + request_info: Collection[Tuple[str, str]], + ) -> RegistrationBehaviour: return RegistrationBehaviour.DENY class RegistrationTestCase(unittest.HomeserverTestCase): """Tests the RegistrationHandler.""" - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs_config = self.default_config() # some of the tests rely on us having a user consent version @@ -145,7 +162,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = self.hs.get_registration_handler() self.store = self.hs.get_datastores().main self.lots_of_users = 100 @@ -153,7 +170,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.requester = create_requester("@requester:test") - def test_user_is_created_and_logged_in_if_doesnt_exist(self): + def test_user_is_created_and_logged_in_if_doesnt_exist(self) -> None: frank = UserID.from_string("@frank:test") user_id = frank.to_string() requester = create_requester(user_id) @@ -164,7 +181,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertIsInstance(result_token, str) self.assertGreater(len(result_token), 20) - def test_if_user_exists(self): + def test_if_user_exists(self) -> None: store = self.hs.get_datastores().main frank = UserID.from_string("@frank:test") self.get_success( @@ -180,12 +197,12 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertTrue(result_token is not None) @override_config({"limit_usage_by_mau": False}) - def test_mau_limits_when_disabled(self): + def test_mau_limits_when_disabled(self) -> None: # Ensure does not throw exception self.get_success(self.get_or_create_user(self.requester, "a", "display_name")) @override_config({"limit_usage_by_mau": True}) - def test_get_or_create_user_mau_not_blocked(self): + def test_get_or_create_user_mau_not_blocked(self) -> None: self.store.count_monthly_users = Mock( return_value=make_awaitable(self.hs.config.server.max_mau_value - 1) ) @@ -193,7 +210,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.get_success(self.get_or_create_user(self.requester, "c", "User")) @override_config({"limit_usage_by_mau": True}) - def test_get_or_create_user_mau_blocked(self): + def test_get_or_create_user_mau_blocked(self) -> None: self.store.get_monthly_active_count = Mock( return_value=make_awaitable(self.lots_of_users) ) @@ -211,7 +228,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ) @override_config({"limit_usage_by_mau": True}) - def test_register_mau_blocked(self): + def test_register_mau_blocked(self) -> None: self.store.get_monthly_active_count = Mock( return_value=make_awaitable(self.lots_of_users) ) @@ -229,7 +246,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): @override_config( {"auto_join_rooms": ["#room:test"], "auto_join_rooms_for_guests": False} ) - def test_auto_join_rooms_for_guests(self): + def test_auto_join_rooms_for_guests(self) -> None: user_id = self.get_success( self.handler.register_user(localpart="jeff", make_guest=True), ) @@ -237,7 +254,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertEqual(len(rooms), 0) @override_config({"auto_join_rooms": ["#room:test"]}) - def test_auto_create_auto_join_rooms(self): + def test_auto_create_auto_join_rooms(self) -> None: room_alias_str = "#room:test" user_id = self.get_success(self.handler.register_user(localpart="jeff")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) @@ -249,7 +266,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertEqual(len(rooms), 1) @override_config({"auto_join_rooms": []}) - def test_auto_create_auto_join_rooms_with_no_rooms(self): + def test_auto_create_auto_join_rooms_with_no_rooms(self) -> None: frank = UserID.from_string("@frank:test") user_id = self.get_success(self.handler.register_user(frank.localpart)) self.assertEqual(user_id, frank.to_string()) @@ -257,7 +274,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertEqual(len(rooms), 0) @override_config({"auto_join_rooms": ["#room:another"]}) - def test_auto_create_auto_join_where_room_is_another_domain(self): + def test_auto_create_auto_join_where_room_is_another_domain(self) -> None: frank = UserID.from_string("@frank:test") user_id = self.get_success(self.handler.register_user(frank.localpart)) self.assertEqual(user_id, frank.to_string()) @@ -267,13 +284,13 @@ class RegistrationTestCase(unittest.HomeserverTestCase): @override_config( {"auto_join_rooms": ["#room:test"], "autocreate_auto_join_rooms": False} ) - def test_auto_create_auto_join_where_auto_create_is_false(self): + def test_auto_create_auto_join_where_auto_create_is_false(self) -> None: user_id = self.get_success(self.handler.register_user(localpart="jeff")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) @override_config({"auto_join_rooms": ["#room:test"]}) - def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self): + def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self) -> None: room_alias_str = "#room:test" self.store.is_real_user = Mock(return_value=make_awaitable(False)) user_id = self.get_success(self.handler.register_user(localpart="support")) @@ -284,7 +301,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.get_failure(directory_handler.get_association(room_alias), SynapseError) @override_config({"auto_join_rooms": ["#room:test"]}) - def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self): + def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None: room_alias_str = "#room:test" self.store.count_real_users = Mock(return_value=make_awaitable(1)) @@ -299,7 +316,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertEqual(len(rooms), 1) @override_config({"auto_join_rooms": ["#room:test"]}) - def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(self): + def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user( + self, + ) -> None: self.store.count_real_users = Mock(return_value=make_awaitable(2)) self.store.is_real_user = Mock(return_value=make_awaitable(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) @@ -312,7 +331,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): "autocreate_auto_join_rooms_federated": False, } ) - def test_auto_create_auto_join_rooms_federated(self): + def test_auto_create_auto_join_rooms_federated(self) -> None: """ Auto-created rooms that are private require an invite to go to the user (instead of directly joining it). @@ -339,7 +358,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): @override_config( {"auto_join_rooms": ["#room:test"], "auto_join_mxid_localpart": "support"} ) - def test_auto_join_mxid_localpart(self): + def test_auto_join_mxid_localpart(self) -> None: """ Ensure the user still needs up in the room created by a different user. """ @@ -376,7 +395,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): "auto_join_mxid_localpart": "support", } ) - def test_auto_create_auto_join_room_preset(self): + def test_auto_create_auto_join_room_preset(self) -> None: """ Auto-created rooms that are private require an invite to go to the user (instead of directly joining it). @@ -416,7 +435,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): "auto_join_mxid_localpart": "support", } ) - def test_auto_create_auto_join_room_preset_guest(self): + def test_auto_create_auto_join_room_preset_guest(self) -> None: """ Auto-created rooms that are private require an invite to go to the user (instead of directly joining it). @@ -454,7 +473,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): "auto_join_mxid_localpart": "support", } ) - def test_auto_create_auto_join_room_preset_invalid_permissions(self): + def test_auto_create_auto_join_room_preset_invalid_permissions(self) -> None: """ Auto-created rooms that are private require an invite, check that registration doesn't completely break if the inviter doesn't have proper @@ -525,7 +544,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): "auto_join_rooms": ["#room:test"], }, ) - def test_auto_create_auto_join_where_no_consent(self): + def test_auto_create_auto_join_where_no_consent(self) -> None: """Test to ensure that the first user is not auto-joined to a room if they have not given general consent. """ @@ -550,19 +569,19 @@ class RegistrationTestCase(unittest.HomeserverTestCase): rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 1) - def test_register_support_user(self): + def test_register_support_user(self) -> None: user_id = self.get_success( self.handler.register_user(localpart="user", user_type=UserTypes.SUPPORT) ) d = self.store.is_support_user(user_id) self.assertTrue(self.get_success(d)) - def test_register_not_support_user(self): + def test_register_not_support_user(self) -> None: user_id = self.get_success(self.handler.register_user(localpart="user")) d = self.store.is_support_user(user_id) self.assertFalse(self.get_success(d)) - def test_invalid_user_id_length(self): + def test_invalid_user_id_length(self) -> None: invalid_user_id = "x" * 256 self.get_failure( self.handler.register_user(localpart=invalid_user_id), SynapseError @@ -577,7 +596,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ] } ) - def test_spam_checker_deny(self): + def test_spam_checker_deny(self) -> None: """A spam checker can deny registration, which results in an error.""" self.get_failure(self.handler.register_user(localpart="user"), SynapseError) @@ -590,7 +609,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ] } ) - def test_spam_checker_legacy_allow(self): + def test_spam_checker_legacy_allow(self) -> None: """Tests that a legacy spam checker implementing the legacy 3-arg version of the check_registration_for_spam callback is correctly called. @@ -610,7 +629,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ] } ) - def test_spam_checker_legacy_deny(self): + def test_spam_checker_legacy_deny(self) -> None: """Tests that a legacy spam checker implementing the legacy 3-arg version of the check_registration_for_spam callback is correctly called. @@ -630,7 +649,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ] } ) - def test_spam_checker_shadow_ban(self): + def test_spam_checker_shadow_ban(self) -> None: """A spam checker can choose to shadow-ban a user, which allows registration to succeed.""" user_id = self.get_success(self.handler.register_user(localpart="user")) @@ -660,7 +679,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ] } ) - def test_spam_checker_receives_sso_type(self): + def test_spam_checker_receives_sso_type(self) -> None: """Test rejecting registration based on SSO type""" f = self.get_failure( self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"), @@ -678,8 +697,12 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ) async def get_or_create_user( - self, requester, localpart, displayname, password_hash=None - ): + self, + requester: Requester, + localpart: str, + displayname: Optional[str], + password_hash: Optional[str] = None, + ) -> Tuple[str, str]: """Creates a new user if the user does not exist, else revokes all previous access tokens and generates a new one. @@ -734,13 +757,15 @@ class RegistrationTestCase(unittest.HomeserverTestCase): class RemoteAutoJoinTestCase(unittest.HomeserverTestCase): """Tests auto-join on remote rooms.""" - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.room_id = "!roomid:remotetest" - async def update_membership(*args, **kwargs): + async def update_membership(*args: Any, **kwargs: Any) -> None: pass - async def lookup_room_alias(*args, **kwargs): + async def lookup_room_alias( + *args: Any, **kwargs: Any + ) -> Tuple[RoomID, List[str]]: return RoomID.from_string(self.room_id), ["remotetest"] self.room_member_handler = Mock(spec=["update_membership", "lookup_room_alias"]) @@ -750,12 +775,12 @@ class RemoteAutoJoinTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver(room_member_handler=self.room_member_handler) return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = self.hs.get_registration_handler() self.store = self.hs.get_datastores().main @override_config({"auto_join_rooms": ["#room:remotetest"]}) - def test_auto_create_auto_join_remote_room(self): + def test_auto_create_auto_join_remote_room(self) -> None: """Tests that we don't attempt to create remote rooms, and that we don't attempt to invite ourselves to rooms we're not in.""" diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py index fcde5dab72..df95490d3b 100644 --- a/tests/handlers/test_room.py +++ b/tests/handlers/test_room.py @@ -14,7 +14,7 @@ class EncryptedByDefaultTestCase(unittest.HomeserverTestCase): ] @override_config({"encryption_enabled_by_default_for_room_type": "all"}) - def test_encrypted_by_default_config_option_all(self): + def test_encrypted_by_default_config_option_all(self) -> None: """Tests that invite-only and non-invite-only rooms have encryption enabled by default when the config option encryption_enabled_by_default_for_room_type is "all". """ @@ -45,7 +45,7 @@ class EncryptedByDefaultTestCase(unittest.HomeserverTestCase): self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT}) @override_config({"encryption_enabled_by_default_for_room_type": "invite"}) - def test_encrypted_by_default_config_option_invite(self): + def test_encrypted_by_default_config_option_invite(self) -> None: """Tests that only new, invite-only rooms have encryption enabled by default when the config option encryption_enabled_by_default_for_room_type is "invite". """ @@ -76,7 +76,7 @@ class EncryptedByDefaultTestCase(unittest.HomeserverTestCase): ) @override_config({"encryption_enabled_by_default_for_room_type": "off"}) - def test_encrypted_by_default_config_option_off(self): + def test_encrypted_by_default_config_option_off(self) -> None: """Tests that neither new invite-only nor non-invite-only rooms have encryption enabled by default when the config option encryption_enabled_by_default_for_room_type is "off". diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index aa650756e4..d907fcaf04 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple from unittest import mock from twisted.internet.defer import ensureDeferred +from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import ( EventContentFields, @@ -34,11 +35,14 @@ from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.types import JsonDict, UserID, create_requester +from synapse.util import Clock from tests import unittest -def _create_event(room_id: str, order: Optional[Any] = None, origin_server_ts: int = 0): +def _create_event( + room_id: str, order: Optional[Any] = None, origin_server_ts: int = 0 +) -> mock.Mock: result = mock.Mock(name=room_id) result.room_id = room_id result.content = {} @@ -48,40 +52,40 @@ def _create_event(room_id: str, order: Optional[Any] = None, origin_server_ts: i return result -def _order(*events): +def _order(*events: mock.Mock) -> List[mock.Mock]: return sorted(events, key=_child_events_comparison_key) class TestSpaceSummarySort(unittest.TestCase): - def test_no_order_last(self): + def test_no_order_last(self) -> None: """An event with no ordering is placed behind those with an ordering.""" ev1 = _create_event("!abc:test") ev2 = _create_event("!xyz:test", "xyz") self.assertEqual([ev2, ev1], _order(ev1, ev2)) - def test_order(self): + def test_order(self) -> None: """The ordering should be used.""" ev1 = _create_event("!abc:test", "xyz") ev2 = _create_event("!xyz:test", "abc") self.assertEqual([ev2, ev1], _order(ev1, ev2)) - def test_order_origin_server_ts(self): + def test_order_origin_server_ts(self) -> None: """Origin server is a tie-breaker for ordering.""" ev1 = _create_event("!abc:test", origin_server_ts=10) ev2 = _create_event("!xyz:test", origin_server_ts=30) self.assertEqual([ev1, ev2], _order(ev1, ev2)) - def test_order_room_id(self): + def test_order_room_id(self) -> None: """Room ID is a final tie-breaker for ordering.""" ev1 = _create_event("!abc:test") ev2 = _create_event("!xyz:test") self.assertEqual([ev1, ev2], _order(ev1, ev2)) - def test_invalid_ordering_type(self): + def test_invalid_ordering_type(self) -> None: """Invalid orderings are considered the same as missing.""" ev1 = _create_event("!abc:test", 1) ev2 = _create_event("!xyz:test", "xyz") @@ -97,7 +101,7 @@ class TestSpaceSummarySort(unittest.TestCase): ev1 = _create_event("!abc:test", True) self.assertEqual([ev2, ev1], _order(ev1, ev2)) - def test_invalid_ordering_value(self): + def test_invalid_ordering_value(self) -> None: """Invalid orderings are considered the same as missing.""" ev1 = _create_event("!abc:test", "foo\n") ev2 = _create_event("!xyz:test", "xyz") @@ -115,7 +119,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.hs = hs self.handler = self.hs.get_room_summary_handler() @@ -223,7 +227,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6) ) - def test_simple_space(self): + def test_simple_space(self) -> None: """Test a simple space with a single room.""" # The result should have the space and the room in it, along with a link # from space -> room. @@ -234,7 +238,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ) self._assert_hierarchy(result, expected) - def test_large_space(self): + def test_large_space(self) -> None: """Test a space with a large number of rooms.""" rooms = [self.room] # Make at least 51 rooms that are part of the space. @@ -260,7 +264,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): result["rooms"] += result2["rooms"] self._assert_hierarchy(result, expected) - def test_visibility(self): + def test_visibility(self) -> None: """A user not in a space cannot inspect it.""" user2 = self.register_user("user2", "pass") token2 = self.login("user2", "pass") @@ -380,7 +384,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self._assert_hierarchy(result2, [(self.space, [self.room])]) def _create_room_with_join_rule( - self, join_rule: str, room_version: Optional[str] = None, **extra_content + self, join_rule: str, room_version: Optional[str] = None, **extra_content: Any ) -> str: """Create a room with the given join rule and add it to the space.""" room_id = self.helper.create_room_as( @@ -403,7 +407,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self._add_child(self.space, room_id, self.token) return room_id - def test_filtering(self): + def test_filtering(self) -> None: """ Rooms should be properly filtered to only include rooms the user has access to. """ @@ -476,7 +480,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ) self._assert_hierarchy(result, expected) - def test_complex_space(self): + def test_complex_space(self) -> None: """ Create a "complex" space to see how it handles things like loops and subspaces. """ @@ -516,7 +520,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ) self._assert_hierarchy(result, expected) - def test_pagination(self): + def test_pagination(self) -> None: """Test simple pagination works.""" room_ids = [] for i in range(1, 10): @@ -553,7 +557,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self._assert_hierarchy(result, expected) self.assertNotIn("next_batch", result) - def test_invalid_pagination_token(self): + def test_invalid_pagination_token(self) -> None: """An invalid pagination token, or changing other parameters, shoudl be rejected.""" room_ids = [] for i in range(1, 10): @@ -604,7 +608,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): SynapseError, ) - def test_max_depth(self): + def test_max_depth(self) -> None: """Create a deep tree to test the max depth against.""" spaces = [self.space] rooms = [self.room] @@ -659,7 +663,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ] self._assert_hierarchy(result, expected) - def test_unknown_room_version(self): + def test_unknown_room_version(self) -> None: """ If a room with an unknown room version is encountered it should not cause the entire summary to skip. @@ -685,7 +689,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ) self._assert_hierarchy(result, expected) - def test_fed_complex(self): + def test_fed_complex(self) -> None: """ Return data over federation and ensure that it is handled properly. """ @@ -722,7 +726,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): "world_readable": True, } - async def summarize_remote_room_hierarchy(_self, room, suggested_only): + async def summarize_remote_room_hierarchy( + _self: Any, room: Any, suggested_only: bool + ) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]: return requested_room_entry, {subroom: child_room}, set() # Add a room to the space which is on another server. @@ -744,7 +750,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ) self._assert_hierarchy(result, expected) - def test_fed_filtering(self): + def test_fed_filtering(self) -> None: """ Rooms returned over federation should be properly filtered to only include rooms the user has access to. @@ -853,7 +859,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ], ) - async def summarize_remote_room_hierarchy(_self, room, suggested_only): + async def summarize_remote_room_hierarchy( + _self: Any, room: Any, suggested_only: bool + ) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]: return subspace_room_entry, dict(children_rooms), set() # Add a room to the space which is on another server. @@ -892,7 +900,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ) self._assert_hierarchy(result, expected) - def test_fed_invited(self): + def test_fed_invited(self) -> None: """ A room which the user was invited to should be included in the response. @@ -915,7 +923,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): }, ) - async def summarize_remote_room_hierarchy(_self, room, suggested_only): + async def summarize_remote_room_hierarchy( + _self: Any, room: Any, suggested_only: bool + ) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]: return fed_room_entry, {}, set() # Add a room to the space which is on another server. @@ -936,7 +946,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ) self._assert_hierarchy(result, expected) - def test_fed_caching(self): + def test_fed_caching(self) -> None: """ Federation `/hierarchy` responses should be cached. """ @@ -1023,7 +1033,7 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.hs = hs self.handler = self.hs.get_room_summary_handler() @@ -1040,12 +1050,12 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase): tok=self.token, ) - def test_own_room(self): + def test_own_room(self) -> None: """Test a simple room created by the requester.""" result = self.get_success(self.handler.get_room_summary(self.user, self.room)) self.assertEqual(result.get("room_id"), self.room) - def test_visibility(self): + def test_visibility(self) -> None: """A user not in a private room cannot get its summary.""" user2 = self.register_user("user2", "pass") token2 = self.login("user2", "pass") @@ -1093,7 +1103,7 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase): result = self.get_success(self.handler.get_room_summary(user2, self.room)) self.assertEqual(result.get("room_id"), self.room) - def test_fed(self): + def test_fed(self) -> None: """ Return data over federation and ensure that it is handled properly. """ @@ -1105,7 +1115,9 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase): {"room_id": fed_room, "world_readable": True}, ) - async def summarize_remote_room_hierarchy(_self, room, suggested_only): + async def summarize_remote_room_hierarchy( + _self: Any, room: Any, suggested_only: bool + ) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]: return requested_room_entry, {}, set() with mock.patch( diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index a0f84e2940..9b1b8b9f13 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Set, Tuple from unittest.mock import Mock import attr @@ -20,7 +20,9 @@ import attr from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import RedirectException +from synapse.module_api import ModuleApi from synapse.server import HomeServer +from synapse.types import JsonDict from synapse.util import Clock from tests.test_utils import simple_async_mock @@ -29,6 +31,7 @@ from tests.unittest import HomeserverTestCase, override_config # Check if we have the dependencies to run the tests. try: import saml2.config + import saml2.response from saml2.sigver import SigverError has_saml2 = True @@ -56,31 +59,39 @@ class FakeAuthnResponse: class TestMappingProvider: - def __init__(self, config, module): + def __init__(self, config: None, module: ModuleApi): pass @staticmethod - def parse_config(config): - return + def parse_config(config: JsonDict) -> None: + return None @staticmethod - def get_saml_attributes(config): + def get_saml_attributes(config: None) -> Tuple[Set[str], Set[str]]: return {"uid"}, {"displayName"} - def get_remote_user_id(self, saml_response, client_redirect_url): + def get_remote_user_id( + self, saml_response: "saml2.response.AuthnResponse", client_redirect_url: str + ) -> str: return saml_response.ava["uid"] def saml_response_to_user_attributes( - self, saml_response, failures, client_redirect_url - ): + self, + saml_response: "saml2.response.AuthnResponse", + failures: int, + client_redirect_url: str, + ) -> dict: localpart = saml_response.ava["username"] + (str(failures) if failures else "") return {"mxid_localpart": localpart, "displayname": None} class TestRedirectMappingProvider(TestMappingProvider): def saml_response_to_user_attributes( - self, saml_response, failures, client_redirect_url - ): + self, + saml_response: "saml2.response.AuthnResponse", + failures: int, + client_redirect_url: str, + ) -> dict: raise RedirectException(b"https://custom-saml-redirect/") @@ -347,7 +358,7 @@ class SamlHandlerTestCase(HomeserverTestCase): ) -def _mock_request(): +def _mock_request() -> Mock: """Returns a mock which will stand in as a SynapseRequest""" mock = Mock( spec=[ diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py index da4bf8b582..8b6e4a40b6 100644 --- a/tests/handlers/test_send_email.py +++ b/tests/handlers/test_send_email.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import List, Tuple +from typing import Callable, List, Tuple from zope.interface import implementer @@ -28,20 +28,27 @@ from tests.unittest import HomeserverTestCase, override_config @implementer(interfaces.IMessageDelivery) class _DummyMessageDelivery: - def __init__(self): + def __init__(self) -> None: # (recipient, message) tuples self.messages: List[Tuple[smtp.Address, bytes]] = [] - def receivedHeader(self, helo, origin, recipients): + def receivedHeader( + self, + helo: Tuple[bytes, bytes], + origin: smtp.Address, + recipients: List[smtp.User], + ) -> None: return None - def validateFrom(self, helo, origin): + def validateFrom( + self, helo: Tuple[bytes, bytes], origin: smtp.Address + ) -> smtp.Address: return origin - def record_message(self, recipient: smtp.Address, message: bytes): + def record_message(self, recipient: smtp.Address, message: bytes) -> None: self.messages.append((recipient, message)) - def validateTo(self, user: smtp.User): + def validateTo(self, user: smtp.User) -> Callable[[], interfaces.IMessageSMTP]: return lambda: _DummyMessage(self, user) @@ -56,20 +63,20 @@ class _DummyMessage: self._user = user self._buffer: List[bytes] = [] - def lineReceived(self, line): + def lineReceived(self, line: bytes) -> None: self._buffer.append(line) - def eomReceived(self): + def eomReceived(self) -> "defer.Deferred[bytes]": message = b"\n".join(self._buffer) + b"\n" self._delivery.record_message(self._user.dest, message) return defer.succeed(b"saved") - def connectionLost(self): + def connectionLost(self) -> None: pass class SendEmailHandlerTestCase(HomeserverTestCase): - def test_send_email(self): + def test_send_email(self) -> None: """Happy-path test that we can send email to a non-TLS server.""" h = self.hs.get_send_email_handler() d = ensureDeferred( @@ -119,7 +126,7 @@ class SendEmailHandlerTestCase(HomeserverTestCase): }, } ) - def test_send_email_force_tls(self): + def test_send_email_force_tls(self) -> None: """Happy-path test that we can send email to an Implicit TLS server.""" h = self.hs.get_send_email_handler() d = ensureDeferred( diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 05f9ec3c51..f1a50c5bcb 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -12,9 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, List, Optional + +from twisted.test.proto_helpers import MemoryReactor + from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer from synapse.storage.databases.main import stats +from synapse.util import Clock from tests import unittest @@ -32,11 +38,11 @@ class StatsRoomTests(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.handler = self.hs.get_stats_handler() - def _add_background_updates(self): + def _add_background_updates(self) -> None: """ Add the background updates we need to run. """ @@ -63,12 +69,14 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - async def get_all_room_state(self): + async def get_all_room_state(self) -> List[Dict[str, Any]]: return await self.store.db_pool.simple_select_list( "room_stats_state", None, retcols=("name", "topic", "canonical_alias") ) - def _get_current_stats(self, stats_type, stat_id): + def _get_current_stats( + self, stats_type: str, stat_id: str + ) -> Optional[Dict[str, Any]]: table, id_col = stats.TYPE_TO_TABLE[stats_type] cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type]) @@ -82,13 +90,13 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - def _perform_background_initial_update(self): + def _perform_background_initial_update(self) -> None: # Do the initial population of the stats via the background update self._add_background_updates() self.wait_for_background_updates() - def test_initial_room(self): + def test_initial_room(self) -> None: """ The background updates will build the table from scratch. """ @@ -125,7 +133,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.assertEqual(len(r), 1) self.assertEqual(r[0]["topic"], "foo") - def test_create_user(self): + def test_create_user(self) -> None: """ When we create a user, it should have statistics already ready. """ @@ -134,12 +142,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): u1stats = self._get_current_stats("user", u1) - self.assertIsNotNone(u1stats) + assert u1stats is not None # not in any rooms by default self.assertEqual(u1stats["joined_rooms"], 0) - def test_create_room(self): + def test_create_room(self) -> None: """ When we create a room, it should have statistics already ready. """ @@ -153,8 +161,8 @@ class StatsRoomTests(unittest.HomeserverTestCase): r2 = self.helper.create_room_as(u1, tok=u1token, is_public=False) r2stats = self._get_current_stats("room", r2) - self.assertIsNotNone(r1stats) - self.assertIsNotNone(r2stats) + assert r1stats is not None + assert r2stats is not None self.assertEqual( r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM @@ -171,7 +179,9 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.assertEqual(r2stats["invited_members"], 0) self.assertEqual(r2stats["banned_members"], 0) - def test_updating_profile_information_does_not_increase_joined_members_count(self): + def test_updating_profile_information_does_not_increase_joined_members_count( + self, + ) -> None: """ Check that the joined_members count does not increase when a user changes their profile information (which is done by sending another join membership event into @@ -186,6 +196,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Get the current room stats r1stats_ante = self._get_current_stats("room", r1) + assert r1stats_ante is not None # Send a profile update into the room new_profile = {"displayname": "bob"} @@ -195,6 +206,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Get the new room stats r1stats_post = self._get_current_stats("room", r1) + assert r1stats_post is not None # Ensure that the user count did not changed self.assertEqual(r1stats_post["joined_members"], r1stats_ante["joined_members"]) @@ -202,7 +214,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats_post["local_users_in_room"], r1stats_ante["local_users_in_room"] ) - def test_send_state_event_nonoverwriting(self): + def test_send_state_event_nonoverwriting(self) -> None: """ When we send a non-overwriting state event, it increments current_state_events """ @@ -218,19 +230,21 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) r1stats_ante = self._get_current_stats("room", r1) + assert r1stats_ante is not None self.helper.send_state( r1, "cat.hissing", {"value": False}, tok=u1token, state_key="moggy" ) r1stats_post = self._get_current_stats("room", r1) + assert r1stats_post is not None self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], 1, ) - def test_join_first_time(self): + def test_join_first_time(self) -> None: """ When a user joins a room for the first time, current_state_events and joined_members should increase by exactly 1. @@ -246,10 +260,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): u2token = self.login("u2", "pass") r1stats_ante = self._get_current_stats("room", r1) + assert r1stats_ante is not None self.helper.join(r1, u2, tok=u2token) r1stats_post = self._get_current_stats("room", r1) + assert r1stats_post is not None self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], @@ -259,7 +275,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats_post["joined_members"] - r1stats_ante["joined_members"], 1 ) - def test_join_after_leave(self): + def test_join_after_leave(self) -> None: """ When a user joins a room after being previously left, joined_members should increase by exactly 1. @@ -280,10 +296,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.helper.leave(r1, u2, tok=u2token) r1stats_ante = self._get_current_stats("room", r1) + assert r1stats_ante is not None self.helper.join(r1, u2, tok=u2token) r1stats_post = self._get_current_stats("room", r1) + assert r1stats_post is not None self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], @@ -296,7 +314,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats_post["left_members"] - r1stats_ante["left_members"], -1 ) - def test_invited(self): + def test_invited(self) -> None: """ When a user invites another user, current_state_events and invited_members should increase by exactly 1. @@ -311,10 +329,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): u2 = self.register_user("u2", "pass") r1stats_ante = self._get_current_stats("room", r1) + assert r1stats_ante is not None self.helper.invite(r1, u1, u2, tok=u1token) r1stats_post = self._get_current_stats("room", r1) + assert r1stats_post is not None self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], @@ -324,7 +344,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats_post["invited_members"] - r1stats_ante["invited_members"], +1 ) - def test_join_after_invite(self): + def test_join_after_invite(self) -> None: """ When a user joins a room after being invited and joined_members should increase by exactly 1. @@ -344,10 +364,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.helper.invite(r1, u1, u2, tok=u1token) r1stats_ante = self._get_current_stats("room", r1) + assert r1stats_ante is not None self.helper.join(r1, u2, tok=u2token) r1stats_post = self._get_current_stats("room", r1) + assert r1stats_post is not None self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], @@ -360,7 +382,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats_post["invited_members"] - r1stats_ante["invited_members"], -1 ) - def test_left(self): + def test_left(self) -> None: """ When a user leaves a room after joining and left_members should increase by exactly 1. @@ -380,10 +402,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.helper.join(r1, u2, tok=u2token) r1stats_ante = self._get_current_stats("room", r1) + assert r1stats_ante is not None self.helper.leave(r1, u2, tok=u2token) r1stats_post = self._get_current_stats("room", r1) + assert r1stats_post is not None self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], @@ -396,7 +420,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1 ) - def test_banned(self): + def test_banned(self) -> None: """ When a user is banned from a room after joining and left_members should increase by exactly 1. @@ -416,10 +440,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.helper.join(r1, u2, tok=u2token) r1stats_ante = self._get_current_stats("room", r1) + assert r1stats_ante is not None self.helper.change_membership(r1, u1, u2, "ban", tok=u1token) r1stats_post = self._get_current_stats("room", r1) + assert r1stats_post is not None self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], @@ -432,7 +458,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1 ) - def test_initial_background_update(self): + def test_initial_background_update(self) -> None: """ Test that statistics can be generated by the initial background update handler. @@ -462,6 +488,9 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats = self._get_current_stats("room", r1) u1stats = self._get_current_stats("user", u1) + assert r1stats is not None + assert u1stats is not None + self.assertEqual(r1stats["joined_members"], 1) self.assertEqual( r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM @@ -469,7 +498,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.assertEqual(u1stats["joined_rooms"], 1) - def test_incomplete_stats(self): + def test_incomplete_stats(self) -> None: """ This tests that we track incomplete statistics. @@ -533,8 +562,11 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.wait_for_background_updates() r1stats_complete = self._get_current_stats("room", r1) + assert r1stats_complete is not None u1stats_complete = self._get_current_stats("user", u1) + assert u1stats_complete is not None u2stats_complete = self._get_current_stats("user", u2) + assert u2stats_complete is not None # now we make our assertions diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index ab5c101eb7..0d9a3de92a 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -14,6 +14,8 @@ from typing import Optional from unittest.mock import MagicMock, Mock, patch +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventTypes, JoinRules from synapse.api.errors import Codes, ResourceLimitError from synapse.api.filtering import Filtering @@ -23,6 +25,7 @@ from synapse.rest import admin from synapse.rest.client import knock, login, room from synapse.server import HomeServer from synapse.types import UserID, create_requester +from synapse.util import Clock import tests.unittest import tests.utils @@ -39,7 +42,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, hs: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.sync_handler = self.hs.get_sync_handler() self.store = self.hs.get_datastores().main @@ -47,7 +50,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # modify its config instead of the hs' self.auth_blocking = self.hs.get_auth_blocking() - def test_wait_for_sync_for_user_auth_blocking(self): + def test_wait_for_sync_for_user_auth_blocking(self) -> None: user_id1 = "@user1:test" user_id2 = "@user2:test" sync_config = generate_sync_config(user_id1) @@ -82,7 +85,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): ) self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - def test_unknown_room_version(self): + def test_unknown_room_version(self) -> None: """ A room with an unknown room version should not break sync (and should be excluded). """ @@ -186,7 +189,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): self.assertNotIn(invite_room, [r.room_id for r in result.invited]) self.assertNotIn(knock_room, [r.room_id for r in result.knocked]) - def test_ban_wins_race_with_join(self): + def test_ban_wins_race_with_join(self) -> None: """Rooms shouldn't appear under "joined" if a join loses a race to a ban. A complicated edge case. Imagine the following scenario: