1
0
Fork 0
mirror of https://github.com/element-hq/synapse.git synced 2025-04-08 11:13:59 +00:00

Refactor to support handling hashes for both.

This commit is contained in:
Half-Shot 2025-03-26 15:34:49 +00:00
parent e75d498a1b
commit 52d74c2081

View file

@ -1134,70 +1134,27 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def _quarantine_local_media_txn(
self,
txn: LoggingTransaction,
mxcs: List[str],
hashes: Set[str],
media_ids: Set[str],
quarantined_by: Optional[str],
) -> int:
"""Quarantine and unquarantine local items
"""Quarantine and unquarantine local media items.
Args:
txn (cursor)
mxcs: A list of local media ids
hashes: A set of sha256 hashes for any media that should be quarantined
media_ids: A set of media IDs for any media that should be quarantined
quarantined_by: The ID of the user who initiated the quarantine request
If it is `None` media will be removed from quarantine
Returns:
The total number of media items quarantined
"""
if not mxcs:
# Shortcircuit if the mxc list is empty
return 0
# First, determine the hashes of the media we want to delete locally.
# We also want the media_ids for any media that lacks a hash.
hash_sql_many_clause_sql, hash_sql_many_clause_args = make_in_list_sql_clause(
txn.database_engine, "media_id", mxcs
)
hash_sql = f"SELECT sha256, media_id FROM local_media_repository WHERE {hash_sql_many_clause_sql}"
if quarantined_by is not None:
hash_sql += " AND safe_from_quarantine = FALSE"
txn.execute(hash_sql, hash_sql_many_clause_args)
results = txn.fetchall()
# Split results into hashes, and hashless media.
hashes = set()
non_hashed_media_ids = set()
for sha256, media_id in txn:
if sha256:
hashes.add(sha256)
else:
non_hashed_media_ids.add(media_id)
total_media_quarantined = 0
# Effectively a legacy path, update any media
# that was explicitly named.
if non_hashed_media_ids:
# Effectively a legacy path, update any media that was explicitly named.
if media_ids:
sql_many_clause_sql, sql_many_clause_args = make_in_list_sql_clause(
txn.database_engine, "media_id", non_hashed_media_ids
)
sql = f"""
UPDATE local_media_repository
SET quarantined_by = ?
WHERE {sql_many_clause_sql}"""
if quarantined_by is not None:
sql += " AND safe_from_quarantine = FALSE"
txn.execute(sql, [quarantined_by] + sql_many_clause_args)
total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
# Update any media that was identified via hash.
if hashes:
# Update all the tables to set the quarantined_by flag
# We also pick up any media with a matching hash.
sql_many_clause_sql, sql_many_clause_args = make_in_list_sql_clause(
txn.database_engine, "sha256", hashes
txn.database_engine, "media_id", media_ids
)
sql = f"""
UPDATE local_media_repository
@ -1211,59 +1168,49 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
# Note that a rowcount of -1 can be used to indicate no rows were affected.
total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
# Update any media that was identified via hash.
if hashes:
sql_many_clause_sql, sql_many_clause_args = make_in_list_sql_clause(
txn.database_engine, "sha256", hashes
)
sql = f"""
UPDATE local_media_repository
SET quarantined_by = ?
WHERE {sql_many_clause_sql}"""
if quarantined_by is not None:
sql += " AND safe_from_quarantine = FALSE"
txn.execute(sql, [quarantined_by] + sql_many_clause_args)
total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
return total_media_quarantined
def _quarantine_remote_media_txn(
self,
txn: LoggingTransaction,
mxcs: List[Tuple[str, str]],
hashes: Set[str],
media: Set[Tuple[str, str]],
quarantined_by: Optional[str],
) -> int:
"""Quarantine and unquarantine remote items
Args:
txn (cursor)
mxcs: A list of tuples of media_origin, media_id
hashes: A set of sha256 hashes for any media that should be quarantined
media_ids: A set of tuples (media_origin, media_id) for any media that should be quarantined
quarantined_by: The ID of the user who initiated the quarantine request
If it is `None` media will be removed from quarantine
Returns:
The total number of media items quarantined
"""
if not mxcs:
# Shortcircuit if the mxc list is empty
return 0
# First, determine the hashes of the media we want to delete locally.
# We also want the media_ids for any media that lacks a hash.
hashes = set()
non_hashed_media_ids = set()
hash_sql_in_list_clause, hash_sql_args = make_tuple_in_list_sql_clause(
txn.database_engine,
("media_origin", "media_id"),
mxcs,
)
hash_sql = f"SELECT sha256, media_origin, media_id FROM remote_media_cache WHERE {hash_sql_in_list_clause}"
txn.execute(hash_sql, hash_sql_args)
# Split results into hashes, and hashless media.
for sha256, media_origin, media_id in txn:
if sha256:
hashes.add(sha256)
else:
non_hashed_media_ids.add((media_origin, media_id))
total_media_quarantined = 0
# Effectively a legacy path, update any media
# that was explicitly named.
if non_hashed_media_ids:
if media:
sql_in_list_clause, sql_args = make_tuple_in_list_sql_clause(
txn.database_engine,
("media_origin", "media_id"),
non_hashed_media_ids,
media,
)
sql = f"""
UPDATE remote_media_cache
@ -1274,19 +1221,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
total_media_quarantined = 0
# Update any media that was identified via hash.
if hashes:
sql_many_clause_sql, sql_many_clause_args = make_in_list_sql_clause(
txn.database_engine, "sha256", hashes
)
# Update all the tables to set the quarantined_by flag
# We also pick up any media with a matching hash.
sql = f"""
UPDATE remote_media_cache
SET quarantined_by = ?
WHERE {sql_many_clause_sql}"""
txn.execute(sql, [quarantined_by] + sql_many_clause_args)
# Note that a rowcount of -1 can be used to indicate no rows were affected.
total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
return total_media_quarantined
@ -1310,9 +1253,47 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
Returns:
The total number of media items quarantined
"""
hashes = set()
media_ids = set()
remote_media = set()
count = self._quarantine_local_media_txn(txn, local_mxcs, quarantined_by)
count += self._quarantine_remote_media_txn(txn, remote_mxcs, quarantined_by)
# First, determine the hashes of the media we want to delete.
# We also want the media_ids for any media that lacks a hash.
if local_mxcs:
hash_sql_many_clause_sql, hash_sql_many_clause_args = (
make_in_list_sql_clause(txn.database_engine, "media_id", local_mxcs)
)
hash_sql = f"SELECT sha256, media_id FROM local_media_repository WHERE {hash_sql_many_clause_sql}"
if quarantined_by is not None:
hash_sql += " AND safe_from_quarantine = FALSE"
txn.execute(hash_sql, hash_sql_many_clause_args)
for sha256, media_id in txn:
if sha256:
hashes.add(sha256)
else:
media_ids.add(media_id)
# Do the same for remote media
if remote_mxcs:
hash_sql_in_list_clause, hash_sql_args = make_tuple_in_list_sql_clause(
txn.database_engine,
("media_origin", "media_id"),
remote_mxcs,
)
hash_sql = f"SELECT sha256, media_origin, media_id FROM remote_media_cache WHERE {hash_sql_in_list_clause}"
txn.execute(hash_sql, hash_sql_args)
for sha256, media_origin, media_id in txn:
if sha256:
hashes.add(sha256)
else:
remote_media.add((media_origin, media_id))
count = self._quarantine_local_media_txn(txn, hashes, media_ids, quarantined_by)
count += self._quarantine_remote_media_txn(
txn, hashes, remote_media, quarantined_by
)
return count