Fix typing for SyncHandler (#8237)

This commit is contained in:
Erik Johnston 2020-09-03 12:54:10 +01:00 committed by GitHub
parent 6f6f371a87
commit 5bfc79486d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 13 additions and 10 deletions

1
changelog.d/8237.misc Normal file
View file

@ -0,0 +1 @@
Fix type hints in `SyncHandler`.

View file

@ -16,7 +16,7 @@
import itertools import itertools
import logging import logging
from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
import attr import attr
from prometheus_client import Counter from prometheus_client import Counter
@ -44,6 +44,9 @@ from synapse.util.caches.response_cache import ResponseCache
from synapse.util.metrics import Measure, measure_func from synapse.util.metrics import Measure, measure_func
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Debug logger for https://github.com/matrix-org/synapse/issues/4422 # Debug logger for https://github.com/matrix-org/synapse/issues/4422
@ -244,7 +247,7 @@ class SyncResult:
class SyncHandler(object): class SyncHandler(object):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs_config = hs.config self.hs_config = hs.config
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
@ -717,9 +720,8 @@ class SyncHandler(object):
] ]
missing_hero_state = await self.store.get_events(missing_hero_event_ids) missing_hero_state = await self.store.get_events(missing_hero_event_ids)
missing_hero_state = missing_hero_state.values()
for s in missing_hero_state: for s in missing_hero_state.values():
cache.set(s.state_key, s.event_id) cache.set(s.state_key, s.event_id)
state[(EventTypes.Member, s.state_key)] = s state[(EventTypes.Member, s.state_key)] = s
@ -1771,7 +1773,7 @@ class SyncHandler(object):
ignored_users: Set[str], ignored_users: Set[str],
room_builder: "RoomSyncResultBuilder", room_builder: "RoomSyncResultBuilder",
ephemeral: List[JsonDict], ephemeral: List[JsonDict],
tags: Optional[List[JsonDict]], tags: Optional[Dict[str, Dict[str, Any]]],
account_data: Dict[str, JsonDict], account_data: Dict[str, JsonDict],
always_include: bool = False, always_include: bool = False,
): ):

View file

@ -298,8 +298,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return None return None
async def get_rooms_for_local_user_where_membership_is( async def get_rooms_for_local_user_where_membership_is(
self, user_id: str, membership_list: List[str] self, user_id: str, membership_list: Collection[str]
) -> Optional[List[RoomsForUser]]: ) -> List[RoomsForUser]:
"""Get all the rooms for this *local* user where the membership for this user """Get all the rooms for this *local* user where the membership for this user
matches one in the membership list. matches one in the membership list.
@ -314,7 +314,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
The RoomsForUser that the user matches the membership types. The RoomsForUser that the user matches the membership types.
""" """
if not membership_list: if not membership_list:
return None return []
rooms = await self.db_pool.runInteraction( rooms = await self.db_pool.runInteraction(
"get_rooms_for_local_user_where_membership_is", "get_rooms_for_local_user_where_membership_is",

View file

@ -43,7 +43,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
) )
tags_by_room = {} tags_by_room = {} # type: Dict[str, Dict[str, JsonDict]]
for row in rows: for row in rows:
room_tags = tags_by_room.setdefault(row["room_id"], {}) room_tags = tags_by_room.setdefault(row["room_id"], {})
room_tags[row["tag"]] = db_to_json(row["content"]) room_tags[row["tag"]] = db_to_json(row["content"])
@ -123,7 +123,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
async def get_updated_tags( async def get_updated_tags(
self, user_id: str, stream_id: int self, user_id: str, stream_id: int
) -> Dict[str, List[str]]: ) -> Dict[str, Dict[str, JsonDict]]:
"""Get all the tags for the rooms where the tags have changed since the """Get all the tags for the rooms where the tags have changed since the
given version given version