Preparatory work for tweaking performance of auth chain lookups (#16833)

This commit is contained in:
Erik Johnston 2024-01-23 11:26:27 +00:00 committed by GitHub
parent fa2700f001
commit 14c725f73b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 163 additions and 27 deletions

1
changelog.d/16833.misc Normal file
View file

@ -0,0 +1 @@
Preparatory work for tweaking performance of auth chain lookups.

View file

@ -159,6 +159,13 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
unique_columns=("event_id", "room_id"),
)
self.db_pool.updates.register_background_index_update(
update_name="event_auth_chain_links_origin_index",
index_name="event_auth_chain_links_origin_index",
table="event_auth_chain_links",
columns=("origin_chain_id", "origin_sequence_number"),
)
async def get_auth_chain(
self, room_id: str, event_ids: Collection[str], include_given: bool = False
) -> List[EventBase]:
@ -271,39 +278,64 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Now we look up all links for the chains we have, adding chains that
# are reachable from any event.
#
# This query is structured to first get all chain IDs reachable, and
# then pull out all links from those chains. This does pull out more
# rows than is strictly necessary, however there isn't a way of
# structuring the recursive part of query to pull out the links without
# also returning large quantities of redundant data (which can make it a
# lot slower).
sql = """
WITH RECURSIVE links(chain_id) AS (
SELECT
DISTINCT origin_chain_id
FROM event_auth_chain_links WHERE %s
UNION
SELECT
target_chain_id
FROM event_auth_chain_links
INNER JOIN links ON (chain_id = origin_chain_id)
)
SELECT
origin_chain_id, origin_sequence_number,
target_chain_id, target_sequence_number
FROM event_auth_chain_links
WHERE %s
FROM links
INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
"""
# A map from chain ID to max sequence number *reachable* from any event ID.
chains: Dict[int, int] = {}
# Add all linked chains reachable from initial set of chains.
for batch2 in batch_iter(event_chains, 1000):
chains_to_fetch = set(event_chains.keys())
while chains_to_fetch:
batch2 = tuple(itertools.islice(chains_to_fetch, 100))
chains_to_fetch.difference_update(batch2)
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch2
)
txn.execute(sql % (clause,), args)
links: Dict[int, List[Tuple[int, int, int]]] = {}
for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in txn:
# chains are only reachable if the origin sequence number of
# the link is less than the max sequence number in the
# origin chain.
if origin_sequence_number <= event_chains.get(origin_chain_id, 0):
chains[target_chain_id] = max(
target_sequence_number,
chains.get(target_chain_id, 0),
links.setdefault(origin_chain_id, []).append(
(origin_sequence_number, target_chain_id, target_sequence_number)
)
for chain_id in links:
if chain_id not in event_chains:
continue
_materialize(chain_id, event_chains[chain_id], links, chains)
chains_to_fetch.difference_update(chains)
# Add the initial set of chains, excluding the sequence corresponding to
# initial event.
for chain_id, seq_no in event_chains.items():
@ -529,41 +561,64 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
chains[chain_id] = max(seq_no, chains.get(chain_id, 0))
# Now we look up all links for the chains we have, adding chains to
# set_to_chain that are reachable from each set.
# Now we look up all links for the chains we have, adding chains that
# are reachable from any event.
#
# This query is structured to first get all chain IDs reachable, and
# then pull out all links from those chains. This does pull out more
# rows than is strictly necessary, however there isn't a way of
# structuring the recursive part of query to pull out the links without
# also returning large quantities of redundant data (which can make it a
# lot slower).
sql = """
WITH RECURSIVE links(chain_id) AS (
SELECT
DISTINCT origin_chain_id
FROM event_auth_chain_links WHERE %s
UNION
SELECT
target_chain_id
FROM event_auth_chain_links
INNER JOIN links ON (chain_id = origin_chain_id)
)
SELECT
origin_chain_id, origin_sequence_number,
target_chain_id, target_sequence_number
FROM event_auth_chain_links
WHERE %s
FROM links
INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
"""
# (We need to take a copy of `seen_chains` as we want to mutate it in
# the loop)
for batch2 in batch_iter(set(seen_chains), 1000):
chains_to_fetch = set(seen_chains)
while chains_to_fetch:
batch2 = tuple(itertools.islice(chains_to_fetch, 100))
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch2
)
txn.execute(sql % (clause,), args)
links: Dict[int, List[Tuple[int, int, int]]] = {}
for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in txn:
for chains in set_to_chain:
# chains are only reachable if the origin sequence number of
# the link is less than the max sequence number in the
# origin chain.
if origin_sequence_number <= chains.get(origin_chain_id, 0):
chains[target_chain_id] = max(
target_sequence_number,
chains.get(target_chain_id, 0),
links.setdefault(origin_chain_id, []).append(
(origin_sequence_number, target_chain_id, target_sequence_number)
)
seen_chains.add(target_chain_id)
for chains in set_to_chain:
for chain_id in links:
if chain_id not in chains:
continue
_materialize(chain_id, chains[chain_id], links, chains)
chains_to_fetch.difference_update(chains)
seen_chains.update(chains)
# Now for each chain we figure out the maximum sequence number reachable
# from *any* state set and the minimum sequence number reachable from
@ -2103,3 +2158,49 @@ class EventFederationStore(EventFederationWorkerStore):
)
return batch_size
def _materialize(
origin_chain_id: int,
origin_sequence_number: int,
links: Dict[int, List[Tuple[int, int, int]]],
materialized: Dict[int, int],
) -> None:
"""Helper function for fetching auth chain links. For a given origin chain
ID / sequence number and a dictionary of links, updates the materialized
dict with the reachable chains.
To get a dict of all chains reachable from a set of chains this function can
be called in a loop, once per origin chain with the same links and
materialized args. The materialized dict will the result.
Args:
origin_chain_id, origin_sequence_number
links: map of the links between chains as a dict from origin chain ID
to list of 3-tuples of origin sequence number, target chain ID and
target sequence number.
materialized: dict to update with new reachability information, as a
map from chain ID to max sequence number reachable.
"""
# Do a standard graph traversal.
stack = [(origin_chain_id, origin_sequence_number)]
while stack:
c, s = stack.pop()
chain_links = links.get(c, [])
for (
sequence_number,
target_chain_id,
target_sequence_number,
) in chain_links:
# Ignore any links that are higher up the chain
if sequence_number > s:
continue
# Check if we have already visited the target chain before, if so we
# can skip it.
if materialized.get(target_chain_id, 0) < target_sequence_number:
stack.append((target_chain_id, target_sequence_number))
materialized[target_chain_id] = target_sequence_number

View file

@ -18,7 +18,7 @@
#
#
SCHEMA_VERSION = 83 # remember to update the list below when updating
SCHEMA_VERSION = 84 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the

View file

@ -0,0 +1,18 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2023 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
-- Force the statistics for these tables to show that the number of distinct
-- chain IDs are proportional to the total rows, as postgres has trouble
-- figuring that out by itself.
ALTER TABLE event_auth_chain_links ALTER origin_chain_id SET (n_distinct = -0.5);
ALTER TABLE event_auth_chain_links ALTER target_chain_id SET (n_distinct = -0.5);

View file

@ -0,0 +1,16 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2023 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(8402, 'event_auth_chain_links_origin_index', '{}');