Bugbear: Add Mutable Parameter fixes (#9682)

Part of #9366

Adds in fixes for B006 and B008, both relating to mutable parameter lint errors.

Signed-off-by: Jonathan de Jong <jonathan@automatia.nl>
This commit is contained in:
Jonathan de Jong 2021-04-08 23:38:54 +02:00 committed by GitHub
parent 64f4f506c5
commit 2ca4e349e9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
38 changed files with 224 additions and 113 deletions

1
changelog.d/9682.misc Normal file
View file

@ -0,0 +1 @@
Introduce flake8-bugbear to the test suite and fix some of its lint violations.

View file

@ -24,6 +24,7 @@ import sys
import time
import urllib
from http import TwistedHttpClient
from typing import Optional
import nacl.encoding
import nacl.signing
@ -718,7 +719,7 @@ class SynapseCmd(cmd.Cmd):
method,
path,
data=None,
query_params={"access_token": None},
query_params: Optional[dict] = None,
alt_text=None,
):
"""Runs an HTTP request and pretty prints the output.
@ -729,6 +730,8 @@ class SynapseCmd(cmd.Cmd):
data: Raw JSON data if any
query_params: dict of query parameters to add to the url
"""
query_params = query_params or {"access_token": None}
url = self._url() + path
if "access_token" in query_params:
query_params["access_token"] = self._tok()

View file

@ -16,6 +16,7 @@
import json
import urllib
from pprint import pformat
from typing import Optional
from twisted.internet import defer, reactor
from twisted.web.client import Agent, readBody
@ -85,8 +86,9 @@ class TwistedHttpClient(HttpClient):
body = yield readBody(response)
defer.returnValue(json.loads(body))
def _create_put_request(self, url, json_data, headers_dict={}):
def _create_put_request(self, url, json_data, headers_dict: Optional[dict] = None):
"""Wrapper of _create_request to issue a PUT request"""
headers_dict = headers_dict or {}
if "Content-Type" not in headers_dict:
raise defer.error(RuntimeError("Must include Content-Type header for PUTs"))
@ -95,14 +97,22 @@ class TwistedHttpClient(HttpClient):
"PUT", url, producer=_JsonProducer(json_data), headers_dict=headers_dict
)
def _create_get_request(self, url, headers_dict={}):
def _create_get_request(self, url, headers_dict: Optional[dict] = None):
"""Wrapper of _create_request to issue a GET request"""
return self._create_request("GET", url, headers_dict=headers_dict)
return self._create_request("GET", url, headers_dict=headers_dict or {})
@defer.inlineCallbacks
def do_request(
self, method, url, data=None, qparams=None, jsonreq=True, headers={}
self,
method,
url,
data=None,
qparams=None,
jsonreq=True,
headers: Optional[dict] = None,
):
headers = headers or {}
if qparams:
url = "%s?%s" % (url, urllib.urlencode(qparams, True))
@ -123,8 +133,12 @@ class TwistedHttpClient(HttpClient):
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
def _create_request(self, method, url, producer=None, headers_dict={}):
def _create_request(
self, method, url, producer=None, headers_dict: Optional[dict] = None
):
"""Creates and sends a request to the given url"""
headers_dict = headers_dict or {}
headers_dict["User-Agent"] = ["Synapse Cmd Client"]
retries_left = 5

View file

@ -18,8 +18,8 @@ ignore =
# E203: whitespace before ':' (which is contrary to pep8?)
# E731: do not assign a lambda expression, use a def
# E501: Line too long (black enforces this for us)
# B00*: Subsection of the bugbear suite (TODO: add in remaining fixes)
ignore=W503,W504,E203,E731,E501,B006,B007,B008
# B007: Subsection of the bugbear suite (TODO: add in remaining fixes)
ignore=W503,W504,E203,E731,E501,B007
[isort]
line_length = 88

View file

@ -49,7 +49,7 @@ This is all tied together by the AppServiceScheduler which DIs the required
components.
"""
import logging
from typing import List
from typing import List, Optional
from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.events import EventBase
@ -191,11 +191,11 @@ class _TransactionController:
self,
service: ApplicationService,
events: List[EventBase],
ephemeral: List[JsonDict] = [],
ephemeral: Optional[List[JsonDict]] = None,
):
try:
txn = await self.store.create_appservice_txn(
service=service, events=events, ephemeral=ephemeral
service=service, events=events, ephemeral=ephemeral or []
)
service_is_up = await self._is_service_up(service)
if service_is_up:

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from typing import Dict, Optional
from ._base import Config
@ -21,8 +21,10 @@ class RateLimitConfig:
def __init__(
self,
config: Dict[str, float],
defaults={"per_second": 0.17, "burst_count": 3.0},
defaults: Optional[Dict[str, float]] = None,
):
defaults = defaults or {"per_second": 0.17, "burst_count": 3.0}
self.per_second = config.get("per_second", defaults["per_second"])
self.burst_count = int(config.get("burst_count", defaults["burst_count"]))

View file

@ -330,9 +330,11 @@ class FrozenEvent(EventBase):
self,
event_dict: JsonDict,
room_version: RoomVersion,
internal_metadata_dict: JsonDict = {},
internal_metadata_dict: Optional[JsonDict] = None,
rejected_reason: Optional[str] = None,
):
internal_metadata_dict = internal_metadata_dict or {}
event_dict = dict(event_dict)
# Signatures is a dict of dicts, and this is faster than doing a
@ -386,9 +388,11 @@ class FrozenEventV2(EventBase):
self,
event_dict: JsonDict,
room_version: RoomVersion,
internal_metadata_dict: JsonDict = {},
internal_metadata_dict: Optional[JsonDict] = None,
rejected_reason: Optional[str] = None,
):
internal_metadata_dict = internal_metadata_dict or {}
event_dict = dict(event_dict)
# Signatures is a dict of dicts, and this is faster than doing a
@ -507,9 +511,11 @@ def _event_type_from_format_version(format_version: int) -> Type[EventBase]:
def make_event_from_dict(
event_dict: JsonDict,
room_version: RoomVersion = RoomVersions.V1,
internal_metadata_dict: JsonDict = {},
internal_metadata_dict: Optional[JsonDict] = None,
rejected_reason: Optional[str] = None,
) -> EventBase:
"""Construct an EventBase from the given event dict"""
event_type = _event_type_from_format_version(room_version.event_format)
return event_type(event_dict, room_version, internal_metadata_dict, rejected_reason)
return event_type(
event_dict, room_version, internal_metadata_dict or {}, rejected_reason
)

View file

@ -18,6 +18,7 @@ server protocol.
"""
import logging
from typing import Optional
import attr
@ -98,7 +99,7 @@ class Transaction(JsonEncodedObject):
"pdus",
]
def __init__(self, transaction_id=None, pdus=[], **kwargs):
def __init__(self, transaction_id=None, pdus: Optional[list] = None, **kwargs):
"""If we include a list of pdus then we decode then as PDU's
automatically.
"""
@ -107,7 +108,7 @@ class Transaction(JsonEncodedObject):
if "edus" in kwargs and not kwargs["edus"]:
del kwargs["edus"]
super().__init__(transaction_id=transaction_id, pdus=pdus, **kwargs)
super().__init__(transaction_id=transaction_id, pdus=pdus or [], **kwargs)
@staticmethod
def create_new(pdus, **kwargs):

View file

@ -182,7 +182,7 @@ class ApplicationServicesHandler:
self,
stream_key: str,
new_token: Optional[int],
users: Collection[Union[str, UserID]] = [],
users: Optional[Collection[Union[str, UserID]]] = None,
):
"""This is called by the notifier in the background
when a ephemeral event handled by the homeserver.
@ -215,7 +215,7 @@ class ApplicationServicesHandler:
# We only start a new background process if necessary rather than
# optimistically (to cut down on overhead).
self._notify_interested_services_ephemeral(
services, stream_key, new_token, users
services, stream_key, new_token, users or []
)
@wrap_as_background_process("notify_interested_services_ephemeral")

View file

@ -1790,7 +1790,7 @@ class FederationHandler(BaseHandler):
room_id: str,
user_id: str,
membership: str,
content: JsonDict = {},
content: JsonDict,
params: Optional[Dict[str, Union[str, Iterable[str]]]] = None,
) -> Tuple[str, EventBase, RoomVersion]:
(

View file

@ -137,7 +137,7 @@ class MessageHandler:
self,
user_id: str,
room_id: str,
state_filter: StateFilter = StateFilter.all(),
state_filter: Optional[StateFilter] = None,
at_token: Optional[StreamToken] = None,
is_guest: bool = False,
) -> List[dict]:
@ -164,6 +164,8 @@ class MessageHandler:
AuthError (403) if the user doesn't have permission to view
members of this room.
"""
state_filter = state_filter or StateFilter.all()
if at_token:
# FIXME this claims to get the state at a stream position, but
# get_recent_events_for_room operates by topo ordering. This therefore
@ -874,7 +876,7 @@ class EventCreationHandler:
event: EventBase,
context: EventContext,
ratelimit: bool = True,
extra_users: List[UserID] = [],
extra_users: Optional[List[UserID]] = None,
ignore_shadow_ban: bool = False,
) -> EventBase:
"""Processes a new event.
@ -902,6 +904,7 @@ class EventCreationHandler:
Raises:
ShadowBanError if the requester has been shadow-banned.
"""
extra_users = extra_users or []
# we don't apply shadow-banning to membership events here. Invites are blocked
# higher up the stack, and we allow shadow-banned users to send join and leave
@ -1071,7 +1074,7 @@ class EventCreationHandler:
event: EventBase,
context: EventContext,
ratelimit: bool = True,
extra_users: List[UserID] = [],
extra_users: Optional[List[UserID]] = None,
) -> EventBase:
"""Called when we have fully built the event, have already
calculated the push actions for the event, and checked auth.
@ -1083,6 +1086,8 @@ class EventCreationHandler:
it was de-duplicated (e.g. because we had already persisted an
event with the same transaction ID.)
"""
extra_users = extra_users or []
assert self.storage.persistence is not None
assert self._events_shard_config.should_handle(
self._instance_name, event.room_id

View file

@ -169,7 +169,7 @@ class RegistrationHandler(BaseHandler):
user_type: Optional[str] = None,
default_display_name: Optional[str] = None,
address: Optional[str] = None,
bind_emails: Iterable[str] = [],
bind_emails: Optional[Iterable[str]] = None,
by_admin: bool = False,
user_agent_ips: Optional[List[Tuple[str, str]]] = None,
auth_provider_id: Optional[str] = None,
@ -204,6 +204,8 @@ class RegistrationHandler(BaseHandler):
Raises:
SynapseError if there was a problem registering.
"""
bind_emails = bind_emails or []
await self.check_registration_ratelimit(address)
result = await self.spam_checker.check_registration_for_spam(

View file

@ -548,7 +548,7 @@ class SyncHandler:
)
async def get_state_after_event(
self, event: EventBase, state_filter: StateFilter = StateFilter.all()
self, event: EventBase, state_filter: Optional[StateFilter] = None
) -> StateMap[str]:
"""
Get the room state after the given event
@ -558,7 +558,7 @@ class SyncHandler:
state_filter: The state filter used to fetch state from the database.
"""
state_ids = await self.state_store.get_state_ids_for_event(
event.event_id, state_filter=state_filter
event.event_id, state_filter=state_filter or StateFilter.all()
)
if event.is_state():
state_ids = dict(state_ids)
@ -569,7 +569,7 @@ class SyncHandler:
self,
room_id: str,
stream_position: StreamToken,
state_filter: StateFilter = StateFilter.all(),
state_filter: Optional[StateFilter] = None,
) -> StateMap[str]:
"""Get the room state at a particular stream position
@ -589,7 +589,7 @@ class SyncHandler:
if last_events:
last_event = last_events[-1]
state = await self.get_state_after_event(
last_event, state_filter=state_filter
last_event, state_filter=state_filter or StateFilter.all()
)
else:

View file

@ -297,7 +297,7 @@ class SimpleHttpClient:
def __init__(
self,
hs: "HomeServer",
treq_args: Dict[str, Any] = {},
treq_args: Optional[Dict[str, Any]] = None,
ip_whitelist: Optional[IPSet] = None,
ip_blacklist: Optional[IPSet] = None,
use_proxy: bool = False,
@ -317,7 +317,7 @@ class SimpleHttpClient:
self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist
self._extra_treq_args = treq_args
self._extra_treq_args = treq_args or {}
self.user_agent = hs.version_string
self.clock = hs.get_clock()

View file

@ -27,7 +27,7 @@ from twisted.python.failure import Failure
from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase
from twisted.web.error import SchemeNotSupported
from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent
from twisted.web.iweb import IAgent, IPolicyForHTTPS
from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint
@ -88,12 +88,14 @@ class ProxyAgent(_AgentBase):
self,
reactor,
proxy_reactor=None,
contextFactory=BrowserLikePolicyForHTTPS(),
contextFactory: Optional[IPolicyForHTTPS] = None,
connectTimeout=None,
bindAddress=None,
pool=None,
use_proxy=False,
):
contextFactory = contextFactory or BrowserLikePolicyForHTTPS()
_AgentBase.__init__(self, reactor, pool)
if proxy_reactor is None:

View file

@ -486,7 +486,7 @@ def start_active_span_from_request(
def start_active_span_from_edu(
edu_content,
operation_name,
references=[],
references: Optional[list] = None,
tags=None,
start_time=None,
ignore_active_span=False,
@ -501,6 +501,7 @@ def start_active_span_from_edu(
For the other args see opentracing.tracer
"""
references = references or []
if opentracing is None:
return noop_context_manager()

View file

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Generator, Iterable, Optional, Tuple
from typing import TYPE_CHECKING, Any, Generator, Iterable, List, Optional, Tuple
from twisted.internet import defer
@ -127,7 +127,7 @@ class ModuleApi:
return defer.ensureDeferred(self._auth_handler.check_user_exists(user_id))
@defer.inlineCallbacks
def register(self, localpart, displayname=None, emails=[]):
def register(self, localpart, displayname=None, emails: Optional[List[str]] = None):
"""Registers a new user with given localpart and optional displayname, emails.
Also returns an access token for the new user.
@ -147,11 +147,13 @@ class ModuleApi:
logger.warning(
"Using deprecated ModuleApi.register which creates a dummy user device."
)
user_id = yield self.register_user(localpart, displayname, emails)
user_id = yield self.register_user(localpart, displayname, emails or [])
_, access_token = yield self.register_device(user_id)
return user_id, access_token
def register_user(self, localpart, displayname=None, emails=[]):
def register_user(
self, localpart, displayname=None, emails: Optional[List[str]] = None
):
"""Registers a new user with given localpart and optional displayname, emails.
Args:
@ -170,7 +172,7 @@ class ModuleApi:
self._hs.get_registration_handler().register_user(
localpart=localpart,
default_display_name=displayname,
bind_emails=emails,
bind_emails=emails or [],
)
)

View file

@ -276,7 +276,7 @@ class Notifier:
event: EventBase,
event_pos: PersistedEventPosition,
max_room_stream_token: RoomStreamToken,
extra_users: Collection[UserID] = [],
extra_users: Optional[Collection[UserID]] = None,
):
"""Unwraps event and calls `on_new_room_event_args`."""
self.on_new_room_event_args(
@ -286,7 +286,7 @@ class Notifier:
state_key=event.get("state_key"),
membership=event.content.get("membership"),
max_room_stream_token=max_room_stream_token,
extra_users=extra_users,
extra_users=extra_users or [],
)
def on_new_room_event_args(
@ -297,7 +297,7 @@ class Notifier:
membership: Optional[str],
event_pos: PersistedEventPosition,
max_room_stream_token: RoomStreamToken,
extra_users: Collection[UserID] = [],
extra_users: Optional[Collection[UserID]] = None,
):
"""Used by handlers to inform the notifier something has happened
in the room, room event wise.
@ -313,7 +313,7 @@ class Notifier:
self.pending_new_room_events.append(
_PendingRoomEventEntry(
event_pos=event_pos,
extra_users=extra_users,
extra_users=extra_users or [],
room_id=room_id,
type=event_type,
state_key=state_key,
@ -382,14 +382,14 @@ class Notifier:
self,
stream_key: str,
new_token: Union[int, RoomStreamToken],
users: Collection[Union[str, UserID]] = [],
users: Optional[Collection[Union[str, UserID]]] = None,
):
try:
stream_token = None
if isinstance(new_token, int):
stream_token = new_token
self.appservice_handler.notify_interested_services_ephemeral(
stream_key, stream_token, users
stream_key, stream_token, users or []
)
except Exception:
logger.exception("Error notifying application services of event")
@ -404,13 +404,16 @@ class Notifier:
self,
stream_key: str,
new_token: Union[int, RoomStreamToken],
users: Collection[Union[str, UserID]] = [],
rooms: Collection[str] = [],
users: Optional[Collection[Union[str, UserID]]] = None,
rooms: Optional[Collection[str]] = None,
):
"""Used to inform listeners that something has happened event wise.
Will wake up all listeners for the given users and rooms.
"""
users = users or []
rooms = rooms or []
with Measure(self.clock, "on_new_event"):
user_streams = set()

View file

@ -900,7 +900,7 @@ class DatabasePool:
table: str,
keyvalues: Dict[str, Any],
values: Dict[str, Any],
insertion_values: Dict[str, Any] = {},
insertion_values: Optional[Dict[str, Any]] = None,
desc: str = "simple_upsert",
lock: bool = True,
) -> Optional[bool]:
@ -927,6 +927,8 @@ class DatabasePool:
Native upserts always return None. Emulated upserts return True if a
new entry was created, False if an existing one was updated.
"""
insertion_values = insertion_values or {}
attempts = 0
while True:
try:
@ -964,7 +966,7 @@ class DatabasePool:
table: str,
keyvalues: Dict[str, Any],
values: Dict[str, Any],
insertion_values: Dict[str, Any] = {},
insertion_values: Optional[Dict[str, Any]] = None,
lock: bool = True,
) -> Optional[bool]:
"""
@ -982,6 +984,8 @@ class DatabasePool:
Native upserts always return None. Emulated upserts return True if a
new entry was created, False if an existing one was updated.
"""
insertion_values = insertion_values or {}
if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
self.simple_upsert_txn_native_upsert(
txn, table, keyvalues, values, insertion_values=insertion_values
@ -1003,7 +1007,7 @@ class DatabasePool:
table: str,
keyvalues: Dict[str, Any],
values: Dict[str, Any],
insertion_values: Dict[str, Any] = {},
insertion_values: Optional[Dict[str, Any]] = None,
lock: bool = True,
) -> bool:
"""
@ -1017,6 +1021,8 @@ class DatabasePool:
Returns True if a new entry was created, False if an existing
one was updated.
"""
insertion_values = insertion_values or {}
# We need to lock the table :(, unless we're *really* careful
if lock:
self.engine.lock_table(txn, table)
@ -1077,7 +1083,7 @@ class DatabasePool:
table: str,
keyvalues: Dict[str, Any],
values: Dict[str, Any],
insertion_values: Dict[str, Any] = {},
insertion_values: Optional[Dict[str, Any]] = None,
) -> None:
"""
Use the native UPSERT functionality in recent PostgreSQL versions.
@ -1090,7 +1096,7 @@ class DatabasePool:
"""
allvalues = {} # type: Dict[str, Any]
allvalues.update(keyvalues)
allvalues.update(insertion_values)
allvalues.update(insertion_values or {})
if not values:
latter = "NOTHING"
@ -1513,7 +1519,7 @@ class DatabasePool:
column: str,
iterable: Iterable[Any],
retcols: Iterable[str],
keyvalues: Dict[str, Any] = {},
keyvalues: Optional[Dict[str, Any]] = None,
desc: str = "simple_select_many_batch",
batch_size: int = 100,
) -> List[Any]:
@ -1531,6 +1537,8 @@ class DatabasePool:
desc: description of the transaction, for logging and metrics
batch_size: the number of rows for each select query
"""
keyvalues = keyvalues or {}
results = [] # type: List[Dict[str, Any]]
if not iterable:

View file

@ -320,8 +320,8 @@ class PersistEventsStore:
txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool,
state_delta_for_room: Dict[str, DeltaState] = {},
new_forward_extremeties: Dict[str, List[str]] = {},
state_delta_for_room: Optional[Dict[str, DeltaState]] = None,
new_forward_extremeties: Optional[Dict[str, List[str]]] = None,
):
"""Insert some number of room events into the necessary database tables.
@ -342,6 +342,9 @@ class PersistEventsStore:
extremities.
"""
state_delta_for_room = state_delta_for_room or {}
new_forward_extremeties = new_forward_extremeties or {}
all_events_and_contexts = events_and_contexts
min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering

View file

@ -1171,7 +1171,7 @@ class GroupServerStore(GroupServerWorkerStore):
user_id: str,
membership: str,
is_admin: bool = False,
content: JsonDict = {},
content: Optional[JsonDict] = None,
local_attestation: Optional[dict] = None,
remote_attestation: Optional[dict] = None,
is_publicised: bool = False,
@ -1192,6 +1192,8 @@ class GroupServerStore(GroupServerWorkerStore):
is_publicised: Whether this should be publicised.
"""
content = content or {}
def _register_user_group_membership_txn(txn, next_id):
# TODO: Upsert?
self.db_pool.simple_delete_txn(

View file

@ -190,7 +190,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# FIXME: how should this be cached?
async def get_filtered_current_state_ids(
self, room_id: str, state_filter: StateFilter = StateFilter.all()
self, room_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[str]:
"""Get the current state event of a given type for a room based on the
current_state_events table. This may not be as up-to-date as the result
@ -205,7 +205,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Map from type/state_key to event ID.
"""
where_clause, where_args = state_filter.make_sql_filter_clause()
where_clause, where_args = (
state_filter or StateFilter.all()
).make_sql_filter_clause()
if not where_clause:
# We delegate to the cached version

View file

@ -14,6 +14,7 @@
# limitations under the License.
import logging
from typing import Optional
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
@ -73,8 +74,10 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
return count
def _get_state_groups_from_groups_txn(
self, txn, groups, state_filter=StateFilter.all()
self, txn, groups, state_filter: Optional[StateFilter] = None
):
state_filter = state_filter or StateFilter.all()
results = {group: {} for group in groups}
where_clause, where_args = state_filter.make_sql_filter_clause()

View file

@ -15,7 +15,7 @@
import logging
from collections import namedtuple
from typing import Dict, Iterable, List, Set, Tuple
from typing import Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.constants import EventTypes
from synapse.storage._base import SQLBaseStore
@ -210,7 +210,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_filter.filter_state(state_dict_ids), not missing_types
async def _get_state_for_groups(
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Dict[int, MutableStateMap[str]]:
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
@ -223,6 +223,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
Returns:
Dict of state group to state map.
"""
state_filter = state_filter or StateFilter.all()
member_filter, non_member_filter = state_filter.get_member_split()

View file

@ -449,7 +449,7 @@ class StateGroupStorage:
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
async def get_state_for_events(
self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all()
self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
) -> Dict[str, StateMap[EventBase]]:
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event.
@ -465,7 +465,7 @@ class StateGroupStorage:
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter
groups, state_filter or StateFilter.all()
)
state_event_map = await self.stores.main.get_events(
@ -485,7 +485,7 @@ class StateGroupStorage:
return {event: event_to_state[event] for event in event_ids}
async def get_state_ids_for_events(
self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all()
self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
) -> Dict[str, StateMap[str]]:
"""
Get the state dicts corresponding to a list of events, containing the event_ids
@ -502,7 +502,7 @@ class StateGroupStorage:
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter
groups, state_filter or StateFilter.all()
)
event_to_state = {
@ -513,7 +513,7 @@ class StateGroupStorage:
return {event: event_to_state[event] for event in event_ids}
async def get_state_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all()
self, event_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[EventBase]:
"""
Get the state dict corresponding to a particular event
@ -525,11 +525,13 @@ class StateGroupStorage:
Returns:
A dict from (type, state_key) -> state_event
"""
state_map = await self.get_state_for_events([event_id], state_filter)
state_map = await self.get_state_for_events(
[event_id], state_filter or StateFilter.all()
)
return state_map[event_id]
async def get_state_ids_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all()
self, event_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[str]:
"""
Get the state dict corresponding to a particular event
@ -541,11 +543,13 @@ class StateGroupStorage:
Returns:
A dict from (type, state_key) -> state_event
"""
state_map = await self.get_state_ids_for_events([event_id], state_filter)
state_map = await self.get_state_ids_for_events(
[event_id], state_filter or StateFilter.all()
)
return state_map[event_id]
def _get_state_for_groups(
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Awaitable[Dict[int, MutableStateMap[str]]]:
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
@ -558,7 +562,9 @@ class StateGroupStorage:
Returns:
Dict of state group to state map.
"""
return self.stores.state._get_state_for_groups(groups, state_filter)
return self.stores.state._get_state_for_groups(
groups, state_filter or StateFilter.all()
)
async def store_state_group(
self,

View file

@ -17,7 +17,7 @@ import logging
import threading
from collections import OrderedDict
from contextlib import contextmanager
from typing import Dict, List, Optional, Set, Tuple, Union
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
import attr
@ -91,7 +91,14 @@ class StreamIdGenerator:
# ... persist event ...
"""
def __init__(self, db_conn, table, column, extra_tables=[], step=1):
def __init__(
self,
db_conn,
table,
column,
extra_tables: Iterable[Tuple[str, str]] = (),
step=1,
):
assert step != 0
self._lock = threading.Lock()
self._step = step

View file

@ -57,12 +57,14 @@ def enumerate_leaves(node, depth):
class _Node:
__slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
def __init__(self, prev_node, next_node, key, value, callbacks=set()):
def __init__(
self, prev_node, next_node, key, value, callbacks: Optional[set] = None
):
self.prev_node = prev_node
self.next_node = next_node
self.key = key
self.value = value
self.callbacks = callbacks
self.callbacks = callbacks or set()
class LruCache(Generic[KT, VT]):
@ -176,10 +178,10 @@ class LruCache(Generic[KT, VT]):
self.len = synchronized(cache_len)
def add_node(key, value, callbacks=set()):
def add_node(key, value, callbacks: Optional[set] = None):
prev_node = list_root
next_node = prev_node.next_node
node = _Node(prev_node, next_node, key, value, callbacks)
node = _Node(prev_node, next_node, key, value, callbacks or set())
prev_node.next_node = node
next_node.prev_node = node
cache[key] = node
@ -237,7 +239,7 @@ class LruCache(Generic[KT, VT]):
def cache_get(
key: KT,
default: Optional[T] = None,
callbacks: Iterable[Callable[[], None]] = [],
callbacks: Iterable[Callable[[], None]] = (),
update_metrics: bool = True,
):
node = cache.get(key, None)
@ -253,7 +255,7 @@ class LruCache(Generic[KT, VT]):
return default
@synchronized
def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []):
def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = ()):
node = cache.get(key, None)
if node is not None:
# We sometimes store large objects, e.g. dicts, which cause

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
from mock import Mock
@ -180,7 +181,11 @@ class MatrixFederationAgentTests(unittest.TestCase):
_check_logcontext(context)
def _handle_well_known_connection(
self, client_factory, expected_sni, content, response_headers={}
self,
client_factory,
expected_sni,
content,
response_headers: Optional[dict] = None,
):
"""Handle an outgoing HTTPs connection: wire it up to a server, check that the
request is for a .well-known, and send the response.
@ -202,10 +207,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertEqual(
request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"]
)
self._send_well_known_response(request, content, headers=response_headers)
self._send_well_known_response(request, content, headers=response_headers or {})
return well_known_server
def _send_well_known_response(self, request, content, headers={}):
def _send_well_known_response(
self, request, content, headers: Optional[dict] = None
):
"""Check that an incoming request looks like a valid .well-known request, and
send back the response.
"""
@ -213,7 +220,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertEqual(request.path, b"/.well-known/matrix/server")
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"])
# send back a response
for k, v in headers.items():
for k, v in (headers or {}).items():
request.setHeader(k, v)
request.write(content)
request.finish()

View file

@ -266,7 +266,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
return resource
def make_worker_hs(
self, worker_app: str, extra_config: dict = {}, **kwargs
self, worker_app: str, extra_config: Optional[dict] = None, **kwargs
) -> HomeServer:
"""Make a new worker HS instance, correctly connecting replcation
stream to the master HS.
@ -283,7 +283,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
config = self._get_worker_hs_config()
config["worker_app"] = worker_app
config.update(extra_config)
config.update(extra_config or {})
worker_hs = self.setup_test_homeserver(
homeserver_to_use=GenericWorkerServer,

View file

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Iterable, Optional
from canonicaljson import encode_canonical_json
@ -332,15 +333,18 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
room_id=ROOM_ID,
type="m.room.message",
key=None,
internal={},
internal: Optional[dict] = None,
depth=None,
prev_events=[],
auth_events=[],
prev_state=[],
prev_events: Optional[list] = None,
auth_events: Optional[list] = None,
prev_state: Optional[list] = None,
redacts=None,
push_actions=[],
push_actions: Iterable = frozenset(),
**content
):
prev_events = prev_events or []
auth_events = auth_events or []
prev_state = prev_state or []
if depth is None:
depth = self.event_id
@ -369,7 +373,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
if redacts is not None:
event_dict["redacts"] = redacts
event = make_event_from_dict(event_dict, internal_metadata_dict=internal)
event = make_event_from_dict(event_dict, internal_metadata_dict=internal or {})
self.event_id += 1
state_handler = self.hs.get_state_handler()

View file

@ -19,6 +19,7 @@
"""Tests REST events for /rooms paths."""
import json
from typing import Iterable
from urllib import parse as urlparse
from mock import Mock
@ -207,7 +208,9 @@ class RoomPermissionsTestCase(RoomBase):
)
self.assertEquals(403, channel.code, msg=channel.result["body"])
def _test_get_membership(self, room=None, members=[], expect_code=None):
def _test_get_membership(
self, room=None, members: Iterable = frozenset(), expect_code=None
):
for member in members:
path = "/rooms/%s/state/m.room.member/%s" % (room, member)
channel = self.make_request("GET", path)

View file

@ -132,7 +132,7 @@ class RestHelper:
src: str,
targ: str,
membership: str,
extra_data: dict = {},
extra_data: Optional[dict] = None,
tok: Optional[str] = None,
expect_code: int = 200,
) -> None:
@ -156,7 +156,7 @@ class RestHelper:
path = path + "?access_token=%s" % tok
data = {"membership": membership}
data.update(extra_data)
data.update(extra_data or {})
channel = make_request(
self.hs.get_reactor(),
@ -187,7 +187,13 @@ class RestHelper:
)
def send_event(
self, room_id, type, content={}, txn_id=None, tok=None, expect_code=200
self,
room_id,
type,
content: Optional[dict] = None,
txn_id=None,
tok=None,
expect_code=200,
):
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
@ -201,7 +207,7 @@ class RestHelper:
self.site,
"PUT",
path,
json.dumps(content).encode("utf8"),
json.dumps(content or {}).encode("utf8"),
)
assert (

View file

@ -16,6 +16,7 @@
import itertools
import json
import urllib
from typing import Optional
from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin
@ -681,7 +682,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
relation_type,
event_type,
key=None,
content={},
content: Optional[dict] = None,
access_token=None,
parent_id=None,
):
@ -713,7 +714,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"POST",
"/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
% (self.room, original_id, relation_type, event_type, query),
json.dumps(content).encode("utf-8"),
json.dumps(content or {}).encode("utf-8"),
access_token=access_token,
)
return channel

View file

@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
from synapse.storage.database import DatabasePool
from synapse.storage.engines import IncorrectDatabaseSetup
from synapse.storage.util.id_generators import MultiWriterIdGenerator
@ -43,7 +45,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
)
def _create_id_generator(
self, instance_name="master", writers=["master"]
self, instance_name="master", writers: Optional[List[str]] = None
) -> MultiWriterIdGenerator:
def _create(conn):
return MultiWriterIdGenerator(
@ -53,7 +55,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
instance_name=instance_name,
tables=[("foobar", "instance_name", "stream_id")],
sequence_name="foobar_seq",
writers=writers,
writers=writers or ["master"],
)
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
@ -476,7 +478,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
)
def _create_id_generator(
self, instance_name="master", writers=["master"]
self, instance_name="master", writers: Optional[List[str]] = None
) -> MultiWriterIdGenerator:
def _create(conn):
return MultiWriterIdGenerator(
@ -486,7 +488,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
instance_name=instance_name,
tables=[("foobar", "instance_name", "stream_id")],
sequence_name="foobar_seq",
writers=writers,
writers=writers or ["master"],
positive=False,
)
@ -612,7 +614,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
)
def _create_id_generator(
self, instance_name="master", writers=["master"]
self, instance_name="master", writers: Optional[List[str]] = None
) -> MultiWriterIdGenerator:
def _create(conn):
return MultiWriterIdGenerator(
@ -625,7 +627,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
("foobar2", "instance_name", "stream_id"),
],
sequence_name="foobar_seq",
writers=writers,
writers=writers or ["master"],
)
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))

View file

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from canonicaljson import json
@ -47,10 +48,15 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.depth = 1
def inject_room_member(
self, room, user, membership, replaces_state=None, extra_content={}
self,
room,
user,
membership,
replaces_state=None,
extra_content: Optional[dict] = None,
):
content = {"membership": membership}
content.update(extra_content)
content.update(extra_content or {})
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{

View file

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
from mock import Mock
@ -37,7 +38,7 @@ def create_event(
state_key=None,
depth=2,
event_id=None,
prev_events=[],
prev_events: Optional[List[str]] = None,
**kwargs
):
global _next_event_id
@ -58,7 +59,7 @@ def create_event(
"sender": "@user_id:example.com",
"room_id": "!room_id:example.com",
"depth": depth,
"prev_events": prev_events,
"prev_events": prev_events or [],
}
if state_key is not None:

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
from mock import Mock
@ -147,9 +148,11 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
return event
@defer.inlineCallbacks
def inject_room_member(self, user_id, membership="join", extra_content={}):
def inject_room_member(
self, user_id, membership="join", extra_content: Optional[dict] = None
):
content = {"membership": membership}
content.update(extra_content)
content.update(extra_content or {})
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{

View file

@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from synapse.config.homeserver import HomeServerConfig
from synapse.util.ratelimitutils import FederationRateLimiter
@ -89,9 +91,9 @@ def _await_resolution(reactor, d):
return (reactor.seconds() - start_time) * 1000
def build_rc_config(settings={}):
def build_rc_config(settings: Optional[dict] = None):
config_dict = default_config("test")
config_dict.update(settings)
config_dict.update(settings or {})
config = HomeServerConfig()
config.parse_config_dict(config_dict, "", "")
return config.rc_federation