Add missing type hints to config base classes (#11377)

This commit is contained in:
Patrick Cloke 2021-11-23 10:21:19 -05:00 committed by GitHub
parent 7cebaf9644
commit 55669bd3de
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 184 additions and 109 deletions

1
changelog.d/11377.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a bug introduced in v1.45.0 where the `read_templates` method of the module API would error.

1
changelog.d/11377.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to configuration classes.

View file

@ -151,6 +151,9 @@ disallow_untyped_defs = True
[mypy-synapse.app.*] [mypy-synapse.app.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.config._base]
disallow_untyped_defs = True
[mypy-synapse.crypto.*] [mypy-synapse.crypto.*]
disallow_untyped_defs = True disallow_untyped_defs = True

View file

@ -20,7 +20,18 @@ import os
from collections import OrderedDict from collections import OrderedDict
from hashlib import sha256 from hashlib import sha256
from textwrap import dedent from textwrap import dedent
from typing import Any, Iterable, List, MutableMapping, Optional, Union from typing import (
Any,
Dict,
Iterable,
List,
MutableMapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
import attr import attr
import jinja2 import jinja2
@ -78,7 +89,7 @@ CONFIG_FILE_HEADER = """\
""" """
def path_exists(file_path): def path_exists(file_path: str) -> bool:
"""Check if a file exists """Check if a file exists
Unlike os.path.exists, this throws an exception if there is an error Unlike os.path.exists, this throws an exception if there is an error
@ -86,7 +97,7 @@ def path_exists(file_path):
the parent dir). the parent dir).
Returns: Returns:
bool: True if the file exists; False if not. True if the file exists; False if not.
""" """
try: try:
os.stat(file_path) os.stat(file_path)
@ -102,15 +113,15 @@ class Config:
A configuration section, containing configuration keys and values. A configuration section, containing configuration keys and values.
Attributes: Attributes:
section (str): The section title of this config object, such as section: The section title of this config object, such as
"tls" or "logger". This is used to refer to it on the root "tls" or "logger". This is used to refer to it on the root
logger (for example, `config.tls.some_option`). Must be logger (for example, `config.tls.some_option`). Must be
defined in subclasses. defined in subclasses.
""" """
section = None section: str
def __init__(self, root_config=None): def __init__(self, root_config: "RootConfig" = None):
self.root = root_config self.root = root_config
# Get the path to the default Synapse template directory # Get the path to the default Synapse template directory
@ -119,7 +130,7 @@ class Config:
) )
@staticmethod @staticmethod
def parse_size(value): def parse_size(value: Union[str, int]) -> int:
if isinstance(value, int): if isinstance(value, int):
return value return value
sizes = {"K": 1024, "M": 1024 * 1024} sizes = {"K": 1024, "M": 1024 * 1024}
@ -162,15 +173,15 @@ class Config:
return int(value) * size return int(value) * size
@staticmethod @staticmethod
def abspath(file_path): def abspath(file_path: str) -> str:
return os.path.abspath(file_path) if file_path else file_path return os.path.abspath(file_path) if file_path else file_path
@classmethod @classmethod
def path_exists(cls, file_path): def path_exists(cls, file_path: str) -> bool:
return path_exists(file_path) return path_exists(file_path)
@classmethod @classmethod
def check_file(cls, file_path, config_name): def check_file(cls, file_path: Optional[str], config_name: str) -> str:
if file_path is None: if file_path is None:
raise ConfigError("Missing config for %s." % (config_name,)) raise ConfigError("Missing config for %s." % (config_name,))
try: try:
@ -183,7 +194,7 @@ class Config:
return cls.abspath(file_path) return cls.abspath(file_path)
@classmethod @classmethod
def ensure_directory(cls, dir_path): def ensure_directory(cls, dir_path: str) -> str:
dir_path = cls.abspath(dir_path) dir_path = cls.abspath(dir_path)
os.makedirs(dir_path, exist_ok=True) os.makedirs(dir_path, exist_ok=True)
if not os.path.isdir(dir_path): if not os.path.isdir(dir_path):
@ -191,7 +202,7 @@ class Config:
return dir_path return dir_path
@classmethod @classmethod
def read_file(cls, file_path, config_name): def read_file(cls, file_path: Any, config_name: str) -> str:
"""Deprecated: call read_file directly""" """Deprecated: call read_file directly"""
return read_file(file_path, (config_name,)) return read_file(file_path, (config_name,))
@ -284,6 +295,9 @@ class Config:
return [env.get_template(filename) for filename in filenames] return [env.get_template(filename) for filename in filenames]
TRootConfig = TypeVar("TRootConfig", bound="RootConfig")
class RootConfig: class RootConfig:
""" """
Holder of an application's configuration. Holder of an application's configuration.
@ -308,7 +322,9 @@ class RootConfig:
raise Exception("Failed making %s: %r" % (config_class.section, e)) raise Exception("Failed making %s: %r" % (config_class.section, e))
setattr(self, config_class.section, conf) setattr(self, config_class.section, conf)
def invoke_all(self, func_name: str, *args, **kwargs) -> MutableMapping[str, Any]: def invoke_all(
self, func_name: str, *args: Any, **kwargs: Any
) -> MutableMapping[str, Any]:
""" """
Invoke a function on all instantiated config objects this RootConfig is Invoke a function on all instantiated config objects this RootConfig is
configured to use. configured to use.
@ -317,6 +333,7 @@ class RootConfig:
func_name: Name of function to invoke func_name: Name of function to invoke
*args *args
**kwargs **kwargs
Returns: Returns:
ordered dictionary of config section name and the result of the ordered dictionary of config section name and the result of the
function from it. function from it.
@ -332,7 +349,7 @@ class RootConfig:
return res return res
@classmethod @classmethod
def invoke_all_static(cls, func_name: str, *args, **kwargs): def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: any) -> None:
""" """
Invoke a static function on config objects this RootConfig is Invoke a static function on config objects this RootConfig is
configured to use. configured to use.
@ -341,6 +358,7 @@ class RootConfig:
func_name: Name of function to invoke func_name: Name of function to invoke
*args *args
**kwargs **kwargs
Returns: Returns:
ordered dictionary of config section name and the result of the ordered dictionary of config section name and the result of the
function from it. function from it.
@ -351,16 +369,16 @@ class RootConfig:
def generate_config( def generate_config(
self, self,
config_dir_path, config_dir_path: str,
data_dir_path, data_dir_path: str,
server_name, server_name: str,
generate_secrets=False, generate_secrets: bool = False,
report_stats=None, report_stats: Optional[bool] = None,
open_private_ports=False, open_private_ports: bool = False,
listeners=None, listeners: Optional[List[dict]] = None,
tls_certificate_path=None, tls_certificate_path: Optional[str] = None,
tls_private_key_path=None, tls_private_key_path: Optional[str] = None,
): ) -> str:
""" """
Build a default configuration file Build a default configuration file
@ -368,27 +386,27 @@ class RootConfig:
(eg with --generate_config). (eg with --generate_config).
Args: Args:
config_dir_path (str): The path where the config files are kept. Used to config_dir_path: The path where the config files are kept. Used to
create filenames for things like the log config and the signing key. create filenames for things like the log config and the signing key.
data_dir_path (str): The path where the data files are kept. Used to create data_dir_path: The path where the data files are kept. Used to create
filenames for things like the database and media store. filenames for things like the database and media store.
server_name (str): The server name. Used to initialise the server_name server_name: The server name. Used to initialise the server_name
config param, but also used in the names of some of the config files. config param, but also used in the names of some of the config files.
generate_secrets (bool): True if we should generate new secrets for things generate_secrets: True if we should generate new secrets for things
like the macaroon_secret_key. If False, these parameters will be left like the macaroon_secret_key. If False, these parameters will be left
unset. unset.
report_stats (bool|None): Initial setting for the report_stats setting. report_stats: Initial setting for the report_stats setting.
If None, report_stats will be left unset. If None, report_stats will be left unset.
open_private_ports (bool): True to leave private ports (such as the non-TLS open_private_ports: True to leave private ports (such as the non-TLS
HTTP listener) open to the internet. HTTP listener) open to the internet.
listeners (list(dict)|None): A list of descriptions of the listeners listeners: A list of descriptions of the listeners synapse should
synapse should start with each of which specifies a port (str), a list of start with each of which specifies a port (int), a list of
resources (list(str)), tls (bool) and type (str). For example: resources (list(str)), tls (bool) and type (str). For example:
[{ [{
"port": 8448, "port": 8448,
@ -403,16 +421,12 @@ class RootConfig:
"type": "http", "type": "http",
}], }],
tls_certificate_path: The path to the tls certificate.
database (str|None): The database type to configure, either `psycog2` tls_private_key_path: The path to the tls private key.
or `sqlite3`.
tls_certificate_path (str|None): The path to the tls certificate.
tls_private_key_path (str|None): The path to the tls private key.
Returns: Returns:
str: the yaml config file The yaml config file
""" """
return CONFIG_FILE_HEADER + "\n\n".join( return CONFIG_FILE_HEADER + "\n\n".join(
@ -432,12 +446,15 @@ class RootConfig:
) )
@classmethod @classmethod
def load_config(cls, description, argv): def load_config(
cls: Type[TRootConfig], description: str, argv: List[str]
) -> TRootConfig:
"""Parse the commandline and config files """Parse the commandline and config files
Doesn't support config-file-generation: used by the worker apps. Doesn't support config-file-generation: used by the worker apps.
Returns: Config object. Returns:
Config object.
""" """
config_parser = argparse.ArgumentParser(description=description) config_parser = argparse.ArgumentParser(description=description)
cls.add_arguments_to_parser(config_parser) cls.add_arguments_to_parser(config_parser)
@ -446,7 +463,7 @@ class RootConfig:
return obj return obj
@classmethod @classmethod
def add_arguments_to_parser(cls, config_parser): def add_arguments_to_parser(cls, config_parser: argparse.ArgumentParser) -> None:
"""Adds all the config flags to an ArgumentParser. """Adds all the config flags to an ArgumentParser.
Doesn't support config-file-generation: used by the worker apps. Doesn't support config-file-generation: used by the worker apps.
@ -454,7 +471,7 @@ class RootConfig:
Used for workers where we want to add extra flags/subcommands. Used for workers where we want to add extra flags/subcommands.
Args: Args:
config_parser (ArgumentParser): App description config_parser: App description
""" """
config_parser.add_argument( config_parser.add_argument(
@ -477,7 +494,9 @@ class RootConfig:
cls.invoke_all_static("add_arguments", config_parser) cls.invoke_all_static("add_arguments", config_parser)
@classmethod @classmethod
def load_config_with_parser(cls, parser, argv): def load_config_with_parser(
cls: Type[TRootConfig], parser: argparse.ArgumentParser, argv: List[str]
) -> Tuple[TRootConfig, argparse.Namespace]:
"""Parse the commandline and config files with the given parser """Parse the commandline and config files with the given parser
Doesn't support config-file-generation: used by the worker apps. Doesn't support config-file-generation: used by the worker apps.
@ -485,13 +504,12 @@ class RootConfig:
Used for workers where we want to add extra flags/subcommands. Used for workers where we want to add extra flags/subcommands.
Args: Args:
parser (ArgumentParser) parser
argv (list[str]) argv
Returns: Returns:
tuple[HomeServerConfig, argparse.Namespace]: Returns the parsed Returns the parsed config object and the parsed argparse.Namespace
config object and the parsed argparse.Namespace object from object from parser.parse_args(..)`
`parser.parse_args(..)`
""" """
obj = cls() obj = cls()
@ -520,12 +538,15 @@ class RootConfig:
return obj, config_args return obj, config_args
@classmethod @classmethod
def load_or_generate_config(cls, description, argv): def load_or_generate_config(
cls: Type[TRootConfig], description: str, argv: List[str]
) -> Optional[TRootConfig]:
"""Parse the commandline and config files """Parse the commandline and config files
Supports generation of config files, so is used for the main homeserver app. Supports generation of config files, so is used for the main homeserver app.
Returns: Config object, or None if --generate-config or --generate-keys was set Returns:
Config object, or None if --generate-config or --generate-keys was set
""" """
parser = argparse.ArgumentParser(description=description) parser = argparse.ArgumentParser(description=description)
parser.add_argument( parser.add_argument(
@ -680,16 +701,21 @@ class RootConfig:
return obj return obj
def parse_config_dict(self, config_dict, config_dir_path=None, data_dir_path=None): def parse_config_dict(
self,
config_dict: Dict[str, Any],
config_dir_path: Optional[str] = None,
data_dir_path: Optional[str] = None,
) -> None:
"""Read the information from the config dict into this Config object. """Read the information from the config dict into this Config object.
Args: Args:
config_dict (dict): Configuration data, as read from the yaml config_dict: Configuration data, as read from the yaml
config_dir_path (str): The path where the config files are kept. Used to config_dir_path: The path where the config files are kept. Used to
create filenames for things like the log config and the signing key. create filenames for things like the log config and the signing key.
data_dir_path (str): The path where the data files are kept. Used to create data_dir_path: The path where the data files are kept. Used to create
filenames for things like the database and media store. filenames for things like the database and media store.
""" """
self.invoke_all( self.invoke_all(
@ -699,17 +725,20 @@ class RootConfig:
data_dir_path=data_dir_path, data_dir_path=data_dir_path,
) )
def generate_missing_files(self, config_dict, config_dir_path): def generate_missing_files(
self, config_dict: Dict[str, Any], config_dir_path: str
) -> None:
self.invoke_all("generate_files", config_dict, config_dir_path) self.invoke_all("generate_files", config_dict, config_dir_path)
def read_config_files(config_files): def read_config_files(config_files: Iterable[str]) -> Dict[str, Any]:
"""Read the config files into a dict """Read the config files into a dict
Args: Args:
config_files (iterable[str]): A list of the config files to read config_files: A list of the config files to read
Returns: dict Returns:
The configuration dictionary.
""" """
specified_config = {} specified_config = {}
for config_file in config_files: for config_file in config_files:
@ -733,17 +762,17 @@ def read_config_files(config_files):
return specified_config return specified_config
def find_config_files(search_paths): def find_config_files(search_paths: List[str]) -> List[str]:
"""Finds config files using a list of search paths. If a path is a file """Finds config files using a list of search paths. If a path is a file
then that file path is added to the list. If a search path is a directory then that file path is added to the list. If a search path is a directory
then all the "*.yaml" files in that directory are added to the list in then all the "*.yaml" files in that directory are added to the list in
sorted order. sorted order.
Args: Args:
search_paths(list(str)): A list of paths to search. search_paths: A list of paths to search.
Returns: Returns:
list(str): A list of file paths. A list of file paths.
""" """
config_files = [] config_files = []
@ -777,7 +806,7 @@ def find_config_files(search_paths):
return config_files return config_files
@attr.s @attr.s(auto_attribs=True)
class ShardedWorkerHandlingConfig: class ShardedWorkerHandlingConfig:
"""Algorithm for choosing which instance is responsible for handling some """Algorithm for choosing which instance is responsible for handling some
sharded work. sharded work.
@ -787,7 +816,7 @@ class ShardedWorkerHandlingConfig:
below). below).
""" """
instances = attr.ib(type=List[str]) instances: List[str]
def should_handle(self, instance_name: str, key: str) -> bool: def should_handle(self, instance_name: str, key: str) -> bool:
"""Whether this instance is responsible for handling the given key.""" """Whether this instance is responsible for handling the given key."""

View file

@ -1,4 +1,18 @@
from typing import Any, Iterable, List, Optional import argparse
from typing import (
Any,
Dict,
Iterable,
List,
MutableMapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
import jinja2
from synapse.config import ( from synapse.config import (
account_validity, account_validity,
@ -19,6 +33,7 @@ from synapse.config import (
logger, logger,
metrics, metrics,
modules, modules,
oembed,
oidc, oidc,
password_auth_providers, password_auth_providers,
push, push,
@ -27,6 +42,7 @@ from synapse.config import (
registration, registration,
repository, repository,
retention, retention,
room,
room_directory, room_directory,
saml2, saml2,
server, server,
@ -51,7 +67,9 @@ MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str
MISSING_REPORT_STATS_SPIEL: str MISSING_REPORT_STATS_SPIEL: str
MISSING_SERVER_NAME: str MISSING_SERVER_NAME: str
def path_exists(file_path: str): ... def path_exists(file_path: str) -> bool: ...
TRootConfig = TypeVar("TRootConfig", bound="RootConfig")
class RootConfig: class RootConfig:
server: server.ServerConfig server: server.ServerConfig
@ -61,6 +79,7 @@ class RootConfig:
logging: logger.LoggingConfig logging: logger.LoggingConfig
ratelimiting: ratelimiting.RatelimitConfig ratelimiting: ratelimiting.RatelimitConfig
media: repository.ContentRepositoryConfig media: repository.ContentRepositoryConfig
oembed: oembed.OembedConfig
captcha: captcha.CaptchaConfig captcha: captcha.CaptchaConfig
voip: voip.VoipConfig voip: voip.VoipConfig
registration: registration.RegistrationConfig registration: registration.RegistrationConfig
@ -80,6 +99,7 @@ class RootConfig:
authproviders: password_auth_providers.PasswordAuthProviderConfig authproviders: password_auth_providers.PasswordAuthProviderConfig
push: push.PushConfig push: push.PushConfig
spamchecker: spam_checker.SpamCheckerConfig spamchecker: spam_checker.SpamCheckerConfig
room: room.RoomConfig
groups: groups.GroupsConfig groups: groups.GroupsConfig
userdirectory: user_directory.UserDirectoryConfig userdirectory: user_directory.UserDirectoryConfig
consent: consent.ConsentConfig consent: consent.ConsentConfig
@ -87,72 +107,85 @@ class RootConfig:
servernotices: server_notices.ServerNoticesConfig servernotices: server_notices.ServerNoticesConfig
roomdirectory: room_directory.RoomDirectoryConfig roomdirectory: room_directory.RoomDirectoryConfig
thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig
tracer: tracer.TracerConfig tracing: tracer.TracerConfig
redis: redis.RedisConfig redis: redis.RedisConfig
modules: modules.ModulesConfig modules: modules.ModulesConfig
caches: cache.CacheConfig caches: cache.CacheConfig
federation: federation.FederationConfig federation: federation.FederationConfig
retention: retention.RetentionConfig retention: retention.RetentionConfig
config_classes: List = ... config_classes: List[Type["Config"]] = ...
def __init__(self) -> None: ... def __init__(self) -> None: ...
def invoke_all(self, func_name: str, *args: Any, **kwargs: Any): ... def invoke_all(
self, func_name: str, *args: Any, **kwargs: Any
) -> MutableMapping[str, Any]: ...
@classmethod @classmethod
def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: Any) -> None: ... def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: Any) -> None: ...
def __getattr__(self, item: str): ...
def parse_config_dict( def parse_config_dict(
self, self,
config_dict: Any, config_dict: Dict[str, Any],
config_dir_path: Optional[Any] = ..., config_dir_path: Optional[str] = ...,
data_dir_path: Optional[Any] = ..., data_dir_path: Optional[str] = ...,
) -> None: ... ) -> None: ...
read_config: Any = ...
def generate_config( def generate_config(
self, self,
config_dir_path: str, config_dir_path: str,
data_dir_path: str, data_dir_path: str,
server_name: str, server_name: str,
generate_secrets: bool = ..., generate_secrets: bool = ...,
report_stats: Optional[str] = ..., report_stats: Optional[bool] = ...,
open_private_ports: bool = ..., open_private_ports: bool = ...,
listeners: Optional[Any] = ..., listeners: Optional[Any] = ...,
database_conf: Optional[Any] = ...,
tls_certificate_path: Optional[str] = ..., tls_certificate_path: Optional[str] = ...,
tls_private_key_path: Optional[str] = ..., tls_private_key_path: Optional[str] = ...,
): ... ) -> str: ...
@classmethod @classmethod
def load_or_generate_config(cls, description: Any, argv: Any): ... def load_or_generate_config(
cls: Type[TRootConfig], description: str, argv: List[str]
) -> Optional[TRootConfig]: ...
@classmethod @classmethod
def load_config(cls, description: Any, argv: Any): ... def load_config(
cls: Type[TRootConfig], description: str, argv: List[str]
) -> TRootConfig: ...
@classmethod @classmethod
def add_arguments_to_parser(cls, config_parser: Any) -> None: ... def add_arguments_to_parser(
cls, config_parser: argparse.ArgumentParser
) -> None: ...
@classmethod @classmethod
def load_config_with_parser(cls, parser: Any, argv: Any): ... def load_config_with_parser(
cls: Type[TRootConfig], parser: argparse.ArgumentParser, argv: List[str]
) -> Tuple[TRootConfig, argparse.Namespace]: ...
def generate_missing_files( def generate_missing_files(
self, config_dict: dict, config_dir_path: str self, config_dict: dict, config_dir_path: str
) -> None: ... ) -> None: ...
class Config: class Config:
root: RootConfig root: RootConfig
default_template_dir: str
def __init__(self, root_config: Optional[RootConfig] = ...) -> None: ... def __init__(self, root_config: Optional[RootConfig] = ...) -> None: ...
def __getattr__(self, item: str, from_root: bool = ...): ...
@staticmethod @staticmethod
def parse_size(value: Any): ... def parse_size(value: Union[str, int]) -> int: ...
@staticmethod @staticmethod
def parse_duration(value: Any): ... def parse_duration(value: Union[str, int]) -> int: ...
@staticmethod @staticmethod
def abspath(file_path: Optional[str]): ... def abspath(file_path: Optional[str]) -> str: ...
@classmethod @classmethod
def path_exists(cls, file_path: str): ... def path_exists(cls, file_path: str) -> bool: ...
@classmethod @classmethod
def check_file(cls, file_path: str, config_name: str): ... def check_file(cls, file_path: str, config_name: str) -> str: ...
@classmethod @classmethod
def ensure_directory(cls, dir_path: str): ... def ensure_directory(cls, dir_path: str) -> str: ...
@classmethod @classmethod
def read_file(cls, file_path: str, config_name: str): ... def read_file(cls, file_path: str, config_name: str) -> str: ...
def read_template(self, filenames: str) -> jinja2.Template: ...
def read_templates(
self,
filenames: List[str],
custom_template_directories: Optional[Iterable[str]] = None,
) -> List[jinja2.Template]: ...
def read_config_files(config_files: List[str]): ... def read_config_files(config_files: Iterable[str]) -> Dict[str, Any]: ...
def find_config_files(search_paths: List[str]): ... def find_config_files(search_paths: List[str]) -> List[str]: ...
class ShardedWorkerHandlingConfig: class ShardedWorkerHandlingConfig:
instances: List[str] instances: List[str]

View file

@ -15,7 +15,7 @@
import os import os
import re import re
import threading import threading
from typing import Callable, Dict from typing import Callable, Dict, Optional
from synapse.python_dependencies import DependencyException, check_requirements from synapse.python_dependencies import DependencyException, check_requirements
@ -217,7 +217,7 @@ class CacheConfig(Config):
expiry_time = cache_config.get("expiry_time") expiry_time = cache_config.get("expiry_time")
if expiry_time: if expiry_time:
self.expiry_time_msec = self.parse_duration(expiry_time) self.expiry_time_msec: Optional[int] = self.parse_duration(expiry_time)
else: else:
self.expiry_time_msec = None self.expiry_time_msec = None

View file

@ -16,6 +16,7 @@
import hashlib import hashlib
import logging import logging
import os import os
from typing import Any, Dict
import attr import attr
import jsonschema import jsonschema
@ -312,7 +313,7 @@ class KeyConfig(Config):
) )
return keys return keys
def generate_files(self, config, config_dir_path): def generate_files(self, config: Dict[str, Any], config_dir_path: str) -> None:
if "signing_key" in config: if "signing_key" in config:
return return

View file

@ -18,7 +18,7 @@ import os
import sys import sys
import threading import threading
from string import Template from string import Template
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Any, Dict
import yaml import yaml
from zope.interface import implementer from zope.interface import implementer
@ -185,7 +185,7 @@ class LoggingConfig(Config):
help=argparse.SUPPRESS, help=argparse.SUPPRESS,
) )
def generate_files(self, config, config_dir_path): def generate_files(self, config: Dict[str, Any], config_dir_path: str) -> None:
log_config = config.get("log_config") log_config = config.get("log_config")
if log_config and not os.path.exists(log_config): if log_config and not os.path.exists(log_config):
log_file = self.abspath("homeserver.log") log_file = self.abspath("homeserver.log")

View file

@ -421,7 +421,7 @@ class ServerConfig(Config):
# before redacting them. # before redacting them.
redaction_retention_period = config.get("redaction_retention_period", "7d") redaction_retention_period = config.get("redaction_retention_period", "7d")
if redaction_retention_period is not None: if redaction_retention_period is not None:
self.redaction_retention_period = self.parse_duration( self.redaction_retention_period: Optional[int] = self.parse_duration(
redaction_retention_period redaction_retention_period
) )
else: else:
@ -430,7 +430,7 @@ class ServerConfig(Config):
# How long to keep entries in the `users_ips` table. # How long to keep entries in the `users_ips` table.
user_ips_max_age = config.get("user_ips_max_age", "28d") user_ips_max_age = config.get("user_ips_max_age", "28d")
if user_ips_max_age is not None: if user_ips_max_age is not None:
self.user_ips_max_age = self.parse_duration(user_ips_max_age) self.user_ips_max_age: Optional[int] = self.parse_duration(user_ips_max_age)
else: else:
self.user_ips_max_age = None self.user_ips_max_age = None

View file

@ -245,7 +245,7 @@ class TlsConfig(Config):
cert_path = self.tls_certificate_file cert_path = self.tls_certificate_file
logger.info("Loading TLS certificate from %s", cert_path) logger.info("Loading TLS certificate from %s", cert_path)
cert_pem = self.read_file(cert_path, "tls_certificate_path") cert_pem = self.read_file(cert_path, "tls_certificate_path")
cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem) cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem.encode())
return cert return cert

View file

@ -1014,7 +1014,7 @@ class ModuleApi:
A list containing the loaded templates, with the orders matching the one of A list containing the loaded templates, with the orders matching the one of
the filenames parameter. the filenames parameter.
""" """
return self._hs.config.read_templates( return self._hs.config.server.read_templates(
filenames, filenames,
(td for td in (self.custom_template_dir, custom_template_directory) if td), (td for td in (self.custom_template_dir, custom_template_directory) if td),
) )

View file

@ -1198,8 +1198,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
expiration_ts = now_ms + self._account_validity_period expiration_ts = now_ms + self._account_validity_period
if use_delta: if use_delta:
assert self._account_validity_startup_job_max_delta is not None
expiration_ts = random.randrange( expiration_ts = random.randrange(
expiration_ts - self._account_validity_startup_job_max_delta, int(expiration_ts - self._account_validity_startup_job_max_delta),
expiration_ts, expiration_ts,
) )

View file

@ -46,15 +46,16 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
"was: %r" % (config.key.macaroon_secret_key,) "was: %r" % (config.key.macaroon_secret_key,)
) )
config = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file]) config2 = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file])
assert config2 is not None
self.assertTrue( self.assertTrue(
hasattr(config.key, "macaroon_secret_key"), hasattr(config2.key, "macaroon_secret_key"),
"Want config to have attr macaroon_secret_key", "Want config to have attr macaroon_secret_key",
) )
if len(config.key.macaroon_secret_key) < 5: if len(config2.key.macaroon_secret_key) < 5:
self.fail( self.fail(
"Want macaroon secret key to be string of at least length 5," "Want macaroon secret key to be string of at least length 5,"
"was: %r" % (config.key.macaroon_secret_key,) "was: %r" % (config2.key.macaroon_secret_key,)
) )
def test_load_succeeds_if_macaroon_secret_key_missing(self): def test_load_succeeds_if_macaroon_secret_key_missing(self):
@ -62,6 +63,9 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
config1 = HomeServerConfig.load_config("", ["-c", self.config_file]) config1 = HomeServerConfig.load_config("", ["-c", self.config_file])
config2 = HomeServerConfig.load_config("", ["-c", self.config_file]) config2 = HomeServerConfig.load_config("", ["-c", self.config_file])
config3 = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file]) config3 = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file])
assert config1 is not None
assert config2 is not None
assert config3 is not None
self.assertEqual( self.assertEqual(
config1.key.macaroon_secret_key, config2.key.macaroon_secret_key config1.key.macaroon_secret_key, config2.key.macaroon_secret_key
) )
@ -78,14 +82,16 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
config = HomeServerConfig.load_config("", ["-c", self.config_file]) config = HomeServerConfig.load_config("", ["-c", self.config_file])
self.assertFalse(config.registration.enable_registration) self.assertFalse(config.registration.enable_registration)
config = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file]) config2 = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file])
self.assertFalse(config.registration.enable_registration) assert config2 is not None
self.assertFalse(config2.registration.enable_registration)
# Check that either config value is clobbered by the command line. # Check that either config value is clobbered by the command line.
config = HomeServerConfig.load_or_generate_config( config3 = HomeServerConfig.load_or_generate_config(
"", ["-c", self.config_file, "--enable-registration"] "", ["-c", self.config_file, "--enable-registration"]
) )
self.assertTrue(config.registration.enable_registration) assert config3 is not None
self.assertTrue(config3.registration.enable_registration)
def test_stats_enabled(self): def test_stats_enabled(self):
self.generate_config_and_remove_lines_containing("enable_metrics") self.generate_config_and_remove_lines_containing("enable_metrics")