Add additional type hints to HTTP client. (#8812)

This also removes some duplicated code between the simple
HTTP client and matrix federation client.
This commit is contained in:
Patrick Cloke 2020-11-25 13:30:47 -05:00 committed by GitHub
parent 4fd222ad70
commit 968939bdac
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 142 additions and 149 deletions

View file

@ -1 +1 @@
Add type hints to matrix federation client and agent.
Add type hints to HTTP abstractions.

1
changelog.d/8812.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to HTTP abstractions.

View file

@ -45,6 +45,7 @@ files =
synapse/handlers/saml_handler.py,
synapse/handlers/sync.py,
synapse/handlers/ui_auth,
synapse/http/client.py,
synapse/http/federation/matrix_federation_agent.py,
synapse/http/federation/well_known_resolver.py,
synapse/http/matrixfederationclient.py,
@ -109,7 +110,7 @@ ignore_missing_imports = True
[mypy-opentracing]
ignore_missing_imports = True
[mypy-OpenSSL]
[mypy-OpenSSL.*]
ignore_missing_imports = True
[mypy-netaddr]

View file

@ -14,9 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import urllib
import urllib.parse
from io import BytesIO
from typing import (
TYPE_CHECKING,
Any,
BinaryIO,
Dict,
@ -31,7 +32,7 @@ from typing import (
import treq
from canonicaljson import encode_canonical_json
from netaddr import IPAddress
from netaddr import IPAddress, IPSet
from prometheus_client import Counter
from zope.interface import implementer, provider
@ -39,6 +40,8 @@ from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
from twisted.internet import defer, error as twisted_error, protocol, ssl
from twisted.internet.interfaces import (
IAddress,
IHostResolution,
IReactorPluggableNameResolver,
IResolutionReceiver,
)
@ -53,7 +56,7 @@ from twisted.web.client import (
)
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
from twisted.web.iweb import IResponse
from twisted.web.iweb import IAgent, IBodyProducer, IResponse
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
@ -63,6 +66,9 @@ from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"])
@ -84,12 +90,19 @@ QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]]
QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]]
def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
def check_against_blacklist(
ip_address: IPAddress, ip_whitelist: Optional[IPSet], ip_blacklist: IPSet
) -> bool:
"""
Compares an IP address to allowed and disallowed IP sets.
Args:
ip_address (netaddr.IPAddress)
ip_whitelist (netaddr.IPSet)
ip_blacklist (netaddr.IPSet)
ip_address: The IP address to check
ip_whitelist: Allowed IP addresses.
ip_blacklist: Disallowed IP addresses.
Returns:
True if the IP address is in the blacklist and not in the whitelist.
"""
if ip_address in ip_blacklist:
if ip_whitelist is None or ip_address not in ip_whitelist:
@ -118,23 +131,30 @@ class IPBlacklistingResolver:
addresses, preventing DNS rebinding attacks on URL preview.
"""
def __init__(self, reactor, ip_whitelist, ip_blacklist):
def __init__(
self,
reactor: IReactorPluggableNameResolver,
ip_whitelist: Optional[IPSet],
ip_blacklist: IPSet,
):
"""
Args:
reactor (twisted.internet.reactor)
ip_whitelist (netaddr.IPSet)
ip_blacklist (netaddr.IPSet)
reactor: The twisted reactor.
ip_whitelist: IP addresses to allow.
ip_blacklist: IP addresses to disallow.
"""
self._reactor = reactor
self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist
def resolveHostName(self, recv, hostname, portNumber=0):
def resolveHostName(
self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
) -> IResolutionReceiver:
r = recv()
addresses = []
addresses = [] # type: List[IAddress]
def _callback():
def _callback() -> None:
r.resolutionBegan(None)
has_bad_ip = False
@ -161,15 +181,15 @@ class IPBlacklistingResolver:
@provider(IResolutionReceiver)
class EndpointReceiver:
@staticmethod
def resolutionBegan(resolutionInProgress):
def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
pass
@staticmethod
def addressResolved(address):
def addressResolved(address: IAddress) -> None:
addresses.append(address)
@staticmethod
def resolutionComplete():
def resolutionComplete() -> None:
_callback()
self._reactor.nameResolver.resolveHostName(
@ -185,19 +205,29 @@ class BlacklistingAgentWrapper(Agent):
directly (without an IP address lookup).
"""
def __init__(self, agent, reactor, ip_whitelist=None, ip_blacklist=None):
def __init__(
self,
agent: IAgent,
ip_whitelist: Optional[IPSet] = None,
ip_blacklist: Optional[IPSet] = None,
):
"""
Args:
agent (twisted.web.client.Agent): The Agent to wrap.
reactor (twisted.internet.reactor)
ip_whitelist (netaddr.IPSet)
ip_blacklist (netaddr.IPSet)
agent: The Agent to wrap.
ip_whitelist: IP addresses to allow.
ip_blacklist: IP addresses to disallow.
"""
self._agent = agent
self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist
def request(self, method, uri, headers=None, bodyProducer=None):
def request(
self,
method: bytes,
uri: bytes,
headers: Optional[Headers] = None,
bodyProducer: Optional[IBodyProducer] = None,
) -> defer.Deferred:
h = urllib.parse.urlparse(uri.decode("ascii"))
try:
@ -226,23 +256,23 @@ class SimpleHttpClient:
def __init__(
self,
hs,
treq_args={},
ip_whitelist=None,
ip_blacklist=None,
http_proxy=None,
https_proxy=None,
hs: "HomeServer",
treq_args: Dict[str, Any] = {},
ip_whitelist: Optional[IPSet] = None,
ip_blacklist: Optional[IPSet] = None,
http_proxy: Optional[bytes] = None,
https_proxy: Optional[bytes] = None,
):
"""
Args:
hs (synapse.server.HomeServer)
treq_args (dict): Extra keyword arguments to be given to treq.request.
ip_blacklist (netaddr.IPSet): The IP addresses that are blacklisted that
hs
treq_args: Extra keyword arguments to be given to treq.request.
ip_blacklist: The IP addresses that are blacklisted that
we may not request.
ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can
ip_whitelist: The whitelisted IP addresses, that we can
request if it were otherwise caught in a blacklist.
http_proxy (bytes): proxy server to use for http connections. host[:port]
https_proxy (bytes): proxy server to use for https connections. host[:port]
http_proxy: proxy server to use for http connections. host[:port]
https_proxy: proxy server to use for https connections. host[:port]
"""
self.hs = hs
@ -306,7 +336,6 @@ class SimpleHttpClient:
# by the DNS resolution.
self.agent = BlacklistingAgentWrapper(
self.agent,
self.reactor,
ip_whitelist=self._ip_whitelist,
ip_blacklist=self._ip_blacklist,
)
@ -397,7 +426,7 @@ class SimpleHttpClient:
async def post_urlencoded_get_json(
self,
uri: str,
args: Mapping[str, Union[str, List[str]]] = {},
args: Optional[Mapping[str, Union[str, List[str]]]] = None,
headers: Optional[RawHeaders] = None,
) -> Any:
"""
@ -422,9 +451,7 @@ class SimpleHttpClient:
# TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args)
query_bytes = urllib.parse.urlencode(encode_urlencode_args(args), True).encode(
"utf8"
)
query_bytes = encode_query_args(args)
actual_headers = {
b"Content-Type": [b"application/x-www-form-urlencoded"],
@ -432,7 +459,7 @@ class SimpleHttpClient:
b"Accept": [b"application/json"],
}
if headers:
actual_headers.update(headers)
actual_headers.update(headers) # type: ignore
response = await self.request(
"POST", uri, headers=Headers(actual_headers), data=query_bytes
@ -479,7 +506,7 @@ class SimpleHttpClient:
b"Accept": [b"application/json"],
}
if headers:
actual_headers.update(headers)
actual_headers.update(headers) # type: ignore
response = await self.request(
"POST", uri, headers=Headers(actual_headers), data=json_str
@ -495,7 +522,10 @@ class SimpleHttpClient:
)
async def get_json(
self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None,
self,
uri: str,
args: Optional[QueryParams] = None,
headers: Optional[RawHeaders] = None,
) -> Any:
"""Gets some json from the given URI.
@ -516,7 +546,7 @@ class SimpleHttpClient:
"""
actual_headers = {b"Accept": [b"application/json"]}
if headers:
actual_headers.update(headers)
actual_headers.update(headers) # type: ignore
body = await self.get_raw(uri, args, headers=headers)
return json_decoder.decode(body.decode("utf-8"))
@ -525,7 +555,7 @@ class SimpleHttpClient:
self,
uri: str,
json_body: Any,
args: QueryParams = {},
args: Optional[QueryParams] = None,
headers: RawHeaders = None,
) -> Any:
"""Puts some json to the given URI.
@ -546,9 +576,9 @@ class SimpleHttpClient:
ValueError: if the response was not JSON
"""
if len(args):
query_bytes = urllib.parse.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
if args:
query_str = urllib.parse.urlencode(args, True)
uri = "%s?%s" % (uri, query_str)
json_str = encode_canonical_json(json_body)
@ -558,7 +588,7 @@ class SimpleHttpClient:
b"Accept": [b"application/json"],
}
if headers:
actual_headers.update(headers)
actual_headers.update(headers) # type: ignore
response = await self.request(
"PUT", uri, headers=Headers(actual_headers), data=json_str
@ -574,7 +604,10 @@ class SimpleHttpClient:
)
async def get_raw(
self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None
self,
uri: str,
args: Optional[QueryParams] = None,
headers: Optional[RawHeaders] = None,
) -> bytes:
"""Gets raw text from the given URI.
@ -592,13 +625,13 @@ class SimpleHttpClient:
HttpResponseException on a non-2xx HTTP response.
"""
if len(args):
query_bytes = urllib.parse.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
if args:
query_str = urllib.parse.urlencode(args, True)
uri = "%s?%s" % (uri, query_str)
actual_headers = {b"User-Agent": [self.user_agent]}
if headers:
actual_headers.update(headers)
actual_headers.update(headers) # type: ignore
response = await self.request("GET", uri, headers=Headers(actual_headers))
@ -641,7 +674,7 @@ class SimpleHttpClient:
actual_headers = {b"User-Agent": [self.user_agent]}
if headers:
actual_headers.update(headers)
actual_headers.update(headers) # type: ignore
response = await self.request("GET", url, headers=Headers(actual_headers))
@ -649,12 +682,13 @@ class SimpleHttpClient:
if (
b"Content-Length" in resp_headers
and max_size
and int(resp_headers[b"Content-Length"][0]) > max_size
):
logger.warning("Requested URL is too large > %r bytes" % (self.max_size,))
logger.warning("Requested URL is too large > %r bytes" % (max_size,))
raise SynapseError(
502,
"Requested file is too large > %r bytes" % (self.max_size,),
"Requested file is too large > %r bytes" % (max_size,),
Codes.TOO_LARGE,
)
@ -668,7 +702,7 @@ class SimpleHttpClient:
try:
length = await make_deferred_yieldable(
_readBodyToFile(response, output_stream, max_size)
readBodyToFile(response, output_stream, max_size)
)
except SynapseError:
# This can happen e.g. because the body is too large.
@ -696,18 +730,16 @@ def _timeout_to_request_timed_out_error(f: Failure):
return f
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
class _ReadBodyToFileProtocol(protocol.Protocol):
def __init__(self, stream, deferred, max_size):
def __init__(
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
):
self.stream = stream
self.deferred = deferred
self.length = 0
self.max_size = max_size
def dataReceived(self, data):
def dataReceived(self, data: bytes) -> None:
self.stream.write(data)
self.length += len(data)
if self.max_size is not None and self.length >= self.max_size:
@ -721,7 +753,7 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.deferred = defer.Deferred()
self.transport.loseConnection()
def connectionLost(self, reason):
def connectionLost(self, reason: Failure) -> None:
if reason.check(ResponseDone):
self.deferred.callback(self.length)
elif reason.check(PotentialDataLoss):
@ -732,35 +764,48 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.deferred.errback(reason)
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
def readBodyToFile(
response: IResponse, stream: BinaryIO, max_size: Optional[int]
) -> defer.Deferred:
"""
Read a HTTP response body to a file-object. Optionally enforcing a maximum file size.
Args:
response: The HTTP response to read from.
stream: The file-object to write to.
max_size: The maximum file size to allow.
Returns:
A Deferred which resolves to the length of the read body.
"""
def _readBodyToFile(response, stream, max_size):
d = defer.Deferred()
response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
return d
def encode_urlencode_args(args):
return {k: encode_urlencode_arg(v) for k, v in args.items()}
def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> bytes:
"""
Encodes a map of query arguments to bytes which can be appended to a URL.
Args:
args: The query arguments, a mapping of string to string or list of strings.
def encode_urlencode_arg(arg):
if isinstance(arg, str):
return arg.encode("utf-8")
elif isinstance(arg, list):
return [encode_urlencode_arg(i) for i in arg]
else:
return arg
Returns:
The query arguments encoded as bytes.
"""
if args is None:
return b""
encoded_args = {}
for k, vs in args.items():
if isinstance(vs, str):
vs = [vs]
encoded_args[k] = [v.encode("utf8") for v in vs]
def _print_ex(e):
if hasattr(e, "reasons") and e.reasons:
for ex in e.reasons:
_print_ex(ex)
else:
logger.exception(e)
query_str = urllib.parse.urlencode(encoded_args, True)
return query_str.encode("utf8")
class InsecureInterceptableContextFactory(ssl.ContextFactory):

View file

@ -19,7 +19,7 @@ import random
import sys
import urllib.parse
from io import BytesIO
from typing import BinaryIO, Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union
import attr
import treq
@ -28,26 +28,27 @@ from prometheus_client import Counter
from signedjson.sign import sign_json
from zope.interface import implementer
from twisted.internet import defer, protocol
from twisted.internet import defer
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IReactorPluggableNameResolver, IReactorTime
from twisted.internet.task import _EPSILON, Cooperator
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
from twisted.web.http_headers import Headers
from twisted.web.iweb import IBodyProducer, IResponse
import synapse.metrics
import synapse.util.retryutils
from synapse.api.errors import (
Codes,
FederationDeniedError,
HttpResponseException,
RequestSendFailed,
SynapseError,
)
from synapse.http import QuieterFileBodyProducer
from synapse.http.client import BlacklistingAgentWrapper, IPBlacklistingResolver
from synapse.http.client import (
BlacklistingAgentWrapper,
IPBlacklistingResolver,
encode_query_args,
readBodyToFile,
)
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import (
@ -250,9 +251,7 @@ class MatrixFederationHttpClient:
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
# blacklist via IP literals in server names
self.agent = BlacklistingAgentWrapper(
self.agent,
self.reactor,
ip_blacklist=hs.config.federation_ip_range_blacklist,
self.agent, ip_blacklist=hs.config.federation_ip_range_blacklist,
)
self.clock = hs.get_clock()
@ -986,7 +985,7 @@ class MatrixFederationHttpClient:
headers = dict(response.headers.getAllRawHeaders())
try:
d = _readBodyToFile(response, output_stream, max_size)
d = readBodyToFile(response, output_stream, max_size)
d.addTimeout(self.default_timeout, self.reactor)
length = await make_deferred_yieldable(d)
except Exception as e:
@ -1010,44 +1009,6 @@ class MatrixFederationHttpClient:
return (length, headers)
class _ReadBodyToFileProtocol(protocol.Protocol):
def __init__(
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
):
self.stream = stream
self.deferred = deferred
self.length = 0
self.max_size = max_size
def dataReceived(self, data: bytes) -> None:
self.stream.write(data)
self.length += len(data)
if self.max_size is not None and self.length >= self.max_size:
self.deferred.errback(
SynapseError(
502,
"Requested file is too large > %r bytes" % (self.max_size,),
Codes.TOO_LARGE,
)
)
self.deferred = defer.Deferred()
self.transport.loseConnection()
def connectionLost(self, reason: Failure) -> None:
if reason.check(ResponseDone):
self.deferred.callback(self.length)
else:
self.deferred.errback(reason)
def _readBodyToFile(
response: IResponse, stream: BinaryIO, max_size: Optional[int]
) -> defer.Deferred:
d = defer.Deferred()
response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
return d
def _flatten_response_never_received(e):
if hasattr(e, "reasons"):
reasons = ", ".join(
@ -1088,18 +1049,3 @@ def check_content_type_is_json(headers: Headers) -> None:
),
can_retry=False,
)
def encode_query_args(args: Optional[QueryArgs]) -> bytes:
if args is None:
return b""
encoded_args = {}
for k, vs in args.items():
if isinstance(vs, str):
vs = [vs]
encoded_args[k] = [v.encode("utf8") for v in vs]
query_str = urllib.parse.urlencode(encoded_args, True)
return query_str.encode("utf8")