1
0
Fork 0
mirror of https://github.com/element-hq/synapse.git synced 2025-04-08 15:24:00 +00:00
This commit is contained in:
Will Hunt 2025-03-28 15:56:03 +00:00 committed by GitHub
commit 561e2c52ac
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 86 additions and 5 deletions

View file

@ -0,0 +1 @@
Add ETag header to the /rooms/$room_id/state endpoint, and return no content on cache hit.

View file

@ -19,6 +19,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
import hashlib
import logging
import random
from http import HTTPStatus
@ -173,7 +174,8 @@ class MessageHandler:
room_id: str,
state_filter: Optional[StateFilter] = None,
at_token: Optional[StreamToken] = None,
) -> List[dict]:
last_hash: Optional[str] = None,
) -> Tuple[List[dict], Optional[str]]:
"""Retrieve all state events for a given room. If the user is
joined to the room then return the current state. If the user has
left the room return the state events from when they left. If an explicit
@ -190,6 +192,8 @@ class MessageHandler:
state based on the current_state_events table.
Returns:
A list of dicts representing state events. [{}, {}, {}]
A hash of the state IDs representing the state events. This is only calculated if
no at_token is given and the user is joined to the room.
Raises:
NotFoundError (404) if the at token does not yield an event
@ -200,6 +204,7 @@ class MessageHandler:
state_filter = StateFilter.all()
user_id = requester.user.to_string()
hash = None
if at_token:
last_event_id = (
@ -239,6 +244,14 @@ class MessageHandler:
state_ids = await self._state_storage_controller.get_current_state_ids(
room_id, state_filter=state_filter
)
hash = hashlib.sha1(
",".join(state_ids.values()).encode("utf-8")
).hexdigest()
# If the requester's hash matches ours, their cache is up to date and we can skip
# fetching events.
if last_hash == hash:
return [], hash
room_state = await self.store.get_events(state_ids.values())
elif membership == Membership.LEAVE:
# If the membership is not JOIN, then the event ID should exist.
@ -257,7 +270,7 @@ class MessageHandler:
self.clock.time_msec(),
config=SerializeEventConfig(requester=requester),
)
return events
return events, hash
async def _user_can_see_state_at_event(
self, user_id: str, room_id: str, event_id: str

View file

@ -715,7 +715,7 @@ class RoomMemberListRestServlet(RestServlet):
membership = parse_string(request, "membership")
not_membership = parse_string(request, "not_membership")
events = await handler.get_state_events(
events, _ = await handler.get_state_events(
room_id=room_id,
requester=requester,
at_token=at_token,
@ -835,13 +835,22 @@ class RoomStateRestServlet(RestServlet):
@cancellable
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, List[JsonDict]]:
) -> Tuple[int, Optional[List[JsonDict]]]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
existing_hash = request.getHeader(b"If-None-Match")
# Get all the current state for this room
events = await self.message_handler.get_state_events(
events, hash = await self.message_handler.get_state_events(
room_id=room_id,
requester=requester,
# Trim quotes from hash.
last_hash=existing_hash.decode("ascii")[1:-1] if existing_hash else None,
)
request.setHeader(b"ETag", f'"{hash}"'.encode("ascii"))
if len(events) == 0:
return 304, None
return 200, events

View file

@ -552,6 +552,64 @@ class RoomStateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual(channel.json_body, {"membership": "join"})
def test_get_state_etag_match(self) -> None:
"""Test that `/rooms/$room_id/state` returns a NOT_MODIFIED response when provided with the correct ETag."""
room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request(
"GET",
"/rooms/%s/state" % room_id,
)
etagheader = channel.headers.getRawHeaders(b"ETag")
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
assert etagheader
channel2 = self.make_request(
"GET",
"/rooms/%s/state" % room_id,
custom_headers=[
(
b"If-None-Match",
etagheader[0],
),
],
)
self.assertEqual(
HTTPStatus.NOT_MODIFIED,
channel2.code,
"Responds with not modified when provided with the correct ETag",
)
self.assertEqual(
etagheader,
channel2.headers.getRawHeaders(b"ETag"),
"returns the same etag",
)
def test_get_state_etag_nonmatch(self) -> None:
"""Test that `/rooms/$room_id/state` returns a normal response to an unrecognised ETag."""
room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request(
"GET",
"/rooms/%s/state" % room_id,
custom_headers=(
(
b"If-None-Match",
'"notavalidetag"',
),
),
)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertCountEqual(
[state_event["type"] for state_event in channel.json_list],
{
"m.room.create",
"m.room.power_levels",
"m.room.join_rules",
"m.room.member",
"m.room.history_visibility",
},
)
class RoomsMemberListTestCase(RoomBase):
"""Tests /rooms/$room_id/members/list REST events."""