Add is_dm room field to Sliding Sync /sync (#17429)

Based on
[MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575):
Sliding Sync
This commit is contained in:
Eric Eastwood 2024-07-11 18:19:26 -05:00 committed by GitHub
parent 5a97bbd895
commit fb66e938b2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 70 additions and 29 deletions

View file

@ -0,0 +1 @@
Populate `is_dm` room field in experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint.

View file

@ -291,6 +291,7 @@ class _RoomMembershipForUser:
sender: The person who sent the membership event
newly_joined: Whether the user newly joined the room during the given token
range
is_dm: Whether this user considers this room as a direct-message (DM) room
"""
room_id: str
@ -299,6 +300,7 @@ class _RoomMembershipForUser:
membership: str
sender: Optional[str]
newly_joined: bool
is_dm: bool
def copy_and_replace(self, **kwds: Any) -> "_RoomMembershipForUser":
return attr.evolve(self, **kwds)
@ -613,6 +615,7 @@ class SlidingSyncHandler:
membership=room_for_user.membership,
sender=room_for_user.sender,
newly_joined=False,
is_dm=False,
)
for room_for_user in room_for_user_list
}
@ -652,6 +655,7 @@ class SlidingSyncHandler:
# - 1c) Update room membership events to the point in time of the `to_token`
# - 2) Add back newly_left rooms (> `from_token` and <= `to_token`)
# - 3) Figure out which rooms are `newly_joined`
# - 4) Figure out which rooms are DM's
# 1) -----------------------------------------------------
@ -714,6 +718,7 @@ class SlidingSyncHandler:
membership=first_membership_change_after_to_token.prev_membership,
sender=first_membership_change_after_to_token.prev_sender,
newly_joined=False,
is_dm=False,
)
else:
# If we can't find the previous membership event, we shouldn't
@ -809,6 +814,7 @@ class SlidingSyncHandler:
membership=last_membership_change_in_from_to_range.membership,
sender=last_membership_change_in_from_to_range.sender,
newly_joined=False,
is_dm=False,
)
# 3) Figure out `newly_joined`
@ -846,6 +852,35 @@ class SlidingSyncHandler:
room_id
].copy_and_replace(newly_joined=True)
# 4) Figure out which rooms the user considers to be direct-message (DM) rooms
#
# We're using global account data (`m.direct`) instead of checking for
# `is_direct` on membership events because that property only appears for
# the invitee membership event (doesn't show up for the inviter).
#
# We're unable to take `to_token` into account for global account data since
# we only keep track of the latest account data for the user.
dm_map = await self.store.get_global_account_data_by_type_for_user(
user_id, AccountDataTypes.DIRECT
)
# Flatten out the map. Account data is set by the client so it needs to be
# scrutinized.
dm_room_id_set = set()
if isinstance(dm_map, dict):
for room_ids in dm_map.values():
# Account data should be a list of room IDs. Ignore anything else
if isinstance(room_ids, list):
for room_id in room_ids:
if isinstance(room_id, str):
dm_room_id_set.add(room_id)
# 4) Fixup
for room_id in filtered_sync_room_id_set:
filtered_sync_room_id_set[room_id] = filtered_sync_room_id_set[
room_id
].copy_and_replace(is_dm=room_id in dm_room_id_set)
return filtered_sync_room_id_set
async def filter_rooms(
@ -869,41 +904,24 @@ class SlidingSyncHandler:
A filtered dictionary of room IDs along with membership information in the
room at the time of `to_token`.
"""
user_id = user.to_string()
# TODO: Apply filters
filtered_room_id_set = set(sync_room_map.keys())
# Filter for Direct-Message (DM) rooms
if filters.is_dm is not None:
# We're using global account data (`m.direct`) instead of checking for
# `is_direct` on membership events because that property only appears for
# the invitee membership event (doesn't show up for the inviter). Account
# data is set by the client so it needs to be scrutinized.
#
# We're unable to take `to_token` into account for global account data since
# we only keep track of the latest account data for the user.
dm_map = await self.store.get_global_account_data_by_type_for_user(
user_id, AccountDataTypes.DIRECT
)
# Flatten out the map
dm_room_id_set = set()
if isinstance(dm_map, dict):
for room_ids in dm_map.values():
# Account data should be a list of room IDs. Ignore anything else
if isinstance(room_ids, list):
for room_id in room_ids:
if isinstance(room_id, str):
dm_room_id_set.add(room_id)
if filters.is_dm:
# Only DM rooms please
filtered_room_id_set = filtered_room_id_set.intersection(dm_room_id_set)
filtered_room_id_set = {
room_id
for room_id in filtered_room_id_set
if sync_room_map[room_id].is_dm
}
else:
# Only non-DM rooms please
filtered_room_id_set = filtered_room_id_set.difference(dm_room_id_set)
filtered_room_id_set = {
room_id
for room_id in filtered_room_id_set
if not sync_room_map[room_id].is_dm
}
if filters.spaces:
raise NotImplementedError()
@ -1538,8 +1556,7 @@ class SlidingSyncHandler:
name=room_name,
avatar=room_avatar,
heroes=heroes,
# TODO: Dummy value
is_dm=False,
is_dm=room_membership_for_user_at_to_token.is_dm,
initial=initial,
required_state=list(required_room_state.values()),
timeline_events=timeline_events,

View file

@ -1662,6 +1662,20 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
list(channel.json_body["lists"]["room-invites"]),
)
# Ensure DM's are correctly marked
self.assertDictEqual(
{
room_id: room.get("is_dm")
for room_id, room in channel.json_body["rooms"].items()
},
{
invite_room_id: None,
room_id: None,
invited_dm_room_id: True,
joined_dm_room_id: True,
},
)
def test_sort_list(self) -> None:
"""
Test that the `lists` are sorted by `stream_ordering`
@ -1874,6 +1888,9 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
channel.json_body["rooms"][room_id1]["invited_count"],
0,
)
self.assertIsNone(
channel.json_body["rooms"][room_id1].get("is_dm"),
)
def test_rooms_meta_when_invited(self) -> None:
"""
@ -1955,6 +1972,9 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
channel.json_body["rooms"][room_id1]["invited_count"],
1,
)
self.assertIsNone(
channel.json_body["rooms"][room_id1].get("is_dm"),
)
def test_rooms_meta_when_banned(self) -> None:
"""
@ -2037,6 +2057,9 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
channel.json_body["rooms"][room_id1]["invited_count"],
0,
)
self.assertIsNone(
channel.json_body["rooms"][room_id1].get("is_dm"),
)
def test_rooms_meta_heroes(self) -> None:
"""