mirror of
https://github.com/element-hq/synapse.git
synced 2024-12-14 11:57:44 +00:00
Add type hints to more handlers (#8244)
This commit is contained in:
parent
4535e849d7
commit
be16ee59a8
5 changed files with 110 additions and 79 deletions
1
changelog.d/8244.misc
Normal file
1
changelog.d/8244.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Add type hints to pagination, initial sync and events handlers.
|
3
mypy.ini
3
mypy.ini
|
@ -17,10 +17,13 @@ files =
|
|||
synapse/handlers/auth.py,
|
||||
synapse/handlers/cas_handler.py,
|
||||
synapse/handlers/directory.py,
|
||||
synapse/handlers/events.py,
|
||||
synapse/handlers/federation.py,
|
||||
synapse/handlers/identity.py,
|
||||
synapse/handlers/initial_sync.py,
|
||||
synapse/handlers/message.py,
|
||||
synapse/handlers/oidc_handler.py,
|
||||
synapse/handlers/pagination.py,
|
||||
synapse/handlers/presence.py,
|
||||
synapse/handlers/room.py,
|
||||
synapse/handlers/room_member.py,
|
||||
|
|
|
@ -15,29 +15,30 @@
|
|||
|
||||
import logging
|
||||
import random
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import AuthError, SynapseError
|
||||
from synapse.events import EventBase
|
||||
from synapse.handlers.presence import format_user_presence_state
|
||||
from synapse.logging.utils import log_function
|
||||
from synapse.types import UserID
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventStreamHandler(BaseHandler):
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super(EventStreamHandler, self).__init__(hs)
|
||||
|
||||
# Count of active streams per user
|
||||
self._streams_per_user = {}
|
||||
# Grace timers per user to delay the "stopped" signal
|
||||
self._stop_timer_per_user = {}
|
||||
|
||||
self.distributor = hs.get_distributor()
|
||||
self.distributor.declare("started_user_eventstream")
|
||||
self.distributor.declare("stopped_user_eventstream")
|
||||
|
@ -52,14 +53,14 @@ class EventStreamHandler(BaseHandler):
|
|||
@log_function
|
||||
async def get_stream(
|
||||
self,
|
||||
auth_user_id,
|
||||
pagin_config,
|
||||
timeout=0,
|
||||
as_client_event=True,
|
||||
affect_presence=True,
|
||||
room_id=None,
|
||||
is_guest=False,
|
||||
):
|
||||
auth_user_id: str,
|
||||
pagin_config: PaginationConfig,
|
||||
timeout: int = 0,
|
||||
as_client_event: bool = True,
|
||||
affect_presence: bool = True,
|
||||
room_id: Optional[str] = None,
|
||||
is_guest: bool = False,
|
||||
) -> JsonDict:
|
||||
"""Fetches the events stream for a given user.
|
||||
"""
|
||||
|
||||
|
@ -98,7 +99,7 @@ class EventStreamHandler(BaseHandler):
|
|||
|
||||
# When the user joins a new room, or another user joins a currently
|
||||
# joined room, we need to send down presence for those users.
|
||||
to_add = []
|
||||
to_add = [] # type: List[JsonDict]
|
||||
for event in events:
|
||||
if not isinstance(event, EventBase):
|
||||
continue
|
||||
|
@ -110,7 +111,7 @@ class EventStreamHandler(BaseHandler):
|
|||
# Send down presence for everyone in the room.
|
||||
users = await self.state.get_current_users_in_room(
|
||||
event.room_id
|
||||
)
|
||||
) # type: Iterable[str]
|
||||
else:
|
||||
users = [event.state_key]
|
||||
|
||||
|
@ -144,20 +145,22 @@ class EventStreamHandler(BaseHandler):
|
|||
|
||||
|
||||
class EventHandler(BaseHandler):
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super(EventHandler, self).__init__(hs)
|
||||
self.storage = hs.get_storage()
|
||||
|
||||
async def get_event(self, user, room_id, event_id):
|
||||
async def get_event(
|
||||
self, user: UserID, room_id: Optional[str], event_id: str
|
||||
) -> Optional[EventBase]:
|
||||
"""Retrieve a single specified event.
|
||||
|
||||
Args:
|
||||
user (synapse.types.UserID): The user requesting the event
|
||||
room_id (str|None): The expected room id. We'll return None if the
|
||||
user: The user requesting the event
|
||||
room_id: The expected room id. We'll return None if the
|
||||
event's room does not match.
|
||||
event_id (str): The event ID to obtain.
|
||||
event_id: The event ID to obtain.
|
||||
Returns:
|
||||
dict: An event, or None if there is no event matching this ID.
|
||||
An event, or None if there is no event matching this ID.
|
||||
Raises:
|
||||
SynapseError if there was a problem retrieving this event, or
|
||||
AuthError if the user does not have the rights to inspect this
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
|
@ -22,8 +23,9 @@ from synapse.api.errors import SynapseError
|
|||
from synapse.events.validator import EventValidator
|
||||
from synapse.handlers.presence import format_user_presence_state
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.storage.roommember import RoomsForUser
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.types import StreamToken, UserID
|
||||
from synapse.types import JsonDict, Requester, StreamToken, UserID
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.async_helpers import concurrently_execute
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
|
@ -31,11 +33,15 @@ from synapse.visibility import filter_events_for_client
|
|||
|
||||
from ._base import BaseHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InitialSyncHandler(BaseHandler):
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super(InitialSyncHandler, self).__init__(hs)
|
||||
self.hs = hs
|
||||
self.state = hs.get_state_handler()
|
||||
|
@ -48,27 +54,25 @@ class InitialSyncHandler(BaseHandler):
|
|||
|
||||
def snapshot_all_rooms(
|
||||
self,
|
||||
user_id=None,
|
||||
pagin_config=None,
|
||||
as_client_event=True,
|
||||
include_archived=False,
|
||||
):
|
||||
user_id: str,
|
||||
pagin_config: PaginationConfig,
|
||||
as_client_event: bool = True,
|
||||
include_archived: bool = False,
|
||||
) -> JsonDict:
|
||||
"""Retrieve a snapshot of all rooms the user is invited or has joined.
|
||||
|
||||
This snapshot may include messages for all rooms where the user is
|
||||
joined, depending on the pagination config.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user making the request.
|
||||
pagin_config (synapse.api.streams.PaginationConfig): The pagination
|
||||
config used to determine how many messages *PER ROOM* to return.
|
||||
as_client_event (bool): True to get events in client-server format.
|
||||
include_archived (bool): True to get rooms that the user has left
|
||||
user_id: The ID of the user making the request.
|
||||
pagin_config: The pagination config used to determine how many
|
||||
messages *PER ROOM* to return.
|
||||
as_client_event: True to get events in client-server format.
|
||||
include_archived: True to get rooms that the user has left
|
||||
Returns:
|
||||
A list of dicts with "room_id" and "membership" keys for all rooms
|
||||
the user is currently invited or joined in on. Rooms where the user
|
||||
is joined on, may return a "messages" key with messages, depending
|
||||
on the specified PaginationConfig.
|
||||
A JsonDict with the same format as the response to `/intialSync`
|
||||
API
|
||||
"""
|
||||
key = (
|
||||
user_id,
|
||||
|
@ -91,11 +95,11 @@ class InitialSyncHandler(BaseHandler):
|
|||
|
||||
async def _snapshot_all_rooms(
|
||||
self,
|
||||
user_id=None,
|
||||
pagin_config=None,
|
||||
as_client_event=True,
|
||||
include_archived=False,
|
||||
):
|
||||
user_id: str,
|
||||
pagin_config: PaginationConfig,
|
||||
as_client_event: bool = True,
|
||||
include_archived: bool = False,
|
||||
) -> JsonDict:
|
||||
|
||||
memberships = [Membership.INVITE, Membership.JOIN]
|
||||
if include_archived:
|
||||
|
@ -134,7 +138,7 @@ class InitialSyncHandler(BaseHandler):
|
|||
if limit is None:
|
||||
limit = 10
|
||||
|
||||
async def handle_room(event):
|
||||
async def handle_room(event: RoomsForUser):
|
||||
d = {
|
||||
"room_id": event.room_id,
|
||||
"membership": event.membership,
|
||||
|
@ -251,17 +255,18 @@ class InitialSyncHandler(BaseHandler):
|
|||
|
||||
return ret
|
||||
|
||||
async def room_initial_sync(self, requester, room_id, pagin_config=None):
|
||||
async def room_initial_sync(
|
||||
self, requester: Requester, room_id: str, pagin_config: PaginationConfig
|
||||
) -> JsonDict:
|
||||
"""Capture the a snapshot of a room. If user is currently a member of
|
||||
the room this will be what is currently in the room. If the user left
|
||||
the room this will be what was in the room when they left.
|
||||
|
||||
Args:
|
||||
requester(Requester): The user to get a snapshot for.
|
||||
room_id(str): The room to get a snapshot of.
|
||||
pagin_config(synapse.streams.config.PaginationConfig):
|
||||
The pagination config used to determine how many messages to
|
||||
return.
|
||||
requester: The user to get a snapshot for.
|
||||
room_id: The room to get a snapshot of.
|
||||
pagin_config: The pagination config used to determine how many
|
||||
messages to return.
|
||||
Raises:
|
||||
AuthError if the user wasn't in the room.
|
||||
Returns:
|
||||
|
@ -305,8 +310,14 @@ class InitialSyncHandler(BaseHandler):
|
|||
return result
|
||||
|
||||
async def _room_initial_sync_parted(
|
||||
self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking
|
||||
):
|
||||
self,
|
||||
user_id: str,
|
||||
room_id: str,
|
||||
pagin_config: PaginationConfig,
|
||||
membership: Membership,
|
||||
member_event_id: str,
|
||||
is_peeking: bool,
|
||||
) -> JsonDict:
|
||||
room_state = await self.state_store.get_state_for_events([member_event_id])
|
||||
|
||||
room_state = room_state[member_event_id]
|
||||
|
@ -350,8 +361,13 @@ class InitialSyncHandler(BaseHandler):
|
|||
}
|
||||
|
||||
async def _room_initial_sync_joined(
|
||||
self, user_id, room_id, pagin_config, membership, is_peeking
|
||||
):
|
||||
self,
|
||||
user_id: str,
|
||||
room_id: str,
|
||||
pagin_config: PaginationConfig,
|
||||
membership: Membership,
|
||||
is_peeking: bool,
|
||||
) -> JsonDict:
|
||||
current_state = await self.state.get_current_state(room_id=room_id)
|
||||
|
||||
# TODO: These concurrently
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Set
|
||||
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
|
@ -30,6 +30,10 @@ from synapse.util.async_helpers import ReadWriteLock
|
|||
from synapse.util.stringutils import random_string
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -68,7 +72,7 @@ class PaginationHandler(object):
|
|||
paginating during a purge.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
@ -78,9 +82,9 @@ class PaginationHandler(object):
|
|||
self._server_name = hs.hostname
|
||||
|
||||
self.pagination_lock = ReadWriteLock()
|
||||
self._purges_in_progress_by_room = set()
|
||||
self._purges_in_progress_by_room = set() # type: Set[str]
|
||||
# map from purge id to PurgeStatus
|
||||
self._purges_by_id = {}
|
||||
self._purges_by_id = {} # type: Dict[str, PurgeStatus]
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
|
||||
self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime
|
||||
|
@ -102,7 +106,9 @@ class PaginationHandler(object):
|
|||
job["longest_max_lifetime"],
|
||||
)
|
||||
|
||||
async def purge_history_for_rooms_in_range(self, min_ms, max_ms):
|
||||
async def purge_history_for_rooms_in_range(
|
||||
self, min_ms: Optional[int], max_ms: Optional[int]
|
||||
):
|
||||
"""Purge outdated events from rooms within the given retention range.
|
||||
|
||||
If a default retention policy is defined in the server's configuration and its
|
||||
|
@ -110,10 +116,10 @@ class PaginationHandler(object):
|
|||
retention policy.
|
||||
|
||||
Args:
|
||||
min_ms (int|None): Duration in milliseconds that define the lower limit of
|
||||
min_ms: Duration in milliseconds that define the lower limit of
|
||||
the range to handle (exclusive). If None, it means that the range has no
|
||||
lower limit.
|
||||
max_ms (int|None): Duration in milliseconds that define the upper limit of
|
||||
max_ms: Duration in milliseconds that define the upper limit of
|
||||
the range to handle (inclusive). If None, it means that the range has no
|
||||
upper limit.
|
||||
"""
|
||||
|
@ -220,18 +226,19 @@ class PaginationHandler(object):
|
|||
"_purge_history", self._purge_history, purge_id, room_id, token, True,
|
||||
)
|
||||
|
||||
def start_purge_history(self, room_id, token, delete_local_events=False):
|
||||
def start_purge_history(
|
||||
self, room_id: str, token: str, delete_local_events: bool = False
|
||||
) -> str:
|
||||
"""Start off a history purge on a room.
|
||||
|
||||
Args:
|
||||
room_id (str): The room to purge from
|
||||
|
||||
token (str): topological token to delete events before
|
||||
delete_local_events (bool): True to delete local events as well as
|
||||
room_id: The room to purge from
|
||||
token: topological token to delete events before
|
||||
delete_local_events: True to delete local events as well as
|
||||
remote ones
|
||||
|
||||
Returns:
|
||||
str: unique ID for this purge transaction.
|
||||
unique ID for this purge transaction.
|
||||
"""
|
||||
if room_id in self._purges_in_progress_by_room:
|
||||
raise SynapseError(
|
||||
|
@ -284,14 +291,11 @@ class PaginationHandler(object):
|
|||
|
||||
self.hs.get_reactor().callLater(24 * 3600, clear_purge)
|
||||
|
||||
def get_purge_status(self, purge_id):
|
||||
def get_purge_status(self, purge_id: str) -> Optional[PurgeStatus]:
|
||||
"""Get the current status of an active purge
|
||||
|
||||
Args:
|
||||
purge_id (str): purge_id returned by start_purge_history
|
||||
|
||||
Returns:
|
||||
PurgeStatus|None
|
||||
purge_id: purge_id returned by start_purge_history
|
||||
"""
|
||||
return self._purges_by_id.get(purge_id)
|
||||
|
||||
|
@ -312,8 +316,8 @@ class PaginationHandler(object):
|
|||
async def get_messages(
|
||||
self,
|
||||
requester: Requester,
|
||||
room_id: Optional[str] = None,
|
||||
pagin_config: Optional[PaginationConfig] = None,
|
||||
room_id: str,
|
||||
pagin_config: PaginationConfig,
|
||||
as_client_event: bool = True,
|
||||
event_filter: Optional[Filter] = None,
|
||||
) -> Dict[str, Any]:
|
||||
|
@ -368,11 +372,15 @@ class PaginationHandler(object):
|
|||
# If they have left the room then clamp the token to be before
|
||||
# they left the room, to save the effort of loading from the
|
||||
# database.
|
||||
|
||||
# This is only None if the room is world_readable, in which
|
||||
# case "JOIN" would have been returned.
|
||||
assert member_event_id
|
||||
|
||||
leave_token = await self.store.get_topological_token_for_event(
|
||||
member_event_id
|
||||
)
|
||||
leave_token = RoomStreamToken.parse(leave_token)
|
||||
if leave_token.topological < max_topo:
|
||||
if RoomStreamToken.parse(leave_token).topological < max_topo:
|
||||
source_config.from_key = str(leave_token)
|
||||
|
||||
await self.hs.get_handlers().federation_handler.maybe_backfill(
|
||||
|
@ -419,8 +427,8 @@ class PaginationHandler(object):
|
|||
)
|
||||
|
||||
if state_ids:
|
||||
state = await self.store.get_events(list(state_ids.values()))
|
||||
state = state.values()
|
||||
state_dict = await self.store.get_events(list(state_ids.values()))
|
||||
state = state_dict.values()
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
|
|
Loading…
Reference in a new issue