Don't allow unsupported content-type

Co-authored-by: Eric Eastwood <erice@element.io>
This commit is contained in:
Devon Hudson 2024-11-05 15:05:22 -07:00 committed by Quentin Gliech
parent d82e1ed357
commit 4b7154c585
No known key found for this signature in database
GPG key ID: 22D62B84552719FC
2 changed files with 89 additions and 0 deletions

View file

@ -21,6 +21,7 @@
import contextlib import contextlib
import logging import logging
import time import time
from http import HTTPStatus
from typing import TYPE_CHECKING, Any, Generator, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Generator, Optional, Tuple, Union
import attr import attr
@ -139,6 +140,41 @@ class SynapseRequest(Request):
self.synapse_site.site_tag, self.synapse_site.site_tag,
) )
# Twisted machinery: this method is called by the Channel once the full request has
# been received, to dispatch the request to a resource.
#
# We're patching Twisted to bail/abort early when we see someone trying to upload
# `multipart/form-data` so we can avoid Twisted parsing the entire request body into
# in-memory (specific problem of this specific `Content-Type`). This protects us
# from an attacker uploading something bigger than the available RAM and crashing
# the server with a `MemoryError`, or carefully block just enough resources to cause
# all other requests to fail.
#
# FIXME: This can be removed once we Twisted releases a fix and we update to a
# version that is patched
def requestReceived(self, command: bytes, path: bytes, version: bytes) -> None:
if command == b"POST":
ctype = self.requestHeaders.getRawHeaders(b"content-type")
if ctype and b"multipart/form-data" in ctype[0]:
self.method, self.uri = command, path
self.clientproto = version
self.code = HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value
self.code_message = bytes(
HTTPStatus.UNSUPPORTED_MEDIA_TYPE.phrase, "ascii"
)
self.responseHeaders.setRawHeaders(b"content-length", [b"0"])
logger.warning(
"Aborting connection from %s because `content-type: multipart/form-data` is unsupported: %s %s",
self.client,
command,
path,
)
self.write(b"")
self.loseConnection()
return
return super().requestReceived(command, path, version)
def handleContentChunk(self, data: bytes) -> None: def handleContentChunk(self, data: bytes) -> None:
# we should have a `content` by now. # we should have a `content` by now.
assert self.content, "handleContentChunk() called before gotLength()" assert self.content, "handleContentChunk() called before gotLength()"

View file

@ -90,3 +90,56 @@ class SynapseRequestTestCase(HomeserverTestCase):
# default max upload size is 50M, so it should drop on the next buffer after # default max upload size is 50M, so it should drop on the next buffer after
# that. # that.
self.assertEqual(sent, 50 * 1024 * 1024 + 1024) self.assertEqual(sent, 50 * 1024 * 1024 + 1024)
def test_content_type_multipart(self) -> None:
"""HTTP POST requests with `content-type: multipart/form-data` should be rejected"""
self.hs.start_listening()
# find the HTTP server which is configured to listen on port 0
(port, factory, _backlog, interface) = self.reactor.tcpServers[0]
self.assertEqual(interface, "::")
self.assertEqual(port, 0)
# as a control case, first send a regular request.
# complete the connection and wire it up to a fake transport
client_address = IPv6Address("TCP", "::1", 2345)
protocol = factory.buildProtocol(client_address)
transport = StringTransport()
protocol.makeConnection(transport)
protocol.dataReceived(
b"POST / HTTP/1.1\r\n"
b"Connection: close\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
b"0\r\n"
b"\r\n"
)
while not transport.disconnecting:
self.reactor.advance(1)
# we should get a 404
self.assertRegex(transport.value().decode(), r"^HTTP/1\.1 404 ")
# now send request with content-type header
protocol = factory.buildProtocol(client_address)
transport = StringTransport()
protocol.makeConnection(transport)
protocol.dataReceived(
b"POST / HTTP/1.1\r\n"
b"Connection: close\r\n"
b"Transfer-Encoding: chunked\r\n"
b"Content-Type: multipart/form-data\r\n"
b"\r\n"
b"0\r\n"
b"\r\n"
)
while not transport.disconnecting:
self.reactor.advance(1)
# we should get a 415
self.assertRegex(transport.value().decode(), r"^HTTP/1\.1 415 ")