mirror of
https://github.com/element-hq/synapse.git
synced 2024-12-15 17:51:10 +00:00
Sliding Sync: Add cache to get_tags_for_room(...)
(#17730)
Add cache to `get_tags_for_room(...)` This helps Sliding Sync because `get_tags_for_room(...)` is going to be used in https://github.com/element-hq/synapse/pull/17695 Essentially, we're just trying to match `get_account_data_for_room(...)` which already has a tree cache.
This commit is contained in:
parent
a9c0e27eb7
commit
83fc225030
5 changed files with 21 additions and 6 deletions
1
changelog.d/17730.misc
Normal file
1
changelog.d/17730.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add cache to `get_tags_for_room(...)`.
|
|
@ -33,7 +33,7 @@ from synapse.replication.http.account_data import (
|
||||||
ReplicationRemoveUserAccountDataRestServlet,
|
ReplicationRemoveUserAccountDataRestServlet,
|
||||||
)
|
)
|
||||||
from synapse.streams import EventSource
|
from synapse.streams import EventSource
|
||||||
from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID
|
from synapse.types import JsonDict, JsonMapping, StrCollection, StreamKeyType, UserID
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -253,7 +253,7 @@ class AccountDataHandler:
|
||||||
return response["max_stream_id"]
|
return response["max_stream_id"]
|
||||||
|
|
||||||
async def add_tag_to_room(
|
async def add_tag_to_room(
|
||||||
self, user_id: str, room_id: str, tag: str, content: JsonDict
|
self, user_id: str, room_id: str, tag: str, content: JsonMapping
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Add a tag to a room for a user.
|
"""Add a tag to a room for a user.
|
||||||
|
|
||||||
|
|
|
@ -471,6 +471,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
self._attempt_to_invalidate_cache("get_account_data_for_room", None)
|
self._attempt_to_invalidate_cache("get_account_data_for_room", None)
|
||||||
self._attempt_to_invalidate_cache("get_account_data_for_room_and_type", None)
|
self._attempt_to_invalidate_cache("get_account_data_for_room_and_type", None)
|
||||||
|
self._attempt_to_invalidate_cache("get_tags_for_room", None)
|
||||||
self._attempt_to_invalidate_cache("get_aliases_for_room", (room_id,))
|
self._attempt_to_invalidate_cache("get_aliases_for_room", (room_id,))
|
||||||
self._attempt_to_invalidate_cache("get_latest_event_ids_in_room", (room_id,))
|
self._attempt_to_invalidate_cache("get_latest_event_ids_in_room", (room_id,))
|
||||||
self._attempt_to_invalidate_cache("_get_forward_extremeties_for_room", None)
|
self._attempt_to_invalidate_cache("_get_forward_extremeties_for_room", None)
|
||||||
|
|
|
@ -158,9 +158,10 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@cached(num_args=2, tree=True)
|
||||||
async def get_tags_for_room(
|
async def get_tags_for_room(
|
||||||
self, user_id: str, room_id: str
|
self, user_id: str, room_id: str
|
||||||
) -> Dict[str, JsonDict]:
|
) -> Mapping[str, JsonMapping]:
|
||||||
"""Get all the tags for the given room
|
"""Get all the tags for the given room
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -182,7 +183,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
||||||
return {tag: db_to_json(content) for tag, content in rows}
|
return {tag: db_to_json(content) for tag, content in rows}
|
||||||
|
|
||||||
async def add_tag_to_room(
|
async def add_tag_to_room(
|
||||||
self, user_id: str, room_id: str, tag: str, content: JsonDict
|
self, user_id: str, room_id: str, tag: str, content: JsonMapping
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Add a tag to a room for a user.
|
"""Add a tag to a room for a user.
|
||||||
|
|
||||||
|
@ -213,6 +214,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
||||||
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
|
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
|
||||||
|
|
||||||
self.get_tags_for_user.invalidate((user_id,))
|
self.get_tags_for_user.invalidate((user_id,))
|
||||||
|
self.get_tags_for_room.invalidate((user_id, room_id))
|
||||||
|
|
||||||
return self._account_data_id_gen.get_current_token()
|
return self._account_data_id_gen.get_current_token()
|
||||||
|
|
||||||
|
@ -237,6 +239,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
||||||
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
|
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
|
||||||
|
|
||||||
self.get_tags_for_user.invalidate((user_id,))
|
self.get_tags_for_user.invalidate((user_id,))
|
||||||
|
self.get_tags_for_room.invalidate((user_id, room_id))
|
||||||
|
|
||||||
return self._account_data_id_gen.get_current_token()
|
return self._account_data_id_gen.get_current_token()
|
||||||
|
|
||||||
|
@ -290,9 +293,19 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
||||||
rows: Iterable[Any],
|
rows: Iterable[Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
if stream_name == AccountDataStream.NAME:
|
if stream_name == AccountDataStream.NAME:
|
||||||
for row in rows:
|
# Cast is safe because the `AccountDataStream` should only be giving us
|
||||||
|
# `AccountDataStreamRow`
|
||||||
|
account_data_stream_rows: List[AccountDataStream.AccountDataStreamRow] = (
|
||||||
|
cast(List[AccountDataStream.AccountDataStreamRow], rows)
|
||||||
|
)
|
||||||
|
|
||||||
|
for row in account_data_stream_rows:
|
||||||
if row.data_type == AccountDataTypes.TAG:
|
if row.data_type == AccountDataTypes.TAG:
|
||||||
self.get_tags_for_user.invalidate((row.user_id,))
|
self.get_tags_for_user.invalidate((row.user_id,))
|
||||||
|
if row.room_id:
|
||||||
|
self.get_tags_for_room.invalidate((row.user_id, row.room_id))
|
||||||
|
else:
|
||||||
|
self.get_tags_for_room.invalidate((row.user_id,))
|
||||||
self._account_data_stream_cache.entity_has_changed(
|
self._account_data_stream_cache.entity_has_changed(
|
||||||
row.user_id, token
|
row.user_id, token
|
||||||
)
|
)
|
||||||
|
|
|
@ -89,7 +89,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||||
return_value="!something:localhost"
|
return_value="!something:localhost"
|
||||||
)
|
)
|
||||||
self._rlsn._store.add_tag_to_room = AsyncMock(return_value=None) # type: ignore[method-assign]
|
self._rlsn._store.add_tag_to_room = AsyncMock(return_value=None) # type: ignore[method-assign]
|
||||||
self._rlsn._store.get_tags_for_room = AsyncMock(return_value={}) # type: ignore[method-assign]
|
self._rlsn._store.get_tags_for_room = AsyncMock(return_value={})
|
||||||
|
|
||||||
@override_config({"hs_disabled": True})
|
@override_config({"hs_disabled": True})
|
||||||
def test_maybe_send_server_notice_disabled_hs(self) -> None:
|
def test_maybe_send_server_notice_disabled_hs(self) -> None:
|
||||||
|
|
Loading…
Reference in a new issue