mirror of
https://github.com/element-hq/synapse.git
synced 2024-12-14 11:57:44 +00:00
Convert simple_select_list and simple_select_list_txn to return lists of tuples (#16505)
This should use fewer allocations and improves type hints.
This commit is contained in:
parent
c14a7de6af
commit
9407d5ba78
31 changed files with 607 additions and 507 deletions
1
changelog.d/16505.misc
Normal file
1
changelog.d/16505.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Reduce memory allocations.
|
|
@ -103,10 +103,10 @@ class DeactivateAccountHandler:
|
|||
# Attempt to unbind any known bound threepids to this account from identity
|
||||
# server(s).
|
||||
bound_threepids = await self.store.user_get_bound_threepids(user_id)
|
||||
for threepid in bound_threepids:
|
||||
for medium, address in bound_threepids:
|
||||
try:
|
||||
result = await self._identity_handler.try_unbind_threepid(
|
||||
user_id, threepid["medium"], threepid["address"], id_server
|
||||
user_id, medium, address, id_server
|
||||
)
|
||||
except Exception:
|
||||
# Do we want this to be a fatal error or should we carry on?
|
||||
|
|
|
@ -1206,10 +1206,7 @@ class SsoHandler:
|
|||
# We have no guarantee that all the devices of that session are for the same
|
||||
# `user_id`. Hence, we have to iterate over the list of devices and log them out
|
||||
# one by one.
|
||||
for device in devices:
|
||||
user_id = device["user_id"]
|
||||
device_id = device["device_id"]
|
||||
|
||||
for user_id, device_id in devices:
|
||||
# If the user_id associated with that device/session is not the one we got
|
||||
# out of the `sub` claim, skip that device and show log an error.
|
||||
if expected_user_id is not None and user_id != expected_user_id:
|
||||
|
|
|
@ -606,13 +606,16 @@ class DatabasePool:
|
|||
|
||||
If the background updates have not completed, wait 15 sec and check again.
|
||||
"""
|
||||
updates = await self.simple_select_list(
|
||||
updates = cast(
|
||||
List[Tuple[str]],
|
||||
await self.simple_select_list(
|
||||
"background_updates",
|
||||
keyvalues=None,
|
||||
retcols=["update_name"],
|
||||
desc="check_background_updates",
|
||||
),
|
||||
)
|
||||
background_update_names = [x["update_name"] for x in updates]
|
||||
background_update_names = [x[0] for x in updates]
|
||||
|
||||
for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
|
||||
if update_name not in background_update_names:
|
||||
|
@ -1804,9 +1807,9 @@ class DatabasePool:
|
|||
keyvalues: Optional[Dict[str, Any]],
|
||||
retcols: Collection[str],
|
||||
desc: str = "simple_select_list",
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> List[Tuple[Any, ...]]:
|
||||
"""Executes a SELECT query on the named table, which may return zero or
|
||||
more rows, returning the result as a list of dicts.
|
||||
more rows, returning the result as a list of tuples.
|
||||
|
||||
Args:
|
||||
table: the table name
|
||||
|
@ -1817,8 +1820,7 @@ class DatabasePool:
|
|||
desc: description of the transaction, for logging and metrics
|
||||
|
||||
Returns:
|
||||
A list of dictionaries, one per result row, each a mapping between the
|
||||
column names from `retcols` and that column's value for the row.
|
||||
A list of tuples, one per result row, each the retcolumn's value for the row.
|
||||
"""
|
||||
return await self.runInteraction(
|
||||
desc,
|
||||
|
@ -1836,9 +1838,9 @@ class DatabasePool:
|
|||
table: str,
|
||||
keyvalues: Optional[Dict[str, Any]],
|
||||
retcols: Iterable[str],
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> List[Tuple[Any, ...]]:
|
||||
"""Executes a SELECT query on the named table, which may return zero or
|
||||
more rows, returning the result as a list of dicts.
|
||||
more rows, returning the result as a list of tuples.
|
||||
|
||||
Args:
|
||||
txn: Transaction object
|
||||
|
@ -1849,8 +1851,7 @@ class DatabasePool:
|
|||
retcols: the names of the columns to return
|
||||
|
||||
Returns:
|
||||
A list of dictionaries, one per result row, each a mapping between the
|
||||
column names from `retcols` and that column's value for the row.
|
||||
A list of tuples, one per result row, each the retcolumn's value for the row.
|
||||
"""
|
||||
if keyvalues:
|
||||
sql = "SELECT %s FROM %s WHERE %s" % (
|
||||
|
@ -1863,7 +1864,7 @@ class DatabasePool:
|
|||
sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
|
||||
txn.execute(sql)
|
||||
|
||||
return cls.cursor_to_dict(txn)
|
||||
return txn.fetchall()
|
||||
|
||||
async def simple_select_many_batch(
|
||||
self,
|
||||
|
|
|
@ -286,16 +286,20 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
|
|||
|
||||
def get_account_data_for_room_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Dict[str, JsonDict]:
|
||||
rows = self.db_pool.simple_select_list_txn(
|
||||
) -> Dict[str, JsonMapping]:
|
||||
rows = cast(
|
||||
List[Tuple[str, str]],
|
||||
self.db_pool.simple_select_list_txn(
|
||||
txn,
|
||||
"room_account_data",
|
||||
{"user_id": user_id, "room_id": room_id},
|
||||
["account_data_type", "content"],
|
||||
table="room_account_data",
|
||||
keyvalues={"user_id": user_id, "room_id": room_id},
|
||||
retcols=["account_data_type", "content"],
|
||||
),
|
||||
)
|
||||
|
||||
return {
|
||||
row["account_data_type"]: db_to_json(row["content"]) for row in rows
|
||||
account_data_type: db_to_json(content)
|
||||
for account_data_type, content in rows
|
||||
}
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
|
|
|
@ -197,16 +197,21 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
Returns:
|
||||
A list of ApplicationServices, which may be empty.
|
||||
"""
|
||||
results = await self.db_pool.simple_select_list(
|
||||
"application_services_state", {"state": state.value}, ["as_id"]
|
||||
results = cast(
|
||||
List[Tuple[str]],
|
||||
await self.db_pool.simple_select_list(
|
||||
table="application_services_state",
|
||||
keyvalues={"state": state.value},
|
||||
retcols=("as_id",),
|
||||
),
|
||||
)
|
||||
# NB: This assumes this class is linked with ApplicationServiceStore
|
||||
as_list = self.get_app_services()
|
||||
services = []
|
||||
|
||||
for res in results:
|
||||
for (as_id,) in results:
|
||||
for service in as_list:
|
||||
if service.id == res["as_id"]:
|
||||
if service.id == as_id:
|
||||
services.append(service)
|
||||
return services
|
||||
|
||||
|
|
|
@ -508,21 +508,24 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
|
|||
if device_id is not None:
|
||||
keyvalues["device_id"] = device_id
|
||||
|
||||
res = await self.db_pool.simple_select_list(
|
||||
res = cast(
|
||||
List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]],
|
||||
await self.db_pool.simple_select_list(
|
||||
table="devices",
|
||||
keyvalues=keyvalues,
|
||||
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
|
||||
),
|
||||
)
|
||||
|
||||
return {
|
||||
(d["user_id"], d["device_id"]): DeviceLastConnectionInfo(
|
||||
user_id=d["user_id"],
|
||||
device_id=d["device_id"],
|
||||
ip=d["ip"],
|
||||
user_agent=d["user_agent"],
|
||||
last_seen=d["last_seen"],
|
||||
(user_id, device_id): DeviceLastConnectionInfo(
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
ip=ip,
|
||||
user_agent=user_agent,
|
||||
last_seen=last_seen,
|
||||
)
|
||||
for d in res
|
||||
for user_id, ip, user_agent, device_id, last_seen in res
|
||||
}
|
||||
|
||||
async def _get_user_ip_and_agents_from_database(
|
||||
|
|
|
@ -283,7 +283,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||
allow_none=True,
|
||||
)
|
||||
|
||||
async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]:
|
||||
async def get_devices_by_user(
|
||||
self, user_id: str
|
||||
) -> Dict[str, Dict[str, Optional[str]]]:
|
||||
"""Retrieve all of a user's registered devices. Only returns devices
|
||||
that are not marked as hidden.
|
||||
|
||||
|
@ -291,20 +293,26 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||
user_id:
|
||||
Returns:
|
||||
A mapping from device_id to a dict containing "device_id", "user_id"
|
||||
and "display_name" for each device.
|
||||
and "display_name" for each device. Display name may be null.
|
||||
"""
|
||||
devices = await self.db_pool.simple_select_list(
|
||||
devices = cast(
|
||||
List[Tuple[str, str, Optional[str]]],
|
||||
await self.db_pool.simple_select_list(
|
||||
table="devices",
|
||||
keyvalues={"user_id": user_id, "hidden": False},
|
||||
retcols=("user_id", "device_id", "display_name"),
|
||||
desc="get_devices_by_user",
|
||||
),
|
||||
)
|
||||
|
||||
return {d["device_id"]: d for d in devices}
|
||||
return {
|
||||
d[1]: {"user_id": d[0], "device_id": d[1], "display_name": d[2]}
|
||||
for d in devices
|
||||
}
|
||||
|
||||
async def get_devices_by_auth_provider_session_id(
|
||||
self, auth_provider_id: str, auth_provider_session_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""Retrieve the list of devices associated with a SSO IdP session ID.
|
||||
|
||||
Args:
|
||||
|
@ -313,7 +321,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||
Returns:
|
||||
A list of dicts containing the device_id and the user_id of each device
|
||||
"""
|
||||
return await self.db_pool.simple_select_list(
|
||||
return cast(
|
||||
List[Tuple[str, str]],
|
||||
await self.db_pool.simple_select_list(
|
||||
table="device_auth_providers",
|
||||
keyvalues={
|
||||
"auth_provider_id": auth_provider_id,
|
||||
|
@ -321,6 +331,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||
},
|
||||
retcols=("user_id", "device_id"),
|
||||
desc="get_devices_by_auth_provider_session_id",
|
||||
),
|
||||
)
|
||||
|
||||
@trace
|
||||
|
@ -821,15 +832,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||
async def get_cached_devices_for_user(
|
||||
self, user_id: str
|
||||
) -> Mapping[str, JsonMapping]:
|
||||
devices = await self.db_pool.simple_select_list(
|
||||
devices = cast(
|
||||
List[Tuple[str, str]],
|
||||
await self.db_pool.simple_select_list(
|
||||
table="device_lists_remote_cache",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=("device_id", "content"),
|
||||
desc="get_cached_devices_for_user",
|
||||
),
|
||||
)
|
||||
return {
|
||||
device["device_id"]: db_to_json(device["content"]) for device in devices
|
||||
}
|
||||
return {device[0]: db_to_json(device[1]) for device in devices}
|
||||
|
||||
def get_cached_device_list_changes(
|
||||
self,
|
||||
|
@ -1080,7 +1092,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||
The IDs of users whose device lists need resync.
|
||||
"""
|
||||
if user_ids:
|
||||
row_tuples = cast(
|
||||
rows = cast(
|
||||
List[Tuple[str]],
|
||||
await self.db_pool.simple_select_many_batch(
|
||||
table="device_lists_remote_resync",
|
||||
|
@ -1090,11 +1102,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||
desc="get_user_ids_requiring_device_list_resync_with_iterable",
|
||||
),
|
||||
)
|
||||
|
||||
return {row[0] for row in row_tuples}
|
||||
else:
|
||||
rows = cast(
|
||||
List[Dict[str, str]],
|
||||
List[Tuple[str]],
|
||||
await self.db_pool.simple_select_list(
|
||||
table="device_lists_remote_resync",
|
||||
keyvalues=None,
|
||||
|
@ -1103,7 +1113,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||
),
|
||||
)
|
||||
|
||||
return {row["user_id"] for row in rows}
|
||||
return {row[0] for row in rows}
|
||||
|
||||
async def mark_remote_users_device_caches_as_stale(
|
||||
self, user_ids: StrCollection
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, Mapping, Optional, Tuple, cast
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple, cast
|
||||
|
||||
from typing_extensions import Literal, TypedDict
|
||||
|
||||
|
@ -274,11 +274,12 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
|
|||
if session_id:
|
||||
keyvalues["session_id"] = session_id
|
||||
|
||||
rows = await self.db_pool.simple_select_list(
|
||||
rows = cast(
|
||||
List[Tuple[str, str, int, int, int, str]],
|
||||
await self.db_pool.simple_select_list(
|
||||
table="e2e_room_keys",
|
||||
keyvalues=keyvalues,
|
||||
retcols=(
|
||||
"user_id",
|
||||
"room_id",
|
||||
"session_id",
|
||||
"first_message_index",
|
||||
|
@ -287,19 +288,27 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
|
|||
"session_data",
|
||||
),
|
||||
desc="get_e2e_room_keys",
|
||||
),
|
||||
)
|
||||
|
||||
sessions: Dict[
|
||||
Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
|
||||
] = {"rooms": {}}
|
||||
for row in rows:
|
||||
room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}})
|
||||
room_entry["sessions"][row["session_id"]] = {
|
||||
"first_message_index": row["first_message_index"],
|
||||
"forwarded_count": row["forwarded_count"],
|
||||
for (
|
||||
room_id,
|
||||
session_id,
|
||||
first_message_index,
|
||||
forwarded_count,
|
||||
is_verified,
|
||||
session_data,
|
||||
) in rows:
|
||||
room_entry = sessions["rooms"].setdefault(room_id, {"sessions": {}})
|
||||
room_entry["sessions"][session_id] = {
|
||||
"first_message_index": first_message_index,
|
||||
"forwarded_count": forwarded_count,
|
||||
# is_verified must be returned to the client as a boolean
|
||||
"is_verified": bool(row["is_verified"]),
|
||||
"session_data": db_to_json(row["session_data"]),
|
||||
"is_verified": bool(is_verified),
|
||||
"session_data": db_to_json(session_data),
|
||||
}
|
||||
|
||||
return sessions
|
||||
|
|
|
@ -1898,21 +1898,23 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
# keeping only the forward extremities (i.e. the events not referenced
|
||||
# by other events in the queue). We do this so that we can always
|
||||
# backpaginate in all the events we have dropped.
|
||||
rows = await self.db_pool.simple_select_list(
|
||||
rows = cast(
|
||||
List[Tuple[str, str]],
|
||||
await self.db_pool.simple_select_list(
|
||||
table="federation_inbound_events_staging",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcols=("event_id", "event_json"),
|
||||
desc="prune_staged_events_in_room_fetch",
|
||||
),
|
||||
)
|
||||
|
||||
# Find the set of events referenced by those in the queue, as well as
|
||||
# collecting all the event IDs in the queue.
|
||||
referenced_events: Set[str] = set()
|
||||
seen_events: Set[str] = set()
|
||||
for row in rows:
|
||||
event_id = row["event_id"]
|
||||
for event_id, event_json in rows:
|
||||
seen_events.add(event_id)
|
||||
event_d = db_to_json(row["event_json"])
|
||||
event_d = db_to_json(event_json)
|
||||
|
||||
# We don't bother parsing the dicts into full blown event objects,
|
||||
# as that is needlessly expensive.
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, FrozenSet
|
||||
from typing import TYPE_CHECKING, Dict, FrozenSet, List, Tuple, cast
|
||||
|
||||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
||||
from synapse.storage.databases.main import CacheInvalidationWorkerStore
|
||||
|
@ -42,13 +42,16 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore):
|
|||
Returns:
|
||||
the features currently enabled for the user
|
||||
"""
|
||||
enabled = await self.db_pool.simple_select_list(
|
||||
"per_user_experimental_features",
|
||||
{"user_id": user_id, "enabled": True},
|
||||
["feature"],
|
||||
enabled = cast(
|
||||
List[Tuple[str]],
|
||||
await self.db_pool.simple_select_list(
|
||||
table="per_user_experimental_features",
|
||||
keyvalues={"user_id": user_id, "enabled": True},
|
||||
retcols=("feature",),
|
||||
),
|
||||
)
|
||||
|
||||
return frozenset(feature["feature"] for feature in enabled)
|
||||
return frozenset(feature[0] for feature in enabled)
|
||||
|
||||
async def set_features_for_user(
|
||||
self,
|
||||
|
|
|
@ -248,7 +248,9 @@ class KeyStore(CacheInvalidationWorkerStore):
|
|||
|
||||
If we have multiple entries for a given key ID, returns the most recent.
|
||||
"""
|
||||
rows = await self.db_pool.simple_select_list(
|
||||
rows = cast(
|
||||
List[Tuple[str, str, int, int, Union[bytes, memoryview]]],
|
||||
await self.db_pool.simple_select_list(
|
||||
table="server_keys_json",
|
||||
keyvalues={"server_name": server_name},
|
||||
retcols=(
|
||||
|
@ -259,6 +261,7 @@ class KeyStore(CacheInvalidationWorkerStore):
|
|||
"key_json",
|
||||
),
|
||||
desc="get_server_keys_json_for_remote",
|
||||
),
|
||||
)
|
||||
|
||||
if not rows:
|
||||
|
@ -266,14 +269,14 @@ class KeyStore(CacheInvalidationWorkerStore):
|
|||
|
||||
# We sort the rows by ts_added_ms so that the most recently added entry
|
||||
# will stomp over older entries in the dictionary.
|
||||
rows.sort(key=lambda r: r["ts_added_ms"])
|
||||
rows.sort(key=lambda r: r[2])
|
||||
|
||||
return {
|
||||
row["key_id"]: FetchKeyResultForRemote(
|
||||
key_id: FetchKeyResultForRemote(
|
||||
# Cast to bytes since postgresql returns a memoryview.
|
||||
key_json=bytes(row["key_json"]),
|
||||
valid_until_ts=row["ts_valid_until_ms"],
|
||||
added_ts=row["ts_added_ms"],
|
||||
key_json=bytes(key_json),
|
||||
valid_until_ts=ts_valid_until_ms,
|
||||
added_ts=ts_added_ms,
|
||||
)
|
||||
for row in rows
|
||||
for key_id, from_server, ts_added_ms, ts_valid_until_ms, key_json in rows
|
||||
}
|
||||
|
|
|
@ -437,7 +437,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
)
|
||||
|
||||
async def get_local_media_thumbnails(self, media_id: str) -> List[ThumbnailInfo]:
|
||||
rows = await self.db_pool.simple_select_list(
|
||||
rows = cast(
|
||||
List[Tuple[int, int, str, str, int]],
|
||||
await self.db_pool.simple_select_list(
|
||||
"local_media_repository_thumbnails",
|
||||
{"media_id": media_id},
|
||||
(
|
||||
|
@ -448,14 +450,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
"thumbnail_length",
|
||||
),
|
||||
desc="get_local_media_thumbnails",
|
||||
),
|
||||
)
|
||||
return [
|
||||
ThumbnailInfo(
|
||||
width=row["thumbnail_width"],
|
||||
height=row["thumbnail_height"],
|
||||
method=row["thumbnail_method"],
|
||||
type=row["thumbnail_type"],
|
||||
length=row["thumbnail_length"],
|
||||
width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
@ -568,7 +567,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
async def get_remote_media_thumbnails(
|
||||
self, origin: str, media_id: str
|
||||
) -> List[ThumbnailInfo]:
|
||||
rows = await self.db_pool.simple_select_list(
|
||||
rows = cast(
|
||||
List[Tuple[int, int, str, str, int]],
|
||||
await self.db_pool.simple_select_list(
|
||||
"remote_media_cache_thumbnails",
|
||||
{"media_origin": origin, "media_id": media_id},
|
||||
(
|
||||
|
@ -579,14 +580,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
"thumbnail_length",
|
||||
),
|
||||
desc="get_remote_media_thumbnails",
|
||||
),
|
||||
)
|
||||
return [
|
||||
ThumbnailInfo(
|
||||
width=row["thumbnail_width"],
|
||||
height=row["thumbnail_height"],
|
||||
method=row["thumbnail_method"],
|
||||
type=row["thumbnail_type"],
|
||||
length=row["thumbnail_length"],
|
||||
width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
|
|
@ -179,11 +179,12 @@ class PushRulesWorkerStore(
|
|||
|
||||
@cached(max_entries=5000)
|
||||
async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
|
||||
rows = await self.db_pool.simple_select_list(
|
||||
rows = cast(
|
||||
List[Tuple[str, int, int, str, str]],
|
||||
await self.db_pool.simple_select_list(
|
||||
table="push_rules",
|
||||
keyvalues={"user_name": user_id},
|
||||
retcols=(
|
||||
"user_name",
|
||||
"rule_id",
|
||||
"priority_class",
|
||||
"priority",
|
||||
|
@ -191,34 +192,31 @@ class PushRulesWorkerStore(
|
|||
"actions",
|
||||
),
|
||||
desc="get_push_rules_for_user",
|
||||
),
|
||||
)
|
||||
|
||||
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
|
||||
# Sort by highest priority_class, then highest priority.
|
||||
rows.sort(key=lambda row: (-int(row[1]), -int(row[2])))
|
||||
|
||||
enabled_map = await self.get_push_rules_enabled_for_user(user_id)
|
||||
|
||||
return _load_rules(
|
||||
[
|
||||
(
|
||||
row["rule_id"],
|
||||
row["priority_class"],
|
||||
row["conditions"],
|
||||
row["actions"],
|
||||
)
|
||||
for row in rows
|
||||
],
|
||||
[(row[0], row[1], row[3], row[4]) for row in rows],
|
||||
enabled_map,
|
||||
self.hs.config.experimental,
|
||||
)
|
||||
|
||||
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
|
||||
results = await self.db_pool.simple_select_list(
|
||||
results = cast(
|
||||
List[Tuple[str, Optional[Union[int, bool]]]],
|
||||
await self.db_pool.simple_select_list(
|
||||
table="push_rules_enable",
|
||||
keyvalues={"user_name": user_id},
|
||||
retcols=("rule_id", "enabled"),
|
||||
desc="get_push_rules_enabled_for_user",
|
||||
),
|
||||
)
|
||||
return {r["rule_id"]: bool(r["enabled"]) for r in results}
|
||||
return {r[0]: bool(r[1]) for r in results}
|
||||
|
||||
async def have_push_rules_changed_for_user(
|
||||
self, user_id: str, last_id: int
|
||||
|
|
|
@ -371,18 +371,20 @@ class PusherWorkerStore(SQLBaseStore):
|
|||
async def get_throttle_params_by_room(
|
||||
self, pusher_id: int
|
||||
) -> Dict[str, ThrottleParams]:
|
||||
res = await self.db_pool.simple_select_list(
|
||||
res = cast(
|
||||
List[Tuple[str, Optional[int], Optional[int]]],
|
||||
await self.db_pool.simple_select_list(
|
||||
"pusher_throttle",
|
||||
{"pusher": pusher_id},
|
||||
["room_id", "last_sent_ts", "throttle_ms"],
|
||||
desc="get_throttle_params_by_room",
|
||||
),
|
||||
)
|
||||
|
||||
params_by_room = {}
|
||||
for row in res:
|
||||
params_by_room[row["room_id"]] = ThrottleParams(
|
||||
row["last_sent_ts"],
|
||||
row["throttle_ms"],
|
||||
for room_id, last_sent_ts, throttle_ms in res:
|
||||
params_by_room[room_id] = ThrottleParams(
|
||||
last_sent_ts or 0, throttle_ms or 0
|
||||
)
|
||||
|
||||
return params_by_room
|
||||
|
|
|
@ -855,13 +855,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
Returns:
|
||||
Tuples of (auth_provider, external_id)
|
||||
"""
|
||||
res = await self.db_pool.simple_select_list(
|
||||
return cast(
|
||||
List[Tuple[str, str]],
|
||||
await self.db_pool.simple_select_list(
|
||||
table="user_external_ids",
|
||||
keyvalues={"user_id": mxid},
|
||||
retcols=("auth_provider", "external_id"),
|
||||
desc="get_external_ids_by_user",
|
||||
),
|
||||
)
|
||||
return [(r["auth_provider"], r["external_id"]) for r in res]
|
||||
|
||||
async def count_all_users(self) -> int:
|
||||
"""Counts all users registered on the homeserver."""
|
||||
|
@ -997,13 +999,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
)
|
||||
|
||||
async def user_get_threepids(self, user_id: str) -> List[ThreepidResult]:
|
||||
results = await self.db_pool.simple_select_list(
|
||||
results = cast(
|
||||
List[Tuple[str, str, int, int]],
|
||||
await self.db_pool.simple_select_list(
|
||||
"user_threepids",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=["medium", "address", "validated_at", "added_at"],
|
||||
desc="user_get_threepids",
|
||||
),
|
||||
)
|
||||
return [ThreepidResult(**r) for r in results]
|
||||
return [
|
||||
ThreepidResult(
|
||||
medium=r[0],
|
||||
address=r[1],
|
||||
validated_at=r[2],
|
||||
added_at=r[3],
|
||||
)
|
||||
for r in results
|
||||
]
|
||||
|
||||
async def user_delete_threepid(
|
||||
self, user_id: str, medium: str, address: str
|
||||
|
@ -1042,7 +1055,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
desc="add_user_bound_threepid",
|
||||
)
|
||||
|
||||
async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]:
|
||||
async def user_get_bound_threepids(self, user_id: str) -> List[Tuple[str, str]]:
|
||||
"""Get the threepids that a user has bound to an identity server through the homeserver
|
||||
The homeserver remembers where binds to an identity server occurred. Using this
|
||||
method can retrieve those threepids.
|
||||
|
@ -1051,15 +1064,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
user_id: The ID of the user to retrieve threepids for
|
||||
|
||||
Returns:
|
||||
List of dictionaries containing the following keys:
|
||||
medium (str): The medium of the threepid (e.g "email")
|
||||
address (str): The address of the threepid (e.g "bob@example.com")
|
||||
List of tuples of two strings:
|
||||
medium: The medium of the threepid (e.g "email")
|
||||
address: The address of the threepid (e.g "bob@example.com")
|
||||
"""
|
||||
return await self.db_pool.simple_select_list(
|
||||
return cast(
|
||||
List[Tuple[str, str]],
|
||||
await self.db_pool.simple_select_list(
|
||||
table="user_threepid_id_server",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=["medium", "address"],
|
||||
desc="user_get_bound_threepids",
|
||||
),
|
||||
)
|
||||
|
||||
async def remove_user_bound_threepid(
|
||||
|
|
|
@ -384,14 +384,17 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
def get_all_relation_ids_for_event_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> List[str]:
|
||||
rows = self.db_pool.simple_select_list_txn(
|
||||
rows = cast(
|
||||
List[Tuple[str]],
|
||||
self.db_pool.simple_select_list_txn(
|
||||
txn=txn,
|
||||
table="event_relations",
|
||||
keyvalues={"relates_to_id": event_id},
|
||||
retcols=["event_id"],
|
||||
),
|
||||
)
|
||||
|
||||
return [row["event_id"] for row in rows]
|
||||
return [row[0] for row in rows]
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
desc="get_all_relation_ids_for_event",
|
||||
|
|
|
@ -1232,28 +1232,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||
"""
|
||||
room_servers: Dict[str, PartialStateResyncInfo] = {}
|
||||
|
||||
rows = await self.db_pool.simple_select_list(
|
||||
rows = cast(
|
||||
List[Tuple[str, str]],
|
||||
await self.db_pool.simple_select_list(
|
||||
table="partial_state_rooms",
|
||||
keyvalues={},
|
||||
retcols=("room_id", "joined_via"),
|
||||
desc="get_server_which_served_partial_join",
|
||||
),
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
room_id = row["room_id"]
|
||||
joined_via = row["joined_via"]
|
||||
for room_id, joined_via in rows:
|
||||
room_servers[room_id] = PartialStateResyncInfo(joined_via=joined_via)
|
||||
|
||||
rows = await self.db_pool.simple_select_list(
|
||||
rows = cast(
|
||||
List[Tuple[str, str]],
|
||||
await self.db_pool.simple_select_list(
|
||||
"partial_state_rooms_servers",
|
||||
keyvalues=None,
|
||||
retcols=("room_id", "server_name"),
|
||||
desc="get_partial_state_rooms",
|
||||
),
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
room_id = row["room_id"]
|
||||
server_name = row["server_name"]
|
||||
for room_id, server_name in rows:
|
||||
entry = room_servers.get(room_id)
|
||||
if entry is None:
|
||||
# There is a foreign key constraint which enforces that every room_id in
|
||||
|
|
|
@ -1070,13 +1070,16 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
|||
for fully-joined rooms.
|
||||
"""
|
||||
|
||||
rows = await self.db_pool.simple_select_list(
|
||||
rows = cast(
|
||||
List[Tuple[str, Optional[str]]],
|
||||
await self.db_pool.simple_select_list(
|
||||
"current_state_events",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcols=("event_id", "membership"),
|
||||
desc="has_completed_background_updates",
|
||||
),
|
||||
)
|
||||
return {row["event_id"]: row["membership"] for row in rows}
|
||||
return dict(rows)
|
||||
|
||||
# TODO This returns a mutable object, which is generally confusing when using a cache.
|
||||
@cached(max_entries=10000) # type: ignore[synapse-@cached-mutable]
|
||||
|
|
|
@ -45,14 +45,17 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
|||
tag content.
|
||||
"""
|
||||
|
||||
rows = await self.db_pool.simple_select_list(
|
||||
rows = cast(
|
||||
List[Tuple[str, str, str]],
|
||||
await self.db_pool.simple_select_list(
|
||||
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
|
||||
),
|
||||
)
|
||||
|
||||
tags_by_room: Dict[str, Dict[str, JsonDict]] = {}
|
||||
for row in rows:
|
||||
room_tags = tags_by_room.setdefault(row["room_id"], {})
|
||||
room_tags[row["tag"]] = db_to_json(row["content"])
|
||||
for room_id, tag, content in rows:
|
||||
room_tags = tags_by_room.setdefault(room_id, {})
|
||||
room_tags[tag] = db_to_json(content)
|
||||
return tags_by_room
|
||||
|
||||
async def get_all_updated_tags(
|
||||
|
@ -161,13 +164,16 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
|||
Returns:
|
||||
A mapping of tags to tag content.
|
||||
"""
|
||||
rows = await self.db_pool.simple_select_list(
|
||||
rows = cast(
|
||||
List[Tuple[str, str]],
|
||||
await self.db_pool.simple_select_list(
|
||||
table="room_tags",
|
||||
keyvalues={"user_id": user_id, "room_id": room_id},
|
||||
retcols=("tag", "content"),
|
||||
desc="get_tags_for_room",
|
||||
),
|
||||
)
|
||||
return {row["tag"]: db_to_json(row["content"]) for row in rows}
|
||||
return {tag: db_to_json(content) for tag, content in rows}
|
||||
|
||||
async def add_tag_to_room(
|
||||
self, user_id: str, room_id: str, tag: str, content: JsonDict
|
||||
|
|
|
@ -169,13 +169,17 @@ class UIAuthWorkerStore(SQLBaseStore):
|
|||
that auth-type.
|
||||
"""
|
||||
results = {}
|
||||
for row in await self.db_pool.simple_select_list(
|
||||
rows = cast(
|
||||
List[Tuple[str, str]],
|
||||
await self.db_pool.simple_select_list(
|
||||
table="ui_auth_sessions_credentials",
|
||||
keyvalues={"session_id": session_id},
|
||||
retcols=("stage_type", "result"),
|
||||
desc="get_completed_ui_auth_stages",
|
||||
):
|
||||
results[row["stage_type"]] = db_to_json(row["result"])
|
||||
),
|
||||
)
|
||||
for stage_type, result in rows:
|
||||
results[stage_type] = db_to_json(result)
|
||||
|
||||
return results
|
||||
|
||||
|
@ -295,13 +299,15 @@ class UIAuthWorkerStore(SQLBaseStore):
|
|||
Returns:
|
||||
List of user_agent/ip pairs
|
||||
"""
|
||||
rows = await self.db_pool.simple_select_list(
|
||||
return cast(
|
||||
List[Tuple[str, str]],
|
||||
await self.db_pool.simple_select_list(
|
||||
table="ui_auth_sessions_ips",
|
||||
keyvalues={"session_id": session_id},
|
||||
retcols=("user_agent", "ip"),
|
||||
desc="get_user_agents_ips_to_ui_auth_session",
|
||||
),
|
||||
)
|
||||
return [(row["user_agent"], row["ip"]) for row in rows]
|
||||
|
||||
async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
|
||||
"""
|
||||
|
|
|
@ -154,16 +154,22 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
if not prev_group:
|
||||
return _GetStateGroupDelta(None, None)
|
||||
|
||||
delta_ids = self.db_pool.simple_select_list_txn(
|
||||
delta_ids = cast(
|
||||
List[Tuple[str, str, str]],
|
||||
self.db_pool.simple_select_list_txn(
|
||||
txn,
|
||||
table="state_groups_state",
|
||||
keyvalues={"state_group": state_group},
|
||||
retcols=("type", "state_key", "event_id"),
|
||||
),
|
||||
)
|
||||
|
||||
return _GetStateGroupDelta(
|
||||
prev_group,
|
||||
{(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
|
||||
{
|
||||
(event_type, state_key): event_id
|
||||
for event_type, state_key, event_id in delta_ids
|
||||
},
|
||||
)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
|
@ -68,10 +68,14 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
)
|
||||
)
|
||||
|
||||
async def get_all_room_state(self) -> List[Dict[str, Any]]:
|
||||
return await self.store.db_pool.simple_select_list(
|
||||
"room_stats_state", None, retcols=("name", "topic", "canonical_alias")
|
||||
async def get_all_room_state(self) -> List[Optional[str]]:
|
||||
rows = cast(
|
||||
List[Tuple[Optional[str]]],
|
||||
await self.store.db_pool.simple_select_list(
|
||||
"room_stats_state", None, retcols=("topic",)
|
||||
),
|
||||
)
|
||||
return [r[0] for r in rows]
|
||||
|
||||
def _get_current_stats(
|
||||
self, stats_type: str, stat_id: str
|
||||
|
@ -130,7 +134,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
|||
r = self.get_success(self.get_all_room_state())
|
||||
|
||||
self.assertEqual(len(r), 1)
|
||||
self.assertEqual(r[0]["topic"], "foo")
|
||||
self.assertEqual(r[0], "foo")
|
||||
|
||||
def test_create_user(self) -> None:
|
||||
"""
|
||||
|
|
|
@ -117,7 +117,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
|
|||
if expected_row is not None:
|
||||
columns += expected_row.keys()
|
||||
|
||||
rows = self.get_success(
|
||||
row_tuples = self.get_success(
|
||||
self.store.db_pool.simple_select_list(
|
||||
table=table,
|
||||
keyvalues={
|
||||
|
@ -134,22 +134,22 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
|
|||
|
||||
if expected_row is not None:
|
||||
self.assertEqual(
|
||||
len(rows),
|
||||
len(row_tuples),
|
||||
1,
|
||||
f"Background update did not leave behind latest receipt in {table}",
|
||||
)
|
||||
self.assertEqual(
|
||||
rows[0],
|
||||
{
|
||||
"room_id": room_id,
|
||||
"receipt_type": receipt_type,
|
||||
"user_id": user_id,
|
||||
**expected_row,
|
||||
},
|
||||
row_tuples[0],
|
||||
(
|
||||
room_id,
|
||||
receipt_type,
|
||||
user_id,
|
||||
*expected_row.values(),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.assertEqual(
|
||||
len(rows),
|
||||
len(row_tuples),
|
||||
0,
|
||||
f"Background update did not remove all duplicate receipts from {table}",
|
||||
)
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import secrets
|
||||
from typing import Generator, Tuple
|
||||
from typing import Generator, List, Tuple, cast
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
|
@ -47,15 +47,15 @@ class UpdateUpsertManyTests(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
def _dump_table_to_tuple(self) -> Generator[Tuple[int, str, str], None, None]:
|
||||
res = self.get_success(
|
||||
yield from cast(
|
||||
List[Tuple[int, str, str]],
|
||||
self.get_success(
|
||||
self.storage.db_pool.simple_select_list(
|
||||
self.table_name, None, ["id, username, value"]
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
for i in res:
|
||||
yield (i["id"], i["username"], i["value"])
|
||||
|
||||
def test_upsert_many(self) -> None:
|
||||
"""
|
||||
Upsert_many will perform the upsert operation across a batch of data.
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import List, Tuple, cast
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import yaml
|
||||
|
@ -526,15 +527,18 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
|
|||
self.wait_for_background_updates()
|
||||
|
||||
# Check the correct values are in the new table.
|
||||
rows = self.get_success(
|
||||
rows = cast(
|
||||
List[Tuple[int, int]],
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_select_list(
|
||||
table="test_constraint",
|
||||
keyvalues={},
|
||||
retcols=("a", "b"),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}])
|
||||
self.assertCountEqual(rows, [(1, 1), (3, 3)])
|
||||
|
||||
# And check that invalid rows get correctly rejected.
|
||||
self.get_failure(
|
||||
|
@ -640,14 +644,17 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
|
|||
self.wait_for_background_updates()
|
||||
|
||||
# Check the correct values are in the new table.
|
||||
rows = self.get_success(
|
||||
rows = cast(
|
||||
List[Tuple[int, int]],
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_select_list(
|
||||
table="test_constraint",
|
||||
keyvalues={},
|
||||
retcols=("a", "b"),
|
||||
)
|
||||
),
|
||||
)
|
||||
self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}])
|
||||
self.assertCountEqual(rows, [(1, 1), (3, 3)])
|
||||
|
||||
# And check that invalid rows get correctly rejected.
|
||||
self.get_failure(
|
||||
|
|
|
@ -146,7 +146,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
|||
@defer.inlineCallbacks
|
||||
def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]:
|
||||
self.mock_txn.rowcount = 3
|
||||
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
|
||||
self.mock_txn.fetchall.return_value = [(1,), (2,), (3,)]
|
||||
self.mock_txn.description = (("colA", None, None, None, None, None, None),)
|
||||
|
||||
ret = yield defer.ensureDeferred(
|
||||
|
@ -155,7 +155,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
|||
)
|
||||
)
|
||||
|
||||
self.assertEqual([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret)
|
||||
self.assertEqual([(1,), (2,), (3,)], ret)
|
||||
self.mock_txn.execute.assert_called_with(
|
||||
"SELECT colA FROM tablename WHERE keycol = ?", ["A set"]
|
||||
)
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from parameterized import parameterized
|
||||
|
@ -97,26 +97,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
|||
self.reactor.advance(200)
|
||||
self.pump(0)
|
||||
|
||||
result = self.get_success(
|
||||
result = cast(
|
||||
List[Tuple[str, str, str, Optional[str], int]],
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_select_list(
|
||||
table="user_ips",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
|
||||
retcols=[
|
||||
"access_token",
|
||||
"ip",
|
||||
"user_agent",
|
||||
"device_id",
|
||||
"last_seen",
|
||||
],
|
||||
desc="get_user_ip_and_agents",
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
result,
|
||||
[
|
||||
{
|
||||
"access_token": "access_token",
|
||||
"ip": "ip",
|
||||
"user_agent": "user_agent",
|
||||
"device_id": None,
|
||||
"last_seen": 12345678000,
|
||||
}
|
||||
],
|
||||
result, [("access_token", "ip", "user_agent", None, 12345678000)]
|
||||
)
|
||||
|
||||
# Add another & trigger the storage loop
|
||||
|
@ -128,26 +128,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
|||
self.reactor.advance(10)
|
||||
self.pump(0)
|
||||
|
||||
result = self.get_success(
|
||||
result = cast(
|
||||
List[Tuple[str, str, str, Optional[str], int]],
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_select_list(
|
||||
table="user_ips",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
|
||||
retcols=[
|
||||
"access_token",
|
||||
"ip",
|
||||
"user_agent",
|
||||
"device_id",
|
||||
"last_seen",
|
||||
],
|
||||
desc="get_user_ip_and_agents",
|
||||
)
|
||||
),
|
||||
)
|
||||
# Only one result, has been upserted.
|
||||
self.assertEqual(
|
||||
result,
|
||||
[
|
||||
{
|
||||
"access_token": "access_token",
|
||||
"ip": "ip",
|
||||
"user_agent": "user_agent",
|
||||
"device_id": None,
|
||||
"last_seen": 12345878000,
|
||||
}
|
||||
],
|
||||
result, [("access_token", "ip", "user_agent", None, 12345878000)]
|
||||
)
|
||||
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
|
@ -177,25 +177,23 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
|||
self.reactor.advance(10)
|
||||
else:
|
||||
# Check that the new IP and user agent has not been stored yet
|
||||
db_result = self.get_success(
|
||||
db_result = cast(
|
||||
List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]],
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_select_list(
|
||||
table="devices",
|
||||
keyvalues={},
|
||||
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
|
||||
retcols=(
|
||||
"user_id",
|
||||
"ip",
|
||||
"user_agent",
|
||||
"device_id",
|
||||
"last_seen",
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
db_result,
|
||||
[
|
||||
{
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"ip": None,
|
||||
"user_agent": None,
|
||||
"last_seen": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
self.assertEqual(db_result, [(user_id, None, None, device_id, None)])
|
||||
|
||||
result = self.get_success(
|
||||
self.store.get_last_client_ip_by_device(user_id, device_id)
|
||||
|
@ -261,30 +259,21 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
# Check that the new IP and user agent has not been stored yet
|
||||
db_result = self.get_success(
|
||||
db_result = cast(
|
||||
List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]],
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_select_list(
|
||||
table="devices",
|
||||
keyvalues={},
|
||||
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
|
||||
),
|
||||
),
|
||||
)
|
||||
self.assertCountEqual(
|
||||
db_result,
|
||||
[
|
||||
{
|
||||
"user_id": user_id,
|
||||
"device_id": device_id_1,
|
||||
"ip": "ip_1",
|
||||
"user_agent": "user_agent_1",
|
||||
"last_seen": 12345678000,
|
||||
},
|
||||
{
|
||||
"user_id": user_id,
|
||||
"device_id": device_id_2,
|
||||
"ip": "ip_2",
|
||||
"user_agent": "user_agent_2",
|
||||
"last_seen": 12345678000,
|
||||
},
|
||||
(user_id, "ip_1", "user_agent_1", device_id_1, 12345678000),
|
||||
(user_id, "ip_2", "user_agent_2", device_id_2, 12345678000),
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -385,28 +374,21 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
# Check that the new IP and user agent has not been stored yet
|
||||
db_result = self.get_success(
|
||||
db_result = cast(
|
||||
List[Tuple[str, str, str, int]],
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_select_list(
|
||||
table="user_ips",
|
||||
keyvalues={},
|
||||
retcols=("access_token", "ip", "user_agent", "last_seen"),
|
||||
),
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
db_result,
|
||||
[
|
||||
{
|
||||
"access_token": "access_token",
|
||||
"ip": "ip_1",
|
||||
"user_agent": "user_agent_1",
|
||||
"last_seen": 12345678000,
|
||||
},
|
||||
{
|
||||
"access_token": "access_token",
|
||||
"ip": "ip_2",
|
||||
"user_agent": "user_agent_2",
|
||||
"last_seen": 12345678000,
|
||||
},
|
||||
("access_token", "ip_1", "user_agent_1", 12345678000),
|
||||
("access_token", "ip_2", "user_agent_2", 12345678000),
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -600,39 +582,49 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
|||
self.reactor.advance(200)
|
||||
|
||||
# We should see that in the DB
|
||||
result = self.get_success(
|
||||
result = cast(
|
||||
List[Tuple[str, str, str, Optional[str], int]],
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_select_list(
|
||||
table="user_ips",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
|
||||
retcols=[
|
||||
"access_token",
|
||||
"ip",
|
||||
"user_agent",
|
||||
"device_id",
|
||||
"last_seen",
|
||||
],
|
||||
desc="get_user_ip_and_agents",
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
result,
|
||||
[
|
||||
{
|
||||
"access_token": "access_token",
|
||||
"ip": "ip",
|
||||
"user_agent": "user_agent",
|
||||
"device_id": device_id,
|
||||
"last_seen": 0,
|
||||
}
|
||||
],
|
||||
[("access_token", "ip", "user_agent", device_id, 0)],
|
||||
)
|
||||
|
||||
# Now advance by a couple of months
|
||||
self.reactor.advance(60 * 24 * 60 * 60)
|
||||
|
||||
# We should get no results.
|
||||
result = self.get_success(
|
||||
result = cast(
|
||||
List[Tuple[str, str, str, Optional[str], int]],
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_select_list(
|
||||
table="user_ips",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
|
||||
retcols=[
|
||||
"access_token",
|
||||
"ip",
|
||||
"user_agent",
|
||||
"device_id",
|
||||
"last_seen",
|
||||
],
|
||||
desc="get_user_ip_and_agents",
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
self.assertEqual(result, [])
|
||||
|
@ -696,28 +688,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
|||
self.reactor.advance(200)
|
||||
|
||||
# We should see that in the DB
|
||||
result = self.get_success(
|
||||
result = cast(
|
||||
List[Tuple[str, str, str, Optional[str], int]],
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_select_list(
|
||||
table="user_ips",
|
||||
keyvalues={},
|
||||
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
|
||||
retcols=[
|
||||
"access_token",
|
||||
"ip",
|
||||
"user_agent",
|
||||
"device_id",
|
||||
"last_seen",
|
||||
],
|
||||
desc="get_user_ip_and_agents",
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
# ensure user1 is filtered out
|
||||
self.assertEqual(
|
||||
result,
|
||||
[
|
||||
{
|
||||
"access_token": access_token2,
|
||||
"ip": "ip",
|
||||
"user_agent": "user_agent",
|
||||
"device_id": device_id2,
|
||||
"last_seen": 0,
|
||||
}
|
||||
],
|
||||
)
|
||||
self.assertEqual(result, [(access_token2, "ip", "user_agent", device_id2, 0)])
|
||||
|
||||
|
||||
class ClientIpAuthTestCase(unittest.HomeserverTestCase):
|
||||
|
|
|
@ -12,6 +12,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Optional, Tuple, cast
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.constants import Membership
|
||||
|
@ -110,21 +112,24 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
|
|||
def test__null_byte_in_display_name_properly_handled(self) -> None:
|
||||
room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
|
||||
|
||||
res = self.get_success(
|
||||
res = cast(
|
||||
List[Tuple[Optional[str], str]],
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_select_list(
|
||||
"room_memberships",
|
||||
{"user_id": "@alice:test"},
|
||||
["display_name", "event_id"],
|
||||
)
|
||||
),
|
||||
)
|
||||
# Check that we only got one result back
|
||||
self.assertEqual(len(res), 1)
|
||||
|
||||
# Check that alice's display name is "alice"
|
||||
self.assertEqual(res[0]["display_name"], "alice")
|
||||
self.assertEqual(res[0][0], "alice")
|
||||
|
||||
# Grab the event_id to use later
|
||||
event_id = res[0]["event_id"]
|
||||
event_id = res[0][1]
|
||||
|
||||
# Create a profile with the offending null byte in the display name
|
||||
new_profile = {"displayname": "ali\u0000ce"}
|
||||
|
@ -139,21 +144,24 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
|
|||
tok=self.t_alice,
|
||||
)
|
||||
|
||||
res2 = self.get_success(
|
||||
res2 = cast(
|
||||
List[Tuple[Optional[str], str]],
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_select_list(
|
||||
"room_memberships",
|
||||
{"user_id": "@alice:test"},
|
||||
["display_name", "event_id"],
|
||||
)
|
||||
),
|
||||
)
|
||||
# Check that we only have two results
|
||||
self.assertEqual(len(res2), 2)
|
||||
|
||||
# Filter out the previous event using the event_id we grabbed above
|
||||
row = [row for row in res2 if row["event_id"] != event_id]
|
||||
row = [row for row in res2 if row[1] != event_id]
|
||||
|
||||
# Check that alice's display name is now None
|
||||
self.assertEqual(row[0]["display_name"], None)
|
||||
self.assertIsNone(row[0][0])
|
||||
|
||||
def test_room_is_locally_forgotten(self) -> None:
|
||||
"""Test that when the last local user has forgotten a room it is known as forgotten."""
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import List, Tuple, cast
|
||||
|
||||
from immutabledict import immutabledict
|
||||
|
||||
|
@ -584,18 +585,21 @@ class StateStoreTestCase(HomeserverTestCase):
|
|||
)
|
||||
|
||||
# check that only state events are in state_groups, and all state events are in state_groups
|
||||
res = self.get_success(
|
||||
res = cast(
|
||||
List[Tuple[str]],
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_select_list(
|
||||
table="state_groups",
|
||||
keyvalues=None,
|
||||
retcols=("event_id",),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
events = []
|
||||
for result in res:
|
||||
self.assertNotIn(event3.event_id, result)
|
||||
events.append(result.get("event_id"))
|
||||
self.assertNotIn(event3.event_id, result) # XXX
|
||||
events.append(result[0])
|
||||
|
||||
for event, _ in processed_events_and_context:
|
||||
if event.is_state():
|
||||
|
@ -606,23 +610,29 @@ class StateStoreTestCase(HomeserverTestCase):
|
|||
# has an entry and prev event in state_group_edges
|
||||
for event, context in processed_events_and_context:
|
||||
if event.is_state():
|
||||
state = self.get_success(
|
||||
state = cast(
|
||||
List[Tuple[str, str]],
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_select_list(
|
||||
table="state_groups_state",
|
||||
keyvalues={"state_group": context.state_group_after_event},
|
||||
retcols=("type", "state_key"),
|
||||
)
|
||||
),
|
||||
)
|
||||
self.assertEqual(event.type, state[0].get("type"))
|
||||
self.assertEqual(event.state_key, state[0].get("state_key"))
|
||||
self.assertEqual(event.type, state[0][0])
|
||||
self.assertEqual(event.state_key, state[0][1])
|
||||
|
||||
groups = self.get_success(
|
||||
groups = cast(
|
||||
List[Tuple[str]],
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_select_list(
|
||||
table="state_group_edges",
|
||||
keyvalues={"state_group": str(context.state_group_after_event)},
|
||||
retcols=("*",),
|
||||
keyvalues={
|
||||
"state_group": str(context.state_group_after_event)
|
||||
},
|
||||
retcols=("prev_state_group",),
|
||||
)
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
context.state_group_before_event, groups[0].get("prev_state_group")
|
||||
)
|
||||
self.assertEqual(context.state_group_before_event, groups[0][0])
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import re
|
||||
from typing import Any, Dict, Set, Tuple
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, cast
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
|
@ -62,14 +62,13 @@ class GetUserDirectoryTables:
|
|||
Returns a list of tuples (user_id, room_id) where room_id is public and
|
||||
contains the user with the given id.
|
||||
"""
|
||||
r = await self.store.db_pool.simple_select_list(
|
||||
r = cast(
|
||||
List[Tuple[str, str]],
|
||||
await self.store.db_pool.simple_select_list(
|
||||
"users_in_public_rooms", None, ("user_id", "room_id")
|
||||
),
|
||||
)
|
||||
|
||||
retval = set()
|
||||
for i in r:
|
||||
retval.add((i["user_id"], i["room_id"]))
|
||||
return retval
|
||||
return set(r)
|
||||
|
||||
async def get_users_who_share_private_rooms(self) -> Set[Tuple[str, str, str]]:
|
||||
"""Fetch the entire `users_who_share_private_rooms` table.
|
||||
|
@ -78,27 +77,30 @@ class GetUserDirectoryTables:
|
|||
to the rows of `users_who_share_private_rooms`.
|
||||
"""
|
||||
|
||||
rows = await self.store.db_pool.simple_select_list(
|
||||
rows = cast(
|
||||
List[Tuple[str, str, str]],
|
||||
await self.store.db_pool.simple_select_list(
|
||||
"users_who_share_private_rooms",
|
||||
None,
|
||||
["user_id", "other_user_id", "room_id"],
|
||||
),
|
||||
)
|
||||
rv = set()
|
||||
for row in rows:
|
||||
rv.add((row["user_id"], row["other_user_id"], row["room_id"]))
|
||||
return rv
|
||||
return set(rows)
|
||||
|
||||
async def get_users_in_user_directory(self) -> Set[str]:
|
||||
"""Fetch the set of users in the `user_directory` table.
|
||||
|
||||
This is useful when checking we've correctly excluded users from the directory.
|
||||
"""
|
||||
result = await self.store.db_pool.simple_select_list(
|
||||
result = cast(
|
||||
List[Tuple[str]],
|
||||
await self.store.db_pool.simple_select_list(
|
||||
"user_directory",
|
||||
None,
|
||||
["user_id"],
|
||||
),
|
||||
)
|
||||
return {row["user_id"] for row in result}
|
||||
return {row[0] for row in result}
|
||||
|
||||
async def get_profiles_in_user_directory(self) -> Dict[str, ProfileInfo]:
|
||||
"""Fetch users and their profiles from the `user_directory` table.
|
||||
|
@ -107,16 +109,17 @@ class GetUserDirectoryTables:
|
|||
It's almost the entire contents of the `user_directory` table: the only
|
||||
thing missing is an unused room_id column.
|
||||
"""
|
||||
rows = await self.store.db_pool.simple_select_list(
|
||||
rows = cast(
|
||||
List[Tuple[str, Optional[str], Optional[str]]],
|
||||
await self.store.db_pool.simple_select_list(
|
||||
"user_directory",
|
||||
None,
|
||||
("user_id", "display_name", "avatar_url"),
|
||||
),
|
||||
)
|
||||
return {
|
||||
row["user_id"]: ProfileInfo(
|
||||
display_name=row["display_name"], avatar_url=row["avatar_url"]
|
||||
)
|
||||
for row in rows
|
||||
user_id: ProfileInfo(display_name=display_name, avatar_url=avatar_url)
|
||||
for user_id, display_name, avatar_url in rows
|
||||
}
|
||||
|
||||
async def get_tables(
|
||||
|
|
Loading…
Reference in a new issue