Add a primitive helper script for listing worker endpoints. (#15243)

Co-authored-by: Patrick Cloke <patrickc@matrix.org>
This commit is contained in:
reivilibre 2023-03-23 12:11:14 +00:00 committed by GitHub
parent 3b0083c92a
commit 98fd558382
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
31 changed files with 424 additions and 12 deletions

View file

@ -0,0 +1 @@
Add a primitive helper script for listing worker endpoints.

View file

@ -0,0 +1,302 @@
#!/usr/bin/env python
# Copyright 2022-2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import re
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, Iterable, Optional, Pattern, Set, Tuple
import yaml
from synapse.config.homeserver import HomeServerConfig
from synapse.federation.transport.server import (
TransportLayerServer,
register_servlets as register_federation_servlets,
)
from synapse.http.server import HttpServer, ServletCallback
from synapse.rest import ClientRestResource
from synapse.rest.key.v2 import RemoteKey
from synapse.server import HomeServer
from synapse.storage import DataStore
logger = logging.getLogger("generate_workers_map")
class MockHomeserver(HomeServer):
DATASTORE_CLASS = DataStore # type: ignore
def __init__(self, config: HomeServerConfig, worker_app: Optional[str]) -> None:
super().__init__(config.server.server_name, config=config)
self.config.worker.worker_app = worker_app
GROUP_PATTERN = re.compile(r"\(\?P<[^>]+?>(.+?)\)")
@dataclass
class EndpointDescription:
"""
Describes an endpoint and how it should be routed.
"""
# The servlet class that handles this endpoint
servlet_class: object
# The category of this endpoint. Is read from the `CATEGORY` constant in the servlet
# class.
category: Optional[str]
# TODO:
# - does it need to be routed based on a stream writer config?
# - does it benefit from any optimised, but optional, routing?
# - what 'opinionated synapse worker class' (event_creator, synchrotron, etc) does
# it go in?
class EnumerationResource(HttpServer):
"""
Accepts servlet registrations for the purposes of building up a description of
all endpoints.
"""
def __init__(self, is_worker: bool) -> None:
self.registrations: Dict[Tuple[str, str], EndpointDescription] = {}
self._is_worker = is_worker
def register_paths(
self,
method: str,
path_patterns: Iterable[Pattern],
callback: ServletCallback,
servlet_classname: str,
) -> None:
# federation servlet callbacks are wrapped, so unwrap them.
callback = getattr(callback, "__wrapped__", callback)
# fish out the servlet class
servlet_class = callback.__self__.__class__ # type: ignore
if self._is_worker and method in getattr(
servlet_class, "WORKERS_DENIED_METHODS", ()
):
# This endpoint would cause an error if called on a worker, so pretend it
# was never registered!
return
sd = EndpointDescription(
servlet_class=servlet_class,
category=getattr(servlet_class, "CATEGORY", None),
)
for pat in path_patterns:
self.registrations[(method, pat.pattern)] = sd
def get_registered_paths_for_hs(
hs: HomeServer,
) -> Dict[Tuple[str, str], EndpointDescription]:
"""
Given a homeserver, get all registered endpoints and their descriptions.
"""
enumerator = EnumerationResource(is_worker=hs.config.worker.worker_app is not None)
ClientRestResource.register_servlets(enumerator, hs)
federation_server = TransportLayerServer(hs)
# we can't use `federation_server.register_servlets` but this line does the
# same thing, only it uses this enumerator
register_federation_servlets(
federation_server.hs,
resource=enumerator,
ratelimiter=federation_server.ratelimiter,
authenticator=federation_server.authenticator,
servlet_groups=federation_server.servlet_groups,
)
# the key server endpoints are separate again
RemoteKey(hs).register(enumerator)
return enumerator.registrations
def get_registered_paths_for_default(
worker_app: Optional[str], base_config: HomeServerConfig
) -> Dict[Tuple[str, str], EndpointDescription]:
"""
Given the name of a worker application and a base homeserver configuration,
returns:
Dict from (method, path) to EndpointDescription
TODO Don't require passing in a config
"""
hs = MockHomeserver(base_config, worker_app)
# TODO We only do this to avoid an error, but don't need the database etc
hs.setup()
return get_registered_paths_for_hs(hs)
def elide_http_methods_if_unconflicting(
registrations: Dict[Tuple[str, str], EndpointDescription],
all_possible_registrations: Dict[Tuple[str, str], EndpointDescription],
) -> Dict[Tuple[str, str], EndpointDescription]:
"""
Elides HTTP methods (by replacing them with `*`) if all possible registered methods
can be handled by the worker whose registration map is `registrations`.
i.e. the only endpoints left with methods (other than `*`) should be the ones where
the worker can't handle all possible methods for that path.
"""
def paths_to_methods_dict(
methods_and_paths: Iterable[Tuple[str, str]]
) -> Dict[str, Set[str]]:
"""
Given (method, path) pairs, produces a dict from path to set of methods
available at that path.
"""
result: Dict[str, Set[str]] = {}
for method, path in methods_and_paths:
result.setdefault(path, set()).add(method)
return result
all_possible_reg_methods = paths_to_methods_dict(all_possible_registrations)
reg_methods = paths_to_methods_dict(registrations)
output = {}
for path, handleable_methods in reg_methods.items():
if handleable_methods == all_possible_reg_methods[path]:
any_method = next(iter(handleable_methods))
# TODO This assumes that all methods have the same servlet.
# I suppose that's possibly dubious?
output[("*", path)] = registrations[(any_method, path)]
else:
for method in handleable_methods:
output[(method, path)] = registrations[(method, path)]
return output
def simplify_path_regexes(
registrations: Dict[Tuple[str, str], EndpointDescription]
) -> Dict[Tuple[str, str], EndpointDescription]:
"""
Simplify all the path regexes for the dict of endpoint descriptions,
so that we don't use the Python-specific regex extensions
(and also to remove needlessly specific detail).
"""
def simplify_path_regex(path: str) -> str:
"""
Given a regex pattern, replaces all named capturing groups (e.g. `(?P<blah>xyz)`)
with a simpler version available in more common regex dialects (e.g. `.*`).
"""
# TODO it's hard to choose between these two;
# `.*` is a vague simplification
# return GROUP_PATTERN.sub(r"\1", path)
return GROUP_PATTERN.sub(r".*", path)
return {(m, simplify_path_regex(p)): v for (m, p), v in registrations.items()}
def main() -> None:
parser = argparse.ArgumentParser(
description=(
"Updates a synapse database to the latest schema and optionally runs background updates"
" on it."
)
)
parser.add_argument("-v", action="store_true")
parser.add_argument(
"--config-path",
type=argparse.FileType("r"),
required=True,
help="Synapse configuration file",
)
args = parser.parse_args()
# TODO
# logging.basicConfig(**logging_config)
# Load, process and sanity-check the config.
hs_config = yaml.safe_load(args.config_path)
config = HomeServerConfig()
config.parse_config_dict(hs_config, "", "")
master_paths = get_registered_paths_for_default(None, config)
worker_paths = get_registered_paths_for_default(
"synapse.app.generic_worker", config
)
all_paths = {**master_paths, **worker_paths}
elided_worker_paths = elide_http_methods_if_unconflicting(worker_paths, all_paths)
elide_http_methods_if_unconflicting(master_paths, all_paths)
# TODO SSO endpoints (pick_idp etc) NOT REGISTERED BY THIS SCRIPT
categories_to_methods_and_paths: Dict[
Optional[str], Dict[Tuple[str, str], EndpointDescription]
] = defaultdict(dict)
for (method, path), desc in elided_worker_paths.items():
categories_to_methods_and_paths[desc.category][method, path] = desc
for category, contents in categories_to_methods_and_paths.items():
print_category(category, contents)
def print_category(
category_name: Optional[str],
elided_worker_paths: Dict[Tuple[str, str], EndpointDescription],
) -> None:
"""
Prints out a category, in documentation page style.
Example:
```
# Category name
/path/xyz
GET /path/abc
```
"""
if category_name:
print(f"# {category_name}")
else:
print("# (Uncategorised requests)")
for ln in sorted(
p for m, p in simplify_path_regexes(elided_worker_paths) if m == "*"
):
print(ln)
print()
for ln in sorted(
f"{m:6} {p}" for m, p in simplify_path_regexes(elided_worker_paths) if m != "*"
):
print(ln)
print()
if __name__ == "__main__":
main()

View file

@ -108,6 +108,7 @@ class PublicRoomList(BaseFederationServlet):
"""
PATH = "/publicRooms"
CATEGORY = "Federation requests"
def __init__(
self,
@ -212,6 +213,7 @@ class OpenIdUserInfo(BaseFederationServlet):
"""
PATH = "/openid/userinfo"
CATEGORY = "Federation requests"
REQUIRE_AUTH = False

View file

@ -70,6 +70,7 @@ class BaseFederationServerServlet(BaseFederationServlet):
class FederationSendServlet(BaseFederationServerServlet):
PATH = "/send/(?P<transaction_id>[^/]*)/?"
CATEGORY = "Inbound federation transaction request"
# We ratelimit manually in the handler as we queue up the requests and we
# don't want to fill up the ratelimiter with blocked requests.
@ -138,6 +139,7 @@ class FederationSendServlet(BaseFederationServerServlet):
class FederationEventServlet(BaseFederationServerServlet):
PATH = "/event/(?P<event_id>[^/]*)/?"
CATEGORY = "Federation requests"
# This is when someone asks for a data item for a given server data_id pair.
async def on_GET(
@ -152,6 +154,7 @@ class FederationEventServlet(BaseFederationServerServlet):
class FederationStateV1Servlet(BaseFederationServerServlet):
PATH = "/state/(?P<room_id>[^/]*)/?"
CATEGORY = "Federation requests"
# This is when someone asks for all data for a given room.
async def on_GET(
@ -170,6 +173,7 @@ class FederationStateV1Servlet(BaseFederationServerServlet):
class FederationStateIdsServlet(BaseFederationServerServlet):
PATH = "/state_ids/(?P<room_id>[^/]*)/?"
CATEGORY = "Federation requests"
async def on_GET(
self,
@ -187,6 +191,7 @@ class FederationStateIdsServlet(BaseFederationServerServlet):
class FederationBackfillServlet(BaseFederationServerServlet):
PATH = "/backfill/(?P<room_id>[^/]*)/?"
CATEGORY = "Federation requests"
async def on_GET(
self,
@ -225,6 +230,7 @@ class FederationTimestampLookupServlet(BaseFederationServerServlet):
"""
PATH = "/timestamp_to_event/(?P<room_id>[^/]*)/?"
CATEGORY = "Federation requests"
async def on_GET(
self,
@ -246,6 +252,7 @@ class FederationTimestampLookupServlet(BaseFederationServerServlet):
class FederationQueryServlet(BaseFederationServerServlet):
PATH = "/query/(?P<query_type>[^/]*)"
CATEGORY = "Federation requests"
# This is when we receive a server-server Query
async def on_GET(
@ -262,6 +269,7 @@ class FederationQueryServlet(BaseFederationServerServlet):
class FederationMakeJoinServlet(BaseFederationServerServlet):
PATH = "/make_join/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
CATEGORY = "Federation requests"
async def on_GET(
self,
@ -297,6 +305,7 @@ class FederationMakeJoinServlet(BaseFederationServerServlet):
class FederationMakeLeaveServlet(BaseFederationServerServlet):
PATH = "/make_leave/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
CATEGORY = "Federation requests"
async def on_GET(
self,
@ -312,6 +321,7 @@ class FederationMakeLeaveServlet(BaseFederationServerServlet):
class FederationV1SendLeaveServlet(BaseFederationServerServlet):
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
CATEGORY = "Federation requests"
async def on_PUT(
self,
@ -327,6 +337,7 @@ class FederationV1SendLeaveServlet(BaseFederationServerServlet):
class FederationV2SendLeaveServlet(BaseFederationServerServlet):
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
CATEGORY = "Federation requests"
PREFIX = FEDERATION_V2_PREFIX
@ -344,6 +355,7 @@ class FederationV2SendLeaveServlet(BaseFederationServerServlet):
class FederationMakeKnockServlet(BaseFederationServerServlet):
PATH = "/make_knock/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
CATEGORY = "Federation requests"
async def on_GET(
self,
@ -366,6 +378,7 @@ class FederationMakeKnockServlet(BaseFederationServerServlet):
class FederationV1SendKnockServlet(BaseFederationServerServlet):
PATH = "/send_knock/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
CATEGORY = "Federation requests"
async def on_PUT(
self,
@ -381,6 +394,7 @@ class FederationV1SendKnockServlet(BaseFederationServerServlet):
class FederationEventAuthServlet(BaseFederationServerServlet):
PATH = "/event_auth/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
CATEGORY = "Federation requests"
async def on_GET(
self,
@ -395,6 +409,7 @@ class FederationEventAuthServlet(BaseFederationServerServlet):
class FederationV1SendJoinServlet(BaseFederationServerServlet):
PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
CATEGORY = "Federation requests"
async def on_PUT(
self,
@ -412,6 +427,7 @@ class FederationV1SendJoinServlet(BaseFederationServerServlet):
class FederationV2SendJoinServlet(BaseFederationServerServlet):
PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
CATEGORY = "Federation requests"
PREFIX = FEDERATION_V2_PREFIX
@ -455,6 +471,7 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet):
class FederationV1InviteServlet(BaseFederationServerServlet):
PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
CATEGORY = "Federation requests"
async def on_PUT(
self,
@ -479,6 +496,7 @@ class FederationV1InviteServlet(BaseFederationServerServlet):
class FederationV2InviteServlet(BaseFederationServerServlet):
PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
CATEGORY = "Federation requests"
PREFIX = FEDERATION_V2_PREFIX
@ -515,6 +533,7 @@ class FederationV2InviteServlet(BaseFederationServerServlet):
class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet):
PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)"
CATEGORY = "Federation requests"
async def on_PUT(
self,
@ -529,6 +548,7 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet):
class FederationClientKeysQueryServlet(BaseFederationServerServlet):
PATH = "/user/keys/query"
CATEGORY = "Federation requests"
async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
@ -538,6 +558,7 @@ class FederationClientKeysQueryServlet(BaseFederationServerServlet):
class FederationUserDevicesQueryServlet(BaseFederationServerServlet):
PATH = "/user/devices/(?P<user_id>[^/]*)"
CATEGORY = "Federation requests"
async def on_GET(
self,
@ -551,6 +572,7 @@ class FederationUserDevicesQueryServlet(BaseFederationServerServlet):
class FederationClientKeysClaimServlet(BaseFederationServerServlet):
PATH = "/user/keys/claim"
CATEGORY = "Federation requests"
async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
@ -561,6 +583,7 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet):
class FederationGetMissingEventsServlet(BaseFederationServerServlet):
PATH = "/get_missing_events/(?P<room_id>[^/]*)"
CATEGORY = "Federation requests"
async def on_POST(
self,
@ -586,6 +609,7 @@ class FederationGetMissingEventsServlet(BaseFederationServerServlet):
class On3pidBindServlet(BaseFederationServerServlet):
PATH = "/3pid/onbind"
CATEGORY = "Federation requests"
REQUIRE_AUTH = False
@ -618,6 +642,7 @@ class On3pidBindServlet(BaseFederationServerServlet):
class FederationVersionServlet(BaseFederationServlet):
PATH = "/version"
CATEGORY = "Federation requests"
REQUIRE_AUTH = False
@ -640,6 +665,7 @@ class FederationVersionServlet(BaseFederationServlet):
class FederationRoomHierarchyServlet(BaseFederationServlet):
PATH = "/hierarchy/(?P<room_id>[^/]*)"
CATEGORY = "Federation requests"
def __init__(
self,
@ -672,6 +698,7 @@ class RoomComplexityServlet(BaseFederationServlet):
PATH = "/rooms/(?P<room_id>[^/]*)/complexity"
PREFIX = FEDERATION_UNSTABLE_PREFIX
CATEGORY = "Federation requests (unstable)"
def __init__(
self,

View file

@ -43,19 +43,22 @@ def client_patterns(
Returns:
An iterable of patterns.
"""
patterns = []
versions = []
if unstable:
unstable_prefix = CLIENT_API_PREFIX + "/unstable"
patterns.append(re.compile("^" + unstable_prefix + path_regex))
if v1:
v1_prefix = CLIENT_API_PREFIX + "/api/v1"
patterns.append(re.compile("^" + v1_prefix + path_regex))
for release in releases:
new_prefix = CLIENT_API_PREFIX + f"/{release}"
patterns.append(re.compile("^" + new_prefix + path_regex))
versions.append("api/v1")
versions.extend(releases)
if unstable:
versions.append("unstable")
return patterns
if len(versions) == 1:
versions_str = versions[0]
elif len(versions) > 1:
versions_str = "(" + "|".join(versions) + ")"
else:
raise RuntimeError("Must have at least one version for a URL")
return [re.compile("^" + CLIENT_API_PREFIX + "/" + versions_str + path_regex)]
def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int) -> None:

View file

@ -576,6 +576,9 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
class ThreepidRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid$")
# This is used as a proxy for all the 3pid endpoints.
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -834,6 +837,7 @@ def assert_valid_next_link(hs: "HomeServer", next_link: str) -> None:
class WhoamiRestServlet(RestServlet):
PATTERNS = client_patterns("/account/whoami$")
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -38,6 +38,7 @@ class AccountDataServlet(RestServlet):
PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)"
)
CATEGORY = "Account data requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -136,6 +137,7 @@ class RoomAccountDataServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)"
"/account_data/(?P<account_data_type>[^/]*)"
)
CATEGORY = "Account data requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -40,6 +40,7 @@ logger = logging.getLogger(__name__)
class DevicesRestServlet(RestServlet):
PATTERNS = client_patterns("/devices$")
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -123,6 +124,7 @@ class DeleteDevicesRestServlet(RestServlet):
class DeviceRestServlet(RestServlet):
PATTERNS = client_patterns("/devices/(?P<device_id>[^/]*)$")
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
class EventStreamRestServlet(RestServlet):
PATTERNS = client_patterns("/events$", v1=True)
CATEGORY = "Sync requests"
DEFAULT_LONGPOLL_TIME_MS = 30000
@ -76,6 +77,7 @@ class EventStreamRestServlet(RestServlet):
class EventRestServlet(RestServlet):
PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -31,6 +31,7 @@ logger = logging.getLogger(__name__)
class GetFilterRestServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
CATEGORY = "Encryption requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -69,6 +70,7 @@ class GetFilterRestServlet(RestServlet):
class CreateFilterRestServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter")
CATEGORY = "Encryption requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -28,6 +28,7 @@ if TYPE_CHECKING:
# TODO: Needs unit testing
class InitialSyncRestServlet(RestServlet):
PATTERNS = client_patterns("/initialSync$", v1=True)
CATEGORY = "Sync requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -89,6 +89,7 @@ class KeyUploadServlet(RestServlet):
"""
PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
CATEGORY = "Encryption requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -182,6 +183,7 @@ class KeyQueryServlet(RestServlet):
"""
PATTERNS = client_patterns("/keys/query$")
CATEGORY = "Encryption requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -225,6 +227,7 @@ class KeyChangesServlet(RestServlet):
"""
PATTERNS = client_patterns("/keys/changes$")
CATEGORY = "Encryption requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -274,6 +277,7 @@ class OneTimeKeyServlet(RestServlet):
"""
PATTERNS = client_patterns("/keys/claim$")
CATEGORY = "Encryption requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -40,6 +40,7 @@ class KnockRoomAliasServlet(RestServlet):
"""
PATTERNS = client_patterns("/knock/(?P<room_identifier>[^/]*)")
CATEGORY = "Event sending requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -72,6 +72,8 @@ class LoginResponse(TypedDict, total=False):
class LoginRestServlet(RestServlet):
PATTERNS = client_patterns("/login$", v1=True)
CATEGORY = "Registration/login requests"
CAS_TYPE = "m.login.cas"
SSO_TYPE = "m.login.sso"
TOKEN_TYPE = "m.login.token"
@ -537,6 +539,7 @@ def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
class RefreshTokenServlet(RestServlet):
PATTERNS = client_patterns("/refresh$")
CATEGORY = "Registration/login requests"
def __init__(self, hs: "HomeServer"):
self._auth_handler = hs.get_auth_handler()
@ -590,6 +593,7 @@ class SsoRedirectServlet(RestServlet):
+ "/(r0|v3)/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$"
)
]
CATEGORY = "SSO requests needed for all SSO providers"
def __init__(self, hs: "HomeServer"):
# make sure that the relevant handlers are instantiated, so that they

View file

@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
class PresenceStatusRestServlet(RestServlet):
PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status", v1=True)
CATEGORY = "Presence requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -29,6 +29,7 @@ if TYPE_CHECKING:
class ProfileDisplaynameRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/displayname", v1=True)
CATEGORY = "Event sending requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -86,6 +87,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
class ProfileAvatarURLRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
CATEGORY = "Event sending requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -142,6 +144,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
class ProfileRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
CATEGORY = "Event sending requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -44,6 +44,9 @@ class PushRuleRestServlet(RestServlet):
"Unrecognised request: You probably wanted a trailing slash"
)
WORKERS_DENIED_METHODS = ["PUT", "DELETE"]
CATEGORY = "Push rule requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()

View file

@ -31,6 +31,7 @@ logger = logging.getLogger(__name__)
class ReadMarkerRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$")
CATEGORY = "Receipts requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -36,6 +36,7 @@ class ReceiptRestServlet(RestServlet):
"/receipt/(?P<receipt_type>[^/]*)"
"/(?P<event_id>[^/]*)$"
)
CATEGORY = "Receipts requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -367,6 +367,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
f"/register/{LoginType.REGISTRATION_TOKEN}/validity",
releases=("v1",),
)
CATEGORY = "Registration/login requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -395,6 +396,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
class RegisterRestServlet(RestServlet):
PATTERNS = client_patterns("/register$")
CATEGORY = "Registration/login requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -42,6 +42,7 @@ class RelationPaginationServlet(RestServlet):
"(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$",
releases=("v1",),
)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -84,6 +85,7 @@ class RelationPaginationServlet(RestServlet):
class ThreadsServlet(RestServlet):
PATTERNS = (re.compile("^/_matrix/client/v1/rooms/(?P<room_id>[^/]*)/threads"),)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -140,7 +140,7 @@ class TransactionRestServlet(RestServlet):
class RoomCreateRestServlet(TransactionRestServlet):
# No PATTERN; we have custom dispatch rules here
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
@ -180,6 +180,8 @@ class RoomCreateRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(RestServlet):
CATEGORY = "Event sending requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
self.event_creation_handler = hs.get_event_creation_handler()
@ -323,6 +325,8 @@ class RoomStateEventRestServlet(RestServlet):
# TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(TransactionRestServlet):
CATEGORY = "Event sending requests"
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
@ -398,6 +402,8 @@ class RoomSendEventRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
CATEGORY = "Event sending requests"
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
super(ResolveRoomIdMixin, self).__init__(hs) # ensure the Mixin is set up
@ -460,6 +466,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
# TODO: Needs unit testing
class PublicRoomListRestServlet(RestServlet):
PATTERNS = client_patterns("/publicRooms$", v1=True)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -578,6 +585,7 @@ class PublicRoomListRestServlet(RestServlet):
# TODO: Needs unit testing
class RoomMemberListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/members$", v1=True)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -633,6 +641,7 @@ class RoomMemberListRestServlet(RestServlet):
# except it does custom AS logic and has a simpler return format
class JoinedRoomMemberListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$", v1=True)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -654,6 +663,10 @@ class JoinedRoomMemberListRestServlet(RestServlet):
# TODO: Needs better unit testing
class RoomMessageListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/messages$", v1=True)
# TODO The routing information should be exposed programatically.
# I want to do this but for now I felt bad about leaving this without
# at least a visible warning on it.
CATEGORY = "Client API requests (ALL FOR SAME ROOM MUST GO TO SAME WORKER)"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -720,6 +733,7 @@ class RoomMessageListRestServlet(RestServlet):
# TODO: Needs unit testing
class RoomStateRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/state$", v1=True)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -742,6 +756,7 @@ class RoomStateRestServlet(RestServlet):
# TODO: Needs unit testing
class RoomInitialSyncRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$", v1=True)
CATEGORY = "Sync requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -766,6 +781,7 @@ class RoomEventServlet(RestServlet):
PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$", v1=True
)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -858,6 +874,7 @@ class RoomEventContextServlet(RestServlet):
PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$", v1=True
)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -958,6 +975,8 @@ class RoomForgetRestServlet(TransactionRestServlet):
# TODO: Needs unit testing
class RoomMembershipRestServlet(TransactionRestServlet):
CATEGORY = "Event sending requests"
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
@ -1071,6 +1090,8 @@ class RoomMembershipRestServlet(TransactionRestServlet):
class RoomRedactEventRestServlet(TransactionRestServlet):
CATEGORY = "Event sending requests"
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
@ -1164,6 +1185,7 @@ class RoomTypingRestServlet(RestServlet):
PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$", v1=True
)
CATEGORY = "The typing stream"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -1195,7 +1217,7 @@ class RoomTypingRestServlet(RestServlet):
# Limit timeout to stop people from setting silly typing timeouts.
timeout = min(content.get("timeout", 30000), 120000)
# Defer getting the typing handler since it will raise on workers.
# Defer getting the typing handler since it will raise on WORKER_PATTERNS.
typing_handler = self.hs.get_typing_writer_handler()
try:
@ -1224,6 +1246,7 @@ class RoomAliasListServlet(RestServlet):
r"/rooms/(?P<room_id>[^/]*)/aliases"
),
] + list(client_patterns("/rooms/(?P<room_id>[^/]*)/aliases$", unstable=False))
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -1244,6 +1267,7 @@ class RoomAliasListServlet(RestServlet):
class SearchRestServlet(RestServlet):
PATTERNS = client_patterns("/search$", v1=True)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -1263,6 +1287,7 @@ class SearchRestServlet(RestServlet):
class JoinedRoomsRestServlet(RestServlet):
PATTERNS = client_patterns("/joined_rooms$", v1=True)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -1334,6 +1359,7 @@ class TimestampLookupRestServlet(RestServlet):
PATTERNS = (
re.compile("^/_matrix/client/v1/rooms/(?P<room_id>[^/]*)/timestamp_to_event$"),
)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -1365,6 +1391,8 @@ class TimestampLookupRestServlet(RestServlet):
class RoomHierarchyRestServlet(RestServlet):
PATTERNS = (re.compile("^/_matrix/client/v1/rooms/(?P<room_id>[^/]*)/hierarchy$"),)
WORKERS = PATTERNS
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -1405,6 +1433,7 @@ class RoomSummaryRestServlet(ResolveRoomIdMixin, RestServlet):
"/rooms/(?P<room_identifier>[^/]*)/summary$"
),
)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

View file

@ -69,6 +69,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)/batch_send$"
),
)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -37,6 +37,7 @@ class RoomKeysServlet(RestServlet):
PATTERNS = client_patterns(
"/room_keys/keys(/(?P<room_id>[^/]+))?(/(?P<session_id>[^/]+))?$"
)
CATEGORY = "Encryption requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -253,6 +254,7 @@ class RoomKeysServlet(RestServlet):
class RoomKeysNewVersionServlet(RestServlet):
PATTERNS = client_patterns("/room_keys/version$")
CATEGORY = "Encryption requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -328,6 +330,7 @@ class RoomKeysNewVersionServlet(RestServlet):
class RoomKeysVersionServlet(RestServlet):
PATTERNS = client_patterns("/room_keys/version/(?P<version>[^/]+)$")
CATEGORY = "Encryption requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -35,6 +35,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
PATTERNS = client_patterns(
"/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$"
)
CATEGORY = "The to_device stream"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -87,6 +87,7 @@ class SyncRestServlet(RestServlet):
PATTERNS = client_patterns("/sync$")
ALLOWED_PRESENCE = {"online", "offline", "unavailable"}
CATEGORY = "Sync requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -37,6 +37,7 @@ class TagListServlet(RestServlet):
PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags$"
)
CATEGORY = "Account data requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -64,6 +65,7 @@ class TagServlet(RestServlet):
PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)"
)
CATEGORY = "Account data requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -31,6 +31,7 @@ logger = logging.getLogger(__name__)
class UserDirectorySearchRestServlet(RestServlet):
PATTERNS = client_patterns("/user_directory/search$")
CATEGORY = "User directory search requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -34,6 +34,7 @@ logger = logging.getLogger(__name__)
class VersionsRestServlet(RestServlet):
PATTERNS = [re.compile("^/_matrix/client/versions$")]
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -29,6 +29,7 @@ if TYPE_CHECKING:
class VoipRestServlet(RestServlet):
PATTERNS = client_patterns("/voip/turnServer$", v1=True)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -93,6 +93,8 @@ class RemoteKey(RestServlet):
}
"""
CATEGORY = "Federation requests"
def __init__(self, hs: "HomeServer"):
self.fetcher = ServerKeyFetcher(hs)
self.store = hs.get_datastores().main