mirror of
https://github.com/element-hq/synapse.git
synced 2024-12-14 11:57:44 +00:00
Add missing type hints to tests.handlers. (#14680)
And do not allow untyped defs in tests.handlers.
This commit is contained in:
parent
54c012c5a8
commit
652d1669c5
22 changed files with 527 additions and 378 deletions
1
changelog.d/14680.misc
Normal file
1
changelog.d/14680.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Add missing type hints.
|
5
mypy.ini
5
mypy.ini
|
@ -95,10 +95,7 @@ disallow_untyped_defs = True
|
|||
[mypy-tests.federation.transport.test_client]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-tests.handlers.test_sso]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-tests.handlers.test_user_directory]
|
||||
[mypy-tests.handlers.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-tests.metrics.test_background_process_metrics]
|
||||
|
|
|
@ -2031,7 +2031,7 @@ class PasswordAuthProvider:
|
|||
self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []
|
||||
|
||||
# Mapping from login type to login parameters
|
||||
self._supported_login_types: Dict[str, Iterable[str]] = {}
|
||||
self._supported_login_types: Dict[str, Tuple[str, ...]] = {}
|
||||
|
||||
# Mapping from login type to auth checker callbacks
|
||||
self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {}
|
||||
|
|
|
@ -31,7 +31,7 @@ from synapse.appservice import (
|
|||
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||
from synapse.rest.client import login, receipts, register, room, sendtodevice
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import RoomStreamToken
|
||||
from synapse.types import JsonDict, RoomStreamToken
|
||||
from synapse.util import Clock
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
|
@ -44,7 +44,7 @@ from tests.utils import MockClock
|
|||
class AppServiceHandlerTestCase(unittest.TestCase):
|
||||
"""Tests the ApplicationServicesHandler."""
|
||||
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
self.mock_store = Mock()
|
||||
self.mock_as_api = Mock()
|
||||
self.mock_scheduler = Mock()
|
||||
|
@ -61,7 +61,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
self.handler = ApplicationServicesHandler(hs)
|
||||
self.event_source = hs.get_event_sources()
|
||||
|
||||
def test_notify_interested_services(self):
|
||||
def test_notify_interested_services(self) -> None:
|
||||
interested_service = self._mkservice(is_interested_in_event=True)
|
||||
services = [
|
||||
self._mkservice(is_interested_in_event=False),
|
||||
|
@ -90,7 +90,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
interested_service, events=[event]
|
||||
)
|
||||
|
||||
def test_query_user_exists_unknown_user(self):
|
||||
def test_query_user_exists_unknown_user(self) -> None:
|
||||
user_id = "@someone:anywhere"
|
||||
services = [self._mkservice(is_interested_in_event=True)]
|
||||
services[0].is_interested_in_user.return_value = True
|
||||
|
@ -107,7 +107,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
|
||||
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
|
||||
|
||||
def test_query_user_exists_known_user(self):
|
||||
def test_query_user_exists_known_user(self) -> None:
|
||||
user_id = "@someone:anywhere"
|
||||
services = [self._mkservice(is_interested_in_event=True)]
|
||||
services[0].is_interested_in_user.return_value = True
|
||||
|
@ -127,7 +127,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
"query_user called when it shouldn't have been.",
|
||||
)
|
||||
|
||||
def test_query_room_alias_exists(self):
|
||||
def test_query_room_alias_exists(self) -> None:
|
||||
room_alias_str = "#foo:bar"
|
||||
room_alias = Mock()
|
||||
room_alias.to_string.return_value = room_alias_str
|
||||
|
@ -157,7 +157,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
self.assertEqual(result.room_id, room_id)
|
||||
self.assertEqual(result.servers, servers)
|
||||
|
||||
def test_get_3pe_protocols_no_appservices(self):
|
||||
def test_get_3pe_protocols_no_appservices(self) -> None:
|
||||
self.mock_store.get_app_services.return_value = []
|
||||
response = self.successResultOf(
|
||||
defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol"))
|
||||
|
@ -165,7 +165,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
self.mock_as_api.get_3pe_protocol.assert_not_called()
|
||||
self.assertEqual(response, {})
|
||||
|
||||
def test_get_3pe_protocols_no_protocols(self):
|
||||
def test_get_3pe_protocols_no_protocols(self) -> None:
|
||||
service = self._mkservice(False, [])
|
||||
self.mock_store.get_app_services.return_value = [service]
|
||||
response = self.successResultOf(
|
||||
|
@ -174,7 +174,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
self.mock_as_api.get_3pe_protocol.assert_not_called()
|
||||
self.assertEqual(response, {})
|
||||
|
||||
def test_get_3pe_protocols_protocol_no_response(self):
|
||||
def test_get_3pe_protocols_protocol_no_response(self) -> None:
|
||||
service = self._mkservice(False, ["my-protocol"])
|
||||
self.mock_store.get_app_services.return_value = [service]
|
||||
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None)
|
||||
|
@ -186,7 +186,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
)
|
||||
self.assertEqual(response, {})
|
||||
|
||||
def test_get_3pe_protocols_select_one_protocol(self):
|
||||
def test_get_3pe_protocols_select_one_protocol(self) -> None:
|
||||
service = self._mkservice(False, ["my-protocol"])
|
||||
self.mock_store.get_app_services.return_value = [service]
|
||||
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
|
||||
|
@ -202,7 +202,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
response, {"my-protocol": {"x-protocol-data": 42, "instances": []}}
|
||||
)
|
||||
|
||||
def test_get_3pe_protocols_one_protocol(self):
|
||||
def test_get_3pe_protocols_one_protocol(self) -> None:
|
||||
service = self._mkservice(False, ["my-protocol"])
|
||||
self.mock_store.get_app_services.return_value = [service]
|
||||
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
|
||||
|
@ -218,7 +218,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
response, {"my-protocol": {"x-protocol-data": 42, "instances": []}}
|
||||
)
|
||||
|
||||
def test_get_3pe_protocols_multiple_protocol(self):
|
||||
def test_get_3pe_protocols_multiple_protocol(self) -> None:
|
||||
service_one = self._mkservice(False, ["my-protocol"])
|
||||
service_two = self._mkservice(False, ["other-protocol"])
|
||||
self.mock_store.get_app_services.return_value = [service_one, service_two]
|
||||
|
@ -237,11 +237,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
},
|
||||
)
|
||||
|
||||
def test_get_3pe_protocols_multiple_info(self):
|
||||
def test_get_3pe_protocols_multiple_info(self) -> None:
|
||||
service_one = self._mkservice(False, ["my-protocol"])
|
||||
service_two = self._mkservice(False, ["my-protocol"])
|
||||
|
||||
async def get_3pe_protocol(service, unusedProtocol):
|
||||
async def get_3pe_protocol(
|
||||
service: ApplicationService, protocol: str
|
||||
) -> Optional[JsonDict]:
|
||||
if service == service_one:
|
||||
return {
|
||||
"x-protocol-data": 42,
|
||||
|
@ -276,7 +278,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
},
|
||||
)
|
||||
|
||||
def test_notify_interested_services_ephemeral(self):
|
||||
def test_notify_interested_services_ephemeral(self) -> None:
|
||||
"""
|
||||
Test sending ephemeral events to the appservice handler are scheduled
|
||||
to be pushed out to interested appservices, and that the stream ID is
|
||||
|
@ -306,7 +308,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
580,
|
||||
)
|
||||
|
||||
def test_notify_interested_services_ephemeral_out_of_order(self):
|
||||
def test_notify_interested_services_ephemeral_out_of_order(self) -> None:
|
||||
"""
|
||||
Test sending out of order ephemeral events to the appservice handler
|
||||
are ignored.
|
||||
|
@ -390,7 +392,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
|
|||
receipts.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.hs = hs
|
||||
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that
|
||||
# we can track any outgoing ephemeral events
|
||||
|
@ -417,7 +419,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
|
|||
"exclusive_as_user", "password", self.exclusive_as_user_device_id
|
||||
)
|
||||
|
||||
def _notify_interested_services(self):
|
||||
def _notify_interested_services(self) -> None:
|
||||
# This is normally set in `notify_interested_services` but we need to call the
|
||||
# internal async version so the reactor gets pushed to completion.
|
||||
self.hs.get_application_service_handler().current_max += 1
|
||||
|
@ -443,7 +445,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
def test_match_interesting_room_members(
|
||||
self, interesting_user: str, should_notify: bool
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Test to make sure that a interesting user (local or remote) in the room is
|
||||
notified as expected when someone else in the room sends a message.
|
||||
|
@ -512,7 +514,9 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
|
|||
else:
|
||||
self.send_mock.assert_not_called()
|
||||
|
||||
def test_application_services_receive_events_sent_by_interesting_local_user(self):
|
||||
def test_application_services_receive_events_sent_by_interesting_local_user(
|
||||
self,
|
||||
) -> None:
|
||||
"""
|
||||
Test to make sure that a messages sent from a local user can be interesting and
|
||||
picked up by the appservice.
|
||||
|
@ -568,7 +572,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(events[0]["type"], "m.room.message")
|
||||
self.assertEqual(events[0]["sender"], alice)
|
||||
|
||||
def test_sending_read_receipt_batches_to_application_services(self):
|
||||
def test_sending_read_receipt_batches_to_application_services(self) -> None:
|
||||
"""Tests that a large batch of read receipts are sent correctly to
|
||||
interested application services.
|
||||
"""
|
||||
|
@ -644,7 +648,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
|
|||
@unittest.override_config(
|
||||
{"experimental_features": {"msc2409_to_device_messages_enabled": True}}
|
||||
)
|
||||
def test_application_services_receive_local_to_device(self):
|
||||
def test_application_services_receive_local_to_device(self) -> None:
|
||||
"""
|
||||
Test that when a user sends a to-device message to another user
|
||||
that is an application service's user namespace, the
|
||||
|
@ -722,7 +726,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
|
|||
@unittest.override_config(
|
||||
{"experimental_features": {"msc2409_to_device_messages_enabled": True}}
|
||||
)
|
||||
def test_application_services_receive_bursts_of_to_device(self):
|
||||
def test_application_services_receive_bursts_of_to_device(self) -> None:
|
||||
"""
|
||||
Test that when a user sends >100 to-device messages at once, any
|
||||
interested AS's will receive them in separate transactions.
|
||||
|
@ -913,7 +917,7 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase)
|
|||
experimental_feature_enabled: bool,
|
||||
as_supports_txn_extensions: bool,
|
||||
as_should_receive_device_list_updates: bool,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Tests that an application service receives notice of changed device
|
||||
lists for a user, when a user changes their device lists.
|
||||
|
@ -1070,7 +1074,7 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
|
|||
and a room for the users to talk in.
|
||||
"""
|
||||
|
||||
async def preparation():
|
||||
async def preparation() -> None:
|
||||
await self._add_otks_for_device(self._sender_user, self._sender_device, 42)
|
||||
await self._add_fallback_key_for_device(
|
||||
self._sender_user, self._sender_device, used=True
|
||||
|
|
|
@ -199,7 +199,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
|||
)
|
||||
|
||||
|
||||
def _mock_request():
|
||||
def _mock_request() -> Mock:
|
||||
"""Returns a mock which will stand in as a SynapseRequest"""
|
||||
mock = Mock(
|
||||
spec=[
|
||||
|
|
|
@ -20,6 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor
|
|||
import synapse.api.errors
|
||||
import synapse.rest.admin
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.events import EventBase
|
||||
from synapse.rest.client import directory, login, room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict, RoomAlias, create_requester
|
||||
|
@ -201,7 +202,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
|
|||
self.test_user_tok = self.login("user", "pass")
|
||||
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
|
||||
|
||||
def _create_alias(self, user) -> None:
|
||||
def _create_alias(self, user: str) -> None:
|
||||
# Create a new alias to this room.
|
||||
self.get_success(
|
||||
self.store.create_room_alias_association(
|
||||
|
@ -324,7 +325,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
return room_alias
|
||||
|
||||
def _set_canonical_alias(self, content) -> None:
|
||||
def _set_canonical_alias(self, content: JsonDict) -> None:
|
||||
"""Configure the canonical alias state on the room."""
|
||||
self.helper.send_state(
|
||||
self.room_id,
|
||||
|
@ -333,13 +334,15 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
|||
tok=self.admin_user_tok,
|
||||
)
|
||||
|
||||
def _get_canonical_alias(self):
|
||||
def _get_canonical_alias(self) -> EventBase:
|
||||
"""Get the canonical alias state of the room."""
|
||||
return self.get_success(
|
||||
result = self.get_success(
|
||||
self._storage_controllers.state.get_current_state_event(
|
||||
self.room_id, EventTypes.CanonicalAlias, ""
|
||||
)
|
||||
)
|
||||
assert result is not None
|
||||
return result
|
||||
|
||||
def test_remove_alias(self) -> None:
|
||||
"""Removing an alias that is the canonical alias should remove it there too."""
|
||||
|
@ -349,8 +352,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
data = self._get_canonical_alias()
|
||||
self.assertEqual(data["content"]["alias"], self.test_alias)
|
||||
self.assertEqual(data["content"]["alt_aliases"], [self.test_alias])
|
||||
self.assertEqual(data.content["alias"], self.test_alias)
|
||||
self.assertEqual(data.content["alt_aliases"], [self.test_alias])
|
||||
|
||||
# Finally, delete the alias.
|
||||
self.get_success(
|
||||
|
@ -360,8 +363,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
data = self._get_canonical_alias()
|
||||
self.assertNotIn("alias", data["content"])
|
||||
self.assertNotIn("alt_aliases", data["content"])
|
||||
self.assertNotIn("alias", data.content)
|
||||
self.assertNotIn("alt_aliases", data.content)
|
||||
|
||||
def test_remove_other_alias(self) -> None:
|
||||
"""Removing an alias listed as in alt_aliases should remove it there too."""
|
||||
|
@ -378,9 +381,9 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
data = self._get_canonical_alias()
|
||||
self.assertEqual(data["content"]["alias"], self.test_alias)
|
||||
self.assertEqual(data.content["alias"], self.test_alias)
|
||||
self.assertEqual(
|
||||
data["content"]["alt_aliases"], [self.test_alias, other_test_alias]
|
||||
data.content["alt_aliases"], [self.test_alias, other_test_alias]
|
||||
)
|
||||
|
||||
# Delete the second alias.
|
||||
|
@ -391,8 +394,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
data = self._get_canonical_alias()
|
||||
self.assertEqual(data["content"]["alias"], self.test_alias)
|
||||
self.assertEqual(data["content"]["alt_aliases"], [self.test_alias])
|
||||
self.assertEqual(data.content["alias"], self.test_alias)
|
||||
self.assertEqual(data.content["alt_aliases"], [self.test_alias])
|
||||
|
||||
|
||||
class TestCreateAliasACL(unittest.HomeserverTestCase):
|
||||
|
|
|
@ -17,7 +17,11 @@
|
|||
import copy
|
||||
from unittest import mock
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
|
||||
|
@ -39,14 +43,14 @@ room_keys = {
|
|||
|
||||
|
||||
class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
def make_homeserver(self, reactor, clock):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
return self.setup_test_homeserver(replication_layer=mock.Mock())
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.handler = hs.get_e2e_room_keys_handler()
|
||||
self.local_user = "@boris:" + hs.hostname
|
||||
|
||||
def test_get_missing_current_version_info(self):
|
||||
def test_get_missing_current_version_info(self) -> None:
|
||||
"""Check that we get a 404 if we ask for info about the current version
|
||||
if there is no version.
|
||||
"""
|
||||
|
@ -56,7 +60,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
res = e.value.code
|
||||
self.assertEqual(res, 404)
|
||||
|
||||
def test_get_missing_version_info(self):
|
||||
def test_get_missing_version_info(self) -> None:
|
||||
"""Check that we get a 404 if we ask for info about a specific version
|
||||
if it doesn't exist.
|
||||
"""
|
||||
|
@ -67,9 +71,9 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
res = e.value.code
|
||||
self.assertEqual(res, 404)
|
||||
|
||||
def test_create_version(self):
|
||||
def test_create_version(self) -> None:
|
||||
"""Check that we can create and then retrieve versions."""
|
||||
res = self.get_success(
|
||||
version = self.get_success(
|
||||
self.handler.create_version(
|
||||
self.local_user,
|
||||
{
|
||||
|
@ -78,7 +82,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
},
|
||||
)
|
||||
)
|
||||
self.assertEqual(res, "1")
|
||||
self.assertEqual(version, "1")
|
||||
|
||||
# check we can retrieve it as the current version
|
||||
res = self.get_success(self.handler.get_version_info(self.local_user))
|
||||
|
@ -110,7 +114,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
# upload a new one...
|
||||
res = self.get_success(
|
||||
version = self.get_success(
|
||||
self.handler.create_version(
|
||||
self.local_user,
|
||||
{
|
||||
|
@ -119,7 +123,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
},
|
||||
)
|
||||
)
|
||||
self.assertEqual(res, "2")
|
||||
self.assertEqual(version, "2")
|
||||
|
||||
# check we can retrieve it as the current version
|
||||
res = self.get_success(self.handler.get_version_info(self.local_user))
|
||||
|
@ -134,7 +138,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
},
|
||||
)
|
||||
|
||||
def test_update_version(self):
|
||||
def test_update_version(self) -> None:
|
||||
"""Check that we can update versions."""
|
||||
version = self.get_success(
|
||||
self.handler.create_version(
|
||||
|
@ -173,7 +177,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
},
|
||||
)
|
||||
|
||||
def test_update_missing_version(self):
|
||||
def test_update_missing_version(self) -> None:
|
||||
"""Check that we get a 404 on updating nonexistent versions"""
|
||||
e = self.get_failure(
|
||||
self.handler.update_version(
|
||||
|
@ -190,7 +194,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
res = e.value.code
|
||||
self.assertEqual(res, 404)
|
||||
|
||||
def test_update_omitted_version(self):
|
||||
def test_update_omitted_version(self) -> None:
|
||||
"""Check that the update succeeds if the version is missing from the body"""
|
||||
version = self.get_success(
|
||||
self.handler.create_version(
|
||||
|
@ -227,7 +231,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
},
|
||||
)
|
||||
|
||||
def test_update_bad_version(self):
|
||||
def test_update_bad_version(self) -> None:
|
||||
"""Check that we get a 400 if the version in the body doesn't match"""
|
||||
version = self.get_success(
|
||||
self.handler.create_version(
|
||||
|
@ -255,7 +259,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
res = e.value.code
|
||||
self.assertEqual(res, 400)
|
||||
|
||||
def test_delete_missing_version(self):
|
||||
def test_delete_missing_version(self) -> None:
|
||||
"""Check that we get a 404 on deleting nonexistent versions"""
|
||||
e = self.get_failure(
|
||||
self.handler.delete_version(self.local_user, "1"), SynapseError
|
||||
|
@ -263,15 +267,15 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
res = e.value.code
|
||||
self.assertEqual(res, 404)
|
||||
|
||||
def test_delete_missing_current_version(self):
|
||||
def test_delete_missing_current_version(self) -> None:
|
||||
"""Check that we get a 404 on deleting nonexistent current version"""
|
||||
e = self.get_failure(self.handler.delete_version(self.local_user), SynapseError)
|
||||
res = e.value.code
|
||||
self.assertEqual(res, 404)
|
||||
|
||||
def test_delete_version(self):
|
||||
def test_delete_version(self) -> None:
|
||||
"""Check that we can create and then delete versions."""
|
||||
res = self.get_success(
|
||||
version = self.get_success(
|
||||
self.handler.create_version(
|
||||
self.local_user,
|
||||
{
|
||||
|
@ -280,7 +284,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
},
|
||||
)
|
||||
)
|
||||
self.assertEqual(res, "1")
|
||||
self.assertEqual(version, "1")
|
||||
|
||||
# check we can delete it
|
||||
self.get_success(self.handler.delete_version(self.local_user, "1"))
|
||||
|
@ -292,7 +296,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
res = e.value.code
|
||||
self.assertEqual(res, 404)
|
||||
|
||||
def test_get_missing_backup(self):
|
||||
def test_get_missing_backup(self) -> None:
|
||||
"""Check that we get a 404 on querying missing backup"""
|
||||
e = self.get_failure(
|
||||
self.handler.get_room_keys(self.local_user, "bogus_version"), SynapseError
|
||||
|
@ -300,7 +304,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
res = e.value.code
|
||||
self.assertEqual(res, 404)
|
||||
|
||||
def test_get_missing_room_keys(self):
|
||||
def test_get_missing_room_keys(self) -> None:
|
||||
"""Check we get an empty response from an empty backup"""
|
||||
version = self.get_success(
|
||||
self.handler.create_version(
|
||||
|
@ -319,7 +323,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
# TODO: test the locking semantics when uploading room_keys,
|
||||
# although this is probably best done in sytest
|
||||
|
||||
def test_upload_room_keys_no_versions(self):
|
||||
def test_upload_room_keys_no_versions(self) -> None:
|
||||
"""Check that we get a 404 on uploading keys when no versions are defined"""
|
||||
e = self.get_failure(
|
||||
self.handler.upload_room_keys(self.local_user, "no_version", room_keys),
|
||||
|
@ -328,7 +332,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
res = e.value.code
|
||||
self.assertEqual(res, 404)
|
||||
|
||||
def test_upload_room_keys_bogus_version(self):
|
||||
def test_upload_room_keys_bogus_version(self) -> None:
|
||||
"""Check that we get a 404 on uploading keys when an nonexistent version
|
||||
is specified
|
||||
"""
|
||||
|
@ -350,7 +354,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
res = e.value.code
|
||||
self.assertEqual(res, 404)
|
||||
|
||||
def test_upload_room_keys_wrong_version(self):
|
||||
def test_upload_room_keys_wrong_version(self) -> None:
|
||||
"""Check that we get a 403 on uploading keys for an old version"""
|
||||
version = self.get_success(
|
||||
self.handler.create_version(
|
||||
|
@ -380,7 +384,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
res = e.value.code
|
||||
self.assertEqual(res, 403)
|
||||
|
||||
def test_upload_room_keys_insert(self):
|
||||
def test_upload_room_keys_insert(self) -> None:
|
||||
"""Check that we can insert and retrieve keys for a session"""
|
||||
version = self.get_success(
|
||||
self.handler.create_version(
|
||||
|
@ -416,7 +420,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
self.assertDictEqual(res, room_keys)
|
||||
|
||||
def test_upload_room_keys_merge(self):
|
||||
def test_upload_room_keys_merge(self) -> None:
|
||||
"""Check that we can upload a new room_key for an existing session and
|
||||
have it correctly merged"""
|
||||
version = self.get_success(
|
||||
|
@ -449,9 +453,11 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
|
||||
)
|
||||
|
||||
res = self.get_success(self.handler.get_room_keys(self.local_user, version))
|
||||
res_keys = self.get_success(
|
||||
self.handler.get_room_keys(self.local_user, version)
|
||||
)
|
||||
self.assertEqual(
|
||||
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
|
||||
res_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
|
||||
"SSBBTSBBIEZJU0gK",
|
||||
)
|
||||
|
||||
|
@ -465,9 +471,12 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
|
||||
)
|
||||
|
||||
res = self.get_success(self.handler.get_room_keys(self.local_user, version))
|
||||
res_keys = self.get_success(
|
||||
self.handler.get_room_keys(self.local_user, version)
|
||||
)
|
||||
self.assertEqual(
|
||||
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
|
||||
res_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
|
||||
"new",
|
||||
)
|
||||
|
||||
# the etag should NOT be equal now, since the key changed
|
||||
|
@ -483,9 +492,12 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
|
||||
)
|
||||
|
||||
res = self.get_success(self.handler.get_room_keys(self.local_user, version))
|
||||
res_keys = self.get_success(
|
||||
self.handler.get_room_keys(self.local_user, version)
|
||||
)
|
||||
self.assertEqual(
|
||||
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
|
||||
res_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
|
||||
"new",
|
||||
)
|
||||
|
||||
# the etag should be the same since the session did not change
|
||||
|
@ -494,7 +506,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
# TODO: check edge cases as well as the common variations here
|
||||
|
||||
def test_delete_room_keys(self):
|
||||
def test_delete_room_keys(self) -> None:
|
||||
"""Check that we can insert and delete keys for a session"""
|
||||
version = self.get_success(
|
||||
self.handler.create_version(
|
||||
|
|
|
@ -439,7 +439,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
|
|||
user_id = self.register_user("kermit", "test")
|
||||
tok = self.login("kermit", "test")
|
||||
|
||||
def create_invite():
|
||||
def create_invite() -> EventBase:
|
||||
room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
|
||||
room_version = self.get_success(self.store.get_room_version(room_id))
|
||||
return event_from_pdu_json(
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
from typing import Optional
|
||||
from unittest import mock
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.errors import AuthError, StoreError
|
||||
from synapse.api.room_versions import RoomVersion
|
||||
from synapse.event_auth import (
|
||||
|
@ -26,8 +28,10 @@ from synapse.federation.transport.client import StateRequestResponse
|
|||
from synapse.logging.context import LoggingContext
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login, room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.test_utils import event_injection, make_awaitable
|
||||
|
@ -40,7 +44,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
|
|||
room.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
# mock out the federation transport client
|
||||
self.mock_federation_transport_client = mock.Mock(
|
||||
spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"]
|
||||
|
@ -165,7 +169,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
|
|||
)
|
||||
else:
|
||||
|
||||
async def get_event(destination: str, event_id: str, timeout=None):
|
||||
async def get_event(
|
||||
destination: str, event_id: str, timeout: Optional[int] = None
|
||||
) -> JsonDict:
|
||||
self.assertEqual(destination, self.OTHER_SERVER_NAME)
|
||||
self.assertEqual(event_id, prev_event.event_id)
|
||||
return {"pdus": [prev_event.get_pdu_json()]}
|
||||
|
|
|
@ -14,12 +14,16 @@
|
|||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login, room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import create_requester
|
||||
from synapse.util import Clock
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
from tests import unittest
|
||||
|
@ -35,7 +39,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
|||
room.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.handler = self.hs.get_event_creation_handler()
|
||||
self._persist_event_storage_controller = (
|
||||
self.hs.get_storage_controllers().persistence
|
||||
|
@ -94,7 +98,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
)
|
||||
|
||||
def test_duplicated_txn_id(self):
|
||||
def test_duplicated_txn_id(self) -> None:
|
||||
"""Test that attempting to handle/persist an event with a transaction ID
|
||||
that has already been persisted correctly returns the old event and does
|
||||
*not* produce duplicate messages.
|
||||
|
@ -161,7 +165,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
|||
# rather than the new one.
|
||||
self.assertEqual(ret_event1.event_id, ret_event4.event_id)
|
||||
|
||||
def test_duplicated_txn_id_one_call(self):
|
||||
def test_duplicated_txn_id_one_call(self) -> None:
|
||||
"""Test that we correctly handle duplicates that we try and persist at
|
||||
the same time.
|
||||
"""
|
||||
|
@ -185,7 +189,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(len(events), 2)
|
||||
self.assertEqual(events[0].event_id, events[1].event_id)
|
||||
|
||||
def test_when_empty_prev_events_allowed_create_event_with_empty_prev_events(self):
|
||||
def test_when_empty_prev_events_allowed_create_event_with_empty_prev_events(
|
||||
self,
|
||||
) -> None:
|
||||
"""When we set allow_no_prev_events=True, should be able to create a
|
||||
event without any prev_events (only auth_events).
|
||||
"""
|
||||
|
@ -214,7 +220,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
def test_when_empty_prev_events_not_allowed_reject_event_with_empty_prev_events(
|
||||
self,
|
||||
):
|
||||
) -> None:
|
||||
"""When we set allow_no_prev_events=False, shouldn't be able to create a
|
||||
event without any prev_events even if it has auth_events. Expect an
|
||||
exception to be raised.
|
||||
|
@ -245,7 +251,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
def test_when_empty_prev_events_allowed_reject_event_with_empty_prev_events_and_auth_events(
|
||||
self,
|
||||
):
|
||||
) -> None:
|
||||
"""When we set allow_no_prev_events=True, should be able to create a
|
||||
event without any prev_events or auth_events. Expect an exception to be
|
||||
raised.
|
||||
|
@ -277,12 +283,12 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
|
|||
room.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.user_id = self.register_user("tester", "foobar")
|
||||
self.access_token = self.login("tester", "foobar")
|
||||
self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
|
||||
|
||||
def test_allow_server_acl(self):
|
||||
def test_allow_server_acl(self) -> None:
|
||||
"""Test that sending an ACL that blocks everyone but ourselves works."""
|
||||
|
||||
self.helper.send_state(
|
||||
|
@ -293,7 +299,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
|
|||
expect_code=200,
|
||||
)
|
||||
|
||||
def test_deny_server_acl_block_outselves(self):
|
||||
def test_deny_server_acl_block_outselves(self) -> None:
|
||||
"""Test that sending an ACL that blocks ourselves does not work."""
|
||||
self.helper.send_state(
|
||||
self.room_id,
|
||||
|
@ -303,7 +309,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
|
|||
expect_code=400,
|
||||
)
|
||||
|
||||
def test_deny_redact_server_acl(self):
|
||||
def test_deny_redact_server_acl(self) -> None:
|
||||
"""Test that attempting to redact an ACL is blocked."""
|
||||
|
||||
body = self.helper.send_state(
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from typing import Any, Dict, Tuple
|
||||
from typing import Any, Awaitable, ContextManager, Dict, Optional, Tuple
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
|
@ -23,7 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor
|
|||
from synapse.handlers.sso import MappingException
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import UserID
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.util import Clock
|
||||
from synapse.util.macaroons import get_value_from_macaroon
|
||||
from synapse.util.stringutils import random_string
|
||||
|
@ -34,6 +34,10 @@ from tests.unittest import HomeserverTestCase, override_config
|
|||
|
||||
try:
|
||||
import authlib # noqa: F401
|
||||
from authlib.oidc.core import UserInfo
|
||||
from authlib.oidc.discovery import OpenIDProviderMetadata
|
||||
|
||||
from synapse.handlers.oidc import Token, UserAttributeDict
|
||||
|
||||
HAS_OIDC = True
|
||||
except ImportError:
|
||||
|
@ -70,29 +74,37 @@ EXPLICIT_ENDPOINT_CONFIG = {
|
|||
|
||||
class TestMappingProvider:
|
||||
@staticmethod
|
||||
def parse_config(config):
|
||||
return
|
||||
def parse_config(config: JsonDict) -> None:
|
||||
return None
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: None):
|
||||
pass
|
||||
|
||||
def get_remote_user_id(self, userinfo):
|
||||
def get_remote_user_id(self, userinfo: "UserInfo") -> str:
|
||||
return userinfo["sub"]
|
||||
|
||||
async def map_user_attributes(self, userinfo, token):
|
||||
return {"localpart": userinfo["username"], "display_name": None}
|
||||
async def map_user_attributes(
|
||||
self, userinfo: "UserInfo", token: "Token"
|
||||
) -> "UserAttributeDict":
|
||||
# This is testing not providing the full map.
|
||||
return {"localpart": userinfo["username"], "display_name": None} # type: ignore[typeddict-item]
|
||||
|
||||
# Do not include get_extra_attributes to test backwards compatibility paths.
|
||||
|
||||
|
||||
class TestMappingProviderExtra(TestMappingProvider):
|
||||
async def get_extra_attributes(self, userinfo, token):
|
||||
async def get_extra_attributes(
|
||||
self, userinfo: "UserInfo", token: "Token"
|
||||
) -> JsonDict:
|
||||
return {"phone": userinfo["phone"]}
|
||||
|
||||
|
||||
class TestMappingProviderFailures(TestMappingProvider):
|
||||
async def map_user_attributes(self, userinfo, token, failures):
|
||||
return {
|
||||
# Superclass is testing the legacy interface for map_user_attributes.
|
||||
async def map_user_attributes( # type: ignore[override]
|
||||
self, userinfo: "UserInfo", token: "Token", failures: int
|
||||
) -> "UserAttributeDict":
|
||||
return { # type: ignore[typeddict-item]
|
||||
"localpart": userinfo["username"] + (str(failures) if failures else ""),
|
||||
"display_name": None,
|
||||
}
|
||||
|
@ -161,13 +173,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
self.hs_patcher.stop()
|
||||
return super().tearDown()
|
||||
|
||||
def reset_mocks(self):
|
||||
def reset_mocks(self) -> None:
|
||||
"""Reset all the Mocks."""
|
||||
self.fake_server.reset_mocks()
|
||||
self.render_error.reset_mock()
|
||||
self.complete_sso_login.reset_mock()
|
||||
|
||||
def metadata_edit(self, values):
|
||||
def metadata_edit(self, values: dict) -> ContextManager[Mock]:
|
||||
"""Modify the result that will be returned by the well-known query"""
|
||||
|
||||
metadata = self.fake_server.get_metadata()
|
||||
|
@ -196,7 +208,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
|
||||
return _build_callback_request(code, state, session), grant
|
||||
|
||||
def assertRenderedError(self, error, error_description=None):
|
||||
def assertRenderedError(
|
||||
self, error: str, error_description: Optional[str] = None
|
||||
) -> Tuple[Any, ...]:
|
||||
self.render_error.assert_called_once()
|
||||
args = self.render_error.call_args[0]
|
||||
self.assertEqual(args[1], error)
|
||||
|
@ -273,8 +287,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
"""Provider metadatas are extensively validated."""
|
||||
h = self.provider
|
||||
|
||||
def force_load_metadata():
|
||||
async def force_load():
|
||||
def force_load_metadata() -> Awaitable[None]:
|
||||
async def force_load() -> "OpenIDProviderMetadata":
|
||||
return await h.load_metadata(force=True)
|
||||
|
||||
return get_awaitable_result(force_load())
|
||||
|
@ -1198,7 +1212,7 @@ def _build_callback_request(
|
|||
state: str,
|
||||
session: str,
|
||||
ip_address: str = "10.0.0.1",
|
||||
):
|
||||
) -> Mock:
|
||||
"""Builds a fake SynapseRequest to mock the browser callback
|
||||
|
||||
Returns a Mock object which looks like the SynapseRequest we get from a browser
|
||||
|
|
|
@ -15,12 +15,13 @@
|
|||
"""Tests for the password_auth_provider interface"""
|
||||
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Type, Union
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
from unittest.mock import Mock
|
||||
|
||||
import synapse
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import Codes
|
||||
from synapse.handlers.account import AccountHandler
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.rest.client import account, devices, login, logout, register
|
||||
from synapse.types import JsonDict, UserID
|
||||
|
@ -44,13 +45,13 @@ class LegacyPasswordOnlyAuthProvider:
|
|||
"""A legacy password_provider which only implements `check_password`."""
|
||||
|
||||
@staticmethod
|
||||
def parse_config(self):
|
||||
def parse_config(config: JsonDict) -> None:
|
||||
pass
|
||||
|
||||
def __init__(self, config, account_handler):
|
||||
def __init__(self, config: None, account_handler: AccountHandler):
|
||||
pass
|
||||
|
||||
def check_password(self, *args):
|
||||
def check_password(self, *args: str) -> Mock:
|
||||
return mock_password_provider.check_password(*args)
|
||||
|
||||
|
||||
|
@ -58,16 +59,16 @@ class LegacyCustomAuthProvider:
|
|||
"""A legacy password_provider which implements a custom login type."""
|
||||
|
||||
@staticmethod
|
||||
def parse_config(self):
|
||||
def parse_config(config: JsonDict) -> None:
|
||||
pass
|
||||
|
||||
def __init__(self, config, account_handler):
|
||||
def __init__(self, config: None, account_handler: AccountHandler):
|
||||
pass
|
||||
|
||||
def get_supported_login_types(self):
|
||||
def get_supported_login_types(self) -> Dict[str, List[str]]:
|
||||
return {"test.login_type": ["test_field"]}
|
||||
|
||||
def check_auth(self, *args):
|
||||
def check_auth(self, *args: str) -> Mock:
|
||||
return mock_password_provider.check_auth(*args)
|
||||
|
||||
|
||||
|
@ -75,15 +76,15 @@ class CustomAuthProvider:
|
|||
"""A module which registers password_auth_provider callbacks for a custom login type."""
|
||||
|
||||
@staticmethod
|
||||
def parse_config(self):
|
||||
def parse_config(config: JsonDict) -> None:
|
||||
pass
|
||||
|
||||
def __init__(self, config, api: ModuleApi):
|
||||
def __init__(self, config: None, api: ModuleApi):
|
||||
api.register_password_auth_provider_callbacks(
|
||||
auth_checkers={("test.login_type", ("test_field",)): self.check_auth}
|
||||
)
|
||||
|
||||
def check_auth(self, *args):
|
||||
def check_auth(self, *args: Any) -> Mock:
|
||||
return mock_password_provider.check_auth(*args)
|
||||
|
||||
|
||||
|
@ -92,16 +93,16 @@ class LegacyPasswordCustomAuthProvider:
|
|||
as a custom type."""
|
||||
|
||||
@staticmethod
|
||||
def parse_config(self):
|
||||
def parse_config(config: JsonDict) -> None:
|
||||
pass
|
||||
|
||||
def __init__(self, config, account_handler):
|
||||
def __init__(self, config: None, account_handler: AccountHandler):
|
||||
pass
|
||||
|
||||
def get_supported_login_types(self):
|
||||
def get_supported_login_types(self) -> Dict[str, List[str]]:
|
||||
return {"m.login.password": ["password"], "test.login_type": ["test_field"]}
|
||||
|
||||
def check_auth(self, *args):
|
||||
def check_auth(self, *args: str) -> Mock:
|
||||
return mock_password_provider.check_auth(*args)
|
||||
|
||||
|
||||
|
@ -110,10 +111,10 @@ class PasswordCustomAuthProvider:
|
|||
as well as a password login"""
|
||||
|
||||
@staticmethod
|
||||
def parse_config(self):
|
||||
def parse_config(config: JsonDict) -> None:
|
||||
pass
|
||||
|
||||
def __init__(self, config, api: ModuleApi):
|
||||
def __init__(self, config: None, api: ModuleApi):
|
||||
api.register_password_auth_provider_callbacks(
|
||||
auth_checkers={
|
||||
("test.login_type", ("test_field",)): self.check_auth,
|
||||
|
@ -121,10 +122,10 @@ class PasswordCustomAuthProvider:
|
|||
}
|
||||
)
|
||||
|
||||
def check_auth(self, *args):
|
||||
def check_auth(self, *args: Any) -> Mock:
|
||||
return mock_password_provider.check_auth(*args)
|
||||
|
||||
def check_pass(self, *args):
|
||||
def check_pass(self, *args: str) -> Mock:
|
||||
return mock_password_provider.check_password(*args)
|
||||
|
||||
|
||||
|
@ -161,16 +162,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
CALLBACK_USERNAME = "get_username_for_registration"
|
||||
CALLBACK_DISPLAYNAME = "get_displayname_for_registration"
|
||||
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
# we use a global mock device, so make sure we are starting with a clean slate
|
||||
mock_password_provider.reset_mock()
|
||||
super().setUp()
|
||||
|
||||
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
|
||||
def test_password_only_auth_progiver_login_legacy(self):
|
||||
def test_password_only_auth_progiver_login_legacy(self) -> None:
|
||||
self.password_only_auth_provider_login_test_body()
|
||||
|
||||
def password_only_auth_provider_login_test_body(self):
|
||||
def password_only_auth_provider_login_test_body(self) -> None:
|
||||
# login flows should only have m.login.password
|
||||
flows = self._get_login_flows()
|
||||
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
|
||||
|
@ -201,10 +202,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
|
||||
def test_password_only_auth_provider_ui_auth_legacy(self):
|
||||
def test_password_only_auth_provider_ui_auth_legacy(self) -> None:
|
||||
self.password_only_auth_provider_ui_auth_test_body()
|
||||
|
||||
def password_only_auth_provider_ui_auth_test_body(self):
|
||||
def password_only_auth_provider_ui_auth_test_body(self) -> None:
|
||||
"""UI Auth should delegate correctly to the password provider"""
|
||||
|
||||
# create the user, otherwise access doesn't work
|
||||
|
@ -238,10 +239,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
|
||||
|
||||
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
|
||||
def test_local_user_fallback_login_legacy(self):
|
||||
def test_local_user_fallback_login_legacy(self) -> None:
|
||||
self.local_user_fallback_login_test_body()
|
||||
|
||||
def local_user_fallback_login_test_body(self):
|
||||
def local_user_fallback_login_test_body(self) -> None:
|
||||
"""rejected login should fall back to local db"""
|
||||
self.register_user("localuser", "localpass")
|
||||
|
||||
|
@ -255,10 +256,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
self.assertEqual("@localuser:test", channel.json_body["user_id"])
|
||||
|
||||
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
|
||||
def test_local_user_fallback_ui_auth_legacy(self):
|
||||
def test_local_user_fallback_ui_auth_legacy(self) -> None:
|
||||
self.local_user_fallback_ui_auth_test_body()
|
||||
|
||||
def local_user_fallback_ui_auth_test_body(self):
|
||||
def local_user_fallback_ui_auth_test_body(self) -> None:
|
||||
"""rejected login should fall back to local db"""
|
||||
self.register_user("localuser", "localpass")
|
||||
|
||||
|
@ -298,10 +299,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
"password_config": {"localdb_enabled": False},
|
||||
}
|
||||
)
|
||||
def test_no_local_user_fallback_login_legacy(self):
|
||||
def test_no_local_user_fallback_login_legacy(self) -> None:
|
||||
self.no_local_user_fallback_login_test_body()
|
||||
|
||||
def no_local_user_fallback_login_test_body(self):
|
||||
def no_local_user_fallback_login_test_body(self) -> None:
|
||||
"""localdb_enabled can block login with the local password"""
|
||||
self.register_user("localuser", "localpass")
|
||||
|
||||
|
@ -320,10 +321,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
"password_config": {"localdb_enabled": False},
|
||||
}
|
||||
)
|
||||
def test_no_local_user_fallback_ui_auth_legacy(self):
|
||||
def test_no_local_user_fallback_ui_auth_legacy(self) -> None:
|
||||
self.no_local_user_fallback_ui_auth_test_body()
|
||||
|
||||
def no_local_user_fallback_ui_auth_test_body(self):
|
||||
def no_local_user_fallback_ui_auth_test_body(self) -> None:
|
||||
"""localdb_enabled can block ui auth with the local password"""
|
||||
self.register_user("localuser", "localpass")
|
||||
|
||||
|
@ -361,10 +362,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
"password_config": {"enabled": False},
|
||||
}
|
||||
)
|
||||
def test_password_auth_disabled_legacy(self):
|
||||
def test_password_auth_disabled_legacy(self) -> None:
|
||||
self.password_auth_disabled_test_body()
|
||||
|
||||
def password_auth_disabled_test_body(self):
|
||||
def password_auth_disabled_test_body(self) -> None:
|
||||
"""password auth doesn't work if it's disabled across the board"""
|
||||
# login flows should be empty
|
||||
flows = self._get_login_flows()
|
||||
|
@ -376,14 +377,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
mock_password_provider.check_password.assert_not_called()
|
||||
|
||||
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
|
||||
def test_custom_auth_provider_login_legacy(self):
|
||||
def test_custom_auth_provider_login_legacy(self) -> None:
|
||||
self.custom_auth_provider_login_test_body()
|
||||
|
||||
@override_config(providers_config(CustomAuthProvider))
|
||||
def test_custom_auth_provider_login(self):
|
||||
def test_custom_auth_provider_login(self) -> None:
|
||||
self.custom_auth_provider_login_test_body()
|
||||
|
||||
def custom_auth_provider_login_test_body(self):
|
||||
def custom_auth_provider_login_test_body(self) -> None:
|
||||
# login flows should have the custom flow and m.login.password, since we
|
||||
# haven't disabled local password lookup.
|
||||
# (password must come first, because reasons)
|
||||
|
@ -424,14 +425,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
|
||||
def test_custom_auth_provider_ui_auth_legacy(self):
|
||||
def test_custom_auth_provider_ui_auth_legacy(self) -> None:
|
||||
self.custom_auth_provider_ui_auth_test_body()
|
||||
|
||||
@override_config(providers_config(CustomAuthProvider))
|
||||
def test_custom_auth_provider_ui_auth(self):
|
||||
def test_custom_auth_provider_ui_auth(self) -> None:
|
||||
self.custom_auth_provider_ui_auth_test_body()
|
||||
|
||||
def custom_auth_provider_ui_auth_test_body(self):
|
||||
def custom_auth_provider_ui_auth_test_body(self) -> None:
|
||||
# register the user and log in twice, to get two devices
|
||||
self.register_user("localuser", "localpass")
|
||||
tok1 = self.login("localuser", "localpass")
|
||||
|
@ -486,14 +487,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
|
||||
def test_custom_auth_provider_callback_legacy(self):
|
||||
def test_custom_auth_provider_callback_legacy(self) -> None:
|
||||
self.custom_auth_provider_callback_test_body()
|
||||
|
||||
@override_config(providers_config(CustomAuthProvider))
|
||||
def test_custom_auth_provider_callback(self):
|
||||
def test_custom_auth_provider_callback(self) -> None:
|
||||
self.custom_auth_provider_callback_test_body()
|
||||
|
||||
def custom_auth_provider_callback_test_body(self):
|
||||
def custom_auth_provider_callback_test_body(self) -> None:
|
||||
callback = Mock(return_value=make_awaitable(None))
|
||||
|
||||
mock_password_provider.check_auth.return_value = make_awaitable(
|
||||
|
@ -521,16 +522,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
"password_config": {"enabled": False},
|
||||
}
|
||||
)
|
||||
def test_custom_auth_password_disabled_legacy(self):
|
||||
def test_custom_auth_password_disabled_legacy(self) -> None:
|
||||
self.custom_auth_password_disabled_test_body()
|
||||
|
||||
@override_config(
|
||||
{**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
|
||||
)
|
||||
def test_custom_auth_password_disabled(self):
|
||||
def test_custom_auth_password_disabled(self) -> None:
|
||||
self.custom_auth_password_disabled_test_body()
|
||||
|
||||
def custom_auth_password_disabled_test_body(self):
|
||||
def custom_auth_password_disabled_test_body(self) -> None:
|
||||
"""Test login with a custom auth provider where password login is disabled"""
|
||||
self.register_user("localuser", "localpass")
|
||||
|
||||
|
@ -548,7 +549,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
"password_config": {"enabled": False, "localdb_enabled": False},
|
||||
}
|
||||
)
|
||||
def test_custom_auth_password_disabled_localdb_enabled_legacy(self):
|
||||
def test_custom_auth_password_disabled_localdb_enabled_legacy(self) -> None:
|
||||
self.custom_auth_password_disabled_localdb_enabled_test_body()
|
||||
|
||||
@override_config(
|
||||
|
@ -557,10 +558,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
"password_config": {"enabled": False, "localdb_enabled": False},
|
||||
}
|
||||
)
|
||||
def test_custom_auth_password_disabled_localdb_enabled(self):
|
||||
def test_custom_auth_password_disabled_localdb_enabled(self) -> None:
|
||||
self.custom_auth_password_disabled_localdb_enabled_test_body()
|
||||
|
||||
def custom_auth_password_disabled_localdb_enabled_test_body(self):
|
||||
def custom_auth_password_disabled_localdb_enabled_test_body(self) -> None:
|
||||
"""Check the localdb_enabled == enabled == False
|
||||
|
||||
Regression test for https://github.com/matrix-org/synapse/issues/8914: check
|
||||
|
@ -583,7 +584,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
"password_config": {"enabled": False},
|
||||
}
|
||||
)
|
||||
def test_password_custom_auth_password_disabled_login_legacy(self):
|
||||
def test_password_custom_auth_password_disabled_login_legacy(self) -> None:
|
||||
self.password_custom_auth_password_disabled_login_test_body()
|
||||
|
||||
@override_config(
|
||||
|
@ -592,10 +593,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
"password_config": {"enabled": False},
|
||||
}
|
||||
)
|
||||
def test_password_custom_auth_password_disabled_login(self):
|
||||
def test_password_custom_auth_password_disabled_login(self) -> None:
|
||||
self.password_custom_auth_password_disabled_login_test_body()
|
||||
|
||||
def password_custom_auth_password_disabled_login_test_body(self):
|
||||
def password_custom_auth_password_disabled_login_test_body(self) -> None:
|
||||
"""log in with a custom auth provider which implements password, but password
|
||||
login is disabled"""
|
||||
self.register_user("localuser", "localpass")
|
||||
|
@ -615,7 +616,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
"password_config": {"enabled": False},
|
||||
}
|
||||
)
|
||||
def test_password_custom_auth_password_disabled_ui_auth_legacy(self):
|
||||
def test_password_custom_auth_password_disabled_ui_auth_legacy(self) -> None:
|
||||
self.password_custom_auth_password_disabled_ui_auth_test_body()
|
||||
|
||||
@override_config(
|
||||
|
@ -624,10 +625,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
"password_config": {"enabled": False},
|
||||
}
|
||||
)
|
||||
def test_password_custom_auth_password_disabled_ui_auth(self):
|
||||
def test_password_custom_auth_password_disabled_ui_auth(self) -> None:
|
||||
self.password_custom_auth_password_disabled_ui_auth_test_body()
|
||||
|
||||
def password_custom_auth_password_disabled_ui_auth_test_body(self):
|
||||
def password_custom_auth_password_disabled_ui_auth_test_body(self) -> None:
|
||||
"""UI Auth with a custom auth provider which implements password, but password
|
||||
login is disabled"""
|
||||
# register the user and log in twice via the test login type to get two devices,
|
||||
|
@ -689,7 +690,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
"password_config": {"localdb_enabled": False},
|
||||
}
|
||||
)
|
||||
def test_custom_auth_no_local_user_fallback_legacy(self):
|
||||
def test_custom_auth_no_local_user_fallback_legacy(self) -> None:
|
||||
self.custom_auth_no_local_user_fallback_test_body()
|
||||
|
||||
@override_config(
|
||||
|
@ -698,10 +699,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
"password_config": {"localdb_enabled": False},
|
||||
}
|
||||
)
|
||||
def test_custom_auth_no_local_user_fallback(self):
|
||||
def test_custom_auth_no_local_user_fallback(self) -> None:
|
||||
self.custom_auth_no_local_user_fallback_test_body()
|
||||
|
||||
def custom_auth_no_local_user_fallback_test_body(self):
|
||||
def custom_auth_no_local_user_fallback_test_body(self) -> None:
|
||||
"""Test login with a custom auth provider where the local db is disabled"""
|
||||
self.register_user("localuser", "localpass")
|
||||
|
||||
|
@ -713,14 +714,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
channel = self._send_password_login("localuser", "localpass")
|
||||
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||
|
||||
def test_on_logged_out(self):
|
||||
def test_on_logged_out(self) -> None:
|
||||
"""Tests that the on_logged_out callback is called when the user logs out."""
|
||||
self.register_user("rin", "password")
|
||||
tok = self.login("rin", "password")
|
||||
|
||||
self.called = False
|
||||
|
||||
async def on_logged_out(user_id, device_id, access_token):
|
||||
async def on_logged_out(
|
||||
user_id: str, device_id: Optional[str], access_token: str
|
||||
) -> None:
|
||||
self.called = True
|
||||
|
||||
on_logged_out = Mock(side_effect=on_logged_out)
|
||||
|
@ -738,7 +741,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
on_logged_out.assert_called_once()
|
||||
self.assertTrue(self.called)
|
||||
|
||||
def test_username(self):
|
||||
def test_username(self) -> None:
|
||||
"""Tests that the get_username_for_registration callback can define the username
|
||||
of a user when registering.
|
||||
"""
|
||||
|
@ -763,7 +766,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
mxid = channel.json_body["user_id"]
|
||||
self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
|
||||
|
||||
def test_username_uia(self):
|
||||
def test_username_uia(self) -> None:
|
||||
"""Tests that the get_username_for_registration callback is only called at the
|
||||
end of the UIA flow.
|
||||
"""
|
||||
|
@ -782,7 +785,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
|
||||
# Set some email configuration so the test doesn't fail because of its absence.
|
||||
@override_config({"email": {"notif_from": "noreply@test"}})
|
||||
def test_3pid_allowed(self):
|
||||
def test_3pid_allowed(self) -> None:
|
||||
"""Tests that an is_3pid_allowed_callbacks forbidding a 3PID makes Synapse refuse
|
||||
to bind the new 3PID, and that one allowing a 3PID makes Synapse accept to bind
|
||||
the 3PID. Also checks that the module is passed a boolean indicating whether the
|
||||
|
@ -791,7 +794,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
self._test_3pid_allowed("rin", False)
|
||||
self._test_3pid_allowed("kitay", True)
|
||||
|
||||
def test_displayname(self):
|
||||
def test_displayname(self) -> None:
|
||||
"""Tests that the get_displayname_for_registration callback can define the
|
||||
display name of a user when registering.
|
||||
"""
|
||||
|
@ -820,7 +823,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
|
||||
self.assertEqual(display_name, username + "-foo")
|
||||
|
||||
def test_displayname_uia(self):
|
||||
def test_displayname_uia(self) -> None:
|
||||
"""Tests that the get_displayname_for_registration callback is only called at the
|
||||
end of the UIA flow.
|
||||
"""
|
||||
|
@ -841,7 +844,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
# Check that the callback has been called.
|
||||
m.assert_called_once()
|
||||
|
||||
def _test_3pid_allowed(self, username: str, registration: bool):
|
||||
def _test_3pid_allowed(self, username: str, registration: bool) -> None:
|
||||
"""Tests that the "is_3pid_allowed" module callback is called correctly, using
|
||||
either /register or /account URLs depending on the arguments.
|
||||
|
||||
|
@ -907,7 +910,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
client is trying to register.
|
||||
"""
|
||||
|
||||
async def callback(uia_results, params):
|
||||
async def callback(uia_results: JsonDict, params: JsonDict) -> str:
|
||||
self.assertIn(LoginType.DUMMY, uia_results)
|
||||
username = params["username"]
|
||||
return username + "-foo"
|
||||
|
@ -950,12 +953,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
def _send_password_login(self, user: str, password: str) -> FakeChannel:
|
||||
return self._send_login(type="m.login.password", user=user, password=password)
|
||||
|
||||
def _send_login(self, type, user, **params) -> FakeChannel:
|
||||
params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type})
|
||||
def _send_login(self, type: str, user: str, **extra_params: str) -> FakeChannel:
|
||||
params = {"identifier": {"type": "m.id.user", "user": user}, "type": type}
|
||||
params.update(extra_params)
|
||||
channel = self.make_request("POST", "/_matrix/client/r0/login", params)
|
||||
return channel
|
||||
|
||||
def _start_delete_device_session(self, access_token, device_id) -> str:
|
||||
def _start_delete_device_session(self, access_token: str, device_id: str) -> str:
|
||||
"""Make an initial delete device request, and return the UI Auth session ID"""
|
||||
channel = self._delete_device(access_token, device_id)
|
||||
self.assertEqual(channel.code, 401)
|
||||
|
|
|
@ -12,12 +12,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
from unittest.mock import Mock, call
|
||||
|
||||
from parameterized import parameterized
|
||||
from signedjson.key import generate_signing_key
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership, PresenceState
|
||||
from synapse.api.presence import UserPresenceState
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
|
@ -35,7 +37,9 @@ from synapse.handlers.presence import (
|
|||
)
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import room
|
||||
from synapse.types import UserID, get_domain_from_id
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict, UserID, get_domain_from_id
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||
|
@ -44,10 +48,12 @@ from tests.replication._base import BaseMultiWorkerStreamTestCase
|
|||
class PresenceUpdateTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [admin.register_servlets]
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
def prepare(
|
||||
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
||||
) -> None:
|
||||
self.store = homeserver.get_datastores().main
|
||||
|
||||
def test_offline_to_online(self):
|
||||
def test_offline_to_online(self) -> None:
|
||||
wheel_timer = Mock()
|
||||
user_id = "@foo:bar"
|
||||
now = 5000000
|
||||
|
@ -85,7 +91,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
|
|||
any_order=True,
|
||||
)
|
||||
|
||||
def test_online_to_online(self):
|
||||
def test_online_to_online(self) -> None:
|
||||
wheel_timer = Mock()
|
||||
user_id = "@foo:bar"
|
||||
now = 5000000
|
||||
|
@ -128,7 +134,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
|
|||
any_order=True,
|
||||
)
|
||||
|
||||
def test_online_to_online_last_active_noop(self):
|
||||
def test_online_to_online_last_active_noop(self) -> None:
|
||||
wheel_timer = Mock()
|
||||
user_id = "@foo:bar"
|
||||
now = 5000000
|
||||
|
@ -173,7 +179,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
|
|||
any_order=True,
|
||||
)
|
||||
|
||||
def test_online_to_online_last_active(self):
|
||||
def test_online_to_online_last_active(self) -> None:
|
||||
wheel_timer = Mock()
|
||||
user_id = "@foo:bar"
|
||||
now = 5000000
|
||||
|
@ -210,7 +216,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
|
|||
any_order=True,
|
||||
)
|
||||
|
||||
def test_remote_ping_timer(self):
|
||||
def test_remote_ping_timer(self) -> None:
|
||||
wheel_timer = Mock()
|
||||
user_id = "@foo:bar"
|
||||
now = 5000000
|
||||
|
@ -244,7 +250,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
|
|||
any_order=True,
|
||||
)
|
||||
|
||||
def test_online_to_offline(self):
|
||||
def test_online_to_offline(self) -> None:
|
||||
wheel_timer = Mock()
|
||||
user_id = "@foo:bar"
|
||||
now = 5000000
|
||||
|
@ -266,7 +272,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
self.assertEqual(wheel_timer.insert.call_count, 0)
|
||||
|
||||
def test_online_to_idle(self):
|
||||
def test_online_to_idle(self) -> None:
|
||||
wheel_timer = Mock()
|
||||
user_id = "@foo:bar"
|
||||
now = 5000000
|
||||
|
@ -300,7 +306,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
|
|||
any_order=True,
|
||||
)
|
||||
|
||||
def test_persisting_presence_updates(self):
|
||||
def test_persisting_presence_updates(self) -> None:
|
||||
"""Tests that the latest presence state for each user is persisted correctly"""
|
||||
# Create some test users and presence states for them
|
||||
presence_states = []
|
||||
|
@ -322,7 +328,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
|
|||
self.get_success(self.store.update_presence(presence_states))
|
||||
|
||||
# Check that each update is present in the database
|
||||
db_presence_states = self.get_success(
|
||||
db_presence_states_raw = self.get_success(
|
||||
self.store.get_all_presence_updates(
|
||||
instance_name="master",
|
||||
last_id=0,
|
||||
|
@ -332,7 +338,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
# Extract presence update user ID and state information into lists of tuples
|
||||
db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]]
|
||||
db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states_raw[0]]
|
||||
presence_states_compare = [(ps.user_id, ps.state) for ps in presence_states]
|
||||
|
||||
# Compare what we put into the storage with what we got out.
|
||||
|
@ -343,7 +349,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
|
|||
class PresenceTimeoutTestCase(unittest.TestCase):
|
||||
"""Tests different timers and that the timer does not change `status_msg` of user."""
|
||||
|
||||
def test_idle_timer(self):
|
||||
def test_idle_timer(self) -> None:
|
||||
user_id = "@foo:bar"
|
||||
status_msg = "I'm here!"
|
||||
now = 5000000
|
||||
|
@ -363,7 +369,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
|||
self.assertEqual(new_state.state, PresenceState.UNAVAILABLE)
|
||||
self.assertEqual(new_state.status_msg, status_msg)
|
||||
|
||||
def test_busy_no_idle(self):
|
||||
def test_busy_no_idle(self) -> None:
|
||||
"""
|
||||
Tests that a user setting their presence to busy but idling doesn't turn their
|
||||
presence state into unavailable.
|
||||
|
@ -387,7 +393,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
|||
self.assertEqual(new_state.state, PresenceState.BUSY)
|
||||
self.assertEqual(new_state.status_msg, status_msg)
|
||||
|
||||
def test_sync_timeout(self):
|
||||
def test_sync_timeout(self) -> None:
|
||||
user_id = "@foo:bar"
|
||||
status_msg = "I'm here!"
|
||||
now = 5000000
|
||||
|
@ -407,7 +413,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
|||
self.assertEqual(new_state.state, PresenceState.OFFLINE)
|
||||
self.assertEqual(new_state.status_msg, status_msg)
|
||||
|
||||
def test_sync_online(self):
|
||||
def test_sync_online(self) -> None:
|
||||
user_id = "@foo:bar"
|
||||
status_msg = "I'm here!"
|
||||
now = 5000000
|
||||
|
@ -429,7 +435,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
|||
self.assertEqual(new_state.state, PresenceState.ONLINE)
|
||||
self.assertEqual(new_state.status_msg, status_msg)
|
||||
|
||||
def test_federation_ping(self):
|
||||
def test_federation_ping(self) -> None:
|
||||
user_id = "@foo:bar"
|
||||
status_msg = "I'm here!"
|
||||
now = 5000000
|
||||
|
@ -448,7 +454,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
|||
self.assertIsNotNone(new_state)
|
||||
self.assertEqual(state, new_state)
|
||||
|
||||
def test_no_timeout(self):
|
||||
def test_no_timeout(self) -> None:
|
||||
user_id = "@foo:bar"
|
||||
now = 5000000
|
||||
|
||||
|
@ -464,7 +470,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
|||
|
||||
self.assertIsNone(new_state)
|
||||
|
||||
def test_federation_timeout(self):
|
||||
def test_federation_timeout(self) -> None:
|
||||
user_id = "@foo:bar"
|
||||
status_msg = "I'm here!"
|
||||
now = 5000000
|
||||
|
@ -487,7 +493,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
|||
self.assertEqual(new_state.state, PresenceState.OFFLINE)
|
||||
self.assertEqual(new_state.status_msg, status_msg)
|
||||
|
||||
def test_last_active(self):
|
||||
def test_last_active(self) -> None:
|
||||
user_id = "@foo:bar"
|
||||
status_msg = "I'm here!"
|
||||
now = 5000000
|
||||
|
@ -508,15 +514,15 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
|||
|
||||
|
||||
class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
def test_external_process_timeout(self):
|
||||
def test_external_process_timeout(self) -> None:
|
||||
"""Test that if an external process doesn't update the records for a while
|
||||
we time out their syncing users presence.
|
||||
"""
|
||||
process_id = 1
|
||||
process_id = "1"
|
||||
user_id = "@test:server"
|
||||
|
||||
# Notify handler that a user is now syncing.
|
||||
|
@ -544,7 +550,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
|
|||
)
|
||||
self.assertEqual(state.state, PresenceState.OFFLINE)
|
||||
|
||||
def test_user_goes_offline_by_timeout_status_msg_remain(self):
|
||||
def test_user_goes_offline_by_timeout_status_msg_remain(self) -> None:
|
||||
"""Test that if a user doesn't update the records for a while
|
||||
users presence goes `OFFLINE` because of timeout and `status_msg` remains.
|
||||
"""
|
||||
|
@ -576,7 +582,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
|
|||
self.assertEqual(state.state, PresenceState.OFFLINE)
|
||||
self.assertEqual(state.status_msg, status_msg)
|
||||
|
||||
def test_user_goes_offline_manually_with_no_status_msg(self):
|
||||
def test_user_goes_offline_manually_with_no_status_msg(self) -> None:
|
||||
"""Test that if a user change presence manually to `OFFLINE`
|
||||
and no status is set, that `status_msg` is `None`.
|
||||
"""
|
||||
|
@ -601,7 +607,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
|
|||
self.assertEqual(state.state, PresenceState.OFFLINE)
|
||||
self.assertEqual(state.status_msg, None)
|
||||
|
||||
def test_user_goes_offline_manually_with_status_msg(self):
|
||||
def test_user_goes_offline_manually_with_status_msg(self) -> None:
|
||||
"""Test that if a user change presence manually to `OFFLINE`
|
||||
and a status is set, that `status_msg` appears.
|
||||
"""
|
||||
|
@ -618,7 +624,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
|
|||
user_id, PresenceState.OFFLINE, "And now here."
|
||||
)
|
||||
|
||||
def test_user_reset_online_with_no_status(self):
|
||||
def test_user_reset_online_with_no_status(self) -> None:
|
||||
"""Test that if a user set again the presence manually
|
||||
and no status is set, that `status_msg` is `None`.
|
||||
"""
|
||||
|
@ -644,7 +650,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
|
|||
self.assertEqual(state.state, PresenceState.ONLINE)
|
||||
self.assertEqual(state.status_msg, None)
|
||||
|
||||
def test_set_presence_with_status_msg_none(self):
|
||||
def test_set_presence_with_status_msg_none(self) -> None:
|
||||
"""Test that if a user set again the presence manually
|
||||
and status is `None`, that `status_msg` is `None`.
|
||||
"""
|
||||
|
@ -659,7 +665,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
|
|||
# Mark user as online and `status_msg = None`
|
||||
self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None)
|
||||
|
||||
def test_set_presence_from_syncing_not_set(self):
|
||||
def test_set_presence_from_syncing_not_set(self) -> None:
|
||||
"""Test that presence is not set by syncing if affect_presence is false"""
|
||||
user_id = "@test:server"
|
||||
status_msg = "I'm here!"
|
||||
|
@ -680,7 +686,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
|
|||
# and status message should still be the same
|
||||
self.assertEqual(state.status_msg, status_msg)
|
||||
|
||||
def test_set_presence_from_syncing_is_set(self):
|
||||
def test_set_presence_from_syncing_is_set(self) -> None:
|
||||
"""Test that presence is set by syncing if affect_presence is true"""
|
||||
user_id = "@test:server"
|
||||
status_msg = "I'm here!"
|
||||
|
@ -699,7 +705,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
|
|||
# we should now be online
|
||||
self.assertEqual(state.state, PresenceState.ONLINE)
|
||||
|
||||
def test_set_presence_from_syncing_keeps_status(self):
|
||||
def test_set_presence_from_syncing_keeps_status(self) -> None:
|
||||
"""Test that presence set by syncing retains status message"""
|
||||
user_id = "@test:server"
|
||||
status_msg = "I'm here!"
|
||||
|
@ -726,7 +732,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
|
|||
},
|
||||
}
|
||||
)
|
||||
def test_set_presence_from_syncing_keeps_busy(self, test_with_workers: bool):
|
||||
def test_set_presence_from_syncing_keeps_busy(
|
||||
self, test_with_workers: bool
|
||||
) -> None:
|
||||
"""Test that presence set by syncing doesn't affect busy status
|
||||
|
||||
Args:
|
||||
|
@ -767,7 +775,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
|
|||
|
||||
def _set_presencestate_with_status_msg(
|
||||
self, user_id: str, state: str, status_msg: Optional[str]
|
||||
):
|
||||
) -> None:
|
||||
"""Set a PresenceState and status_msg and check the result.
|
||||
|
||||
Args:
|
||||
|
@ -790,14 +798,14 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
|
|||
|
||||
|
||||
class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
self.clock = hs.get_clock()
|
||||
self.instance_name = hs.get_instance_name()
|
||||
|
||||
self.queue = self.presence_handler.get_federation_queue()
|
||||
|
||||
def test_send_and_get(self):
|
||||
def test_send_and_get(self) -> None:
|
||||
state1 = UserPresenceState.default("@user1:test")
|
||||
state2 = UserPresenceState.default("@user2:test")
|
||||
state3 = UserPresenceState.default("@user3:test")
|
||||
|
@ -834,7 +842,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
|
|||
self.assertFalse(limited)
|
||||
self.assertCountEqual(rows, [])
|
||||
|
||||
def test_send_and_get_split(self):
|
||||
def test_send_and_get_split(self) -> None:
|
||||
state1 = UserPresenceState.default("@user1:test")
|
||||
state2 = UserPresenceState.default("@user2:test")
|
||||
state3 = UserPresenceState.default("@user3:test")
|
||||
|
@ -877,7 +885,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
self.assertCountEqual(rows, expected_rows)
|
||||
|
||||
def test_clear_queue_all(self):
|
||||
def test_clear_queue_all(self) -> None:
|
||||
state1 = UserPresenceState.default("@user1:test")
|
||||
state2 = UserPresenceState.default("@user2:test")
|
||||
state3 = UserPresenceState.default("@user3:test")
|
||||
|
@ -921,7 +929,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
self.assertCountEqual(rows, expected_rows)
|
||||
|
||||
def test_partially_clear_queue(self):
|
||||
def test_partially_clear_queue(self) -> None:
|
||||
state1 = UserPresenceState.default("@user1:test")
|
||||
state2 = UserPresenceState.default("@user2:test")
|
||||
state3 = UserPresenceState.default("@user3:test")
|
||||
|
@ -982,7 +990,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
servlets = [room.register_servlets]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
hs = self.setup_test_homeserver(
|
||||
"server",
|
||||
federation_http_client=None,
|
||||
|
@ -990,14 +998,14 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
return hs
|
||||
|
||||
def default_config(self):
|
||||
def default_config(self) -> JsonDict:
|
||||
config = super().default_config()
|
||||
# Enable federation sending on the main process.
|
||||
config["federation_sender_instances"] = None
|
||||
return config
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.federation_sender = hs.get_federation_sender()
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.federation_sender = cast(Mock, hs.get_federation_sender())
|
||||
self.event_builder_factory = hs.get_event_builder_factory()
|
||||
self.federation_event_handler = hs.get_federation_event_handler()
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
|
@ -1013,7 +1021,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
|
|||
# random key to use.
|
||||
self.random_signing_key = generate_signing_key("ver")
|
||||
|
||||
def test_remote_joins(self):
|
||||
def test_remote_joins(self) -> None:
|
||||
# We advance time to something that isn't 0, as we use 0 as a special
|
||||
# value.
|
||||
self.reactor.advance(1000000000000)
|
||||
|
@ -1061,7 +1069,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
|
|||
destinations={"server3"}, states=[expected_state]
|
||||
)
|
||||
|
||||
def test_remote_gets_presence_when_local_user_joins(self):
|
||||
def test_remote_gets_presence_when_local_user_joins(self) -> None:
|
||||
# We advance time to something that isn't 0, as we use 0 as a special
|
||||
# value.
|
||||
self.reactor.advance(1000000000000)
|
||||
|
@ -1110,7 +1118,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
|
|||
destinations={"server2", "server3"}, states=[expected_state]
|
||||
)
|
||||
|
||||
def _add_new_user(self, room_id, user_id):
|
||||
def _add_new_user(self, room_id: str, user_id: str) -> None:
|
||||
"""Add new user to the room by creating an event and poking the federation API."""
|
||||
|
||||
hostname = get_domain_from_id(user_id)
|
||||
|
|
|
@ -332,7 +332,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
|||
@unittest.override_config(
|
||||
{"server_name": "test:8888", "allowed_avatar_mimetypes": ["image/png"]}
|
||||
)
|
||||
def test_avatar_constraint_on_local_server_with_port(self):
|
||||
def test_avatar_constraint_on_local_server_with_port(self) -> None:
|
||||
"""Test that avatar metadata is correctly fetched when the media is on a local
|
||||
server and the server has an explicit port.
|
||||
|
||||
|
@ -376,7 +376,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
|||
self.get_success(self.handler.check_avatar_size_and_mime_type(remote_mxc))
|
||||
)
|
||||
|
||||
def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]):
|
||||
def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]) -> None:
|
||||
"""Stores metadata about files in the database.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -15,14 +15,18 @@
|
|||
from copy import deepcopy
|
||||
from typing import List
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.constants import EduTypes, ReceiptTypes
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
|
||||
|
||||
class ReceiptsTestCase(unittest.HomeserverTestCase):
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.event_source = hs.get_event_sources().sources.receipt
|
||||
|
||||
def test_filters_out_private_receipt(self) -> None:
|
||||
|
|
|
@ -12,8 +12,11 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Collection, List, Optional, Tuple
|
||||
from unittest.mock import Mock
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.auth import Auth
|
||||
from synapse.api.constants import UserTypes
|
||||
from synapse.api.errors import (
|
||||
|
@ -22,8 +25,18 @@ from synapse.api.errors import (
|
|||
ResourceLimitError,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.server import HomeServer
|
||||
from synapse.spam_checker_api import RegistrationBehaviour
|
||||
from synapse.types import RoomAlias, RoomID, UserID, create_requester
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
Requester,
|
||||
RoomAlias,
|
||||
RoomID,
|
||||
UserID,
|
||||
create_requester,
|
||||
)
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.test_utils import make_awaitable
|
||||
from tests.unittest import override_config
|
||||
|
@ -33,94 +46,98 @@ from .. import unittest
|
|||
|
||||
|
||||
class TestSpamChecker:
|
||||
def __init__(self, config, api):
|
||||
def __init__(self, config: None, api: ModuleApi):
|
||||
api.register_spam_checker_callbacks(
|
||||
check_registration_for_spam=self.check_registration_for_spam,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_config(config):
|
||||
return config
|
||||
def parse_config(config: JsonDict) -> None:
|
||||
return None
|
||||
|
||||
async def check_registration_for_spam(
|
||||
self,
|
||||
email_threepid,
|
||||
username,
|
||||
request_info,
|
||||
auth_provider_id,
|
||||
):
|
||||
email_threepid: Optional[dict],
|
||||
username: Optional[str],
|
||||
request_info: Collection[Tuple[str, str]],
|
||||
auth_provider_id: Optional[str],
|
||||
) -> RegistrationBehaviour:
|
||||
pass
|
||||
|
||||
|
||||
class DenyAll(TestSpamChecker):
|
||||
async def check_registration_for_spam(
|
||||
self,
|
||||
email_threepid,
|
||||
username,
|
||||
request_info,
|
||||
auth_provider_id,
|
||||
):
|
||||
email_threepid: Optional[dict],
|
||||
username: Optional[str],
|
||||
request_info: Collection[Tuple[str, str]],
|
||||
auth_provider_id: Optional[str],
|
||||
) -> RegistrationBehaviour:
|
||||
return RegistrationBehaviour.DENY
|
||||
|
||||
|
||||
class BanAll(TestSpamChecker):
|
||||
async def check_registration_for_spam(
|
||||
self,
|
||||
email_threepid,
|
||||
username,
|
||||
request_info,
|
||||
auth_provider_id,
|
||||
):
|
||||
email_threepid: Optional[dict],
|
||||
username: Optional[str],
|
||||
request_info: Collection[Tuple[str, str]],
|
||||
auth_provider_id: Optional[str],
|
||||
) -> RegistrationBehaviour:
|
||||
return RegistrationBehaviour.SHADOW_BAN
|
||||
|
||||
|
||||
class BanBadIdPUser(TestSpamChecker):
|
||||
async def check_registration_for_spam(
|
||||
self, email_threepid, username, request_info, auth_provider_id=None
|
||||
):
|
||||
self,
|
||||
email_threepid: Optional[dict],
|
||||
username: Optional[str],
|
||||
request_info: Collection[Tuple[str, str]],
|
||||
auth_provider_id: Optional[str] = None,
|
||||
) -> RegistrationBehaviour:
|
||||
# Reject any user coming from CAS and whose username contains profanity
|
||||
if auth_provider_id == "cas" and "flimflob" in username:
|
||||
if auth_provider_id == "cas" and username and "flimflob" in username:
|
||||
return RegistrationBehaviour.DENY
|
||||
return RegistrationBehaviour.ALLOW
|
||||
|
||||
|
||||
class TestLegacyRegistrationSpamChecker:
|
||||
def __init__(self, config, api):
|
||||
def __init__(self, config: None, api: ModuleApi):
|
||||
pass
|
||||
|
||||
async def check_registration_for_spam(
|
||||
self,
|
||||
email_threepid,
|
||||
username,
|
||||
request_info,
|
||||
):
|
||||
email_threepid: Optional[dict],
|
||||
username: Optional[str],
|
||||
request_info: Collection[Tuple[str, str]],
|
||||
) -> RegistrationBehaviour:
|
||||
pass
|
||||
|
||||
|
||||
class LegacyAllowAll(TestLegacyRegistrationSpamChecker):
|
||||
async def check_registration_for_spam(
|
||||
self,
|
||||
email_threepid,
|
||||
username,
|
||||
request_info,
|
||||
):
|
||||
email_threepid: Optional[dict],
|
||||
username: Optional[str],
|
||||
request_info: Collection[Tuple[str, str]],
|
||||
) -> RegistrationBehaviour:
|
||||
return RegistrationBehaviour.ALLOW
|
||||
|
||||
|
||||
class LegacyDenyAll(TestLegacyRegistrationSpamChecker):
|
||||
async def check_registration_for_spam(
|
||||
self,
|
||||
email_threepid,
|
||||
username,
|
||||
request_info,
|
||||
):
|
||||
email_threepid: Optional[dict],
|
||||
username: Optional[str],
|
||||
request_info: Collection[Tuple[str, str]],
|
||||
) -> RegistrationBehaviour:
|
||||
return RegistrationBehaviour.DENY
|
||||
|
||||
|
||||
class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||
"""Tests the RegistrationHandler."""
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
hs_config = self.default_config()
|
||||
|
||||
# some of the tests rely on us having a user consent version
|
||||
|
@ -145,7 +162,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
return hs
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.handler = self.hs.get_registration_handler()
|
||||
self.store = self.hs.get_datastores().main
|
||||
self.lots_of_users = 100
|
||||
|
@ -153,7 +170,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
self.requester = create_requester("@requester:test")
|
||||
|
||||
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
|
||||
def test_user_is_created_and_logged_in_if_doesnt_exist(self) -> None:
|
||||
frank = UserID.from_string("@frank:test")
|
||||
user_id = frank.to_string()
|
||||
requester = create_requester(user_id)
|
||||
|
@ -164,7 +181,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
self.assertIsInstance(result_token, str)
|
||||
self.assertGreater(len(result_token), 20)
|
||||
|
||||
def test_if_user_exists(self):
|
||||
def test_if_user_exists(self) -> None:
|
||||
store = self.hs.get_datastores().main
|
||||
frank = UserID.from_string("@frank:test")
|
||||
self.get_success(
|
||||
|
@ -180,12 +197,12 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
self.assertTrue(result_token is not None)
|
||||
|
||||
@override_config({"limit_usage_by_mau": False})
|
||||
def test_mau_limits_when_disabled(self):
|
||||
def test_mau_limits_when_disabled(self) -> None:
|
||||
# Ensure does not throw exception
|
||||
self.get_success(self.get_or_create_user(self.requester, "a", "display_name"))
|
||||
|
||||
@override_config({"limit_usage_by_mau": True})
|
||||
def test_get_or_create_user_mau_not_blocked(self):
|
||||
def test_get_or_create_user_mau_not_blocked(self) -> None:
|
||||
self.store.count_monthly_users = Mock(
|
||||
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
|
||||
)
|
||||
|
@ -193,7 +210,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
self.get_success(self.get_or_create_user(self.requester, "c", "User"))
|
||||
|
||||
@override_config({"limit_usage_by_mau": True})
|
||||
def test_get_or_create_user_mau_blocked(self):
|
||||
def test_get_or_create_user_mau_blocked(self) -> None:
|
||||
self.store.get_monthly_active_count = Mock(
|
||||
return_value=make_awaitable(self.lots_of_users)
|
||||
)
|
||||
|
@ -211,7 +228,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
@override_config({"limit_usage_by_mau": True})
|
||||
def test_register_mau_blocked(self):
|
||||
def test_register_mau_blocked(self) -> None:
|
||||
self.store.get_monthly_active_count = Mock(
|
||||
return_value=make_awaitable(self.lots_of_users)
|
||||
)
|
||||
|
@ -229,7 +246,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
@override_config(
|
||||
{"auto_join_rooms": ["#room:test"], "auto_join_rooms_for_guests": False}
|
||||
)
|
||||
def test_auto_join_rooms_for_guests(self):
|
||||
def test_auto_join_rooms_for_guests(self) -> None:
|
||||
user_id = self.get_success(
|
||||
self.handler.register_user(localpart="jeff", make_guest=True),
|
||||
)
|
||||
|
@ -237,7 +254,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(len(rooms), 0)
|
||||
|
||||
@override_config({"auto_join_rooms": ["#room:test"]})
|
||||
def test_auto_create_auto_join_rooms(self):
|
||||
def test_auto_create_auto_join_rooms(self) -> None:
|
||||
room_alias_str = "#room:test"
|
||||
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
|
||||
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
|
||||
|
@ -249,7 +266,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(len(rooms), 1)
|
||||
|
||||
@override_config({"auto_join_rooms": []})
|
||||
def test_auto_create_auto_join_rooms_with_no_rooms(self):
|
||||
def test_auto_create_auto_join_rooms_with_no_rooms(self) -> None:
|
||||
frank = UserID.from_string("@frank:test")
|
||||
user_id = self.get_success(self.handler.register_user(frank.localpart))
|
||||
self.assertEqual(user_id, frank.to_string())
|
||||
|
@ -257,7 +274,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(len(rooms), 0)
|
||||
|
||||
@override_config({"auto_join_rooms": ["#room:another"]})
|
||||
def test_auto_create_auto_join_where_room_is_another_domain(self):
|
||||
def test_auto_create_auto_join_where_room_is_another_domain(self) -> None:
|
||||
frank = UserID.from_string("@frank:test")
|
||||
user_id = self.get_success(self.handler.register_user(frank.localpart))
|
||||
self.assertEqual(user_id, frank.to_string())
|
||||
|
@ -267,13 +284,13 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
@override_config(
|
||||
{"auto_join_rooms": ["#room:test"], "autocreate_auto_join_rooms": False}
|
||||
)
|
||||
def test_auto_create_auto_join_where_auto_create_is_false(self):
|
||||
def test_auto_create_auto_join_where_auto_create_is_false(self) -> None:
|
||||
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
|
||||
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
|
||||
self.assertEqual(len(rooms), 0)
|
||||
|
||||
@override_config({"auto_join_rooms": ["#room:test"]})
|
||||
def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self):
|
||||
def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self) -> None:
|
||||
room_alias_str = "#room:test"
|
||||
self.store.is_real_user = Mock(return_value=make_awaitable(False))
|
||||
user_id = self.get_success(self.handler.register_user(localpart="support"))
|
||||
|
@ -284,7 +301,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
self.get_failure(directory_handler.get_association(room_alias), SynapseError)
|
||||
|
||||
@override_config({"auto_join_rooms": ["#room:test"]})
|
||||
def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self):
|
||||
def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None:
|
||||
room_alias_str = "#room:test"
|
||||
|
||||
self.store.count_real_users = Mock(return_value=make_awaitable(1))
|
||||
|
@ -299,7 +316,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(len(rooms), 1)
|
||||
|
||||
@override_config({"auto_join_rooms": ["#room:test"]})
|
||||
def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(self):
|
||||
def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(
|
||||
self,
|
||||
) -> None:
|
||||
self.store.count_real_users = Mock(return_value=make_awaitable(2))
|
||||
self.store.is_real_user = Mock(return_value=make_awaitable(True))
|
||||
user_id = self.get_success(self.handler.register_user(localpart="real"))
|
||||
|
@ -312,7 +331,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
"autocreate_auto_join_rooms_federated": False,
|
||||
}
|
||||
)
|
||||
def test_auto_create_auto_join_rooms_federated(self):
|
||||
def test_auto_create_auto_join_rooms_federated(self) -> None:
|
||||
"""
|
||||
Auto-created rooms that are private require an invite to go to the user
|
||||
(instead of directly joining it).
|
||||
|
@ -339,7 +358,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
@override_config(
|
||||
{"auto_join_rooms": ["#room:test"], "auto_join_mxid_localpart": "support"}
|
||||
)
|
||||
def test_auto_join_mxid_localpart(self):
|
||||
def test_auto_join_mxid_localpart(self) -> None:
|
||||
"""
|
||||
Ensure the user still needs up in the room created by a different user.
|
||||
"""
|
||||
|
@ -376,7 +395,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
"auto_join_mxid_localpart": "support",
|
||||
}
|
||||
)
|
||||
def test_auto_create_auto_join_room_preset(self):
|
||||
def test_auto_create_auto_join_room_preset(self) -> None:
|
||||
"""
|
||||
Auto-created rooms that are private require an invite to go to the user
|
||||
(instead of directly joining it).
|
||||
|
@ -416,7 +435,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
"auto_join_mxid_localpart": "support",
|
||||
}
|
||||
)
|
||||
def test_auto_create_auto_join_room_preset_guest(self):
|
||||
def test_auto_create_auto_join_room_preset_guest(self) -> None:
|
||||
"""
|
||||
Auto-created rooms that are private require an invite to go to the user
|
||||
(instead of directly joining it).
|
||||
|
@ -454,7 +473,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
"auto_join_mxid_localpart": "support",
|
||||
}
|
||||
)
|
||||
def test_auto_create_auto_join_room_preset_invalid_permissions(self):
|
||||
def test_auto_create_auto_join_room_preset_invalid_permissions(self) -> None:
|
||||
"""
|
||||
Auto-created rooms that are private require an invite, check that
|
||||
registration doesn't completely break if the inviter doesn't have proper
|
||||
|
@ -525,7 +544,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
"auto_join_rooms": ["#room:test"],
|
||||
},
|
||||
)
|
||||
def test_auto_create_auto_join_where_no_consent(self):
|
||||
def test_auto_create_auto_join_where_no_consent(self) -> None:
|
||||
"""Test to ensure that the first user is not auto-joined to a room if
|
||||
they have not given general consent.
|
||||
"""
|
||||
|
@ -550,19 +569,19 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
|
||||
self.assertEqual(len(rooms), 1)
|
||||
|
||||
def test_register_support_user(self):
|
||||
def test_register_support_user(self) -> None:
|
||||
user_id = self.get_success(
|
||||
self.handler.register_user(localpart="user", user_type=UserTypes.SUPPORT)
|
||||
)
|
||||
d = self.store.is_support_user(user_id)
|
||||
self.assertTrue(self.get_success(d))
|
||||
|
||||
def test_register_not_support_user(self):
|
||||
def test_register_not_support_user(self) -> None:
|
||||
user_id = self.get_success(self.handler.register_user(localpart="user"))
|
||||
d = self.store.is_support_user(user_id)
|
||||
self.assertFalse(self.get_success(d))
|
||||
|
||||
def test_invalid_user_id_length(self):
|
||||
def test_invalid_user_id_length(self) -> None:
|
||||
invalid_user_id = "x" * 256
|
||||
self.get_failure(
|
||||
self.handler.register_user(localpart=invalid_user_id), SynapseError
|
||||
|
@ -577,7 +596,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
]
|
||||
}
|
||||
)
|
||||
def test_spam_checker_deny(self):
|
||||
def test_spam_checker_deny(self) -> None:
|
||||
"""A spam checker can deny registration, which results in an error."""
|
||||
self.get_failure(self.handler.register_user(localpart="user"), SynapseError)
|
||||
|
||||
|
@ -590,7 +609,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
]
|
||||
}
|
||||
)
|
||||
def test_spam_checker_legacy_allow(self):
|
||||
def test_spam_checker_legacy_allow(self) -> None:
|
||||
"""Tests that a legacy spam checker implementing the legacy 3-arg version of the
|
||||
check_registration_for_spam callback is correctly called.
|
||||
|
||||
|
@ -610,7 +629,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
]
|
||||
}
|
||||
)
|
||||
def test_spam_checker_legacy_deny(self):
|
||||
def test_spam_checker_legacy_deny(self) -> None:
|
||||
"""Tests that a legacy spam checker implementing the legacy 3-arg version of the
|
||||
check_registration_for_spam callback is correctly called.
|
||||
|
||||
|
@ -630,7 +649,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
]
|
||||
}
|
||||
)
|
||||
def test_spam_checker_shadow_ban(self):
|
||||
def test_spam_checker_shadow_ban(self) -> None:
|
||||
"""A spam checker can choose to shadow-ban a user, which allows registration to succeed."""
|
||||
user_id = self.get_success(self.handler.register_user(localpart="user"))
|
||||
|
||||
|
@ -660,7 +679,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
]
|
||||
}
|
||||
)
|
||||
def test_spam_checker_receives_sso_type(self):
|
||||
def test_spam_checker_receives_sso_type(self) -> None:
|
||||
"""Test rejecting registration based on SSO type"""
|
||||
f = self.get_failure(
|
||||
self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"),
|
||||
|
@ -678,8 +697,12 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
async def get_or_create_user(
|
||||
self, requester, localpart, displayname, password_hash=None
|
||||
):
|
||||
self,
|
||||
requester: Requester,
|
||||
localpart: str,
|
||||
displayname: Optional[str],
|
||||
password_hash: Optional[str] = None,
|
||||
) -> Tuple[str, str]:
|
||||
"""Creates a new user if the user does not exist,
|
||||
else revokes all previous access tokens and generates a new one.
|
||||
|
||||
|
@ -734,13 +757,15 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
class RemoteAutoJoinTestCase(unittest.HomeserverTestCase):
|
||||
"""Tests auto-join on remote rooms."""
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
self.room_id = "!roomid:remotetest"
|
||||
|
||||
async def update_membership(*args, **kwargs):
|
||||
async def update_membership(*args: Any, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
async def lookup_room_alias(*args, **kwargs):
|
||||
async def lookup_room_alias(
|
||||
*args: Any, **kwargs: Any
|
||||
) -> Tuple[RoomID, List[str]]:
|
||||
return RoomID.from_string(self.room_id), ["remotetest"]
|
||||
|
||||
self.room_member_handler = Mock(spec=["update_membership", "lookup_room_alias"])
|
||||
|
@ -750,12 +775,12 @@ class RemoteAutoJoinTestCase(unittest.HomeserverTestCase):
|
|||
hs = self.setup_test_homeserver(room_member_handler=self.room_member_handler)
|
||||
return hs
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.handler = self.hs.get_registration_handler()
|
||||
self.store = self.hs.get_datastores().main
|
||||
|
||||
@override_config({"auto_join_rooms": ["#room:remotetest"]})
|
||||
def test_auto_create_auto_join_remote_room(self):
|
||||
def test_auto_create_auto_join_remote_room(self) -> None:
|
||||
"""Tests that we don't attempt to create remote rooms, and that we don't attempt
|
||||
to invite ourselves to rooms we're not in."""
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ class EncryptedByDefaultTestCase(unittest.HomeserverTestCase):
|
|||
]
|
||||
|
||||
@override_config({"encryption_enabled_by_default_for_room_type": "all"})
|
||||
def test_encrypted_by_default_config_option_all(self):
|
||||
def test_encrypted_by_default_config_option_all(self) -> None:
|
||||
"""Tests that invite-only and non-invite-only rooms have encryption enabled by
|
||||
default when the config option encryption_enabled_by_default_for_room_type is "all".
|
||||
"""
|
||||
|
@ -45,7 +45,7 @@ class EncryptedByDefaultTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
|
||||
|
||||
@override_config({"encryption_enabled_by_default_for_room_type": "invite"})
|
||||
def test_encrypted_by_default_config_option_invite(self):
|
||||
def test_encrypted_by_default_config_option_invite(self) -> None:
|
||||
"""Tests that only new, invite-only rooms have encryption enabled by default when
|
||||
the config option encryption_enabled_by_default_for_room_type is "invite".
|
||||
"""
|
||||
|
@ -76,7 +76,7 @@ class EncryptedByDefaultTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
@override_config({"encryption_enabled_by_default_for_room_type": "off"})
|
||||
def test_encrypted_by_default_config_option_off(self):
|
||||
def test_encrypted_by_default_config_option_off(self) -> None:
|
||||
"""Tests that neither new invite-only nor non-invite-only rooms have encryption
|
||||
enabled by default when the config option
|
||||
encryption_enabled_by_default_for_room_type is "off".
|
||||
|
|
|
@ -11,10 +11,11 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Iterable, List, Optional, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
|
||||
from unittest import mock
|
||||
|
||||
from twisted.internet.defer import ensureDeferred
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.constants import (
|
||||
EventContentFields,
|
||||
|
@ -34,11 +35,14 @@ from synapse.rest import admin
|
|||
from synapse.rest.client import login, room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict, UserID, create_requester
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
|
||||
|
||||
def _create_event(room_id: str, order: Optional[Any] = None, origin_server_ts: int = 0):
|
||||
def _create_event(
|
||||
room_id: str, order: Optional[Any] = None, origin_server_ts: int = 0
|
||||
) -> mock.Mock:
|
||||
result = mock.Mock(name=room_id)
|
||||
result.room_id = room_id
|
||||
result.content = {}
|
||||
|
@ -48,40 +52,40 @@ def _create_event(room_id: str, order: Optional[Any] = None, origin_server_ts: i
|
|||
return result
|
||||
|
||||
|
||||
def _order(*events):
|
||||
def _order(*events: mock.Mock) -> List[mock.Mock]:
|
||||
return sorted(events, key=_child_events_comparison_key)
|
||||
|
||||
|
||||
class TestSpaceSummarySort(unittest.TestCase):
|
||||
def test_no_order_last(self):
|
||||
def test_no_order_last(self) -> None:
|
||||
"""An event with no ordering is placed behind those with an ordering."""
|
||||
ev1 = _create_event("!abc:test")
|
||||
ev2 = _create_event("!xyz:test", "xyz")
|
||||
|
||||
self.assertEqual([ev2, ev1], _order(ev1, ev2))
|
||||
|
||||
def test_order(self):
|
||||
def test_order(self) -> None:
|
||||
"""The ordering should be used."""
|
||||
ev1 = _create_event("!abc:test", "xyz")
|
||||
ev2 = _create_event("!xyz:test", "abc")
|
||||
|
||||
self.assertEqual([ev2, ev1], _order(ev1, ev2))
|
||||
|
||||
def test_order_origin_server_ts(self):
|
||||
def test_order_origin_server_ts(self) -> None:
|
||||
"""Origin server is a tie-breaker for ordering."""
|
||||
ev1 = _create_event("!abc:test", origin_server_ts=10)
|
||||
ev2 = _create_event("!xyz:test", origin_server_ts=30)
|
||||
|
||||
self.assertEqual([ev1, ev2], _order(ev1, ev2))
|
||||
|
||||
def test_order_room_id(self):
|
||||
def test_order_room_id(self) -> None:
|
||||
"""Room ID is a final tie-breaker for ordering."""
|
||||
ev1 = _create_event("!abc:test")
|
||||
ev2 = _create_event("!xyz:test")
|
||||
|
||||
self.assertEqual([ev1, ev2], _order(ev1, ev2))
|
||||
|
||||
def test_invalid_ordering_type(self):
|
||||
def test_invalid_ordering_type(self) -> None:
|
||||
"""Invalid orderings are considered the same as missing."""
|
||||
ev1 = _create_event("!abc:test", 1)
|
||||
ev2 = _create_event("!xyz:test", "xyz")
|
||||
|
@ -97,7 +101,7 @@ class TestSpaceSummarySort(unittest.TestCase):
|
|||
ev1 = _create_event("!abc:test", True)
|
||||
self.assertEqual([ev2, ev1], _order(ev1, ev2))
|
||||
|
||||
def test_invalid_ordering_value(self):
|
||||
def test_invalid_ordering_value(self) -> None:
|
||||
"""Invalid orderings are considered the same as missing."""
|
||||
ev1 = _create_event("!abc:test", "foo\n")
|
||||
ev2 = _create_event("!xyz:test", "xyz")
|
||||
|
@ -115,7 +119,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
|||
login.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs: HomeServer):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.hs = hs
|
||||
self.handler = self.hs.get_room_summary_handler()
|
||||
|
||||
|
@ -223,7 +227,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
|||
fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6)
|
||||
)
|
||||
|
||||
def test_simple_space(self):
|
||||
def test_simple_space(self) -> None:
|
||||
"""Test a simple space with a single room."""
|
||||
# The result should have the space and the room in it, along with a link
|
||||
# from space -> room.
|
||||
|
@ -234,7 +238,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
self._assert_hierarchy(result, expected)
|
||||
|
||||
def test_large_space(self):
|
||||
def test_large_space(self) -> None:
|
||||
"""Test a space with a large number of rooms."""
|
||||
rooms = [self.room]
|
||||
# Make at least 51 rooms that are part of the space.
|
||||
|
@ -260,7 +264,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
|||
result["rooms"] += result2["rooms"]
|
||||
self._assert_hierarchy(result, expected)
|
||||
|
||||
def test_visibility(self):
|
||||
def test_visibility(self) -> None:
|
||||
"""A user not in a space cannot inspect it."""
|
||||
user2 = self.register_user("user2", "pass")
|
||||
token2 = self.login("user2", "pass")
|
||||
|
@ -380,7 +384,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
|||
self._assert_hierarchy(result2, [(self.space, [self.room])])
|
||||
|
||||
def _create_room_with_join_rule(
|
||||
self, join_rule: str, room_version: Optional[str] = None, **extra_content
|
||||
self, join_rule: str, room_version: Optional[str] = None, **extra_content: Any
|
||||
) -> str:
|
||||
"""Create a room with the given join rule and add it to the space."""
|
||||
room_id = self.helper.create_room_as(
|
||||
|
@ -403,7 +407,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
|||
self._add_child(self.space, room_id, self.token)
|
||||
return room_id
|
||||
|
||||
def test_filtering(self):
|
||||
def test_filtering(self) -> None:
|
||||
"""
|
||||
Rooms should be properly filtered to only include rooms the user has access to.
|
||||
"""
|
||||
|
@ -476,7 +480,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
self._assert_hierarchy(result, expected)
|
||||
|
||||
def test_complex_space(self):
|
||||
def test_complex_space(self) -> None:
|
||||
"""
|
||||
Create a "complex" space to see how it handles things like loops and subspaces.
|
||||
"""
|
||||
|
@ -516,7 +520,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
self._assert_hierarchy(result, expected)
|
||||
|
||||
def test_pagination(self):
|
||||
def test_pagination(self) -> None:
|
||||
"""Test simple pagination works."""
|
||||
room_ids = []
|
||||
for i in range(1, 10):
|
||||
|
@ -553,7 +557,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
|||
self._assert_hierarchy(result, expected)
|
||||
self.assertNotIn("next_batch", result)
|
||||
|
||||
def test_invalid_pagination_token(self):
|
||||
def test_invalid_pagination_token(self) -> None:
|
||||
"""An invalid pagination token, or changing other parameters, shoudl be rejected."""
|
||||
room_ids = []
|
||||
for i in range(1, 10):
|
||||
|
@ -604,7 +608,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
|||
SynapseError,
|
||||
)
|
||||
|
||||
def test_max_depth(self):
|
||||
def test_max_depth(self) -> None:
|
||||
"""Create a deep tree to test the max depth against."""
|
||||
spaces = [self.space]
|
||||
rooms = [self.room]
|
||||
|
@ -659,7 +663,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
|||
]
|
||||
self._assert_hierarchy(result, expected)
|
||||
|
||||
def test_unknown_room_version(self):
|
||||
def test_unknown_room_version(self) -> None:
|
||||
"""
|
||||
If a room with an unknown room version is encountered it should not cause
|
||||
the entire summary to skip.
|
||||
|
@ -685,7 +689,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
self._assert_hierarchy(result, expected)
|
||||
|
||||
def test_fed_complex(self):
|
||||
def test_fed_complex(self) -> None:
|
||||
"""
|
||||
Return data over federation and ensure that it is handled properly.
|
||||
"""
|
||||
|
@ -722,7 +726,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
|||
"world_readable": True,
|
||||
}
|
||||
|
||||
async def summarize_remote_room_hierarchy(_self, room, suggested_only):
|
||||
async def summarize_remote_room_hierarchy(
|
||||
_self: Any, room: Any, suggested_only: bool
|
||||
) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]:
|
||||
return requested_room_entry, {subroom: child_room}, set()
|
||||
|
||||
# Add a room to the space which is on another server.
|
||||
|
@ -744,7 +750,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
self._assert_hierarchy(result, expected)
|
||||
|
||||
def test_fed_filtering(self):
|
||||
def test_fed_filtering(self) -> None:
|
||||
"""
|
||||
Rooms returned over federation should be properly filtered to only include
|
||||
rooms the user has access to.
|
||||
|
@ -853,7 +859,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
|||
],
|
||||
)
|
||||
|
||||
async def summarize_remote_room_hierarchy(_self, room, suggested_only):
|
||||
async def summarize_remote_room_hierarchy(
|
||||
_self: Any, room: Any, suggested_only: bool
|
||||
) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]:
|
||||
return subspace_room_entry, dict(children_rooms), set()
|
||||
|
||||
# Add a room to the space which is on another server.
|
||||
|
@ -892,7 +900,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
self._assert_hierarchy(result, expected)
|
||||
|
||||
def test_fed_invited(self):
|
||||
def test_fed_invited(self) -> None:
|
||||
"""
|
||||
A room which the user was invited to should be included in the response.
|
||||
|
||||
|
@ -915,7 +923,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
|||
},
|
||||
)
|
||||
|
||||
async def summarize_remote_room_hierarchy(_self, room, suggested_only):
|
||||
async def summarize_remote_room_hierarchy(
|
||||
_self: Any, room: Any, suggested_only: bool
|
||||
) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]:
|
||||
return fed_room_entry, {}, set()
|
||||
|
||||
# Add a room to the space which is on another server.
|
||||
|
@ -936,7 +946,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
self._assert_hierarchy(result, expected)
|
||||
|
||||
def test_fed_caching(self):
|
||||
def test_fed_caching(self) -> None:
|
||||
"""
|
||||
Federation `/hierarchy` responses should be cached.
|
||||
"""
|
||||
|
@ -1023,7 +1033,7 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase):
|
|||
login.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs: HomeServer):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.hs = hs
|
||||
self.handler = self.hs.get_room_summary_handler()
|
||||
|
||||
|
@ -1040,12 +1050,12 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase):
|
|||
tok=self.token,
|
||||
)
|
||||
|
||||
def test_own_room(self):
|
||||
def test_own_room(self) -> None:
|
||||
"""Test a simple room created by the requester."""
|
||||
result = self.get_success(self.handler.get_room_summary(self.user, self.room))
|
||||
self.assertEqual(result.get("room_id"), self.room)
|
||||
|
||||
def test_visibility(self):
|
||||
def test_visibility(self) -> None:
|
||||
"""A user not in a private room cannot get its summary."""
|
||||
user2 = self.register_user("user2", "pass")
|
||||
token2 = self.login("user2", "pass")
|
||||
|
@ -1093,7 +1103,7 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase):
|
|||
result = self.get_success(self.handler.get_room_summary(user2, self.room))
|
||||
self.assertEqual(result.get("room_id"), self.room)
|
||||
|
||||
def test_fed(self):
|
||||
def test_fed(self) -> None:
|
||||
"""
|
||||
Return data over federation and ensure that it is handled properly.
|
||||
"""
|
||||
|
@ -1105,7 +1115,9 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase):
|
|||
{"room_id": fed_room, "world_readable": True},
|
||||
)
|
||||
|
||||
async def summarize_remote_room_hierarchy(_self, room, suggested_only):
|
||||
async def summarize_remote_room_hierarchy(
|
||||
_self: Any, room: Any, suggested_only: bool
|
||||
) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]:
|
||||
return requested_room_entry, {}, set()
|
||||
|
||||
with mock.patch(
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Set, Tuple
|
||||
from unittest.mock import Mock
|
||||
|
||||
import attr
|
||||
|
@ -20,7 +20,9 @@ import attr
|
|||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.errors import RedirectException
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.test_utils import simple_async_mock
|
||||
|
@ -29,6 +31,7 @@ from tests.unittest import HomeserverTestCase, override_config
|
|||
# Check if we have the dependencies to run the tests.
|
||||
try:
|
||||
import saml2.config
|
||||
import saml2.response
|
||||
from saml2.sigver import SigverError
|
||||
|
||||
has_saml2 = True
|
||||
|
@ -56,31 +59,39 @@ class FakeAuthnResponse:
|
|||
|
||||
|
||||
class TestMappingProvider:
|
||||
def __init__(self, config, module):
|
||||
def __init__(self, config: None, module: ModuleApi):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def parse_config(config):
|
||||
return
|
||||
def parse_config(config: JsonDict) -> None:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_saml_attributes(config):
|
||||
def get_saml_attributes(config: None) -> Tuple[Set[str], Set[str]]:
|
||||
return {"uid"}, {"displayName"}
|
||||
|
||||
def get_remote_user_id(self, saml_response, client_redirect_url):
|
||||
def get_remote_user_id(
|
||||
self, saml_response: "saml2.response.AuthnResponse", client_redirect_url: str
|
||||
) -> str:
|
||||
return saml_response.ava["uid"]
|
||||
|
||||
def saml_response_to_user_attributes(
|
||||
self, saml_response, failures, client_redirect_url
|
||||
):
|
||||
self,
|
||||
saml_response: "saml2.response.AuthnResponse",
|
||||
failures: int,
|
||||
client_redirect_url: str,
|
||||
) -> dict:
|
||||
localpart = saml_response.ava["username"] + (str(failures) if failures else "")
|
||||
return {"mxid_localpart": localpart, "displayname": None}
|
||||
|
||||
|
||||
class TestRedirectMappingProvider(TestMappingProvider):
|
||||
def saml_response_to_user_attributes(
|
||||
self, saml_response, failures, client_redirect_url
|
||||
):
|
||||
self,
|
||||
saml_response: "saml2.response.AuthnResponse",
|
||||
failures: int,
|
||||
client_redirect_url: str,
|
||||
) -> dict:
|
||||
raise RedirectException(b"https://custom-saml-redirect/")
|
||||
|
||||
|
||||
|
@ -347,7 +358,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
|||
)
|
||||
|
||||
|
||||
def _mock_request():
|
||||
def _mock_request() -> Mock:
|
||||
"""Returns a mock which will stand in as a SynapseRequest"""
|
||||
mock = Mock(
|
||||
spec=[
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import List, Tuple
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
|
@ -28,20 +28,27 @@ from tests.unittest import HomeserverTestCase, override_config
|
|||
|
||||
@implementer(interfaces.IMessageDelivery)
|
||||
class _DummyMessageDelivery:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
# (recipient, message) tuples
|
||||
self.messages: List[Tuple[smtp.Address, bytes]] = []
|
||||
|
||||
def receivedHeader(self, helo, origin, recipients):
|
||||
def receivedHeader(
|
||||
self,
|
||||
helo: Tuple[bytes, bytes],
|
||||
origin: smtp.Address,
|
||||
recipients: List[smtp.User],
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
def validateFrom(self, helo, origin):
|
||||
def validateFrom(
|
||||
self, helo: Tuple[bytes, bytes], origin: smtp.Address
|
||||
) -> smtp.Address:
|
||||
return origin
|
||||
|
||||
def record_message(self, recipient: smtp.Address, message: bytes):
|
||||
def record_message(self, recipient: smtp.Address, message: bytes) -> None:
|
||||
self.messages.append((recipient, message))
|
||||
|
||||
def validateTo(self, user: smtp.User):
|
||||
def validateTo(self, user: smtp.User) -> Callable[[], interfaces.IMessageSMTP]:
|
||||
return lambda: _DummyMessage(self, user)
|
||||
|
||||
|
||||
|
@ -56,20 +63,20 @@ class _DummyMessage:
|
|||
self._user = user
|
||||
self._buffer: List[bytes] = []
|
||||
|
||||
def lineReceived(self, line):
|
||||
def lineReceived(self, line: bytes) -> None:
|
||||
self._buffer.append(line)
|
||||
|
||||
def eomReceived(self):
|
||||
def eomReceived(self) -> "defer.Deferred[bytes]":
|
||||
message = b"\n".join(self._buffer) + b"\n"
|
||||
self._delivery.record_message(self._user.dest, message)
|
||||
return defer.succeed(b"saved")
|
||||
|
||||
def connectionLost(self):
|
||||
def connectionLost(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class SendEmailHandlerTestCase(HomeserverTestCase):
|
||||
def test_send_email(self):
|
||||
def test_send_email(self) -> None:
|
||||
"""Happy-path test that we can send email to a non-TLS server."""
|
||||
h = self.hs.get_send_email_handler()
|
||||
d = ensureDeferred(
|
||||
|
@ -119,7 +126,7 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
|
|||
},
|
||||
}
|
||||
)
|
||||
def test_send_email_force_tls(self):
|
||||
def test_send_email_force_tls(self) -> None:
|
||||
"""Happy-path test that we can send email to an Implicit TLS server."""
|
||||
h = self.hs.get_send_email_handler()
|
||||
d = ensureDeferred(
|
||||
|
|
|
@ -12,9 +12,15 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login, room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.databases.main import stats
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
|
||||
|
@ -32,11 +38,11 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
login.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.store = hs.get_datastores().main
|
||||
self.handler = self.hs.get_stats_handler()
|
||||
|
||||
def _add_background_updates(self):
|
||||
def _add_background_updates(self) -> None:
|
||||
"""
|
||||
Add the background updates we need to run.
|
||||
"""
|
||||
|
@ -63,12 +69,14 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
)
|
||||
)
|
||||
|
||||
async def get_all_room_state(self):
|
||||
async def get_all_room_state(self) -> List[Dict[str, Any]]:
|
||||
return await self.store.db_pool.simple_select_list(
|
||||
"room_stats_state", None, retcols=("name", "topic", "canonical_alias")
|
||||
)
|
||||
|
||||
def _get_current_stats(self, stats_type, stat_id):
|
||||
def _get_current_stats(
|
||||
self, stats_type: str, stat_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
table, id_col = stats.TYPE_TO_TABLE[stats_type]
|
||||
|
||||
cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type])
|
||||
|
@ -82,13 +90,13 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
)
|
||||
)
|
||||
|
||||
def _perform_background_initial_update(self):
|
||||
def _perform_background_initial_update(self) -> None:
|
||||
# Do the initial population of the stats via the background update
|
||||
self._add_background_updates()
|
||||
|
||||
self.wait_for_background_updates()
|
||||
|
||||
def test_initial_room(self):
|
||||
def test_initial_room(self) -> None:
|
||||
"""
|
||||
The background updates will build the table from scratch.
|
||||
"""
|
||||
|
@ -125,7 +133,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
self.assertEqual(len(r), 1)
|
||||
self.assertEqual(r[0]["topic"], "foo")
|
||||
|
||||
def test_create_user(self):
|
||||
def test_create_user(self) -> None:
|
||||
"""
|
||||
When we create a user, it should have statistics already ready.
|
||||
"""
|
||||
|
@ -134,12 +142,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
|
||||
u1stats = self._get_current_stats("user", u1)
|
||||
|
||||
self.assertIsNotNone(u1stats)
|
||||
assert u1stats is not None
|
||||
|
||||
# not in any rooms by default
|
||||
self.assertEqual(u1stats["joined_rooms"], 0)
|
||||
|
||||
def test_create_room(self):
|
||||
def test_create_room(self) -> None:
|
||||
"""
|
||||
When we create a room, it should have statistics already ready.
|
||||
"""
|
||||
|
@ -153,8 +161,8 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
r2 = self.helper.create_room_as(u1, tok=u1token, is_public=False)
|
||||
r2stats = self._get_current_stats("room", r2)
|
||||
|
||||
self.assertIsNotNone(r1stats)
|
||||
self.assertIsNotNone(r2stats)
|
||||
assert r1stats is not None
|
||||
assert r2stats is not None
|
||||
|
||||
self.assertEqual(
|
||||
r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM
|
||||
|
@ -171,7 +179,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
self.assertEqual(r2stats["invited_members"], 0)
|
||||
self.assertEqual(r2stats["banned_members"], 0)
|
||||
|
||||
def test_updating_profile_information_does_not_increase_joined_members_count(self):
|
||||
def test_updating_profile_information_does_not_increase_joined_members_count(
|
||||
self,
|
||||
) -> None:
|
||||
"""
|
||||
Check that the joined_members count does not increase when a user changes their
|
||||
profile information (which is done by sending another join membership event into
|
||||
|
@ -186,6 +196,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
|
||||
# Get the current room stats
|
||||
r1stats_ante = self._get_current_stats("room", r1)
|
||||
assert r1stats_ante is not None
|
||||
|
||||
# Send a profile update into the room
|
||||
new_profile = {"displayname": "bob"}
|
||||
|
@ -195,6 +206,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
|
||||
# Get the new room stats
|
||||
r1stats_post = self._get_current_stats("room", r1)
|
||||
assert r1stats_post is not None
|
||||
|
||||
# Ensure that the user count did not changed
|
||||
self.assertEqual(r1stats_post["joined_members"], r1stats_ante["joined_members"])
|
||||
|
@ -202,7 +214,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
r1stats_post["local_users_in_room"], r1stats_ante["local_users_in_room"]
|
||||
)
|
||||
|
||||
def test_send_state_event_nonoverwriting(self):
|
||||
def test_send_state_event_nonoverwriting(self) -> None:
|
||||
"""
|
||||
When we send a non-overwriting state event, it increments current_state_events
|
||||
"""
|
||||
|
@ -218,19 +230,21 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
r1stats_ante = self._get_current_stats("room", r1)
|
||||
assert r1stats_ante is not None
|
||||
|
||||
self.helper.send_state(
|
||||
r1, "cat.hissing", {"value": False}, tok=u1token, state_key="moggy"
|
||||
)
|
||||
|
||||
r1stats_post = self._get_current_stats("room", r1)
|
||||
assert r1stats_post is not None
|
||||
|
||||
self.assertEqual(
|
||||
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
|
||||
1,
|
||||
)
|
||||
|
||||
def test_join_first_time(self):
|
||||
def test_join_first_time(self) -> None:
|
||||
"""
|
||||
When a user joins a room for the first time, current_state_events and
|
||||
joined_members should increase by exactly 1.
|
||||
|
@ -246,10 +260,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
u2token = self.login("u2", "pass")
|
||||
|
||||
r1stats_ante = self._get_current_stats("room", r1)
|
||||
assert r1stats_ante is not None
|
||||
|
||||
self.helper.join(r1, u2, tok=u2token)
|
||||
|
||||
r1stats_post = self._get_current_stats("room", r1)
|
||||
assert r1stats_post is not None
|
||||
|
||||
self.assertEqual(
|
||||
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
|
||||
|
@ -259,7 +275,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
r1stats_post["joined_members"] - r1stats_ante["joined_members"], 1
|
||||
)
|
||||
|
||||
def test_join_after_leave(self):
|
||||
def test_join_after_leave(self) -> None:
|
||||
"""
|
||||
When a user joins a room after being previously left,
|
||||
joined_members should increase by exactly 1.
|
||||
|
@ -280,10 +296,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
self.helper.leave(r1, u2, tok=u2token)
|
||||
|
||||
r1stats_ante = self._get_current_stats("room", r1)
|
||||
assert r1stats_ante is not None
|
||||
|
||||
self.helper.join(r1, u2, tok=u2token)
|
||||
|
||||
r1stats_post = self._get_current_stats("room", r1)
|
||||
assert r1stats_post is not None
|
||||
|
||||
self.assertEqual(
|
||||
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
|
||||
|
@ -296,7 +314,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
r1stats_post["left_members"] - r1stats_ante["left_members"], -1
|
||||
)
|
||||
|
||||
def test_invited(self):
|
||||
def test_invited(self) -> None:
|
||||
"""
|
||||
When a user invites another user, current_state_events and
|
||||
invited_members should increase by exactly 1.
|
||||
|
@ -311,10 +329,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
u2 = self.register_user("u2", "pass")
|
||||
|
||||
r1stats_ante = self._get_current_stats("room", r1)
|
||||
assert r1stats_ante is not None
|
||||
|
||||
self.helper.invite(r1, u1, u2, tok=u1token)
|
||||
|
||||
r1stats_post = self._get_current_stats("room", r1)
|
||||
assert r1stats_post is not None
|
||||
|
||||
self.assertEqual(
|
||||
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
|
||||
|
@ -324,7 +344,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
r1stats_post["invited_members"] - r1stats_ante["invited_members"], +1
|
||||
)
|
||||
|
||||
def test_join_after_invite(self):
|
||||
def test_join_after_invite(self) -> None:
|
||||
"""
|
||||
When a user joins a room after being invited and
|
||||
joined_members should increase by exactly 1.
|
||||
|
@ -344,10 +364,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
self.helper.invite(r1, u1, u2, tok=u1token)
|
||||
|
||||
r1stats_ante = self._get_current_stats("room", r1)
|
||||
assert r1stats_ante is not None
|
||||
|
||||
self.helper.join(r1, u2, tok=u2token)
|
||||
|
||||
r1stats_post = self._get_current_stats("room", r1)
|
||||
assert r1stats_post is not None
|
||||
|
||||
self.assertEqual(
|
||||
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
|
||||
|
@ -360,7 +382,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
r1stats_post["invited_members"] - r1stats_ante["invited_members"], -1
|
||||
)
|
||||
|
||||
def test_left(self):
|
||||
def test_left(self) -> None:
|
||||
"""
|
||||
When a user leaves a room after joining and
|
||||
left_members should increase by exactly 1.
|
||||
|
@ -380,10 +402,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
self.helper.join(r1, u2, tok=u2token)
|
||||
|
||||
r1stats_ante = self._get_current_stats("room", r1)
|
||||
assert r1stats_ante is not None
|
||||
|
||||
self.helper.leave(r1, u2, tok=u2token)
|
||||
|
||||
r1stats_post = self._get_current_stats("room", r1)
|
||||
assert r1stats_post is not None
|
||||
|
||||
self.assertEqual(
|
||||
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
|
||||
|
@ -396,7 +420,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1
|
||||
)
|
||||
|
||||
def test_banned(self):
|
||||
def test_banned(self) -> None:
|
||||
"""
|
||||
When a user is banned from a room after joining and
|
||||
left_members should increase by exactly 1.
|
||||
|
@ -416,10 +440,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
self.helper.join(r1, u2, tok=u2token)
|
||||
|
||||
r1stats_ante = self._get_current_stats("room", r1)
|
||||
assert r1stats_ante is not None
|
||||
|
||||
self.helper.change_membership(r1, u1, u2, "ban", tok=u1token)
|
||||
|
||||
r1stats_post = self._get_current_stats("room", r1)
|
||||
assert r1stats_post is not None
|
||||
|
||||
self.assertEqual(
|
||||
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
|
||||
|
@ -432,7 +458,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1
|
||||
)
|
||||
|
||||
def test_initial_background_update(self):
|
||||
def test_initial_background_update(self) -> None:
|
||||
"""
|
||||
Test that statistics can be generated by the initial background update
|
||||
handler.
|
||||
|
@ -462,6 +488,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
r1stats = self._get_current_stats("room", r1)
|
||||
u1stats = self._get_current_stats("user", u1)
|
||||
|
||||
assert r1stats is not None
|
||||
assert u1stats is not None
|
||||
|
||||
self.assertEqual(r1stats["joined_members"], 1)
|
||||
self.assertEqual(
|
||||
r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM
|
||||
|
@ -469,7 +498,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
|
||||
self.assertEqual(u1stats["joined_rooms"], 1)
|
||||
|
||||
def test_incomplete_stats(self):
|
||||
def test_incomplete_stats(self) -> None:
|
||||
"""
|
||||
This tests that we track incomplete statistics.
|
||||
|
||||
|
@ -533,8 +562,11 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
self.wait_for_background_updates()
|
||||
|
||||
r1stats_complete = self._get_current_stats("room", r1)
|
||||
assert r1stats_complete is not None
|
||||
u1stats_complete = self._get_current_stats("user", u1)
|
||||
assert u1stats_complete is not None
|
||||
u2stats_complete = self._get_current_stats("user", u2)
|
||||
assert u2stats_complete is not None
|
||||
|
||||
# now we make our assertions
|
||||
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
from typing import Optional
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.constants import EventTypes, JoinRules
|
||||
from synapse.api.errors import Codes, ResourceLimitError
|
||||
from synapse.api.filtering import Filtering
|
||||
|
@ -23,6 +25,7 @@ from synapse.rest import admin
|
|||
from synapse.rest.client import knock, login, room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import UserID, create_requester
|
||||
from synapse.util import Clock
|
||||
|
||||
import tests.unittest
|
||||
import tests.utils
|
||||
|
@ -39,7 +42,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
|
|||
room.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs: HomeServer):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.sync_handler = self.hs.get_sync_handler()
|
||||
self.store = self.hs.get_datastores().main
|
||||
|
||||
|
@ -47,7 +50,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
|
|||
# modify its config instead of the hs'
|
||||
self.auth_blocking = self.hs.get_auth_blocking()
|
||||
|
||||
def test_wait_for_sync_for_user_auth_blocking(self):
|
||||
def test_wait_for_sync_for_user_auth_blocking(self) -> None:
|
||||
user_id1 = "@user1:test"
|
||||
user_id2 = "@user2:test"
|
||||
sync_config = generate_sync_config(user_id1)
|
||||
|
@ -82,7 +85,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
|
|||
)
|
||||
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||
|
||||
def test_unknown_room_version(self):
|
||||
def test_unknown_room_version(self) -> None:
|
||||
"""
|
||||
A room with an unknown room version should not break sync (and should be excluded).
|
||||
"""
|
||||
|
@ -186,7 +189,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
|
|||
self.assertNotIn(invite_room, [r.room_id for r in result.invited])
|
||||
self.assertNotIn(knock_room, [r.room_id for r in result.knocked])
|
||||
|
||||
def test_ban_wins_race_with_join(self):
|
||||
def test_ban_wins_race_with_join(self) -> None:
|
||||
"""Rooms shouldn't appear under "joined" if a join loses a race to a ban.
|
||||
|
||||
A complicated edge case. Imagine the following scenario:
|
||||
|
|
Loading…
Reference in a new issue