From 1f773eec912e4908ab60f7823f5c0a024261af4d Mon Sep 17 00:00:00 2001
From: Erik Johnston <erik@matrix.org>
Date: Wed, 26 Feb 2020 15:33:26 +0000
Subject: [PATCH] Port PresenceHandler to async/await (#6991)

---
 changelog.d/6991.misc               |   1 +
 synapse/handlers/message.py         |   5 +-
 synapse/handlers/presence.py        | 192 +++++++++++++---------------
 synapse/replication/tcp/resource.py |   6 +-
 synapse/server.pyi                  |   5 +
 tests/handlers/test_presence.py     |  18 ++-
 tox.ini                             |   1 +
 7 files changed, 113 insertions(+), 115 deletions(-)
 create mode 100644 changelog.d/6991.misc

diff --git a/changelog.d/6991.misc b/changelog.d/6991.misc
new file mode 100644
index 0000000000..5130f4e8af
--- /dev/null
+++ b/changelog.d/6991.misc
@@ -0,0 +1 @@
+Port `synapse.handlers.presence` to async/await.
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index d6be280952..a0103addd3 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1016,11 +1016,10 @@ class EventCreationHandler(object):
             # matters as sometimes presence code can take a while.
             run_in_background(self._bump_active_time, requester.user)
 
-    @defer.inlineCallbacks
-    def _bump_active_time(self, user):
+    async def _bump_active_time(self, user):
         try:
             presence = self.hs.get_presence_handler()
-            yield presence.bump_presence_active_time(user)
+            await presence.bump_presence_active_time(user)
         except Exception:
             logger.exception("Error bumping presence active time")
 
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 0d6cf2b008..5526015ddb 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -24,11 +24,12 @@ The methods that define policy are:
 
 import logging
 from contextlib import contextmanager
-from typing import Dict, Set
+from typing import Dict, List, Set
 
 from six import iteritems, itervalues
 
 from prometheus_client import Counter
+from typing_extensions import ContextManager
 
 from twisted.internet import defer
 
@@ -42,10 +43,14 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.presence import UserPresenceState
 from synapse.types import UserID, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached
 from synapse.util.metrics import Measure
 from synapse.util.wheel_timer import WheelTimer
 
+MYPY = False
+if MYPY:
+    import synapse.server
+
 logger = logging.getLogger(__name__)
 
 
@@ -97,7 +102,6 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
 class PresenceHandler(object):
     def __init__(self, hs: "synapse.server.HomeServer"):
         self.hs = hs
-        self.is_mine = hs.is_mine
         self.is_mine_id = hs.is_mine_id
         self.server_name = hs.hostname
         self.clock = hs.get_clock()
@@ -150,7 +154,7 @@ class PresenceHandler(object):
 
         # Set of users who have presence in the `user_to_current_state` that
         # have not yet been persisted
-        self.unpersisted_users_changes = set()
+        self.unpersisted_users_changes = set()  # type: Set[str]
 
         hs.get_reactor().addSystemEventTrigger(
             "before",
@@ -160,12 +164,11 @@ class PresenceHandler(object):
             self._on_shutdown,
         )
 
-        self.serial_to_user = {}
         self._next_serial = 1
 
         # Keeps track of the number of *ongoing* syncs on this process. While
         # this is non zero a user will never go offline.
-        self.user_to_num_current_syncs = {}
+        self.user_to_num_current_syncs = {}  # type: Dict[str, int]
 
         # Keeps track of the number of *ongoing* syncs on other processes.
         # While any sync is ongoing on another process the user will never
@@ -213,8 +216,7 @@ class PresenceHandler(object):
         self._event_pos = self.store.get_current_events_token()
         self._event_processing = False
 
-    @defer.inlineCallbacks
-    def _on_shutdown(self):
+    async def _on_shutdown(self):
         """Gets called when shutting down. This lets us persist any updates that
         we haven't yet persisted, e.g. updates that only changes some internal
         timers. This allows changes to persist across startup without having to
@@ -235,7 +237,7 @@ class PresenceHandler(object):
 
         if self.unpersisted_users_changes:
 
-            yield self.store.update_presence(
+            await self.store.update_presence(
                 [
                     self.user_to_current_state[user_id]
                     for user_id in self.unpersisted_users_changes
@@ -243,8 +245,7 @@ class PresenceHandler(object):
             )
         logger.info("Finished _on_shutdown")
 
-    @defer.inlineCallbacks
-    def _persist_unpersisted_changes(self):
+    async def _persist_unpersisted_changes(self):
         """We periodically persist the unpersisted changes, as otherwise they
         may stack up and slow down shutdown times.
         """
@@ -253,12 +254,11 @@ class PresenceHandler(object):
 
         if unpersisted:
             logger.info("Persisting %d unpersisted presence updates", len(unpersisted))
-            yield self.store.update_presence(
+            await self.store.update_presence(
                 [self.user_to_current_state[user_id] for user_id in unpersisted]
             )
 
-    @defer.inlineCallbacks
-    def _update_states(self, new_states):
+    async def _update_states(self, new_states):
         """Updates presence of users. Sets the appropriate timeouts. Pokes
         the notifier and federation if and only if the changed presence state
         should be sent to clients/servers.
@@ -267,7 +267,7 @@ class PresenceHandler(object):
 
         with Measure(self.clock, "presence_update_states"):
 
-            # NOTE: We purposefully don't yield between now and when we've
+            # NOTE: We purposefully don't await between now and when we've
             # calculated what we want to do with the new states, to avoid races.
 
             to_notify = {}  # Changes we want to notify everyone about
@@ -311,7 +311,7 @@ class PresenceHandler(object):
 
             if to_notify:
                 notified_presence_counter.inc(len(to_notify))
-                yield self._persist_and_notify(list(to_notify.values()))
+                await self._persist_and_notify(list(to_notify.values()))
 
             self.unpersisted_users_changes |= {s.user_id for s in new_states}
             self.unpersisted_users_changes -= set(to_notify.keys())
@@ -326,7 +326,7 @@ class PresenceHandler(object):
 
                 self._push_to_remotes(to_federation_ping.values())
 
-    def _handle_timeouts(self):
+    async def _handle_timeouts(self):
         """Checks the presence of users that have timed out and updates as
         appropriate.
         """
@@ -368,10 +368,9 @@ class PresenceHandler(object):
             now=now,
         )
 
-        return self._update_states(changes)
+        return await self._update_states(changes)
 
-    @defer.inlineCallbacks
-    def bump_presence_active_time(self, user):
+    async def bump_presence_active_time(self, user):
         """We've seen the user do something that indicates they're interacting
         with the app.
         """
@@ -383,16 +382,17 @@ class PresenceHandler(object):
 
         bump_active_time_counter.inc()
 
-        prev_state = yield self.current_state_for_user(user_id)
+        prev_state = await self.current_state_for_user(user_id)
 
         new_fields = {"last_active_ts": self.clock.time_msec()}
         if prev_state.state == PresenceState.UNAVAILABLE:
             new_fields["state"] = PresenceState.ONLINE
 
-        yield self._update_states([prev_state.copy_and_replace(**new_fields)])
+        await self._update_states([prev_state.copy_and_replace(**new_fields)])
 
-    @defer.inlineCallbacks
-    def user_syncing(self, user_id, affect_presence=True):
+    async def user_syncing(
+        self, user_id: str, affect_presence: bool = True
+    ) -> ContextManager[None]:
         """Returns a context manager that should surround any stream requests
         from the user.
 
@@ -415,11 +415,11 @@ class PresenceHandler(object):
             curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
             self.user_to_num_current_syncs[user_id] = curr_sync + 1
 
-            prev_state = yield self.current_state_for_user(user_id)
+            prev_state = await self.current_state_for_user(user_id)
             if prev_state.state == PresenceState.OFFLINE:
                 # If they're currently offline then bring them online, otherwise
                 # just update the last sync times.
-                yield self._update_states(
+                await self._update_states(
                     [
                         prev_state.copy_and_replace(
                             state=PresenceState.ONLINE,
@@ -429,7 +429,7 @@ class PresenceHandler(object):
                     ]
                 )
             else:
-                yield self._update_states(
+                await self._update_states(
                     [
                         prev_state.copy_and_replace(
                             last_user_sync_ts=self.clock.time_msec()
@@ -437,13 +437,12 @@ class PresenceHandler(object):
                     ]
                 )
 
-        @defer.inlineCallbacks
-        def _end():
+        async def _end():
             try:
                 self.user_to_num_current_syncs[user_id] -= 1
 
-                prev_state = yield self.current_state_for_user(user_id)
-                yield self._update_states(
+                prev_state = await self.current_state_for_user(user_id)
+                await self._update_states(
                     [
                         prev_state.copy_and_replace(
                             last_user_sync_ts=self.clock.time_msec()
@@ -480,8 +479,7 @@ class PresenceHandler(object):
         else:
             return set()
 
-    @defer.inlineCallbacks
-    def update_external_syncs_row(
+    async def update_external_syncs_row(
         self, process_id, user_id, is_syncing, sync_time_msec
     ):
         """Update the syncing users for an external process as a delta.
@@ -494,8 +492,8 @@ class PresenceHandler(object):
             is_syncing (bool): Whether or not the user is now syncing
             sync_time_msec(int): Time in ms when the user was last syncing
         """
-        with (yield self.external_sync_linearizer.queue(process_id)):
-            prev_state = yield self.current_state_for_user(user_id)
+        with (await self.external_sync_linearizer.queue(process_id)):
+            prev_state = await self.current_state_for_user(user_id)
 
             process_presence = self.external_process_to_current_syncs.setdefault(
                 process_id, set()
@@ -525,25 +523,24 @@ class PresenceHandler(object):
                 process_presence.discard(user_id)
 
             if updates:
-                yield self._update_states(updates)
+                await self._update_states(updates)
 
             self.external_process_last_updated_ms[process_id] = self.clock.time_msec()
 
-    @defer.inlineCallbacks
-    def update_external_syncs_clear(self, process_id):
+    async def update_external_syncs_clear(self, process_id):
         """Marks all users that had been marked as syncing by a given process
         as offline.
 
         Used when the process has stopped/disappeared.
         """
-        with (yield self.external_sync_linearizer.queue(process_id)):
+        with (await self.external_sync_linearizer.queue(process_id)):
             process_presence = self.external_process_to_current_syncs.pop(
                 process_id, set()
             )
-            prev_states = yield self.current_state_for_users(process_presence)
+            prev_states = await self.current_state_for_users(process_presence)
             time_now_ms = self.clock.time_msec()
 
-            yield self._update_states(
+            await self._update_states(
                 [
                     prev_state.copy_and_replace(last_user_sync_ts=time_now_ms)
                     for prev_state in itervalues(prev_states)
@@ -551,15 +548,13 @@ class PresenceHandler(object):
             )
             self.external_process_last_updated_ms.pop(process_id, None)
 
-    @defer.inlineCallbacks
-    def current_state_for_user(self, user_id):
+    async def current_state_for_user(self, user_id):
         """Get the current presence state for a user.
         """
-        res = yield self.current_state_for_users([user_id])
+        res = await self.current_state_for_users([user_id])
         return res[user_id]
 
-    @defer.inlineCallbacks
-    def current_state_for_users(self, user_ids):
+    async def current_state_for_users(self, user_ids):
         """Get the current presence state for multiple users.
 
         Returns:
@@ -574,7 +569,7 @@ class PresenceHandler(object):
         if missing:
             # There are things not in our in memory cache. Lets pull them out of
             # the database.
-            res = yield self.store.get_presence_for_users(missing)
+            res = await self.store.get_presence_for_users(missing)
             states.update(res)
 
             missing = [user_id for user_id, state in iteritems(states) if not state]
@@ -587,14 +582,13 @@ class PresenceHandler(object):
 
         return states
 
-    @defer.inlineCallbacks
-    def _persist_and_notify(self, states):
+    async def _persist_and_notify(self, states):
         """Persist states in the database, poke the notifier and send to
         interested remote servers
         """
-        stream_id, max_token = yield self.store.update_presence(states)
+        stream_id, max_token = await self.store.update_presence(states)
 
-        parties = yield get_interested_parties(self.store, states)
+        parties = await get_interested_parties(self.store, states)
         room_ids_to_states, users_to_states = parties
 
         self.notifier.on_new_event(
@@ -606,9 +600,8 @@ class PresenceHandler(object):
 
         self._push_to_remotes(states)
 
-    @defer.inlineCallbacks
-    def notify_for_states(self, state, stream_id):
-        parties = yield get_interested_parties(self.store, [state])
+    async def notify_for_states(self, state, stream_id):
+        parties = await get_interested_parties(self.store, [state])
         room_ids_to_states, users_to_states = parties
 
         self.notifier.on_new_event(
@@ -626,8 +619,7 @@ class PresenceHandler(object):
         """
         self.federation.send_presence(states)
 
-    @defer.inlineCallbacks
-    def incoming_presence(self, origin, content):
+    async def incoming_presence(self, origin, content):
         """Called when we receive a `m.presence` EDU from a remote server.
         """
         now = self.clock.time_msec()
@@ -670,21 +662,19 @@ class PresenceHandler(object):
             new_fields["status_msg"] = push.get("status_msg", None)
             new_fields["currently_active"] = push.get("currently_active", False)
 
-            prev_state = yield self.current_state_for_user(user_id)
+            prev_state = await self.current_state_for_user(user_id)
             updates.append(prev_state.copy_and_replace(**new_fields))
 
         if updates:
             federation_presence_counter.inc(len(updates))
-            yield self._update_states(updates)
+            await self._update_states(updates)
 
-    @defer.inlineCallbacks
-    def get_state(self, target_user, as_event=False):
-        results = yield self.get_states([target_user.to_string()], as_event=as_event)
+    async def get_state(self, target_user, as_event=False):
+        results = await self.get_states([target_user.to_string()], as_event=as_event)
 
         return results[0]
 
-    @defer.inlineCallbacks
-    def get_states(self, target_user_ids, as_event=False):
+    async def get_states(self, target_user_ids, as_event=False):
         """Get the presence state for users.
 
         Args:
@@ -695,7 +685,7 @@ class PresenceHandler(object):
             list
         """
 
-        updates = yield self.current_state_for_users(target_user_ids)
+        updates = await self.current_state_for_users(target_user_ids)
         updates = list(updates.values())
 
         for user_id in set(target_user_ids) - {u.user_id for u in updates}:
@@ -713,8 +703,7 @@ class PresenceHandler(object):
         else:
             return updates
 
-    @defer.inlineCallbacks
-    def set_state(self, target_user, state, ignore_status_msg=False):
+    async def set_state(self, target_user, state, ignore_status_msg=False):
         """Set the presence state of the user.
         """
         status_msg = state.get("status_msg", None)
@@ -730,7 +719,7 @@ class PresenceHandler(object):
 
         user_id = target_user.to_string()
 
-        prev_state = yield self.current_state_for_user(user_id)
+        prev_state = await self.current_state_for_user(user_id)
 
         new_fields = {"state": presence}
 
@@ -741,16 +730,15 @@ class PresenceHandler(object):
         if presence == PresenceState.ONLINE:
             new_fields["last_active_ts"] = self.clock.time_msec()
 
-        yield self._update_states([prev_state.copy_and_replace(**new_fields)])
+        await self._update_states([prev_state.copy_and_replace(**new_fields)])
 
-    @defer.inlineCallbacks
-    def is_visible(self, observed_user, observer_user):
+    async def is_visible(self, observed_user, observer_user):
         """Returns whether a user can see another user's presence.
         """
-        observer_room_ids = yield self.store.get_rooms_for_user(
+        observer_room_ids = await self.store.get_rooms_for_user(
             observer_user.to_string()
         )
-        observed_room_ids = yield self.store.get_rooms_for_user(
+        observed_room_ids = await self.store.get_rooms_for_user(
             observed_user.to_string()
         )
 
@@ -759,8 +747,7 @@ class PresenceHandler(object):
 
         return False
 
-    @defer.inlineCallbacks
-    def get_all_presence_updates(self, last_id, current_id):
+    async def get_all_presence_updates(self, last_id, current_id):
         """
         Gets a list of presence update rows from between the given stream ids.
         Each row has:
@@ -775,7 +762,7 @@ class PresenceHandler(object):
         """
         # TODO(markjh): replicate the unpersisted changes.
         # This could use the in-memory stores for recent changes.
-        rows = yield self.store.get_all_presence_updates(last_id, current_id)
+        rows = await self.store.get_all_presence_updates(last_id, current_id)
         return rows
 
     def notify_new_event(self):
@@ -786,20 +773,18 @@ class PresenceHandler(object):
         if self._event_processing:
             return
 
-        @defer.inlineCallbacks
-        def _process_presence():
+        async def _process_presence():
             assert not self._event_processing
 
             self._event_processing = True
             try:
-                yield self._unsafe_process()
+                await self._unsafe_process()
             finally:
                 self._event_processing = False
 
         run_as_background_process("presence.notify_new_event", _process_presence)
 
-    @defer.inlineCallbacks
-    def _unsafe_process(self):
+    async def _unsafe_process(self):
         # Loop round handling deltas until we're up to date
         while True:
             with Measure(self.clock, "presence_delta"):
@@ -812,10 +797,10 @@ class PresenceHandler(object):
                     self._event_pos,
                     room_max_stream_ordering,
                 )
-                max_pos, deltas = yield self.store.get_current_state_deltas(
+                max_pos, deltas = await self.store.get_current_state_deltas(
                     self._event_pos, room_max_stream_ordering
                 )
-                yield self._handle_state_delta(deltas)
+                await self._handle_state_delta(deltas)
 
                 self._event_pos = max_pos
 
@@ -824,8 +809,7 @@ class PresenceHandler(object):
                     max_pos
                 )
 
-    @defer.inlineCallbacks
-    def _handle_state_delta(self, deltas):
+    async def _handle_state_delta(self, deltas):
         """Process current state deltas to find new joins that need to be
         handled.
         """
@@ -846,13 +830,13 @@ class PresenceHandler(object):
                 # joins.
                 continue
 
-            event = yield self.store.get_event(event_id, allow_none=True)
+            event = await self.store.get_event(event_id, allow_none=True)
             if not event or event.content.get("membership") != Membership.JOIN:
                 # We only care about joins
                 continue
 
             if prev_event_id:
-                prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
+                prev_event = await self.store.get_event(prev_event_id, allow_none=True)
                 if (
                     prev_event
                     and prev_event.content.get("membership") == Membership.JOIN
@@ -860,10 +844,9 @@ class PresenceHandler(object):
                     # Ignore changes to join events.
                     continue
 
-            yield self._on_user_joined_room(room_id, state_key)
+            await self._on_user_joined_room(room_id, state_key)
 
-    @defer.inlineCallbacks
-    def _on_user_joined_room(self, room_id, user_id):
+    async def _on_user_joined_room(self, room_id, user_id):
         """Called when we detect a user joining the room via the current state
         delta stream.
 
@@ -882,8 +865,8 @@ class PresenceHandler(object):
             # TODO: We should be able to filter the hosts down to those that
             # haven't previously seen the user
 
-            state = yield self.current_state_for_user(user_id)
-            hosts = yield self.state.get_current_hosts_in_room(room_id)
+            state = await self.current_state_for_user(user_id)
+            hosts = await self.state.get_current_hosts_in_room(room_id)
 
             # Filter out ourselves.
             hosts = {host for host in hosts if host != self.server_name}
@@ -903,10 +886,10 @@ class PresenceHandler(object):
             # TODO: Check that this is actually a new server joining the
             # room.
 
-            user_ids = yield self.state.get_current_users_in_room(room_id)
+            user_ids = await self.state.get_current_users_in_room(room_id)
             user_ids = list(filter(self.is_mine_id, user_ids))
 
-            states = yield self.current_state_for_users(user_ids)
+            states = await self.current_state_for_users(user_ids)
 
             # Filter out old presence, i.e. offline presence states where
             # the user hasn't been active for a week. We can change this
@@ -996,9 +979,8 @@ class PresenceEventSource(object):
         self.store = hs.get_datastore()
         self.state = hs.get_state_handler()
 
-    @defer.inlineCallbacks
     @log_function
-    def get_new_events(
+    async def get_new_events(
         self,
         user,
         from_key,
@@ -1045,7 +1027,7 @@ class PresenceEventSource(object):
             presence = self.get_presence_handler()
             stream_change_cache = self.store.presence_stream_cache
 
-            users_interested_in = yield self._get_interested_in(user, explicit_room_id)
+            users_interested_in = await self._get_interested_in(user, explicit_room_id)
 
             user_ids_changed = set()
             changed = None
@@ -1071,7 +1053,7 @@ class PresenceEventSource(object):
                 else:
                     user_ids_changed = users_interested_in
 
-            updates = yield presence.current_state_for_users(user_ids_changed)
+            updates = await presence.current_state_for_users(user_ids_changed)
 
         if include_offline:
             return (list(updates.values()), max_token)
@@ -1084,11 +1066,11 @@ class PresenceEventSource(object):
     def get_current_key(self):
         return self.store.get_current_presence_token()
 
-    def get_pagination_rows(self, user, pagination_config, key):
-        return self.get_new_events(user, from_key=None, include_offline=False)
+    async def get_pagination_rows(self, user, pagination_config, key):
+        return await self.get_new_events(user, from_key=None, include_offline=False)
 
-    @cachedInlineCallbacks(num_args=2, cache_context=True)
-    def _get_interested_in(self, user, explicit_room_id, cache_context):
+    @cached(num_args=2, cache_context=True)
+    async def _get_interested_in(self, user, explicit_room_id, cache_context):
         """Returns the set of users that the given user should see presence
         updates for
         """
@@ -1096,13 +1078,13 @@ class PresenceEventSource(object):
         users_interested_in = set()
         users_interested_in.add(user_id)  # So that we receive our own presence
 
-        users_who_share_room = yield self.store.get_users_who_share_room_with_user(
+        users_who_share_room = await self.store.get_users_who_share_room_with_user(
             user_id, on_invalidate=cache_context.invalidate
         )
         users_interested_in.update(users_who_share_room)
 
         if explicit_room_id:
-            user_ids = yield self.store.get_users_in_room(
+            user_ids = await self.store.get_users_in_room(
                 explicit_room_id, on_invalidate=cache_context.invalidate
             )
             users_interested_in.update(user_ids)
@@ -1277,8 +1259,8 @@ def get_interested_parties(store, states):
         2-tuple: `(room_ids_to_states, users_to_states)`,
         with each item being a dict of `entity_name` -> `[UserPresenceState]`
     """
-    room_ids_to_states = {}
-    users_to_states = {}
+    room_ids_to_states = {}  # type: Dict[str, List[UserPresenceState]]
+    users_to_states = {}  # type: Dict[str, List[UserPresenceState]]
     for state in states:
         room_ids = yield store.get_rooms_for_user(state.user_id)
         for room_id in room_ids:
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index ce60ae2e07..ce9d1fae12 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -323,7 +323,11 @@ class ReplicationStreamer(object):
 
         # We need to tell the presence handler that the connection has been
         # lost so that it can handle any ongoing syncs on that connection.
-        self.presence_handler.update_external_syncs_clear(connection.conn_id)
+        run_as_background_process(
+            "update_external_syncs_clear",
+            self.presence_handler.update_external_syncs_clear,
+            connection.conn_id,
+        )
 
 
 def _batch_updates(updates):
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 40eabfe5d9..3844f0e12f 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -3,6 +3,7 @@ import twisted.internet
 import synapse.api.auth
 import synapse.config.homeserver
 import synapse.crypto.keyring
+import synapse.federation.federation_server
 import synapse.federation.sender
 import synapse.federation.transport.client
 import synapse.handlers
@@ -107,5 +108,9 @@ class HomeServer(object):
         self,
     ) -> synapse.replication.tcp.client.ReplicationClientHandler:
         pass
+    def get_federation_registry(
+        self,
+    ) -> synapse.federation.federation_server.FederationHandlerRegistry:
+        pass
     def is_mine_id(self, domain_id: str) -> bool:
         pass
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 64915bafcd..05ea40a7de 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -494,8 +494,10 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
         self.helper.join(room_id, "@test2:server")
 
         # Mark test2 as online, test will be offline with a last_active of 0
-        self.presence_handler.set_state(
-            UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
+        self.get_success(
+            self.presence_handler.set_state(
+                UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
+            )
         )
         self.reactor.pump([0])  # Wait for presence updates to be handled
 
@@ -543,14 +545,18 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
         room_id = self.helper.create_room_as(self.user_id)
 
         # Mark test as online
-        self.presence_handler.set_state(
-            UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE}
+        self.get_success(
+            self.presence_handler.set_state(
+                UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE}
+            )
         )
 
         # Mark test2 as online, test will be offline with a last_active of 0.
         # Note we don't join them to the room yet
-        self.presence_handler.set_state(
-            UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
+        self.get_success(
+            self.presence_handler.set_state(
+                UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
+            )
         )
 
         # Add servers to the room
diff --git a/tox.ini b/tox.ini
index b715ea0bff..4ccfde01b5 100644
--- a/tox.ini
+++ b/tox.ini
@@ -183,6 +183,7 @@ commands = mypy \
             synapse/events/spamcheck.py \
             synapse/federation/sender \
             synapse/federation/transport \
+            synapse/handlers/presence.py \
             synapse/handlers/sync.py \
             synapse/handlers/ui_auth \
             synapse/logging/ \