mirror of
https://github.com/element-hq/synapse.git
synced 2024-12-14 11:57:44 +00:00
Add type hints to application services. (#8655)
This commit is contained in:
parent
2239813278
commit
31d721fbf6
5 changed files with 122 additions and 79 deletions
1
changelog.d/8655.misc
Normal file
1
changelog.d/8655.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add more type hints to the application services code.
|
4
mypy.ini
4
mypy.ini
|
@ -57,6 +57,7 @@ files =
|
||||||
synapse/server_notices,
|
synapse/server_notices,
|
||||||
synapse/spam_checker_api,
|
synapse/spam_checker_api,
|
||||||
synapse/state,
|
synapse/state,
|
||||||
|
synapse/storage/databases/main/appservice.py,
|
||||||
synapse/storage/databases/main/events.py,
|
synapse/storage/databases/main/events.py,
|
||||||
synapse/storage/databases/main/registration.py,
|
synapse/storage/databases/main/registration.py,
|
||||||
synapse/storage/databases/main/stream.py,
|
synapse/storage/databases/main/stream.py,
|
||||||
|
@ -82,6 +83,9 @@ ignore_missing_imports = True
|
||||||
[mypy-zope]
|
[mypy-zope]
|
||||||
ignore_missing_imports = True
|
ignore_missing_imports = True
|
||||||
|
|
||||||
|
[mypy-bcrypt]
|
||||||
|
ignore_missing_imports = True
|
||||||
|
|
||||||
[mypy-constantly]
|
[mypy-constantly]
|
||||||
ignore_missing_imports = True
|
ignore_missing_imports = True
|
||||||
|
|
||||||
|
|
|
@ -12,9 +12,8 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||||
|
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
|
||||||
|
@ -34,16 +33,20 @@ from synapse.metrics.background_process_metrics import (
|
||||||
run_as_background_process,
|
run_as_background_process,
|
||||||
wrap_as_background_process,
|
wrap_as_background_process,
|
||||||
)
|
)
|
||||||
from synapse.types import Collection, JsonDict, RoomStreamToken, UserID
|
from synapse.storage.databases.main.directory import RoomAliasMapping
|
||||||
|
from synapse.types import Collection, JsonDict, RoomAlias, RoomStreamToken, UserID
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "")
|
events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "")
|
||||||
|
|
||||||
|
|
||||||
class ApplicationServicesHandler:
|
class ApplicationServicesHandler:
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine_id = hs.is_mine_id
|
||||||
self.appservice_api = hs.get_application_service_api()
|
self.appservice_api = hs.get_application_service_api()
|
||||||
|
@ -247,7 +250,9 @@ class ApplicationServicesHandler:
|
||||||
service, "presence", new_token
|
service, "presence", new_token
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _handle_typing(self, service: ApplicationService, new_token: int):
|
async def _handle_typing(
|
||||||
|
self, service: ApplicationService, new_token: int
|
||||||
|
) -> List[JsonDict]:
|
||||||
typing_source = self.event_sources.sources["typing"]
|
typing_source = self.event_sources.sources["typing"]
|
||||||
# Get the typing events from just before current
|
# Get the typing events from just before current
|
||||||
typing, _ = await typing_source.get_new_events_as(
|
typing, _ = await typing_source.get_new_events_as(
|
||||||
|
@ -259,7 +264,7 @@ class ApplicationServicesHandler:
|
||||||
)
|
)
|
||||||
return typing
|
return typing
|
||||||
|
|
||||||
async def _handle_receipts(self, service: ApplicationService):
|
async def _handle_receipts(self, service: ApplicationService) -> List[JsonDict]:
|
||||||
from_key = await self.store.get_type_stream_id_for_appservice(
|
from_key = await self.store.get_type_stream_id_for_appservice(
|
||||||
service, "read_receipt"
|
service, "read_receipt"
|
||||||
)
|
)
|
||||||
|
@ -271,7 +276,7 @@ class ApplicationServicesHandler:
|
||||||
|
|
||||||
async def _handle_presence(
|
async def _handle_presence(
|
||||||
self, service: ApplicationService, users: Collection[Union[str, UserID]]
|
self, service: ApplicationService, users: Collection[Union[str, UserID]]
|
||||||
):
|
) -> List[JsonDict]:
|
||||||
events = [] # type: List[JsonDict]
|
events = [] # type: List[JsonDict]
|
||||||
presence_source = self.event_sources.sources["presence"]
|
presence_source = self.event_sources.sources["presence"]
|
||||||
from_key = await self.store.get_type_stream_id_for_appservice(
|
from_key = await self.store.get_type_stream_id_for_appservice(
|
||||||
|
@ -301,11 +306,11 @@ class ApplicationServicesHandler:
|
||||||
|
|
||||||
return events
|
return events
|
||||||
|
|
||||||
async def query_user_exists(self, user_id):
|
async def query_user_exists(self, user_id: str) -> bool:
|
||||||
"""Check if any application service knows this user_id exists.
|
"""Check if any application service knows this user_id exists.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id(str): The user to query if they exist on any AS.
|
user_id: The user to query if they exist on any AS.
|
||||||
Returns:
|
Returns:
|
||||||
True if this user exists on at least one application service.
|
True if this user exists on at least one application service.
|
||||||
"""
|
"""
|
||||||
|
@ -316,11 +321,13 @@ class ApplicationServicesHandler:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def query_room_alias_exists(self, room_alias):
|
async def query_room_alias_exists(
|
||||||
|
self, room_alias: RoomAlias
|
||||||
|
) -> Optional[RoomAliasMapping]:
|
||||||
"""Check if an application service knows this room alias exists.
|
"""Check if an application service knows this room alias exists.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_alias(RoomAlias): The room alias to query.
|
room_alias: The room alias to query.
|
||||||
Returns:
|
Returns:
|
||||||
namedtuple: with keys "room_id" and "servers" or None if no
|
namedtuple: with keys "room_id" and "servers" or None if no
|
||||||
association can be found.
|
association can be found.
|
||||||
|
@ -336,10 +343,13 @@ class ApplicationServicesHandler:
|
||||||
)
|
)
|
||||||
if is_known_alias:
|
if is_known_alias:
|
||||||
# the alias exists now so don't query more ASes.
|
# the alias exists now so don't query more ASes.
|
||||||
result = await self.store.get_association_from_room_alias(room_alias)
|
return await self.store.get_association_from_room_alias(room_alias)
|
||||||
return result
|
|
||||||
|
|
||||||
async def query_3pe(self, kind, protocol, fields):
|
return None
|
||||||
|
|
||||||
|
async def query_3pe(
|
||||||
|
self, kind: str, protocol: str, fields: Dict[bytes, List[bytes]]
|
||||||
|
) -> List[JsonDict]:
|
||||||
services = self._get_services_for_3pn(protocol)
|
services = self._get_services_for_3pn(protocol)
|
||||||
|
|
||||||
results = await make_deferred_yieldable(
|
results = await make_deferred_yieldable(
|
||||||
|
@ -361,7 +371,9 @@ class ApplicationServicesHandler:
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
async def get_3pe_protocols(self, only_protocol=None):
|
async def get_3pe_protocols(
|
||||||
|
self, only_protocol: Optional[str] = None
|
||||||
|
) -> Dict[str, JsonDict]:
|
||||||
services = self.store.get_app_services()
|
services = self.store.get_app_services()
|
||||||
protocols = {} # type: Dict[str, List[JsonDict]]
|
protocols = {} # type: Dict[str, List[JsonDict]]
|
||||||
|
|
||||||
|
@ -379,7 +391,7 @@ class ApplicationServicesHandler:
|
||||||
if info is not None:
|
if info is not None:
|
||||||
protocols[p].append(info)
|
protocols[p].append(info)
|
||||||
|
|
||||||
def _merge_instances(infos):
|
def _merge_instances(infos: List[JsonDict]) -> JsonDict:
|
||||||
if not infos:
|
if not infos:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@ -394,19 +406,17 @@ class ApplicationServicesHandler:
|
||||||
|
|
||||||
return combined
|
return combined
|
||||||
|
|
||||||
for p in protocols.keys():
|
return {p: _merge_instances(protocols[p]) for p in protocols.keys()}
|
||||||
protocols[p] = _merge_instances(protocols[p])
|
|
||||||
|
|
||||||
return protocols
|
async def _get_services_for_event(
|
||||||
|
self, event: EventBase
|
||||||
async def _get_services_for_event(self, event):
|
) -> List[ApplicationService]:
|
||||||
"""Retrieve a list of application services interested in this event.
|
"""Retrieve a list of application services interested in this event.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event(Event): The event to check. Can be None if alias_list is not.
|
event: The event to check. Can be None if alias_list is not.
|
||||||
Returns:
|
Returns:
|
||||||
list<ApplicationService>: A list of services interested in this
|
A list of services interested in this event based on the service regex.
|
||||||
event based on the service regex.
|
|
||||||
"""
|
"""
|
||||||
services = self.store.get_app_services()
|
services = self.store.get_app_services()
|
||||||
|
|
||||||
|
@ -420,17 +430,15 @@ class ApplicationServicesHandler:
|
||||||
|
|
||||||
return interested_list
|
return interested_list
|
||||||
|
|
||||||
def _get_services_for_user(self, user_id):
|
def _get_services_for_user(self, user_id: str) -> List[ApplicationService]:
|
||||||
services = self.store.get_app_services()
|
services = self.store.get_app_services()
|
||||||
interested_list = [s for s in services if (s.is_interested_in_user(user_id))]
|
return [s for s in services if (s.is_interested_in_user(user_id))]
|
||||||
return interested_list
|
|
||||||
|
|
||||||
def _get_services_for_3pn(self, protocol):
|
def _get_services_for_3pn(self, protocol: str) -> List[ApplicationService]:
|
||||||
services = self.store.get_app_services()
|
services = self.store.get_app_services()
|
||||||
interested_list = [s for s in services if s.is_interested_in_protocol(protocol)]
|
return [s for s in services if s.is_interested_in_protocol(protocol)]
|
||||||
return interested_list
|
|
||||||
|
|
||||||
async def _is_unknown_user(self, user_id):
|
async def _is_unknown_user(self, user_id: str) -> bool:
|
||||||
if not self.is_mine_id(user_id):
|
if not self.is_mine_id(user_id):
|
||||||
# we don't know if they are unknown or not since it isn't one of our
|
# we don't know if they are unknown or not since it isn't one of our
|
||||||
# users. We can't poke ASes.
|
# users. We can't poke ASes.
|
||||||
|
@ -445,9 +453,8 @@ class ApplicationServicesHandler:
|
||||||
service_list = [s for s in services if s.sender == user_id]
|
service_list = [s for s in services if s.sender == user_id]
|
||||||
return len(service_list) == 0
|
return len(service_list) == 0
|
||||||
|
|
||||||
async def _check_user_exists(self, user_id):
|
async def _check_user_exists(self, user_id: str) -> bool:
|
||||||
unknown_user = await self._is_unknown_user(user_id)
|
unknown_user = await self._is_unknown_user(user_id)
|
||||||
if unknown_user:
|
if unknown_user:
|
||||||
exists = await self.query_user_exists(user_id)
|
return await self.query_user_exists(user_id)
|
||||||
return exists
|
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -18,10 +18,20 @@ import logging
|
||||||
import time
|
import time
|
||||||
import unicodedata
|
import unicodedata
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import bcrypt # type: ignore[import]
|
import bcrypt
|
||||||
import pymacaroons
|
import pymacaroons
|
||||||
|
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
|
@ -49,6 +59,9 @@ from synapse.util.threepids import canonicalise_email
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -149,11 +162,7 @@ class SsoLoginExtraAttributes:
|
||||||
class AuthHandler(BaseHandler):
|
class AuthHandler(BaseHandler):
|
||||||
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
|
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
hs (synapse.server.HomeServer):
|
|
||||||
"""
|
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
|
|
||||||
self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker]
|
self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker]
|
||||||
|
|
|
@ -15,21 +15,31 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import List
|
from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple
|
||||||
|
|
||||||
from synapse.appservice import ApplicationService, AppServiceTransaction
|
from synapse.appservice import (
|
||||||
|
ApplicationService,
|
||||||
|
ApplicationServiceState,
|
||||||
|
AppServiceTransaction,
|
||||||
|
)
|
||||||
from synapse.config.appservice import load_appservices
|
from synapse.config.appservice import load_appservices
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool
|
||||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||||
|
from synapse.storage.types import Connection
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _make_exclusive_regex(services_cache):
|
def _make_exclusive_regex(
|
||||||
|
services_cache: List[ApplicationService],
|
||||||
|
) -> Optional[Pattern]:
|
||||||
# We precompile a regex constructed from all the regexes that the AS's
|
# We precompile a regex constructed from all the regexes that the AS's
|
||||||
# have registered for exclusive users.
|
# have registered for exclusive users.
|
||||||
exclusive_user_regexes = [
|
exclusive_user_regexes = [
|
||||||
|
@ -39,17 +49,19 @@ def _make_exclusive_regex(services_cache):
|
||||||
]
|
]
|
||||||
if exclusive_user_regexes:
|
if exclusive_user_regexes:
|
||||||
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
|
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
|
||||||
exclusive_user_regex = re.compile(exclusive_user_regex)
|
exclusive_user_pattern = re.compile(
|
||||||
|
exclusive_user_regex
|
||||||
|
) # type: Optional[Pattern]
|
||||||
else:
|
else:
|
||||||
# We handle this case specially otherwise the constructed regex
|
# We handle this case specially otherwise the constructed regex
|
||||||
# will always match
|
# will always match
|
||||||
exclusive_user_regex = None
|
exclusive_user_pattern = None
|
||||||
|
|
||||||
return exclusive_user_regex
|
return exclusive_user_pattern
|
||||||
|
|
||||||
|
|
||||||
class ApplicationServiceWorkerStore(SQLBaseStore):
|
class ApplicationServiceWorkerStore(SQLBaseStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
||||||
self.services_cache = load_appservices(
|
self.services_cache = load_appservices(
|
||||||
hs.hostname, hs.config.app_service_config_files
|
hs.hostname, hs.config.app_service_config_files
|
||||||
)
|
)
|
||||||
|
@ -60,7 +72,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
|
||||||
def get_app_services(self):
|
def get_app_services(self):
|
||||||
return self.services_cache
|
return self.services_cache
|
||||||
|
|
||||||
def get_if_app_services_interested_in_user(self, user_id):
|
def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
|
||||||
"""Check if the user is one associated with an app service (exclusively)
|
"""Check if the user is one associated with an app service (exclusively)
|
||||||
"""
|
"""
|
||||||
if self.exclusive_user_regex:
|
if self.exclusive_user_regex:
|
||||||
|
@ -68,7 +80,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_app_service_by_user_id(self, user_id):
|
def get_app_service_by_user_id(self, user_id: str) -> Optional[ApplicationService]:
|
||||||
"""Retrieve an application service from their user ID.
|
"""Retrieve an application service from their user ID.
|
||||||
|
|
||||||
All application services have associated with them a particular user ID.
|
All application services have associated with them a particular user ID.
|
||||||
|
@ -77,35 +89,35 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
|
||||||
a user ID to an application service.
|
a user ID to an application service.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id(str): The user ID to see if it is an application service.
|
user_id: The user ID to see if it is an application service.
|
||||||
Returns:
|
Returns:
|
||||||
synapse.appservice.ApplicationService or None.
|
The application service or None.
|
||||||
"""
|
"""
|
||||||
for service in self.services_cache:
|
for service in self.services_cache:
|
||||||
if service.sender == user_id:
|
if service.sender == user_id:
|
||||||
return service
|
return service
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_app_service_by_token(self, token):
|
def get_app_service_by_token(self, token: str) -> Optional[ApplicationService]:
|
||||||
"""Get the application service with the given appservice token.
|
"""Get the application service with the given appservice token.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
token (str): The application service token.
|
token: The application service token.
|
||||||
Returns:
|
Returns:
|
||||||
synapse.appservice.ApplicationService or None.
|
The application service or None.
|
||||||
"""
|
"""
|
||||||
for service in self.services_cache:
|
for service in self.services_cache:
|
||||||
if service.token == token:
|
if service.token == token:
|
||||||
return service
|
return service
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_app_service_by_id(self, as_id):
|
def get_app_service_by_id(self, as_id: str) -> Optional[ApplicationService]:
|
||||||
"""Get the application service with the given appservice ID.
|
"""Get the application service with the given appservice ID.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
as_id (str): The application service ID.
|
as_id: The application service ID.
|
||||||
Returns:
|
Returns:
|
||||||
synapse.appservice.ApplicationService or None.
|
The application service or None.
|
||||||
"""
|
"""
|
||||||
for service in self.services_cache:
|
for service in self.services_cache:
|
||||||
if service.id == as_id:
|
if service.id == as_id:
|
||||||
|
@ -124,11 +136,13 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore):
|
||||||
class ApplicationServiceTransactionWorkerStore(
|
class ApplicationServiceTransactionWorkerStore(
|
||||||
ApplicationServiceWorkerStore, EventsWorkerStore
|
ApplicationServiceWorkerStore, EventsWorkerStore
|
||||||
):
|
):
|
||||||
async def get_appservices_by_state(self, state):
|
async def get_appservices_by_state(
|
||||||
|
self, state: ApplicationServiceState
|
||||||
|
) -> List[ApplicationService]:
|
||||||
"""Get a list of application services based on their state.
|
"""Get a list of application services based on their state.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state(ApplicationServiceState): The state to filter on.
|
state: The state to filter on.
|
||||||
Returns:
|
Returns:
|
||||||
A list of ApplicationServices, which may be empty.
|
A list of ApplicationServices, which may be empty.
|
||||||
"""
|
"""
|
||||||
|
@ -145,13 +159,15 @@ class ApplicationServiceTransactionWorkerStore(
|
||||||
services.append(service)
|
services.append(service)
|
||||||
return services
|
return services
|
||||||
|
|
||||||
async def get_appservice_state(self, service):
|
async def get_appservice_state(
|
||||||
|
self, service: ApplicationService
|
||||||
|
) -> Optional[ApplicationServiceState]:
|
||||||
"""Get the application service state.
|
"""Get the application service state.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
service(ApplicationService): The service whose state to set.
|
service: The service whose state to set.
|
||||||
Returns:
|
Returns:
|
||||||
An ApplicationServiceState.
|
An ApplicationServiceState or none.
|
||||||
"""
|
"""
|
||||||
result = await self.db_pool.simple_select_one(
|
result = await self.db_pool.simple_select_one(
|
||||||
"application_services_state",
|
"application_services_state",
|
||||||
|
@ -164,12 +180,14 @@ class ApplicationServiceTransactionWorkerStore(
|
||||||
return result.get("state")
|
return result.get("state")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def set_appservice_state(self, service, state) -> None:
|
async def set_appservice_state(
|
||||||
|
self, service: ApplicationService, state: ApplicationServiceState
|
||||||
|
) -> None:
|
||||||
"""Set the application service state.
|
"""Set the application service state.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
service(ApplicationService): The service whose state to set.
|
service: The service whose state to set.
|
||||||
state(ApplicationServiceState): The connectivity state to apply.
|
state: The connectivity state to apply.
|
||||||
"""
|
"""
|
||||||
await self.db_pool.simple_upsert(
|
await self.db_pool.simple_upsert(
|
||||||
"application_services_state", {"as_id": service.id}, {"state": state}
|
"application_services_state", {"as_id": service.id}, {"state": state}
|
||||||
|
@ -226,13 +244,14 @@ class ApplicationServiceTransactionWorkerStore(
|
||||||
"create_appservice_txn", _create_appservice_txn
|
"create_appservice_txn", _create_appservice_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
async def complete_appservice_txn(self, txn_id, service) -> None:
|
async def complete_appservice_txn(
|
||||||
|
self, txn_id: int, service: ApplicationService
|
||||||
|
) -> None:
|
||||||
"""Completes an application service transaction.
|
"""Completes an application service transaction.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
txn_id(str): The transaction ID being completed.
|
txn_id: The transaction ID being completed.
|
||||||
service(ApplicationService): The application service which was sent
|
service: The application service which was sent this transaction.
|
||||||
this transaction.
|
|
||||||
"""
|
"""
|
||||||
txn_id = int(txn_id)
|
txn_id = int(txn_id)
|
||||||
|
|
||||||
|
@ -242,7 +261,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||||
# has probably missed some events), so whine loudly but still continue,
|
# has probably missed some events), so whine loudly but still continue,
|
||||||
# since it shouldn't fail completion of the transaction.
|
# since it shouldn't fail completion of the transaction.
|
||||||
last_txn_id = self._get_last_txn(txn, service.id)
|
last_txn_id = self._get_last_txn(txn, service.id)
|
||||||
if (last_txn_id + 1) != txn_id:
|
if (txn_id + 1) != txn_id:
|
||||||
logger.error(
|
logger.error(
|
||||||
"appservice: Completing a transaction which has an ID > 1 from "
|
"appservice: Completing a transaction which has an ID > 1 from "
|
||||||
"the last ID sent to this AS. We've either dropped events or "
|
"the last ID sent to this AS. We've either dropped events or "
|
||||||
|
@ -272,12 +291,13 @@ class ApplicationServiceTransactionWorkerStore(
|
||||||
"complete_appservice_txn", _complete_appservice_txn
|
"complete_appservice_txn", _complete_appservice_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_oldest_unsent_txn(self, service):
|
async def get_oldest_unsent_txn(
|
||||||
"""Get the oldest transaction which has not been sent for this
|
self, service: ApplicationService
|
||||||
service.
|
) -> Optional[AppServiceTransaction]:
|
||||||
|
"""Get the oldest transaction which has not been sent for this service.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
service(ApplicationService): The app service to get the oldest txn.
|
service: The app service to get the oldest txn.
|
||||||
Returns:
|
Returns:
|
||||||
An AppServiceTransaction or None.
|
An AppServiceTransaction or None.
|
||||||
"""
|
"""
|
||||||
|
@ -313,7 +333,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||||
service=service, id=entry["txn_id"], events=events, ephemeral=[]
|
service=service, id=entry["txn_id"], events=events, ephemeral=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_last_txn(self, txn, service_id):
|
def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"SELECT last_txn FROM application_services_state WHERE as_id=?",
|
"SELECT last_txn FROM application_services_state WHERE as_id=?",
|
||||||
(service_id,),
|
(service_id,),
|
||||||
|
@ -324,7 +344,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||||
else:
|
else:
|
||||||
return int(last_txn_id[0]) # select 'last_txn' col
|
return int(last_txn_id[0]) # select 'last_txn' col
|
||||||
|
|
||||||
async def set_appservice_last_pos(self, pos) -> None:
|
async def set_appservice_last_pos(self, pos: int) -> None:
|
||||||
def set_appservice_last_pos_txn(txn):
|
def set_appservice_last_pos_txn(txn):
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
|
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
|
||||||
|
@ -334,7 +354,9 @@ class ApplicationServiceTransactionWorkerStore(
|
||||||
"set_appservice_last_pos", set_appservice_last_pos_txn
|
"set_appservice_last_pos", set_appservice_last_pos_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_new_events_for_appservice(self, current_id, limit):
|
async def get_new_events_for_appservice(
|
||||||
|
self, current_id: int, limit: int
|
||||||
|
) -> Tuple[int, List[EventBase]]:
|
||||||
"""Get all new events for an appservice"""
|
"""Get all new events for an appservice"""
|
||||||
|
|
||||||
def get_new_events_for_appservice_txn(txn):
|
def get_new_events_for_appservice_txn(txn):
|
||||||
|
@ -394,7 +416,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||||
)
|
)
|
||||||
|
|
||||||
async def set_type_stream_id_for_appservice(
|
async def set_type_stream_id_for_appservice(
|
||||||
self, service: ApplicationService, type: str, pos: int
|
self, service: ApplicationService, type: str, pos: Optional[int]
|
||||||
) -> None:
|
) -> None:
|
||||||
if type not in ("read_receipt", "presence"):
|
if type not in ("read_receipt", "presence"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
Loading…
Reference in a new issue