1
0
Fork 0
mirror of https://github.com/element-hq/synapse.git synced 2025-04-09 03:04:00 +00:00

Store hashes of media files, and allow quarantining by hash. (#18277)

This PR makes a few radical changes to media. This now stores the SHA256
hash of each file stored in the database (excluding thumbnails, more on
that later). If a set of media is quarantined, any additional uploads of
the same file contents or any other files with the same hash will be
quarantined at the same time.

Currently this does NOT:
 - De-duplicate media, although a future extension could be to do that.
- Run any background jobs to identify the hashes of older files. This
could also be a future extension, though the value of doing so is
limited to combat the abuse of recent media.
- Hash thumbnails. It's assumed that thumbnails are parented to some
form of media, so you'd likely be wanting to quarantine the media and
the thumbnail at the same time.
This commit is contained in:
Will Hunt 2025-03-27 17:26:34 +00:00 committed by GitHub
parent a39b856cf0
commit d17295e5c3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 579 additions and 62 deletions

View file

@ -0,0 +1 @@
Hashes of media files are now tracked by Synapse. Media quarantines will now apply to all files with the same hash.

View file

@ -59,7 +59,11 @@ from synapse.media._base import (
respond_with_responder,
)
from synapse.media.filepath import MediaFilePaths
from synapse.media.media_storage import MediaStorage
from synapse.media.media_storage import (
MediaStorage,
SHA256TransparentIOReader,
SHA256TransparentIOWriter,
)
from synapse.media.storage_provider import StorageProviderWrapper
from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
from synapse.media.url_previewer import UrlPreviewer
@ -301,15 +305,26 @@ class MediaRepository:
auth_user: The user_id of the uploader
"""
file_info = FileInfo(server_name=None, file_id=media_id)
fname = await self.media_storage.store_file(content, file_info)
sha256reader = SHA256TransparentIOReader(content)
# This implements all of IO as it has a passthrough
fname = await self.media_storage.store_file(sha256reader.wrap(), file_info)
sha256 = sha256reader.hexdigest()
should_quarantine = await self.store.get_is_hash_quarantined(sha256)
logger.info("Stored local media in file %r", fname)
if should_quarantine:
logger.warn(
"Media has been automatically quarantined as it matched existing quarantined media"
)
await self.store.update_local_media(
media_id=media_id,
media_type=media_type,
upload_name=upload_name,
media_length=content_length,
user_id=auth_user,
sha256=sha256,
quarantined_by="system" if should_quarantine else None,
)
try:
@ -342,11 +357,19 @@ class MediaRepository:
media_id = random_string(24)
file_info = FileInfo(server_name=None, file_id=media_id)
fname = await self.media_storage.store_file(content, file_info)
# This implements all of IO as it has a passthrough
sha256reader = SHA256TransparentIOReader(content)
fname = await self.media_storage.store_file(sha256reader.wrap(), file_info)
sha256 = sha256reader.hexdigest()
should_quarantine = await self.store.get_is_hash_quarantined(sha256)
logger.info("Stored local media in file %r", fname)
if should_quarantine:
logger.warn(
"Media has been automatically quarantined as it matched existing quarantined media"
)
await self.store.store_local_media(
media_id=media_id,
media_type=media_type,
@ -354,6 +377,9 @@ class MediaRepository:
upload_name=upload_name,
media_length=content_length,
user_id=auth_user,
sha256=sha256,
# TODO: Better name?
quarantined_by="system" if should_quarantine else None,
)
try:
@ -756,11 +782,13 @@ class MediaRepository:
file_info = FileInfo(server_name=server_name, file_id=file_id)
async with self.media_storage.store_into_file(file_info) as (f, fname):
sha256writer = SHA256TransparentIOWriter(f)
try:
length, headers = await self.client.download_media(
server_name,
media_id,
output_stream=f,
# This implements all of BinaryIO as it has a passthrough
output_stream=sha256writer.wrap(),
max_size=self.max_upload_size,
max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
@ -825,6 +853,7 @@ class MediaRepository:
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
sha256=sha256writer.hexdigest(),
)
logger.info("Stored remote media in file %r", fname)
@ -845,6 +874,7 @@ class MediaRepository:
last_access_ts=time_now_ms,
quarantined_by=None,
authenticated=authenticated,
sha256=sha256writer.hexdigest(),
)
async def _federation_download_remote_file(
@ -879,11 +909,13 @@ class MediaRepository:
file_info = FileInfo(server_name=server_name, file_id=file_id)
async with self.media_storage.store_into_file(file_info) as (f, fname):
sha256writer = SHA256TransparentIOWriter(f)
try:
res = await self.client.federation_download_media(
server_name,
media_id,
output_stream=f,
# This implements all of BinaryIO as it has a passthrough
output_stream=sha256writer.wrap(),
max_size=self.max_upload_size,
max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
@ -954,6 +986,7 @@ class MediaRepository:
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
sha256=sha256writer.hexdigest(),
)
logger.debug("Stored remote media in file %r", fname)
@ -974,6 +1007,7 @@ class MediaRepository:
last_access_ts=time_now_ms,
quarantined_by=None,
authenticated=authenticated,
sha256=sha256writer.hexdigest(),
)
def _get_thumbnail_requirements(

View file

@ -19,6 +19,7 @@
#
#
import contextlib
import hashlib
import json
import logging
import os
@ -70,6 +71,88 @@ logger = logging.getLogger(__name__)
CRLF = b"\r\n"
class SHA256TransparentIOWriter:
"""Will generate a SHA256 hash from a source stream transparently.
Args:
source: Source stream.
"""
def __init__(self, source: BinaryIO):
self._hash = hashlib.sha256()
self._source = source
def write(self, buffer: Union[bytes, bytearray]) -> int:
"""Wrapper for source.write()
Args:
buffer
Returns:
the value of source.write()
"""
res = self._source.write(buffer)
self._hash.update(buffer)
return res
def hexdigest(self) -> str:
"""The digest of the written or read value.
Returns:
The digest in hex formaat.
"""
return self._hash.hexdigest()
def wrap(self) -> BinaryIO:
# This class implements a subset the IO interface and passes through everything else via __getattr__
return cast(BinaryIO, self)
# Passthrough any other calls
def __getattr__(self, attr_name: str) -> Any:
return getattr(self._source, attr_name)
class SHA256TransparentIOReader:
"""Will generate a SHA256 hash from a source stream transparently.
Args:
source: Source IO stream.
"""
def __init__(self, source: IO):
self._hash = hashlib.sha256()
self._source = source
def read(self, n: int = -1) -> bytes:
"""Wrapper for source.read()
Args:
n
Returns:
the value of source.read()
"""
bytes = self._source.read(n)
self._hash.update(bytes)
return bytes
def hexdigest(self) -> str:
"""The digest of the written or read value.
Returns:
The digest in hex formaat.
"""
return self._hash.hexdigest()
def wrap(self) -> IO:
# This class implements a subset the IO interface and passes through everything else via __getattr__
return cast(IO, self)
# Passthrough any other calls
def __getattr__(self, attr_name: str) -> Any:
return getattr(self._source, attr_name)
class MediaStorage:
"""Responsible for storing/fetching files from local sources.
@ -107,7 +190,6 @@ class MediaStorage:
Returns:
the file path written to in the primary media store
"""
async with self.store_into_file(file_info) as (f, fname):
# Write to the main media repository
await self.write_to_file(source, f)

View file

@ -19,6 +19,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
import logging
from enum import Enum
from typing import (
TYPE_CHECKING,
@ -51,6 +52,8 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = (
"media_repository_drop_index_wo_method_2"
)
logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class LocalMedia:
@ -65,6 +68,7 @@ class LocalMedia:
safe_from_quarantine: bool
user_id: Optional[str]
authenticated: Optional[bool]
sha256: Optional[str]
@attr.s(slots=True, frozen=True, auto_attribs=True)
@ -79,6 +83,7 @@ class RemoteMedia:
last_access_ts: int
quarantined_by: Optional[str]
authenticated: Optional[bool]
sha256: Optional[str]
@attr.s(slots=True, frozen=True, auto_attribs=True)
@ -154,6 +159,26 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
unique=True,
)
self.db_pool.updates.register_background_index_update(
update_name="local_media_repository_sha256_idx",
index_name="local_media_repository_sha256",
table="local_media_repository",
where_clause="sha256 IS NOT NULL",
columns=[
"sha256",
],
)
self.db_pool.updates.register_background_index_update(
update_name="remote_media_cache_sha256_idx",
index_name="remote_media_cache_sha256",
table="remote_media_cache",
where_clause="sha256 IS NOT NULL",
columns=[
"sha256",
],
)
self.db_pool.updates.register_background_update_handler(
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2,
self._drop_media_index_without_method,
@ -221,6 +246,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"safe_from_quarantine",
"user_id",
"authenticated",
"sha256",
),
allow_none=True,
desc="get_local_media",
@ -239,6 +265,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
safe_from_quarantine=row[7],
user_id=row[8],
authenticated=row[9],
sha256=row[10],
)
async def get_local_media_by_user_paginate(
@ -295,7 +322,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
quarantined_by,
safe_from_quarantine,
user_id,
authenticated
authenticated,
sha256
FROM local_media_repository
WHERE user_id = ?
ORDER BY {order_by_column} {order}, media_id ASC
@ -320,6 +348,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
safe_from_quarantine=bool(row[8]),
user_id=row[9],
authenticated=row[10],
sha256=row[11],
)
for row in txn
]
@ -449,6 +478,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
media_length: int,
user_id: UserID,
url_cache: Optional[str] = None,
sha256: Optional[str] = None,
quarantined_by: Optional[str] = None,
) -> None:
if self.hs.config.media.enable_authenticated_media:
authenticated = True
@ -466,6 +497,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"user_id": user_id.to_string(),
"url_cache": url_cache,
"authenticated": authenticated,
"sha256": sha256,
"quarantined_by": quarantined_by,
},
desc="store_local_media",
)
@ -477,20 +510,28 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
upload_name: Optional[str],
media_length: int,
user_id: UserID,
sha256: str,
url_cache: Optional[str] = None,
quarantined_by: Optional[str] = None,
) -> None:
updatevalues = {
"media_type": media_type,
"upload_name": upload_name,
"media_length": media_length,
"url_cache": url_cache,
"sha256": sha256,
}
# This should never be un-set by this function.
if quarantined_by is not None:
updatevalues["quarantined_by"] = quarantined_by
await self.db_pool.simple_update_one(
"local_media_repository",
keyvalues={
"user_id": user_id.to_string(),
"media_id": media_id,
},
updatevalues={
"media_type": media_type,
"upload_name": upload_name,
"media_length": media_length,
"url_cache": url_cache,
},
updatevalues=updatevalues,
desc="update_local_media",
)
@ -657,6 +698,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"last_access_ts",
"quarantined_by",
"authenticated",
"sha256",
),
allow_none=True,
desc="get_cached_remote_media",
@ -674,6 +716,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
last_access_ts=row[5],
quarantined_by=row[6],
authenticated=row[7],
sha256=row[8],
)
async def store_cached_remote_media(
@ -685,6 +728,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
time_now_ms: int,
upload_name: Optional[str],
filesystem_id: str,
sha256: Optional[str],
) -> None:
if self.hs.config.media.enable_authenticated_media:
authenticated = True
@ -703,6 +747,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"filesystem_id": filesystem_id,
"last_access_ts": time_now_ms,
"authenticated": authenticated,
"sha256": sha256,
},
desc="store_cached_remote_media",
)
@ -946,3 +991,37 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
await self.db_pool.runInteraction(
"delete_url_cache_media", _delete_url_cache_media_txn
)
async def get_is_hash_quarantined(self, sha256: str) -> bool:
"""Get whether a specific sha256 hash digest matches any quarantined media.
Returns:
None if the media_id doesn't exist.
"""
def get_matching_media_txn(
txn: LoggingTransaction, table: str, sha256: str
) -> bool:
# Return on first match
sql = """
SELECT 1
FROM local_media_repository
WHERE sha256 = ? AND quarantined_by IS NOT NULL
UNION ALL
SELECT 1
FROM remote_media_cache
WHERE sha256 = ? AND quarantined_by IS NOT NULL
LIMIT 1
"""
txn.execute(sql, (sha256, sha256))
row = txn.fetchone()
return row is not None
return await self.db_pool.runInteraction(
"get_matching_media_txn",
get_matching_media_txn,
"local_media_repository",
sha256,
)

View file

@ -51,11 +51,15 @@ from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.config.homeserver import HomeServerConfig
from synapse.events import EventBase
from synapse.replication.tcp.streams.partial_state import UnPartialStatedRoomStream
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage._base import (
db_to_json,
make_in_list_sql_clause,
)
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
make_tuple_in_list_sql_clause,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.types import Cursor
@ -1127,6 +1131,109 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
return local_media_ids
def _quarantine_local_media_txn(
self,
txn: LoggingTransaction,
hashes: Set[str],
media_ids: Set[str],
quarantined_by: Optional[str],
) -> int:
"""Quarantine and unquarantine local media items.
Args:
txn (cursor)
hashes: A set of sha256 hashes for any media that should be quarantined
media_ids: A set of media IDs for any media that should be quarantined
quarantined_by: The ID of the user who initiated the quarantine request
If it is `None` media will be removed from quarantine
Returns:
The total number of media items quarantined
"""
total_media_quarantined = 0
# Effectively a legacy path, update any media that was explicitly named.
if media_ids:
sql_many_clause_sql, sql_many_clause_args = make_in_list_sql_clause(
txn.database_engine, "media_id", media_ids
)
sql = f"""
UPDATE local_media_repository
SET quarantined_by = ?
WHERE {sql_many_clause_sql}"""
if quarantined_by is not None:
sql += " AND safe_from_quarantine = FALSE"
txn.execute(sql, [quarantined_by] + sql_many_clause_args)
# Note that a rowcount of -1 can be used to indicate no rows were affected.
total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
# Update any media that was identified via hash.
if hashes:
sql_many_clause_sql, sql_many_clause_args = make_in_list_sql_clause(
txn.database_engine, "sha256", hashes
)
sql = f"""
UPDATE local_media_repository
SET quarantined_by = ?
WHERE {sql_many_clause_sql}"""
if quarantined_by is not None:
sql += " AND safe_from_quarantine = FALSE"
txn.execute(sql, [quarantined_by] + sql_many_clause_args)
total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
return total_media_quarantined
def _quarantine_remote_media_txn(
self,
txn: LoggingTransaction,
hashes: Set[str],
media: Set[Tuple[str, str]],
quarantined_by: Optional[str],
) -> int:
"""Quarantine and unquarantine remote items
Args:
txn (cursor)
hashes: A set of sha256 hashes for any media that should be quarantined
media_ids: A set of tuples (media_origin, media_id) for any media that should be quarantined
quarantined_by: The ID of the user who initiated the quarantine request
If it is `None` media will be removed from quarantine
Returns:
The total number of media items quarantined
"""
total_media_quarantined = 0
if media:
sql_in_list_clause, sql_args = make_tuple_in_list_sql_clause(
txn.database_engine,
("media_origin", "media_id"),
media,
)
sql = f"""
UPDATE remote_media_cache
SET quarantined_by = ?
WHERE {sql_in_list_clause}"""
txn.execute(sql, [quarantined_by] + sql_args)
total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
total_media_quarantined = 0
if hashes:
sql_many_clause_sql, sql_many_clause_args = make_in_list_sql_clause(
txn.database_engine, "sha256", hashes
)
sql = f"""
UPDATE remote_media_cache
SET quarantined_by = ?
WHERE {sql_many_clause_sql}"""
txn.execute(sql, [quarantined_by] + sql_many_clause_args)
total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
return total_media_quarantined
def _quarantine_media_txn(
self,
txn: LoggingTransaction,
@ -1146,40 +1253,49 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
Returns:
The total number of media items quarantined
"""
hashes = set()
media_ids = set()
remote_media = set()
# Update all the tables to set the quarantined_by flag
sql = """
UPDATE local_media_repository
SET quarantined_by = ?
WHERE media_id = ?
"""
# set quarantine
if quarantined_by is not None:
sql += "AND safe_from_quarantine = FALSE"
txn.executemany(
sql, [(quarantined_by, media_id) for media_id in local_mxcs]
# First, determine the hashes of the media we want to delete.
# We also want the media_ids for any media that lacks a hash.
if local_mxcs:
hash_sql_many_clause_sql, hash_sql_many_clause_args = (
make_in_list_sql_clause(txn.database_engine, "media_id", local_mxcs)
)
# remove from quarantine
else:
txn.executemany(
sql, [(quarantined_by, media_id) for media_id in local_mxcs]
hash_sql = f"SELECT sha256, media_id FROM local_media_repository WHERE {hash_sql_many_clause_sql}"
if quarantined_by is not None:
hash_sql += " AND safe_from_quarantine = FALSE"
txn.execute(hash_sql, hash_sql_many_clause_args)
for sha256, media_id in txn:
if sha256:
hashes.add(sha256)
else:
media_ids.add(media_id)
# Do the same for remote media
if remote_mxcs:
hash_sql_in_list_clause, hash_sql_args = make_tuple_in_list_sql_clause(
txn.database_engine,
("media_origin", "media_id"),
remote_mxcs,
)
# Note that a rowcount of -1 can be used to indicate no rows were affected.
total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0
hash_sql = f"SELECT sha256, media_origin, media_id FROM remote_media_cache WHERE {hash_sql_in_list_clause}"
txn.execute(hash_sql, hash_sql_args)
for sha256, media_origin, media_id in txn:
if sha256:
hashes.add(sha256)
else:
remote_media.add((media_origin, media_id))
txn.executemany(
"""
UPDATE remote_media_cache
SET quarantined_by = ?
WHERE media_origin = ? AND media_id = ?
""",
[(quarantined_by, origin, media_id) for origin, media_id in remote_mxcs],
count = self._quarantine_local_media_txn(txn, hashes, media_ids, quarantined_by)
count += self._quarantine_remote_media_txn(
txn, hashes, remote_media, quarantined_by
)
total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
return total_media_quarantined
return count
async def block_room(self, room_id: str, user_id: str) -> None:
"""Marks the room as blocked.

View file

@ -19,7 +19,7 @@
#
#
SCHEMA_VERSION = 90 # remember to update the list below when updating
SCHEMA_VERSION = 91 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the

View file

@ -0,0 +1,21 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2025 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
-- Store the SHA256 content hash of media files.
ALTER TABLE local_media_repository ADD COLUMN sha256 TEXT;
ALTER TABLE remote_media_cache ADD COLUMN sha256 TEXT;
-- Add a background updates to handle creating the new index.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(9101, 'local_media_repository_sha256_idx', '{}'),
(9101, 'remote_media_cache_sha256_idx', '{}');

View file

@ -369,6 +369,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
time_now_ms=self.clock.time_msec(),
upload_name=None,
filesystem_id="xyz",
sha256="abcdefg12345",
)
)

View file

@ -31,6 +31,9 @@ from synapse.rest.client import login, register, room
from synapse.server import HomeServer
from synapse.types import UserID
from synapse.util import Clock
from synapse.util.stringutils import (
random_string,
)
from tests import unittest
from tests.unittest import override_config
@ -65,7 +68,6 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
# quarantined media) into both the local store and the remote cache, plus
# one additional local media that is marked as protected from quarantine.
media_repository = hs.get_media_repository()
test_media_content = b"example string"
def _create_media_and_set_attributes(
last_accessed_ms: Optional[int],
@ -73,12 +75,14 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
is_protected: Optional[bool] = False,
) -> MXCUri:
# "Upload" some media to the local media store
# If the meda
random_content = bytes(random_string(24), "utf-8")
mxc_uri: MXCUri = self.get_success(
media_repository.create_content(
media_type="text/plain",
upload_name=None,
content=io.BytesIO(test_media_content),
content_length=len(test_media_content),
content=io.BytesIO(random_content),
content_length=len(random_content),
auth_user=UserID.from_string(test_user_id),
)
)
@ -129,6 +133,7 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
time_now_ms=clock.time_msec(),
upload_name="testfile.txt",
filesystem_id="abcdefg12345",
sha256=random_string(24),
)
)

View file

@ -42,6 +42,7 @@ from twisted.web.resource import Resource
from synapse.api.errors import Codes, HttpResponseException
from synapse.api.ratelimiting import Ratelimiter
from synapse.events import EventBase
from synapse.http.client import ByteWriteable
from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable
from synapse.media._base import FileInfo, ThumbnailInfo
@ -59,7 +60,7 @@ from synapse.util import Clock
from tests import unittest
from tests.server import FakeChannel
from tests.test_utils import SMALL_CMYK_JPEG, SMALL_PNG
from tests.test_utils import SMALL_CMYK_JPEG, SMALL_PNG, SMALL_PNG_SHA256
from tests.unittest import override_config
from tests.utils import default_config
@ -1257,3 +1258,107 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
)
assert channel.code == 502
assert channel.json_body["errcode"] == "M_TOO_LARGE"
def read_body(
response: IResponse, stream: ByteWriteable, max_size: Optional[int]
) -> Deferred:
d: Deferred = defer.Deferred()
stream.write(SMALL_PNG)
d.callback(len(SMALL_PNG))
return d
class MediaHashesTestCase(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
media.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user = self.register_user("user", "pass")
self.tok = self.login("user", "pass")
self.store = hs.get_datastores().main
self.client = hs.get_federation_http_client()
def create_resource_dict(self) -> Dict[str, Resource]:
resources = super().create_resource_dict()
resources["/_matrix/media"] = self.hs.get_media_repository_resource()
return resources
def test_ensure_correct_sha256(self) -> None:
"""Check that the hash does not change"""
media = self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=200)
mxc = media.get("content_uri")
assert mxc
store_media = self.get_success(self.store.get_local_media(mxc[11:]))
assert store_media
self.assertEqual(
store_media.sha256,
SMALL_PNG_SHA256,
)
def test_ensure_multiple_correct_sha256(self) -> None:
"""Check that two media items have the same hash."""
media_a = self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=200)
mxc_a = media_a.get("content_uri")
assert mxc_a
store_media_a = self.get_success(self.store.get_local_media(mxc_a[11:]))
assert store_media_a
media_b = self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=200)
mxc_b = media_b.get("content_uri")
assert mxc_b
store_media_b = self.get_success(self.store.get_local_media(mxc_b[11:]))
assert store_media_b
self.assertNotEqual(
store_media_a.media_id,
store_media_b.media_id,
)
self.assertEqual(
store_media_a.sha256,
store_media_b.sha256,
)
@override_config(
{
"enable_authenticated_media": False,
}
)
# mock actually reading file body
@patch(
"synapse.http.matrixfederationclient.read_body_with_max_size",
read_body,
)
def test_ensure_correct_sha256_federated(self) -> None:
"""Check that federated media have the same hash."""
# Mock getting a file over federation
async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
resp = MagicMock(spec=IResponse)
resp.code = 200
resp.length = 500
resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
resp.phrase = b"OK"
return resp
self.client._send_request = _send_request # type: ignore
# first request should go through
channel = self.make_request(
"GET",
"/_matrix/media/v3/download/remote.org/abc",
shorthand=False,
access_token=self.tok,
)
assert channel.code == 200
store_media = self.get_success(
self.store.get_cached_remote_media("remote.org", "abc")
)
assert store_media
self.assertEqual(
store_media.sha256,
SMALL_PNG_SHA256,
)

View file

@ -20,7 +20,7 @@
#
import urllib.parse
from typing import Dict
from typing import Dict, cast
from parameterized import parameterized
@ -32,6 +32,7 @@ from synapse.http.server import JsonResource
from synapse.rest.admin import VersionServlet
from synapse.rest.client import login, media, room
from synapse.server import HomeServer
from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
@ -227,10 +228,25 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Upload some media
response_1 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
response_2 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
response_3 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
# Extract media IDs
server_and_media_id_1 = response_1["content_uri"][6:]
server_and_media_id_2 = response_2["content_uri"][6:]
server_and_media_id_3 = response_3["content_uri"][6:]
# Remove the hash from the media to simulate historic media.
self.get_success(
self.hs.get_datastores().main.update_local_media(
media_id=server_and_media_id_3.split("/")[1],
media_type="image/png",
upload_name=None,
media_length=123,
user_id=UserID.from_string(non_admin_user),
# Hack to force some media to have no hash.
sha256=cast(str, None),
)
)
# Quarantine all media by this user
url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
@ -244,12 +260,13 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self.pump(1.0)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(
channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items"
channel.json_body, {"num_quarantined": 3}, "Expected 3 quarantined items"
)
# Attempt to access each piece of media
self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
self._ensure_quarantined(admin_user_tok, server_and_media_id_2)
self._ensure_quarantined(admin_user_tok, server_and_media_id_3)
def test_cannot_quarantine_safe_media(self) -> None:
self.register_user("user_admin", "pass", admin=True)

View file

@ -35,7 +35,7 @@ from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
from tests.test_utils import SMALL_PNG
from tests.test_utils import SMALL_CMYK_JPEG, SMALL_PNG
from tests.unittest import override_config
VALID_TIMESTAMP = 1609459200000 # 2021-01-01 in milliseconds
@ -598,23 +598,27 @@ class DeleteMediaByDateSizeTestCase(_AdminMediaTests):
class QuarantineMediaByIDTestCase(_AdminMediaTests):
def upload_media_and_return_media_id(self, data: bytes) -> str:
# Upload some media into the room
response = self.helper.upload_media(
data,
tok=self.admin_user_tok,
expect_code=200,
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
return server_and_media_id.split("/")[1]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.server_name = hs.hostname
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
# Upload some media into the room
response = self.helper.upload_media(
SMALL_PNG,
tok=self.admin_user_tok,
expect_code=200,
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
self.media_id = server_and_media_id.split("/")[1]
self.media_id = self.upload_media_and_return_media_id(SMALL_PNG)
self.media_id_2 = self.upload_media_and_return_media_id(SMALL_PNG)
self.media_id_3 = self.upload_media_and_return_media_id(SMALL_PNG)
self.media_id_other = self.upload_media_and_return_media_id(SMALL_CMYK_JPEG)
self.url = "/_synapse/admin/v1/media/%s/%s/%s"
@parameterized.expand(["quarantine", "unquarantine"])
@ -686,6 +690,52 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
assert media_info is not None
self.assertFalse(media_info.quarantined_by)
def test_quarantine_media_match_hash(self) -> None:
"""
Tests that quarantining removes all media with the same hash
"""
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertFalse(media_info.quarantined_by)
# quarantining
channel = self.make_request(
"POST",
self.url % ("quarantine", self.server_name, self.media_id),
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
# Test that ALL similar media was quarantined.
for media in [self.media_id, self.media_id_2, self.media_id_3]:
media_info = self.get_success(self.store.get_local_media(media))
assert media_info is not None
self.assertTrue(media_info.quarantined_by)
# Test that other media was not.
media_info = self.get_success(self.store.get_local_media(self.media_id_other))
assert media_info is not None
self.assertFalse(media_info.quarantined_by)
# remove from quarantine
channel = self.make_request(
"POST",
self.url % ("unquarantine", self.server_name, self.media_id),
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
# Test that ALL similar media is now reset.
for media in [self.media_id, self.media_id_2, self.media_id_3]:
media_info = self.get_success(self.store.get_local_media(media))
assert media_info is not None
self.assertFalse(media_info.quarantined_by)
def test_quarantine_protected_media(self) -> None:
"""
Tests that quarantining from protected media fails

View file

@ -137,6 +137,7 @@ class MediaDomainBlockingTests(unittest.HomeserverTestCase):
time_now_ms=clock.time_msec(),
upload_name="test.png",
filesystem_id=file_id,
sha256=file_id,
)
)
self.register_user("user", "password")
@ -2593,6 +2594,7 @@ class AuthenticatedMediaTestCase(unittest.HomeserverTestCase):
time_now_ms=self.clock.time_msec(),
upload_name="remote_test.png",
filesystem_id=file_id,
sha256=file_id,
)
)
@ -2725,6 +2727,7 @@ class AuthenticatedMediaTestCase(unittest.HomeserverTestCase):
time_now_ms=self.clock.time_msec(),
upload_name="remote_test.png",
filesystem_id=file_id,
sha256=file_id,
)
)

View file

@ -61,6 +61,7 @@ class MediaDomainBlockingTests(unittest.HomeserverTestCase):
time_now_ms=clock.time_msec(),
upload_name="test.png",
filesystem_id=file_id,
sha256=file_id,
)
)

View file

@ -139,6 +139,8 @@ SMALL_PNG = unhexlify(
b"0000001f15c4890000000a49444154789c63000100000500010d"
b"0a2db40000000049454e44ae426082"
)
# The SHA256 hexdigest for the above bytes.
SMALL_PNG_SHA256 = "ebf4f635a17d10d6eb46ba680b70142419aa3220f228001a036d311a22ee9d2a"
# A small CMYK-encoded JPEG image used in some tests.
#