Merge branch 'develop' into madlittlemods/17368-bust-_membership_stream_cache

This commit is contained in:
Eric Eastwood 2024-12-02 09:49:53 -06:00
commit f5f0e36ec1
18 changed files with 338 additions and 93 deletions

View file

@ -212,7 +212,8 @@ jobs:
mv debs*/* debs/
tar -cvJf debs.tar.xz debs
- name: Attach to release
uses: softprops/action-gh-release@v2
# Pinned to work around https://github.com/softprops/action-gh-release/issues/445
uses: softprops/action-gh-release@v2.0.5
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:

1
changelog.d/17962.misc Normal file
View file

@ -0,0 +1 @@
Fix new scheduled tasks jumping the queue.

View file

@ -0,0 +1 @@
Use stable `M_USER_LOCKED` error code for locked accounts, as per [Matrix 1.12](https://spec.matrix.org/v1.12/client-server-api/#account-locking).

1
changelog.d/17970.bugfix Normal file
View file

@ -0,0 +1 @@
Fix release process to not create duplicate releases.

1
changelog.d/17972.misc Normal file
View file

@ -0,0 +1 @@
Consolidate SSO redirects through `/_matrix/client/v3/login/sso/redirect(/{idpId})`.

12
poetry.lock generated
View file

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]]
name = "annotated-types"
@ -1917,13 +1917,13 @@ test = ["pretend", "pytest (>=3.0.1)", "pytest-rerunfailures"]
[[package]]
name = "pysaml2"
version = "7.3.1"
version = "7.5.0"
description = "Python implementation of SAML Version 2 Standard"
optional = true
python-versions = ">=3.6.2,<4.0.0"
python-versions = ">=3.9,<4.0"
files = [
{file = "pysaml2-7.3.1-py3-none-any.whl", hash = "sha256:2cc66e7a371d3f5ff9601f0ed93b5276cca816fce82bb38447d5a0651f2f5193"},
{file = "pysaml2-7.3.1.tar.gz", hash = "sha256:eab22d187c6dd7707c58b5bb1688f9b8e816427667fc99d77f54399e15cd0a0a"},
{file = "pysaml2-7.5.0-py3-none-any.whl", hash = "sha256:bc6627cc344476a83c757f440a73fda1369f13b6fda1b4e16bca63ffbabb5318"},
{file = "pysaml2-7.5.0.tar.gz", hash = "sha256:f36871d4e5ee857c6b85532e942550d2cf90ea4ee943d75eb681044bbc4f54f7"},
]
[package.dependencies]
@ -1933,7 +1933,7 @@ pyopenssl = "*"
python-dateutil = "*"
pytz = "*"
requests = ">=2,<3"
xmlschema = ">=1.2.1"
xmlschema = ">=2,<3"
[package.extras]
s2repoze = ["paste", "repoze.who", "zope.interface"]

View file

@ -87,8 +87,7 @@ class Codes(str, Enum):
WEAK_PASSWORD = "M_WEAK_PASSWORD"
INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
USER_DEACTIVATED = "M_USER_DEACTIVATED"
# USER_LOCKED = "M_USER_LOCKED"
USER_LOCKED = "ORG_MATRIX_MSC3939_USER_LOCKED"
USER_LOCKED = "M_USER_LOCKED"
NOT_YET_UPLOADED = "M_NOT_YET_UPLOADED"
CANNOT_OVERWRITE_MEDIA = "M_CANNOT_OVERWRITE_MEDIA"

View file

@ -23,7 +23,8 @@
import hmac
from hashlib import sha256
from urllib.parse import urlencode
from typing import Optional
from urllib.parse import urlencode, urljoin
from synapse.config import ConfigError
from synapse.config.homeserver import HomeServerConfig
@ -66,3 +67,42 @@ class ConsentURIBuilder:
urlencode({"u": user_id, "h": mac}),
)
return consent_uri
class LoginSSORedirectURIBuilder:
def __init__(self, hs_config: HomeServerConfig):
self._public_baseurl = hs_config.server.public_baseurl
def build_login_sso_redirect_uri(
self, *, idp_id: Optional[str], client_redirect_url: str
) -> str:
"""Build a `/login/sso/redirect` URI for the given identity provider.
Builds `/_matrix/client/v3/login/sso/redirect/{idpId}?redirectUrl=xxx` when `idp_id` is specified.
Otherwise, builds `/_matrix/client/v3/login/sso/redirect?redirectUrl=xxx` when `idp_id` is `None`.
Args:
idp_id: Optional ID of the identity provider
client_redirect_url: URL to redirect the user to after login
Returns
The URI to follow when choosing a specific identity provider.
"""
base_url = urljoin(
self._public_baseurl,
f"{CLIENT_API_PREFIX}/v3/login/sso/redirect",
)
serialized_query_parameters = urlencode({"redirectUrl": client_redirect_url})
if idp_id:
resultant_url = urljoin(
# We have to add a trailing slash to the base URL to ensure that the
# last path segment is not stripped away when joining with another path.
f"{base_url}/",
f"{idp_id}?{serialized_query_parameters}",
)
else:
resultant_url = f"{base_url}?{serialized_query_parameters}"
return resultant_url

View file

@ -20,7 +20,7 @@
#
#
from typing import Any, List
from typing import Any, List, Optional
from synapse.config.sso import SsoAttributeRequirement
from synapse.types import JsonDict
@ -46,7 +46,9 @@ class CasConfig(Config):
# TODO Update this to a _synapse URL.
public_baseurl = self.root.server.public_baseurl
self.cas_service_url = public_baseurl + "_matrix/client/r0/login/cas/ticket"
self.cas_service_url: Optional[str] = (
public_baseurl + "_matrix/client/r0/login/cas/ticket"
)
self.cas_protocol_version = cas_config.get("protocol_version")
if (

View file

@ -332,8 +332,14 @@ class ServerConfig(Config):
logger.info("Using default public_baseurl %s", public_baseurl)
else:
self.serve_client_wellknown = True
# Ensure that public_baseurl ends with a trailing slash
if public_baseurl[-1] != "/":
public_baseurl += "/"
# Scrutinize user-provided config
if not isinstance(public_baseurl, str):
raise ConfigError("Must be a string", ("public_baseurl",))
self.public_baseurl = public_baseurl
# check that public_baseurl is valid

View file

@ -495,7 +495,7 @@ class LockReleasedCommand(Command):
class NewActiveTaskCommand(_SimpleCommand):
"""Sent to inform instance handling background tasks that a new active task is available to run.
"""Sent to inform instance handling background tasks that a new task is ready to run.
Format::

View file

@ -727,7 +727,7 @@ class ReplicationCommandHandler:
) -> None:
"""Called when get a new NEW_ACTIVE_TASK command."""
if self._task_scheduler:
self._task_scheduler.launch_task_by_id(cmd.data)
self._task_scheduler.on_new_task(cmd.data)
def new_connection(self, connection: IReplicationConnection) -> None:
"""Called when we have a new connection."""

View file

@ -21,6 +21,7 @@
import logging
from typing import TYPE_CHECKING
from synapse.api.urls import LoginSSORedirectURIBuilder
from synapse.http.server import (
DirectServeHtmlResource,
finish_request,
@ -49,6 +50,8 @@ class PickIdpResource(DirectServeHtmlResource):
hs.config.sso.sso_login_idp_picker_template
)
self._server_name = hs.hostname
self._public_baseurl = hs.config.server.public_baseurl
self._login_sso_redirect_url_builder = LoginSSORedirectURIBuilder(hs.config)
async def _async_render_GET(self, request: SynapseRequest) -> None:
client_redirect_url = parse_string(
@ -56,25 +59,23 @@ class PickIdpResource(DirectServeHtmlResource):
)
idp = parse_string(request, "idp", required=False)
# if we need to pick an IdP, do so
# If we need to pick an IdP, do so
if not idp:
return await self._serve_id_picker(request, client_redirect_url)
# otherwise, redirect to the IdP's redirect URI
providers = self._sso_handler.get_identity_providers()
auth_provider = providers.get(idp)
if not auth_provider:
logger.info("Unknown idp %r", idp)
self._sso_handler.render_error(
request, "unknown_idp", "Unknown identity provider ID"
# Otherwise, redirect to the login SSO redirect endpoint for the given IdP
# (which will in turn take us to the the IdP's redirect URI).
#
# We could go directly to the IdP's redirect URI, but this way we ensure that
# the user goes through the same logic as normal flow. Additionally, if a proxy
# needs to intercept the request, it only needs to intercept the one endpoint.
sso_login_redirect_url = (
self._login_sso_redirect_url_builder.build_login_sso_redirect_uri(
idp_id=idp, client_redirect_url=client_redirect_url
)
return
sso_url = await auth_provider.handle_redirect_request(
request, client_redirect_url.encode("utf8")
)
logger.info("Redirecting to %s", sso_url)
request.redirect(sso_url)
logger.info("Redirecting to %s", sso_login_redirect_url)
request.redirect(sso_login_redirect_url)
finish_request(request)
async def _serve_id_picker(

View file

@ -174,9 +174,10 @@ class TaskScheduler:
The id of the scheduled task
"""
status = TaskStatus.SCHEDULED
start_now = False
if timestamp is None or timestamp < self._clock.time_msec():
timestamp = self._clock.time_msec()
status = TaskStatus.ACTIVE
start_now = True
task = ScheduledTask(
random_string(16),
@ -190,9 +191,11 @@ class TaskScheduler:
)
await self._store.insert_scheduled_task(task)
if status == TaskStatus.ACTIVE:
# If the task is ready to run immediately, run the scheduling algorithm now
# rather than waiting
if start_now:
if self._run_background_tasks:
await self._launch_task(task)
self._launch_scheduled_tasks()
else:
self._hs.get_replication_command_handler().send_new_active_task(task.id)
@ -300,23 +303,13 @@ class TaskScheduler:
raise Exception(f"Task {id} is currently ACTIVE and can't be deleted")
await self._store.delete_scheduled_task(id)
def launch_task_by_id(self, id: str) -> None:
"""Try launching the task with the given ID."""
# Don't bother trying to launch new tasks if we're already at capacity.
if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS:
return
def on_new_task(self, task_id: str) -> None:
"""Handle a notification that a new ready-to-run task has been added to the queue"""
# Just run the scheduler
self._launch_scheduled_tasks()
run_as_background_process("launch_task_by_id", self._launch_task_by_id, id)
async def _launch_task_by_id(self, id: str) -> None:
"""Helper async function for `launch_task_by_id`."""
task = await self.get_task(id)
if task:
await self._launch_task(task)
@wrap_as_background_process("launch_scheduled_tasks")
async def _launch_scheduled_tasks(self) -> None:
"""Retrieve and launch scheduled tasks that should be running at that time."""
def _launch_scheduled_tasks(self) -> None:
"""Retrieve and launch scheduled tasks that should be running at this time."""
# Don't bother trying to launch new tasks if we're already at capacity.
if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS:
return
@ -326,20 +319,26 @@ class TaskScheduler:
self._launching_new_tasks = True
try:
for task in await self.get_tasks(
statuses=[TaskStatus.ACTIVE], limit=self.MAX_CONCURRENT_RUNNING_TASKS
):
await self._launch_task(task)
for task in await self.get_tasks(
statuses=[TaskStatus.SCHEDULED],
max_timestamp=self._clock.time_msec(),
limit=self.MAX_CONCURRENT_RUNNING_TASKS,
):
await self._launch_task(task)
async def inner() -> None:
try:
for task in await self.get_tasks(
statuses=[TaskStatus.ACTIVE],
limit=self.MAX_CONCURRENT_RUNNING_TASKS,
):
# _launch_task will ignore tasks that we're already running, and
# will also do nothing if we're already at the maximum capacity.
await self._launch_task(task)
for task in await self.get_tasks(
statuses=[TaskStatus.SCHEDULED],
max_timestamp=self._clock.time_msec(),
limit=self.MAX_CONCURRENT_RUNNING_TASKS,
):
await self._launch_task(task)
finally:
self._launching_new_tasks = False
finally:
self._launching_new_tasks = False
run_as_background_process("launch_scheduled_tasks", inner)
@wrap_as_background_process("clean_scheduled_tasks")
async def _clean_scheduled_tasks(self) -> None:

55
tests/api/test_urls.py Normal file
View file

@ -0,0 +1,55 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2024 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.urls import LoginSSORedirectURIBuilder
from synapse.server import HomeServer
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
# a (valid) url with some annoying characters in. %3D is =, %26 is &, %2B is +
TRICKY_TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="%26=o"'
class LoginSSORedirectURIBuilderTestCase(HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.login_sso_redirect_url_builder = LoginSSORedirectURIBuilder(hs.config)
def test_no_idp_id(self) -> None:
self.assertEqual(
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
idp_id=None, client_redirect_url="http://example.com/redirect"
),
"https://test/_matrix/client/v3/login/sso/redirect?redirectUrl=http%3A%2F%2Fexample.com%2Fredirect",
)
def test_explicit_idp_id(self) -> None:
self.assertEqual(
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
idp_id="oidc-github", client_redirect_url="http://example.com/redirect"
),
"https://test/_matrix/client/v3/login/sso/redirect/oidc-github?redirectUrl=http%3A%2F%2Fexample.com%2Fredirect",
)
def test_tricky_redirect_uri(self) -> None:
self.assertEqual(
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
idp_id="oidc-github",
client_redirect_url=TRICKY_TEST_CLIENT_REDIRECT_URL,
),
"https://test/_matrix/client/v3/login/sso/redirect/oidc-github?redirectUrl=https%3A%2F%2Fx%3F%3Cab+c%3E%26q%22%2B%253D%252B%22%3D%22f%C3%B6%2526%3Do%22",
)

View file

@ -43,6 +43,7 @@ from twisted.web.resource import Resource
import synapse.rest.admin
from synapse.api.constants import ApprovalNoticeMedium, LoginType
from synapse.api.errors import Codes
from synapse.api.urls import LoginSSORedirectURIBuilder
from synapse.appservice import ApplicationService
from synapse.http.client import RawHeaders
from synapse.module_api import ModuleApi
@ -69,6 +70,10 @@ try:
except ImportError:
HAS_JWT = False
import logging
logger = logging.getLogger(__name__)
# synapse server name: used to populate public_baseurl in some tests
SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse"
@ -77,7 +82,7 @@ SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse"
# FakeChannel.isSecure() returns False, so synapse will see the requested uri as
# http://..., so using http in the public_baseurl stops Synapse trying to redirect to
# https://....
BASE_URL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,)
PUBLIC_BASEURL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,)
# CAS server used in some tests
CAS_SERVER = "https://fake.test"
@ -109,6 +114,23 @@ ADDITIONAL_LOGIN_FLOWS = [
]
def get_relative_uri_from_absolute_uri(absolute_uri: str) -> str:
"""
Peels off the path and query string from an absolute URI. Useful when interacting
with `make_request(...)` util function which expects a relative path instead of a
full URI.
"""
parsed_uri = urllib.parse.urlparse(absolute_uri)
# Sanity check that we're working with an absolute URI
assert parsed_uri.scheme == "http" or parsed_uri.scheme == "https"
relative_uri = parsed_uri.path
if parsed_uri.query:
relative_uri += "?" + parsed_uri.query
return relative_uri
class TestSpamChecker:
def __init__(self, config: None, api: ModuleApi):
api.register_spam_checker_callbacks(
@ -614,7 +636,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["public_baseurl"] = BASE_URL
config["public_baseurl"] = PUBLIC_BASEURL
config["cas_config"] = {
"enabled": True,
@ -653,6 +675,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
]
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.login_sso_redirect_url_builder = LoginSSORedirectURIBuilder(hs.config)
def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
d.update(build_synapse_client_resource_tree(self.hs))
@ -725,6 +750,32 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ "&idp=cas",
shorthand=False,
)
self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
sso_login_redirect_uri = location_headers[0]
# it should redirect us to the standard login SSO redirect flow
self.assertEqual(
sso_login_redirect_uri,
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
idp_id="cas", client_redirect_url=TEST_CLIENT_REDIRECT_URL
),
)
# follow the redirect
channel = self.make_request(
"GET",
# We have to make this relative to be compatible with `make_request(...)`
get_relative_uri_from_absolute_uri(sso_login_redirect_uri),
# We have to set the Host header to match the `public_baseurl` to avoid
# the extra redirect in the `SsoRedirectServlet` in order for the
# cookies to be visible.
custom_headers=[
("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME),
],
)
self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
@ -750,6 +801,32 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ "&idp=saml",
)
self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
sso_login_redirect_uri = location_headers[0]
# it should redirect us to the standard login SSO redirect flow
self.assertEqual(
sso_login_redirect_uri,
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
idp_id="saml", client_redirect_url=TEST_CLIENT_REDIRECT_URL
),
)
# follow the redirect
channel = self.make_request(
"GET",
# We have to make this relative to be compatible with `make_request(...)`
get_relative_uri_from_absolute_uri(sso_login_redirect_uri),
# We have to set the Host header to match the `public_baseurl` to avoid
# the extra redirect in the `SsoRedirectServlet` in order for the
# cookies to be visible.
custom_headers=[
("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME),
],
)
self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
@ -773,13 +850,38 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
# pick the default OIDC provider
channel = self.make_request(
"GET",
"/_synapse/client/pick_idp?redirectUrl="
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ "&idp=oidc",
f"/_synapse/client/pick_idp?redirectUrl={urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)}&idp=oidc",
)
self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
sso_login_redirect_uri = location_headers[0]
# it should redirect us to the standard login SSO redirect flow
self.assertEqual(
sso_login_redirect_uri,
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
idp_id="oidc", client_redirect_url=TEST_CLIENT_REDIRECT_URL
),
)
with fake_oidc_server.patch_homeserver(hs=self.hs):
# follow the redirect
channel = self.make_request(
"GET",
# We have to make this relative to be compatible with `make_request(...)`
get_relative_uri_from_absolute_uri(sso_login_redirect_uri),
# We have to set the Host header to match the `public_baseurl` to avoid
# the extra redirect in the `SsoRedirectServlet` in order for the
# cookies to be visible.
custom_headers=[
("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME),
],
)
self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
oidc_uri = location_headers[0]
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
@ -838,12 +940,38 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
self.assertEqual(chan.json_body["user_id"], "@user1:test")
def test_multi_sso_redirect_to_unknown(self) -> None:
"""An unknown IdP should cause a 400"""
"""An unknown IdP should cause a 404"""
channel = self.make_request(
"GET",
"/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
sso_login_redirect_uri = location_headers[0]
# it should redirect us to the standard login SSO redirect flow
self.assertEqual(
sso_login_redirect_uri,
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
idp_id="xyz", client_redirect_url="http://x"
),
)
# follow the redirect
channel = self.make_request(
"GET",
# We have to make this relative to be compatible with `make_request(...)`
get_relative_uri_from_absolute_uri(sso_login_redirect_uri),
# We have to set the Host header to match the `public_baseurl` to avoid
# the extra redirect in the `SsoRedirectServlet` in order for the
# cookies to be visible.
custom_headers=[
("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME),
],
)
self.assertEqual(channel.code, 404, channel.result)
def test_client_idp_redirect_to_unknown(self) -> None:
"""If the client tries to pick an unknown IdP, return a 404"""
@ -1473,7 +1601,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["public_baseurl"] = BASE_URL
config["public_baseurl"] = PUBLIC_BASEURL
config["oidc_config"] = {}
config["oidc_config"].update(TEST_OIDC_CONFIG)

View file

@ -889,7 +889,7 @@ class RestHelper:
"GET",
uri,
)
assert channel.code == 302
assert channel.code == 302, f"Expected 302 for {uri}, got {channel.code}"
# hit the redirect url again with the right Host header, which should now issue
# a cookie and redirect to the SSO provider.
@ -901,17 +901,18 @@ class RestHelper:
location = get_location(channel)
parts = urllib.parse.urlsplit(location)
next_uri = urllib.parse.urlunsplit(("", "") + parts[2:])
channel = make_request(
self.reactor,
self.site,
"GET",
urllib.parse.urlunsplit(("", "") + parts[2:]),
next_uri,
custom_headers=[
("Host", parts[1]),
],
)
assert channel.code == 302
assert channel.code == 302, f"Expected 302 for {next_uri}, got {channel.code}"
channel.extract_cookies(cookies)
return get_location(channel)

View file

@ -18,8 +18,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
from typing import Optional, Tuple
from typing import List, Optional, Tuple
from twisted.internet.task import deferLater
from twisted.test.proto_helpers import MemoryReactor
@ -104,33 +103,43 @@ class TestTaskScheduler(HomeserverTestCase):
)
)
# This is to give the time to the active tasks to finish
self.reactor.advance(1)
# Check that only MAX_CONCURRENT_RUNNING_TASKS tasks has run and that one
# is still scheduled.
tasks = [
self.get_success(self.task_scheduler.get_task(task_id))
for task_id in task_ids
]
def get_tasks_of_status(status: TaskStatus) -> List[ScheduledTask]:
tasks = (
self.get_success(self.task_scheduler.get_task(task_id))
for task_id in task_ids
)
return [t for t in tasks if t is not None and t.status == status]
# At this point, there should be MAX_CONCURRENT_RUNNING_TASKS active tasks and
# one scheduled task.
self.assertEquals(
len(
[t for t in tasks if t is not None and t.status == TaskStatus.COMPLETE]
),
len(get_tasks_of_status(TaskStatus.ACTIVE)),
TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS,
)
self.assertEquals(
len(get_tasks_of_status(TaskStatus.SCHEDULED)),
1,
)
scheduled_tasks = [
t for t in tasks if t is not None and t.status == TaskStatus.ACTIVE
]
self.assertEquals(len(scheduled_tasks), 1)
# We need to wait for the next run of the scheduler loop
self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
# Give the time to the active tasks to finish
self.reactor.advance(1)
# Check that the last task has been properly executed after the next scheduler loop run
# Check that MAX_CONCURRENT_RUNNING_TASKS tasks have run and that one
# is still scheduled.
self.assertEquals(
len(get_tasks_of_status(TaskStatus.COMPLETE)),
TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS,
)
scheduled_tasks = get_tasks_of_status(TaskStatus.SCHEDULED)
self.assertEquals(len(scheduled_tasks), 1)
# The scheduled task should start 0.1s after the first of the active tasks
# finishes
self.reactor.advance(0.1)
self.assertEquals(len(get_tasks_of_status(TaskStatus.ACTIVE)), 1)
# ... and should finally complete after another second
self.reactor.advance(1)
prev_scheduled_task = self.get_success(
self.task_scheduler.get_task(scheduled_tasks[0].id)
)