mirror of
https://github.com/element-hq/synapse.git
synced 2024-12-14 11:57:44 +00:00
Bugbear: Add Mutable Parameter fixes (#9682)
Part of #9366 Adds in fixes for B006 and B008, both relating to mutable parameter lint errors. Signed-off-by: Jonathan de Jong <jonathan@automatia.nl>
This commit is contained in:
parent
64f4f506c5
commit
2ca4e349e9
38 changed files with 224 additions and 113 deletions
1
changelog.d/9682.misc
Normal file
1
changelog.d/9682.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Introduce flake8-bugbear to the test suite and fix some of its lint violations.
|
|
@ -24,6 +24,7 @@ import sys
|
||||||
import time
|
import time
|
||||||
import urllib
|
import urllib
|
||||||
from http import TwistedHttpClient
|
from http import TwistedHttpClient
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import nacl.encoding
|
import nacl.encoding
|
||||||
import nacl.signing
|
import nacl.signing
|
||||||
|
@ -718,7 +719,7 @@ class SynapseCmd(cmd.Cmd):
|
||||||
method,
|
method,
|
||||||
path,
|
path,
|
||||||
data=None,
|
data=None,
|
||||||
query_params={"access_token": None},
|
query_params: Optional[dict] = None,
|
||||||
alt_text=None,
|
alt_text=None,
|
||||||
):
|
):
|
||||||
"""Runs an HTTP request and pretty prints the output.
|
"""Runs an HTTP request and pretty prints the output.
|
||||||
|
@ -729,6 +730,8 @@ class SynapseCmd(cmd.Cmd):
|
||||||
data: Raw JSON data if any
|
data: Raw JSON data if any
|
||||||
query_params: dict of query parameters to add to the url
|
query_params: dict of query parameters to add to the url
|
||||||
"""
|
"""
|
||||||
|
query_params = query_params or {"access_token": None}
|
||||||
|
|
||||||
url = self._url() + path
|
url = self._url() + path
|
||||||
if "access_token" in query_params:
|
if "access_token" in query_params:
|
||||||
query_params["access_token"] = self._tok()
|
query_params["access_token"] = self._tok()
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
import json
|
import json
|
||||||
import urllib
|
import urllib
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
from twisted.web.client import Agent, readBody
|
from twisted.web.client import Agent, readBody
|
||||||
|
@ -85,8 +86,9 @@ class TwistedHttpClient(HttpClient):
|
||||||
body = yield readBody(response)
|
body = yield readBody(response)
|
||||||
defer.returnValue(json.loads(body))
|
defer.returnValue(json.loads(body))
|
||||||
|
|
||||||
def _create_put_request(self, url, json_data, headers_dict={}):
|
def _create_put_request(self, url, json_data, headers_dict: Optional[dict] = None):
|
||||||
"""Wrapper of _create_request to issue a PUT request"""
|
"""Wrapper of _create_request to issue a PUT request"""
|
||||||
|
headers_dict = headers_dict or {}
|
||||||
|
|
||||||
if "Content-Type" not in headers_dict:
|
if "Content-Type" not in headers_dict:
|
||||||
raise defer.error(RuntimeError("Must include Content-Type header for PUTs"))
|
raise defer.error(RuntimeError("Must include Content-Type header for PUTs"))
|
||||||
|
@ -95,14 +97,22 @@ class TwistedHttpClient(HttpClient):
|
||||||
"PUT", url, producer=_JsonProducer(json_data), headers_dict=headers_dict
|
"PUT", url, producer=_JsonProducer(json_data), headers_dict=headers_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_get_request(self, url, headers_dict={}):
|
def _create_get_request(self, url, headers_dict: Optional[dict] = None):
|
||||||
"""Wrapper of _create_request to issue a GET request"""
|
"""Wrapper of _create_request to issue a GET request"""
|
||||||
return self._create_request("GET", url, headers_dict=headers_dict)
|
return self._create_request("GET", url, headers_dict=headers_dict or {})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_request(
|
def do_request(
|
||||||
self, method, url, data=None, qparams=None, jsonreq=True, headers={}
|
self,
|
||||||
|
method,
|
||||||
|
url,
|
||||||
|
data=None,
|
||||||
|
qparams=None,
|
||||||
|
jsonreq=True,
|
||||||
|
headers: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
|
headers = headers or {}
|
||||||
|
|
||||||
if qparams:
|
if qparams:
|
||||||
url = "%s?%s" % (url, urllib.urlencode(qparams, True))
|
url = "%s?%s" % (url, urllib.urlencode(qparams, True))
|
||||||
|
|
||||||
|
@ -123,8 +133,12 @@ class TwistedHttpClient(HttpClient):
|
||||||
defer.returnValue(json.loads(body))
|
defer.returnValue(json.loads(body))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _create_request(self, method, url, producer=None, headers_dict={}):
|
def _create_request(
|
||||||
|
self, method, url, producer=None, headers_dict: Optional[dict] = None
|
||||||
|
):
|
||||||
"""Creates and sends a request to the given url"""
|
"""Creates and sends a request to the given url"""
|
||||||
|
headers_dict = headers_dict or {}
|
||||||
|
|
||||||
headers_dict["User-Agent"] = ["Synapse Cmd Client"]
|
headers_dict["User-Agent"] = ["Synapse Cmd Client"]
|
||||||
|
|
||||||
retries_left = 5
|
retries_left = 5
|
||||||
|
|
|
@ -18,8 +18,8 @@ ignore =
|
||||||
# E203: whitespace before ':' (which is contrary to pep8?)
|
# E203: whitespace before ':' (which is contrary to pep8?)
|
||||||
# E731: do not assign a lambda expression, use a def
|
# E731: do not assign a lambda expression, use a def
|
||||||
# E501: Line too long (black enforces this for us)
|
# E501: Line too long (black enforces this for us)
|
||||||
# B00*: Subsection of the bugbear suite (TODO: add in remaining fixes)
|
# B007: Subsection of the bugbear suite (TODO: add in remaining fixes)
|
||||||
ignore=W503,W504,E203,E731,E501,B006,B007,B008
|
ignore=W503,W504,E203,E731,E501,B007
|
||||||
|
|
||||||
[isort]
|
[isort]
|
||||||
line_length = 88
|
line_length = 88
|
||||||
|
|
|
@ -49,7 +49,7 @@ This is all tied together by the AppServiceScheduler which DIs the required
|
||||||
components.
|
components.
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
from synapse.appservice import ApplicationService, ApplicationServiceState
|
from synapse.appservice import ApplicationService, ApplicationServiceState
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
|
@ -191,11 +191,11 @@ class _TransactionController:
|
||||||
self,
|
self,
|
||||||
service: ApplicationService,
|
service: ApplicationService,
|
||||||
events: List[EventBase],
|
events: List[EventBase],
|
||||||
ephemeral: List[JsonDict] = [],
|
ephemeral: Optional[List[JsonDict]] = None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
txn = await self.store.create_appservice_txn(
|
txn = await self.store.create_appservice_txn(
|
||||||
service=service, events=events, ephemeral=ephemeral
|
service=service, events=events, ephemeral=ephemeral or []
|
||||||
)
|
)
|
||||||
service_is_up = await self._is_service_up(service)
|
service_is_up = await self._is_service_up(service)
|
||||||
if service_is_up:
|
if service_is_up:
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from ._base import Config
|
from ._base import Config
|
||||||
|
|
||||||
|
@ -21,8 +21,10 @@ class RateLimitConfig:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Dict[str, float],
|
config: Dict[str, float],
|
||||||
defaults={"per_second": 0.17, "burst_count": 3.0},
|
defaults: Optional[Dict[str, float]] = None,
|
||||||
):
|
):
|
||||||
|
defaults = defaults or {"per_second": 0.17, "burst_count": 3.0}
|
||||||
|
|
||||||
self.per_second = config.get("per_second", defaults["per_second"])
|
self.per_second = config.get("per_second", defaults["per_second"])
|
||||||
self.burst_count = int(config.get("burst_count", defaults["burst_count"]))
|
self.burst_count = int(config.get("burst_count", defaults["burst_count"]))
|
||||||
|
|
||||||
|
|
|
@ -330,9 +330,11 @@ class FrozenEvent(EventBase):
|
||||||
self,
|
self,
|
||||||
event_dict: JsonDict,
|
event_dict: JsonDict,
|
||||||
room_version: RoomVersion,
|
room_version: RoomVersion,
|
||||||
internal_metadata_dict: JsonDict = {},
|
internal_metadata_dict: Optional[JsonDict] = None,
|
||||||
rejected_reason: Optional[str] = None,
|
rejected_reason: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
internal_metadata_dict = internal_metadata_dict or {}
|
||||||
|
|
||||||
event_dict = dict(event_dict)
|
event_dict = dict(event_dict)
|
||||||
|
|
||||||
# Signatures is a dict of dicts, and this is faster than doing a
|
# Signatures is a dict of dicts, and this is faster than doing a
|
||||||
|
@ -386,9 +388,11 @@ class FrozenEventV2(EventBase):
|
||||||
self,
|
self,
|
||||||
event_dict: JsonDict,
|
event_dict: JsonDict,
|
||||||
room_version: RoomVersion,
|
room_version: RoomVersion,
|
||||||
internal_metadata_dict: JsonDict = {},
|
internal_metadata_dict: Optional[JsonDict] = None,
|
||||||
rejected_reason: Optional[str] = None,
|
rejected_reason: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
internal_metadata_dict = internal_metadata_dict or {}
|
||||||
|
|
||||||
event_dict = dict(event_dict)
|
event_dict = dict(event_dict)
|
||||||
|
|
||||||
# Signatures is a dict of dicts, and this is faster than doing a
|
# Signatures is a dict of dicts, and this is faster than doing a
|
||||||
|
@ -507,9 +511,11 @@ def _event_type_from_format_version(format_version: int) -> Type[EventBase]:
|
||||||
def make_event_from_dict(
|
def make_event_from_dict(
|
||||||
event_dict: JsonDict,
|
event_dict: JsonDict,
|
||||||
room_version: RoomVersion = RoomVersions.V1,
|
room_version: RoomVersion = RoomVersions.V1,
|
||||||
internal_metadata_dict: JsonDict = {},
|
internal_metadata_dict: Optional[JsonDict] = None,
|
||||||
rejected_reason: Optional[str] = None,
|
rejected_reason: Optional[str] = None,
|
||||||
) -> EventBase:
|
) -> EventBase:
|
||||||
"""Construct an EventBase from the given event dict"""
|
"""Construct an EventBase from the given event dict"""
|
||||||
event_type = _event_type_from_format_version(room_version.event_format)
|
event_type = _event_type_from_format_version(room_version.event_format)
|
||||||
return event_type(event_dict, room_version, internal_metadata_dict, rejected_reason)
|
return event_type(
|
||||||
|
event_dict, room_version, internal_metadata_dict or {}, rejected_reason
|
||||||
|
)
|
||||||
|
|
|
@ -18,6 +18,7 @@ server protocol.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
@ -98,7 +99,7 @@ class Transaction(JsonEncodedObject):
|
||||||
"pdus",
|
"pdus",
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, transaction_id=None, pdus=[], **kwargs):
|
def __init__(self, transaction_id=None, pdus: Optional[list] = None, **kwargs):
|
||||||
"""If we include a list of pdus then we decode then as PDU's
|
"""If we include a list of pdus then we decode then as PDU's
|
||||||
automatically.
|
automatically.
|
||||||
"""
|
"""
|
||||||
|
@ -107,7 +108,7 @@ class Transaction(JsonEncodedObject):
|
||||||
if "edus" in kwargs and not kwargs["edus"]:
|
if "edus" in kwargs and not kwargs["edus"]:
|
||||||
del kwargs["edus"]
|
del kwargs["edus"]
|
||||||
|
|
||||||
super().__init__(transaction_id=transaction_id, pdus=pdus, **kwargs)
|
super().__init__(transaction_id=transaction_id, pdus=pdus or [], **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_new(pdus, **kwargs):
|
def create_new(pdus, **kwargs):
|
||||||
|
|
|
@ -182,7 +182,7 @@ class ApplicationServicesHandler:
|
||||||
self,
|
self,
|
||||||
stream_key: str,
|
stream_key: str,
|
||||||
new_token: Optional[int],
|
new_token: Optional[int],
|
||||||
users: Collection[Union[str, UserID]] = [],
|
users: Optional[Collection[Union[str, UserID]]] = None,
|
||||||
):
|
):
|
||||||
"""This is called by the notifier in the background
|
"""This is called by the notifier in the background
|
||||||
when a ephemeral event handled by the homeserver.
|
when a ephemeral event handled by the homeserver.
|
||||||
|
@ -215,7 +215,7 @@ class ApplicationServicesHandler:
|
||||||
# We only start a new background process if necessary rather than
|
# We only start a new background process if necessary rather than
|
||||||
# optimistically (to cut down on overhead).
|
# optimistically (to cut down on overhead).
|
||||||
self._notify_interested_services_ephemeral(
|
self._notify_interested_services_ephemeral(
|
||||||
services, stream_key, new_token, users
|
services, stream_key, new_token, users or []
|
||||||
)
|
)
|
||||||
|
|
||||||
@wrap_as_background_process("notify_interested_services_ephemeral")
|
@wrap_as_background_process("notify_interested_services_ephemeral")
|
||||||
|
|
|
@ -1790,7 +1790,7 @@ class FederationHandler(BaseHandler):
|
||||||
room_id: str,
|
room_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
membership: str,
|
membership: str,
|
||||||
content: JsonDict = {},
|
content: JsonDict,
|
||||||
params: Optional[Dict[str, Union[str, Iterable[str]]]] = None,
|
params: Optional[Dict[str, Union[str, Iterable[str]]]] = None,
|
||||||
) -> Tuple[str, EventBase, RoomVersion]:
|
) -> Tuple[str, EventBase, RoomVersion]:
|
||||||
(
|
(
|
||||||
|
|
|
@ -137,7 +137,7 @@ class MessageHandler:
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
state_filter: StateFilter = StateFilter.all(),
|
state_filter: Optional[StateFilter] = None,
|
||||||
at_token: Optional[StreamToken] = None,
|
at_token: Optional[StreamToken] = None,
|
||||||
is_guest: bool = False,
|
is_guest: bool = False,
|
||||||
) -> List[dict]:
|
) -> List[dict]:
|
||||||
|
@ -164,6 +164,8 @@ class MessageHandler:
|
||||||
AuthError (403) if the user doesn't have permission to view
|
AuthError (403) if the user doesn't have permission to view
|
||||||
members of this room.
|
members of this room.
|
||||||
"""
|
"""
|
||||||
|
state_filter = state_filter or StateFilter.all()
|
||||||
|
|
||||||
if at_token:
|
if at_token:
|
||||||
# FIXME this claims to get the state at a stream position, but
|
# FIXME this claims to get the state at a stream position, but
|
||||||
# get_recent_events_for_room operates by topo ordering. This therefore
|
# get_recent_events_for_room operates by topo ordering. This therefore
|
||||||
|
@ -874,7 +876,7 @@ class EventCreationHandler:
|
||||||
event: EventBase,
|
event: EventBase,
|
||||||
context: EventContext,
|
context: EventContext,
|
||||||
ratelimit: bool = True,
|
ratelimit: bool = True,
|
||||||
extra_users: List[UserID] = [],
|
extra_users: Optional[List[UserID]] = None,
|
||||||
ignore_shadow_ban: bool = False,
|
ignore_shadow_ban: bool = False,
|
||||||
) -> EventBase:
|
) -> EventBase:
|
||||||
"""Processes a new event.
|
"""Processes a new event.
|
||||||
|
@ -902,6 +904,7 @@ class EventCreationHandler:
|
||||||
Raises:
|
Raises:
|
||||||
ShadowBanError if the requester has been shadow-banned.
|
ShadowBanError if the requester has been shadow-banned.
|
||||||
"""
|
"""
|
||||||
|
extra_users = extra_users or []
|
||||||
|
|
||||||
# we don't apply shadow-banning to membership events here. Invites are blocked
|
# we don't apply shadow-banning to membership events here. Invites are blocked
|
||||||
# higher up the stack, and we allow shadow-banned users to send join and leave
|
# higher up the stack, and we allow shadow-banned users to send join and leave
|
||||||
|
@ -1071,7 +1074,7 @@ class EventCreationHandler:
|
||||||
event: EventBase,
|
event: EventBase,
|
||||||
context: EventContext,
|
context: EventContext,
|
||||||
ratelimit: bool = True,
|
ratelimit: bool = True,
|
||||||
extra_users: List[UserID] = [],
|
extra_users: Optional[List[UserID]] = None,
|
||||||
) -> EventBase:
|
) -> EventBase:
|
||||||
"""Called when we have fully built the event, have already
|
"""Called when we have fully built the event, have already
|
||||||
calculated the push actions for the event, and checked auth.
|
calculated the push actions for the event, and checked auth.
|
||||||
|
@ -1083,6 +1086,8 @@ class EventCreationHandler:
|
||||||
it was de-duplicated (e.g. because we had already persisted an
|
it was de-duplicated (e.g. because we had already persisted an
|
||||||
event with the same transaction ID.)
|
event with the same transaction ID.)
|
||||||
"""
|
"""
|
||||||
|
extra_users = extra_users or []
|
||||||
|
|
||||||
assert self.storage.persistence is not None
|
assert self.storage.persistence is not None
|
||||||
assert self._events_shard_config.should_handle(
|
assert self._events_shard_config.should_handle(
|
||||||
self._instance_name, event.room_id
|
self._instance_name, event.room_id
|
||||||
|
|
|
@ -169,7 +169,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
user_type: Optional[str] = None,
|
user_type: Optional[str] = None,
|
||||||
default_display_name: Optional[str] = None,
|
default_display_name: Optional[str] = None,
|
||||||
address: Optional[str] = None,
|
address: Optional[str] = None,
|
||||||
bind_emails: Iterable[str] = [],
|
bind_emails: Optional[Iterable[str]] = None,
|
||||||
by_admin: bool = False,
|
by_admin: bool = False,
|
||||||
user_agent_ips: Optional[List[Tuple[str, str]]] = None,
|
user_agent_ips: Optional[List[Tuple[str, str]]] = None,
|
||||||
auth_provider_id: Optional[str] = None,
|
auth_provider_id: Optional[str] = None,
|
||||||
|
@ -204,6 +204,8 @@ class RegistrationHandler(BaseHandler):
|
||||||
Raises:
|
Raises:
|
||||||
SynapseError if there was a problem registering.
|
SynapseError if there was a problem registering.
|
||||||
"""
|
"""
|
||||||
|
bind_emails = bind_emails or []
|
||||||
|
|
||||||
await self.check_registration_ratelimit(address)
|
await self.check_registration_ratelimit(address)
|
||||||
|
|
||||||
result = await self.spam_checker.check_registration_for_spam(
|
result = await self.spam_checker.check_registration_for_spam(
|
||||||
|
|
|
@ -548,7 +548,7 @@ class SyncHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_state_after_event(
|
async def get_state_after_event(
|
||||||
self, event: EventBase, state_filter: StateFilter = StateFilter.all()
|
self, event: EventBase, state_filter: Optional[StateFilter] = None
|
||||||
) -> StateMap[str]:
|
) -> StateMap[str]:
|
||||||
"""
|
"""
|
||||||
Get the room state after the given event
|
Get the room state after the given event
|
||||||
|
@ -558,7 +558,7 @@ class SyncHandler:
|
||||||
state_filter: The state filter used to fetch state from the database.
|
state_filter: The state filter used to fetch state from the database.
|
||||||
"""
|
"""
|
||||||
state_ids = await self.state_store.get_state_ids_for_event(
|
state_ids = await self.state_store.get_state_ids_for_event(
|
||||||
event.event_id, state_filter=state_filter
|
event.event_id, state_filter=state_filter or StateFilter.all()
|
||||||
)
|
)
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
state_ids = dict(state_ids)
|
state_ids = dict(state_ids)
|
||||||
|
@ -569,7 +569,7 @@ class SyncHandler:
|
||||||
self,
|
self,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
stream_position: StreamToken,
|
stream_position: StreamToken,
|
||||||
state_filter: StateFilter = StateFilter.all(),
|
state_filter: Optional[StateFilter] = None,
|
||||||
) -> StateMap[str]:
|
) -> StateMap[str]:
|
||||||
"""Get the room state at a particular stream position
|
"""Get the room state at a particular stream position
|
||||||
|
|
||||||
|
@ -589,7 +589,7 @@ class SyncHandler:
|
||||||
if last_events:
|
if last_events:
|
||||||
last_event = last_events[-1]
|
last_event = last_events[-1]
|
||||||
state = await self.get_state_after_event(
|
state = await self.get_state_after_event(
|
||||||
last_event, state_filter=state_filter
|
last_event, state_filter=state_filter or StateFilter.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -297,7 +297,7 @@ class SimpleHttpClient:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hs: "HomeServer",
|
hs: "HomeServer",
|
||||||
treq_args: Dict[str, Any] = {},
|
treq_args: Optional[Dict[str, Any]] = None,
|
||||||
ip_whitelist: Optional[IPSet] = None,
|
ip_whitelist: Optional[IPSet] = None,
|
||||||
ip_blacklist: Optional[IPSet] = None,
|
ip_blacklist: Optional[IPSet] = None,
|
||||||
use_proxy: bool = False,
|
use_proxy: bool = False,
|
||||||
|
@ -317,7 +317,7 @@ class SimpleHttpClient:
|
||||||
|
|
||||||
self._ip_whitelist = ip_whitelist
|
self._ip_whitelist = ip_whitelist
|
||||||
self._ip_blacklist = ip_blacklist
|
self._ip_blacklist = ip_blacklist
|
||||||
self._extra_treq_args = treq_args
|
self._extra_treq_args = treq_args or {}
|
||||||
|
|
||||||
self.user_agent = hs.version_string
|
self.user_agent = hs.version_string
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
|
@ -27,7 +27,7 @@ from twisted.python.failure import Failure
|
||||||
from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase
|
from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase
|
||||||
from twisted.web.error import SchemeNotSupported
|
from twisted.web.error import SchemeNotSupported
|
||||||
from twisted.web.http_headers import Headers
|
from twisted.web.http_headers import Headers
|
||||||
from twisted.web.iweb import IAgent
|
from twisted.web.iweb import IAgent, IPolicyForHTTPS
|
||||||
|
|
||||||
from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint
|
from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint
|
||||||
|
|
||||||
|
@ -88,12 +88,14 @@ class ProxyAgent(_AgentBase):
|
||||||
self,
|
self,
|
||||||
reactor,
|
reactor,
|
||||||
proxy_reactor=None,
|
proxy_reactor=None,
|
||||||
contextFactory=BrowserLikePolicyForHTTPS(),
|
contextFactory: Optional[IPolicyForHTTPS] = None,
|
||||||
connectTimeout=None,
|
connectTimeout=None,
|
||||||
bindAddress=None,
|
bindAddress=None,
|
||||||
pool=None,
|
pool=None,
|
||||||
use_proxy=False,
|
use_proxy=False,
|
||||||
):
|
):
|
||||||
|
contextFactory = contextFactory or BrowserLikePolicyForHTTPS()
|
||||||
|
|
||||||
_AgentBase.__init__(self, reactor, pool)
|
_AgentBase.__init__(self, reactor, pool)
|
||||||
|
|
||||||
if proxy_reactor is None:
|
if proxy_reactor is None:
|
||||||
|
|
|
@ -486,7 +486,7 @@ def start_active_span_from_request(
|
||||||
def start_active_span_from_edu(
|
def start_active_span_from_edu(
|
||||||
edu_content,
|
edu_content,
|
||||||
operation_name,
|
operation_name,
|
||||||
references=[],
|
references: Optional[list] = None,
|
||||||
tags=None,
|
tags=None,
|
||||||
start_time=None,
|
start_time=None,
|
||||||
ignore_active_span=False,
|
ignore_active_span=False,
|
||||||
|
@ -501,6 +501,7 @@ def start_active_span_from_edu(
|
||||||
|
|
||||||
For the other args see opentracing.tracer
|
For the other args see opentracing.tracer
|
||||||
"""
|
"""
|
||||||
|
references = references or []
|
||||||
|
|
||||||
if opentracing is None:
|
if opentracing is None:
|
||||||
return noop_context_manager()
|
return noop_context_manager()
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Generator, Iterable, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Generator, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
@ -127,7 +127,7 @@ class ModuleApi:
|
||||||
return defer.ensureDeferred(self._auth_handler.check_user_exists(user_id))
|
return defer.ensureDeferred(self._auth_handler.check_user_exists(user_id))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def register(self, localpart, displayname=None, emails=[]):
|
def register(self, localpart, displayname=None, emails: Optional[List[str]] = None):
|
||||||
"""Registers a new user with given localpart and optional displayname, emails.
|
"""Registers a new user with given localpart and optional displayname, emails.
|
||||||
|
|
||||||
Also returns an access token for the new user.
|
Also returns an access token for the new user.
|
||||||
|
@ -147,11 +147,13 @@ class ModuleApi:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Using deprecated ModuleApi.register which creates a dummy user device."
|
"Using deprecated ModuleApi.register which creates a dummy user device."
|
||||||
)
|
)
|
||||||
user_id = yield self.register_user(localpart, displayname, emails)
|
user_id = yield self.register_user(localpart, displayname, emails or [])
|
||||||
_, access_token = yield self.register_device(user_id)
|
_, access_token = yield self.register_device(user_id)
|
||||||
return user_id, access_token
|
return user_id, access_token
|
||||||
|
|
||||||
def register_user(self, localpart, displayname=None, emails=[]):
|
def register_user(
|
||||||
|
self, localpart, displayname=None, emails: Optional[List[str]] = None
|
||||||
|
):
|
||||||
"""Registers a new user with given localpart and optional displayname, emails.
|
"""Registers a new user with given localpart and optional displayname, emails.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -170,7 +172,7 @@ class ModuleApi:
|
||||||
self._hs.get_registration_handler().register_user(
|
self._hs.get_registration_handler().register_user(
|
||||||
localpart=localpart,
|
localpart=localpart,
|
||||||
default_display_name=displayname,
|
default_display_name=displayname,
|
||||||
bind_emails=emails,
|
bind_emails=emails or [],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -276,7 +276,7 @@ class Notifier:
|
||||||
event: EventBase,
|
event: EventBase,
|
||||||
event_pos: PersistedEventPosition,
|
event_pos: PersistedEventPosition,
|
||||||
max_room_stream_token: RoomStreamToken,
|
max_room_stream_token: RoomStreamToken,
|
||||||
extra_users: Collection[UserID] = [],
|
extra_users: Optional[Collection[UserID]] = None,
|
||||||
):
|
):
|
||||||
"""Unwraps event and calls `on_new_room_event_args`."""
|
"""Unwraps event and calls `on_new_room_event_args`."""
|
||||||
self.on_new_room_event_args(
|
self.on_new_room_event_args(
|
||||||
|
@ -286,7 +286,7 @@ class Notifier:
|
||||||
state_key=event.get("state_key"),
|
state_key=event.get("state_key"),
|
||||||
membership=event.content.get("membership"),
|
membership=event.content.get("membership"),
|
||||||
max_room_stream_token=max_room_stream_token,
|
max_room_stream_token=max_room_stream_token,
|
||||||
extra_users=extra_users,
|
extra_users=extra_users or [],
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_new_room_event_args(
|
def on_new_room_event_args(
|
||||||
|
@ -297,7 +297,7 @@ class Notifier:
|
||||||
membership: Optional[str],
|
membership: Optional[str],
|
||||||
event_pos: PersistedEventPosition,
|
event_pos: PersistedEventPosition,
|
||||||
max_room_stream_token: RoomStreamToken,
|
max_room_stream_token: RoomStreamToken,
|
||||||
extra_users: Collection[UserID] = [],
|
extra_users: Optional[Collection[UserID]] = None,
|
||||||
):
|
):
|
||||||
"""Used by handlers to inform the notifier something has happened
|
"""Used by handlers to inform the notifier something has happened
|
||||||
in the room, room event wise.
|
in the room, room event wise.
|
||||||
|
@ -313,7 +313,7 @@ class Notifier:
|
||||||
self.pending_new_room_events.append(
|
self.pending_new_room_events.append(
|
||||||
_PendingRoomEventEntry(
|
_PendingRoomEventEntry(
|
||||||
event_pos=event_pos,
|
event_pos=event_pos,
|
||||||
extra_users=extra_users,
|
extra_users=extra_users or [],
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
type=event_type,
|
type=event_type,
|
||||||
state_key=state_key,
|
state_key=state_key,
|
||||||
|
@ -382,14 +382,14 @@ class Notifier:
|
||||||
self,
|
self,
|
||||||
stream_key: str,
|
stream_key: str,
|
||||||
new_token: Union[int, RoomStreamToken],
|
new_token: Union[int, RoomStreamToken],
|
||||||
users: Collection[Union[str, UserID]] = [],
|
users: Optional[Collection[Union[str, UserID]]] = None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
stream_token = None
|
stream_token = None
|
||||||
if isinstance(new_token, int):
|
if isinstance(new_token, int):
|
||||||
stream_token = new_token
|
stream_token = new_token
|
||||||
self.appservice_handler.notify_interested_services_ephemeral(
|
self.appservice_handler.notify_interested_services_ephemeral(
|
||||||
stream_key, stream_token, users
|
stream_key, stream_token, users or []
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error notifying application services of event")
|
logger.exception("Error notifying application services of event")
|
||||||
|
@ -404,13 +404,16 @@ class Notifier:
|
||||||
self,
|
self,
|
||||||
stream_key: str,
|
stream_key: str,
|
||||||
new_token: Union[int, RoomStreamToken],
|
new_token: Union[int, RoomStreamToken],
|
||||||
users: Collection[Union[str, UserID]] = [],
|
users: Optional[Collection[Union[str, UserID]]] = None,
|
||||||
rooms: Collection[str] = [],
|
rooms: Optional[Collection[str]] = None,
|
||||||
):
|
):
|
||||||
"""Used to inform listeners that something has happened event wise.
|
"""Used to inform listeners that something has happened event wise.
|
||||||
|
|
||||||
Will wake up all listeners for the given users and rooms.
|
Will wake up all listeners for the given users and rooms.
|
||||||
"""
|
"""
|
||||||
|
users = users or []
|
||||||
|
rooms = rooms or []
|
||||||
|
|
||||||
with Measure(self.clock, "on_new_event"):
|
with Measure(self.clock, "on_new_event"):
|
||||||
user_streams = set()
|
user_streams = set()
|
||||||
|
|
||||||
|
|
|
@ -900,7 +900,7 @@ class DatabasePool:
|
||||||
table: str,
|
table: str,
|
||||||
keyvalues: Dict[str, Any],
|
keyvalues: Dict[str, Any],
|
||||||
values: Dict[str, Any],
|
values: Dict[str, Any],
|
||||||
insertion_values: Dict[str, Any] = {},
|
insertion_values: Optional[Dict[str, Any]] = None,
|
||||||
desc: str = "simple_upsert",
|
desc: str = "simple_upsert",
|
||||||
lock: bool = True,
|
lock: bool = True,
|
||||||
) -> Optional[bool]:
|
) -> Optional[bool]:
|
||||||
|
@ -927,6 +927,8 @@ class DatabasePool:
|
||||||
Native upserts always return None. Emulated upserts return True if a
|
Native upserts always return None. Emulated upserts return True if a
|
||||||
new entry was created, False if an existing one was updated.
|
new entry was created, False if an existing one was updated.
|
||||||
"""
|
"""
|
||||||
|
insertion_values = insertion_values or {}
|
||||||
|
|
||||||
attempts = 0
|
attempts = 0
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
@ -964,7 +966,7 @@ class DatabasePool:
|
||||||
table: str,
|
table: str,
|
||||||
keyvalues: Dict[str, Any],
|
keyvalues: Dict[str, Any],
|
||||||
values: Dict[str, Any],
|
values: Dict[str, Any],
|
||||||
insertion_values: Dict[str, Any] = {},
|
insertion_values: Optional[Dict[str, Any]] = None,
|
||||||
lock: bool = True,
|
lock: bool = True,
|
||||||
) -> Optional[bool]:
|
) -> Optional[bool]:
|
||||||
"""
|
"""
|
||||||
|
@ -982,6 +984,8 @@ class DatabasePool:
|
||||||
Native upserts always return None. Emulated upserts return True if a
|
Native upserts always return None. Emulated upserts return True if a
|
||||||
new entry was created, False if an existing one was updated.
|
new entry was created, False if an existing one was updated.
|
||||||
"""
|
"""
|
||||||
|
insertion_values = insertion_values or {}
|
||||||
|
|
||||||
if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
|
if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
|
||||||
self.simple_upsert_txn_native_upsert(
|
self.simple_upsert_txn_native_upsert(
|
||||||
txn, table, keyvalues, values, insertion_values=insertion_values
|
txn, table, keyvalues, values, insertion_values=insertion_values
|
||||||
|
@ -1003,7 +1007,7 @@ class DatabasePool:
|
||||||
table: str,
|
table: str,
|
||||||
keyvalues: Dict[str, Any],
|
keyvalues: Dict[str, Any],
|
||||||
values: Dict[str, Any],
|
values: Dict[str, Any],
|
||||||
insertion_values: Dict[str, Any] = {},
|
insertion_values: Optional[Dict[str, Any]] = None,
|
||||||
lock: bool = True,
|
lock: bool = True,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
|
@ -1017,6 +1021,8 @@ class DatabasePool:
|
||||||
Returns True if a new entry was created, False if an existing
|
Returns True if a new entry was created, False if an existing
|
||||||
one was updated.
|
one was updated.
|
||||||
"""
|
"""
|
||||||
|
insertion_values = insertion_values or {}
|
||||||
|
|
||||||
# We need to lock the table :(, unless we're *really* careful
|
# We need to lock the table :(, unless we're *really* careful
|
||||||
if lock:
|
if lock:
|
||||||
self.engine.lock_table(txn, table)
|
self.engine.lock_table(txn, table)
|
||||||
|
@ -1077,7 +1083,7 @@ class DatabasePool:
|
||||||
table: str,
|
table: str,
|
||||||
keyvalues: Dict[str, Any],
|
keyvalues: Dict[str, Any],
|
||||||
values: Dict[str, Any],
|
values: Dict[str, Any],
|
||||||
insertion_values: Dict[str, Any] = {},
|
insertion_values: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Use the native UPSERT functionality in recent PostgreSQL versions.
|
Use the native UPSERT functionality in recent PostgreSQL versions.
|
||||||
|
@ -1090,7 +1096,7 @@ class DatabasePool:
|
||||||
"""
|
"""
|
||||||
allvalues = {} # type: Dict[str, Any]
|
allvalues = {} # type: Dict[str, Any]
|
||||||
allvalues.update(keyvalues)
|
allvalues.update(keyvalues)
|
||||||
allvalues.update(insertion_values)
|
allvalues.update(insertion_values or {})
|
||||||
|
|
||||||
if not values:
|
if not values:
|
||||||
latter = "NOTHING"
|
latter = "NOTHING"
|
||||||
|
@ -1513,7 +1519,7 @@ class DatabasePool:
|
||||||
column: str,
|
column: str,
|
||||||
iterable: Iterable[Any],
|
iterable: Iterable[Any],
|
||||||
retcols: Iterable[str],
|
retcols: Iterable[str],
|
||||||
keyvalues: Dict[str, Any] = {},
|
keyvalues: Optional[Dict[str, Any]] = None,
|
||||||
desc: str = "simple_select_many_batch",
|
desc: str = "simple_select_many_batch",
|
||||||
batch_size: int = 100,
|
batch_size: int = 100,
|
||||||
) -> List[Any]:
|
) -> List[Any]:
|
||||||
|
@ -1531,6 +1537,8 @@ class DatabasePool:
|
||||||
desc: description of the transaction, for logging and metrics
|
desc: description of the transaction, for logging and metrics
|
||||||
batch_size: the number of rows for each select query
|
batch_size: the number of rows for each select query
|
||||||
"""
|
"""
|
||||||
|
keyvalues = keyvalues or {}
|
||||||
|
|
||||||
results = [] # type: List[Dict[str, Any]]
|
results = [] # type: List[Dict[str, Any]]
|
||||||
|
|
||||||
if not iterable:
|
if not iterable:
|
||||||
|
|
|
@ -320,8 +320,8 @@ class PersistEventsStore:
|
||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||||
backfilled: bool,
|
backfilled: bool,
|
||||||
state_delta_for_room: Dict[str, DeltaState] = {},
|
state_delta_for_room: Optional[Dict[str, DeltaState]] = None,
|
||||||
new_forward_extremeties: Dict[str, List[str]] = {},
|
new_forward_extremeties: Optional[Dict[str, List[str]]] = None,
|
||||||
):
|
):
|
||||||
"""Insert some number of room events into the necessary database tables.
|
"""Insert some number of room events into the necessary database tables.
|
||||||
|
|
||||||
|
@ -342,6 +342,9 @@ class PersistEventsStore:
|
||||||
extremities.
|
extremities.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
state_delta_for_room = state_delta_for_room or {}
|
||||||
|
new_forward_extremeties = new_forward_extremeties or {}
|
||||||
|
|
||||||
all_events_and_contexts = events_and_contexts
|
all_events_and_contexts = events_and_contexts
|
||||||
|
|
||||||
min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering
|
min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering
|
||||||
|
|
|
@ -1171,7 +1171,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||||
user_id: str,
|
user_id: str,
|
||||||
membership: str,
|
membership: str,
|
||||||
is_admin: bool = False,
|
is_admin: bool = False,
|
||||||
content: JsonDict = {},
|
content: Optional[JsonDict] = None,
|
||||||
local_attestation: Optional[dict] = None,
|
local_attestation: Optional[dict] = None,
|
||||||
remote_attestation: Optional[dict] = None,
|
remote_attestation: Optional[dict] = None,
|
||||||
is_publicised: bool = False,
|
is_publicised: bool = False,
|
||||||
|
@ -1192,6 +1192,8 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||||
is_publicised: Whether this should be publicised.
|
is_publicised: Whether this should be publicised.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
content = content or {}
|
||||||
|
|
||||||
def _register_user_group_membership_txn(txn, next_id):
|
def _register_user_group_membership_txn(txn, next_id):
|
||||||
# TODO: Upsert?
|
# TODO: Upsert?
|
||||||
self.db_pool.simple_delete_txn(
|
self.db_pool.simple_delete_txn(
|
||||||
|
|
|
@ -190,7 +190,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
|
|
||||||
# FIXME: how should this be cached?
|
# FIXME: how should this be cached?
|
||||||
async def get_filtered_current_state_ids(
|
async def get_filtered_current_state_ids(
|
||||||
self, room_id: str, state_filter: StateFilter = StateFilter.all()
|
self, room_id: str, state_filter: Optional[StateFilter] = None
|
||||||
) -> StateMap[str]:
|
) -> StateMap[str]:
|
||||||
"""Get the current state event of a given type for a room based on the
|
"""Get the current state event of a given type for a room based on the
|
||||||
current_state_events table. This may not be as up-to-date as the result
|
current_state_events table. This may not be as up-to-date as the result
|
||||||
|
@ -205,7 +205,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
Map from type/state_key to event ID.
|
Map from type/state_key to event ID.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
where_clause, where_args = state_filter.make_sql_filter_clause()
|
where_clause, where_args = (
|
||||||
|
state_filter or StateFilter.all()
|
||||||
|
).make_sql_filter_clause()
|
||||||
|
|
||||||
if not where_clause:
|
if not where_clause:
|
||||||
# We delegate to the cached version
|
# We delegate to the cached version
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool
|
||||||
|
@ -73,8 +74,10 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
|
||||||
return count
|
return count
|
||||||
|
|
||||||
def _get_state_groups_from_groups_txn(
|
def _get_state_groups_from_groups_txn(
|
||||||
self, txn, groups, state_filter=StateFilter.all()
|
self, txn, groups, state_filter: Optional[StateFilter] = None
|
||||||
):
|
):
|
||||||
|
state_filter = state_filter or StateFilter.all()
|
||||||
|
|
||||||
results = {group: {} for group in groups}
|
results = {group: {} for group in groups}
|
||||||
|
|
||||||
where_clause, where_args = state_filter.make_sql_filter_clause()
|
where_clause, where_args = state_filter.make_sql_filter_clause()
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import Dict, Iterable, List, Set, Tuple
|
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
|
@ -210,7 +210,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||||
return state_filter.filter_state(state_dict_ids), not missing_types
|
return state_filter.filter_state(state_dict_ids), not missing_types
|
||||||
|
|
||||||
async def _get_state_for_groups(
|
async def _get_state_for_groups(
|
||||||
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
|
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
|
||||||
) -> Dict[int, MutableStateMap[str]]:
|
) -> Dict[int, MutableStateMap[str]]:
|
||||||
"""Gets the state at each of a list of state groups, optionally
|
"""Gets the state at each of a list of state groups, optionally
|
||||||
filtering by type/state_key
|
filtering by type/state_key
|
||||||
|
@ -223,6 +223,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||||
Returns:
|
Returns:
|
||||||
Dict of state group to state map.
|
Dict of state group to state map.
|
||||||
"""
|
"""
|
||||||
|
state_filter = state_filter or StateFilter.all()
|
||||||
|
|
||||||
member_filter, non_member_filter = state_filter.get_member_split()
|
member_filter, non_member_filter = state_filter.get_member_split()
|
||||||
|
|
||||||
|
|
|
@ -449,7 +449,7 @@ class StateGroupStorage:
|
||||||
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
|
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
|
||||||
|
|
||||||
async def get_state_for_events(
|
async def get_state_for_events(
|
||||||
self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all()
|
self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
|
||||||
) -> Dict[str, StateMap[EventBase]]:
|
) -> Dict[str, StateMap[EventBase]]:
|
||||||
"""Given a list of event_ids and type tuples, return a list of state
|
"""Given a list of event_ids and type tuples, return a list of state
|
||||||
dicts for each event.
|
dicts for each event.
|
||||||
|
@ -465,7 +465,7 @@ class StateGroupStorage:
|
||||||
|
|
||||||
groups = set(event_to_groups.values())
|
groups = set(event_to_groups.values())
|
||||||
group_to_state = await self.stores.state._get_state_for_groups(
|
group_to_state = await self.stores.state._get_state_for_groups(
|
||||||
groups, state_filter
|
groups, state_filter or StateFilter.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
state_event_map = await self.stores.main.get_events(
|
state_event_map = await self.stores.main.get_events(
|
||||||
|
@ -485,7 +485,7 @@ class StateGroupStorage:
|
||||||
return {event: event_to_state[event] for event in event_ids}
|
return {event: event_to_state[event] for event in event_ids}
|
||||||
|
|
||||||
async def get_state_ids_for_events(
|
async def get_state_ids_for_events(
|
||||||
self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all()
|
self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
|
||||||
) -> Dict[str, StateMap[str]]:
|
) -> Dict[str, StateMap[str]]:
|
||||||
"""
|
"""
|
||||||
Get the state dicts corresponding to a list of events, containing the event_ids
|
Get the state dicts corresponding to a list of events, containing the event_ids
|
||||||
|
@ -502,7 +502,7 @@ class StateGroupStorage:
|
||||||
|
|
||||||
groups = set(event_to_groups.values())
|
groups = set(event_to_groups.values())
|
||||||
group_to_state = await self.stores.state._get_state_for_groups(
|
group_to_state = await self.stores.state._get_state_for_groups(
|
||||||
groups, state_filter
|
groups, state_filter or StateFilter.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
event_to_state = {
|
event_to_state = {
|
||||||
|
@ -513,7 +513,7 @@ class StateGroupStorage:
|
||||||
return {event: event_to_state[event] for event in event_ids}
|
return {event: event_to_state[event] for event in event_ids}
|
||||||
|
|
||||||
async def get_state_for_event(
|
async def get_state_for_event(
|
||||||
self, event_id: str, state_filter: StateFilter = StateFilter.all()
|
self, event_id: str, state_filter: Optional[StateFilter] = None
|
||||||
) -> StateMap[EventBase]:
|
) -> StateMap[EventBase]:
|
||||||
"""
|
"""
|
||||||
Get the state dict corresponding to a particular event
|
Get the state dict corresponding to a particular event
|
||||||
|
@ -525,11 +525,13 @@ class StateGroupStorage:
|
||||||
Returns:
|
Returns:
|
||||||
A dict from (type, state_key) -> state_event
|
A dict from (type, state_key) -> state_event
|
||||||
"""
|
"""
|
||||||
state_map = await self.get_state_for_events([event_id], state_filter)
|
state_map = await self.get_state_for_events(
|
||||||
|
[event_id], state_filter or StateFilter.all()
|
||||||
|
)
|
||||||
return state_map[event_id]
|
return state_map[event_id]
|
||||||
|
|
||||||
async def get_state_ids_for_event(
|
async def get_state_ids_for_event(
|
||||||
self, event_id: str, state_filter: StateFilter = StateFilter.all()
|
self, event_id: str, state_filter: Optional[StateFilter] = None
|
||||||
) -> StateMap[str]:
|
) -> StateMap[str]:
|
||||||
"""
|
"""
|
||||||
Get the state dict corresponding to a particular event
|
Get the state dict corresponding to a particular event
|
||||||
|
@ -541,11 +543,13 @@ class StateGroupStorage:
|
||||||
Returns:
|
Returns:
|
||||||
A dict from (type, state_key) -> state_event
|
A dict from (type, state_key) -> state_event
|
||||||
"""
|
"""
|
||||||
state_map = await self.get_state_ids_for_events([event_id], state_filter)
|
state_map = await self.get_state_ids_for_events(
|
||||||
|
[event_id], state_filter or StateFilter.all()
|
||||||
|
)
|
||||||
return state_map[event_id]
|
return state_map[event_id]
|
||||||
|
|
||||||
def _get_state_for_groups(
|
def _get_state_for_groups(
|
||||||
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
|
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
|
||||||
) -> Awaitable[Dict[int, MutableStateMap[str]]]:
|
) -> Awaitable[Dict[int, MutableStateMap[str]]]:
|
||||||
"""Gets the state at each of a list of state groups, optionally
|
"""Gets the state at each of a list of state groups, optionally
|
||||||
filtering by type/state_key
|
filtering by type/state_key
|
||||||
|
@ -558,7 +562,9 @@ class StateGroupStorage:
|
||||||
Returns:
|
Returns:
|
||||||
Dict of state group to state map.
|
Dict of state group to state map.
|
||||||
"""
|
"""
|
||||||
return self.stores.state._get_state_for_groups(groups, state_filter)
|
return self.stores.state._get_state_for_groups(
|
||||||
|
groups, state_filter or StateFilter.all()
|
||||||
|
)
|
||||||
|
|
||||||
async def store_state_group(
|
async def store_state_group(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -17,7 +17,7 @@ import logging
|
||||||
import threading
|
import threading
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
@ -91,7 +91,14 @@ class StreamIdGenerator:
|
||||||
# ... persist event ...
|
# ... persist event ...
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, db_conn, table, column, extra_tables=[], step=1):
|
def __init__(
|
||||||
|
self,
|
||||||
|
db_conn,
|
||||||
|
table,
|
||||||
|
column,
|
||||||
|
extra_tables: Iterable[Tuple[str, str]] = (),
|
||||||
|
step=1,
|
||||||
|
):
|
||||||
assert step != 0
|
assert step != 0
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
self._step = step
|
self._step = step
|
||||||
|
|
|
@ -57,12 +57,14 @@ def enumerate_leaves(node, depth):
|
||||||
class _Node:
|
class _Node:
|
||||||
__slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
|
__slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
|
||||||
|
|
||||||
def __init__(self, prev_node, next_node, key, value, callbacks=set()):
|
def __init__(
|
||||||
|
self, prev_node, next_node, key, value, callbacks: Optional[set] = None
|
||||||
|
):
|
||||||
self.prev_node = prev_node
|
self.prev_node = prev_node
|
||||||
self.next_node = next_node
|
self.next_node = next_node
|
||||||
self.key = key
|
self.key = key
|
||||||
self.value = value
|
self.value = value
|
||||||
self.callbacks = callbacks
|
self.callbacks = callbacks or set()
|
||||||
|
|
||||||
|
|
||||||
class LruCache(Generic[KT, VT]):
|
class LruCache(Generic[KT, VT]):
|
||||||
|
@ -176,10 +178,10 @@ class LruCache(Generic[KT, VT]):
|
||||||
|
|
||||||
self.len = synchronized(cache_len)
|
self.len = synchronized(cache_len)
|
||||||
|
|
||||||
def add_node(key, value, callbacks=set()):
|
def add_node(key, value, callbacks: Optional[set] = None):
|
||||||
prev_node = list_root
|
prev_node = list_root
|
||||||
next_node = prev_node.next_node
|
next_node = prev_node.next_node
|
||||||
node = _Node(prev_node, next_node, key, value, callbacks)
|
node = _Node(prev_node, next_node, key, value, callbacks or set())
|
||||||
prev_node.next_node = node
|
prev_node.next_node = node
|
||||||
next_node.prev_node = node
|
next_node.prev_node = node
|
||||||
cache[key] = node
|
cache[key] = node
|
||||||
|
@ -237,7 +239,7 @@ class LruCache(Generic[KT, VT]):
|
||||||
def cache_get(
|
def cache_get(
|
||||||
key: KT,
|
key: KT,
|
||||||
default: Optional[T] = None,
|
default: Optional[T] = None,
|
||||||
callbacks: Iterable[Callable[[], None]] = [],
|
callbacks: Iterable[Callable[[], None]] = (),
|
||||||
update_metrics: bool = True,
|
update_metrics: bool = True,
|
||||||
):
|
):
|
||||||
node = cache.get(key, None)
|
node = cache.get(key, None)
|
||||||
|
@ -253,7 +255,7 @@ class LruCache(Generic[KT, VT]):
|
||||||
return default
|
return default
|
||||||
|
|
||||||
@synchronized
|
@synchronized
|
||||||
def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []):
|
def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = ()):
|
||||||
node = cache.get(key, None)
|
node = cache.get(key, None)
|
||||||
if node is not None:
|
if node is not None:
|
||||||
# We sometimes store large objects, e.g. dicts, which cause
|
# We sometimes store large objects, e.g. dicts, which cause
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
|
||||||
|
@ -180,7 +181,11 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
||||||
_check_logcontext(context)
|
_check_logcontext(context)
|
||||||
|
|
||||||
def _handle_well_known_connection(
|
def _handle_well_known_connection(
|
||||||
self, client_factory, expected_sni, content, response_headers={}
|
self,
|
||||||
|
client_factory,
|
||||||
|
expected_sni,
|
||||||
|
content,
|
||||||
|
response_headers: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
"""Handle an outgoing HTTPs connection: wire it up to a server, check that the
|
"""Handle an outgoing HTTPs connection: wire it up to a server, check that the
|
||||||
request is for a .well-known, and send the response.
|
request is for a .well-known, and send the response.
|
||||||
|
@ -202,10 +207,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"]
|
request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"]
|
||||||
)
|
)
|
||||||
self._send_well_known_response(request, content, headers=response_headers)
|
self._send_well_known_response(request, content, headers=response_headers or {})
|
||||||
return well_known_server
|
return well_known_server
|
||||||
|
|
||||||
def _send_well_known_response(self, request, content, headers={}):
|
def _send_well_known_response(
|
||||||
|
self, request, content, headers: Optional[dict] = None
|
||||||
|
):
|
||||||
"""Check that an incoming request looks like a valid .well-known request, and
|
"""Check that an incoming request looks like a valid .well-known request, and
|
||||||
send back the response.
|
send back the response.
|
||||||
"""
|
"""
|
||||||
|
@ -213,7 +220,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
||||||
self.assertEqual(request.path, b"/.well-known/matrix/server")
|
self.assertEqual(request.path, b"/.well-known/matrix/server")
|
||||||
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"])
|
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"])
|
||||||
# send back a response
|
# send back a response
|
||||||
for k, v in headers.items():
|
for k, v in (headers or {}).items():
|
||||||
request.setHeader(k, v)
|
request.setHeader(k, v)
|
||||||
request.write(content)
|
request.write(content)
|
||||||
request.finish()
|
request.finish()
|
||||||
|
|
|
@ -266,7 +266,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||||
return resource
|
return resource
|
||||||
|
|
||||||
def make_worker_hs(
|
def make_worker_hs(
|
||||||
self, worker_app: str, extra_config: dict = {}, **kwargs
|
self, worker_app: str, extra_config: Optional[dict] = None, **kwargs
|
||||||
) -> HomeServer:
|
) -> HomeServer:
|
||||||
"""Make a new worker HS instance, correctly connecting replcation
|
"""Make a new worker HS instance, correctly connecting replcation
|
||||||
stream to the master HS.
|
stream to the master HS.
|
||||||
|
@ -283,7 +283,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
config = self._get_worker_hs_config()
|
config = self._get_worker_hs_config()
|
||||||
config["worker_app"] = worker_app
|
config["worker_app"] = worker_app
|
||||||
config.update(extra_config)
|
config.update(extra_config or {})
|
||||||
|
|
||||||
worker_hs = self.setup_test_homeserver(
|
worker_hs = self.setup_test_homeserver(
|
||||||
homeserver_to_use=GenericWorkerServer,
|
homeserver_to_use=GenericWorkerServer,
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Iterable, Optional
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
|
|
||||||
|
@ -332,15 +333,18 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||||
room_id=ROOM_ID,
|
room_id=ROOM_ID,
|
||||||
type="m.room.message",
|
type="m.room.message",
|
||||||
key=None,
|
key=None,
|
||||||
internal={},
|
internal: Optional[dict] = None,
|
||||||
depth=None,
|
depth=None,
|
||||||
prev_events=[],
|
prev_events: Optional[list] = None,
|
||||||
auth_events=[],
|
auth_events: Optional[list] = None,
|
||||||
prev_state=[],
|
prev_state: Optional[list] = None,
|
||||||
redacts=None,
|
redacts=None,
|
||||||
push_actions=[],
|
push_actions: Iterable = frozenset(),
|
||||||
**content
|
**content
|
||||||
):
|
):
|
||||||
|
prev_events = prev_events or []
|
||||||
|
auth_events = auth_events or []
|
||||||
|
prev_state = prev_state or []
|
||||||
|
|
||||||
if depth is None:
|
if depth is None:
|
||||||
depth = self.event_id
|
depth = self.event_id
|
||||||
|
@ -369,7 +373,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||||
if redacts is not None:
|
if redacts is not None:
|
||||||
event_dict["redacts"] = redacts
|
event_dict["redacts"] = redacts
|
||||||
|
|
||||||
event = make_event_from_dict(event_dict, internal_metadata_dict=internal)
|
event = make_event_from_dict(event_dict, internal_metadata_dict=internal or {})
|
||||||
|
|
||||||
self.event_id += 1
|
self.event_id += 1
|
||||||
state_handler = self.hs.get_state_handler()
|
state_handler = self.hs.get_state_handler()
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
"""Tests REST events for /rooms paths."""
|
"""Tests REST events for /rooms paths."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from typing import Iterable
|
||||||
from urllib import parse as urlparse
|
from urllib import parse as urlparse
|
||||||
|
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
@ -207,7 +208,9 @@ class RoomPermissionsTestCase(RoomBase):
|
||||||
)
|
)
|
||||||
self.assertEquals(403, channel.code, msg=channel.result["body"])
|
self.assertEquals(403, channel.code, msg=channel.result["body"])
|
||||||
|
|
||||||
def _test_get_membership(self, room=None, members=[], expect_code=None):
|
def _test_get_membership(
|
||||||
|
self, room=None, members: Iterable = frozenset(), expect_code=None
|
||||||
|
):
|
||||||
for member in members:
|
for member in members:
|
||||||
path = "/rooms/%s/state/m.room.member/%s" % (room, member)
|
path = "/rooms/%s/state/m.room.member/%s" % (room, member)
|
||||||
channel = self.make_request("GET", path)
|
channel = self.make_request("GET", path)
|
||||||
|
|
|
@ -132,7 +132,7 @@ class RestHelper:
|
||||||
src: str,
|
src: str,
|
||||||
targ: str,
|
targ: str,
|
||||||
membership: str,
|
membership: str,
|
||||||
extra_data: dict = {},
|
extra_data: Optional[dict] = None,
|
||||||
tok: Optional[str] = None,
|
tok: Optional[str] = None,
|
||||||
expect_code: int = 200,
|
expect_code: int = 200,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -156,7 +156,7 @@ class RestHelper:
|
||||||
path = path + "?access_token=%s" % tok
|
path = path + "?access_token=%s" % tok
|
||||||
|
|
||||||
data = {"membership": membership}
|
data = {"membership": membership}
|
||||||
data.update(extra_data)
|
data.update(extra_data or {})
|
||||||
|
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.hs.get_reactor(),
|
self.hs.get_reactor(),
|
||||||
|
@ -187,7 +187,13 @@ class RestHelper:
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_event(
|
def send_event(
|
||||||
self, room_id, type, content={}, txn_id=None, tok=None, expect_code=200
|
self,
|
||||||
|
room_id,
|
||||||
|
type,
|
||||||
|
content: Optional[dict] = None,
|
||||||
|
txn_id=None,
|
||||||
|
tok=None,
|
||||||
|
expect_code=200,
|
||||||
):
|
):
|
||||||
if txn_id is None:
|
if txn_id is None:
|
||||||
txn_id = "m%s" % (str(time.time()))
|
txn_id = "m%s" % (str(time.time()))
|
||||||
|
@ -201,7 +207,7 @@ class RestHelper:
|
||||||
self.site,
|
self.site,
|
||||||
"PUT",
|
"PUT",
|
||||||
path,
|
path,
|
||||||
json.dumps(content).encode("utf8"),
|
json.dumps(content or {}).encode("utf8"),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
import urllib
|
import urllib
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, RelationTypes
|
from synapse.api.constants import EventTypes, RelationTypes
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
|
@ -681,7 +682,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||||
relation_type,
|
relation_type,
|
||||||
event_type,
|
event_type,
|
||||||
key=None,
|
key=None,
|
||||||
content={},
|
content: Optional[dict] = None,
|
||||||
access_token=None,
|
access_token=None,
|
||||||
parent_id=None,
|
parent_id=None,
|
||||||
):
|
):
|
||||||
|
@ -713,7 +714,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||||
"POST",
|
"POST",
|
||||||
"/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
|
"/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
|
||||||
% (self.room, original_id, relation_type, event_type, query),
|
% (self.room, original_id, relation_type, event_type, query),
|
||||||
json.dumps(content).encode("utf-8"),
|
json.dumps(content or {}).encode("utf-8"),
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
)
|
)
|
||||||
return channel
|
return channel
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool
|
||||||
from synapse.storage.engines import IncorrectDatabaseSetup
|
from synapse.storage.engines import IncorrectDatabaseSetup
|
||||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||||
|
@ -43,7 +45,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_id_generator(
|
def _create_id_generator(
|
||||||
self, instance_name="master", writers=["master"]
|
self, instance_name="master", writers: Optional[List[str]] = None
|
||||||
) -> MultiWriterIdGenerator:
|
) -> MultiWriterIdGenerator:
|
||||||
def _create(conn):
|
def _create(conn):
|
||||||
return MultiWriterIdGenerator(
|
return MultiWriterIdGenerator(
|
||||||
|
@ -53,7 +55,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
instance_name=instance_name,
|
instance_name=instance_name,
|
||||||
tables=[("foobar", "instance_name", "stream_id")],
|
tables=[("foobar", "instance_name", "stream_id")],
|
||||||
sequence_name="foobar_seq",
|
sequence_name="foobar_seq",
|
||||||
writers=writers,
|
writers=writers or ["master"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
|
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
|
||||||
|
@ -476,7 +478,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_id_generator(
|
def _create_id_generator(
|
||||||
self, instance_name="master", writers=["master"]
|
self, instance_name="master", writers: Optional[List[str]] = None
|
||||||
) -> MultiWriterIdGenerator:
|
) -> MultiWriterIdGenerator:
|
||||||
def _create(conn):
|
def _create(conn):
|
||||||
return MultiWriterIdGenerator(
|
return MultiWriterIdGenerator(
|
||||||
|
@ -486,7 +488,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
instance_name=instance_name,
|
instance_name=instance_name,
|
||||||
tables=[("foobar", "instance_name", "stream_id")],
|
tables=[("foobar", "instance_name", "stream_id")],
|
||||||
sequence_name="foobar_seq",
|
sequence_name="foobar_seq",
|
||||||
writers=writers,
|
writers=writers or ["master"],
|
||||||
positive=False,
|
positive=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -612,7 +614,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_id_generator(
|
def _create_id_generator(
|
||||||
self, instance_name="master", writers=["master"]
|
self, instance_name="master", writers: Optional[List[str]] = None
|
||||||
) -> MultiWriterIdGenerator:
|
) -> MultiWriterIdGenerator:
|
||||||
def _create(conn):
|
def _create(conn):
|
||||||
return MultiWriterIdGenerator(
|
return MultiWriterIdGenerator(
|
||||||
|
@ -625,7 +627,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
("foobar2", "instance_name", "stream_id"),
|
("foobar2", "instance_name", "stream_id"),
|
||||||
],
|
],
|
||||||
sequence_name="foobar_seq",
|
sequence_name="foobar_seq",
|
||||||
writers=writers,
|
writers=writers or ["master"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
|
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from canonicaljson import json
|
from canonicaljson import json
|
||||||
|
|
||||||
|
@ -47,10 +48,15 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||||
self.depth = 1
|
self.depth = 1
|
||||||
|
|
||||||
def inject_room_member(
|
def inject_room_member(
|
||||||
self, room, user, membership, replaces_state=None, extra_content={}
|
self,
|
||||||
|
room,
|
||||||
|
user,
|
||||||
|
membership,
|
||||||
|
replaces_state=None,
|
||||||
|
extra_content: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
content = {"membership": membership}
|
content = {"membership": membership}
|
||||||
content.update(extra_content)
|
content.update(extra_content or {})
|
||||||
builder = self.event_builder_factory.for_room_version(
|
builder = self.event_builder_factory.for_room_version(
|
||||||
RoomVersions.V1,
|
RoomVersions.V1,
|
||||||
{
|
{
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
|
||||||
|
@ -37,7 +38,7 @@ def create_event(
|
||||||
state_key=None,
|
state_key=None,
|
||||||
depth=2,
|
depth=2,
|
||||||
event_id=None,
|
event_id=None,
|
||||||
prev_events=[],
|
prev_events: Optional[List[str]] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
global _next_event_id
|
global _next_event_id
|
||||||
|
@ -58,7 +59,7 @@ def create_event(
|
||||||
"sender": "@user_id:example.com",
|
"sender": "@user_id:example.com",
|
||||||
"room_id": "!room_id:example.com",
|
"room_id": "!room_id:example.com",
|
||||||
"depth": depth,
|
"depth": depth,
|
||||||
"prev_events": prev_events,
|
"prev_events": prev_events or [],
|
||||||
}
|
}
|
||||||
|
|
||||||
if state_key is not None:
|
if state_key is not None:
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
|
||||||
|
@ -147,9 +148,11 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
|
||||||
return event
|
return event
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def inject_room_member(self, user_id, membership="join", extra_content={}):
|
def inject_room_member(
|
||||||
|
self, user_id, membership="join", extra_content: Optional[dict] = None
|
||||||
|
):
|
||||||
content = {"membership": membership}
|
content = {"membership": membership}
|
||||||
content.update(extra_content)
|
content.update(extra_content or {})
|
||||||
builder = self.event_builder_factory.for_room_version(
|
builder = self.event_builder_factory.for_room_version(
|
||||||
RoomVersions.V1,
|
RoomVersions.V1,
|
||||||
{
|
{
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||||
|
|
||||||
|
@ -89,9 +91,9 @@ def _await_resolution(reactor, d):
|
||||||
return (reactor.seconds() - start_time) * 1000
|
return (reactor.seconds() - start_time) * 1000
|
||||||
|
|
||||||
|
|
||||||
def build_rc_config(settings={}):
|
def build_rc_config(settings: Optional[dict] = None):
|
||||||
config_dict = default_config("test")
|
config_dict = default_config("test")
|
||||||
config_dict.update(settings)
|
config_dict.update(settings or {})
|
||||||
config = HomeServerConfig()
|
config = HomeServerConfig()
|
||||||
config.parse_config_dict(config_dict, "", "")
|
config.parse_config_dict(config_dict, "", "")
|
||||||
return config.rc_federation
|
return config.rc_federation
|
||||||
|
|
Loading…
Reference in a new issue