Improvements to admin redact api (#17792)

- better validation on user input
- fix an early task completion
- when checking membership in rooms, check for rooms user has been
banned from as well
This commit is contained in:
Shay 2024-10-08 06:23:21 -07:00 committed by GitHub
parent 006251a5d0
commit a5986ac229
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 107 additions and 41 deletions

1
changelog.d/17792.bugfix Normal file
View file

@ -0,0 +1 @@
Improve input validation and room membership checks in admin redaction API.

View file

@ -443,8 +443,8 @@ class AdminHandler:
["m.room.member", "m.room.message"], ["m.room.member", "m.room.message"],
) )
if not event_ids: if not event_ids:
# there's nothing to redact # nothing to redact in this room
return TaskStatus.COMPLETE, result, None continue
events = await self._store.get_events_as_list(event_ids) events = await self._store.get_events_as_list(event_ids)
for event in events: for event in events:

View file

@ -27,7 +27,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import attr import attr
from synapse._pydantic_compat import StrictBool from synapse._pydantic_compat import StrictBool, StrictInt, StrictStr
from synapse.api.constants import Direction, UserTypes from synapse.api.constants import Direction, UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import ( from synapse.http.servlet import (
@ -1421,40 +1421,39 @@ class RedactUser(RestServlet):
self._store = hs.get_datastores().main self._store = hs.get_datastores().main
self.admin_handler = hs.get_admin_handler() self.admin_handler = hs.get_admin_handler()
class PostBody(RequestBodyModel):
rooms: List[StrictStr]
reason: Optional[StrictStr]
limit: Optional[StrictInt]
async def on_POST( async def on_POST(
self, request: SynapseRequest, user_id: str self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request) requester = await self._auth.get_user_by_req(request)
await assert_user_is_admin(self._auth, requester) await assert_user_is_admin(self._auth, requester)
body = parse_json_object_from_request(request, allow_empty_body=True) # parse provided user id to check that it is valid
rooms = body.get("rooms") UserID.from_string(user_id)
if rooms is None:
body = parse_and_validate_json_object_from_request(request, self.PostBody)
limit = body.limit
if limit and limit <= 0:
raise SynapseError( raise SynapseError(
HTTPStatus.BAD_REQUEST, "Must provide a value for rooms." HTTPStatus.BAD_REQUEST,
"If limit is provided it must be a non-negative integer greater than 0.",
) )
reason = body.get("reason") rooms = body.rooms
if reason:
if not isinstance(reason, str):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"If a reason is provided it must be a string.",
)
limit = body.get("limit")
if limit:
if not isinstance(limit, int) or limit <= 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"If limit is provided it must be a non-negative integer greater than 0.",
)
if not rooms: if not rooms:
rooms = await self._store.get_rooms_for_user(user_id) current_rooms = list(await self._store.get_rooms_for_user(user_id))
banned_rooms = list(
await self._store.get_rooms_user_currently_banned_from(user_id)
)
rooms = current_rooms + banned_rooms
redact_id = await self.admin_handler.start_redact_events( redact_id = await self.admin_handler.start_redact_events(
user_id, list(rooms), requester.serialize(), reason, limit user_id, rooms, requester.serialize(), body.reason, limit
) )
return HTTPStatus.OK, {"redact_id": redact_id} return HTTPStatus.OK, {"redact_id": redact_id}

View file

@ -711,6 +711,27 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
return {row[0] for row in txn} return {row[0] for row in txn}
async def get_rooms_user_currently_banned_from(
self, user_id: str
) -> FrozenSet[str]:
"""Returns a set of room_ids the user is currently banned from.
If a remote user only returns rooms this server is currently
participating in.
"""
room_ids = await self.db_pool.simple_select_onecol(
table="current_state_events",
keyvalues={
"type": EventTypes.Member,
"membership": Membership.BAN,
"state_key": user_id,
},
retcol="room_id",
desc="get_rooms_user_currently_banned_from",
)
return frozenset(room_ids)
@cached(max_entries=500000, iterable=True) @cached(max_entries=500000, iterable=True)
async def get_rooms_for_user(self, user_id: str) -> FrozenSet[str]: async def get_rooms_for_user(self, user_id: str) -> FrozenSet[str]:
"""Returns a set of room_ids the user is currently joined to. """Returns a set of room_ids the user is currently joined to.

View file

@ -5288,19 +5288,26 @@ class UserRedactionTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(matched), len(rm2_originals)) self.assertEqual(len(matched), len(rm2_originals))
def test_admin_redact_works_if_user_kicked_or_banned(self) -> None: def test_admin_redact_works_if_user_kicked_or_banned(self) -> None:
originals = [] originals1 = []
originals2 = []
for rm in [self.rm1, self.rm2, self.rm3]: for rm in [self.rm1, self.rm2, self.rm3]:
join = self.helper.join(rm, self.bad_user, tok=self.bad_user_tok) join = self.helper.join(rm, self.bad_user, tok=self.bad_user_tok)
originals.append(join["event_id"]) if rm in [self.rm1, self.rm3]:
originals1.append(join["event_id"])
else:
originals2.append(join["event_id"])
for i in range(5): for i in range(5):
event = {"body": f"hello{i}", "msgtype": "m.text"} event = {"body": f"hello{i}", "msgtype": "m.text"}
res = self.helper.send_event( res = self.helper.send_event(
rm, "m.room.message", event, tok=self.bad_user_tok rm, "m.room.message", event, tok=self.bad_user_tok
) )
originals.append(res["event_id"]) if rm in [self.rm1, self.rm3]:
originals1.append(res["event_id"])
else:
originals2.append(res["event_id"])
# kick user from rooms 1 and 3 # kick user from rooms 1 and 3
for r in [self.rm1, self.rm2]: for r in [self.rm1, self.rm3]:
channel = self.make_request( channel = self.make_request(
"POST", "POST",
f"/_matrix/client/r0/rooms/{r}/kick", f"/_matrix/client/r0/rooms/{r}/kick",
@ -5330,32 +5337,70 @@ class UserRedactionTestCase(unittest.HomeserverTestCase):
failed_redactions = channel2.json_body.get("failed_redactions") failed_redactions = channel2.json_body.get("failed_redactions")
self.assertEqual(failed_redactions, {}) self.assertEqual(failed_redactions, {})
# ban user # double check
channel3 = self.make_request( for rm in [self.rm1, self.rm3]:
filter = json.dumps({"types": [EventTypes.Redaction]})
channel3 = self.make_request(
"GET",
f"rooms/{rm}/messages?filter={filter}&limit=50",
access_token=self.admin_tok,
)
self.assertEqual(channel3.code, 200)
matches = []
for event in channel3.json_body["chunk"]:
for event_id in originals1:
if (
event["type"] == "m.room.redaction"
and event["redacts"] == event_id
):
matches.append((event_id, event))
# we redacted 6 messages
self.assertEqual(len(matches), 6)
# ban user from room 2
channel4 = self.make_request(
"POST", "POST",
f"/_matrix/client/r0/rooms/{self.rm2}/ban", f"/_matrix/client/r0/rooms/{self.rm2}/ban",
content={"reason": "being a bummer", "user_id": self.bad_user}, content={"reason": "being a bummer", "user_id": self.bad_user},
access_token=self.admin_tok, access_token=self.admin_tok,
) )
self.assertEqual(channel3.code, HTTPStatus.OK, channel3.result) self.assertEqual(channel4.code, HTTPStatus.OK, channel4.result)
# redact messages in room 2 # make a request to ban all user's messages
channel4 = self.make_request( channel5 = self.make_request(
"POST", "POST",
f"/_synapse/admin/v1/user/{self.bad_user}/redact", f"/_synapse/admin/v1/user/{self.bad_user}/redact",
content={"rooms": [self.rm2]}, content={"rooms": []},
access_token=self.admin_tok, access_token=self.admin_tok,
) )
self.assertEqual(channel4.code, 200) self.assertEqual(channel5.code, 200)
id2 = channel1.json_body.get("redact_id") id2 = channel5.json_body.get("redact_id")
# check that there were no failed redactions in room 2 # check that there were no failed redactions in room 2
channel5 = self.make_request( channel6 = self.make_request(
"GET", "GET",
f"/_synapse/admin/v1/user/redact_status/{id2}", f"/_synapse/admin/v1/user/redact_status/{id2}",
access_token=self.admin_tok, access_token=self.admin_tok,
) )
self.assertEqual(channel5.code, 200) self.assertEqual(channel6.code, 200)
self.assertEqual(channel5.json_body.get("status"), "complete") self.assertEqual(channel6.json_body.get("status"), "complete")
failed_redactions = channel5.json_body.get("failed_redactions") failed_redactions = channel6.json_body.get("failed_redactions")
self.assertEqual(failed_redactions, {}) self.assertEqual(failed_redactions, {})
# double check messages in room 2 were redacted
filter = json.dumps({"types": [EventTypes.Redaction]})
channel7 = self.make_request(
"GET",
f"rooms/{self.rm2}/messages?filter={filter}&limit=50",
access_token=self.admin_tok,
)
self.assertEqual(channel7.code, 200)
matches = []
for event in channel7.json_body["chunk"]:
for event_id in originals2:
if event["type"] == "m.room.redaction" and event["redacts"] == event_id:
matches.append((event_id, event))
# we redacted 6 messages
self.assertEqual(len(matches), 6)