Add type hints to application services. (#8655)

This commit is contained in:
Patrick Cloke 2020-10-28 11:12:21 -04:00 committed by GitHub
parent 2239813278
commit 31d721fbf6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 122 additions and 79 deletions

1
changelog.d/8655.misc Normal file
View file

@ -0,0 +1 @@
Add more type hints to the application services code.

View file

@ -57,6 +57,7 @@ files =
synapse/server_notices, synapse/server_notices,
synapse/spam_checker_api, synapse/spam_checker_api,
synapse/state, synapse/state,
synapse/storage/databases/main/appservice.py,
synapse/storage/databases/main/events.py, synapse/storage/databases/main/events.py,
synapse/storage/databases/main/registration.py, synapse/storage/databases/main/registration.py,
synapse/storage/databases/main/stream.py, synapse/storage/databases/main/stream.py,
@ -82,6 +83,9 @@ ignore_missing_imports = True
[mypy-zope] [mypy-zope]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-bcrypt]
ignore_missing_imports = True
[mypy-constantly] [mypy-constantly]
ignore_missing_imports = True ignore_missing_imports = True

View file

@ -12,9 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict, List, Optional, Union from typing import TYPE_CHECKING, Dict, List, Optional, Union
from prometheus_client import Counter from prometheus_client import Counter
@ -34,16 +33,20 @@ from synapse.metrics.background_process_metrics import (
run_as_background_process, run_as_background_process,
wrap_as_background_process, wrap_as_background_process,
) )
from synapse.types import Collection, JsonDict, RoomStreamToken, UserID from synapse.storage.databases.main.directory import RoomAliasMapping
from synapse.types import Collection, JsonDict, RoomAlias, RoomStreamToken, UserID
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "") events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "")
class ApplicationServicesHandler: class ApplicationServicesHandler:
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.appservice_api = hs.get_application_service_api() self.appservice_api = hs.get_application_service_api()
@ -247,7 +250,9 @@ class ApplicationServicesHandler:
service, "presence", new_token service, "presence", new_token
) )
async def _handle_typing(self, service: ApplicationService, new_token: int): async def _handle_typing(
self, service: ApplicationService, new_token: int
) -> List[JsonDict]:
typing_source = self.event_sources.sources["typing"] typing_source = self.event_sources.sources["typing"]
# Get the typing events from just before current # Get the typing events from just before current
typing, _ = await typing_source.get_new_events_as( typing, _ = await typing_source.get_new_events_as(
@ -259,7 +264,7 @@ class ApplicationServicesHandler:
) )
return typing return typing
async def _handle_receipts(self, service: ApplicationService): async def _handle_receipts(self, service: ApplicationService) -> List[JsonDict]:
from_key = await self.store.get_type_stream_id_for_appservice( from_key = await self.store.get_type_stream_id_for_appservice(
service, "read_receipt" service, "read_receipt"
) )
@ -271,7 +276,7 @@ class ApplicationServicesHandler:
async def _handle_presence( async def _handle_presence(
self, service: ApplicationService, users: Collection[Union[str, UserID]] self, service: ApplicationService, users: Collection[Union[str, UserID]]
): ) -> List[JsonDict]:
events = [] # type: List[JsonDict] events = [] # type: List[JsonDict]
presence_source = self.event_sources.sources["presence"] presence_source = self.event_sources.sources["presence"]
from_key = await self.store.get_type_stream_id_for_appservice( from_key = await self.store.get_type_stream_id_for_appservice(
@ -301,11 +306,11 @@ class ApplicationServicesHandler:
return events return events
async def query_user_exists(self, user_id): async def query_user_exists(self, user_id: str) -> bool:
"""Check if any application service knows this user_id exists. """Check if any application service knows this user_id exists.
Args: Args:
user_id(str): The user to query if they exist on any AS. user_id: The user to query if they exist on any AS.
Returns: Returns:
True if this user exists on at least one application service. True if this user exists on at least one application service.
""" """
@ -316,11 +321,13 @@ class ApplicationServicesHandler:
return True return True
return False return False
async def query_room_alias_exists(self, room_alias): async def query_room_alias_exists(
self, room_alias: RoomAlias
) -> Optional[RoomAliasMapping]:
"""Check if an application service knows this room alias exists. """Check if an application service knows this room alias exists.
Args: Args:
room_alias(RoomAlias): The room alias to query. room_alias: The room alias to query.
Returns: Returns:
namedtuple: with keys "room_id" and "servers" or None if no namedtuple: with keys "room_id" and "servers" or None if no
association can be found. association can be found.
@ -336,10 +343,13 @@ class ApplicationServicesHandler:
) )
if is_known_alias: if is_known_alias:
# the alias exists now so don't query more ASes. # the alias exists now so don't query more ASes.
result = await self.store.get_association_from_room_alias(room_alias) return await self.store.get_association_from_room_alias(room_alias)
return result
async def query_3pe(self, kind, protocol, fields): return None
async def query_3pe(
self, kind: str, protocol: str, fields: Dict[bytes, List[bytes]]
) -> List[JsonDict]:
services = self._get_services_for_3pn(protocol) services = self._get_services_for_3pn(protocol)
results = await make_deferred_yieldable( results = await make_deferred_yieldable(
@ -361,7 +371,9 @@ class ApplicationServicesHandler:
return ret return ret
async def get_3pe_protocols(self, only_protocol=None): async def get_3pe_protocols(
self, only_protocol: Optional[str] = None
) -> Dict[str, JsonDict]:
services = self.store.get_app_services() services = self.store.get_app_services()
protocols = {} # type: Dict[str, List[JsonDict]] protocols = {} # type: Dict[str, List[JsonDict]]
@ -379,7 +391,7 @@ class ApplicationServicesHandler:
if info is not None: if info is not None:
protocols[p].append(info) protocols[p].append(info)
def _merge_instances(infos): def _merge_instances(infos: List[JsonDict]) -> JsonDict:
if not infos: if not infos:
return {} return {}
@ -394,19 +406,17 @@ class ApplicationServicesHandler:
return combined return combined
for p in protocols.keys(): return {p: _merge_instances(protocols[p]) for p in protocols.keys()}
protocols[p] = _merge_instances(protocols[p])
return protocols async def _get_services_for_event(
self, event: EventBase
async def _get_services_for_event(self, event): ) -> List[ApplicationService]:
"""Retrieve a list of application services interested in this event. """Retrieve a list of application services interested in this event.
Args: Args:
event(Event): The event to check. Can be None if alias_list is not. event: The event to check. Can be None if alias_list is not.
Returns: Returns:
list<ApplicationService>: A list of services interested in this A list of services interested in this event based on the service regex.
event based on the service regex.
""" """
services = self.store.get_app_services() services = self.store.get_app_services()
@ -420,17 +430,15 @@ class ApplicationServicesHandler:
return interested_list return interested_list
def _get_services_for_user(self, user_id): def _get_services_for_user(self, user_id: str) -> List[ApplicationService]:
services = self.store.get_app_services() services = self.store.get_app_services()
interested_list = [s for s in services if (s.is_interested_in_user(user_id))] return [s for s in services if (s.is_interested_in_user(user_id))]
return interested_list
def _get_services_for_3pn(self, protocol): def _get_services_for_3pn(self, protocol: str) -> List[ApplicationService]:
services = self.store.get_app_services() services = self.store.get_app_services()
interested_list = [s for s in services if s.is_interested_in_protocol(protocol)] return [s for s in services if s.is_interested_in_protocol(protocol)]
return interested_list
async def _is_unknown_user(self, user_id): async def _is_unknown_user(self, user_id: str) -> bool:
if not self.is_mine_id(user_id): if not self.is_mine_id(user_id):
# we don't know if they are unknown or not since it isn't one of our # we don't know if they are unknown or not since it isn't one of our
# users. We can't poke ASes. # users. We can't poke ASes.
@ -445,9 +453,8 @@ class ApplicationServicesHandler:
service_list = [s for s in services if s.sender == user_id] service_list = [s for s in services if s.sender == user_id]
return len(service_list) == 0 return len(service_list) == 0
async def _check_user_exists(self, user_id): async def _check_user_exists(self, user_id: str) -> bool:
unknown_user = await self._is_unknown_user(user_id) unknown_user = await self._is_unknown_user(user_id)
if unknown_user: if unknown_user:
exists = await self.query_user_exists(user_id) return await self.query_user_exists(user_id)
return exists
return True return True

View file

@ -18,10 +18,20 @@ import logging
import time import time
import unicodedata import unicodedata
import urllib.parse import urllib.parse
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
Union,
)
import attr import attr
import bcrypt # type: ignore[import] import bcrypt
import pymacaroons import pymacaroons
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
@ -49,6 +59,9 @@ from synapse.util.threepids import canonicalise_email
from ._base import BaseHandler from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -149,11 +162,7 @@ class SsoLoginExtraAttributes:
class AuthHandler(BaseHandler): class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer):
"""
super().__init__(hs) super().__init__(hs)
self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker] self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker]

View file

@ -15,21 +15,31 @@
# limitations under the License. # limitations under the License.
import logging import logging
import re import re
from typing import List from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple
from synapse.appservice import ApplicationService, AppServiceTransaction from synapse.appservice import (
ApplicationService,
ApplicationServiceState,
AppServiceTransaction,
)
from synapse.config.appservice import load_appservices from synapse.config.appservice import load_appservices
from synapse.events import EventBase from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.types import Connection
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _make_exclusive_regex(services_cache): def _make_exclusive_regex(
services_cache: List[ApplicationService],
) -> Optional[Pattern]:
# We precompile a regex constructed from all the regexes that the AS's # We precompile a regex constructed from all the regexes that the AS's
# have registered for exclusive users. # have registered for exclusive users.
exclusive_user_regexes = [ exclusive_user_regexes = [
@ -39,17 +49,19 @@ def _make_exclusive_regex(services_cache):
] ]
if exclusive_user_regexes: if exclusive_user_regexes:
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes) exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
exclusive_user_regex = re.compile(exclusive_user_regex) exclusive_user_pattern = re.compile(
exclusive_user_regex
) # type: Optional[Pattern]
else: else:
# We handle this case specially otherwise the constructed regex # We handle this case specially otherwise the constructed regex
# will always match # will always match
exclusive_user_regex = None exclusive_user_pattern = None
return exclusive_user_regex return exclusive_user_pattern
class ApplicationServiceWorkerStore(SQLBaseStore): class ApplicationServiceWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
self.services_cache = load_appservices( self.services_cache = load_appservices(
hs.hostname, hs.config.app_service_config_files hs.hostname, hs.config.app_service_config_files
) )
@ -60,7 +72,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
def get_app_services(self): def get_app_services(self):
return self.services_cache return self.services_cache
def get_if_app_services_interested_in_user(self, user_id): def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
"""Check if the user is one associated with an app service (exclusively) """Check if the user is one associated with an app service (exclusively)
""" """
if self.exclusive_user_regex: if self.exclusive_user_regex:
@ -68,7 +80,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
else: else:
return False return False
def get_app_service_by_user_id(self, user_id): def get_app_service_by_user_id(self, user_id: str) -> Optional[ApplicationService]:
"""Retrieve an application service from their user ID. """Retrieve an application service from their user ID.
All application services have associated with them a particular user ID. All application services have associated with them a particular user ID.
@ -77,35 +89,35 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
a user ID to an application service. a user ID to an application service.
Args: Args:
user_id(str): The user ID to see if it is an application service. user_id: The user ID to see if it is an application service.
Returns: Returns:
synapse.appservice.ApplicationService or None. The application service or None.
""" """
for service in self.services_cache: for service in self.services_cache:
if service.sender == user_id: if service.sender == user_id:
return service return service
return None return None
def get_app_service_by_token(self, token): def get_app_service_by_token(self, token: str) -> Optional[ApplicationService]:
"""Get the application service with the given appservice token. """Get the application service with the given appservice token.
Args: Args:
token (str): The application service token. token: The application service token.
Returns: Returns:
synapse.appservice.ApplicationService or None. The application service or None.
""" """
for service in self.services_cache: for service in self.services_cache:
if service.token == token: if service.token == token:
return service return service
return None return None
def get_app_service_by_id(self, as_id): def get_app_service_by_id(self, as_id: str) -> Optional[ApplicationService]:
"""Get the application service with the given appservice ID. """Get the application service with the given appservice ID.
Args: Args:
as_id (str): The application service ID. as_id: The application service ID.
Returns: Returns:
synapse.appservice.ApplicationService or None. The application service or None.
""" """
for service in self.services_cache: for service in self.services_cache:
if service.id == as_id: if service.id == as_id:
@ -124,11 +136,13 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore):
class ApplicationServiceTransactionWorkerStore( class ApplicationServiceTransactionWorkerStore(
ApplicationServiceWorkerStore, EventsWorkerStore ApplicationServiceWorkerStore, EventsWorkerStore
): ):
async def get_appservices_by_state(self, state): async def get_appservices_by_state(
self, state: ApplicationServiceState
) -> List[ApplicationService]:
"""Get a list of application services based on their state. """Get a list of application services based on their state.
Args: Args:
state(ApplicationServiceState): The state to filter on. state: The state to filter on.
Returns: Returns:
A list of ApplicationServices, which may be empty. A list of ApplicationServices, which may be empty.
""" """
@ -145,13 +159,15 @@ class ApplicationServiceTransactionWorkerStore(
services.append(service) services.append(service)
return services return services
async def get_appservice_state(self, service): async def get_appservice_state(
self, service: ApplicationService
) -> Optional[ApplicationServiceState]:
"""Get the application service state. """Get the application service state.
Args: Args:
service(ApplicationService): The service whose state to set. service: The service whose state to set.
Returns: Returns:
An ApplicationServiceState. An ApplicationServiceState or none.
""" """
result = await self.db_pool.simple_select_one( result = await self.db_pool.simple_select_one(
"application_services_state", "application_services_state",
@ -164,12 +180,14 @@ class ApplicationServiceTransactionWorkerStore(
return result.get("state") return result.get("state")
return None return None
async def set_appservice_state(self, service, state) -> None: async def set_appservice_state(
self, service: ApplicationService, state: ApplicationServiceState
) -> None:
"""Set the application service state. """Set the application service state.
Args: Args:
service(ApplicationService): The service whose state to set. service: The service whose state to set.
state(ApplicationServiceState): The connectivity state to apply. state: The connectivity state to apply.
""" """
await self.db_pool.simple_upsert( await self.db_pool.simple_upsert(
"application_services_state", {"as_id": service.id}, {"state": state} "application_services_state", {"as_id": service.id}, {"state": state}
@ -226,13 +244,14 @@ class ApplicationServiceTransactionWorkerStore(
"create_appservice_txn", _create_appservice_txn "create_appservice_txn", _create_appservice_txn
) )
async def complete_appservice_txn(self, txn_id, service) -> None: async def complete_appservice_txn(
self, txn_id: int, service: ApplicationService
) -> None:
"""Completes an application service transaction. """Completes an application service transaction.
Args: Args:
txn_id(str): The transaction ID being completed. txn_id: The transaction ID being completed.
service(ApplicationService): The application service which was sent service: The application service which was sent this transaction.
this transaction.
""" """
txn_id = int(txn_id) txn_id = int(txn_id)
@ -242,7 +261,7 @@ class ApplicationServiceTransactionWorkerStore(
# has probably missed some events), so whine loudly but still continue, # has probably missed some events), so whine loudly but still continue,
# since it shouldn't fail completion of the transaction. # since it shouldn't fail completion of the transaction.
last_txn_id = self._get_last_txn(txn, service.id) last_txn_id = self._get_last_txn(txn, service.id)
if (last_txn_id + 1) != txn_id: if (txn_id + 1) != txn_id:
logger.error( logger.error(
"appservice: Completing a transaction which has an ID > 1 from " "appservice: Completing a transaction which has an ID > 1 from "
"the last ID sent to this AS. We've either dropped events or " "the last ID sent to this AS. We've either dropped events or "
@ -272,12 +291,13 @@ class ApplicationServiceTransactionWorkerStore(
"complete_appservice_txn", _complete_appservice_txn "complete_appservice_txn", _complete_appservice_txn
) )
async def get_oldest_unsent_txn(self, service): async def get_oldest_unsent_txn(
"""Get the oldest transaction which has not been sent for this self, service: ApplicationService
service. ) -> Optional[AppServiceTransaction]:
"""Get the oldest transaction which has not been sent for this service.
Args: Args:
service(ApplicationService): The app service to get the oldest txn. service: The app service to get the oldest txn.
Returns: Returns:
An AppServiceTransaction or None. An AppServiceTransaction or None.
""" """
@ -313,7 +333,7 @@ class ApplicationServiceTransactionWorkerStore(
service=service, id=entry["txn_id"], events=events, ephemeral=[] service=service, id=entry["txn_id"], events=events, ephemeral=[]
) )
def _get_last_txn(self, txn, service_id): def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
txn.execute( txn.execute(
"SELECT last_txn FROM application_services_state WHERE as_id=?", "SELECT last_txn FROM application_services_state WHERE as_id=?",
(service_id,), (service_id,),
@ -324,7 +344,7 @@ class ApplicationServiceTransactionWorkerStore(
else: else:
return int(last_txn_id[0]) # select 'last_txn' col return int(last_txn_id[0]) # select 'last_txn' col
async def set_appservice_last_pos(self, pos) -> None: async def set_appservice_last_pos(self, pos: int) -> None:
def set_appservice_last_pos_txn(txn): def set_appservice_last_pos_txn(txn):
txn.execute( txn.execute(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
@ -334,7 +354,9 @@ class ApplicationServiceTransactionWorkerStore(
"set_appservice_last_pos", set_appservice_last_pos_txn "set_appservice_last_pos", set_appservice_last_pos_txn
) )
async def get_new_events_for_appservice(self, current_id, limit): async def get_new_events_for_appservice(
self, current_id: int, limit: int
) -> Tuple[int, List[EventBase]]:
"""Get all new events for an appservice""" """Get all new events for an appservice"""
def get_new_events_for_appservice_txn(txn): def get_new_events_for_appservice_txn(txn):
@ -394,7 +416,7 @@ class ApplicationServiceTransactionWorkerStore(
) )
async def set_type_stream_id_for_appservice( async def set_type_stream_id_for_appservice(
self, service: ApplicationService, type: str, pos: int self, service: ApplicationService, type: str, pos: Optional[int]
) -> None: ) -> None:
if type not in ("read_receipt", "presence"): if type not in ("read_receipt", "presence"):
raise ValueError( raise ValueError(