Ratelimiting of remote media downloads (#17256)

This commit is contained in:
Shay 2024-06-05 05:43:36 -07:00 committed by GitHub
parent aabf577166
commit fcbc79bb87
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 372 additions and 14 deletions

View file

@ -0,0 +1 @@
Improve ratelimiting in Synapse (#17256).

View file

@ -1946,6 +1946,24 @@ Example configuration:
max_image_pixels: 35M
```
---
### `remote_media_download_burst_count`
Remote media downloads are ratelimited using a [leaky bucket algorithm](https://en.wikipedia.org/wiki/Leaky_bucket), where a given "bucket" is keyed to the IP address of the requester when requesting remote media downloads. This configuration option sets the size of the bucket against which the size in bytes of downloads are penalized - if the bucket is full, ie a given number of bytes have already been downloaded, further downloads will be denied until the bucket drains. Defaults to 500MiB. See also `remote_media_download_per_second` which determines the rate at which the "bucket" is emptied and thus has available space to authorize new requests.
Example configuration:
```yaml
remote_media_download_burst_count: 200M
```
---
### `remote_media_download_per_second`
Works in conjunction with `remote_media_download_burst_count` to ratelimit remote media downloads - this configuration option determines the rate at which the "bucket" (see above) leaks in bytes per second. As requests are made to download remote media, the size of those requests in bytes is added to the bucket, and once the bucket has reached it's capacity, no more requests will be allowed until a number of bytes has "drained" from the bucket. This setting determines the rate at which bytes drain from the bucket, with the practical effect that the larger the number, the faster the bucket leaks, allowing for more bytes downloaded over a shorter period of time. Defaults to 87KiB per second. See also `remote_media_download_burst_count`.
Example configuration:
```yaml
remote_media_download_per_second: 40K
```
---
### `prevent_media_downloads_from`
A list of domains to never download media from. Media from these

View file

@ -218,3 +218,13 @@ class RatelimitConfig(Config):
"rc_media_create",
defaults={"per_second": 10, "burst_count": 50},
)
self.remote_media_downloads = RatelimitSettings(
key="rc_remote_media_downloads",
per_second=self.parse_size(
config.get("remote_media_download_per_second", "87K")
),
burst_count=self.parse_size(
config.get("remote_media_download_burst_count", "500M")
),
)

View file

@ -56,6 +56,7 @@ from synapse.api.errors import (
SynapseError,
UnsupportedRoomVersionError,
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
@ -1877,6 +1878,8 @@ class FederationClient(FederationBase):
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
try:
return await self.transport_layer.download_media_v3(
@ -1885,6 +1888,8 @@ class FederationClient(FederationBase):
output_stream=output_stream,
max_size=max_size,
max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)
except HttpResponseException as e:
# If an error is received that is due to an unrecognised endpoint,
@ -1905,6 +1910,8 @@ class FederationClient(FederationBase):
output_stream=output_stream,
max_size=max_size,
max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)

View file

@ -43,6 +43,7 @@ import ijson
from synapse.api.constants import Direction, Membership
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.api.room_versions import RoomVersion
from synapse.api.urls import (
FEDERATION_UNSTABLE_PREFIX,
@ -819,6 +820,8 @@ class TransportLayerClient:
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
path = f"/_matrix/media/r0/download/{destination}/{media_id}"
@ -834,6 +837,8 @@ class TransportLayerClient:
"allow_remote": "false",
"timeout_ms": str(max_timeout_ms),
},
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)
async def download_media_v3(
@ -843,6 +848,8 @@ class TransportLayerClient:
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
path = f"/_matrix/media/v3/download/{destination}/{media_id}"
@ -862,6 +869,8 @@ class TransportLayerClient:
"allow_redirect": "true",
},
follow_redirects=True,
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)

View file

@ -57,7 +57,7 @@ from twisted.internet.interfaces import IReactorTime
from twisted.internet.task import Cooperator
from twisted.web.client import ResponseFailed
from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent, IBodyProducer, IResponse
from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse
import synapse.metrics
import synapse.util.retryutils
@ -68,6 +68,7 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http import QuieterFileBodyProducer
from synapse.http.client import (
@ -1411,9 +1412,11 @@ class MatrixFederationHttpClient:
destination: str,
path: str,
output_stream: BinaryIO,
download_ratelimiter: Ratelimiter,
ip_address: str,
max_size: int,
args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
max_size: Optional[int] = None,
ignore_backoff: bool = False,
follow_redirects: bool = False,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
@ -1422,6 +1425,10 @@ class MatrixFederationHttpClient:
destination: The remote server to send the HTTP request to.
path: The HTTP path to GET.
output_stream: File to write the response body to.
download_ratelimiter: a ratelimiter to limit remote media downloads, keyed to
requester IP
ip_address: IP address of the requester
max_size: maximum allowable size in bytes of the file
args: Optional dictionary used to create the query string.
ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
@ -1441,11 +1448,27 @@ class MatrixFederationHttpClient:
federation whitelist
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
SynapseError: If the requested file exceeds ratelimits
"""
request = MatrixFederationRequest(
method="GET", destination=destination, path=path, query=args
)
# check for a minimum balance of 1MiB in ratelimiter before initiating request
send_req, _ = await download_ratelimiter.can_do_action(
requester=None, key=ip_address, n_actions=1048576, update=False
)
if not send_req:
msg = "Requested file size exceeds ratelimits"
logger.warning(
"{%s} [%s] %s",
request.txn_id,
request.destination,
msg,
)
raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
response = await self._send_request(
request,
retry_on_dns_fail=retry_on_dns_fail,
@ -1455,12 +1478,36 @@ class MatrixFederationHttpClient:
headers = dict(response.headers.getAllRawHeaders())
expected_size = response.length
# if we don't get an expected length then use the max length
if expected_size == UNKNOWN_LENGTH:
expected_size = max_size
logger.debug(
f"File size unknown, assuming file is max allowable size: {max_size}"
)
read_body, _ = await download_ratelimiter.can_do_action(
requester=None,
key=ip_address,
n_actions=expected_size,
)
if not read_body:
msg = "Requested file size exceeds ratelimits"
logger.warning(
"{%s} [%s] %s",
request.txn_id,
request.destination,
msg,
)
raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
try:
d = read_body_with_max_size(response, output_stream, max_size)
# add a byte of headroom to max size as function errs at >=
d = read_body_with_max_size(response, output_stream, expected_size + 1)
d.addTimeout(self.default_timeout_seconds, self.reactor)
length = await make_deferred_yieldable(d)
except BodyExceededMaxSize:
msg = "Requested file is too large > %r bytes" % (max_size,)
msg = "Requested file is too large > %r bytes" % (expected_size,)
logger.warning(
"{%s} [%s] %s",
request.txn_id,

View file

@ -42,6 +42,7 @@ from synapse.api.errors import (
SynapseError,
cs_error,
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.config.repository import ThumbnailRequirement
from synapse.http.server import respond_with_json
from synapse.http.site import SynapseRequest
@ -111,6 +112,12 @@ class MediaRepository:
)
self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
self.download_ratelimiter = Ratelimiter(
store=hs.get_storage_controllers().main,
clock=hs.get_clock(),
cfg=hs.config.ratelimiting.remote_media_downloads,
)
# List of StorageProviders where we should search for media and
# potentially upload to.
storage_providers = []
@ -464,6 +471,7 @@ class MediaRepository:
media_id: str,
name: Optional[str],
max_timeout_ms: int,
ip_address: str,
) -> None:
"""Respond to requests for remote media.
@ -475,6 +483,7 @@ class MediaRepository:
the filename in the Content-Disposition header of the response.
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
ip_address: the IP address of the requester
Returns:
Resolves once a response has successfully been written to request
@ -500,7 +509,11 @@ class MediaRepository:
key = (server_name, media_id)
async with self.remote_media_linearizer.queue(key):
responder, media_info = await self._get_remote_media_impl(
server_name, media_id, max_timeout_ms
server_name,
media_id,
max_timeout_ms,
self.download_ratelimiter,
ip_address,
)
# We deliberately stream the file outside the lock
@ -517,7 +530,7 @@ class MediaRepository:
respond_404(request)
async def get_remote_media_info(
self, server_name: str, media_id: str, max_timeout_ms: int
self, server_name: str, media_id: str, max_timeout_ms: int, ip_address: str
) -> RemoteMedia:
"""Gets the media info associated with the remote file, downloading
if necessary.
@ -527,6 +540,7 @@ class MediaRepository:
media_id: The media ID of the content (as defined by the remote server).
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
ip_address: IP address of the requester
Returns:
The media info of the file
@ -542,7 +556,11 @@ class MediaRepository:
key = (server_name, media_id)
async with self.remote_media_linearizer.queue(key):
responder, media_info = await self._get_remote_media_impl(
server_name, media_id, max_timeout_ms
server_name,
media_id,
max_timeout_ms,
self.download_ratelimiter,
ip_address,
)
# Ensure we actually use the responder so that it releases resources
@ -553,7 +571,12 @@ class MediaRepository:
return media_info
async def _get_remote_media_impl(
self, server_name: str, media_id: str, max_timeout_ms: int
self,
server_name: str,
media_id: str,
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Tuple[Optional[Responder], RemoteMedia]:
"""Looks for media in local cache, if not there then attempt to
download from remote server.
@ -564,6 +587,9 @@ class MediaRepository:
remote server).
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
requester IP.
ip_address: the IP address of the requester
Returns:
A tuple of responder and the media info of the file.
@ -596,7 +622,7 @@ class MediaRepository:
try:
media_info = await self._download_remote_file(
server_name, media_id, max_timeout_ms
server_name, media_id, max_timeout_ms, download_ratelimiter, ip_address
)
except SynapseError:
raise
@ -630,6 +656,8 @@ class MediaRepository:
server_name: str,
media_id: str,
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> RemoteMedia:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
@ -641,6 +669,9 @@ class MediaRepository:
locally generated.
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
requester IP
ip_address: the IP address of the requester
Returns:
The media info of the file.
@ -658,6 +689,8 @@ class MediaRepository:
output_stream=f,
max_size=self.max_upload_size,
max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)
except RequestSendFailed as e:
logger.warning(

View file

@ -359,9 +359,10 @@ class ThumbnailProvider:
desired_method: str,
desired_type: str,
max_timeout_ms: int,
ip_address: str,
) -> None:
media_info = await self.media_repo.get_remote_media_info(
server_name, media_id, max_timeout_ms
server_name, media_id, max_timeout_ms, ip_address
)
if not media_info:
respond_404(request)
@ -422,12 +423,13 @@ class ThumbnailProvider:
method: str,
m_type: str,
max_timeout_ms: int,
ip_address: str,
) -> None:
# TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead of
# downloading the remote file and generating our own thumbnails.
media_info = await self.media_repo.get_remote_media_info(
server_name, media_id, max_timeout_ms
server_name, media_id, max_timeout_ms, ip_address
)
if not media_info:
return

View file

@ -174,6 +174,7 @@ class UnstableThumbnailResource(RestServlet):
respond_404(request)
return
ip_address = request.getClientAddress().host
remote_resp_function = (
self.thumbnailer.select_or_generate_remote_thumbnail
if self.dynamic_thumbnails
@ -188,6 +189,7 @@ class UnstableThumbnailResource(RestServlet):
method,
m_type,
max_timeout_ms,
ip_address,
)
self.media_repo.mark_recently_accessed(server_name, media_id)

View file

@ -97,6 +97,12 @@ class DownloadResource(RestServlet):
respond_404(request)
return
ip_address = request.getClientAddress().host
await self.media_repo.get_remote_media(
request, server_name, media_id, file_name, max_timeout_ms
request,
server_name,
media_id,
file_name,
max_timeout_ms,
ip_address,
)

View file

@ -104,6 +104,7 @@ class ThumbnailResource(RestServlet):
respond_404(request)
return
ip_address = request.getClientAddress().host
remote_resp_function = (
self.thumbnail_provider.select_or_generate_remote_thumbnail
if self.dynamic_thumbnails
@ -118,5 +119,6 @@ class ThumbnailResource(RestServlet):
method,
m_type,
max_timeout_ms,
ip_address,
)
self.media_repo.mark_recently_accessed(server_name, media_id)

View file

@ -25,7 +25,7 @@ import tempfile
from binascii import unhexlify
from io import BytesIO
from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Tuple, Union
from unittest.mock import Mock
from unittest.mock import MagicMock, Mock, patch
from urllib import parse
import attr
@ -37,9 +37,12 @@ from twisted.internet import defer
from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.http_headers import Headers
from twisted.web.iweb import UNKNOWN_LENGTH, IResponse
from twisted.web.resource import Resource
from synapse.api.errors import Codes, HttpResponseException
from synapse.api.ratelimiting import Ratelimiter
from synapse.events import EventBase
from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable
@ -59,6 +62,7 @@ from synapse.util import Clock
from tests import unittest
from tests.server import FakeChannel
from tests.test_utils import SMALL_PNG
from tests.unittest import override_config
from tests.utils import default_config
@ -251,9 +255,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
destination: str,
path: str,
output_stream: BinaryIO,
download_ratelimiter: Ratelimiter,
ip_address: Any,
max_size: int,
args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
max_size: Optional[int] = None,
ignore_backoff: bool = False,
follow_redirects: bool = False,
) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
@ -878,3 +884,218 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
tok=self.tok,
expect_code=400,
)
class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
self.storage_path = self.mktemp()
self.media_store_path = self.mktemp()
os.mkdir(self.storage_path)
os.mkdir(self.media_store_path)
config["media_store_path"] = self.media_store_path
provider_config = {
"module": "synapse.media.storage_provider.FileStorageProviderBackend",
"store_local": True,
"store_synchronous": False,
"store_remote": True,
"config": {"directory": self.storage_path},
}
config["media_storage_providers"] = [provider_config]
return self.setup_test_homeserver(config=config)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.repo = hs.get_media_repository()
self.client = hs.get_federation_http_client()
self.store = hs.get_datastores().main
def create_resource_dict(self) -> Dict[str, Resource]:
# We need to manually set the resource tree to include media, the
# default only does `/_matrix/client` APIs.
return {"/_matrix/media": self.hs.get_media_repository_resource()}
# mock actually reading file body
def read_body_with_max_size_30MiB(*args: Any, **kwargs: Any) -> Deferred:
d: Deferred = defer.Deferred()
d.callback(31457280)
return d
def read_body_with_max_size_50MiB(*args: Any, **kwargs: Any) -> Deferred:
d: Deferred = defer.Deferred()
d.callback(52428800)
return d
@patch(
"synapse.http.matrixfederationclient.read_body_with_max_size",
read_body_with_max_size_30MiB,
)
def test_download_ratelimit_default(self) -> None:
"""
Test remote media download ratelimiting against default configuration - 500MB bucket
and 87kb/second drain rate
"""
# mock out actually sending the request, returns a 30MiB response
async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
resp = MagicMock(spec=IResponse)
resp.code = 200
resp.length = 31457280
resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
resp.phrase = b"OK"
return resp
self.client._send_request = _send_request # type: ignore
# first request should go through
channel = self.make_request(
"GET",
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz",
shorthand=False,
)
assert channel.code == 200
# next 15 should go through
for i in range(15):
channel2 = self.make_request(
"GET",
f"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy{i}",
shorthand=False,
)
assert channel2.code == 200
# 17th will hit ratelimit
channel3 = self.make_request(
"GET",
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyx",
shorthand=False,
)
assert channel3.code == 429
# however, a request from a different IP will go through
channel4 = self.make_request(
"GET",
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz",
shorthand=False,
client_ip="187.233.230.159",
)
assert channel4.code == 200
# at 87Kib/s it should take about 2 minutes for enough to drain from bucket that another
# 30MiB download is authorized - The last download was blocked at 503,316,480.
# The next download will be authorized when bucket hits 492,830,720
# (524,288,000 total capacity - 31,457,280 download size) so 503,316,480 - 492,830,720 ~= 10,485,760
# needs to drain before another download will be authorized, that will take ~=
# 2 minutes (10,485,760/89,088/60)
self.reactor.pump([2.0 * 60.0])
# enough has drained and next request goes through
channel5 = self.make_request(
"GET",
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyb",
shorthand=False,
)
assert channel5.code == 200
@override_config(
{
"remote_media_download_per_second": "50M",
"remote_media_download_burst_count": "50M",
}
)
@patch(
"synapse.http.matrixfederationclient.read_body_with_max_size",
read_body_with_max_size_50MiB,
)
def test_download_rate_limit_config(self) -> None:
"""
Test that download rate limit config options are correctly picked up and applied
"""
async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
resp = MagicMock(spec=IResponse)
resp.code = 200
resp.length = 52428800
resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
resp.phrase = b"OK"
return resp
self.client._send_request = _send_request # type: ignore
# first request should go through
channel = self.make_request(
"GET",
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz",
shorthand=False,
)
assert channel.code == 200
# immediate second request should fail
channel = self.make_request(
"GET",
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy1",
shorthand=False,
)
assert channel.code == 429
# advance half a second
self.reactor.pump([0.5])
# request still fails
channel = self.make_request(
"GET",
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy2",
shorthand=False,
)
assert channel.code == 429
# advance another half second
self.reactor.pump([0.5])
# enough has drained from bucket and request is successful
channel = self.make_request(
"GET",
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy3",
shorthand=False,
)
assert channel.code == 200
@patch(
"synapse.http.matrixfederationclient.read_body_with_max_size",
read_body_with_max_size_30MiB,
)
def test_download_ratelimit_max_size_sub(self) -> None:
"""
Test that if no content-length is provided, the default max size is applied instead
"""
# mock out actually sending the request
async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
resp = MagicMock(spec=IResponse)
resp.code = 200
resp.length = UNKNOWN_LENGTH
resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
resp.phrase = b"OK"
return resp
self.client._send_request = _send_request # type: ignore
# ten requests should go through using the max size (500MB/50MB)
for i in range(10):
channel2 = self.make_request(
"GET",
f"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy{i}",
shorthand=False,
)
assert channel2.code == 200
# eleventh will hit ratelimit
channel3 = self.make_request(
"GET",
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyx",
shorthand=False,
)
assert channel3.code == 429