mirror of
https://github.com/element-hq/synapse.git
synced 2025-03-05 15:37:02 +00:00
Type tests.utils
(#13028)
* Cast to postgres types when handling postgres db * Remove unused method * Easy annotations * Annotate create_room * Use `ParamSpec` to annotate looping_call * Annotate `default_config` * Track `now` as a float `time_ms` returns an int like the proper Synapse `Clock` * Introduce a `Timer` dataclass * Introduce a Looper type * Suppress checking of a mock * tests.utils is typed * Changelog * Whoops, import ParamSpec from typing_extensions * ditch the psycopg2 casts
This commit is contained in:
parent
68695d8007
commit
6ba732fefe
5 changed files with 101 additions and 45 deletions
1
changelog.d/13028.misc
Normal file
1
changelog.d/13028.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add type annotations to `tests.utils`.
|
3
mypy.ini
3
mypy.ini
|
@ -126,6 +126,9 @@ disallow_untyped_defs = True
|
||||||
[mypy-tests.federation.transport.test_client]
|
[mypy-tests.federation.transport.test_client]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-tests.utils]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
|
||||||
;; Dependencies without annotations
|
;; Dependencies without annotations
|
||||||
;; Before ignoring a module, check to see if type stubs are available.
|
;; Before ignoring a module, check to see if type stubs are available.
|
||||||
|
|
|
@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, Generator, Optional
|
||||||
import attr
|
import attr
|
||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
from matrix_common.versionstring import get_distribution_version_string
|
from matrix_common.versionstring import get_distribution_version_string
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
from twisted.internet import defer, task
|
from twisted.internet import defer, task
|
||||||
from twisted.internet.defer import Deferred
|
from twisted.internet.defer import Deferred
|
||||||
|
@ -82,6 +83,9 @@ def unwrapFirstError(failure: Failure) -> Failure:
|
||||||
return failure.value.subFailure # type: ignore[union-attr] # Issue in Twisted's annotations
|
return failure.value.subFailure # type: ignore[union-attr] # Issue in Twisted's annotations
|
||||||
|
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True)
|
@attr.s(slots=True)
|
||||||
class Clock:
|
class Clock:
|
||||||
"""
|
"""
|
||||||
|
@ -110,7 +114,7 @@ class Clock:
|
||||||
return int(self.time() * 1000)
|
return int(self.time() * 1000)
|
||||||
|
|
||||||
def looping_call(
|
def looping_call(
|
||||||
self, f: Callable, msec: float, *args: Any, **kwargs: Any
|
self, f: Callable[P, object], msec: float, *args: P.args, **kwargs: P.kwargs
|
||||||
) -> LoopingCall:
|
) -> LoopingCall:
|
||||||
"""Call a function repeatedly.
|
"""Call a function repeatedly.
|
||||||
|
|
||||||
|
|
|
@ -109,7 +109,7 @@ GLOBAL_ROOT = ListNode["_Node"].create_root_node()
|
||||||
|
|
||||||
@wrap_as_background_process("LruCache._expire_old_entries")
|
@wrap_as_background_process("LruCache._expire_old_entries")
|
||||||
async def _expire_old_entries(
|
async def _expire_old_entries(
|
||||||
clock: Clock, expiry_seconds: int, autotune_config: Optional[dict]
|
clock: Clock, expiry_seconds: float, autotune_config: Optional[dict]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Walks the global cache list to find cache entries that haven't been
|
"""Walks the global cache list to find cache entries that haven't been
|
||||||
accessed in the given number of seconds, or if a given memory threshold has been breached.
|
accessed in the given number of seconds, or if a given memory threshold has been breached.
|
||||||
|
|
134
tests/utils.py
134
tests/utils.py
|
@ -15,12 +15,17 @@
|
||||||
|
|
||||||
import atexit
|
import atexit
|
||||||
import os
|
import os
|
||||||
|
from typing import Any, Callable, Dict, List, Tuple, Union, overload
|
||||||
|
|
||||||
|
import attr
|
||||||
|
from typing_extensions import Literal, ParamSpec
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.api.room_versions import RoomVersions
|
from synapse.api.room_versions import RoomVersions
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.config.server import DEFAULT_ROOM_VERSION
|
from synapse.config.server import DEFAULT_ROOM_VERSION
|
||||||
from synapse.logging.context import current_context, set_current_context
|
from synapse.logging.context import current_context, set_current_context
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.database import LoggingDatabaseConnection
|
from synapse.storage.database import LoggingDatabaseConnection
|
||||||
from synapse.storage.engines import create_engine
|
from synapse.storage.engines import create_engine
|
||||||
from synapse.storage.prepare_database import prepare_database
|
from synapse.storage.prepare_database import prepare_database
|
||||||
|
@ -50,12 +55,11 @@ SQLITE_PERSIST_DB = os.environ.get("SYNAPSE_TEST_PERSIST_SQLITE_DB") is not None
|
||||||
POSTGRES_DBNAME_FOR_INITIAL_CREATE = "postgres"
|
POSTGRES_DBNAME_FOR_INITIAL_CREATE = "postgres"
|
||||||
|
|
||||||
|
|
||||||
def setupdb():
|
def setupdb() -> None:
|
||||||
# If we're using PostgreSQL, set up the db once
|
# If we're using PostgreSQL, set up the db once
|
||||||
if USE_POSTGRES_FOR_TESTS:
|
if USE_POSTGRES_FOR_TESTS:
|
||||||
# create a PostgresEngine
|
# create a PostgresEngine
|
||||||
db_engine = create_engine({"name": "psycopg2", "args": {}})
|
db_engine = create_engine({"name": "psycopg2", "args": {}})
|
||||||
|
|
||||||
# connect to postgres to create the base database.
|
# connect to postgres to create the base database.
|
||||||
db_conn = db_engine.module.connect(
|
db_conn = db_engine.module.connect(
|
||||||
user=POSTGRES_USER,
|
user=POSTGRES_USER,
|
||||||
|
@ -82,11 +86,11 @@ def setupdb():
|
||||||
port=POSTGRES_PORT,
|
port=POSTGRES_PORT,
|
||||||
password=POSTGRES_PASSWORD,
|
password=POSTGRES_PASSWORD,
|
||||||
)
|
)
|
||||||
db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
|
logging_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
|
||||||
prepare_database(db_conn, db_engine, None)
|
prepare_database(logging_conn, db_engine, None)
|
||||||
db_conn.close()
|
logging_conn.close()
|
||||||
|
|
||||||
def _cleanup():
|
def _cleanup() -> None:
|
||||||
db_conn = db_engine.module.connect(
|
db_conn = db_engine.module.connect(
|
||||||
user=POSTGRES_USER,
|
user=POSTGRES_USER,
|
||||||
host=POSTGRES_HOST,
|
host=POSTGRES_HOST,
|
||||||
|
@ -103,7 +107,19 @@ def setupdb():
|
||||||
atexit.register(_cleanup)
|
atexit.register(_cleanup)
|
||||||
|
|
||||||
|
|
||||||
def default_config(name, parse=False):
|
@overload
|
||||||
|
def default_config(name: str, parse: Literal[False] = ...) -> Dict[str, object]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def default_config(name: str, parse: Literal[True]) -> HomeServerConfig:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def default_config(
|
||||||
|
name: str, parse: bool = False
|
||||||
|
) -> Union[Dict[str, object], HomeServerConfig]:
|
||||||
"""
|
"""
|
||||||
Create a reasonable test config.
|
Create a reasonable test config.
|
||||||
"""
|
"""
|
||||||
|
@ -181,90 +197,122 @@ def default_config(name, parse=False):
|
||||||
return config_dict
|
return config_dict
|
||||||
|
|
||||||
|
|
||||||
def mock_getRawHeaders(headers=None):
|
def mock_getRawHeaders(headers=None): # type: ignore[no-untyped-def]
|
||||||
headers = headers if headers is not None else {}
|
headers = headers if headers is not None else {}
|
||||||
|
|
||||||
def getRawHeaders(name, default=None):
|
def getRawHeaders(name, default=None): # type: ignore[no-untyped-def]
|
||||||
|
# If the requested header is present, the real twisted function returns
|
||||||
|
# List[str] if name is a str and List[bytes] if name is a bytes.
|
||||||
|
# This mock doesn't support that behaviour.
|
||||||
|
# Fortunately, none of the current callers of mock_getRawHeaders() provide a
|
||||||
|
# headers dict, so we don't encounter this discrepancy in practice.
|
||||||
return headers.get(name, default)
|
return headers.get(name, default)
|
||||||
|
|
||||||
return getRawHeaders
|
return getRawHeaders
|
||||||
|
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, auto_attribs=True)
|
||||||
|
class Timer:
|
||||||
|
absolute_time: float
|
||||||
|
callback: Callable[[], None]
|
||||||
|
expired: bool
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Make this generic over a ParamSpec?
|
||||||
|
@attr.s(slots=True, auto_attribs=True)
|
||||||
|
class Looper:
|
||||||
|
func: Callable[..., Any]
|
||||||
|
interval: float # seconds
|
||||||
|
last: float
|
||||||
|
args: Tuple[object, ...]
|
||||||
|
kwargs: Dict[str, object]
|
||||||
|
|
||||||
|
|
||||||
class MockClock:
|
class MockClock:
|
||||||
now = 1000
|
now = 1000.0
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
# list of lists of [absolute_time, callback, expired] in no particular
|
# Timers in no particular order
|
||||||
# order
|
self.timers: List[Timer] = []
|
||||||
self.timers = []
|
self.loopers: List[Looper] = []
|
||||||
self.loopers = []
|
|
||||||
|
|
||||||
def time(self):
|
def time(self) -> float:
|
||||||
return self.now
|
return self.now
|
||||||
|
|
||||||
def time_msec(self):
|
def time_msec(self) -> int:
|
||||||
return self.time() * 1000
|
return int(self.time() * 1000)
|
||||||
|
|
||||||
def call_later(self, delay, callback, *args, **kwargs):
|
def call_later(
|
||||||
|
self,
|
||||||
|
delay: float,
|
||||||
|
callback: Callable[P, object],
|
||||||
|
*args: P.args,
|
||||||
|
**kwargs: P.kwargs,
|
||||||
|
) -> Timer:
|
||||||
ctx = current_context()
|
ctx = current_context()
|
||||||
|
|
||||||
def wrapped_callback():
|
def wrapped_callback() -> None:
|
||||||
set_current_context(ctx)
|
set_current_context(ctx)
|
||||||
callback(*args, **kwargs)
|
callback(*args, **kwargs)
|
||||||
|
|
||||||
t = [self.now + delay, wrapped_callback, False]
|
t = Timer(self.now + delay, wrapped_callback, False)
|
||||||
self.timers.append(t)
|
self.timers.append(t)
|
||||||
|
|
||||||
return t
|
return t
|
||||||
|
|
||||||
def looping_call(self, function, interval, *args, **kwargs):
|
def looping_call(
|
||||||
self.loopers.append([function, interval / 1000.0, self.now, args, kwargs])
|
self,
|
||||||
|
function: Callable[P, object],
|
||||||
|
interval: float,
|
||||||
|
*args: P.args,
|
||||||
|
**kwargs: P.kwargs,
|
||||||
|
) -> None:
|
||||||
|
# This type-ignore should be redundant once we use a mypy release with
|
||||||
|
# https://github.com/python/mypy/pull/12668.
|
||||||
|
self.loopers.append(Looper(function, interval / 1000.0, self.now, args, kwargs)) # type: ignore[arg-type]
|
||||||
|
|
||||||
def cancel_call_later(self, timer, ignore_errs=False):
|
def cancel_call_later(self, timer: Timer, ignore_errs: bool = False) -> None:
|
||||||
if timer[2]:
|
if timer.expired:
|
||||||
if not ignore_errs:
|
if not ignore_errs:
|
||||||
raise Exception("Cannot cancel an expired timer")
|
raise Exception("Cannot cancel an expired timer")
|
||||||
|
|
||||||
timer[2] = True
|
timer.expired = True
|
||||||
self.timers = [t for t in self.timers if t != timer]
|
self.timers = [t for t in self.timers if t != timer]
|
||||||
|
|
||||||
# For unit testing
|
# For unit testing
|
||||||
def advance_time(self, secs):
|
def advance_time(self, secs: float) -> None:
|
||||||
self.now += secs
|
self.now += secs
|
||||||
|
|
||||||
timers = self.timers
|
timers = self.timers
|
||||||
self.timers = []
|
self.timers = []
|
||||||
|
|
||||||
for t in timers:
|
for t in timers:
|
||||||
time, callback, expired = t
|
if t.expired:
|
||||||
|
|
||||||
if expired:
|
|
||||||
raise Exception("Timer already expired")
|
raise Exception("Timer already expired")
|
||||||
|
|
||||||
if self.now >= time:
|
if self.now >= t.absolute_time:
|
||||||
t[2] = True
|
t.expired = True
|
||||||
callback()
|
t.callback()
|
||||||
else:
|
else:
|
||||||
self.timers.append(t)
|
self.timers.append(t)
|
||||||
|
|
||||||
for looped in self.loopers:
|
for looped in self.loopers:
|
||||||
func, interval, last, args, kwargs = looped
|
if looped.last + looped.interval < self.now:
|
||||||
if last + interval < self.now:
|
looped.func(*looped.args, **looped.kwargs)
|
||||||
func(*args, **kwargs)
|
looped.last = self.now
|
||||||
looped[2] = self.now
|
|
||||||
|
|
||||||
def advance_time_msec(self, ms):
|
def advance_time_msec(self, ms: float) -> None:
|
||||||
self.advance_time(ms / 1000.0)
|
self.advance_time(ms / 1000.0)
|
||||||
|
|
||||||
def time_bound_deferred(self, d, *args, **kwargs):
|
|
||||||
# We don't bother timing things out for now.
|
|
||||||
return d
|
|
||||||
|
|
||||||
|
async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None:
|
||||||
async def create_room(hs, room_id: str, creator_id: str):
|
|
||||||
"""Creates and persist a creation event for the given room"""
|
"""Creates and persist a creation event for the given room"""
|
||||||
|
|
||||||
persistence_store = hs.get_storage_controllers().persistence
|
persistence_store = hs.get_storage_controllers().persistence
|
||||||
|
assert persistence_store is not None
|
||||||
store = hs.get_datastores().main
|
store = hs.get_datastores().main
|
||||||
event_builder_factory = hs.get_event_builder_factory()
|
event_builder_factory = hs.get_event_builder_factory()
|
||||||
event_creation_handler = hs.get_event_creation_handler()
|
event_creation_handler = hs.get_event_creation_handler()
|
||||||
|
|
Loading…
Add table
Reference in a new issue