Bugbear: Add Mutable Parameter fixes (#9682)

Part of #9366

Adds in fixes for B006 and B008, both relating to mutable parameter lint errors.

Signed-off-by: Jonathan de Jong <jonathan@automatia.nl>
This commit is contained in:
Jonathan de Jong 2021-04-08 23:38:54 +02:00 committed by GitHub
parent 64f4f506c5
commit 2ca4e349e9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
38 changed files with 224 additions and 113 deletions

1
changelog.d/9682.misc Normal file
View file

@ -0,0 +1 @@
Introduce flake8-bugbear to the test suite and fix some of its lint violations.

View file

@ -24,6 +24,7 @@ import sys
import time import time
import urllib import urllib
from http import TwistedHttpClient from http import TwistedHttpClient
from typing import Optional
import nacl.encoding import nacl.encoding
import nacl.signing import nacl.signing
@ -718,7 +719,7 @@ class SynapseCmd(cmd.Cmd):
method, method,
path, path,
data=None, data=None,
query_params={"access_token": None}, query_params: Optional[dict] = None,
alt_text=None, alt_text=None,
): ):
"""Runs an HTTP request and pretty prints the output. """Runs an HTTP request and pretty prints the output.
@ -729,6 +730,8 @@ class SynapseCmd(cmd.Cmd):
data: Raw JSON data if any data: Raw JSON data if any
query_params: dict of query parameters to add to the url query_params: dict of query parameters to add to the url
""" """
query_params = query_params or {"access_token": None}
url = self._url() + path url = self._url() + path
if "access_token" in query_params: if "access_token" in query_params:
query_params["access_token"] = self._tok() query_params["access_token"] = self._tok()

View file

@ -16,6 +16,7 @@
import json import json
import urllib import urllib
from pprint import pformat from pprint import pformat
from typing import Optional
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.web.client import Agent, readBody from twisted.web.client import Agent, readBody
@ -85,8 +86,9 @@ class TwistedHttpClient(HttpClient):
body = yield readBody(response) body = yield readBody(response)
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
def _create_put_request(self, url, json_data, headers_dict={}): def _create_put_request(self, url, json_data, headers_dict: Optional[dict] = None):
"""Wrapper of _create_request to issue a PUT request""" """Wrapper of _create_request to issue a PUT request"""
headers_dict = headers_dict or {}
if "Content-Type" not in headers_dict: if "Content-Type" not in headers_dict:
raise defer.error(RuntimeError("Must include Content-Type header for PUTs")) raise defer.error(RuntimeError("Must include Content-Type header for PUTs"))
@ -95,14 +97,22 @@ class TwistedHttpClient(HttpClient):
"PUT", url, producer=_JsonProducer(json_data), headers_dict=headers_dict "PUT", url, producer=_JsonProducer(json_data), headers_dict=headers_dict
) )
def _create_get_request(self, url, headers_dict={}): def _create_get_request(self, url, headers_dict: Optional[dict] = None):
"""Wrapper of _create_request to issue a GET request""" """Wrapper of _create_request to issue a GET request"""
return self._create_request("GET", url, headers_dict=headers_dict) return self._create_request("GET", url, headers_dict=headers_dict or {})
@defer.inlineCallbacks @defer.inlineCallbacks
def do_request( def do_request(
self, method, url, data=None, qparams=None, jsonreq=True, headers={} self,
method,
url,
data=None,
qparams=None,
jsonreq=True,
headers: Optional[dict] = None,
): ):
headers = headers or {}
if qparams: if qparams:
url = "%s?%s" % (url, urllib.urlencode(qparams, True)) url = "%s?%s" % (url, urllib.urlencode(qparams, True))
@ -123,8 +133,12 @@ class TwistedHttpClient(HttpClient):
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_request(self, method, url, producer=None, headers_dict={}): def _create_request(
self, method, url, producer=None, headers_dict: Optional[dict] = None
):
"""Creates and sends a request to the given url""" """Creates and sends a request to the given url"""
headers_dict = headers_dict or {}
headers_dict["User-Agent"] = ["Synapse Cmd Client"] headers_dict["User-Agent"] = ["Synapse Cmd Client"]
retries_left = 5 retries_left = 5

View file

@ -18,8 +18,8 @@ ignore =
# E203: whitespace before ':' (which is contrary to pep8?) # E203: whitespace before ':' (which is contrary to pep8?)
# E731: do not assign a lambda expression, use a def # E731: do not assign a lambda expression, use a def
# E501: Line too long (black enforces this for us) # E501: Line too long (black enforces this for us)
# B00*: Subsection of the bugbear suite (TODO: add in remaining fixes) # B007: Subsection of the bugbear suite (TODO: add in remaining fixes)
ignore=W503,W504,E203,E731,E501,B006,B007,B008 ignore=W503,W504,E203,E731,E501,B007
[isort] [isort]
line_length = 88 line_length = 88

View file

@ -49,7 +49,7 @@ This is all tied together by the AppServiceScheduler which DIs the required
components. components.
""" """
import logging import logging
from typing import List from typing import List, Optional
from synapse.appservice import ApplicationService, ApplicationServiceState from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.events import EventBase from synapse.events import EventBase
@ -191,11 +191,11 @@ class _TransactionController:
self, self,
service: ApplicationService, service: ApplicationService,
events: List[EventBase], events: List[EventBase],
ephemeral: List[JsonDict] = [], ephemeral: Optional[List[JsonDict]] = None,
): ):
try: try:
txn = await self.store.create_appservice_txn( txn = await self.store.create_appservice_txn(
service=service, events=events, ephemeral=ephemeral service=service, events=events, ephemeral=ephemeral or []
) )
service_is_up = await self._is_service_up(service) service_is_up = await self._is_service_up(service)
if service_is_up: if service_is_up:

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict from typing import Dict, Optional
from ._base import Config from ._base import Config
@ -21,8 +21,10 @@ class RateLimitConfig:
def __init__( def __init__(
self, self,
config: Dict[str, float], config: Dict[str, float],
defaults={"per_second": 0.17, "burst_count": 3.0}, defaults: Optional[Dict[str, float]] = None,
): ):
defaults = defaults or {"per_second": 0.17, "burst_count": 3.0}
self.per_second = config.get("per_second", defaults["per_second"]) self.per_second = config.get("per_second", defaults["per_second"])
self.burst_count = int(config.get("burst_count", defaults["burst_count"])) self.burst_count = int(config.get("burst_count", defaults["burst_count"]))

View file

@ -330,9 +330,11 @@ class FrozenEvent(EventBase):
self, self,
event_dict: JsonDict, event_dict: JsonDict,
room_version: RoomVersion, room_version: RoomVersion,
internal_metadata_dict: JsonDict = {}, internal_metadata_dict: Optional[JsonDict] = None,
rejected_reason: Optional[str] = None, rejected_reason: Optional[str] = None,
): ):
internal_metadata_dict = internal_metadata_dict or {}
event_dict = dict(event_dict) event_dict = dict(event_dict)
# Signatures is a dict of dicts, and this is faster than doing a # Signatures is a dict of dicts, and this is faster than doing a
@ -386,9 +388,11 @@ class FrozenEventV2(EventBase):
self, self,
event_dict: JsonDict, event_dict: JsonDict,
room_version: RoomVersion, room_version: RoomVersion,
internal_metadata_dict: JsonDict = {}, internal_metadata_dict: Optional[JsonDict] = None,
rejected_reason: Optional[str] = None, rejected_reason: Optional[str] = None,
): ):
internal_metadata_dict = internal_metadata_dict or {}
event_dict = dict(event_dict) event_dict = dict(event_dict)
# Signatures is a dict of dicts, and this is faster than doing a # Signatures is a dict of dicts, and this is faster than doing a
@ -507,9 +511,11 @@ def _event_type_from_format_version(format_version: int) -> Type[EventBase]:
def make_event_from_dict( def make_event_from_dict(
event_dict: JsonDict, event_dict: JsonDict,
room_version: RoomVersion = RoomVersions.V1, room_version: RoomVersion = RoomVersions.V1,
internal_metadata_dict: JsonDict = {}, internal_metadata_dict: Optional[JsonDict] = None,
rejected_reason: Optional[str] = None, rejected_reason: Optional[str] = None,
) -> EventBase: ) -> EventBase:
"""Construct an EventBase from the given event dict""" """Construct an EventBase from the given event dict"""
event_type = _event_type_from_format_version(room_version.event_format) event_type = _event_type_from_format_version(room_version.event_format)
return event_type(event_dict, room_version, internal_metadata_dict, rejected_reason) return event_type(
event_dict, room_version, internal_metadata_dict or {}, rejected_reason
)

View file

@ -18,6 +18,7 @@ server protocol.
""" """
import logging import logging
from typing import Optional
import attr import attr
@ -98,7 +99,7 @@ class Transaction(JsonEncodedObject):
"pdus", "pdus",
] ]
def __init__(self, transaction_id=None, pdus=[], **kwargs): def __init__(self, transaction_id=None, pdus: Optional[list] = None, **kwargs):
"""If we include a list of pdus then we decode then as PDU's """If we include a list of pdus then we decode then as PDU's
automatically. automatically.
""" """
@ -107,7 +108,7 @@ class Transaction(JsonEncodedObject):
if "edus" in kwargs and not kwargs["edus"]: if "edus" in kwargs and not kwargs["edus"]:
del kwargs["edus"] del kwargs["edus"]
super().__init__(transaction_id=transaction_id, pdus=pdus, **kwargs) super().__init__(transaction_id=transaction_id, pdus=pdus or [], **kwargs)
@staticmethod @staticmethod
def create_new(pdus, **kwargs): def create_new(pdus, **kwargs):

View file

@ -182,7 +182,7 @@ class ApplicationServicesHandler:
self, self,
stream_key: str, stream_key: str,
new_token: Optional[int], new_token: Optional[int],
users: Collection[Union[str, UserID]] = [], users: Optional[Collection[Union[str, UserID]]] = None,
): ):
"""This is called by the notifier in the background """This is called by the notifier in the background
when a ephemeral event handled by the homeserver. when a ephemeral event handled by the homeserver.
@ -215,7 +215,7 @@ class ApplicationServicesHandler:
# We only start a new background process if necessary rather than # We only start a new background process if necessary rather than
# optimistically (to cut down on overhead). # optimistically (to cut down on overhead).
self._notify_interested_services_ephemeral( self._notify_interested_services_ephemeral(
services, stream_key, new_token, users services, stream_key, new_token, users or []
) )
@wrap_as_background_process("notify_interested_services_ephemeral") @wrap_as_background_process("notify_interested_services_ephemeral")

View file

@ -1790,7 +1790,7 @@ class FederationHandler(BaseHandler):
room_id: str, room_id: str,
user_id: str, user_id: str,
membership: str, membership: str,
content: JsonDict = {}, content: JsonDict,
params: Optional[Dict[str, Union[str, Iterable[str]]]] = None, params: Optional[Dict[str, Union[str, Iterable[str]]]] = None,
) -> Tuple[str, EventBase, RoomVersion]: ) -> Tuple[str, EventBase, RoomVersion]:
( (

View file

@ -137,7 +137,7 @@ class MessageHandler:
self, self,
user_id: str, user_id: str,
room_id: str, room_id: str,
state_filter: StateFilter = StateFilter.all(), state_filter: Optional[StateFilter] = None,
at_token: Optional[StreamToken] = None, at_token: Optional[StreamToken] = None,
is_guest: bool = False, is_guest: bool = False,
) -> List[dict]: ) -> List[dict]:
@ -164,6 +164,8 @@ class MessageHandler:
AuthError (403) if the user doesn't have permission to view AuthError (403) if the user doesn't have permission to view
members of this room. members of this room.
""" """
state_filter = state_filter or StateFilter.all()
if at_token: if at_token:
# FIXME this claims to get the state at a stream position, but # FIXME this claims to get the state at a stream position, but
# get_recent_events_for_room operates by topo ordering. This therefore # get_recent_events_for_room operates by topo ordering. This therefore
@ -874,7 +876,7 @@ class EventCreationHandler:
event: EventBase, event: EventBase,
context: EventContext, context: EventContext,
ratelimit: bool = True, ratelimit: bool = True,
extra_users: List[UserID] = [], extra_users: Optional[List[UserID]] = None,
ignore_shadow_ban: bool = False, ignore_shadow_ban: bool = False,
) -> EventBase: ) -> EventBase:
"""Processes a new event. """Processes a new event.
@ -902,6 +904,7 @@ class EventCreationHandler:
Raises: Raises:
ShadowBanError if the requester has been shadow-banned. ShadowBanError if the requester has been shadow-banned.
""" """
extra_users = extra_users or []
# we don't apply shadow-banning to membership events here. Invites are blocked # we don't apply shadow-banning to membership events here. Invites are blocked
# higher up the stack, and we allow shadow-banned users to send join and leave # higher up the stack, and we allow shadow-banned users to send join and leave
@ -1071,7 +1074,7 @@ class EventCreationHandler:
event: EventBase, event: EventBase,
context: EventContext, context: EventContext,
ratelimit: bool = True, ratelimit: bool = True,
extra_users: List[UserID] = [], extra_users: Optional[List[UserID]] = None,
) -> EventBase: ) -> EventBase:
"""Called when we have fully built the event, have already """Called when we have fully built the event, have already
calculated the push actions for the event, and checked auth. calculated the push actions for the event, and checked auth.
@ -1083,6 +1086,8 @@ class EventCreationHandler:
it was de-duplicated (e.g. because we had already persisted an it was de-duplicated (e.g. because we had already persisted an
event with the same transaction ID.) event with the same transaction ID.)
""" """
extra_users = extra_users or []
assert self.storage.persistence is not None assert self.storage.persistence is not None
assert self._events_shard_config.should_handle( assert self._events_shard_config.should_handle(
self._instance_name, event.room_id self._instance_name, event.room_id

View file

@ -169,7 +169,7 @@ class RegistrationHandler(BaseHandler):
user_type: Optional[str] = None, user_type: Optional[str] = None,
default_display_name: Optional[str] = None, default_display_name: Optional[str] = None,
address: Optional[str] = None, address: Optional[str] = None,
bind_emails: Iterable[str] = [], bind_emails: Optional[Iterable[str]] = None,
by_admin: bool = False, by_admin: bool = False,
user_agent_ips: Optional[List[Tuple[str, str]]] = None, user_agent_ips: Optional[List[Tuple[str, str]]] = None,
auth_provider_id: Optional[str] = None, auth_provider_id: Optional[str] = None,
@ -204,6 +204,8 @@ class RegistrationHandler(BaseHandler):
Raises: Raises:
SynapseError if there was a problem registering. SynapseError if there was a problem registering.
""" """
bind_emails = bind_emails or []
await self.check_registration_ratelimit(address) await self.check_registration_ratelimit(address)
result = await self.spam_checker.check_registration_for_spam( result = await self.spam_checker.check_registration_for_spam(

View file

@ -548,7 +548,7 @@ class SyncHandler:
) )
async def get_state_after_event( async def get_state_after_event(
self, event: EventBase, state_filter: StateFilter = StateFilter.all() self, event: EventBase, state_filter: Optional[StateFilter] = None
) -> StateMap[str]: ) -> StateMap[str]:
""" """
Get the room state after the given event Get the room state after the given event
@ -558,7 +558,7 @@ class SyncHandler:
state_filter: The state filter used to fetch state from the database. state_filter: The state filter used to fetch state from the database.
""" """
state_ids = await self.state_store.get_state_ids_for_event( state_ids = await self.state_store.get_state_ids_for_event(
event.event_id, state_filter=state_filter event.event_id, state_filter=state_filter or StateFilter.all()
) )
if event.is_state(): if event.is_state():
state_ids = dict(state_ids) state_ids = dict(state_ids)
@ -569,7 +569,7 @@ class SyncHandler:
self, self,
room_id: str, room_id: str,
stream_position: StreamToken, stream_position: StreamToken,
state_filter: StateFilter = StateFilter.all(), state_filter: Optional[StateFilter] = None,
) -> StateMap[str]: ) -> StateMap[str]:
"""Get the room state at a particular stream position """Get the room state at a particular stream position
@ -589,7 +589,7 @@ class SyncHandler:
if last_events: if last_events:
last_event = last_events[-1] last_event = last_events[-1]
state = await self.get_state_after_event( state = await self.get_state_after_event(
last_event, state_filter=state_filter last_event, state_filter=state_filter or StateFilter.all()
) )
else: else:

View file

@ -297,7 +297,7 @@ class SimpleHttpClient:
def __init__( def __init__(
self, self,
hs: "HomeServer", hs: "HomeServer",
treq_args: Dict[str, Any] = {}, treq_args: Optional[Dict[str, Any]] = None,
ip_whitelist: Optional[IPSet] = None, ip_whitelist: Optional[IPSet] = None,
ip_blacklist: Optional[IPSet] = None, ip_blacklist: Optional[IPSet] = None,
use_proxy: bool = False, use_proxy: bool = False,
@ -317,7 +317,7 @@ class SimpleHttpClient:
self._ip_whitelist = ip_whitelist self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist self._ip_blacklist = ip_blacklist
self._extra_treq_args = treq_args self._extra_treq_args = treq_args or {}
self.user_agent = hs.version_string self.user_agent = hs.version_string
self.clock = hs.get_clock() self.clock = hs.get_clock()

View file

@ -27,7 +27,7 @@ from twisted.python.failure import Failure
from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase
from twisted.web.error import SchemeNotSupported from twisted.web.error import SchemeNotSupported
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent from twisted.web.iweb import IAgent, IPolicyForHTTPS
from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint
@ -88,12 +88,14 @@ class ProxyAgent(_AgentBase):
self, self,
reactor, reactor,
proxy_reactor=None, proxy_reactor=None,
contextFactory=BrowserLikePolicyForHTTPS(), contextFactory: Optional[IPolicyForHTTPS] = None,
connectTimeout=None, connectTimeout=None,
bindAddress=None, bindAddress=None,
pool=None, pool=None,
use_proxy=False, use_proxy=False,
): ):
contextFactory = contextFactory or BrowserLikePolicyForHTTPS()
_AgentBase.__init__(self, reactor, pool) _AgentBase.__init__(self, reactor, pool)
if proxy_reactor is None: if proxy_reactor is None:

View file

@ -486,7 +486,7 @@ def start_active_span_from_request(
def start_active_span_from_edu( def start_active_span_from_edu(
edu_content, edu_content,
operation_name, operation_name,
references=[], references: Optional[list] = None,
tags=None, tags=None,
start_time=None, start_time=None,
ignore_active_span=False, ignore_active_span=False,
@ -501,6 +501,7 @@ def start_active_span_from_edu(
For the other args see opentracing.tracer For the other args see opentracing.tracer
""" """
references = references or []
if opentracing is None: if opentracing is None:
return noop_context_manager() return noop_context_manager()

View file

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Any, Generator, Iterable, Optional, Tuple from typing import TYPE_CHECKING, Any, Generator, Iterable, List, Optional, Tuple
from twisted.internet import defer from twisted.internet import defer
@ -127,7 +127,7 @@ class ModuleApi:
return defer.ensureDeferred(self._auth_handler.check_user_exists(user_id)) return defer.ensureDeferred(self._auth_handler.check_user_exists(user_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def register(self, localpart, displayname=None, emails=[]): def register(self, localpart, displayname=None, emails: Optional[List[str]] = None):
"""Registers a new user with given localpart and optional displayname, emails. """Registers a new user with given localpart and optional displayname, emails.
Also returns an access token for the new user. Also returns an access token for the new user.
@ -147,11 +147,13 @@ class ModuleApi:
logger.warning( logger.warning(
"Using deprecated ModuleApi.register which creates a dummy user device." "Using deprecated ModuleApi.register which creates a dummy user device."
) )
user_id = yield self.register_user(localpart, displayname, emails) user_id = yield self.register_user(localpart, displayname, emails or [])
_, access_token = yield self.register_device(user_id) _, access_token = yield self.register_device(user_id)
return user_id, access_token return user_id, access_token
def register_user(self, localpart, displayname=None, emails=[]): def register_user(
self, localpart, displayname=None, emails: Optional[List[str]] = None
):
"""Registers a new user with given localpart and optional displayname, emails. """Registers a new user with given localpart and optional displayname, emails.
Args: Args:
@ -170,7 +172,7 @@ class ModuleApi:
self._hs.get_registration_handler().register_user( self._hs.get_registration_handler().register_user(
localpart=localpart, localpart=localpart,
default_display_name=displayname, default_display_name=displayname,
bind_emails=emails, bind_emails=emails or [],
) )
) )

View file

@ -276,7 +276,7 @@ class Notifier:
event: EventBase, event: EventBase,
event_pos: PersistedEventPosition, event_pos: PersistedEventPosition,
max_room_stream_token: RoomStreamToken, max_room_stream_token: RoomStreamToken,
extra_users: Collection[UserID] = [], extra_users: Optional[Collection[UserID]] = None,
): ):
"""Unwraps event and calls `on_new_room_event_args`.""" """Unwraps event and calls `on_new_room_event_args`."""
self.on_new_room_event_args( self.on_new_room_event_args(
@ -286,7 +286,7 @@ class Notifier:
state_key=event.get("state_key"), state_key=event.get("state_key"),
membership=event.content.get("membership"), membership=event.content.get("membership"),
max_room_stream_token=max_room_stream_token, max_room_stream_token=max_room_stream_token,
extra_users=extra_users, extra_users=extra_users or [],
) )
def on_new_room_event_args( def on_new_room_event_args(
@ -297,7 +297,7 @@ class Notifier:
membership: Optional[str], membership: Optional[str],
event_pos: PersistedEventPosition, event_pos: PersistedEventPosition,
max_room_stream_token: RoomStreamToken, max_room_stream_token: RoomStreamToken,
extra_users: Collection[UserID] = [], extra_users: Optional[Collection[UserID]] = None,
): ):
"""Used by handlers to inform the notifier something has happened """Used by handlers to inform the notifier something has happened
in the room, room event wise. in the room, room event wise.
@ -313,7 +313,7 @@ class Notifier:
self.pending_new_room_events.append( self.pending_new_room_events.append(
_PendingRoomEventEntry( _PendingRoomEventEntry(
event_pos=event_pos, event_pos=event_pos,
extra_users=extra_users, extra_users=extra_users or [],
room_id=room_id, room_id=room_id,
type=event_type, type=event_type,
state_key=state_key, state_key=state_key,
@ -382,14 +382,14 @@ class Notifier:
self, self,
stream_key: str, stream_key: str,
new_token: Union[int, RoomStreamToken], new_token: Union[int, RoomStreamToken],
users: Collection[Union[str, UserID]] = [], users: Optional[Collection[Union[str, UserID]]] = None,
): ):
try: try:
stream_token = None stream_token = None
if isinstance(new_token, int): if isinstance(new_token, int):
stream_token = new_token stream_token = new_token
self.appservice_handler.notify_interested_services_ephemeral( self.appservice_handler.notify_interested_services_ephemeral(
stream_key, stream_token, users stream_key, stream_token, users or []
) )
except Exception: except Exception:
logger.exception("Error notifying application services of event") logger.exception("Error notifying application services of event")
@ -404,13 +404,16 @@ class Notifier:
self, self,
stream_key: str, stream_key: str,
new_token: Union[int, RoomStreamToken], new_token: Union[int, RoomStreamToken],
users: Collection[Union[str, UserID]] = [], users: Optional[Collection[Union[str, UserID]]] = None,
rooms: Collection[str] = [], rooms: Optional[Collection[str]] = None,
): ):
"""Used to inform listeners that something has happened event wise. """Used to inform listeners that something has happened event wise.
Will wake up all listeners for the given users and rooms. Will wake up all listeners for the given users and rooms.
""" """
users = users or []
rooms = rooms or []
with Measure(self.clock, "on_new_event"): with Measure(self.clock, "on_new_event"):
user_streams = set() user_streams = set()

View file

@ -900,7 +900,7 @@ class DatabasePool:
table: str, table: str,
keyvalues: Dict[str, Any], keyvalues: Dict[str, Any],
values: Dict[str, Any], values: Dict[str, Any],
insertion_values: Dict[str, Any] = {}, insertion_values: Optional[Dict[str, Any]] = None,
desc: str = "simple_upsert", desc: str = "simple_upsert",
lock: bool = True, lock: bool = True,
) -> Optional[bool]: ) -> Optional[bool]:
@ -927,6 +927,8 @@ class DatabasePool:
Native upserts always return None. Emulated upserts return True if a Native upserts always return None. Emulated upserts return True if a
new entry was created, False if an existing one was updated. new entry was created, False if an existing one was updated.
""" """
insertion_values = insertion_values or {}
attempts = 0 attempts = 0
while True: while True:
try: try:
@ -964,7 +966,7 @@ class DatabasePool:
table: str, table: str,
keyvalues: Dict[str, Any], keyvalues: Dict[str, Any],
values: Dict[str, Any], values: Dict[str, Any],
insertion_values: Dict[str, Any] = {}, insertion_values: Optional[Dict[str, Any]] = None,
lock: bool = True, lock: bool = True,
) -> Optional[bool]: ) -> Optional[bool]:
""" """
@ -982,6 +984,8 @@ class DatabasePool:
Native upserts always return None. Emulated upserts return True if a Native upserts always return None. Emulated upserts return True if a
new entry was created, False if an existing one was updated. new entry was created, False if an existing one was updated.
""" """
insertion_values = insertion_values or {}
if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables: if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
self.simple_upsert_txn_native_upsert( self.simple_upsert_txn_native_upsert(
txn, table, keyvalues, values, insertion_values=insertion_values txn, table, keyvalues, values, insertion_values=insertion_values
@ -1003,7 +1007,7 @@ class DatabasePool:
table: str, table: str,
keyvalues: Dict[str, Any], keyvalues: Dict[str, Any],
values: Dict[str, Any], values: Dict[str, Any],
insertion_values: Dict[str, Any] = {}, insertion_values: Optional[Dict[str, Any]] = None,
lock: bool = True, lock: bool = True,
) -> bool: ) -> bool:
""" """
@ -1017,6 +1021,8 @@ class DatabasePool:
Returns True if a new entry was created, False if an existing Returns True if a new entry was created, False if an existing
one was updated. one was updated.
""" """
insertion_values = insertion_values or {}
# We need to lock the table :(, unless we're *really* careful # We need to lock the table :(, unless we're *really* careful
if lock: if lock:
self.engine.lock_table(txn, table) self.engine.lock_table(txn, table)
@ -1077,7 +1083,7 @@ class DatabasePool:
table: str, table: str,
keyvalues: Dict[str, Any], keyvalues: Dict[str, Any],
values: Dict[str, Any], values: Dict[str, Any],
insertion_values: Dict[str, Any] = {}, insertion_values: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
""" """
Use the native UPSERT functionality in recent PostgreSQL versions. Use the native UPSERT functionality in recent PostgreSQL versions.
@ -1090,7 +1096,7 @@ class DatabasePool:
""" """
allvalues = {} # type: Dict[str, Any] allvalues = {} # type: Dict[str, Any]
allvalues.update(keyvalues) allvalues.update(keyvalues)
allvalues.update(insertion_values) allvalues.update(insertion_values or {})
if not values: if not values:
latter = "NOTHING" latter = "NOTHING"
@ -1513,7 +1519,7 @@ class DatabasePool:
column: str, column: str,
iterable: Iterable[Any], iterable: Iterable[Any],
retcols: Iterable[str], retcols: Iterable[str],
keyvalues: Dict[str, Any] = {}, keyvalues: Optional[Dict[str, Any]] = None,
desc: str = "simple_select_many_batch", desc: str = "simple_select_many_batch",
batch_size: int = 100, batch_size: int = 100,
) -> List[Any]: ) -> List[Any]:
@ -1531,6 +1537,8 @@ class DatabasePool:
desc: description of the transaction, for logging and metrics desc: description of the transaction, for logging and metrics
batch_size: the number of rows for each select query batch_size: the number of rows for each select query
""" """
keyvalues = keyvalues or {}
results = [] # type: List[Dict[str, Any]] results = [] # type: List[Dict[str, Any]]
if not iterable: if not iterable:

View file

@ -320,8 +320,8 @@ class PersistEventsStore:
txn: LoggingTransaction, txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]], events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool, backfilled: bool,
state_delta_for_room: Dict[str, DeltaState] = {}, state_delta_for_room: Optional[Dict[str, DeltaState]] = None,
new_forward_extremeties: Dict[str, List[str]] = {}, new_forward_extremeties: Optional[Dict[str, List[str]]] = None,
): ):
"""Insert some number of room events into the necessary database tables. """Insert some number of room events into the necessary database tables.
@ -342,6 +342,9 @@ class PersistEventsStore:
extremities. extremities.
""" """
state_delta_for_room = state_delta_for_room or {}
new_forward_extremeties = new_forward_extremeties or {}
all_events_and_contexts = events_and_contexts all_events_and_contexts = events_and_contexts
min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering

View file

@ -1171,7 +1171,7 @@ class GroupServerStore(GroupServerWorkerStore):
user_id: str, user_id: str,
membership: str, membership: str,
is_admin: bool = False, is_admin: bool = False,
content: JsonDict = {}, content: Optional[JsonDict] = None,
local_attestation: Optional[dict] = None, local_attestation: Optional[dict] = None,
remote_attestation: Optional[dict] = None, remote_attestation: Optional[dict] = None,
is_publicised: bool = False, is_publicised: bool = False,
@ -1192,6 +1192,8 @@ class GroupServerStore(GroupServerWorkerStore):
is_publicised: Whether this should be publicised. is_publicised: Whether this should be publicised.
""" """
content = content or {}
def _register_user_group_membership_txn(txn, next_id): def _register_user_group_membership_txn(txn, next_id):
# TODO: Upsert? # TODO: Upsert?
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(

View file

@ -190,7 +190,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# FIXME: how should this be cached? # FIXME: how should this be cached?
async def get_filtered_current_state_ids( async def get_filtered_current_state_ids(
self, room_id: str, state_filter: StateFilter = StateFilter.all() self, room_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[str]: ) -> StateMap[str]:
"""Get the current state event of a given type for a room based on the """Get the current state event of a given type for a room based on the
current_state_events table. This may not be as up-to-date as the result current_state_events table. This may not be as up-to-date as the result
@ -205,7 +205,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Map from type/state_key to event ID. Map from type/state_key to event ID.
""" """
where_clause, where_args = state_filter.make_sql_filter_clause() where_clause, where_args = (
state_filter or StateFilter.all()
).make_sql_filter_clause()
if not where_clause: if not where_clause:
# We delegate to the cached version # We delegate to the cached version

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Optional
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
@ -73,8 +74,10 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
return count return count
def _get_state_groups_from_groups_txn( def _get_state_groups_from_groups_txn(
self, txn, groups, state_filter=StateFilter.all() self, txn, groups, state_filter: Optional[StateFilter] = None
): ):
state_filter = state_filter or StateFilter.all()
results = {group: {} for group in groups} results = {group: {} for group in groups}
where_clause, where_args = state_filter.make_sql_filter_clause() where_clause, where_args = state_filter.make_sql_filter_clause()

View file

@ -15,7 +15,7 @@
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Dict, Iterable, List, Set, Tuple from typing import Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
@ -210,7 +210,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_filter.filter_state(state_dict_ids), not missing_types return state_filter.filter_state(state_dict_ids), not missing_types
async def _get_state_for_groups( async def _get_state_for_groups(
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Dict[int, MutableStateMap[str]]: ) -> Dict[int, MutableStateMap[str]]:
"""Gets the state at each of a list of state groups, optionally """Gets the state at each of a list of state groups, optionally
filtering by type/state_key filtering by type/state_key
@ -223,6 +223,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
Returns: Returns:
Dict of state group to state map. Dict of state group to state map.
""" """
state_filter = state_filter or StateFilter.all()
member_filter, non_member_filter = state_filter.get_member_split() member_filter, non_member_filter = state_filter.get_member_split()

View file

@ -449,7 +449,7 @@ class StateGroupStorage:
return self.stores.state._get_state_groups_from_groups(groups, state_filter) return self.stores.state._get_state_groups_from_groups(groups, state_filter)
async def get_state_for_events( async def get_state_for_events(
self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all() self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
) -> Dict[str, StateMap[EventBase]]: ) -> Dict[str, StateMap[EventBase]]:
"""Given a list of event_ids and type tuples, return a list of state """Given a list of event_ids and type tuples, return a list of state
dicts for each event. dicts for each event.
@ -465,7 +465,7 @@ class StateGroupStorage:
groups = set(event_to_groups.values()) groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups( group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter groups, state_filter or StateFilter.all()
) )
state_event_map = await self.stores.main.get_events( state_event_map = await self.stores.main.get_events(
@ -485,7 +485,7 @@ class StateGroupStorage:
return {event: event_to_state[event] for event in event_ids} return {event: event_to_state[event] for event in event_ids}
async def get_state_ids_for_events( async def get_state_ids_for_events(
self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all() self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
) -> Dict[str, StateMap[str]]: ) -> Dict[str, StateMap[str]]:
""" """
Get the state dicts corresponding to a list of events, containing the event_ids Get the state dicts corresponding to a list of events, containing the event_ids
@ -502,7 +502,7 @@ class StateGroupStorage:
groups = set(event_to_groups.values()) groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups( group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter groups, state_filter or StateFilter.all()
) )
event_to_state = { event_to_state = {
@ -513,7 +513,7 @@ class StateGroupStorage:
return {event: event_to_state[event] for event in event_ids} return {event: event_to_state[event] for event in event_ids}
async def get_state_for_event( async def get_state_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all() self, event_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[EventBase]: ) -> StateMap[EventBase]:
""" """
Get the state dict corresponding to a particular event Get the state dict corresponding to a particular event
@ -525,11 +525,13 @@ class StateGroupStorage:
Returns: Returns:
A dict from (type, state_key) -> state_event A dict from (type, state_key) -> state_event
""" """
state_map = await self.get_state_for_events([event_id], state_filter) state_map = await self.get_state_for_events(
[event_id], state_filter or StateFilter.all()
)
return state_map[event_id] return state_map[event_id]
async def get_state_ids_for_event( async def get_state_ids_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all() self, event_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[str]: ) -> StateMap[str]:
""" """
Get the state dict corresponding to a particular event Get the state dict corresponding to a particular event
@ -541,11 +543,13 @@ class StateGroupStorage:
Returns: Returns:
A dict from (type, state_key) -> state_event A dict from (type, state_key) -> state_event
""" """
state_map = await self.get_state_ids_for_events([event_id], state_filter) state_map = await self.get_state_ids_for_events(
[event_id], state_filter or StateFilter.all()
)
return state_map[event_id] return state_map[event_id]
def _get_state_for_groups( def _get_state_for_groups(
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Awaitable[Dict[int, MutableStateMap[str]]]: ) -> Awaitable[Dict[int, MutableStateMap[str]]]:
"""Gets the state at each of a list of state groups, optionally """Gets the state at each of a list of state groups, optionally
filtering by type/state_key filtering by type/state_key
@ -558,7 +562,9 @@ class StateGroupStorage:
Returns: Returns:
Dict of state group to state map. Dict of state group to state map.
""" """
return self.stores.state._get_state_for_groups(groups, state_filter) return self.stores.state._get_state_for_groups(
groups, state_filter or StateFilter.all()
)
async def store_state_group( async def store_state_group(
self, self,

View file

@ -17,7 +17,7 @@ import logging
import threading import threading
from collections import OrderedDict from collections import OrderedDict
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict, List, Optional, Set, Tuple, Union from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
import attr import attr
@ -91,7 +91,14 @@ class StreamIdGenerator:
# ... persist event ... # ... persist event ...
""" """
def __init__(self, db_conn, table, column, extra_tables=[], step=1): def __init__(
self,
db_conn,
table,
column,
extra_tables: Iterable[Tuple[str, str]] = (),
step=1,
):
assert step != 0 assert step != 0
self._lock = threading.Lock() self._lock = threading.Lock()
self._step = step self._step = step

View file

@ -57,12 +57,14 @@ def enumerate_leaves(node, depth):
class _Node: class _Node:
__slots__ = ["prev_node", "next_node", "key", "value", "callbacks"] __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
def __init__(self, prev_node, next_node, key, value, callbacks=set()): def __init__(
self, prev_node, next_node, key, value, callbacks: Optional[set] = None
):
self.prev_node = prev_node self.prev_node = prev_node
self.next_node = next_node self.next_node = next_node
self.key = key self.key = key
self.value = value self.value = value
self.callbacks = callbacks self.callbacks = callbacks or set()
class LruCache(Generic[KT, VT]): class LruCache(Generic[KT, VT]):
@ -176,10 +178,10 @@ class LruCache(Generic[KT, VT]):
self.len = synchronized(cache_len) self.len = synchronized(cache_len)
def add_node(key, value, callbacks=set()): def add_node(key, value, callbacks: Optional[set] = None):
prev_node = list_root prev_node = list_root
next_node = prev_node.next_node next_node = prev_node.next_node
node = _Node(prev_node, next_node, key, value, callbacks) node = _Node(prev_node, next_node, key, value, callbacks or set())
prev_node.next_node = node prev_node.next_node = node
next_node.prev_node = node next_node.prev_node = node
cache[key] = node cache[key] = node
@ -237,7 +239,7 @@ class LruCache(Generic[KT, VT]):
def cache_get( def cache_get(
key: KT, key: KT,
default: Optional[T] = None, default: Optional[T] = None,
callbacks: Iterable[Callable[[], None]] = [], callbacks: Iterable[Callable[[], None]] = (),
update_metrics: bool = True, update_metrics: bool = True,
): ):
node = cache.get(key, None) node = cache.get(key, None)
@ -253,7 +255,7 @@ class LruCache(Generic[KT, VT]):
return default return default
@synchronized @synchronized
def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []): def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = ()):
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
# We sometimes store large objects, e.g. dicts, which cause # We sometimes store large objects, e.g. dicts, which cause

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Optional
from mock import Mock from mock import Mock
@ -180,7 +181,11 @@ class MatrixFederationAgentTests(unittest.TestCase):
_check_logcontext(context) _check_logcontext(context)
def _handle_well_known_connection( def _handle_well_known_connection(
self, client_factory, expected_sni, content, response_headers={} self,
client_factory,
expected_sni,
content,
response_headers: Optional[dict] = None,
): ):
"""Handle an outgoing HTTPs connection: wire it up to a server, check that the """Handle an outgoing HTTPs connection: wire it up to a server, check that the
request is for a .well-known, and send the response. request is for a .well-known, and send the response.
@ -202,10 +207,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertEqual( self.assertEqual(
request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"] request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"]
) )
self._send_well_known_response(request, content, headers=response_headers) self._send_well_known_response(request, content, headers=response_headers or {})
return well_known_server return well_known_server
def _send_well_known_response(self, request, content, headers={}): def _send_well_known_response(
self, request, content, headers: Optional[dict] = None
):
"""Check that an incoming request looks like a valid .well-known request, and """Check that an incoming request looks like a valid .well-known request, and
send back the response. send back the response.
""" """
@ -213,7 +220,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertEqual(request.path, b"/.well-known/matrix/server") self.assertEqual(request.path, b"/.well-known/matrix/server")
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"])
# send back a response # send back a response
for k, v in headers.items(): for k, v in (headers or {}).items():
request.setHeader(k, v) request.setHeader(k, v)
request.write(content) request.write(content)
request.finish() request.finish()

View file

@ -266,7 +266,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
return resource return resource
def make_worker_hs( def make_worker_hs(
self, worker_app: str, extra_config: dict = {}, **kwargs self, worker_app: str, extra_config: Optional[dict] = None, **kwargs
) -> HomeServer: ) -> HomeServer:
"""Make a new worker HS instance, correctly connecting replcation """Make a new worker HS instance, correctly connecting replcation
stream to the master HS. stream to the master HS.
@ -283,7 +283,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
config = self._get_worker_hs_config() config = self._get_worker_hs_config()
config["worker_app"] = worker_app config["worker_app"] = worker_app
config.update(extra_config) config.update(extra_config or {})
worker_hs = self.setup_test_homeserver( worker_hs = self.setup_test_homeserver(
homeserver_to_use=GenericWorkerServer, homeserver_to_use=GenericWorkerServer,

View file

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Iterable, Optional
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
@ -332,15 +333,18 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
room_id=ROOM_ID, room_id=ROOM_ID,
type="m.room.message", type="m.room.message",
key=None, key=None,
internal={}, internal: Optional[dict] = None,
depth=None, depth=None,
prev_events=[], prev_events: Optional[list] = None,
auth_events=[], auth_events: Optional[list] = None,
prev_state=[], prev_state: Optional[list] = None,
redacts=None, redacts=None,
push_actions=[], push_actions: Iterable = frozenset(),
**content **content
): ):
prev_events = prev_events or []
auth_events = auth_events or []
prev_state = prev_state or []
if depth is None: if depth is None:
depth = self.event_id depth = self.event_id
@ -369,7 +373,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
if redacts is not None: if redacts is not None:
event_dict["redacts"] = redacts event_dict["redacts"] = redacts
event = make_event_from_dict(event_dict, internal_metadata_dict=internal) event = make_event_from_dict(event_dict, internal_metadata_dict=internal or {})
self.event_id += 1 self.event_id += 1
state_handler = self.hs.get_state_handler() state_handler = self.hs.get_state_handler()

View file

@ -19,6 +19,7 @@
"""Tests REST events for /rooms paths.""" """Tests REST events for /rooms paths."""
import json import json
from typing import Iterable
from urllib import parse as urlparse from urllib import parse as urlparse
from mock import Mock from mock import Mock
@ -207,7 +208,9 @@ class RoomPermissionsTestCase(RoomBase):
) )
self.assertEquals(403, channel.code, msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
def _test_get_membership(self, room=None, members=[], expect_code=None): def _test_get_membership(
self, room=None, members: Iterable = frozenset(), expect_code=None
):
for member in members: for member in members:
path = "/rooms/%s/state/m.room.member/%s" % (room, member) path = "/rooms/%s/state/m.room.member/%s" % (room, member)
channel = self.make_request("GET", path) channel = self.make_request("GET", path)

View file

@ -132,7 +132,7 @@ class RestHelper:
src: str, src: str,
targ: str, targ: str,
membership: str, membership: str,
extra_data: dict = {}, extra_data: Optional[dict] = None,
tok: Optional[str] = None, tok: Optional[str] = None,
expect_code: int = 200, expect_code: int = 200,
) -> None: ) -> None:
@ -156,7 +156,7 @@ class RestHelper:
path = path + "?access_token=%s" % tok path = path + "?access_token=%s" % tok
data = {"membership": membership} data = {"membership": membership}
data.update(extra_data) data.update(extra_data or {})
channel = make_request( channel = make_request(
self.hs.get_reactor(), self.hs.get_reactor(),
@ -187,7 +187,13 @@ class RestHelper:
) )
def send_event( def send_event(
self, room_id, type, content={}, txn_id=None, tok=None, expect_code=200 self,
room_id,
type,
content: Optional[dict] = None,
txn_id=None,
tok=None,
expect_code=200,
): ):
if txn_id is None: if txn_id is None:
txn_id = "m%s" % (str(time.time())) txn_id = "m%s" % (str(time.time()))
@ -201,7 +207,7 @@ class RestHelper:
self.site, self.site,
"PUT", "PUT",
path, path,
json.dumps(content).encode("utf8"), json.dumps(content or {}).encode("utf8"),
) )
assert ( assert (

View file

@ -16,6 +16,7 @@
import itertools import itertools
import json import json
import urllib import urllib
from typing import Optional
from synapse.api.constants import EventTypes, RelationTypes from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin from synapse.rest import admin
@ -681,7 +682,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
relation_type, relation_type,
event_type, event_type,
key=None, key=None,
content={}, content: Optional[dict] = None,
access_token=None, access_token=None,
parent_id=None, parent_id=None,
): ):
@ -713,7 +714,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"POST", "POST",
"/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s" "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
% (self.room, original_id, relation_type, event_type, query), % (self.room, original_id, relation_type, event_type, query),
json.dumps(content).encode("utf-8"), json.dumps(content or {}).encode("utf-8"),
access_token=access_token, access_token=access_token,
) )
return channel return channel

View file

@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Optional
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.engines import IncorrectDatabaseSetup from synapse.storage.engines import IncorrectDatabaseSetup
from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator
@ -43,7 +45,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
) )
def _create_id_generator( def _create_id_generator(
self, instance_name="master", writers=["master"] self, instance_name="master", writers: Optional[List[str]] = None
) -> MultiWriterIdGenerator: ) -> MultiWriterIdGenerator:
def _create(conn): def _create(conn):
return MultiWriterIdGenerator( return MultiWriterIdGenerator(
@ -53,7 +55,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
instance_name=instance_name, instance_name=instance_name,
tables=[("foobar", "instance_name", "stream_id")], tables=[("foobar", "instance_name", "stream_id")],
sequence_name="foobar_seq", sequence_name="foobar_seq",
writers=writers, writers=writers or ["master"],
) )
return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
@ -476,7 +478,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
) )
def _create_id_generator( def _create_id_generator(
self, instance_name="master", writers=["master"] self, instance_name="master", writers: Optional[List[str]] = None
) -> MultiWriterIdGenerator: ) -> MultiWriterIdGenerator:
def _create(conn): def _create(conn):
return MultiWriterIdGenerator( return MultiWriterIdGenerator(
@ -486,7 +488,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
instance_name=instance_name, instance_name=instance_name,
tables=[("foobar", "instance_name", "stream_id")], tables=[("foobar", "instance_name", "stream_id")],
sequence_name="foobar_seq", sequence_name="foobar_seq",
writers=writers, writers=writers or ["master"],
positive=False, positive=False,
) )
@ -612,7 +614,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
) )
def _create_id_generator( def _create_id_generator(
self, instance_name="master", writers=["master"] self, instance_name="master", writers: Optional[List[str]] = None
) -> MultiWriterIdGenerator: ) -> MultiWriterIdGenerator:
def _create(conn): def _create(conn):
return MultiWriterIdGenerator( return MultiWriterIdGenerator(
@ -625,7 +627,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
("foobar2", "instance_name", "stream_id"), ("foobar2", "instance_name", "stream_id"),
], ],
sequence_name="foobar_seq", sequence_name="foobar_seq",
writers=writers, writers=writers or ["master"],
) )
return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) return self.get_success_or_raise(self.db_pool.runWithConnection(_create))

View file

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional
from canonicaljson import json from canonicaljson import json
@ -47,10 +48,15 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.depth = 1 self.depth = 1
def inject_room_member( def inject_room_member(
self, room, user, membership, replaces_state=None, extra_content={} self,
room,
user,
membership,
replaces_state=None,
extra_content: Optional[dict] = None,
): ):
content = {"membership": membership} content = {"membership": membership}
content.update(extra_content) content.update(extra_content or {})
builder = self.event_builder_factory.for_room_version( builder = self.event_builder_factory.for_room_version(
RoomVersions.V1, RoomVersions.V1,
{ {

View file

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Optional
from mock import Mock from mock import Mock
@ -37,7 +38,7 @@ def create_event(
state_key=None, state_key=None,
depth=2, depth=2,
event_id=None, event_id=None,
prev_events=[], prev_events: Optional[List[str]] = None,
**kwargs **kwargs
): ):
global _next_event_id global _next_event_id
@ -58,7 +59,7 @@ def create_event(
"sender": "@user_id:example.com", "sender": "@user_id:example.com",
"room_id": "!room_id:example.com", "room_id": "!room_id:example.com",
"depth": depth, "depth": depth,
"prev_events": prev_events, "prev_events": prev_events or [],
} }
if state_key is not None: if state_key is not None:

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Optional
from mock import Mock from mock import Mock
@ -147,9 +148,11 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
return event return event
@defer.inlineCallbacks @defer.inlineCallbacks
def inject_room_member(self, user_id, membership="join", extra_content={}): def inject_room_member(
self, user_id, membership="join", extra_content: Optional[dict] = None
):
content = {"membership": membership} content = {"membership": membership}
content.update(extra_content) content.update(extra_content or {})
builder = self.event_builder_factory.for_room_version( builder = self.event_builder_factory.for_room_version(
RoomVersions.V1, RoomVersions.V1,
{ {

View file

@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
@ -89,9 +91,9 @@ def _await_resolution(reactor, d):
return (reactor.seconds() - start_time) * 1000 return (reactor.seconds() - start_time) * 1000
def build_rc_config(settings={}): def build_rc_config(settings: Optional[dict] = None):
config_dict = default_config("test") config_dict = default_config("test")
config_dict.update(settings) config_dict.update(settings or {})
config = HomeServerConfig() config = HomeServerConfig()
config.parse_config_dict(config_dict, "", "") config.parse_config_dict(config_dict, "", "")
return config.rc_federation return config.rc_federation