Merge branch 'develop' into release-v1.119

This commit is contained in:
Devon Hudson 2024-11-08 10:06:46 -07:00
commit f377cee7ec
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
45 changed files with 498 additions and 190 deletions

View file

@ -36,11 +36,11 @@ IS_PR = os.environ["GITHUB_REF"].startswith("refs/pull/")
# First calculate the various trial jobs. # First calculate the various trial jobs.
# #
# For PRs, we only run each type of test with the oldest Python version supported (which # For PRs, we only run each type of test with the oldest Python version supported (which
# is Python 3.8 right now) # is Python 3.9 right now)
trial_sqlite_tests = [ trial_sqlite_tests = [
{ {
"python-version": "3.8", "python-version": "3.9",
"database": "sqlite", "database": "sqlite",
"extras": "all", "extras": "all",
} }
@ -53,12 +53,12 @@ if not IS_PR:
"database": "sqlite", "database": "sqlite",
"extras": "all", "extras": "all",
} }
for version in ("3.9", "3.10", "3.11", "3.12", "3.13") for version in ("3.10", "3.11", "3.12", "3.13")
) )
trial_postgres_tests = [ trial_postgres_tests = [
{ {
"python-version": "3.8", "python-version": "3.9",
"database": "postgres", "database": "postgres",
"postgres-version": "11", "postgres-version": "11",
"extras": "all", "extras": "all",
@ -77,7 +77,7 @@ if not IS_PR:
trial_no_extra_tests = [ trial_no_extra_tests = [
{ {
"python-version": "3.8", "python-version": "3.9",
"database": "sqlite", "database": "sqlite",
"extras": "", "extras": "",
} }
@ -99,24 +99,24 @@ set_output("trial_test_matrix", test_matrix)
# First calculate the various sytest jobs. # First calculate the various sytest jobs.
# #
# For each type of test we only run on focal on PRs # For each type of test we only run on bullseye on PRs
sytest_tests = [ sytest_tests = [
{ {
"sytest-tag": "focal", "sytest-tag": "bullseye",
}, },
{ {
"sytest-tag": "focal", "sytest-tag": "bullseye",
"postgres": "postgres", "postgres": "postgres",
}, },
{ {
"sytest-tag": "focal", "sytest-tag": "bullseye",
"postgres": "multi-postgres", "postgres": "multi-postgres",
"workers": "workers", "workers": "workers",
}, },
{ {
"sytest-tag": "focal", "sytest-tag": "bullseye",
"postgres": "multi-postgres", "postgres": "multi-postgres",
"workers": "workers", "workers": "workers",
"reactor": "asyncio", "reactor": "asyncio",
@ -127,11 +127,11 @@ if not IS_PR:
sytest_tests.extend( sytest_tests.extend(
[ [
{ {
"sytest-tag": "focal", "sytest-tag": "bullseye",
"reactor": "asyncio", "reactor": "asyncio",
}, },
{ {
"sytest-tag": "focal", "sytest-tag": "bullseye",
"postgres": "postgres", "postgres": "postgres",
"reactor": "asyncio", "reactor": "asyncio",
}, },

View file

@ -1,5 +1,5 @@
#!/usr/bin/env bash #!/usr/bin/env bash
# this script is run by GitHub Actions in a plain `focal` container; it # this script is run by GitHub Actions in a plain `jammy` container; it
# - installs the minimal system requirements, and poetry; # - installs the minimal system requirements, and poetry;
# - patches the project definition file to refer to old versions only; # - patches the project definition file to refer to old versions only;
# - creates a venv with these old versions using poetry; and finally # - creates a venv with these old versions using poetry; and finally

View file

@ -132,9 +132,9 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- sytest-tag: focal - sytest-tag: bullseye
- sytest-tag: focal - sytest-tag: bullseye
postgres: postgres postgres: postgres
workers: workers workers: workers
redis: redis redis: redis

View file

@ -102,7 +102,7 @@ jobs:
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
strategy: strategy:
matrix: matrix:
os: [ubuntu-20.04, macos-12] os: [ubuntu-22.04, macos-12]
arch: [x86_64, aarch64] arch: [x86_64, aarch64]
# is_pr is a flag used to exclude certain jobs from the matrix on PRs. # is_pr is a flag used to exclude certain jobs from the matrix on PRs.
# It is not read by the rest of the workflow. # It is not read by the rest of the workflow.
@ -144,7 +144,7 @@ jobs:
- name: Only build a single wheel on PR - name: Only build a single wheel on PR
if: startsWith(github.ref, 'refs/pull/') if: startsWith(github.ref, 'refs/pull/')
run: echo "CIBW_BUILD="cp38-manylinux_${{ matrix.arch }}"" >> $GITHUB_ENV run: echo "CIBW_BUILD="cp39-manylinux_${{ matrix.arch }}"" >> $GITHUB_ENV
- name: Build wheels - name: Build wheels
run: python -m cibuildwheel --output-dir wheelhouse run: python -m cibuildwheel --output-dir wheelhouse

View file

@ -397,7 +397,7 @@ jobs:
needs: needs:
- linting-done - linting-done
- changes - changes
runs-on: ubuntu-20.04 runs-on: ubuntu-22.04
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
@ -409,12 +409,12 @@ jobs:
# their build dependencies # their build dependencies
- run: | - run: |
sudo apt-get -qq update sudo apt-get -qq update
sudo apt-get -qq install build-essential libffi-dev python-dev \ sudo apt-get -qq install build-essential libffi-dev python3-dev \
libxml2-dev libxslt-dev xmlsec1 zlib1g-dev libjpeg-dev libwebp-dev libxml2-dev libxslt-dev xmlsec1 zlib1g-dev libjpeg-dev libwebp-dev
- uses: actions/setup-python@v5 - uses: actions/setup-python@v5
with: with:
python-version: '3.8' python-version: '3.9'
- name: Prepare old deps - name: Prepare old deps
if: steps.cache-poetry-old-deps.outputs.cache-hit != 'true' if: steps.cache-poetry-old-deps.outputs.cache-hit != 'true'
@ -458,7 +458,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
python-version: ["pypy-3.8"] python-version: ["pypy-3.9"]
extras: ["all"] extras: ["all"]
steps: steps:
@ -580,11 +580,11 @@ jobs:
strategy: strategy:
matrix: matrix:
include: include:
- python-version: "3.8" - python-version: "3.9"
postgres-version: "11" postgres-version: "11"
- python-version: "3.11" - python-version: "3.13"
postgres-version: "15" postgres-version: "17"
services: services:
postgres: postgres:

View file

@ -99,11 +99,11 @@ jobs:
if: needs.check_repo.outputs.should_run_workflow == 'true' if: needs.check_repo.outputs.should_run_workflow == 'true'
runs-on: ubuntu-latest runs-on: ubuntu-latest
container: container:
# We're using ubuntu:focal because it uses Python 3.8 which is our minimum supported Python version. # We're using debian:bullseye because it uses Python 3.9 which is our minimum supported Python version.
# This job is a canary to warn us about unreleased twisted changes that would cause problems for us if # This job is a canary to warn us about unreleased twisted changes that would cause problems for us if
# they were to be released immediately. For simplicity's sake (and to save CI runners) we use the oldest # they were to be released immediately. For simplicity's sake (and to save CI runners) we use the oldest
# version, assuming that any incompatibilities on newer versions would also be present on the oldest. # version, assuming that any incompatibilities on newer versions would also be present on the oldest.
image: matrixdotorg/sytest-synapse:focal image: matrixdotorg/sytest-synapse:bullseye
volumes: volumes:
- ${{ github.workspace }}:/src - ${{ github.workspace }}:/src

1
changelog.d/17902.misc Normal file
View file

@ -0,0 +1 @@
Update version constraint to allow the latest poetry-core 1.9.1.

1
changelog.d/17903.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a long-standing bug in Synapse which could cause one-time keys to be issued in the incorrect order, causing message decryption failures.

1
changelog.d/17906.bugfix Normal file
View file

@ -0,0 +1 @@
Fix tests to run with latest Twisted.

1
changelog.d/17907.bugfix Normal file
View file

@ -0,0 +1 @@
Fix tests to run with latest Twisted.

1
changelog.d/17908.misc Normal file
View file

@ -0,0 +1 @@
Remove support for python 3.8.

1
changelog.d/17909.misc Normal file
View file

@ -0,0 +1 @@
Update the portdb CI to use Python 3.13 and Postgres 17 as latest dependencies.

1
changelog.d/17911.bugfix Normal file
View file

@ -0,0 +1 @@
Fix tests to run with latest Twisted.

1
changelog.d/17915.bugfix Normal file
View file

@ -0,0 +1 @@
Fix experimental support for [MSC4222](https://github.com/matrix-org/matrix-spec-proposals/pull/4222) where we would return the full state on incremental syncs when using lazy loaded members and there were no new events in the timeline.

View file

@ -322,7 +322,7 @@ The following command will let you run the integration test with the most common
configuration: configuration:
```sh ```sh
$ docker run --rm -it -v /path/where/you/have/cloned/the/repository\:/src:ro -v /path/to/where/you/want/logs\:/logs matrixdotorg/sytest-synapse:focal $ docker run --rm -it -v /path/where/you/have/cloned/the/repository\:/src:ro -v /path/to/where/you/want/logs\:/logs matrixdotorg/sytest-synapse:bullseye
``` ```
(Note that the paths must be full paths! You could also write `$(realpath relative/path)` if needed.) (Note that the paths must be full paths! You could also write `$(realpath relative/path)` if needed.)

View file

@ -208,7 +208,7 @@ When following this route please make sure that the [Platform-specific prerequis
System requirements: System requirements:
- POSIX-compliant system (tested on Linux & OS X) - POSIX-compliant system (tested on Linux & OS X)
- Python 3.8 or later, up to Python 3.11. - Python 3.9 or later, up to Python 3.13.
- At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org - At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org
If building on an uncommon architecture for which pre-built wheels are If building on an uncommon architecture for which pre-built wheels are

View file

@ -117,6 +117,17 @@ each upgrade are complete before moving on to the next upgrade, to avoid
stacking them up. You can monitor the currently running background updates with stacking them up. You can monitor the currently running background updates with
[the Admin API](usage/administration/admin_api/background_updates.html#status). [the Admin API](usage/administration/admin_api/background_updates.html#status).
# Upgrading to v1.119.0
## Minimum supported Python version
The minimum supported Python version has been increased from v3.8 to v3.9.
You will need Python 3.9+ to run Synapse v1.119.0 (due out Nov 7th, 2024).
If you use current versions of the Matrix.org-distributed Docker images, no action is required.
Please note that support for Ubuntu `focal` was dropped as well since it uses Python 3.8.
# Upgrading to v1.111.0 # Upgrading to v1.111.0
## New worker endpoints for authenticated client and federation media ## New worker endpoints for authenticated client and federation media

View file

@ -26,7 +26,7 @@ strict_equality = True
# Run mypy type checking with the minimum supported Python version to catch new usage # Run mypy type checking with the minimum supported Python version to catch new usage
# that isn't backwards-compatible (types, overloads, etc). # that isn't backwards-compatible (types, overloads, etc).
python_version = 3.8 python_version = 3.9
files = files =
docker/, docker/,

26
poetry.lock generated
View file

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
[[package]] [[package]]
name = "annotated-types" name = "annotated-types"
@ -11,9 +11,6 @@ files = [
{file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"},
] ]
[package.dependencies]
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""}
[[package]] [[package]]
name = "attrs" name = "attrs"
version = "24.2.0" version = "24.2.0"
@ -874,9 +871,7 @@ files = [
[package.dependencies] [package.dependencies]
attrs = ">=22.2.0" attrs = ">=22.2.0"
importlib-resources = {version = ">=1.4.0", markers = "python_version < \"3.9\""}
jsonschema-specifications = ">=2023.03.6" jsonschema-specifications = ">=2023.03.6"
pkgutil-resolve-name = {version = ">=1.3.10", markers = "python_version < \"3.9\""}
referencing = ">=0.28.4" referencing = ">=0.28.4"
rpds-py = ">=0.7.1" rpds-py = ">=0.7.1"
@ -896,7 +891,6 @@ files = [
] ]
[package.dependencies] [package.dependencies]
importlib-resources = {version = ">=1.4.0", markers = "python_version < \"3.9\""}
referencing = ">=0.28.0" referencing = ">=0.28.0"
[[package]] [[package]]
@ -912,7 +906,6 @@ files = [
[package.dependencies] [package.dependencies]
importlib-metadata = {version = ">=4.11.4", markers = "python_version < \"3.12\""} importlib-metadata = {version = ">=4.11.4", markers = "python_version < \"3.12\""}
importlib-resources = {version = "*", markers = "python_version < \"3.9\""}
"jaraco.classes" = "*" "jaraco.classes" = "*"
jeepney = {version = ">=0.4.2", markers = "sys_platform == \"linux\""} jeepney = {version = ">=0.4.2", markers = "sys_platform == \"linux\""}
pywin32-ctypes = {version = ">=0.2.0", markers = "sys_platform == \"win32\""} pywin32-ctypes = {version = ">=0.2.0", markers = "sys_platform == \"win32\""}
@ -1571,17 +1564,6 @@ files = [
[package.extras] [package.extras]
testing = ["pytest", "pytest-cov"] testing = ["pytest", "pytest-cov"]
[[package]]
name = "pkgutil-resolve-name"
version = "1.3.10"
description = "Resolve a name to an object."
optional = false
python-versions = ">=3.6"
files = [
{file = "pkgutil_resolve_name-1.3.10-py3-none-any.whl", hash = "sha256:ca27cc078d25c5ad71a9de0a7a330146c4e014c2462d9af19c6b828280649c5e"},
{file = "pkgutil_resolve_name-1.3.10.tar.gz", hash = "sha256:357d6c9e6a755653cfd78893817c0853af365dd51ec97f3d358a819373bbd174"},
]
[[package]] [[package]]
name = "prometheus-client" name = "prometheus-client"
version = "0.21.0" version = "0.21.0"
@ -1948,7 +1930,6 @@ files = [
[package.dependencies] [package.dependencies]
cryptography = ">=3.1" cryptography = ">=3.1"
defusedxml = "*" defusedxml = "*"
importlib-resources = {version = "*", markers = "python_version < \"3.9\""}
pyopenssl = "*" pyopenssl = "*"
python-dateutil = "*" python-dateutil = "*"
pytz = "*" pytz = "*"
@ -2164,7 +2145,6 @@ files = [
[package.dependencies] [package.dependencies]
markdown-it-py = ">=2.2.0,<3.0.0" markdown-it-py = ">=2.2.0,<3.0.0"
pygments = ">=2.13.0,<3.0.0" pygments = ">=2.13.0,<3.0.0"
typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9\""}
[package.extras] [package.extras]
jupyter = ["ipywidgets (>=7.5.1,<9)"] jupyter = ["ipywidgets (>=7.5.1,<9)"]
@ -3121,5 +3101,5 @@ user-search = ["pyicu"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.8.0" python-versions = "^3.9.0"
content-hash = "eaded26b4770b9d19bfcee6dee8b96203df358ce51939d9b90fdbcf605e2f5fd" content-hash = "0cd942a5193d01cbcef135a0bebd3fa0f12f7dbc63899d6f1c301e0649e9d902"

View file

@ -36,7 +36,7 @@
[tool.ruff] [tool.ruff]
line-length = 88 line-length = 88
target-version = "py38" target-version = "py39"
[tool.ruff.lint] [tool.ruff.lint]
# See https://beta.ruff.rs/docs/rules/#error-e # See https://beta.ruff.rs/docs/rules/#error-e
@ -155,7 +155,7 @@ synapse_review_recent_signups = "synapse._scripts.review_recent_signups:main"
update_synapse_database = "synapse._scripts.update_synapse_database:main" update_synapse_database = "synapse._scripts.update_synapse_database:main"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.8.0" python = "^3.9.0"
# Mandatory Dependencies # Mandatory Dependencies
# ---------------------- # ----------------------
@ -178,7 +178,7 @@ Twisted = {extras = ["tls"], version = ">=18.9.0"}
treq = ">=15.1" treq = ">=15.1"
# Twisted has required pyopenssl 16.0 since about Twisted 16.6. # Twisted has required pyopenssl 16.0 since about Twisted 16.6.
pyOpenSSL = ">=16.0.0" pyOpenSSL = ">=16.0.0"
PyYAML = ">=3.13" PyYAML = ">=5.3"
pyasn1 = ">=0.1.9" pyasn1 = ">=0.1.9"
pyasn1-modules = ">=0.0.7" pyasn1-modules = ">=0.0.7"
bcrypt = ">=3.1.7" bcrypt = ">=3.1.7"
@ -241,7 +241,7 @@ authlib = { version = ">=0.15.1", optional = true }
# `contrib/systemd/log_config.yaml`. # `contrib/systemd/log_config.yaml`.
# Note: systemd-python 231 appears to have been yanked from pypi # Note: systemd-python 231 appears to have been yanked from pypi
systemd-python = { version = ">=231", optional = true } systemd-python = { version = ">=231", optional = true }
lxml = { version = ">=4.2.0", optional = true } lxml = { version = ">=4.5.2", optional = true }
sentry-sdk = { version = ">=0.7.2", optional = true } sentry-sdk = { version = ">=0.7.2", optional = true }
opentracing = { version = ">=2.2.0", optional = true } opentracing = { version = ">=2.2.0", optional = true }
jaeger-client = { version = ">=4.0.0", optional = true } jaeger-client = { version = ">=4.0.0", optional = true }
@ -370,7 +370,7 @@ tomli = ">=1.2.3"
# runtime errors caused by build system changes. # runtime errors caused by build system changes.
# We are happy to raise these upper bounds upon request, # We are happy to raise these upper bounds upon request,
# provided we check that it's safe to do so (i.e. that CI passes). # provided we check that it's safe to do so (i.e. that CI passes).
requires = ["poetry-core>=1.1.0,<=1.9.0", "setuptools_rust>=1.3,<=1.8.1"] requires = ["poetry-core>=1.1.0,<=1.9.1", "setuptools_rust>=1.3,<=1.8.1"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
@ -378,13 +378,13 @@ build-backend = "poetry.core.masonry.api"
# Skip unsupported platforms (by us or by Rust). # Skip unsupported platforms (by us or by Rust).
# See https://cibuildwheel.readthedocs.io/en/stable/options/#build-skip for the list of build targets. # See https://cibuildwheel.readthedocs.io/en/stable/options/#build-skip for the list of build targets.
# We skip: # We skip:
# - CPython 3.6 and 3.7: EOLed # - CPython 3.6, 3.7 and 3.8: EOLed
# - PyPy 3.7: we only support Python 3.8+ # - PyPy 3.7 and 3.8: we only support Python 3.9+
# - musllinux i686: excluded to reduce number of wheels we build. # - musllinux i686: excluded to reduce number of wheels we build.
# c.f. https://github.com/matrix-org/synapse/pull/12595#discussion_r963107677 # c.f. https://github.com/matrix-org/synapse/pull/12595#discussion_r963107677
# - PyPy on Aarch64 and musllinux on aarch64: too slow to build. # - PyPy on Aarch64 and musllinux on aarch64: too slow to build.
# c.f. https://github.com/matrix-org/synapse/pull/14259 # c.f. https://github.com/matrix-org/synapse/pull/14259
skip = "cp36* cp37* pp37* *-musllinux_i686 pp*aarch64 *-musllinux_aarch64" skip = "cp36* cp37* cp38* pp37* pp38* *-musllinux_i686 pp*aarch64 *-musllinux_aarch64"
# We need a rust compiler # We need a rust compiler
before-all = "curl https://sh.rustup.rs -sSf | sh -s -- --default-toolchain stable -y --profile minimal" before-all = "curl https://sh.rustup.rs -sSf | sh -s -- --default-toolchain stable -y --profile minimal"

View file

@ -28,9 +28,8 @@ from typing import Collection, Optional, Sequence, Set
# example) # example)
DISTS = ( DISTS = (
"debian:bullseye", # (EOL ~2024-07) (our EOL forced by Python 3.9 is 2025-10-05) "debian:bullseye", # (EOL ~2024-07) (our EOL forced by Python 3.9 is 2025-10-05)
"debian:bookworm", # (EOL not specified yet) (our EOL forced by Python 3.11 is 2027-10-24) "debian:bookworm", # (EOL 2026-06) (our EOL forced by Python 3.11 is 2027-10-24)
"debian:sid", # (EOL not specified yet) (our EOL forced by Python 3.11 is 2027-10-24) "debian:sid", # (rolling distro, no EOL)
"ubuntu:focal", # 20.04 LTS (EOL 2025-04) (our EOL forced by Python 3.8 is 2024-10-14)
"ubuntu:jammy", # 22.04 LTS (EOL 2027-04) (our EOL forced by Python 3.10 is 2026-10-04) "ubuntu:jammy", # 22.04 LTS (EOL 2027-04) (our EOL forced by Python 3.10 is 2026-10-04)
"ubuntu:noble", # 24.04 LTS (EOL 2029-06) "ubuntu:noble", # 24.04 LTS (EOL 2029-06)
"ubuntu:oracular", # 24.10 (EOL 2025-07) "ubuntu:oracular", # 24.10 (EOL 2025-07)

View file

@ -39,8 +39,8 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True
# Note that we use an (unneeded) variable here so that pyupgrade doesn't nuke the # Note that we use an (unneeded) variable here so that pyupgrade doesn't nuke the
# if-statement completely. # if-statement completely.
py_version = sys.version_info py_version = sys.version_info
if py_version < (3, 8): if py_version < (3, 9):
print("Synapse requires Python 3.8 or above.") print("Synapse requires Python 3.9 or above.")
sys.exit(1) sys.exit(1)
# Allow using the asyncio reactor via env var. # Allow using the asyncio reactor via env var.

View file

@ -615,7 +615,7 @@ class E2eKeysHandler:
3. Attempt to fetch fallback keys from the database. 3. Attempt to fetch fallback keys from the database.
Args: Args:
local_query: An iterable of tuples of (user ID, device ID, algorithm). local_query: An iterable of tuples of (user ID, device ID, algorithm, number of keys).
always_include_fallback_keys: True to always include fallback keys. always_include_fallback_keys: True to always include fallback keys.
Returns: Returns:

View file

@ -196,7 +196,9 @@ class MessageHandler:
AuthError (403) if the user doesn't have permission to view AuthError (403) if the user doesn't have permission to view
members of this room. members of this room.
""" """
state_filter = state_filter or StateFilter.all() if state_filter is None:
state_filter = StateFilter.all()
user_id = requester.user.to_string() user_id = requester.user.to_string()
if at_token: if at_token:

View file

@ -1520,7 +1520,7 @@ class SyncHandler:
if sync_config.use_state_after: if sync_config.use_state_after:
delta_state_ids: MutableStateMap[str] = {} delta_state_ids: MutableStateMap[str] = {}
if members_to_fetch is not None: if members_to_fetch:
# We're lazy-loading, so the client might need some more member # We're lazy-loading, so the client might need some more member
# events to understand the events in this timeline. So we always # events to understand the events in this timeline. So we always
# fish out all the member events corresponding to the timeline # fish out all the member events corresponding to the timeline

View file

@ -39,7 +39,7 @@ from twisted.internet.endpoints import (
) )
from twisted.internet.interfaces import ( from twisted.internet.interfaces import (
IPushProducer, IPushProducer,
IReactorTCP, IReactorTime,
IStreamClientEndpoint, IStreamClientEndpoint,
) )
from twisted.internet.protocol import Factory, Protocol from twisted.internet.protocol import Factory, Protocol
@ -113,7 +113,7 @@ class RemoteHandler(logging.Handler):
port: int, port: int,
maximum_buffer: int = 1000, maximum_buffer: int = 1000,
level: int = logging.NOTSET, level: int = logging.NOTSET,
_reactor: Optional[IReactorTCP] = None, _reactor: Optional[IReactorTime] = None,
): ):
super().__init__(level=level) super().__init__(level=level)
self.host = host self.host = host

View file

@ -234,8 +234,11 @@ class StateStorageController:
RuntimeError if we don't have a state group for one or more of the events RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown) (ie they are outliers or unknown)
""" """
if state_filter is None:
state_filter = StateFilter.all()
await_full_state = True await_full_state = True
if state_filter and not state_filter.must_await_full_state(self._is_mine_id): if not state_filter.must_await_full_state(self._is_mine_id):
await_full_state = False await_full_state = False
event_to_groups = await self.get_state_group_for_events( event_to_groups = await self.get_state_group_for_events(
@ -244,7 +247,7 @@ class StateStorageController:
groups = set(event_to_groups.values()) groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups( group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter or StateFilter.all() groups, state_filter
) )
state_event_map = await self.stores.main.get_events( state_event_map = await self.stores.main.get_events(
@ -292,10 +295,11 @@ class StateStorageController:
RuntimeError if we don't have a state group for one or more of the events RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown) (ie they are outliers or unknown)
""" """
if ( if state_filter is None:
await_full_state state_filter = StateFilter.all()
and state_filter
and not state_filter.must_await_full_state(self._is_mine_id) if await_full_state and not state_filter.must_await_full_state(
self._is_mine_id
): ):
# Full state is not required if the state filter is restrictive enough. # Full state is not required if the state filter is restrictive enough.
await_full_state = False await_full_state = False
@ -306,7 +310,7 @@ class StateStorageController:
groups = set(event_to_groups.values()) groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups( group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter or StateFilter.all() groups, state_filter
) )
event_to_state = { event_to_state = {
@ -335,9 +339,10 @@ class StateStorageController:
RuntimeError if we don't have a state group for the event (ie it is an RuntimeError if we don't have a state group for the event (ie it is an
outlier or is unknown) outlier or is unknown)
""" """
state_map = await self.get_state_for_events( if state_filter is None:
[event_id], state_filter or StateFilter.all() state_filter = StateFilter.all()
)
state_map = await self.get_state_for_events([event_id], state_filter)
return state_map[event_id] return state_map[event_id]
@trace @trace
@ -365,9 +370,12 @@ class StateStorageController:
RuntimeError if we don't have a state group for the event (ie it is an RuntimeError if we don't have a state group for the event (ie it is an
outlier or is unknown) outlier or is unknown)
""" """
if state_filter is None:
state_filter = StateFilter.all()
state_map = await self.get_state_ids_for_events( state_map = await self.get_state_ids_for_events(
[event_id], [event_id],
state_filter or StateFilter.all(), state_filter,
await_full_state=await_full_state, await_full_state=await_full_state,
) )
return state_map[event_id] return state_map[event_id]
@ -388,9 +396,12 @@ class StateStorageController:
at the event and `state_filter` is not satisfied by partial state. at the event and `state_filter` is not satisfied by partial state.
Defaults to `True`. Defaults to `True`.
""" """
if state_filter is None:
state_filter = StateFilter.all()
state_ids = await self.get_state_ids_for_event( state_ids = await self.get_state_ids_for_event(
event_id, event_id,
state_filter=state_filter or StateFilter.all(), state_filter=state_filter,
await_full_state=await_full_state, await_full_state=await_full_state,
) )
@ -426,6 +437,9 @@ class StateStorageController:
at the last event in the room before `stream_position` and at the last event in the room before `stream_position` and
`state_filter` is not satisfied by partial state. Defaults to `True`. `state_filter` is not satisfied by partial state. Defaults to `True`.
""" """
if state_filter is None:
state_filter = StateFilter.all()
# FIXME: This gets the state at the latest event before the stream ordering, # FIXME: This gets the state at the latest event before the stream ordering,
# which might not be the same as the "current state" of the room at the time # which might not be the same as the "current state" of the room at the time
# of the stream token if there were multiple forward extremities at the time. # of the stream token if there were multiple forward extremities at the time.
@ -442,7 +456,7 @@ class StateStorageController:
if last_event_id: if last_event_id:
state = await self.get_state_after_event( state = await self.get_state_after_event(
last_event_id, last_event_id,
state_filter=state_filter or StateFilter.all(), state_filter=state_filter,
await_full_state=await_full_state, await_full_state=await_full_state,
) )
@ -500,9 +514,10 @@ class StateStorageController:
Returns: Returns:
Dict of state group to state map. Dict of state group to state map.
""" """
return await self.stores.state._get_state_for_groups( if state_filter is None:
groups, state_filter or StateFilter.all() state_filter = StateFilter.all()
)
return await self.stores.state._get_state_for_groups(groups, state_filter)
@trace @trace
@tag_args @tag_args
@ -583,12 +598,13 @@ class StateStorageController:
Returns: Returns:
The current state of the room. The current state of the room.
""" """
if await_full_state and ( if state_filter is None:
not state_filter or state_filter.must_await_full_state(self._is_mine_id) state_filter = StateFilter.all()
):
if await_full_state and state_filter.must_await_full_state(self._is_mine_id):
await self._partial_state_room_tracker.await_full_state(room_id) await self._partial_state_room_tracker.await_full_state(room_id)
if state_filter and not state_filter.is_full(): if state_filter is not None and not state_filter.is_full():
return await self.stores.main.get_partial_filtered_current_state_ids( return await self.stores.main.get_partial_filtered_current_state_ids(
room_id, state_filter room_id, state_filter
) )

View file

@ -99,6 +99,13 @@ class EndToEndKeyBackgroundStore(SQLBaseStore):
unique=True, unique=True,
) )
self.db_pool.updates.register_background_index_update(
update_name="add_otk_ts_added_index",
index_name="e2e_one_time_keys_json_user_id_device_id_algorithm_ts_added_idx",
table="e2e_one_time_keys_json",
columns=("user_id", "device_id", "algorithm", "ts_added_ms"),
)
class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorkerStore): class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorkerStore):
def __init__( def __init__(
@ -1122,7 +1129,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
"""Take a list of one time keys out of the database. """Take a list of one time keys out of the database.
Args: Args:
query_list: An iterable of tuples of (user ID, device ID, algorithm). query_list: An iterable of tuples of (user ID, device ID, algorithm, number of keys).
Returns: Returns:
A tuple (results, missing) of: A tuple (results, missing) of:
@ -1310,9 +1317,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
OTK was found. OTK was found.
""" """
# Return the oldest keys from this device (based on `ts_added_ms`).
# Doing so means that keys are issued in the same order they were uploaded,
# which reduces the chances of a client expiring its copy of a (private)
# key while the public key is still on the server, waiting to be issued.
sql = """ sql = """
SELECT key_id, key_json FROM e2e_one_time_keys_json SELECT key_id, key_json FROM e2e_one_time_keys_json
WHERE user_id = ? AND device_id = ? AND algorithm = ? WHERE user_id = ? AND device_id = ? AND algorithm = ?
ORDER BY ts_added_ms
LIMIT ? LIMIT ?
""" """
@ -1354,13 +1366,22 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
A list of tuples (user_id, device_id, algorithm, key_id, key_json) A list of tuples (user_id, device_id, algorithm, key_id, key_json)
for each OTK claimed. for each OTK claimed.
""" """
# Find, delete, and return the oldest keys from each device (based on
# `ts_added_ms`).
#
# Doing so means that keys are issued in the same order they were uploaded,
# which reduces the chances of a client expiring its copy of a (private)
# key while the public key is still on the server, waiting to be issued.
sql = """ sql = """
WITH claims(user_id, device_id, algorithm, claim_count) AS ( WITH claims(user_id, device_id, algorithm, claim_count) AS (
VALUES ? VALUES ?
), ranked_keys AS ( ), ranked_keys AS (
SELECT SELECT
user_id, device_id, algorithm, key_id, claim_count, user_id, device_id, algorithm, key_id, claim_count,
ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r ROW_NUMBER() OVER (
PARTITION BY (user_id, device_id, algorithm)
ORDER BY ts_added_ms
) AS r
FROM e2e_one_time_keys_json FROM e2e_one_time_keys_json
JOIN claims USING (user_id, device_id, algorithm) JOIN claims USING (user_id, device_id, algorithm)
) )

View file

@ -2550,7 +2550,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
still contains events with partial state. still contains events with partial state.
""" """
try: try:
async with self._un_partial_stated_rooms_stream_id_gen.get_next() as un_partial_state_room_stream_id: async with (
self._un_partial_stated_rooms_stream_id_gen.get_next() as un_partial_state_room_stream_id
):
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"clear_partial_state_room", "clear_partial_state_room",
self._clear_partial_state_room_txn, self._clear_partial_state_room_txn,

View file

@ -572,10 +572,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns: Returns:
Map from type/state_key to event ID. Map from type/state_key to event ID.
""" """
if state_filter is None:
state_filter = StateFilter.all()
where_clause, where_args = ( where_clause, where_args = (state_filter).make_sql_filter_clause()
state_filter or StateFilter.all()
).make_sql_filter_clause()
if not where_clause: if not where_clause:
# We delegate to the cached version # We delegate to the cached version
@ -584,7 +584,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
def _get_filtered_current_state_ids_txn( def _get_filtered_current_state_ids_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> StateMap[str]: ) -> StateMap[str]:
results = StateMapWrapper(state_filter=state_filter or StateFilter.all()) results = StateMapWrapper(state_filter=state_filter)
sql = """ sql = """
SELECT type, state_key, event_id FROM current_state_events SELECT type, state_key, event_id FROM current_state_events
@ -681,7 +681,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
context: EventContext, context: EventContext,
) -> None: ) -> None:
"""Update the state group for a partial state event""" """Update the state group for a partial state event"""
async with self._un_partial_stated_events_stream_id_gen.get_next() as un_partial_state_event_stream_id: async with (
self._un_partial_stated_events_stream_id_gen.get_next() as un_partial_state_event_stream_id
):
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"update_state_for_partial_state_event", "update_state_for_partial_state_event",
self._update_state_for_partial_state_event_txn, self._update_state_for_partial_state_event_txn,

View file

@ -112,8 +112,8 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
Returns: Returns:
Map from state_group to a StateMap at that point. Map from state_group to a StateMap at that point.
""" """
if state_filter is None:
state_filter = state_filter or StateFilter.all() state_filter = StateFilter.all()
results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups} results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups}

View file

@ -284,7 +284,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
Returns: Returns:
Dict of state group to state map. Dict of state group to state map.
""" """
state_filter = state_filter or StateFilter.all() if state_filter is None:
state_filter = StateFilter.all()
member_filter, non_member_filter = state_filter.get_member_split() member_filter, non_member_filter = state_filter.get_member_split()

View file

@ -0,0 +1,18 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2024 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
-- Add an index on (user_id, device_id, algorithm, ts_added_ms) on e2e_one_time_keys_json, so that OTKs can
-- efficiently be issued in the same order they were uploaded.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(8803, 'add_otk_ts_added_index', '{}');

View file

@ -68,15 +68,23 @@ class StateFilter:
include_others: bool = False include_others: bool = False
def __attrs_post_init__(self) -> None: def __attrs_post_init__(self) -> None:
if self.include_others:
# If `include_others` is set we canonicalise the filter by removing # If `include_others` is set we canonicalise the filter by removing
# wildcards from the types dictionary # wildcards from the types dictionary
if self.include_others:
# this is needed to work around the fact that StateFilter is frozen # this is needed to work around the fact that StateFilter is frozen
object.__setattr__( object.__setattr__(
self, self,
"types", "types",
immutabledict({k: v for k, v in self.types.items() if v is not None}), immutabledict({k: v for k, v in self.types.items() if v is not None}),
) )
else:
# Otherwise we remove entries where the value is the empty set.
object.__setattr__(
self,
"types",
immutabledict({k: v for k, v in self.types.items() if v is None or v}),
)
@staticmethod @staticmethod
def all() -> "StateFilter": def all() -> "StateFilter":

View file

@ -151,18 +151,30 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
def test_claim_one_time_key(self) -> None: def test_claim_one_time_key(self) -> None:
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
device_id = "xyz" device_id = "xyz"
keys = {"alg1:k1": "key1"}
res = self.get_success( res = self.get_success(
self.handler.upload_keys_for_user( self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys} local_user, device_id, {"one_time_keys": {"alg1:k1": "key1"}}
) )
) )
self.assertDictEqual( self.assertDictEqual(
res, {"one_time_key_counts": {"alg1": 1, "signed_curve25519": 0}} res, {"one_time_key_counts": {"alg1": 1, "signed_curve25519": 0}}
) )
res2 = self.get_success( # Keys should be returned in the order they were uploaded. To test, advance time
# a little, then upload a second key with an earlier key ID; it should get
# returned second.
self.reactor.advance(1)
res = self.get_success(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": {"alg1:k0": "key0"}}
)
)
self.assertDictEqual(
res, {"one_time_key_counts": {"alg1": 2, "signed_curve25519": 0}}
)
# now claim both keys back. They should be in the same order
res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}}, {local_user: {device_id: {"alg1": 1}}},
self.requester, self.requester,
@ -171,12 +183,27 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
) )
) )
self.assertEqual( self.assertEqual(
res2, res,
{ {
"failures": {}, "failures": {},
"one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}}, "one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}},
}, },
) )
res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=False,
)
)
self.assertEqual(
res,
{
"failures": {},
"one_time_keys": {local_user: {device_id: {"alg1:k0": "key0"}}},
},
)
def test_claim_one_time_key_bulk(self) -> None: def test_claim_one_time_key_bulk(self) -> None:
"""Like test_claim_one_time_key but claims multiple keys in one handler call.""" """Like test_claim_one_time_key but claims multiple keys in one handler call."""
@ -336,6 +363,47 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
counts_by_alg, expected_counts_by_alg, f"{user_id}:{device_id}" counts_by_alg, expected_counts_by_alg, f"{user_id}:{device_id}"
) )
def test_claim_one_time_key_bulk_ordering(self) -> None:
"""Keys returned by the bulk claim call should be returned in the correct order"""
# Alice has lots of keys, uploaded in a specific order
alice = f"@alice:{self.hs.hostname}"
alice_dev = "alice_dev_1"
self.get_success(
self.handler.upload_keys_for_user(
alice,
alice_dev,
{"one_time_keys": {"alg1:k20": 20, "alg1:k21": 21, "alg1:k22": 22}},
)
)
# Advance time by 1s, to ensure that there is a difference in upload time.
self.reactor.advance(1)
self.get_success(
self.handler.upload_keys_for_user(
alice,
alice_dev,
{"one_time_keys": {"alg1:k10": 10, "alg1:k11": 11, "alg1:k12": 12}},
)
)
# Now claim some, and check we get the right ones.
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{alice: {alice_dev: {"alg1": 2}}},
self.requester,
timeout=None,
always_include_fallback_keys=False,
)
)
# We should get the first-uploaded keys, even though they have later key ids.
# We should get a random set of two of k20, k21, k22.
self.assertEqual(claim_res["failures"], {})
claimed_keys = claim_res["one_time_keys"]["@alice:test"]["alice_dev_1"]
self.assertEqual(len(claimed_keys), 2)
for key_id in claimed_keys.keys():
self.assertIn(key_id, ["alg1:k20", "alg1:k21", "alg1:k22"])
def test_fallback_key(self) -> None: def test_fallback_key(self) -> None:
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
device_id = "xyz" device_id = "xyz"

View file

@ -661,9 +661,12 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
) )
) )
with patch.object( with (
patch.object(
fed_client, "make_membership_event", mock_make_membership_event fed_client, "make_membership_event", mock_make_membership_event
), patch.object(fed_client, "send_join", mock_send_join): ),
patch.object(fed_client, "send_join", mock_send_join),
):
# Join and check that our join event is rejected # Join and check that our join event is rejected
# (The join event is rejected because it doesn't have any signatures) # (The join event is rejected because it doesn't have any signatures)
join_exc = self.get_failure( join_exc = self.get_failure(
@ -708,9 +711,12 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
fed_handler = self.hs.get_federation_handler() fed_handler = self.hs.get_federation_handler()
store = self.hs.get_datastores().main store = self.hs.get_datastores().main
with patch.object( with (
patch.object(
fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
), patch.object(store, "is_partial_state_room", mock_is_partial_state_room): ),
patch.object(store, "is_partial_state_room", mock_is_partial_state_room),
):
# Start the partial state sync. # Start the partial state sync.
fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id") fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1) self.assertEqual(mock_sync_partial_state_room.call_count, 1)
@ -760,9 +766,12 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
fed_handler = self.hs.get_federation_handler() fed_handler = self.hs.get_federation_handler()
store = self.hs.get_datastores().main store = self.hs.get_datastores().main
with patch.object( with (
patch.object(
fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
), patch.object(store, "is_partial_state_room", mock_is_partial_state_room): ),
patch.object(store, "is_partial_state_room", mock_is_partial_state_room),
):
# Start the partial state sync. # Start the partial state sync.
fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id") fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1) self.assertEqual(mock_sync_partial_state_room.call_count, 1)

View file

@ -172,20 +172,25 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
) )
) )
with patch.object( with (
patch.object(
self.handler.federation_handler.federation_client, self.handler.federation_handler.federation_client,
"make_membership_event", "make_membership_event",
mock_make_membership_event, mock_make_membership_event,
), patch.object( ),
patch.object(
self.handler.federation_handler.federation_client, self.handler.federation_handler.federation_client,
"send_join", "send_join",
mock_send_join, mock_send_join,
), patch( ),
patch(
"synapse.event_auth._is_membership_change_allowed", "synapse.event_auth._is_membership_change_allowed",
return_value=None, return_value=None,
), patch( ),
patch(
"synapse.handlers.federation_event.check_state_dependent_auth_rules", "synapse.handlers.federation_event.check_state_dependent_auth_rules",
return_value=None, return_value=None,
),
): ):
self.get_success( self.get_success(
self.handler.update_membership( self.handler.update_membership(

View file

@ -1262,3 +1262,35 @@ class SyncStateAfterTestCase(tests.unittest.HomeserverTestCase):
) )
) )
self.assertEqual(state[("m.test_event", "")], second_state["event_id"]) self.assertEqual(state[("m.test_event", "")], second_state["event_id"])
def test_incremental_sync_lazy_loaded_no_timeline(self) -> None:
"""Test that lazy-loading with an empty timeline doesn't return the full
state.
There was a bug where an empty state filter would cause the DB to return
the full state, rather than an empty set.
"""
user = self.register_user("user", "password")
tok = self.login("user", "password")
# Create a room as the user and set some custom state.
joined_room = self.helper.create_room_as(user, tok=tok)
since_token = self.hs.get_event_sources().get_current_token()
end_stream_token = self.hs.get_event_sources().get_current_token()
state = self.get_success(
self.sync_handler._compute_state_delta_for_incremental_sync(
room_id=joined_room,
sync_config=generate_sync_config(user, use_state_after=True),
batch=TimelineBatch(
prev_batch=end_stream_token, events=[], limited=True
),
since_token=since_token,
end_token=end_stream_token,
members_to_fetch=set(),
timeline_state={},
)
)
self.assertEqual(state, {})

View file

@ -27,6 +27,7 @@ from typing import (
Callable, Callable,
ContextManager, ContextManager,
Dict, Dict,
Generator,
List, List,
Optional, Optional,
Set, Set,
@ -49,7 +50,10 @@ from synapse.http.server import (
respond_with_json, respond_with_json,
) )
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.context import LoggingContext, make_deferred_yieldable from synapse.logging.context import (
LoggingContext,
make_deferred_yieldable,
)
from synapse.types import JsonDict from synapse.types import JsonDict
from tests.server import FakeChannel, make_request from tests.server import FakeChannel, make_request
@ -199,7 +203,7 @@ def make_request_with_cancellation_test(
# #
# We would like to trigger a cancellation at the first `await`, re-run the # We would like to trigger a cancellation at the first `await`, re-run the
# request and cancel at the second `await`, and so on. By patching # request and cancel at the second `await`, and so on. By patching
# `Deferred.__next__`, we can intercept `await`s, track which ones we have or # `Deferred.__await__`, we can intercept `await`s, track which ones we have or
# have not seen, and force them to block when they wouldn't have. # have not seen, and force them to block when they wouldn't have.
# The set of previously seen `await`s. # The set of previously seen `await`s.
@ -211,7 +215,7 @@ def make_request_with_cancellation_test(
) )
for request_number in itertools.count(1): for request_number in itertools.count(1):
deferred_patch = Deferred__next__Patch(seen_awaits, request_number) deferred_patch = Deferred__await__Patch(seen_awaits, request_number)
try: try:
with mock.patch( with mock.patch(
@ -250,6 +254,8 @@ def make_request_with_cancellation_test(
) )
if respond_mock.called: if respond_mock.called:
_log_for_request(request_number, "--- response finished ---")
# The request ran to completion and we are done with testing it. # The request ran to completion and we are done with testing it.
# `respond_with_json` writes the response asynchronously, so we # `respond_with_json` writes the response asynchronously, so we
@ -311,8 +317,8 @@ def make_request_with_cancellation_test(
assert False, "unreachable" # noqa: B011 assert False, "unreachable" # noqa: B011
class Deferred__next__Patch: class Deferred__await__Patch:
"""A `Deferred.__next__` patch that will intercept `await`s and force them """A `Deferred.__await__` patch that will intercept `await`s and force them
to block once it sees a new `await`. to block once it sees a new `await`.
When done with the patch, `unblock_awaits()` must be called to clean up after any When done with the patch, `unblock_awaits()` must be called to clean up after any
@ -322,7 +328,7 @@ class Deferred__next__Patch:
Usage: Usage:
seen_awaits = set() seen_awaits = set()
deferred_patch = Deferred__next__Patch(seen_awaits, 1) deferred_patch = Deferred__await__Patch(seen_awaits, 1)
try: try:
with deferred_patch.patch(): with deferred_patch.patch():
# do things # do things
@ -335,14 +341,14 @@ class Deferred__next__Patch:
""" """
Args: Args:
seen_awaits: The set of stack traces of `await`s that have been previously seen_awaits: The set of stack traces of `await`s that have been previously
seen. When the `Deferred.__next__` patch sees a new `await`, it will add seen. When the `Deferred.__await__` patch sees a new `await`, it will add
it to the set. it to the set.
request_number: The request number to log against. request_number: The request number to log against.
""" """
self._request_number = request_number self._request_number = request_number
self._seen_awaits = seen_awaits self._seen_awaits = seen_awaits
self._original_Deferred___next__ = Deferred.__next__ # type: ignore[misc,unused-ignore] self._original_Deferred__await__ = Deferred.__await__ # type: ignore[misc,unused-ignore]
# The number of `await`s on `Deferred`s we have seen so far. # The number of `await`s on `Deferred`s we have seen so far.
self.awaits_seen = 0 self.awaits_seen = 0
@ -350,8 +356,13 @@ class Deferred__next__Patch:
# Whether we have seen a new `await` not in `seen_awaits`. # Whether we have seen a new `await` not in `seen_awaits`.
self.new_await_seen = False self.new_await_seen = False
# Whether to block new await points we see. This gets set to False once
# we have cancelled the request to allow things to run after
# cancellation.
self._block_new_awaits = True
# To force `await`s on resolved `Deferred`s to block, we make up a new # To force `await`s on resolved `Deferred`s to block, we make up a new
# unresolved `Deferred` and return it out of `Deferred.__next__` / # unresolved `Deferred` and return it out of `Deferred.__await__` /
# `coroutine.send()`. We have to resolve it later, in case the `await`ing # `coroutine.send()`. We have to resolve it later, in case the `await`ing
# coroutine is part of some shared processing, such as `@cached`. # coroutine is part of some shared processing, such as `@cached`.
self._to_unblock: Dict[Deferred, Union[object, Failure]] = {} self._to_unblock: Dict[Deferred, Union[object, Failure]] = {}
@ -360,15 +371,15 @@ class Deferred__next__Patch:
self._previous_stack: List[inspect.FrameInfo] = [] self._previous_stack: List[inspect.FrameInfo] = []
def patch(self) -> ContextManager[Mock]: def patch(self) -> ContextManager[Mock]:
"""Returns a context manager which patches `Deferred.__next__`.""" """Returns a context manager which patches `Deferred.__await__`."""
def Deferred___next__( def Deferred___await__(
deferred: "Deferred[T]", value: object = None deferred: "Deferred[T]",
) -> "Deferred[T]": ) -> Generator["Deferred[T]", None, T]:
"""Intercepts `await`s on `Deferred`s and rigs them to block once we have """Intercepts calls to `__await__`, which returns a generator
seen enough of them. yielding deferreds that we await on.
`Deferred.__next__` will normally: The generator for `__await__` will normally:
* return `self` if the `Deferred` is unresolved, in which case * return `self` if the `Deferred` is unresolved, in which case
`coroutine.send()` will return the `Deferred`, and `coroutine.send()` will return the `Deferred`, and
`_defer.inlineCallbacks` will stop running the coroutine until the `_defer.inlineCallbacks` will stop running the coroutine until the
@ -376,9 +387,43 @@ class Deferred__next__Patch:
* raise a `StopIteration(result)`, containing the result of the `await`. * raise a `StopIteration(result)`, containing the result of the `await`.
* raise another exception, which will come out of the `await`. * raise another exception, which will come out of the `await`.
""" """
# Get the original generator.
gen = self._original_Deferred__await__(deferred)
# Run the generator, handling each iteration to see if we need to
# block.
try:
while True:
# We've hit a new await point (or the deferred has
# completed), handle it.
handle_next_iteration(deferred)
# Continue on.
yield gen.send(None)
except StopIteration as e:
# We need to convert `StopIteration` into a normal return.
return e.value
def handle_next_iteration(
deferred: "Deferred[T]",
) -> None:
"""Intercepts `await`s on `Deferred`s and rigs them to block once we have
seen enough of them.
Args:
deferred: The deferred that we've captured and are intercepting
`await` calls within.
"""
if not self._block_new_awaits:
# We're no longer blocking awaits points
return
self.awaits_seen += 1 self.awaits_seen += 1
stack = _get_stack(skip_frames=1) stack = _get_stack(
skip_frames=2 # Ignore this function and `Deferred___await__` in stack trace
)
stack_hash = _hash_stack(stack) stack_hash = _hash_stack(stack)
if stack_hash not in self._seen_awaits: if stack_hash not in self._seen_awaits:
@ -389,20 +434,29 @@ class Deferred__next__Patch:
if not self.new_await_seen: if not self.new_await_seen:
# This `await` isn't interesting. Let it proceed normally. # This `await` isn't interesting. Let it proceed normally.
_log_await_stack(
stack,
self._previous_stack,
self._request_number,
"already seen",
)
# Don't log the stack. It's been seen before in a previous run. # Don't log the stack. It's been seen before in a previous run.
self._previous_stack = stack self._previous_stack = stack
return self._original_Deferred___next__(deferred, value) return
# We want to block at the current `await`. # We want to block at the current `await`.
if deferred.called and not deferred.paused: if deferred.called and not deferred.paused:
# This `Deferred` already has a result. # This `Deferred` already has a result. We chain a new,
# We return a new, unresolved, `Deferred` for `_inlineCallbacks` to wait # unresolved, `Deferred` to the end of this Deferred that it
# on. This blocks the coroutine that did this `await`. # will wait on. This blocks the coroutine that did this `await`.
# We queue it up for unblocking later. # We queue it up for unblocking later.
new_deferred: "Deferred[T]" = Deferred() new_deferred: "Deferred[T]" = Deferred()
self._to_unblock[new_deferred] = deferred.result self._to_unblock[new_deferred] = deferred.result
deferred.addBoth(lambda _: make_deferred_yieldable(new_deferred))
_log_await_stack( _log_await_stack(
stack, stack,
self._previous_stack, self._previous_stack,
@ -411,7 +465,9 @@ class Deferred__next__Patch:
) )
self._previous_stack = stack self._previous_stack = stack
return make_deferred_yieldable(new_deferred) # Continue iterating on the deferred now that we've blocked it
# again.
return
# This `Deferred` does not have a result yet. # This `Deferred` does not have a result yet.
# The `await` will block normally, so we don't have to do anything. # The `await` will block normally, so we don't have to do anything.
@ -423,9 +479,9 @@ class Deferred__next__Patch:
) )
self._previous_stack = stack self._previous_stack = stack
return self._original_Deferred___next__(deferred, value) return
return mock.patch.object(Deferred, "__next__", new=Deferred___next__) return mock.patch.object(Deferred, "__await__", new=Deferred___await__)
def unblock_awaits(self) -> None: def unblock_awaits(self) -> None:
"""Unblocks any shared processing that we forced to block. """Unblocks any shared processing that we forced to block.
@ -433,6 +489,9 @@ class Deferred__next__Patch:
Must be called when done, otherwise processing shared between multiple requests, Must be called when done, otherwise processing shared between multiple requests,
such as database queries started by `@cached`, will become permanently stuck. such as database queries started by `@cached`, will become permanently stuck.
""" """
# Also disable blocking at future await points
self._block_new_awaits = False
to_unblock = self._to_unblock to_unblock = self._to_unblock
self._to_unblock = {} self._to_unblock = {}
for deferred, result in to_unblock.items(): for deferred, result in to_unblock.items():

View file

@ -120,9 +120,11 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
# #
# We have seen stringy and null values for "room" in the wild, so presumably # We have seen stringy and null values for "room" in the wild, so presumably
# some of this validation was missing in the past. # some of this validation was missing in the past.
with patch("synapse.events.validator.validate_canonicaljson"), patch( with (
"synapse.events.validator.jsonschema.validate" patch("synapse.events.validator.validate_canonicaljson"),
), patch("synapse.handlers.event_auth.check_state_dependent_auth_rules"): patch("synapse.events.validator.jsonschema.validate"),
patch("synapse.handlers.event_auth.check_state_dependent_auth_rules"),
):
pl_event_id = self.helper.send_state( pl_event_id = self.helper.send_state(
self.room_id, self.room_id,
"m.room.power_levels", "m.room.power_levels",

View file

@ -58,6 +58,7 @@ import twisted
from twisted.enterprise import adbapi from twisted.enterprise import adbapi
from twisted.internet import address, tcp, threads, udp from twisted.internet import address, tcp, threads, udp
from twisted.internet._resolver import SimpleResolverComplexifier from twisted.internet._resolver import SimpleResolverComplexifier
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import ( from twisted.internet.interfaces import (
@ -73,6 +74,7 @@ from twisted.internet.interfaces import (
IReactorPluggableNameResolver, IReactorPluggableNameResolver,
IReactorTime, IReactorTime,
IResolverSimple, IResolverSimple,
ITCPTransport,
ITransport, ITransport,
) )
from twisted.internet.protocol import ClientFactory, DatagramProtocol, Factory from twisted.internet.protocol import ClientFactory, DatagramProtocol, Factory
@ -780,7 +782,7 @@ def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
return clock, hs_clock return clock, hs_clock
@implementer(ITransport) @implementer(ITCPTransport)
@attr.s(cmp=False, auto_attribs=True) @attr.s(cmp=False, auto_attribs=True)
class FakeTransport: class FakeTransport:
""" """
@ -809,12 +811,12 @@ class FakeTransport:
will get called back for connectionLost() notifications etc. will get called back for connectionLost() notifications etc.
""" """
_peer_address: IAddress = attr.Factory( _peer_address: Union[IPv4Address, IPv6Address] = attr.Factory(
lambda: address.IPv4Address("TCP", "127.0.0.1", 5678) lambda: address.IPv4Address("TCP", "127.0.0.1", 5678)
) )
"""The value to be returned by getPeer""" """The value to be returned by getPeer"""
_host_address: IAddress = attr.Factory( _host_address: Union[IPv4Address, IPv6Address] = attr.Factory(
lambda: address.IPv4Address("TCP", "127.0.0.1", 1234) lambda: address.IPv4Address("TCP", "127.0.0.1", 1234)
) )
"""The value to be returned by getHost""" """The value to be returned by getHost"""
@ -826,10 +828,10 @@ class FakeTransport:
producer: Optional[IPushProducer] = None producer: Optional[IPushProducer] = None
autoflush: bool = True autoflush: bool = True
def getPeer(self) -> IAddress: def getPeer(self) -> Union[IPv4Address, IPv6Address]:
return self._peer_address return self._peer_address
def getHost(self) -> IAddress: def getHost(self) -> Union[IPv4Address, IPv6Address]:
return self._host_address return self._host_address
def loseConnection(self) -> None: def loseConnection(self) -> None:
@ -939,6 +941,51 @@ class FakeTransport:
logger.info("FakeTransport: Buffer now empty, completing disconnect") logger.info("FakeTransport: Buffer now empty, completing disconnect")
self.disconnected = True self.disconnected = True
## ITCPTransport methods. ##
def loseWriteConnection(self) -> None:
"""
Half-close the write side of a TCP connection.
If the protocol instance this is attached to provides
IHalfCloseableProtocol, it will get notified when the operation is
done. When closing write connection, as with loseConnection this will
only happen when buffer has emptied and there is no registered
producer.
"""
raise NotImplementedError()
def getTcpNoDelay(self) -> bool:
"""
Return if C{TCP_NODELAY} is enabled.
"""
return False
def setTcpNoDelay(self, enabled: bool) -> None:
"""
Enable/disable C{TCP_NODELAY}.
Enabling C{TCP_NODELAY} turns off Nagle's algorithm. Small packets are
sent sooner, possibly at the expense of overall throughput.
"""
# Ignore setting this.
def getTcpKeepAlive(self) -> bool:
"""
Return if C{SO_KEEPALIVE} is enabled.
"""
return False
def setTcpKeepAlive(self, enabled: bool) -> None:
"""
Enable/disable C{SO_KEEPALIVE}.
Enabling C{SO_KEEPALIVE} sends packets periodically when the connection
is otherwise idle, usually once every two hours. They are intended
to allow detection of lost peers in a non-infinite amount of time.
"""
# Ignore setting this.
def connect_client( def connect_client(
reactor: ThreadedMemoryReactorClock, client_id: int reactor: ThreadedMemoryReactorClock, client_id: int

View file

@ -1465,20 +1465,25 @@ class GetCurrentStateDeltaMembershipChangesForUserFederationTestCase(
) )
) )
with patch.object( with (
patch.object(
self.room_member_handler.federation_handler.federation_client, self.room_member_handler.federation_handler.federation_client,
"make_membership_event", "make_membership_event",
mock_make_membership_event, mock_make_membership_event,
), patch.object( ),
patch.object(
self.room_member_handler.federation_handler.federation_client, self.room_member_handler.federation_handler.federation_client,
"send_join", "send_join",
mock_send_join, mock_send_join,
), patch( ),
patch(
"synapse.event_auth._is_membership_change_allowed", "synapse.event_auth._is_membership_change_allowed",
return_value=None, return_value=None,
), patch( ),
patch(
"synapse.handlers.federation_event.check_state_dependent_auth_rules", "synapse.handlers.federation_event.check_state_dependent_auth_rules",
return_value=None, return_value=None,
),
): ):
self.get_success( self.get_success(
self.room_member_handler.update_membership( self.room_member_handler.update_membership(

View file

@ -320,12 +320,19 @@ class ConcurrentlyExecuteTest(TestCase):
await concurrently_execute(callback, [1], 2) await concurrently_execute(callback, [1], 2)
except _TestException as e: except _TestException as e:
tb = traceback.extract_tb(e.__traceback__) tb = traceback.extract_tb(e.__traceback__)
# we expect to see "caller", "concurrently_execute", "callback",
# and some magic from inside ensureDeferred that happens when .fail # Remove twisted internals from the stack, as we don't care
# is called. # about the precise details.
tb = traceback.StackSummary(
t for t in tb if "/twisted/" not in t.filename
)
# we expect to see "caller", "concurrently_execute" at the top of the stack
self.assertEqual(tb[0].name, "caller") self.assertEqual(tb[0].name, "caller")
self.assertEqual(tb[1].name, "concurrently_execute") self.assertEqual(tb[1].name, "concurrently_execute")
self.assertEqual(tb[-2].name, "callback") # ... some stack frames from the implementation of `concurrently_execute` ...
# and at the bottom of the stack we expect to see "callback"
self.assertEqual(tb[-1].name, "callback")
else: else:
self.fail("No exception thrown") self.fail("No exception thrown")

View file

@ -109,10 +109,13 @@ class TestDependencyChecker(TestCase):
def test_checks_ignore_dev_dependencies(self) -> None: def test_checks_ignore_dev_dependencies(self) -> None:
"""Both generic and per-extra checks should ignore dev dependencies.""" """Both generic and per-extra checks should ignore dev dependencies."""
with patch( with (
patch(
"synapse.util.check_dependencies.metadata.requires", "synapse.util.check_dependencies.metadata.requires",
return_value=["dummypkg >= 1; extra == 'mypy'"], return_value=["dummypkg >= 1; extra == 'mypy'"],
), patch("synapse.util.check_dependencies.RUNTIME_EXTRAS", {"cool-extra"}): ),
patch("synapse.util.check_dependencies.RUNTIME_EXTRAS", {"cool-extra"}),
):
# We're testing that none of these calls raise. # We're testing that none of these calls raise.
with self.mock_installed_package(None): with self.mock_installed_package(None):
check_requirements() check_requirements()
@ -141,10 +144,13 @@ class TestDependencyChecker(TestCase):
def test_check_for_extra_dependencies(self) -> None: def test_check_for_extra_dependencies(self) -> None:
"""Complain if a package required for an extra is missing or old.""" """Complain if a package required for an extra is missing or old."""
with patch( with (
patch(
"synapse.util.check_dependencies.metadata.requires", "synapse.util.check_dependencies.metadata.requires",
return_value=["dummypkg >= 1; extra == 'cool-extra'"], return_value=["dummypkg >= 1; extra == 'cool-extra'"],
), patch("synapse.util.check_dependencies.RUNTIME_EXTRAS", {"cool-extra"}): ),
patch("synapse.util.check_dependencies.RUNTIME_EXTRAS", {"cool-extra"}),
):
with self.mock_installed_package(None): with self.mock_installed_package(None):
self.assertRaises(DependencyException, check_requirements, "cool-extra") self.assertRaises(DependencyException, check_requirements, "cool-extra")
with self.mock_installed_package(old): with self.mock_installed_package(old):

View file

@ -1,5 +1,5 @@
[tox] [tox]
envlist = py37, py38, py39, py310 envlist = py39, py310, py311, py312, py313
# we require tox>=2.3.2 for the fix to https://github.com/tox-dev/tox/issues/208 # we require tox>=2.3.2 for the fix to https://github.com/tox-dev/tox/issues/208
minversion = 2.3.2 minversion = 2.3.2