mirror of
https://github.com/element-hq/synapse.git
synced 2025-04-09 03:04:00 +00:00
Store hashes of media files, and allow quarantining by hash. (#18277)
This PR makes a few radical changes to media. This now stores the SHA256 hash of each file stored in the database (excluding thumbnails, more on that later). If a set of media is quarantined, any additional uploads of the same file contents or any other files with the same hash will be quarantined at the same time. Currently this does NOT: - De-duplicate media, although a future extension could be to do that. - Run any background jobs to identify the hashes of older files. This could also be a future extension, though the value of doing so is limited to combat the abuse of recent media. - Hash thumbnails. It's assumed that thumbnails are parented to some form of media, so you'd likely be wanting to quarantine the media and the thumbnail at the same time.
This commit is contained in:
parent
a39b856cf0
commit
d17295e5c3
15 changed files with 579 additions and 62 deletions
1
changelog.d/18277.feature
Normal file
1
changelog.d/18277.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Hashes of media files are now tracked by Synapse. Media quarantines will now apply to all files with the same hash.
|
|
@ -59,7 +59,11 @@ from synapse.media._base import (
|
|||
respond_with_responder,
|
||||
)
|
||||
from synapse.media.filepath import MediaFilePaths
|
||||
from synapse.media.media_storage import MediaStorage
|
||||
from synapse.media.media_storage import (
|
||||
MediaStorage,
|
||||
SHA256TransparentIOReader,
|
||||
SHA256TransparentIOWriter,
|
||||
)
|
||||
from synapse.media.storage_provider import StorageProviderWrapper
|
||||
from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
|
||||
from synapse.media.url_previewer import UrlPreviewer
|
||||
|
@ -301,15 +305,26 @@ class MediaRepository:
|
|||
auth_user: The user_id of the uploader
|
||||
"""
|
||||
file_info = FileInfo(server_name=None, file_id=media_id)
|
||||
fname = await self.media_storage.store_file(content, file_info)
|
||||
sha256reader = SHA256TransparentIOReader(content)
|
||||
# This implements all of IO as it has a passthrough
|
||||
fname = await self.media_storage.store_file(sha256reader.wrap(), file_info)
|
||||
sha256 = sha256reader.hexdigest()
|
||||
should_quarantine = await self.store.get_is_hash_quarantined(sha256)
|
||||
logger.info("Stored local media in file %r", fname)
|
||||
|
||||
if should_quarantine:
|
||||
logger.warn(
|
||||
"Media has been automatically quarantined as it matched existing quarantined media"
|
||||
)
|
||||
|
||||
await self.store.update_local_media(
|
||||
media_id=media_id,
|
||||
media_type=media_type,
|
||||
upload_name=upload_name,
|
||||
media_length=content_length,
|
||||
user_id=auth_user,
|
||||
sha256=sha256,
|
||||
quarantined_by="system" if should_quarantine else None,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -342,11 +357,19 @@ class MediaRepository:
|
|||
media_id = random_string(24)
|
||||
|
||||
file_info = FileInfo(server_name=None, file_id=media_id)
|
||||
|
||||
fname = await self.media_storage.store_file(content, file_info)
|
||||
# This implements all of IO as it has a passthrough
|
||||
sha256reader = SHA256TransparentIOReader(content)
|
||||
fname = await self.media_storage.store_file(sha256reader.wrap(), file_info)
|
||||
sha256 = sha256reader.hexdigest()
|
||||
should_quarantine = await self.store.get_is_hash_quarantined(sha256)
|
||||
|
||||
logger.info("Stored local media in file %r", fname)
|
||||
|
||||
if should_quarantine:
|
||||
logger.warn(
|
||||
"Media has been automatically quarantined as it matched existing quarantined media"
|
||||
)
|
||||
|
||||
await self.store.store_local_media(
|
||||
media_id=media_id,
|
||||
media_type=media_type,
|
||||
|
@ -354,6 +377,9 @@ class MediaRepository:
|
|||
upload_name=upload_name,
|
||||
media_length=content_length,
|
||||
user_id=auth_user,
|
||||
sha256=sha256,
|
||||
# TODO: Better name?
|
||||
quarantined_by="system" if should_quarantine else None,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -756,11 +782,13 @@ class MediaRepository:
|
|||
file_info = FileInfo(server_name=server_name, file_id=file_id)
|
||||
|
||||
async with self.media_storage.store_into_file(file_info) as (f, fname):
|
||||
sha256writer = SHA256TransparentIOWriter(f)
|
||||
try:
|
||||
length, headers = await self.client.download_media(
|
||||
server_name,
|
||||
media_id,
|
||||
output_stream=f,
|
||||
# This implements all of BinaryIO as it has a passthrough
|
||||
output_stream=sha256writer.wrap(),
|
||||
max_size=self.max_upload_size,
|
||||
max_timeout_ms=max_timeout_ms,
|
||||
download_ratelimiter=download_ratelimiter,
|
||||
|
@ -825,6 +853,7 @@ class MediaRepository:
|
|||
upload_name=upload_name,
|
||||
media_length=length,
|
||||
filesystem_id=file_id,
|
||||
sha256=sha256writer.hexdigest(),
|
||||
)
|
||||
|
||||
logger.info("Stored remote media in file %r", fname)
|
||||
|
@ -845,6 +874,7 @@ class MediaRepository:
|
|||
last_access_ts=time_now_ms,
|
||||
quarantined_by=None,
|
||||
authenticated=authenticated,
|
||||
sha256=sha256writer.hexdigest(),
|
||||
)
|
||||
|
||||
async def _federation_download_remote_file(
|
||||
|
@ -879,11 +909,13 @@ class MediaRepository:
|
|||
file_info = FileInfo(server_name=server_name, file_id=file_id)
|
||||
|
||||
async with self.media_storage.store_into_file(file_info) as (f, fname):
|
||||
sha256writer = SHA256TransparentIOWriter(f)
|
||||
try:
|
||||
res = await self.client.federation_download_media(
|
||||
server_name,
|
||||
media_id,
|
||||
output_stream=f,
|
||||
# This implements all of BinaryIO as it has a passthrough
|
||||
output_stream=sha256writer.wrap(),
|
||||
max_size=self.max_upload_size,
|
||||
max_timeout_ms=max_timeout_ms,
|
||||
download_ratelimiter=download_ratelimiter,
|
||||
|
@ -954,6 +986,7 @@ class MediaRepository:
|
|||
upload_name=upload_name,
|
||||
media_length=length,
|
||||
filesystem_id=file_id,
|
||||
sha256=sha256writer.hexdigest(),
|
||||
)
|
||||
|
||||
logger.debug("Stored remote media in file %r", fname)
|
||||
|
@ -974,6 +1007,7 @@ class MediaRepository:
|
|||
last_access_ts=time_now_ms,
|
||||
quarantined_by=None,
|
||||
authenticated=authenticated,
|
||||
sha256=sha256writer.hexdigest(),
|
||||
)
|
||||
|
||||
def _get_thumbnail_requirements(
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#
|
||||
#
|
||||
import contextlib
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
@ -70,6 +71,88 @@ logger = logging.getLogger(__name__)
|
|||
CRLF = b"\r\n"
|
||||
|
||||
|
||||
class SHA256TransparentIOWriter:
|
||||
"""Will generate a SHA256 hash from a source stream transparently.
|
||||
|
||||
Args:
|
||||
source: Source stream.
|
||||
"""
|
||||
|
||||
def __init__(self, source: BinaryIO):
|
||||
self._hash = hashlib.sha256()
|
||||
self._source = source
|
||||
|
||||
def write(self, buffer: Union[bytes, bytearray]) -> int:
|
||||
"""Wrapper for source.write()
|
||||
|
||||
Args:
|
||||
buffer
|
||||
|
||||
Returns:
|
||||
the value of source.write()
|
||||
"""
|
||||
res = self._source.write(buffer)
|
||||
self._hash.update(buffer)
|
||||
return res
|
||||
|
||||
def hexdigest(self) -> str:
|
||||
"""The digest of the written or read value.
|
||||
|
||||
Returns:
|
||||
The digest in hex formaat.
|
||||
"""
|
||||
return self._hash.hexdigest()
|
||||
|
||||
def wrap(self) -> BinaryIO:
|
||||
# This class implements a subset the IO interface and passes through everything else via __getattr__
|
||||
return cast(BinaryIO, self)
|
||||
|
||||
# Passthrough any other calls
|
||||
def __getattr__(self, attr_name: str) -> Any:
|
||||
return getattr(self._source, attr_name)
|
||||
|
||||
|
||||
class SHA256TransparentIOReader:
|
||||
"""Will generate a SHA256 hash from a source stream transparently.
|
||||
|
||||
Args:
|
||||
source: Source IO stream.
|
||||
"""
|
||||
|
||||
def __init__(self, source: IO):
|
||||
self._hash = hashlib.sha256()
|
||||
self._source = source
|
||||
|
||||
def read(self, n: int = -1) -> bytes:
|
||||
"""Wrapper for source.read()
|
||||
|
||||
Args:
|
||||
n
|
||||
|
||||
Returns:
|
||||
the value of source.read()
|
||||
"""
|
||||
bytes = self._source.read(n)
|
||||
self._hash.update(bytes)
|
||||
return bytes
|
||||
|
||||
def hexdigest(self) -> str:
|
||||
"""The digest of the written or read value.
|
||||
|
||||
Returns:
|
||||
The digest in hex formaat.
|
||||
"""
|
||||
return self._hash.hexdigest()
|
||||
|
||||
def wrap(self) -> IO:
|
||||
# This class implements a subset the IO interface and passes through everything else via __getattr__
|
||||
return cast(IO, self)
|
||||
|
||||
# Passthrough any other calls
|
||||
def __getattr__(self, attr_name: str) -> Any:
|
||||
return getattr(self._source, attr_name)
|
||||
|
||||
|
||||
class MediaStorage:
|
||||
"""Responsible for storing/fetching files from local sources.
|
||||
|
||||
|
@ -107,7 +190,6 @@ class MediaStorage:
|
|||
Returns:
|
||||
the file path written to in the primary media store
|
||||
"""
|
||||
|
||||
async with self.store_into_file(file_info) as (f, fname):
|
||||
# Write to the main media repository
|
||||
await self.write_to_file(source, f)
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
# [This file includes modifications made by New Vector Limited]
|
||||
#
|
||||
#
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
|
@ -51,6 +52,8 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = (
|
|||
"media_repository_drop_index_wo_method_2"
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class LocalMedia:
|
||||
|
@ -65,6 +68,7 @@ class LocalMedia:
|
|||
safe_from_quarantine: bool
|
||||
user_id: Optional[str]
|
||||
authenticated: Optional[bool]
|
||||
sha256: Optional[str]
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
|
@ -79,6 +83,7 @@ class RemoteMedia:
|
|||
last_access_ts: int
|
||||
quarantined_by: Optional[str]
|
||||
authenticated: Optional[bool]
|
||||
sha256: Optional[str]
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
|
@ -154,6 +159,26 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
|
|||
unique=True,
|
||||
)
|
||||
|
||||
self.db_pool.updates.register_background_index_update(
|
||||
update_name="local_media_repository_sha256_idx",
|
||||
index_name="local_media_repository_sha256",
|
||||
table="local_media_repository",
|
||||
where_clause="sha256 IS NOT NULL",
|
||||
columns=[
|
||||
"sha256",
|
||||
],
|
||||
)
|
||||
|
||||
self.db_pool.updates.register_background_index_update(
|
||||
update_name="remote_media_cache_sha256_idx",
|
||||
index_name="remote_media_cache_sha256",
|
||||
table="remote_media_cache",
|
||||
where_clause="sha256 IS NOT NULL",
|
||||
columns=[
|
||||
"sha256",
|
||||
],
|
||||
)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2,
|
||||
self._drop_media_index_without_method,
|
||||
|
@ -221,6 +246,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
"safe_from_quarantine",
|
||||
"user_id",
|
||||
"authenticated",
|
||||
"sha256",
|
||||
),
|
||||
allow_none=True,
|
||||
desc="get_local_media",
|
||||
|
@ -239,6 +265,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
safe_from_quarantine=row[7],
|
||||
user_id=row[8],
|
||||
authenticated=row[9],
|
||||
sha256=row[10],
|
||||
)
|
||||
|
||||
async def get_local_media_by_user_paginate(
|
||||
|
@ -295,7 +322,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
quarantined_by,
|
||||
safe_from_quarantine,
|
||||
user_id,
|
||||
authenticated
|
||||
authenticated,
|
||||
sha256
|
||||
FROM local_media_repository
|
||||
WHERE user_id = ?
|
||||
ORDER BY {order_by_column} {order}, media_id ASC
|
||||
|
@ -320,6 +348,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
safe_from_quarantine=bool(row[8]),
|
||||
user_id=row[9],
|
||||
authenticated=row[10],
|
||||
sha256=row[11],
|
||||
)
|
||||
for row in txn
|
||||
]
|
||||
|
@ -449,6 +478,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
media_length: int,
|
||||
user_id: UserID,
|
||||
url_cache: Optional[str] = None,
|
||||
sha256: Optional[str] = None,
|
||||
quarantined_by: Optional[str] = None,
|
||||
) -> None:
|
||||
if self.hs.config.media.enable_authenticated_media:
|
||||
authenticated = True
|
||||
|
@ -466,6 +497,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
"user_id": user_id.to_string(),
|
||||
"url_cache": url_cache,
|
||||
"authenticated": authenticated,
|
||||
"sha256": sha256,
|
||||
"quarantined_by": quarantined_by,
|
||||
},
|
||||
desc="store_local_media",
|
||||
)
|
||||
|
@ -477,20 +510,28 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
upload_name: Optional[str],
|
||||
media_length: int,
|
||||
user_id: UserID,
|
||||
sha256: str,
|
||||
url_cache: Optional[str] = None,
|
||||
quarantined_by: Optional[str] = None,
|
||||
) -> None:
|
||||
updatevalues = {
|
||||
"media_type": media_type,
|
||||
"upload_name": upload_name,
|
||||
"media_length": media_length,
|
||||
"url_cache": url_cache,
|
||||
"sha256": sha256,
|
||||
}
|
||||
|
||||
# This should never be un-set by this function.
|
||||
if quarantined_by is not None:
|
||||
updatevalues["quarantined_by"] = quarantined_by
|
||||
|
||||
await self.db_pool.simple_update_one(
|
||||
"local_media_repository",
|
||||
keyvalues={
|
||||
"user_id": user_id.to_string(),
|
||||
"media_id": media_id,
|
||||
},
|
||||
updatevalues={
|
||||
"media_type": media_type,
|
||||
"upload_name": upload_name,
|
||||
"media_length": media_length,
|
||||
"url_cache": url_cache,
|
||||
},
|
||||
updatevalues=updatevalues,
|
||||
desc="update_local_media",
|
||||
)
|
||||
|
||||
|
@ -657,6 +698,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
"last_access_ts",
|
||||
"quarantined_by",
|
||||
"authenticated",
|
||||
"sha256",
|
||||
),
|
||||
allow_none=True,
|
||||
desc="get_cached_remote_media",
|
||||
|
@ -674,6 +716,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
last_access_ts=row[5],
|
||||
quarantined_by=row[6],
|
||||
authenticated=row[7],
|
||||
sha256=row[8],
|
||||
)
|
||||
|
||||
async def store_cached_remote_media(
|
||||
|
@ -685,6 +728,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
time_now_ms: int,
|
||||
upload_name: Optional[str],
|
||||
filesystem_id: str,
|
||||
sha256: Optional[str],
|
||||
) -> None:
|
||||
if self.hs.config.media.enable_authenticated_media:
|
||||
authenticated = True
|
||||
|
@ -703,6 +747,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
"filesystem_id": filesystem_id,
|
||||
"last_access_ts": time_now_ms,
|
||||
"authenticated": authenticated,
|
||||
"sha256": sha256,
|
||||
},
|
||||
desc="store_cached_remote_media",
|
||||
)
|
||||
|
@ -946,3 +991,37 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
await self.db_pool.runInteraction(
|
||||
"delete_url_cache_media", _delete_url_cache_media_txn
|
||||
)
|
||||
|
||||
async def get_is_hash_quarantined(self, sha256: str) -> bool:
|
||||
"""Get whether a specific sha256 hash digest matches any quarantined media.
|
||||
|
||||
Returns:
|
||||
None if the media_id doesn't exist.
|
||||
"""
|
||||
|
||||
def get_matching_media_txn(
|
||||
txn: LoggingTransaction, table: str, sha256: str
|
||||
) -> bool:
|
||||
# Return on first match
|
||||
sql = """
|
||||
SELECT 1
|
||||
FROM local_media_repository
|
||||
WHERE sha256 = ? AND quarantined_by IS NOT NULL
|
||||
|
||||
UNION ALL
|
||||
|
||||
SELECT 1
|
||||
FROM remote_media_cache
|
||||
WHERE sha256 = ? AND quarantined_by IS NOT NULL
|
||||
LIMIT 1
|
||||
"""
|
||||
txn.execute(sql, (sha256, sha256))
|
||||
row = txn.fetchone()
|
||||
return row is not None
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_matching_media_txn",
|
||||
get_matching_media_txn,
|
||||
"local_media_repository",
|
||||
sha256,
|
||||
)
|
||||
|
|
|
@ -51,11 +51,15 @@ from synapse.api.room_versions import RoomVersion, RoomVersions
|
|||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.events import EventBase
|
||||
from synapse.replication.tcp.streams.partial_state import UnPartialStatedRoomStream
|
||||
from synapse.storage._base import db_to_json, make_in_list_sql_clause
|
||||
from synapse.storage._base import (
|
||||
db_to_json,
|
||||
make_in_list_sql_clause,
|
||||
)
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
make_tuple_in_list_sql_clause,
|
||||
)
|
||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.types import Cursor
|
||||
|
@ -1127,6 +1131,109 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||
|
||||
return local_media_ids
|
||||
|
||||
def _quarantine_local_media_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
hashes: Set[str],
|
||||
media_ids: Set[str],
|
||||
quarantined_by: Optional[str],
|
||||
) -> int:
|
||||
"""Quarantine and unquarantine local media items.
|
||||
|
||||
Args:
|
||||
txn (cursor)
|
||||
hashes: A set of sha256 hashes for any media that should be quarantined
|
||||
media_ids: A set of media IDs for any media that should be quarantined
|
||||
quarantined_by: The ID of the user who initiated the quarantine request
|
||||
If it is `None` media will be removed from quarantine
|
||||
Returns:
|
||||
The total number of media items quarantined
|
||||
"""
|
||||
total_media_quarantined = 0
|
||||
|
||||
# Effectively a legacy path, update any media that was explicitly named.
|
||||
if media_ids:
|
||||
sql_many_clause_sql, sql_many_clause_args = make_in_list_sql_clause(
|
||||
txn.database_engine, "media_id", media_ids
|
||||
)
|
||||
sql = f"""
|
||||
UPDATE local_media_repository
|
||||
SET quarantined_by = ?
|
||||
WHERE {sql_many_clause_sql}"""
|
||||
|
||||
if quarantined_by is not None:
|
||||
sql += " AND safe_from_quarantine = FALSE"
|
||||
|
||||
txn.execute(sql, [quarantined_by] + sql_many_clause_args)
|
||||
# Note that a rowcount of -1 can be used to indicate no rows were affected.
|
||||
total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
|
||||
|
||||
# Update any media that was identified via hash.
|
||||
if hashes:
|
||||
sql_many_clause_sql, sql_many_clause_args = make_in_list_sql_clause(
|
||||
txn.database_engine, "sha256", hashes
|
||||
)
|
||||
sql = f"""
|
||||
UPDATE local_media_repository
|
||||
SET quarantined_by = ?
|
||||
WHERE {sql_many_clause_sql}"""
|
||||
|
||||
if quarantined_by is not None:
|
||||
sql += " AND safe_from_quarantine = FALSE"
|
||||
|
||||
txn.execute(sql, [quarantined_by] + sql_many_clause_args)
|
||||
total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
|
||||
|
||||
return total_media_quarantined
|
||||
|
||||
def _quarantine_remote_media_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
hashes: Set[str],
|
||||
media: Set[Tuple[str, str]],
|
||||
quarantined_by: Optional[str],
|
||||
) -> int:
|
||||
"""Quarantine and unquarantine remote items
|
||||
|
||||
Args:
|
||||
txn (cursor)
|
||||
hashes: A set of sha256 hashes for any media that should be quarantined
|
||||
media_ids: A set of tuples (media_origin, media_id) for any media that should be quarantined
|
||||
quarantined_by: The ID of the user who initiated the quarantine request
|
||||
If it is `None` media will be removed from quarantine
|
||||
Returns:
|
||||
The total number of media items quarantined
|
||||
"""
|
||||
total_media_quarantined = 0
|
||||
|
||||
if media:
|
||||
sql_in_list_clause, sql_args = make_tuple_in_list_sql_clause(
|
||||
txn.database_engine,
|
||||
("media_origin", "media_id"),
|
||||
media,
|
||||
)
|
||||
sql = f"""
|
||||
UPDATE remote_media_cache
|
||||
SET quarantined_by = ?
|
||||
WHERE {sql_in_list_clause}"""
|
||||
|
||||
txn.execute(sql, [quarantined_by] + sql_args)
|
||||
total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
|
||||
|
||||
total_media_quarantined = 0
|
||||
if hashes:
|
||||
sql_many_clause_sql, sql_many_clause_args = make_in_list_sql_clause(
|
||||
txn.database_engine, "sha256", hashes
|
||||
)
|
||||
sql = f"""
|
||||
UPDATE remote_media_cache
|
||||
SET quarantined_by = ?
|
||||
WHERE {sql_many_clause_sql}"""
|
||||
txn.execute(sql, [quarantined_by] + sql_many_clause_args)
|
||||
total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
|
||||
|
||||
return total_media_quarantined
|
||||
|
||||
def _quarantine_media_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
|
@ -1146,40 +1253,49 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||
Returns:
|
||||
The total number of media items quarantined
|
||||
"""
|
||||
hashes = set()
|
||||
media_ids = set()
|
||||
remote_media = set()
|
||||
|
||||
# Update all the tables to set the quarantined_by flag
|
||||
sql = """
|
||||
UPDATE local_media_repository
|
||||
SET quarantined_by = ?
|
||||
WHERE media_id = ?
|
||||
"""
|
||||
|
||||
# set quarantine
|
||||
if quarantined_by is not None:
|
||||
sql += "AND safe_from_quarantine = FALSE"
|
||||
txn.executemany(
|
||||
sql, [(quarantined_by, media_id) for media_id in local_mxcs]
|
||||
# First, determine the hashes of the media we want to delete.
|
||||
# We also want the media_ids for any media that lacks a hash.
|
||||
if local_mxcs:
|
||||
hash_sql_many_clause_sql, hash_sql_many_clause_args = (
|
||||
make_in_list_sql_clause(txn.database_engine, "media_id", local_mxcs)
|
||||
)
|
||||
# remove from quarantine
|
||||
else:
|
||||
txn.executemany(
|
||||
sql, [(quarantined_by, media_id) for media_id in local_mxcs]
|
||||
hash_sql = f"SELECT sha256, media_id FROM local_media_repository WHERE {hash_sql_many_clause_sql}"
|
||||
if quarantined_by is not None:
|
||||
hash_sql += " AND safe_from_quarantine = FALSE"
|
||||
|
||||
txn.execute(hash_sql, hash_sql_many_clause_args)
|
||||
for sha256, media_id in txn:
|
||||
if sha256:
|
||||
hashes.add(sha256)
|
||||
else:
|
||||
media_ids.add(media_id)
|
||||
|
||||
# Do the same for remote media
|
||||
if remote_mxcs:
|
||||
hash_sql_in_list_clause, hash_sql_args = make_tuple_in_list_sql_clause(
|
||||
txn.database_engine,
|
||||
("media_origin", "media_id"),
|
||||
remote_mxcs,
|
||||
)
|
||||
|
||||
# Note that a rowcount of -1 can be used to indicate no rows were affected.
|
||||
total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0
|
||||
hash_sql = f"SELECT sha256, media_origin, media_id FROM remote_media_cache WHERE {hash_sql_in_list_clause}"
|
||||
txn.execute(hash_sql, hash_sql_args)
|
||||
for sha256, media_origin, media_id in txn:
|
||||
if sha256:
|
||||
hashes.add(sha256)
|
||||
else:
|
||||
remote_media.add((media_origin, media_id))
|
||||
|
||||
txn.executemany(
|
||||
"""
|
||||
UPDATE remote_media_cache
|
||||
SET quarantined_by = ?
|
||||
WHERE media_origin = ? AND media_id = ?
|
||||
""",
|
||||
[(quarantined_by, origin, media_id) for origin, media_id in remote_mxcs],
|
||||
count = self._quarantine_local_media_txn(txn, hashes, media_ids, quarantined_by)
|
||||
count += self._quarantine_remote_media_txn(
|
||||
txn, hashes, remote_media, quarantined_by
|
||||
)
|
||||
total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
|
||||
|
||||
return total_media_quarantined
|
||||
return count
|
||||
|
||||
async def block_room(self, room_id: str, user_id: str) -> None:
|
||||
"""Marks the room as blocked.
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#
|
||||
#
|
||||
|
||||
SCHEMA_VERSION = 90 # remember to update the list below when updating
|
||||
SCHEMA_VERSION = 91 # remember to update the list below when updating
|
||||
"""Represents the expectations made by the codebase about the database schema
|
||||
|
||||
This should be incremented whenever the codebase changes its requirements on the
|
||||
|
|
21
synapse/storage/schema/main/delta/91/01_media_hash.sql
Normal file
21
synapse/storage/schema/main/delta/91/01_media_hash.sql
Normal file
|
@ -0,0 +1,21 @@
|
|||
--
|
||||
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
--
|
||||
-- Copyright (C) 2025 New Vector, Ltd
|
||||
--
|
||||
-- This program is free software: you can redistribute it and/or modify
|
||||
-- it under the terms of the GNU Affero General Public License as
|
||||
-- published by the Free Software Foundation, either version 3 of the
|
||||
-- License, or (at your option) any later version.
|
||||
--
|
||||
-- See the GNU Affero General Public License for more details:
|
||||
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
-- Store the SHA256 content hash of media files.
|
||||
ALTER TABLE local_media_repository ADD COLUMN sha256 TEXT;
|
||||
ALTER TABLE remote_media_cache ADD COLUMN sha256 TEXT;
|
||||
|
||||
-- Add a background updates to handle creating the new index.
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||
(9101, 'local_media_repository_sha256_idx', '{}'),
|
||||
(9101, 'remote_media_cache_sha256_idx', '{}');
|
|
@ -369,6 +369,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
|||
time_now_ms=self.clock.time_msec(),
|
||||
upload_name=None,
|
||||
filesystem_id="xyz",
|
||||
sha256="abcdefg12345",
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -31,6 +31,9 @@ from synapse.rest.client import login, register, room
|
|||
from synapse.server import HomeServer
|
||||
from synapse.types import UserID
|
||||
from synapse.util import Clock
|
||||
from synapse.util.stringutils import (
|
||||
random_string,
|
||||
)
|
||||
|
||||
from tests import unittest
|
||||
from tests.unittest import override_config
|
||||
|
@ -65,7 +68,6 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
|
|||
# quarantined media) into both the local store and the remote cache, plus
|
||||
# one additional local media that is marked as protected from quarantine.
|
||||
media_repository = hs.get_media_repository()
|
||||
test_media_content = b"example string"
|
||||
|
||||
def _create_media_and_set_attributes(
|
||||
last_accessed_ms: Optional[int],
|
||||
|
@ -73,12 +75,14 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
|
|||
is_protected: Optional[bool] = False,
|
||||
) -> MXCUri:
|
||||
# "Upload" some media to the local media store
|
||||
# If the meda
|
||||
random_content = bytes(random_string(24), "utf-8")
|
||||
mxc_uri: MXCUri = self.get_success(
|
||||
media_repository.create_content(
|
||||
media_type="text/plain",
|
||||
upload_name=None,
|
||||
content=io.BytesIO(test_media_content),
|
||||
content_length=len(test_media_content),
|
||||
content=io.BytesIO(random_content),
|
||||
content_length=len(random_content),
|
||||
auth_user=UserID.from_string(test_user_id),
|
||||
)
|
||||
)
|
||||
|
@ -129,6 +133,7 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
|
|||
time_now_ms=clock.time_msec(),
|
||||
upload_name="testfile.txt",
|
||||
filesystem_id="abcdefg12345",
|
||||
sha256=random_string(24),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -42,6 +42,7 @@ from twisted.web.resource import Resource
|
|||
from synapse.api.errors import Codes, HttpResponseException
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.events import EventBase
|
||||
from synapse.http.client import ByteWriteable
|
||||
from synapse.http.types import QueryParams
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.media._base import FileInfo, ThumbnailInfo
|
||||
|
@ -59,7 +60,7 @@ from synapse.util import Clock
|
|||
|
||||
from tests import unittest
|
||||
from tests.server import FakeChannel
|
||||
from tests.test_utils import SMALL_CMYK_JPEG, SMALL_PNG
|
||||
from tests.test_utils import SMALL_CMYK_JPEG, SMALL_PNG, SMALL_PNG_SHA256
|
||||
from tests.unittest import override_config
|
||||
from tests.utils import default_config
|
||||
|
||||
|
@ -1257,3 +1258,107 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
assert channel.code == 502
|
||||
assert channel.json_body["errcode"] == "M_TOO_LARGE"
|
||||
|
||||
|
||||
def read_body(
|
||||
response: IResponse, stream: ByteWriteable, max_size: Optional[int]
|
||||
) -> Deferred:
|
||||
d: Deferred = defer.Deferred()
|
||||
stream.write(SMALL_PNG)
|
||||
d.callback(len(SMALL_PNG))
|
||||
return d
|
||||
|
||||
|
||||
class MediaHashesTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
admin.register_servlets,
|
||||
login.register_servlets,
|
||||
media.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.user = self.register_user("user", "pass")
|
||||
self.tok = self.login("user", "pass")
|
||||
self.store = hs.get_datastores().main
|
||||
self.client = hs.get_federation_http_client()
|
||||
|
||||
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||
resources = super().create_resource_dict()
|
||||
resources["/_matrix/media"] = self.hs.get_media_repository_resource()
|
||||
return resources
|
||||
|
||||
def test_ensure_correct_sha256(self) -> None:
|
||||
"""Check that the hash does not change"""
|
||||
media = self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=200)
|
||||
mxc = media.get("content_uri")
|
||||
assert mxc
|
||||
store_media = self.get_success(self.store.get_local_media(mxc[11:]))
|
||||
assert store_media
|
||||
self.assertEqual(
|
||||
store_media.sha256,
|
||||
SMALL_PNG_SHA256,
|
||||
)
|
||||
|
||||
def test_ensure_multiple_correct_sha256(self) -> None:
|
||||
"""Check that two media items have the same hash."""
|
||||
media_a = self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=200)
|
||||
mxc_a = media_a.get("content_uri")
|
||||
assert mxc_a
|
||||
store_media_a = self.get_success(self.store.get_local_media(mxc_a[11:]))
|
||||
assert store_media_a
|
||||
|
||||
media_b = self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=200)
|
||||
mxc_b = media_b.get("content_uri")
|
||||
assert mxc_b
|
||||
store_media_b = self.get_success(self.store.get_local_media(mxc_b[11:]))
|
||||
assert store_media_b
|
||||
|
||||
self.assertNotEqual(
|
||||
store_media_a.media_id,
|
||||
store_media_b.media_id,
|
||||
)
|
||||
self.assertEqual(
|
||||
store_media_a.sha256,
|
||||
store_media_b.sha256,
|
||||
)
|
||||
|
||||
@override_config(
|
||||
{
|
||||
"enable_authenticated_media": False,
|
||||
}
|
||||
)
|
||||
# mock actually reading file body
|
||||
@patch(
|
||||
"synapse.http.matrixfederationclient.read_body_with_max_size",
|
||||
read_body,
|
||||
)
|
||||
def test_ensure_correct_sha256_federated(self) -> None:
|
||||
"""Check that federated media have the same hash."""
|
||||
|
||||
# Mock getting a file over federation
|
||||
async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
|
||||
resp = MagicMock(spec=IResponse)
|
||||
resp.code = 200
|
||||
resp.length = 500
|
||||
resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
|
||||
resp.phrase = b"OK"
|
||||
return resp
|
||||
|
||||
self.client._send_request = _send_request # type: ignore
|
||||
|
||||
# first request should go through
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/media/v3/download/remote.org/abc",
|
||||
shorthand=False,
|
||||
access_token=self.tok,
|
||||
)
|
||||
assert channel.code == 200
|
||||
store_media = self.get_success(
|
||||
self.store.get_cached_remote_media("remote.org", "abc")
|
||||
)
|
||||
assert store_media
|
||||
self.assertEqual(
|
||||
store_media.sha256,
|
||||
SMALL_PNG_SHA256,
|
||||
)
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#
|
||||
|
||||
import urllib.parse
|
||||
from typing import Dict
|
||||
from typing import Dict, cast
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
|
@ -32,6 +32,7 @@ from synapse.http.server import JsonResource
|
|||
from synapse.rest.admin import VersionServlet
|
||||
from synapse.rest.client import login, media, room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import UserID
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
|
@ -227,10 +228,25 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
|
|||
# Upload some media
|
||||
response_1 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
|
||||
response_2 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
|
||||
response_3 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
|
||||
|
||||
# Extract media IDs
|
||||
server_and_media_id_1 = response_1["content_uri"][6:]
|
||||
server_and_media_id_2 = response_2["content_uri"][6:]
|
||||
server_and_media_id_3 = response_3["content_uri"][6:]
|
||||
|
||||
# Remove the hash from the media to simulate historic media.
|
||||
self.get_success(
|
||||
self.hs.get_datastores().main.update_local_media(
|
||||
media_id=server_and_media_id_3.split("/")[1],
|
||||
media_type="image/png",
|
||||
upload_name=None,
|
||||
media_length=123,
|
||||
user_id=UserID.from_string(non_admin_user),
|
||||
# Hack to force some media to have no hash.
|
||||
sha256=cast(str, None),
|
||||
)
|
||||
)
|
||||
|
||||
# Quarantine all media by this user
|
||||
url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
|
||||
|
@ -244,12 +260,13 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
|
|||
self.pump(1.0)
|
||||
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||
self.assertEqual(
|
||||
channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items"
|
||||
channel.json_body, {"num_quarantined": 3}, "Expected 3 quarantined items"
|
||||
)
|
||||
|
||||
# Attempt to access each piece of media
|
||||
self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
|
||||
self._ensure_quarantined(admin_user_tok, server_and_media_id_2)
|
||||
self._ensure_quarantined(admin_user_tok, server_and_media_id_3)
|
||||
|
||||
def test_cannot_quarantine_safe_media(self) -> None:
|
||||
self.register_user("user_admin", "pass", admin=True)
|
||||
|
|
|
@ -35,7 +35,7 @@ from synapse.server import HomeServer
|
|||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.test_utils import SMALL_PNG
|
||||
from tests.test_utils import SMALL_CMYK_JPEG, SMALL_PNG
|
||||
from tests.unittest import override_config
|
||||
|
||||
VALID_TIMESTAMP = 1609459200000 # 2021-01-01 in milliseconds
|
||||
|
@ -598,23 +598,27 @@ class DeleteMediaByDateSizeTestCase(_AdminMediaTests):
|
|||
|
||||
|
||||
class QuarantineMediaByIDTestCase(_AdminMediaTests):
|
||||
def upload_media_and_return_media_id(self, data: bytes) -> str:
|
||||
# Upload some media into the room
|
||||
response = self.helper.upload_media(
|
||||
data,
|
||||
tok=self.admin_user_tok,
|
||||
expect_code=200,
|
||||
)
|
||||
# Extract media ID from the response
|
||||
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
|
||||
return server_and_media_id.split("/")[1]
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.store = hs.get_datastores().main
|
||||
self.server_name = hs.hostname
|
||||
|
||||
self.admin_user = self.register_user("admin", "pass", admin=True)
|
||||
self.admin_user_tok = self.login("admin", "pass")
|
||||
|
||||
# Upload some media into the room
|
||||
response = self.helper.upload_media(
|
||||
SMALL_PNG,
|
||||
tok=self.admin_user_tok,
|
||||
expect_code=200,
|
||||
)
|
||||
# Extract media ID from the response
|
||||
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
|
||||
self.media_id = server_and_media_id.split("/")[1]
|
||||
|
||||
self.media_id = self.upload_media_and_return_media_id(SMALL_PNG)
|
||||
self.media_id_2 = self.upload_media_and_return_media_id(SMALL_PNG)
|
||||
self.media_id_3 = self.upload_media_and_return_media_id(SMALL_PNG)
|
||||
self.media_id_other = self.upload_media_and_return_media_id(SMALL_CMYK_JPEG)
|
||||
self.url = "/_synapse/admin/v1/media/%s/%s/%s"
|
||||
|
||||
@parameterized.expand(["quarantine", "unquarantine"])
|
||||
|
@ -686,6 +690,52 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
|
|||
assert media_info is not None
|
||||
self.assertFalse(media_info.quarantined_by)
|
||||
|
||||
def test_quarantine_media_match_hash(self) -> None:
|
||||
"""
|
||||
Tests that quarantining removes all media with the same hash
|
||||
"""
|
||||
|
||||
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
||||
assert media_info is not None
|
||||
self.assertFalse(media_info.quarantined_by)
|
||||
|
||||
# quarantining
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
self.url % ("quarantine", self.server_name, self.media_id),
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||
self.assertFalse(channel.json_body)
|
||||
|
||||
# Test that ALL similar media was quarantined.
|
||||
for media in [self.media_id, self.media_id_2, self.media_id_3]:
|
||||
media_info = self.get_success(self.store.get_local_media(media))
|
||||
assert media_info is not None
|
||||
self.assertTrue(media_info.quarantined_by)
|
||||
|
||||
# Test that other media was not.
|
||||
media_info = self.get_success(self.store.get_local_media(self.media_id_other))
|
||||
assert media_info is not None
|
||||
self.assertFalse(media_info.quarantined_by)
|
||||
|
||||
# remove from quarantine
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
self.url % ("unquarantine", self.server_name, self.media_id),
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||
self.assertFalse(channel.json_body)
|
||||
|
||||
# Test that ALL similar media is now reset.
|
||||
for media in [self.media_id, self.media_id_2, self.media_id_3]:
|
||||
media_info = self.get_success(self.store.get_local_media(media))
|
||||
assert media_info is not None
|
||||
self.assertFalse(media_info.quarantined_by)
|
||||
|
||||
def test_quarantine_protected_media(self) -> None:
|
||||
"""
|
||||
Tests that quarantining from protected media fails
|
||||
|
|
|
@ -137,6 +137,7 @@ class MediaDomainBlockingTests(unittest.HomeserverTestCase):
|
|||
time_now_ms=clock.time_msec(),
|
||||
upload_name="test.png",
|
||||
filesystem_id=file_id,
|
||||
sha256=file_id,
|
||||
)
|
||||
)
|
||||
self.register_user("user", "password")
|
||||
|
@ -2593,6 +2594,7 @@ class AuthenticatedMediaTestCase(unittest.HomeserverTestCase):
|
|||
time_now_ms=self.clock.time_msec(),
|
||||
upload_name="remote_test.png",
|
||||
filesystem_id=file_id,
|
||||
sha256=file_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -2725,6 +2727,7 @@ class AuthenticatedMediaTestCase(unittest.HomeserverTestCase):
|
|||
time_now_ms=self.clock.time_msec(),
|
||||
upload_name="remote_test.png",
|
||||
filesystem_id=file_id,
|
||||
sha256=file_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -61,6 +61,7 @@ class MediaDomainBlockingTests(unittest.HomeserverTestCase):
|
|||
time_now_ms=clock.time_msec(),
|
||||
upload_name="test.png",
|
||||
filesystem_id=file_id,
|
||||
sha256=file_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -139,6 +139,8 @@ SMALL_PNG = unhexlify(
|
|||
b"0000001f15c4890000000a49444154789c63000100000500010d"
|
||||
b"0a2db40000000049454e44ae426082"
|
||||
)
|
||||
# The SHA256 hexdigest for the above bytes.
|
||||
SMALL_PNG_SHA256 = "ebf4f635a17d10d6eb46ba680b70142419aa3220f228001a036d311a22ee9d2a"
|
||||
|
||||
# A small CMYK-encoded JPEG image used in some tests.
|
||||
#
|
||||
|
|
Loading…
Add table
Reference in a new issue