mirror of
https://github.com/element-hq/synapse.git
synced 2025-01-20 18:42:33 +00:00
Annotate synapse.storage.util (#10892)
Also mark `synapse.streams` as having has no untyped defs Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com>
This commit is contained in:
parent
797ee7812d
commit
51a5da74cc
8 changed files with 124 additions and 65 deletions
1
changelog.d/10892.misc
Normal file
1
changelog.d/10892.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add further type hints to `synapse.storage.util`.
|
6
mypy.ini
6
mypy.ini
|
@ -105,6 +105,12 @@ disallow_untyped_defs = True
|
||||||
[mypy-synapse.state.*]
|
[mypy-synapse.state.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.storage.util.*]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.streams.*]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-synapse.util.batching_queue]
|
[mypy-synapse.util.batching_queue]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
|
|
@ -13,14 +13,14 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from synapse.storage.types import Connection
|
from synapse.storage.database import LoggingDatabaseConnection
|
||||||
from synapse.storage.util.id_generators import _load_current_id
|
from synapse.storage.util.id_generators import _load_current_id
|
||||||
|
|
||||||
|
|
||||||
class SlavedIdTracker:
|
class SlavedIdTracker:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
db_conn: Connection,
|
db_conn: LoggingDatabaseConnection,
|
||||||
table: str,
|
table: str,
|
||||||
column: str,
|
column: str,
|
||||||
extra_tables: Optional[List[Tuple[str, str]]] = None,
|
extra_tables: Optional[List[Tuple[str, str]]] = None,
|
||||||
|
|
|
@ -15,9 +15,8 @@
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from synapse.replication.tcp.streams import PushersStream
|
from synapse.replication.tcp.streams import PushersStream
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
||||||
from synapse.storage.databases.main.pusher import PusherWorkerStore
|
from synapse.storage.databases.main.pusher import PusherWorkerStore
|
||||||
from synapse.storage.types import Connection
|
|
||||||
|
|
||||||
from ._base import BaseSlavedStore
|
from ._base import BaseSlavedStore
|
||||||
from ._slaved_id_tracker import SlavedIdTracker
|
from ._slaved_id_tracker import SlavedIdTracker
|
||||||
|
@ -27,7 +26,12 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
|
|
||||||
class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
|
class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
database: DatabasePool,
|
||||||
|
db_conn: LoggingDatabaseConnection,
|
||||||
|
hs: "HomeServer",
|
||||||
|
):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
self._pushers_id_gen = SlavedIdTracker( # type: ignore
|
self._pushers_id_gen = SlavedIdTracker( # type: ignore
|
||||||
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
|
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
|
||||||
|
|
|
@ -18,8 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional,
|
||||||
|
|
||||||
from synapse.push import PusherConfig, ThrottleParams
|
from synapse.push import PusherConfig, ThrottleParams
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
||||||
from synapse.storage.types import Connection
|
|
||||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
|
@ -32,7 +31,12 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PusherWorkerStore(SQLBaseStore):
|
class PusherWorkerStore(SQLBaseStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
database: DatabasePool,
|
||||||
|
db_conn: LoggingDatabaseConnection,
|
||||||
|
hs: "HomeServer",
|
||||||
|
):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
self._pushers_id_gen = StreamIdGenerator(
|
self._pushers_id_gen = StreamIdGenerator(
|
||||||
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
|
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
|
||||||
|
|
|
@ -26,7 +26,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
|
||||||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
||||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||||
from synapse.storage.databases.main.stats import StatsStore
|
from synapse.storage.databases.main.stats import StatsStore
|
||||||
from synapse.storage.types import Connection, Cursor
|
from synapse.storage.types import Cursor
|
||||||
from synapse.storage.util.id_generators import IdGenerator
|
from synapse.storage.util.id_generators import IdGenerator
|
||||||
from synapse.storage.util.sequence import build_sequence_generator
|
from synapse.storage.util.sequence import build_sequence_generator
|
||||||
from synapse.types import UserID, UserInfo
|
from synapse.types import UserID, UserInfo
|
||||||
|
@ -1775,7 +1775,12 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
||||||
|
|
||||||
|
|
||||||
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
database: DatabasePool,
|
||||||
|
db_conn: LoggingDatabaseConnection,
|
||||||
|
hs: "HomeServer",
|
||||||
|
):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
self._ignore_unknown_session_error = (
|
self._ignore_unknown_session_error = (
|
||||||
|
|
|
@ -16,42 +16,62 @@ import logging
|
||||||
import threading
|
import threading
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
|
from types import TracebackType
|
||||||
|
from typing import (
|
||||||
|
AsyncContextManager,
|
||||||
|
ContextManager,
|
||||||
|
Dict,
|
||||||
|
Generator,
|
||||||
|
Generic,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from sortedcontainers import SortedSet
|
from sortedcontainers import SortedSet
|
||||||
|
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
from synapse.storage.database import (
|
||||||
|
DatabasePool,
|
||||||
|
LoggingDatabaseConnection,
|
||||||
|
LoggingTransaction,
|
||||||
|
)
|
||||||
from synapse.storage.types import Cursor
|
from synapse.storage.types import Cursor
|
||||||
from synapse.storage.util.sequence import PostgresSequenceGenerator
|
from synapse.storage.util.sequence import PostgresSequenceGenerator
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class IdGenerator:
|
class IdGenerator:
|
||||||
def __init__(self, db_conn, table, column):
|
def __init__(
|
||||||
|
self,
|
||||||
|
db_conn: LoggingDatabaseConnection,
|
||||||
|
table: str,
|
||||||
|
column: str,
|
||||||
|
):
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
self._next_id = _load_current_id(db_conn, table, column)
|
self._next_id = _load_current_id(db_conn, table, column)
|
||||||
|
|
||||||
def get_next(self):
|
def get_next(self) -> int:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._next_id += 1
|
self._next_id += 1
|
||||||
return self._next_id
|
return self._next_id
|
||||||
|
|
||||||
|
|
||||||
def _load_current_id(db_conn, table, column, step=1):
|
def _load_current_id(
|
||||||
"""
|
db_conn: LoggingDatabaseConnection, table: str, column: str, step: int = 1
|
||||||
|
) -> int:
|
||||||
Args:
|
|
||||||
db_conn (object):
|
|
||||||
table (str):
|
|
||||||
column (str):
|
|
||||||
step (int):
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int
|
|
||||||
"""
|
|
||||||
# debug logging for https://github.com/matrix-org/synapse/issues/7968
|
# debug logging for https://github.com/matrix-org/synapse/issues/7968
|
||||||
logger.info("initialising stream generator for %s(%s)", table, column)
|
logger.info("initialising stream generator for %s(%s)", table, column)
|
||||||
cur = db_conn.cursor(txn_name="_load_current_id")
|
cur = db_conn.cursor(txn_name="_load_current_id")
|
||||||
|
@ -59,7 +79,9 @@ def _load_current_id(db_conn, table, column, step=1):
|
||||||
cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
|
cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
|
||||||
else:
|
else:
|
||||||
cur.execute("SELECT MIN(%s) FROM %s" % (column, table))
|
cur.execute("SELECT MIN(%s) FROM %s" % (column, table))
|
||||||
(val,) = cur.fetchone()
|
result = cur.fetchone()
|
||||||
|
assert result is not None
|
||||||
|
(val,) = result
|
||||||
cur.close()
|
cur.close()
|
||||||
current_id = int(val) if val else step
|
current_id = int(val) if val else step
|
||||||
return (max if step > 0 else min)(current_id, step)
|
return (max if step > 0 else min)(current_id, step)
|
||||||
|
@ -93,16 +115,16 @@ class StreamIdGenerator:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
db_conn,
|
db_conn: LoggingDatabaseConnection,
|
||||||
table,
|
table: str,
|
||||||
column,
|
column: str,
|
||||||
extra_tables: Iterable[Tuple[str, str]] = (),
|
extra_tables: Iterable[Tuple[str, str]] = (),
|
||||||
step=1,
|
step: int = 1,
|
||||||
):
|
) -> None:
|
||||||
assert step != 0
|
assert step != 0
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
self._step = step
|
self._step: int = step
|
||||||
self._current = _load_current_id(db_conn, table, column, step)
|
self._current: int = _load_current_id(db_conn, table, column, step)
|
||||||
for table, column in extra_tables:
|
for table, column in extra_tables:
|
||||||
self._current = (max if step > 0 else min)(
|
self._current = (max if step > 0 else min)(
|
||||||
self._current, _load_current_id(db_conn, table, column, step)
|
self._current, _load_current_id(db_conn, table, column, step)
|
||||||
|
@ -115,7 +137,7 @@ class StreamIdGenerator:
|
||||||
# The key and values are the same, but we never look at the values.
|
# The key and values are the same, but we never look at the values.
|
||||||
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
|
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
|
||||||
|
|
||||||
def get_next(self):
|
def get_next(self) -> AsyncContextManager[int]:
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
async with stream_id_gen.get_next() as stream_id:
|
async with stream_id_gen.get_next() as stream_id:
|
||||||
|
@ -128,7 +150,7 @@ class StreamIdGenerator:
|
||||||
self._unfinished_ids[next_id] = next_id
|
self._unfinished_ids[next_id] = next_id
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def manager():
|
def manager() -> Generator[int, None, None]:
|
||||||
try:
|
try:
|
||||||
yield next_id
|
yield next_id
|
||||||
finally:
|
finally:
|
||||||
|
@ -137,7 +159,7 @@ class StreamIdGenerator:
|
||||||
|
|
||||||
return _AsyncCtxManagerWrapper(manager())
|
return _AsyncCtxManagerWrapper(manager())
|
||||||
|
|
||||||
def get_next_mult(self, n):
|
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
async with stream_id_gen.get_next(n) as stream_ids:
|
async with stream_id_gen.get_next(n) as stream_ids:
|
||||||
|
@ -155,7 +177,7 @@ class StreamIdGenerator:
|
||||||
self._unfinished_ids[next_id] = next_id
|
self._unfinished_ids[next_id] = next_id
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def manager():
|
def manager() -> Generator[Sequence[int], None, None]:
|
||||||
try:
|
try:
|
||||||
yield next_ids
|
yield next_ids
|
||||||
finally:
|
finally:
|
||||||
|
@ -215,7 +237,7 @@ class MultiWriterIdGenerator:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
db_conn,
|
db_conn: LoggingDatabaseConnection,
|
||||||
db: DatabasePool,
|
db: DatabasePool,
|
||||||
stream_name: str,
|
stream_name: str,
|
||||||
instance_name: str,
|
instance_name: str,
|
||||||
|
@ -223,7 +245,7 @@ class MultiWriterIdGenerator:
|
||||||
sequence_name: str,
|
sequence_name: str,
|
||||||
writers: List[str],
|
writers: List[str],
|
||||||
positive: bool = True,
|
positive: bool = True,
|
||||||
):
|
) -> None:
|
||||||
self._db = db
|
self._db = db
|
||||||
self._stream_name = stream_name
|
self._stream_name = stream_name
|
||||||
self._instance_name = instance_name
|
self._instance_name = instance_name
|
||||||
|
@ -285,9 +307,9 @@ class MultiWriterIdGenerator:
|
||||||
|
|
||||||
def _load_current_ids(
|
def _load_current_ids(
|
||||||
self,
|
self,
|
||||||
db_conn,
|
db_conn: LoggingDatabaseConnection,
|
||||||
tables: List[Tuple[str, str, str]],
|
tables: List[Tuple[str, str, str]],
|
||||||
):
|
) -> None:
|
||||||
cur = db_conn.cursor(txn_name="_load_current_ids")
|
cur = db_conn.cursor(txn_name="_load_current_ids")
|
||||||
|
|
||||||
# Load the current positions of all writers for the stream.
|
# Load the current positions of all writers for the stream.
|
||||||
|
@ -335,7 +357,9 @@ class MultiWriterIdGenerator:
|
||||||
"agg": "MAX" if self._positive else "-MIN",
|
"agg": "MAX" if self._positive else "-MIN",
|
||||||
}
|
}
|
||||||
cur.execute(sql)
|
cur.execute(sql)
|
||||||
(stream_id,) = cur.fetchone()
|
result = cur.fetchone()
|
||||||
|
assert result is not None
|
||||||
|
(stream_id,) = result
|
||||||
|
|
||||||
max_stream_id = max(max_stream_id, stream_id)
|
max_stream_id = max(max_stream_id, stream_id)
|
||||||
|
|
||||||
|
@ -354,7 +378,7 @@ class MultiWriterIdGenerator:
|
||||||
|
|
||||||
self._persisted_upto_position = min_stream_id
|
self._persisted_upto_position = min_stream_id
|
||||||
|
|
||||||
rows = []
|
rows: List[Tuple[str, int]] = []
|
||||||
for table, instance_column, id_column in tables:
|
for table, instance_column, id_column in tables:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT %(instance)s, %(id)s FROM %(table)s
|
SELECT %(instance)s, %(id)s FROM %(table)s
|
||||||
|
@ -367,7 +391,8 @@ class MultiWriterIdGenerator:
|
||||||
}
|
}
|
||||||
cur.execute(sql, (min_stream_id * self._return_factor,))
|
cur.execute(sql, (min_stream_id * self._return_factor,))
|
||||||
|
|
||||||
rows.extend(cur)
|
# Cast safety: this corresponds to the types returned by the query above.
|
||||||
|
rows.extend(cast(Iterable[Tuple[str, int]], cur))
|
||||||
|
|
||||||
# Sort so that we handle rows in order for each instance.
|
# Sort so that we handle rows in order for each instance.
|
||||||
rows.sort()
|
rows.sort()
|
||||||
|
@ -385,13 +410,13 @@ class MultiWriterIdGenerator:
|
||||||
|
|
||||||
cur.close()
|
cur.close()
|
||||||
|
|
||||||
def _load_next_id_txn(self, txn) -> int:
|
def _load_next_id_txn(self, txn: Cursor) -> int:
|
||||||
return self._sequence_gen.get_next_id_txn(txn)
|
return self._sequence_gen.get_next_id_txn(txn)
|
||||||
|
|
||||||
def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
|
def _load_next_mult_id_txn(self, txn: Cursor, n: int) -> List[int]:
|
||||||
return self._sequence_gen.get_next_mult_txn(txn, n)
|
return self._sequence_gen.get_next_mult_txn(txn, n)
|
||||||
|
|
||||||
def get_next(self):
|
def get_next(self) -> AsyncContextManager[int]:
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
async with stream_id_gen.get_next() as stream_id:
|
async with stream_id_gen.get_next() as stream_id:
|
||||||
|
@ -403,9 +428,12 @@ class MultiWriterIdGenerator:
|
||||||
if self._writers and self._instance_name not in self._writers:
|
if self._writers and self._instance_name not in self._writers:
|
||||||
raise Exception("Tried to allocate stream ID on non-writer")
|
raise Exception("Tried to allocate stream ID on non-writer")
|
||||||
|
|
||||||
return _MultiWriterCtxManager(self)
|
# Cast safety: the second argument to _MultiWriterCtxManager, multiple_ids,
|
||||||
|
# controls the return type. If `None` or omitted, the context manager yields
|
||||||
|
# a single integer stream_id; otherwise it yields a list of stream_ids.
|
||||||
|
return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
|
||||||
|
|
||||||
def get_next_mult(self, n: int):
|
def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
async with stream_id_gen.get_next_mult(5) as stream_ids:
|
async with stream_id_gen.get_next_mult(5) as stream_ids:
|
||||||
|
@ -417,9 +445,10 @@ class MultiWriterIdGenerator:
|
||||||
if self._writers and self._instance_name not in self._writers:
|
if self._writers and self._instance_name not in self._writers:
|
||||||
raise Exception("Tried to allocate stream ID on non-writer")
|
raise Exception("Tried to allocate stream ID on non-writer")
|
||||||
|
|
||||||
return _MultiWriterCtxManager(self, n)
|
# Cast safety: see get_next.
|
||||||
|
return cast(AsyncContextManager[List[int]], _MultiWriterCtxManager(self, n))
|
||||||
|
|
||||||
def get_next_txn(self, txn: LoggingTransaction):
|
def get_next_txn(self, txn: LoggingTransaction) -> int:
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
|
@ -457,7 +486,7 @@ class MultiWriterIdGenerator:
|
||||||
|
|
||||||
return self._return_factor * next_id
|
return self._return_factor * next_id
|
||||||
|
|
||||||
def _mark_id_as_finished(self, next_id: int):
|
def _mark_id_as_finished(self, next_id: int) -> None:
|
||||||
"""The ID has finished being processed so we should advance the
|
"""The ID has finished being processed so we should advance the
|
||||||
current position if possible.
|
current position if possible.
|
||||||
"""
|
"""
|
||||||
|
@ -534,7 +563,7 @@ class MultiWriterIdGenerator:
|
||||||
for name, i in self._current_positions.items()
|
for name, i in self._current_positions.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
def advance(self, instance_name: str, new_id: int):
|
def advance(self, instance_name: str, new_id: int) -> None:
|
||||||
"""Advance the position of the named writer to the given ID, if greater
|
"""Advance the position of the named writer to the given ID, if greater
|
||||||
than existing entry.
|
than existing entry.
|
||||||
"""
|
"""
|
||||||
|
@ -560,7 +589,7 @@ class MultiWriterIdGenerator:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
return self._return_factor * self._persisted_upto_position
|
return self._return_factor * self._persisted_upto_position
|
||||||
|
|
||||||
def _add_persisted_position(self, new_id: int):
|
def _add_persisted_position(self, new_id: int) -> None:
|
||||||
"""Record that we have persisted a position.
|
"""Record that we have persisted a position.
|
||||||
|
|
||||||
This is used to keep the `_current_positions` up to date.
|
This is used to keep the `_current_positions` up to date.
|
||||||
|
@ -606,7 +635,7 @@ class MultiWriterIdGenerator:
|
||||||
# do.
|
# do.
|
||||||
break
|
break
|
||||||
|
|
||||||
def _update_stream_positions_table_txn(self, txn: Cursor):
|
def _update_stream_positions_table_txn(self, txn: Cursor) -> None:
|
||||||
"""Update the `stream_positions` table with newly persisted position."""
|
"""Update the `stream_positions` table with newly persisted position."""
|
||||||
|
|
||||||
if not self._writers:
|
if not self._writers:
|
||||||
|
@ -628,20 +657,25 @@ class MultiWriterIdGenerator:
|
||||||
txn.execute(sql, (self._stream_name, self._instance_name, pos))
|
txn.execute(sql, (self._stream_name, self._instance_name, pos))
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True)
|
@attr.s(frozen=True, auto_attribs=True)
|
||||||
class _AsyncCtxManagerWrapper:
|
class _AsyncCtxManagerWrapper(Generic[T]):
|
||||||
"""Helper class to convert a plain context manager to an async one.
|
"""Helper class to convert a plain context manager to an async one.
|
||||||
|
|
||||||
This is mainly useful if you have a plain context manager but the interface
|
This is mainly useful if you have a plain context manager but the interface
|
||||||
requires an async one.
|
requires an async one.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
inner = attr.ib()
|
inner: ContextManager[T]
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self) -> T:
|
||||||
return self.inner.__enter__()
|
return self.inner.__enter__()
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc, tb):
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[Type[BaseException]],
|
||||||
|
exc: Optional[BaseException],
|
||||||
|
tb: Optional[TracebackType],
|
||||||
|
) -> Optional[bool]:
|
||||||
return self.inner.__exit__(exc_type, exc, tb)
|
return self.inner.__exit__(exc_type, exc, tb)
|
||||||
|
|
||||||
|
|
||||||
|
@ -671,7 +705,12 @@ class _MultiWriterCtxManager:
|
||||||
else:
|
else:
|
||||||
return [i * self.id_gen._return_factor for i in self.stream_ids]
|
return [i * self.id_gen._return_factor for i in self.stream_ids]
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc, tb):
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[Type[BaseException]],
|
||||||
|
exc: Optional[BaseException],
|
||||||
|
tb: Optional[TracebackType],
|
||||||
|
) -> bool:
|
||||||
for i in self.stream_ids:
|
for i in self.stream_ids:
|
||||||
self.id_gen._mark_id_as_finished(i)
|
self.id_gen._mark_id_as_finished(i)
|
||||||
|
|
||||||
|
|
|
@ -81,7 +81,7 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
|
||||||
id_column: str,
|
id_column: str,
|
||||||
stream_name: Optional[str] = None,
|
stream_name: Optional[str] = None,
|
||||||
positive: bool = True,
|
positive: bool = True,
|
||||||
):
|
) -> None:
|
||||||
"""Should be called during start up to test that the current value of
|
"""Should be called during start up to test that the current value of
|
||||||
the sequence is greater than or equal to the maximum ID in the table.
|
the sequence is greater than or equal to the maximum ID in the table.
|
||||||
|
|
||||||
|
@ -122,7 +122,7 @@ class PostgresSequenceGenerator(SequenceGenerator):
|
||||||
id_column: str,
|
id_column: str,
|
||||||
stream_name: Optional[str] = None,
|
stream_name: Optional[str] = None,
|
||||||
positive: bool = True,
|
positive: bool = True,
|
||||||
):
|
) -> None:
|
||||||
"""See SequenceGenerator.check_consistency for docstring."""
|
"""See SequenceGenerator.check_consistency for docstring."""
|
||||||
|
|
||||||
txn = db_conn.cursor(txn_name="sequence.check_consistency")
|
txn = db_conn.cursor(txn_name="sequence.check_consistency")
|
||||||
|
@ -244,7 +244,7 @@ class LocalSequenceGenerator(SequenceGenerator):
|
||||||
id_column: str,
|
id_column: str,
|
||||||
stream_name: Optional[str] = None,
|
stream_name: Optional[str] = None,
|
||||||
positive: bool = True,
|
positive: bool = True,
|
||||||
):
|
) -> None:
|
||||||
# There is nothing to do for in memory sequences
|
# There is nothing to do for in memory sequences
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue