mirror of
https://github.com/element-hq/synapse.git
synced 2024-12-14 11:57:44 +00:00
parent
2a1470cd05
commit
864f144543
22 changed files with 104 additions and 40 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -10,6 +10,7 @@
|
|||
*.tac
|
||||
_trial_temp/
|
||||
_trial_temp*/
|
||||
/out
|
||||
|
||||
# stuff that is likely to exist when you run a server locally
|
||||
/*.db
|
||||
|
|
1
changelog.d/6150.misc
Normal file
1
changelog.d/6150.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Expand type-checking on modules imported by synapse.config.
|
|
@ -17,6 +17,7 @@
|
|||
"""Contains exceptions and error codes."""
|
||||
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
from six import iteritems
|
||||
from six.moves import http_client
|
||||
|
@ -111,7 +112,7 @@ class ProxiedRequestError(SynapseError):
|
|||
def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None):
|
||||
super(ProxiedRequestError, self).__init__(code, msg, errcode)
|
||||
if additional_fields is None:
|
||||
self._additional_fields = {}
|
||||
self._additional_fields = {} # type: Dict
|
||||
else:
|
||||
self._additional_fields = dict(additional_fields)
|
||||
|
||||
|
|
|
@ -12,6 +12,9 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import attr
|
||||
|
||||
|
||||
|
@ -102,4 +105,4 @@ KNOWN_ROOM_VERSIONS = {
|
|||
RoomVersions.V4,
|
||||
RoomVersions.V5,
|
||||
)
|
||||
} # type: dict[str, RoomVersion]
|
||||
} # type: Dict[str, RoomVersion]
|
||||
|
|
|
@ -263,7 +263,9 @@ def start(hs, listeners=None):
|
|||
refresh_certificate(hs)
|
||||
|
||||
# Start the tracer
|
||||
synapse.logging.opentracing.init_tracer(hs.config)
|
||||
synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa
|
||||
hs.config
|
||||
)
|
||||
|
||||
# It is now safe to start your Synapse.
|
||||
hs.start_listening(listeners)
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
from six import string_types
|
||||
from six.moves.urllib import parse as urlparse
|
||||
|
@ -56,8 +57,8 @@ def load_appservices(hostname, config_files):
|
|||
return []
|
||||
|
||||
# Dicts of value -> filename
|
||||
seen_as_tokens = {}
|
||||
seen_ids = {}
|
||||
seen_as_tokens = {} # type: Dict[str, str]
|
||||
seen_ids = {} # type: Dict[str, str]
|
||||
|
||||
appservices = []
|
||||
|
||||
|
|
|
@ -73,8 +73,8 @@ DEFAULT_CONFIG = """\
|
|||
|
||||
|
||||
class ConsentConfig(Config):
|
||||
def __init__(self):
|
||||
super(ConsentConfig, self).__init__()
|
||||
def __init__(self, *args):
|
||||
super(ConsentConfig, self).__init__(*args)
|
||||
|
||||
self.user_consent_version = None
|
||||
self.user_consent_template_dir = None
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, List
|
||||
|
||||
from synapse.util.module_loader import load_module
|
||||
|
||||
from ._base import Config
|
||||
|
@ -22,7 +24,7 @@ LDAP_PROVIDER = "ldap_auth_provider.LdapAuthProvider"
|
|||
|
||||
class PasswordAuthProviderConfig(Config):
|
||||
def read_config(self, config, **kwargs):
|
||||
self.password_providers = []
|
||||
self.password_providers = [] # type: List[Any]
|
||||
providers = []
|
||||
|
||||
# We want to be backwards compatible with the old `ldap_config`
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
import os
|
||||
from collections import namedtuple
|
||||
from typing import Dict, List
|
||||
|
||||
from synapse.python_dependencies import DependencyException, check_requirements
|
||||
from synapse.util.module_loader import load_module
|
||||
|
@ -61,7 +62,7 @@ def parse_thumbnail_requirements(thumbnail_sizes):
|
|||
Dictionary mapping from media type string to list of
|
||||
ThumbnailRequirement tuples.
|
||||
"""
|
||||
requirements = {}
|
||||
requirements = {} # type: Dict[str, List]
|
||||
for size in thumbnail_sizes:
|
||||
width = size["width"]
|
||||
height = size["height"]
|
||||
|
@ -130,7 +131,7 @@ class ContentRepositoryConfig(Config):
|
|||
#
|
||||
# We don't create the storage providers here as not all workers need
|
||||
# them to be started.
|
||||
self.media_storage_providers = []
|
||||
self.media_storage_providers = [] # type: List[tuple]
|
||||
|
||||
for provider_config in storage_providers:
|
||||
# We special case the module "file_system" so as not to need to
|
||||
|
|
|
@ -19,6 +19,7 @@ import logging
|
|||
import os.path
|
||||
import re
|
||||
from textwrap import indent
|
||||
from typing import List
|
||||
|
||||
import attr
|
||||
import yaml
|
||||
|
@ -243,7 +244,7 @@ class ServerConfig(Config):
|
|||
# events with profile information that differ from the target's global profile.
|
||||
self.allow_per_room_profiles = config.get("allow_per_room_profiles", True)
|
||||
|
||||
self.listeners = []
|
||||
self.listeners = [] # type: List[dict]
|
||||
for listener in config.get("listeners", []):
|
||||
if not isinstance(listener.get("port", None), int):
|
||||
raise ConfigError(
|
||||
|
@ -287,7 +288,10 @@ class ServerConfig(Config):
|
|||
validator=attr.validators.instance_of(bool), default=False
|
||||
)
|
||||
complexity = attr.ib(
|
||||
validator=attr.validators.instance_of((int, float)), default=1.0
|
||||
validator=attr.validators.instance_of(
|
||||
(float, int) # type: ignore[arg-type] # noqa
|
||||
),
|
||||
default=1.0,
|
||||
)
|
||||
complexity_error = attr.ib(
|
||||
validator=attr.validators.instance_of(str),
|
||||
|
@ -366,7 +370,7 @@ class ServerConfig(Config):
|
|||
"cleanup_extremities_with_dummy_events", True
|
||||
)
|
||||
|
||||
def has_tls_listener(self):
|
||||
def has_tls_listener(self) -> bool:
|
||||
return any(l["tls"] for l in self.listeners)
|
||||
|
||||
def generate_config_section(
|
||||
|
|
|
@ -59,8 +59,8 @@ class ServerNoticesConfig(Config):
|
|||
None if server notices are not enabled.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(ServerNoticesConfig, self).__init__()
|
||||
def __init__(self, *args):
|
||||
super(ServerNoticesConfig, self).__init__(*args)
|
||||
self.server_notices_mxid = None
|
||||
self.server_notices_mxid_display_name = None
|
||||
self.server_notices_mxid_avatar_url = None
|
||||
|
|
|
@ -170,6 +170,7 @@ import inspect
|
|||
import logging
|
||||
import re
|
||||
from functools import wraps
|
||||
from typing import Dict
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
|
@ -547,7 +548,7 @@ def inject_active_span_twisted_headers(headers, destination, check_destination=T
|
|||
return
|
||||
|
||||
span = opentracing.tracer.active_span
|
||||
carrier = {}
|
||||
carrier = {} # type: Dict[str, str]
|
||||
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
|
||||
|
||||
for key, value in carrier.items():
|
||||
|
@ -584,7 +585,7 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True):
|
|||
|
||||
span = opentracing.tracer.active_span
|
||||
|
||||
carrier = {}
|
||||
carrier = {} # type: Dict[str, str]
|
||||
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
|
||||
|
||||
for key, value in carrier.items():
|
||||
|
@ -639,7 +640,7 @@ def get_active_span_text_map(destination=None):
|
|||
if destination and not whitelisted_homeserver(destination):
|
||||
return {}
|
||||
|
||||
carrier = {}
|
||||
carrier = {} # type: Dict[str, str]
|
||||
opentracing.tracer.inject(
|
||||
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
|
||||
)
|
||||
|
@ -653,7 +654,7 @@ def active_span_context_as_string():
|
|||
Returns:
|
||||
The active span context encoded as a string.
|
||||
"""
|
||||
carrier = {}
|
||||
carrier = {} # type: Dict[str, str]
|
||||
if opentracing:
|
||||
opentracing.tracer.inject(
|
||||
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
|
||||
|
|
|
@ -119,7 +119,11 @@ def trace_function(f):
|
|||
logger = logging.getLogger(name)
|
||||
level = logging.DEBUG
|
||||
|
||||
s = inspect.currentframe().f_back
|
||||
frame = inspect.currentframe()
|
||||
if frame is None:
|
||||
raise Exception("Can't get current frame!")
|
||||
|
||||
s = frame.f_back
|
||||
|
||||
to_print = [
|
||||
"\t%s:%s %s. Args: args=%s, kwargs=%s"
|
||||
|
@ -144,7 +148,7 @@ def trace_function(f):
|
|||
pathname=pathname,
|
||||
lineno=lineno,
|
||||
msg=msg,
|
||||
args=None,
|
||||
args=tuple(),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
|
@ -157,7 +161,12 @@ def trace_function(f):
|
|||
|
||||
|
||||
def get_previous_frames():
|
||||
s = inspect.currentframe().f_back.f_back
|
||||
|
||||
frame = inspect.currentframe()
|
||||
if frame is None:
|
||||
raise Exception("Can't get current frame!")
|
||||
|
||||
s = frame.f_back.f_back
|
||||
to_return = []
|
||||
while s:
|
||||
if s.f_globals["__name__"].startswith("synapse"):
|
||||
|
@ -174,7 +183,10 @@ def get_previous_frames():
|
|||
|
||||
|
||||
def get_previous_frame(ignore=[]):
|
||||
s = inspect.currentframe().f_back.f_back
|
||||
frame = inspect.currentframe()
|
||||
if frame is None:
|
||||
raise Exception("Can't get current frame!")
|
||||
s = frame.f_back.f_back
|
||||
|
||||
while s:
|
||||
if s.f_globals["__name__"].startswith("synapse"):
|
||||
|
|
|
@ -125,7 +125,7 @@ class InFlightGauge(object):
|
|||
)
|
||||
|
||||
# Counts number of in flight blocks for a given set of label values
|
||||
self._registrations = {}
|
||||
self._registrations = {} # type: Dict
|
||||
|
||||
# Protects access to _registrations
|
||||
self._lock = threading.Lock()
|
||||
|
@ -226,7 +226,7 @@ class BucketCollector(object):
|
|||
# Fetch the data -- this must be synchronous!
|
||||
data = self.data_collector()
|
||||
|
||||
buckets = {}
|
||||
buckets = {} # type: Dict[float, int]
|
||||
|
||||
res = []
|
||||
for x in data.keys():
|
||||
|
|
|
@ -36,9 +36,9 @@ from twisted.web.resource import Resource
|
|||
try:
|
||||
from prometheus_client.samples import Sample
|
||||
except ImportError:
|
||||
Sample = namedtuple(
|
||||
Sample = namedtuple( # type: ignore[no-redef] # noqa
|
||||
"Sample", ["name", "labels", "value", "timestamp", "exemplar"]
|
||||
) # type: ignore
|
||||
)
|
||||
|
||||
|
||||
CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8")
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Set
|
||||
from typing import List, Set
|
||||
|
||||
from pkg_resources import (
|
||||
DistributionNotFound,
|
||||
|
@ -73,6 +73,7 @@ REQUIREMENTS = [
|
|||
"netaddr>=0.7.18",
|
||||
"Jinja2>=2.9",
|
||||
"bleach>=1.4.3",
|
||||
"typing-extensions>=3.7.4",
|
||||
]
|
||||
|
||||
CONDITIONAL_REQUIREMENTS = {
|
||||
|
@ -144,7 +145,11 @@ def check_requirements(for_feature=None):
|
|||
deps_needed.append(dependency)
|
||||
errors.append(
|
||||
"Needed %s, got %s==%s"
|
||||
% (dependency, e.dist.project_name, e.dist.version)
|
||||
% (
|
||||
dependency,
|
||||
e.dist.project_name, # type: ignore[attr-defined] # noqa
|
||||
e.dist.version, # type: ignore[attr-defined] # noqa
|
||||
)
|
||||
)
|
||||
except DistributionNotFound:
|
||||
deps_needed.append(dependency)
|
||||
|
@ -159,7 +164,7 @@ def check_requirements(for_feature=None):
|
|||
if not for_feature:
|
||||
# Check the optional dependencies are up to date. We allow them to not be
|
||||
# installed.
|
||||
OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), [])
|
||||
OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), []) # type: List[str]
|
||||
|
||||
for dependency in OPTS:
|
||||
try:
|
||||
|
@ -168,7 +173,11 @@ def check_requirements(for_feature=None):
|
|||
deps_needed.append(dependency)
|
||||
errors.append(
|
||||
"Needed optional %s, got %s==%s"
|
||||
% (dependency, e.dist.project_name, e.dist.version)
|
||||
% (
|
||||
dependency,
|
||||
e.dist.project_name, # type: ignore[attr-defined] # noqa
|
||||
e.dist.version, # type: ignore[attr-defined] # noqa
|
||||
)
|
||||
)
|
||||
except DistributionNotFound:
|
||||
# If it's not found, we don't care
|
||||
|
|
|
@ -318,6 +318,7 @@ class StreamToken(
|
|||
)
|
||||
):
|
||||
_SEPARATOR = "_"
|
||||
START = None # type: StreamToken
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, string):
|
||||
|
@ -402,7 +403,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
|
|||
followed by the "stream_ordering" id of the event it comes after.
|
||||
"""
|
||||
|
||||
__slots__ = []
|
||||
__slots__ = [] # type: list
|
||||
|
||||
@classmethod
|
||||
def parse(cls, string):
|
||||
|
|
|
@ -13,9 +13,11 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, Sequence, Set, Union
|
||||
|
||||
from six.moves import range
|
||||
|
||||
|
@ -213,7 +215,9 @@ class Linearizer(object):
|
|||
# the first element is the number of things executing, and
|
||||
# the second element is an OrderedDict, where the keys are deferreds for the
|
||||
# things blocked from executing.
|
||||
self.key_to_defer = {}
|
||||
self.key_to_defer = (
|
||||
{}
|
||||
) # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]]
|
||||
|
||||
def queue(self, key):
|
||||
# we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
|
||||
|
@ -340,10 +344,10 @@ class ReadWriteLock(object):
|
|||
|
||||
def __init__(self):
|
||||
# Latest readers queued
|
||||
self.key_to_current_readers = {}
|
||||
self.key_to_current_readers = {} # type: Dict[str, Set[defer.Deferred]]
|
||||
|
||||
# Latest writer queued
|
||||
self.key_to_current_writer = {}
|
||||
self.key_to_current_writer = {} # type: Dict[str, defer.Deferred]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def read(self, key):
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import six
|
||||
from six.moves import intern
|
||||
|
@ -37,7 +38,7 @@ def get_cache_factor_for(cache_name):
|
|||
|
||||
|
||||
caches_by_name = {}
|
||||
collectors_by_name = {}
|
||||
collectors_by_name = {} # type: Dict
|
||||
|
||||
cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
|
||||
cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
|
||||
|
|
|
@ -18,10 +18,12 @@ import inspect
|
|||
import logging
|
||||
import threading
|
||||
from collections import namedtuple
|
||||
from typing import Any, cast
|
||||
|
||||
from six import itervalues
|
||||
|
||||
from prometheus_client import Gauge
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
|
@ -37,6 +39,18 @@ from . import register_cache
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _CachedFunction(Protocol):
|
||||
invalidate = None # type: Any
|
||||
invalidate_all = None # type: Any
|
||||
invalidate_many = None # type: Any
|
||||
prefill = None # type: Any
|
||||
cache = None # type: Any
|
||||
num_args = None # type: Any
|
||||
|
||||
def __name__(self):
|
||||
...
|
||||
|
||||
|
||||
cache_pending_metric = Gauge(
|
||||
"synapse_util_caches_cache_pending",
|
||||
"Number of lookups currently pending for this cache",
|
||||
|
@ -245,7 +259,9 @@ class Cache(object):
|
|||
|
||||
|
||||
class _CacheDescriptorBase(object):
|
||||
def __init__(self, orig, num_args, inlineCallbacks, cache_context=False):
|
||||
def __init__(
|
||||
self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False
|
||||
):
|
||||
self.orig = orig
|
||||
|
||||
if inlineCallbacks:
|
||||
|
@ -404,7 +420,7 @@ class CacheDescriptor(_CacheDescriptorBase):
|
|||
return tuple(get_cache_key_gen(args, kwargs))
|
||||
|
||||
@functools.wraps(self.orig)
|
||||
def wrapped(*args, **kwargs):
|
||||
def _wrapped(*args, **kwargs):
|
||||
# If we're passed a cache_context then we'll want to call its invalidate()
|
||||
# whenever we are invalidated
|
||||
invalidate_callback = kwargs.pop("on_invalidate", None)
|
||||
|
@ -440,6 +456,8 @@ class CacheDescriptor(_CacheDescriptorBase):
|
|||
|
||||
return make_deferred_yieldable(observer)
|
||||
|
||||
wrapped = cast(_CachedFunction, _wrapped)
|
||||
|
||||
if self.num_args == 1:
|
||||
wrapped.invalidate = lambda key: cache.invalidate(key[0])
|
||||
wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Dict
|
||||
|
||||
from six import itervalues
|
||||
|
||||
SENTINEL = object()
|
||||
|
@ -12,7 +14,7 @@ class TreeCache(object):
|
|||
|
||||
def __init__(self):
|
||||
self.size = 0
|
||||
self.root = {}
|
||||
self.root = {} # type: Dict
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
return self.set(key, value)
|
||||
|
|
|
@ -54,5 +54,5 @@ def load_python_module(location: str):
|
|||
if spec is None:
|
||||
raise Exception("Unable to load module at %s" % (location,))
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
spec.loader.exec_module(mod) # type: ignore
|
||||
return mod
|
||||
|
|
Loading…
Reference in a new issue