Use a chain cover index to efficiently calculate auth chain difference (#8868)

This commit is contained in:
Erik Johnston 2021-01-11 16:09:22 +00:00 committed by GitHub
parent 671138f658
commit 1315a2e8be
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 1777 additions and 56 deletions

1
changelog.d/8868.misc Normal file
View file

@ -0,0 +1 @@
Improve efficiency of large state resolutions for new rooms.

32
docs/auth_chain_diff.dot Normal file
View file

@ -0,0 +1,32 @@
digraph auth {
nodesep=0.5;
rankdir="RL";
C [label="Create (1,1)"];
BJ [label="Bob's Join (2,1)", color=red];
BJ2 [label="Bob's Join (2,2)", color=red];
BJ2 -> BJ [color=red, dir=none];
subgraph cluster_foo {
A1 [label="Alice's invite (4,1)", color=blue];
A2 [label="Alice's Join (4,2)", color=blue];
A3 [label="Alice's Join (4,3)", color=blue];
A3 -> A2 -> A1 [color=blue, dir=none];
color=none;
}
PL1 [label="Power Level (3,1)", color=darkgreen];
PL2 [label="Power Level (3,2)", color=darkgreen];
PL2 -> PL1 [color=darkgreen, dir=none];
{rank = same; C; BJ; PL1; A1;}
A1 -> C [color=grey];
A1 -> BJ [color=grey];
PL1 -> C [color=grey];
BJ2 -> PL1 [penwidth=2];
A3 -> PL2 [penwidth=2];
A1 -> PL1 -> BJ -> C [penwidth=2];
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

View file

@ -0,0 +1,108 @@
# Auth Chain Difference Algorithm
The auth chain difference algorithm is used by V2 state resolution, where a
naive implementation can be a significant source of CPU and DB usage.
### Definitions
A *state set* is a set of state events; e.g. the input of a state resolution
algorithm is a collection of state sets.
The *auth chain* of a set of events are all the events' auth events and *their*
auth events, recursively (i.e. the events reachable by walking the graph induced
by an event's auth events links).
The *auth chain difference* of a collection of state sets is the union minus the
intersection of the sets of auth chains corresponding to the state sets, i.e an
event is in the auth chain difference if it is reachable by walking the auth
event graph from at least one of the state sets but not from *all* of the state
sets.
## Breadth First Walk Algorithm
A way of calculating the auth chain difference without calculating the full auth
chains for each state set is to do a parallel breadth first walk (ordered by
depth) of each state set's auth chain. By tracking which events are reachable
from each state set we can finish early if every pending event is reachable from
every state set.
This can work well for state sets that have a small auth chain difference, but
can be very inefficient for larger differences. However, this algorithm is still
used if we don't have a chain cover index for the room (e.g. because we're in
the process of indexing it).
## Chain Cover Index
Synapse computes auth chain differences by pre-computing a "chain cover" index
for the auth chain in a room, allowing efficient reachability queries like "is
event A in the auth chain of event B". This is done by assigning every event a
*chain ID* and *sequence number* (e.g. `(5,3)`), and having a map of *links*
between chains (e.g. `(5,3) -> (2,4)`) such that A is reachable by B (i.e. `A`
is in the auth chain of `B`) if and only if either:
1. A and B have the same chain ID and `A`'s sequence number is less than `B`'s
sequence number; or
2. there is a link `L` between `B`'s chain ID and `A`'s chain ID such that
`L.start_seq_no` <= `B.seq_no` and `A.seq_no` <= `L.end_seq_no`.
There are actually two potential implementations, one where we store links from
each chain to every other reachable chain (the transitive closure of the links
graph), and one where we remove redundant links (the transitive reduction of the
links graph) e.g. if we have chains `C3 -> C2 -> C1` then the link `C3 -> C1`
would not be stored. Synapse uses the former implementations so that it doesn't
need to recurse to test reachability between chains.
### Example
An example auth graph would look like the following, where chains have been
formed based on type/state_key and are denoted by colour and are labelled with
`(chain ID, sequence number)`. Links are denoted by the arrows (links in grey
are those that would be remove in the second implementation described above).
![Example](auth_chain_diff.dot.png)
Note that we don't include all links between events and their auth events, as
most of those links would be redundant. For example, all events point to the
create event, but each chain only needs the one link from it's base to the
create event.
## Using the Index
This index can be used to calculate the auth chain difference of the state sets
by looking at the chain ID and sequence numbers reachable from each state set:
1. For every state set lookup the chain ID/sequence numbers of each state event
2. Use the index to find all chains and the maximum sequence number reachable
from each state set.
3. The auth chain difference is then all events in each chain that have sequence
numbers between the maximum sequence number reachable from *any* state set and
the minimum reachable by *all* state sets (if any).
Note that steps 2 is effectively calculating the auth chain for each state set
(in terms of chain IDs and sequence numbers), and step 3 is calculating the
difference between the union and intersection of the auth chains.
### Worked Example
For example, given the above graph, we can calculate the difference between
state sets consisting of:
1. `S1`: Alice's invite `(4,1)` and Bob's second join `(2,2)`; and
2. `S2`: Alice's second join `(4,3)` and Bob's first join `(2,1)`.
Using the index we see that the following auth chains are reachable from each
state set:
1. `S1`: `(1,1)`, `(2,2)`, `(3,1)` & `(4,1)`
2. `S2`: `(1,1)`, `(2,1)`, `(3,2)` & `(4,3)`
And so, for each the ranges that are in the auth chain difference:
1. Chain 1: None, (since everything can reach the create event).
2. Chain 2: The range `(1, 2]` (i.e. just `2`), as `1` is reachable by all state
sets and the maximum reachable is `2` (corresponding to Bob's second join).
3. Chain 3: Similarly the range `(1, 2]` (corresponding to the second power
level).
4. Chain 4: The range `(1, 3]` (corresponding to both of Alice's joins).
So the final result is: Bob's second join `(2,2)`, the second power level
`(3,2)` and both of Alice's joins `(4,2)` & `(4,3)`.

View file

@ -179,6 +179,9 @@ class LoggingDatabaseConnection:
_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]] _CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
R = TypeVar("R")
class LoggingTransaction: class LoggingTransaction:
"""An object that almost-transparently proxies for the 'txn' object """An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute() passed to the constructor. Adds logging and metrics to the .execute()
@ -266,6 +269,20 @@ class LoggingTransaction:
for val in args: for val in args:
self.execute(sql, val) self.execute(sql, val)
def execute_values(self, sql: str, *args: Any) -> List[Tuple]:
"""Corresponds to psycopg2.extras.execute_values. Only available when
using postgres.
Always sets fetch=True when caling `execute_values`, so will return the
results.
"""
assert isinstance(self.database_engine, PostgresEngine)
from psycopg2.extras import execute_values # type: ignore
return self._do_execute(
lambda *x: execute_values(self.txn, *x, fetch=True), sql, *args
)
def execute(self, sql: str, *args: Any) -> None: def execute(self, sql: str, *args: Any) -> None:
self._do_execute(self.txn.execute, sql, *args) self._do_execute(self.txn.execute, sql, *args)
@ -276,7 +293,7 @@ class LoggingTransaction:
"Strip newlines out of SQL so that the loggers in the DB are on one line" "Strip newlines out of SQL so that the loggers in the DB are on one line"
return " ".join(line.strip() for line in sql.splitlines() if line.strip()) return " ".join(line.strip() for line in sql.splitlines() if line.strip())
def _do_execute(self, func, sql: str, *args: Any) -> None: def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R:
sql = self._make_sql_one_line(sql) sql = self._make_sql_one_line(sql)
# TODO(paul): Maybe use 'info' and 'debug' for values? # TODO(paul): Maybe use 'info' and 'debug' for values?
@ -347,9 +364,6 @@ class PerformanceCounters:
return top_n_counters return top_n_counters
R = TypeVar("R")
class DatabasePool: class DatabasePool:
"""Wraps a single physical database and connection pool. """Wraps a single physical database and connection pool.

View file

@ -24,6 +24,8 @@ from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Cursor
from synapse.types import Collection from synapse.types import Collection
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
@ -32,6 +34,11 @@ from synapse.util.iterutils import batch_iter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class _NoChainCoverIndex(Exception):
def __init__(self, room_id: str):
super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
@ -151,15 +158,193 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
The set of the difference in auth chains. The set of the difference in auth chains.
""" """
# Check if we have indexed the room so we can use the chain cover
# algorithm.
room = await self.get_room(room_id)
if room["has_auth_chain_index"]:
try:
return await self.db_pool.runInteraction(
"get_auth_chain_difference_chains",
self._get_auth_chain_difference_using_cover_index_txn,
room_id,
state_sets,
)
except _NoChainCoverIndex:
# For whatever reason we don't actually have a chain cover index
# for the events in question, so we fall back to the old method.
pass
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_auth_chain_difference", "get_auth_chain_difference",
self._get_auth_chain_difference_txn, self._get_auth_chain_difference_txn,
state_sets, state_sets,
) )
def _get_auth_chain_difference_using_cover_index_txn(
self, txn: Cursor, room_id: str, state_sets: List[Set[str]]
) -> Set[str]:
"""Calculates the auth chain difference using the chain index.
See docs/auth_chain_difference_algorithm.md for details
"""
# First we look up the chain ID/sequence numbers for all the events, and
# work out the chain/sequence numbers reachable from each state set.
initial_events = set(state_sets[0]).union(*state_sets[1:])
# Map from event_id -> (chain ID, seq no)
chain_info = {} # type: Dict[str, Tuple[int, int]]
# Map from chain ID -> seq no -> event Id
chain_to_event = {} # type: Dict[int, Dict[int, str]]
# All the chains that we've found that are reachable from the state
# sets.
seen_chains = set() # type: Set[int]
sql = """
SELECT event_id, chain_id, sequence_number
FROM event_auth_chains
WHERE %s
"""
for batch in batch_iter(initial_events, 1000):
clause, args = make_in_list_sql_clause(
txn.database_engine, "event_id", batch
)
txn.execute(sql % (clause,), args)
for event_id, chain_id, sequence_number in txn:
chain_info[event_id] = (chain_id, sequence_number)
seen_chains.add(chain_id)
chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id
# Check that we actually have a chain ID for all the events.
events_missing_chain_info = initial_events.difference(chain_info)
if events_missing_chain_info:
# This can happen due to e.g. downgrade/upgrade of the server. We
# raise an exception and fall back to the previous algorithm.
logger.info(
"Unexpectedly found that events don't have chain IDs in room %s: %s",
room_id,
events_missing_chain_info,
)
raise _NoChainCoverIndex(room_id)
# Corresponds to `state_sets`, except as a map from chain ID to max
# sequence number reachable from the state set.
set_to_chain = [] # type: List[Dict[int, int]]
for state_set in state_sets:
chains = {} # type: Dict[int, int]
set_to_chain.append(chains)
for event_id in state_set:
chain_id, seq_no = chain_info[event_id]
chains[chain_id] = max(seq_no, chains.get(chain_id, 0))
# Now we look up all links for the chains we have, adding chains to
# set_to_chain that are reachable from each set.
sql = """
SELECT
origin_chain_id, origin_sequence_number,
target_chain_id, target_sequence_number
FROM event_auth_chain_links
WHERE %s
"""
# (We need to take a copy of `seen_chains` as we want to mutate it in
# the loop)
for batch in batch_iter(set(seen_chains), 1000):
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch
)
txn.execute(sql % (clause,), args)
for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in txn:
for chains in set_to_chain:
# chains are only reachable if the origin sequence number of
# the link is less than the max sequence number in the
# origin chain.
if origin_sequence_number <= chains.get(origin_chain_id, 0):
chains[target_chain_id] = max(
target_sequence_number, chains.get(target_chain_id, 0),
)
seen_chains.add(target_chain_id)
# Now for each chain we figure out the maximum sequence number reachable
# from *any* state set and the minimum sequence number reachable from
# *all* state sets. Events in that range are in the auth chain
# difference.
result = set()
# Mapping from chain ID to the range of sequence numbers that should be
# pulled from the database.
chain_to_gap = {} # type: Dict[int, Tuple[int, int]]
for chain_id in seen_chains:
min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain)
max_seq_no = max(chains.get(chain_id, 0) for chains in set_to_chain)
if min_seq_no < max_seq_no:
# We have a non empty gap, try and fill it from the events that
# we have, otherwise add them to the list of gaps to pull out
# from the DB.
for seq_no in range(min_seq_no + 1, max_seq_no + 1):
event_id = chain_to_event.get(chain_id, {}).get(seq_no)
if event_id:
result.add(event_id)
else:
chain_to_gap[chain_id] = (min_seq_no, max_seq_no)
break
if not chain_to_gap:
# If there are no gaps to fetch, we're done!
return result
if isinstance(self.database_engine, PostgresEngine):
# We can use `execute_values` to efficiently fetch the gaps when
# using postgres.
sql = """
SELECT event_id
FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq)
WHERE
c.chain_id = l.chain_id
AND min_seq < sequence_number AND sequence_number <= max_seq
"""
args = [
(chain_id, min_no, max_no)
for chain_id, (min_no, max_no) in chain_to_gap.items()
]
rows = txn.execute_values(sql, args)
result.update(r for r, in rows)
else:
# For SQLite we just fall back to doing a noddy for loop.
sql = """
SELECT event_id FROM event_auth_chains
WHERE chain_id = ? AND ? < sequence_number AND sequence_number <= ?
"""
for chain_id, (min_no, max_no) in chain_to_gap.items():
txn.execute(sql, (chain_id, min_no, max_no))
result.update(r for r, in txn)
return result
def _get_auth_chain_difference_txn( def _get_auth_chain_difference_txn(
self, txn, state_sets: List[Set[str]] self, txn, state_sets: List[Set[str]]
) -> Set[str]: ) -> Set[str]:
"""Calculates the auth chain difference using a breadth first search.
This is used when we don't have a cover index for the room.
"""
# Algorithm Description # Algorithm Description
# ~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~

View file

@ -17,7 +17,17 @@
import itertools import itertools
import logging import logging
from collections import OrderedDict, namedtuple from collections import OrderedDict, namedtuple
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple from typing import (
TYPE_CHECKING,
Any,
Dict,
Generator,
Iterable,
List,
Optional,
Set,
Tuple,
)
import attr import attr
from prometheus_client import Counter from prometheus_client import Counter
@ -33,9 +43,10 @@ from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchEntry from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import StateMap, get_domain_from_id from synapse.types import StateMap, get_domain_from_id
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter, sorted_topologically
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -89,6 +100,14 @@ class PersistEventsStore:
self._clock = hs.get_clock() self._clock = hs.get_clock()
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
def get_chain_id_txn(txn):
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
return txn.fetchone()[0]
self._event_chain_id_gen = build_sequence_generator(
db.engine, get_chain_id_txn, "event_auth_chain_id"
)
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
@ -366,26 +385,7 @@ class PersistEventsStore:
# Insert into event_to_state_groups. # Insert into event_to_state_groups.
self._store_event_state_mappings_txn(txn, events_and_contexts) self._store_event_state_mappings_txn(txn, events_and_contexts)
# We want to store event_auth mappings for rejected events, as they're self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts])
# used in state res v2.
# This is only necessary if the rejected event appears in an accepted
# event's auth chain, but its easier for now just to store them (and
# it doesn't take much storage compared to storing the entire event
# anyway).
self.db_pool.simple_insert_many_txn(
txn,
table="event_auth",
values=[
{
"event_id": event.event_id,
"room_id": event.room_id,
"auth_id": auth_id,
}
for event, _ in events_and_contexts
for auth_id in event.auth_event_ids()
if event.is_state()
],
)
# _store_rejected_events_txn filters out any events which were # _store_rejected_events_txn filters out any events which were
# rejected, and returns the filtered list. # rejected, and returns the filtered list.
@ -407,6 +407,381 @@ class PersistEventsStore:
# room_memberships, where applicable. # room_memberships, where applicable.
self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
def _persist_event_auth_chain_txn(
self, txn: LoggingTransaction, events: List[EventBase],
) -> None:
# We only care about state events, so this if there are no state events.
if not any(e.is_state() for e in events):
return
# We want to store event_auth mappings for rejected events, as they're
# used in state res v2.
# This is only necessary if the rejected event appears in an accepted
# event's auth chain, but its easier for now just to store them (and
# it doesn't take much storage compared to storing the entire event
# anyway).
self.db_pool.simple_insert_many_txn(
txn,
table="event_auth",
values=[
{
"event_id": event.event_id,
"room_id": event.room_id,
"auth_id": auth_id,
}
for event in events
for auth_id in event.auth_event_ids()
if event.is_state()
],
)
# We now calculate chain ID/sequence numbers for any state events we're
# persisting. We ignore out of band memberships as we're not in the room
# and won't have their auth chain (we'll fix it up later if we join the
# room).
#
# See: docs/auth_chain_difference_algorithm.md
# We ignore legacy rooms that we aren't filling the chain cover index
# for.
rows = self.db_pool.simple_select_many_txn(
txn,
table="rooms",
column="room_id",
iterable={event.room_id for event in events if event.is_state()},
keyvalues={},
retcols=("room_id", "has_auth_chain_index"),
)
rooms_using_chain_index = {
row["room_id"] for row in rows if row["has_auth_chain_index"]
}
state_events = {
event.event_id: event
for event in events
if event.is_state() and event.room_id in rooms_using_chain_index
}
if not state_events:
return
# Map from event ID to chain ID/sequence number.
chain_map = {} # type: Dict[str, Tuple[int, int]]
# We need to know the type/state_key and auth events of the events we're
# calculating chain IDs for. We don't rely on having the full Event
# instances as we'll potentially be pulling more events from the DB and
# we don't need the overhead of fetching/parsing the full event JSON.
event_to_types = {
e.event_id: (e.type, e.state_key) for e in state_events.values()
}
event_to_auth_chain = {
e.event_id: e.auth_event_ids() for e in state_events.values()
}
# Set of event IDs to calculate chain ID/seq numbers for.
events_to_calc_chain_id_for = set(state_events)
# We check if there are any events that need to be handled in the rooms
# we're looking at. These should just be out of band memberships, where
# we didn't have the auth chain when we first persisted.
rows = self.db_pool.simple_select_many_txn(
txn,
table="event_auth_chain_to_calculate",
keyvalues={},
column="room_id",
iterable={e.room_id for e in state_events.values()},
retcols=("event_id", "type", "state_key"),
)
for row in rows:
event_id = row["event_id"]
event_type = row["type"]
state_key = row["state_key"]
# (We could pull out the auth events for all rows at once using
# simple_select_many, but this case happens rarely and almost always
# with a single row.)
auth_events = self.db_pool.simple_select_onecol_txn(
txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id",
)
events_to_calc_chain_id_for.add(event_id)
event_to_types[event_id] = (event_type, state_key)
event_to_auth_chain[event_id] = auth_events
# First we get the chain ID and sequence numbers for the events'
# auth events (that aren't also currently being persisted).
#
# Note that there there is an edge case here where we might not have
# calculated chains and sequence numbers for events that were "out
# of band". We handle this case by fetching the necessary info and
# adding it to the set of events to calculate chain IDs for.
missing_auth_chains = {
a_id
for auth_events in event_to_auth_chain.values()
for a_id in auth_events
if a_id not in events_to_calc_chain_id_for
}
# We loop here in case we find an out of band membership and need to
# fetch their auth event info.
while missing_auth_chains:
sql = """
SELECT event_id, events.type, state_key, chain_id, sequence_number
FROM events
INNER JOIN state_events USING (event_id)
LEFT JOIN event_auth_chains USING (event_id)
WHERE
"""
clause, args = make_in_list_sql_clause(
txn.database_engine, "event_id", missing_auth_chains,
)
txn.execute(sql + clause, args)
missing_auth_chains.clear()
for auth_id, event_type, state_key, chain_id, sequence_number in txn:
event_to_types[auth_id] = (event_type, state_key)
if chain_id is None:
# No chain ID, so the event was persisted out of band.
# We add to list of events to calculate auth chains for.
events_to_calc_chain_id_for.add(auth_id)
event_to_auth_chain[
auth_id
] = self.db_pool.simple_select_onecol_txn(
txn,
"event_auth",
keyvalues={"event_id": auth_id},
retcol="auth_id",
)
missing_auth_chains.update(
e
for e in event_to_auth_chain[auth_id]
if e not in event_to_types
)
else:
chain_map[auth_id] = (chain_id, sequence_number)
# Now we check if we have any events where we don't have auth chain,
# this should only be out of band memberships.
for event_id in sorted_topologically(event_to_auth_chain, event_to_auth_chain):
for auth_id in event_to_auth_chain[event_id]:
if (
auth_id not in chain_map
and auth_id not in events_to_calc_chain_id_for
):
events_to_calc_chain_id_for.discard(event_id)
# If this is an event we're trying to persist we add it to
# the list of events to calculate chain IDs for next time
# around. (Otherwise we will have already added it to the
# table).
event = state_events.get(event_id)
if event:
self.db_pool.simple_insert_txn(
txn,
table="event_auth_chain_to_calculate",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
},
)
# We stop checking the event's auth events since we've
# discarded it.
break
if not events_to_calc_chain_id_for:
return
# We now calculate the chain IDs/sequence numbers for the events. We
# do this by looking at the chain ID and sequence number of any auth
# event with the same type/state_key and incrementing the sequence
# number by one. If there was no match or the chain ID/sequence
# number is already taken we generate a new chain.
#
# We need to do this in a topologically sorted order as we want to
# generate chain IDs/sequence numbers of an event's auth events
# before the event itself.
chains_tuples_allocated = set() # type: Set[Tuple[int, int]]
new_chain_tuples = {} # type: Dict[str, Tuple[int, int]]
for event_id in sorted_topologically(
events_to_calc_chain_id_for, event_to_auth_chain
):
existing_chain_id = None
for auth_id in event_to_auth_chain[event_id]:
if event_to_types.get(event_id) == event_to_types.get(auth_id):
existing_chain_id = chain_map[auth_id]
break
new_chain_tuple = None
if existing_chain_id:
# We found a chain ID/sequence number candidate, check its
# not already taken.
proposed_new_id = existing_chain_id[0]
proposed_new_seq = existing_chain_id[1] + 1
if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated:
already_allocated = self.db_pool.simple_select_one_onecol_txn(
txn,
table="event_auth_chains",
keyvalues={
"chain_id": proposed_new_id,
"sequence_number": proposed_new_seq,
},
retcol="event_id",
allow_none=True,
)
if already_allocated:
# Mark it as already allocated so we don't need to hit
# the DB again.
chains_tuples_allocated.add((proposed_new_id, proposed_new_seq))
else:
new_chain_tuple = (
proposed_new_id,
proposed_new_seq,
)
if not new_chain_tuple:
new_chain_tuple = (self._event_chain_id_gen.get_next_id_txn(txn), 1)
chains_tuples_allocated.add(new_chain_tuple)
chain_map[event_id] = new_chain_tuple
new_chain_tuples[event_id] = new_chain_tuple
self.db_pool.simple_insert_many_txn(
txn,
table="event_auth_chains",
values=[
{"event_id": event_id, "chain_id": c_id, "sequence_number": seq}
for event_id, (c_id, seq) in new_chain_tuples.items()
],
)
self.db_pool.simple_delete_many_txn(
txn,
table="event_auth_chain_to_calculate",
keyvalues={},
column="event_id",
iterable=new_chain_tuples,
)
# Now we need to calculate any new links between chains caused by
# the new events.
#
# Links are pairs of chain ID/sequence numbers such that for any
# event A (CA, SA) and any event B (CB, SB), B is in A's auth chain
# if and only if there is at least one link (CA, S1) -> (CB, S2)
# where SA >= S1 and S2 >= SB.
#
# We try and avoid adding redundant links to the table, e.g. if we
# have two links between two chains which both start/end at the
# sequence number event (or cross) then one can be safely dropped.
#
# To calculate new links we look at every new event and:
# 1. Fetch the chain ID/sequence numbers of its auth events,
# discarding any that are reachable by other auth events, or
# that have the same chain ID as the event.
# 2. For each retained auth event we:
# a. Add a link from the event's to the auth event's chain
# ID/sequence number; and
# b. Add a link from the event to every chain reachable by the
# auth event.
# Step 1, fetch all existing links from all the chains we've seen
# referenced.
chain_links = _LinkMap()
rows = self.db_pool.simple_select_many_txn(
txn,
table="event_auth_chain_links",
column="origin_chain_id",
iterable={chain_id for chain_id, _ in chain_map.values()},
keyvalues={},
retcols=(
"origin_chain_id",
"origin_sequence_number",
"target_chain_id",
"target_sequence_number",
),
)
for row in rows:
chain_links.add_link(
(row["origin_chain_id"], row["origin_sequence_number"]),
(row["target_chain_id"], row["target_sequence_number"]),
new=False,
)
# We do this in toplogical order to avoid adding redundant links.
for event_id in sorted_topologically(
events_to_calc_chain_id_for, event_to_auth_chain
):
chain_id, sequence_number = chain_map[event_id]
# Filter out auth events that are reachable by other auth
# events. We do this by looking at every permutation of pairs of
# auth events (A, B) to check if B is reachable from A.
reduction = {
a_id
for a_id in event_to_auth_chain[event_id]
if chain_map[a_id][0] != chain_id
}
for start_auth_id, end_auth_id in itertools.permutations(
event_to_auth_chain[event_id], r=2,
):
if chain_links.exists_path_from(
chain_map[start_auth_id], chain_map[end_auth_id]
):
reduction.discard(end_auth_id)
# Step 2, figure out what the new links are from the reduced
# list of auth events.
for auth_id in reduction:
auth_chain_id, auth_sequence_number = chain_map[auth_id]
# Step 2a, add link between the event and auth event
chain_links.add_link(
(chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
)
# Step 2b, add a link to chains reachable from the auth
# event.
for target_id, target_seq in chain_links.get_links_from(
(auth_chain_id, auth_sequence_number)
):
if target_id == chain_id:
continue
chain_links.add_link(
(chain_id, sequence_number), (target_id, target_seq)
)
self.db_pool.simple_insert_many_txn(
txn,
table="event_auth_chain_links",
values=[
{
"origin_chain_id": source_id,
"origin_sequence_number": source_seq,
"target_chain_id": target_id,
"target_sequence_number": target_seq,
}
for (
source_id,
source_seq,
target_id,
target_seq,
) in chain_links.get_additions()
],
)
def _persist_transaction_ids_txn( def _persist_transaction_ids_txn(
self, self,
txn: LoggingTransaction, txn: LoggingTransaction,
@ -1521,3 +1896,131 @@ class PersistEventsStore:
if not ev.internal_metadata.is_outlier() if not ev.internal_metadata.is_outlier()
], ],
) )
@attr.s(slots=True)
class _LinkMap:
"""A helper type for tracking links between chains.
"""
# Stores the set of links as nested maps: source chain ID -> target chain ID
# -> source sequence number -> target sequence number.
maps = attr.ib(type=Dict[int, Dict[int, Dict[int, int]]], factory=dict)
# Stores the links that have been added (with new set to true), as tuples of
# `(source chain ID, source sequence no, target chain ID, target sequence no.)`
additions = attr.ib(type=Set[Tuple[int, int, int, int]], factory=set)
def add_link(
self,
src_tuple: Tuple[int, int],
target_tuple: Tuple[int, int],
new: bool = True,
) -> bool:
"""Add a new link between two chains, ensuring no redundant links are added.
New links should be added in topological order.
Args:
src_tuple: The chain ID/sequence number of the source of the link.
target_tuple: The chain ID/sequence number of the target of the link.
new: Whether this is a "new" link, i.e. should it be returned
by `get_additions`.
Returns:
True if a link was added, false if the given link was dropped as redundant
"""
src_chain, src_seq = src_tuple
target_chain, target_seq = target_tuple
current_links = self.maps.setdefault(src_chain, {}).setdefault(target_chain, {})
assert src_chain != target_chain
if new:
# Check if the new link is redundant
for current_seq_src, current_seq_target in current_links.items():
# If a link "crosses" another link then its redundant. For example
# in the following link 1 (L1) is redundant, as any event reachable
# via L1 is *also* reachable via L2.
#
# Chain A Chain B
# | |
# L1 |------ |
# | | |
# L2 |---- | -->|
# | | |
# | |--->|
# | |
# | |
#
# So we only need to keep links which *do not* cross, i.e. links
# that both start and end above or below an existing link.
#
# Note, since we add links in topological ordering we should never
# see `src_seq` less than `current_seq_src`.
if current_seq_src <= src_seq and target_seq <= current_seq_target:
# This new link is redundant, nothing to do.
return False
self.additions.add((src_chain, src_seq, target_chain, target_seq))
current_links[src_seq] = target_seq
return True
def get_links_from(
self, src_tuple: Tuple[int, int]
) -> Generator[Tuple[int, int], None, None]:
"""Gets the chains reachable from the given chain/sequence number.
Yields:
The chain ID and sequence number the link points to.
"""
src_chain, src_seq = src_tuple
for target_id, sequence_numbers in self.maps.get(src_chain, {}).items():
for link_src_seq, target_seq in sequence_numbers.items():
if link_src_seq <= src_seq:
yield target_id, target_seq
def get_links_between(
self, source_chain: int, target_chain: int
) -> Generator[Tuple[int, int], None, None]:
"""Gets the links between two chains.
Yields:
The source and target sequence numbers.
"""
yield from self.maps.get(source_chain, {}).get(target_chain, {}).items()
def get_additions(self) -> Generator[Tuple[int, int, int, int], None, None]:
"""Gets any newly added links.
Yields:
The source chain ID/sequence number and target chain ID/sequence number
"""
for src_chain, src_seq, target_chain, _ in self.additions:
target_seq = self.maps.get(src_chain, {}).get(target_chain, {}).get(src_seq)
if target_seq is not None:
yield (src_chain, src_seq, target_chain, target_seq)
def exists_path_from(
self, src_tuple: Tuple[int, int], target_tuple: Tuple[int, int],
) -> bool:
"""Checks if there is a path between the source chain ID/sequence and
target chain ID/sequence.
"""
src_chain, src_seq = src_tuple
target_chain, target_seq = target_tuple
if src_chain == target_chain:
return target_seq <= src_seq
links = self.get_links_between(src_chain, target_chain)
for link_start_seq, link_end_seq in links:
if link_start_seq <= src_seq and target_seq <= link_end_seq:
return True
return False

View file

@ -84,7 +84,7 @@ class RoomWorkerStore(SQLBaseStore):
return await self.db_pool.simple_select_one( return await self.db_pool.simple_select_one(
table="rooms", table="rooms",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
retcols=("room_id", "is_public", "creator"), retcols=("room_id", "is_public", "creator", "has_auth_chain_index"),
desc="get_room", desc="get_room",
allow_none=True, allow_none=True,
) )
@ -1166,6 +1166,37 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
# It's overridden by RoomStore for the synapse master. # It's overridden by RoomStore for the synapse master.
raise NotImplementedError() raise NotImplementedError()
async def has_auth_chain_index(self, room_id: str) -> bool:
"""Check if the room has (or can have) a chain cover index.
Defaults to True if we don't have an entry in `rooms` table nor any
events for the room.
"""
has_auth_chain_index = await self.db_pool.simple_select_one_onecol(
table="rooms",
keyvalues={"room_id": room_id},
retcol="has_auth_chain_index",
desc="has_auth_chain_index",
allow_none=True,
)
if has_auth_chain_index:
return True
# It's possible that we already have events for the room in our DB
# without a corresponding room entry. If we do then we don't want to
# mark the room as having an auth chain cover index.
max_ordering = await self.db_pool.simple_select_one_onecol(
table="events",
keyvalues={"room_id": room_id},
retcol="MAX(stream_ordering)",
allow_none=True,
desc="upsert_room_on_join",
)
return max_ordering is None
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
@ -1179,12 +1210,21 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
Called when we join a room over federation, and overwrites any room version Called when we join a room over federation, and overwrites any room version
currently in the table. currently in the table.
""" """
# It's possible that we already have events for the room in our DB
# without a corresponding room entry. If we do then we don't want to
# mark the room as having an auth chain cover index.
has_auth_chain_index = await self.has_auth_chain_index(room_id)
await self.db_pool.simple_upsert( await self.db_pool.simple_upsert(
desc="upsert_room_on_join", desc="upsert_room_on_join",
table="rooms", table="rooms",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
values={"room_version": room_version.identifier}, values={"room_version": room_version.identifier},
insertion_values={"is_public": False, "creator": ""}, insertion_values={
"is_public": False,
"creator": "",
"has_auth_chain_index": has_auth_chain_index,
},
# rooms has a unique constraint on room_id, so no need to lock when doing an # rooms has a unique constraint on room_id, so no need to lock when doing an
# emulated upsert. # emulated upsert.
lock=False, lock=False,
@ -1219,6 +1259,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
"creator": room_creator_user_id, "creator": room_creator_user_id,
"is_public": is_public, "is_public": is_public,
"room_version": room_version.identifier, "room_version": room_version.identifier,
"has_auth_chain_index": True,
}, },
) )
if is_public: if is_public:
@ -1247,6 +1288,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
When we receive an invite or any other event over federation that may relate to a room When we receive an invite or any other event over federation that may relate to a room
we are not in, store the version of the room if we don't already know the room version. we are not in, store the version of the room if we don't already know the room version.
""" """
# It's possible that we already have events for the room in our DB
# without a corresponding room entry. If we do then we don't want to
# mark the room as having an auth chain cover index.
has_auth_chain_index = await self.has_auth_chain_index(room_id)
await self.db_pool.simple_upsert( await self.db_pool.simple_upsert(
desc="maybe_store_room_on_outlier_membership", desc="maybe_store_room_on_outlier_membership",
table="rooms", table="rooms",
@ -1256,6 +1302,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
"room_version": room_version.identifier, "room_version": room_version.identifier,
"is_public": False, "is_public": False,
"creator": "", "creator": "",
"has_auth_chain_index": has_auth_chain_index,
}, },
# rooms has a unique constraint on room_id, so no need to lock when doing an # rooms has a unique constraint on room_id, so no need to lock when doing an
# emulated upsert. # emulated upsert.

View file

@ -0,0 +1,52 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- See docs/auth_chain_difference_algorithm.md
CREATE TABLE event_auth_chains (
event_id TEXT PRIMARY KEY,
chain_id BIGINT NOT NULL,
sequence_number BIGINT NOT NULL
);
CREATE UNIQUE INDEX event_auth_chains_c_seq_index ON event_auth_chains (chain_id, sequence_number);
CREATE TABLE event_auth_chain_links (
origin_chain_id BIGINT NOT NULL,
origin_sequence_number BIGINT NOT NULL,
target_chain_id BIGINT NOT NULL,
target_sequence_number BIGINT NOT NULL
);
CREATE INDEX event_auth_chain_links_idx ON event_auth_chain_links (origin_chain_id, target_chain_id);
-- Events that we have persisted but not calculated auth chains for,
-- e.g. out of band memberships (where we don't have the auth chain)
CREATE TABLE event_auth_chain_to_calculate (
event_id TEXT PRIMARY KEY,
room_id TEXT NOT NULL,
type TEXT NOT NULL,
state_key TEXT NOT NULL
);
CREATE INDEX event_auth_chain_to_calculate_rm_id ON event_auth_chain_to_calculate(room_id);
-- Whether we've calculated the above index for a room.
ALTER TABLE rooms ADD COLUMN has_auth_chain_index BOOLEAN;

View file

@ -0,0 +1,16 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE SEQUENCE IF NOT EXISTS event_auth_chain_id;

View file

@ -13,8 +13,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import heapq
from itertools import islice from itertools import islice
from typing import Iterable, Iterator, Sequence, Tuple, TypeVar from typing import (
Dict,
Generator,
Iterable,
Iterator,
Mapping,
Sequence,
Set,
Tuple,
TypeVar,
)
from synapse.types import Collection
T = TypeVar("T") T = TypeVar("T")
@ -46,3 +59,41 @@ def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]:
If the input is empty, no chunks are returned. If the input is empty, no chunks are returned.
""" """
return (iseq[i : i + maxlen] for i in range(0, len(iseq), maxlen)) return (iseq[i : i + maxlen] for i in range(0, len(iseq), maxlen))
def sorted_topologically(
nodes: Iterable[T], graph: Mapping[T, Collection[T]],
) -> Generator[T, None, None]:
"""Given a set of nodes and a graph, yield the nodes in toplogical order.
For example `sorted_topologically([1, 2], {1: [2]})` will yield `2, 1`.
"""
# This is implemented by Kahn's algorithm.
degree_map = {node: 0 for node in nodes}
reverse_graph = {} # type: Dict[T, Set[T]]
for node, edges in graph.items():
if node not in degree_map:
continue
for edge in edges:
if edge in degree_map:
degree_map[node] += 1
reverse_graph.setdefault(edge, set()).add(node)
reverse_graph.setdefault(node, set())
zero_degree = [node for node, degree in degree_map.items() if degree == 0]
heapq.heapify(zero_degree)
while zero_degree:
node = heapq.heappop(zero_degree)
yield node
for edge in reverse_graph[node]:
if edge in degree_map:
degree_map[edge] -= 1
if degree_map[edge] == 0:
heapq.heappush(zero_degree, edge)

View file

@ -0,0 +1,472 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Tuple
from twisted.trial import unittest
from synapse.api.constants import EventTypes
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
from synapse.storage.databases.main.events import _LinkMap
from tests.unittest import HomeserverTestCase
class EventChainStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self._next_stream_ordering = 1
def test_simple(self):
"""Test that the example in `docs/auth_chain_difference_algorithm.md`
works.
"""
event_factory = self.hs.get_event_builder_factory()
bob = "@creator:test"
alice = "@alice:test"
room_id = "!room:test"
# Ensure that we have a rooms entry so that we generate the chain index.
self.get_success(
self.store.store_room(
room_id=room_id,
room_creator_user_id="",
is_public=True,
room_version=RoomVersions.V6,
)
)
create = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.Create,
"state_key": "",
"sender": bob,
"room_id": room_id,
"content": {"tag": "create"},
},
).build(prev_event_ids=[], auth_event_ids=[])
)
bob_join = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.Member,
"state_key": bob,
"sender": bob,
"room_id": room_id,
"content": {"tag": "bob_join"},
},
).build(prev_event_ids=[], auth_event_ids=[create.event_id])
)
power = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.PowerLevels,
"state_key": "",
"sender": bob,
"room_id": room_id,
"content": {"tag": "power"},
},
).build(
prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
)
)
alice_invite = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.Member,
"state_key": alice,
"sender": bob,
"room_id": room_id,
"content": {"tag": "alice_invite"},
},
).build(
prev_event_ids=[],
auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
)
)
alice_join = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.Member,
"state_key": alice,
"sender": alice,
"room_id": room_id,
"content": {"tag": "alice_join"},
},
).build(
prev_event_ids=[],
auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
)
)
power_2 = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.PowerLevels,
"state_key": "",
"sender": bob,
"room_id": room_id,
"content": {"tag": "power_2"},
},
).build(
prev_event_ids=[],
auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
)
)
bob_join_2 = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.Member,
"state_key": bob,
"sender": bob,
"room_id": room_id,
"content": {"tag": "bob_join_2"},
},
).build(
prev_event_ids=[],
auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
)
)
alice_join2 = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.Member,
"state_key": alice,
"sender": alice,
"room_id": room_id,
"content": {"tag": "alice_join2"},
},
).build(
prev_event_ids=[],
auth_event_ids=[
create.event_id,
alice_join.event_id,
power_2.event_id,
],
)
)
events = [
create,
bob_join,
power,
alice_invite,
alice_join,
bob_join_2,
power_2,
alice_join2,
]
expected_links = [
(bob_join, create),
(power, create),
(power, bob_join),
(alice_invite, create),
(alice_invite, power),
(alice_invite, bob_join),
(bob_join_2, power),
(alice_join2, power_2),
]
self.persist(events)
chain_map, link_map = self.fetch_chains(events)
# Check that the expected links and only the expected links have been
# added.
self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
for start, end in expected_links:
start_id, start_seq = chain_map[start.event_id]
end_id, end_seq = chain_map[end.event_id]
self.assertIn(
(start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
)
# Test that everything can reach the create event, but the create event
# can't reach anything.
for event in events[1:]:
self.assertTrue(
link_map.exists_path_from(
chain_map[event.event_id], chain_map[create.event_id]
),
)
self.assertFalse(
link_map.exists_path_from(
chain_map[create.event_id], chain_map[event.event_id],
),
)
def test_out_of_order_events(self):
"""Test that we handle persisting events that we don't have the full
auth chain for yet (which should only happen for out of band memberships).
"""
event_factory = self.hs.get_event_builder_factory()
bob = "@creator:test"
alice = "@alice:test"
room_id = "!room:test"
# Ensure that we have a rooms entry so that we generate the chain index.
self.get_success(
self.store.store_room(
room_id=room_id,
room_creator_user_id="",
is_public=True,
room_version=RoomVersions.V6,
)
)
# First persist the base room.
create = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.Create,
"state_key": "",
"sender": bob,
"room_id": room_id,
"content": {"tag": "create"},
},
).build(prev_event_ids=[], auth_event_ids=[])
)
bob_join = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.Member,
"state_key": bob,
"sender": bob,
"room_id": room_id,
"content": {"tag": "bob_join"},
},
).build(prev_event_ids=[], auth_event_ids=[create.event_id])
)
power = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.PowerLevels,
"state_key": "",
"sender": bob,
"room_id": room_id,
"content": {"tag": "power"},
},
).build(
prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
)
)
self.persist([create, bob_join, power])
# Now persist an invite and a couple of memberships out of order.
alice_invite = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.Member,
"state_key": alice,
"sender": bob,
"room_id": room_id,
"content": {"tag": "alice_invite"},
},
).build(
prev_event_ids=[],
auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
)
)
alice_join = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.Member,
"state_key": alice,
"sender": alice,
"room_id": room_id,
"content": {"tag": "alice_join"},
},
).build(
prev_event_ids=[],
auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
)
)
alice_join2 = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.Member,
"state_key": alice,
"sender": alice,
"room_id": room_id,
"content": {"tag": "alice_join2"},
},
).build(
prev_event_ids=[],
auth_event_ids=[create.event_id, alice_join.event_id, power.event_id],
)
)
self.persist([alice_join])
self.persist([alice_join2])
self.persist([alice_invite])
# The end result should be sane.
events = [create, bob_join, power, alice_invite, alice_join]
chain_map, link_map = self.fetch_chains(events)
expected_links = [
(bob_join, create),
(power, create),
(power, bob_join),
(alice_invite, create),
(alice_invite, power),
(alice_invite, bob_join),
]
# Check that the expected links and only the expected links have been
# added.
self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
for start, end in expected_links:
start_id, start_seq = chain_map[start.event_id]
end_id, end_seq = chain_map[end.event_id]
self.assertIn(
(start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
)
def persist(
self, events: List[EventBase],
):
"""Persist the given events and check that the links generated match
those given.
"""
persist_events_store = self.hs.get_datastores().persist_events
for e in events:
e.internal_metadata.stream_ordering = self._next_stream_ordering
self._next_stream_ordering += 1
def _persist(txn):
# We need to persist the events to the events and state_events
# tables.
persist_events_store._store_event_txn(txn, [(e, {}) for e in events])
# Actually call the function that calculates the auth chain stuff.
persist_events_store._persist_event_auth_chain_txn(txn, events)
self.get_success(
persist_events_store.db_pool.runInteraction("_persist", _persist,)
)
def fetch_chains(
self, events: List[EventBase]
) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]:
# Fetch the map from event ID -> (chain ID, sequence number)
rows = self.get_success(
self.store.db_pool.simple_select_many_batch(
table="event_auth_chains",
column="event_id",
iterable=[e.event_id for e in events],
retcols=("event_id", "chain_id", "sequence_number"),
keyvalues={},
)
)
chain_map = {
row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows
}
# Fetch all the links and pass them to the _LinkMap.
rows = self.get_success(
self.store.db_pool.simple_select_many_batch(
table="event_auth_chain_links",
column="origin_chain_id",
iterable=[chain_id for chain_id, _ in chain_map.values()],
retcols=(
"origin_chain_id",
"origin_sequence_number",
"target_chain_id",
"target_sequence_number",
),
keyvalues={},
)
)
link_map = _LinkMap()
for row in rows:
added = link_map.add_link(
(row["origin_chain_id"], row["origin_sequence_number"]),
(row["target_chain_id"], row["target_sequence_number"]),
)
# We shouldn't have persisted any redundant links
self.assertTrue(added)
return chain_map, link_map
class LinkMapTestCase(unittest.TestCase):
def test_simple(self):
"""Basic tests for the LinkMap.
"""
link_map = _LinkMap()
link_map.add_link((1, 1), (2, 1), new=False)
self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)])
self.assertCountEqual(link_map.get_additions(), [])
self.assertTrue(link_map.exists_path_from((1, 5), (2, 1)))
self.assertFalse(link_map.exists_path_from((1, 5), (2, 2)))
self.assertTrue(link_map.exists_path_from((1, 5), (1, 1)))
self.assertFalse(link_map.exists_path_from((1, 1), (1, 5)))
# Attempting to add a redundant link is ignored.
self.assertFalse(link_map.add_link((1, 4), (2, 1)))
self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
# Adding new non-redundant links works
self.assertTrue(link_map.add_link((1, 3), (2, 3)))
self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
self.assertTrue(link_map.add_link((2, 5), (1, 3)))
self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)])
self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)])

View file

@ -13,6 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import attr
from parameterized import parameterized
from synapse.events import _EventInternalMetadata
import tests.unittest import tests.unittest
import tests.utils import tests.utils
@ -113,7 +118,154 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1])) r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
self.assertTrue(r == [room2] or r == [room3]) self.assertTrue(r == [room2] or r == [room3])
def test_auth_difference(self): @parameterized.expand([(True,), (False,)])
def test_auth_difference(self, use_chain_cover_index: bool):
room_id = "@ROOM:local"
# The silly auth graph we use to test the auth difference algorithm,
# where the top are the most recent events.
#
# A B
# \ /
# D E
# \ |
# ` F C
# | /|
# G ´ |
# | \ |
# H I
# | |
# K J
auth_graph = {
"a": ["e"],
"b": ["e"],
"c": ["g", "i"],
"d": ["f"],
"e": ["f"],
"f": ["g"],
"g": ["h", "i"],
"h": ["k"],
"i": ["j"],
"k": [],
"j": [],
}
depth_map = {
"a": 7,
"b": 7,
"c": 4,
"d": 6,
"e": 6,
"f": 5,
"g": 3,
"h": 2,
"i": 2,
"k": 1,
"j": 1,
}
# Mark the room as not having a cover index
def store_room(txn):
self.store.db_pool.simple_insert_txn(
txn,
"rooms",
{
"room_id": room_id,
"creator": "room_creator_user_id",
"is_public": True,
"room_version": "6",
"has_auth_chain_index": use_chain_cover_index,
},
)
self.get_success(self.store.db_pool.runInteraction("store_room", store_room))
# We rudely fiddle with the appropriate tables directly, as that's much
# easier than constructing events properly.
def insert_event(txn):
stream_ordering = 0
for event_id in auth_graph:
stream_ordering += 1
depth = depth_map[event_id]
self.store.db_pool.simple_insert_txn(
txn,
table="events",
values={
"event_id": event_id,
"room_id": room_id,
"depth": depth,
"topological_ordering": depth,
"type": "m.test",
"processed": True,
"outlier": False,
"stream_ordering": stream_ordering,
},
)
self.hs.datastores.persist_events._persist_event_auth_chain_txn(
txn,
[
FakeEvent(event_id, room_id, auth_graph[event_id])
for event_id in auth_graph
],
)
self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
# Now actually test that various combinations give the right result:
difference = self.get_success(
self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
)
self.assertSetEqual(difference, {"a", "b"})
difference = self.get_success(
self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}])
)
self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
difference = self.get_success(
self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}])
)
self.assertSetEqual(difference, {"a", "b", "c"})
difference = self.get_success(
self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}])
)
self.assertSetEqual(difference, {"a", "b"})
difference = self.get_success(
self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}])
)
self.assertSetEqual(difference, {"a", "b", "d", "e"})
difference = self.get_success(
self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}])
)
self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
difference = self.get_success(
self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}])
)
self.assertSetEqual(difference, {"a", "b"})
difference = self.get_success(
self.store.get_auth_chain_difference(room_id, [{"a"}])
)
self.assertSetEqual(difference, set())
def test_auth_difference_partial_cover(self):
"""Test that we correctly handle rooms where not all events have a chain
cover calculated. This can happen in some obscure edge cases, including
during the background update that calculates the chain cover for old
rooms.
"""
room_id = "@ROOM:local" room_id = "@ROOM:local"
# The silly auth graph we use to test the auth difference algorithm, # The silly auth graph we use to test the auth difference algorithm,
@ -162,8 +314,24 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# We rudely fiddle with the appropriate tables directly, as that's much # We rudely fiddle with the appropriate tables directly, as that's much
# easier than constructing events properly. # easier than constructing events properly.
def insert_event(txn, event_id, stream_ordering): def insert_event(txn):
# First insert the room and mark it as having a chain cover.
self.store.db_pool.simple_insert_txn(
txn,
"rooms",
{
"room_id": room_id,
"creator": "room_creator_user_id",
"is_public": True,
"room_version": "6",
"has_auth_chain_index": True,
},
)
stream_ordering = 0
for event_id in auth_graph:
stream_ordering += 1
depth = depth_map[event_id] depth = depth_map[event_id]
self.store.db_pool.simple_insert_txn( self.store.db_pool.simple_insert_txn(
@ -181,24 +349,39 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
}, },
) )
self.store.db_pool.simple_insert_many_txn( # Insert all events apart from 'B'
self.hs.datastores.persist_events._persist_event_auth_chain_txn(
txn, txn,
table="event_auth", [
values=[ FakeEvent(event_id, room_id, auth_graph[event_id])
{"event_id": event_id, "room_id": room_id, "auth_id": a} for event_id in auth_graph
for a in auth_graph[event_id] if event_id != "b"
], ],
) )
next_stream_ordering = 0 # Now we insert the event 'B' without a chain cover, by temporarily
for event_id in auth_graph: # pretending the room doesn't have a chain cover.
next_stream_ordering += 1
self.get_success( self.store.db_pool.simple_update_txn(
self.store.db_pool.runInteraction( txn,
"insert", insert_event, event_id, next_stream_ordering table="rooms",
keyvalues={"room_id": room_id},
updatevalues={"has_auth_chain_index": False},
) )
self.hs.datastores.persist_events._persist_event_auth_chain_txn(
txn, [FakeEvent("b", room_id, auth_graph["b"])],
) )
self.store.db_pool.simple_update_txn(
txn,
table="rooms",
keyvalues={"room_id": room_id},
updatevalues={"has_auth_chain_index": True},
)
self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
# Now actually test that various combinations give the right result: # Now actually test that various combinations give the right result:
difference = self.get_success( difference = self.get_success(
@ -240,3 +423,21 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.store.get_auth_chain_difference(room_id, [{"a"}]) self.store.get_auth_chain_difference(room_id, [{"a"}])
) )
self.assertSetEqual(difference, set()) self.assertSetEqual(difference, set())
@attr.s
class FakeEvent:
event_id = attr.ib()
room_id = attr.ib()
auth_events = attr.ib()
type = "foo"
state_key = "foo"
internal_metadata = _EventInternalMetadata({})
def auth_event_ids(self):
return self.auth_events
def is_state(self):
return True

View file

@ -12,7 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.util.iterutils import chunk_seq from typing import Dict, List
from synapse.util.iterutils import chunk_seq, sorted_topologically
from tests.unittest import TestCase from tests.unittest import TestCase
@ -45,3 +47,40 @@ class ChunkSeqTests(TestCase):
self.assertEqual( self.assertEqual(
list(parts), [], list(parts), [],
) )
class SortTopologically(TestCase):
def test_empty(self):
"Test that an empty graph works correctly"
graph = {} # type: Dict[int, List[int]]
self.assertEqual(list(sorted_topologically([], graph)), [])
def test_disconnected(self):
"Test that a graph with no edges work"
graph = {1: [], 2: []} # type: Dict[int, List[int]]
# For disconnected nodes the output is simply sorted.
self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
def test_linear(self):
"Test that a simple `4 -> 3 -> 2 -> 1` graph works"
graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]]
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
def test_subset(self):
"Test that only sorting a subset of the graph works"
graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]]
self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4])
def test_fork(self):
"Test that a forked graph works"
graph = {1: [], 2: [1], 3: [1], 4: [2, 3]} # type: Dict[int, List[int]]
# Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should
# always get the same one.
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])