diff --git a/.github/workflows/release-artifacts.yml b/.github/workflows/release-artifacts.yml index c0aff79141..42a374fa19 100644 --- a/.github/workflows/release-artifacts.yml +++ b/.github/workflows/release-artifacts.yml @@ -212,7 +212,8 @@ jobs: mv debs*/* debs/ tar -cvJf debs.tar.xz debs - name: Attach to release - uses: softprops/action-gh-release@v2 + # Pinned to work around https://github.com/softprops/action-gh-release/issues/445 + uses: softprops/action-gh-release@v2.0.5 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: diff --git a/changelog.d/17962.misc b/changelog.d/17962.misc new file mode 100644 index 0000000000..adf6348707 --- /dev/null +++ b/changelog.d/17962.misc @@ -0,0 +1 @@ +Fix new scheduled tasks jumping the queue. diff --git a/changelog.d/17965.feature b/changelog.d/17965.feature new file mode 100644 index 0000000000..e447a58986 --- /dev/null +++ b/changelog.d/17965.feature @@ -0,0 +1 @@ +Use stable `M_USER_LOCKED` error code for locked accounts, as per [Matrix 1.12](https://spec.matrix.org/v1.12/client-server-api/#account-locking). \ No newline at end of file diff --git a/changelog.d/17970.bugfix b/changelog.d/17970.bugfix new file mode 100644 index 0000000000..835079de3f --- /dev/null +++ b/changelog.d/17970.bugfix @@ -0,0 +1 @@ +Fix release process to not create duplicate releases. diff --git a/changelog.d/17972.misc b/changelog.d/17972.misc new file mode 100644 index 0000000000..e7f009d20d --- /dev/null +++ b/changelog.d/17972.misc @@ -0,0 +1 @@ +Consolidate SSO redirects through `/_matrix/client/v3/login/sso/redirect(/{idpId})`. diff --git a/poetry.lock b/poetry.lock index f43fe2489a..1a735e2fd6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "annotated-types" @@ -1917,13 +1917,13 @@ test = ["pretend", "pytest (>=3.0.1)", "pytest-rerunfailures"] [[package]] name = "pysaml2" -version = "7.3.1" +version = "7.5.0" description = "Python implementation of SAML Version 2 Standard" optional = true -python-versions = ">=3.6.2,<4.0.0" +python-versions = ">=3.9,<4.0" files = [ - {file = "pysaml2-7.3.1-py3-none-any.whl", hash = "sha256:2cc66e7a371d3f5ff9601f0ed93b5276cca816fce82bb38447d5a0651f2f5193"}, - {file = "pysaml2-7.3.1.tar.gz", hash = "sha256:eab22d187c6dd7707c58b5bb1688f9b8e816427667fc99d77f54399e15cd0a0a"}, + {file = "pysaml2-7.5.0-py3-none-any.whl", hash = "sha256:bc6627cc344476a83c757f440a73fda1369f13b6fda1b4e16bca63ffbabb5318"}, + {file = "pysaml2-7.5.0.tar.gz", hash = "sha256:f36871d4e5ee857c6b85532e942550d2cf90ea4ee943d75eb681044bbc4f54f7"}, ] [package.dependencies] @@ -1933,7 +1933,7 @@ pyopenssl = "*" python-dateutil = "*" pytz = "*" requests = ">=2,<3" -xmlschema = ">=1.2.1" +xmlschema = ">=2,<3" [package.extras] s2repoze = ["paste", "repoze.who", "zope.interface"] diff --git a/synapse/api/errors.py b/synapse/api/errors.py index e6efa7a424..71e4bb4971 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -87,8 +87,7 @@ class Codes(str, Enum): WEAK_PASSWORD = "M_WEAK_PASSWORD" INVALID_SIGNATURE = "M_INVALID_SIGNATURE" USER_DEACTIVATED = "M_USER_DEACTIVATED" - # USER_LOCKED = "M_USER_LOCKED" - USER_LOCKED = "ORG_MATRIX_MSC3939_USER_LOCKED" + USER_LOCKED = "M_USER_LOCKED" NOT_YET_UPLOADED = "M_NOT_YET_UPLOADED" CANNOT_OVERWRITE_MEDIA = "M_CANNOT_OVERWRITE_MEDIA" diff --git a/synapse/api/urls.py b/synapse/api/urls.py index 03a3e96f28..655b5edd7a 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -23,7 +23,8 @@ import hmac from hashlib import sha256 -from urllib.parse import urlencode +from typing import Optional +from urllib.parse import urlencode, urljoin from synapse.config import ConfigError from synapse.config.homeserver import HomeServerConfig @@ -66,3 +67,42 @@ class ConsentURIBuilder: urlencode({"u": user_id, "h": mac}), ) return consent_uri + + +class LoginSSORedirectURIBuilder: + def __init__(self, hs_config: HomeServerConfig): + self._public_baseurl = hs_config.server.public_baseurl + + def build_login_sso_redirect_uri( + self, *, idp_id: Optional[str], client_redirect_url: str + ) -> str: + """Build a `/login/sso/redirect` URI for the given identity provider. + + Builds `/_matrix/client/v3/login/sso/redirect/{idpId}?redirectUrl=xxx` when `idp_id` is specified. + Otherwise, builds `/_matrix/client/v3/login/sso/redirect?redirectUrl=xxx` when `idp_id` is `None`. + + Args: + idp_id: Optional ID of the identity provider + client_redirect_url: URL to redirect the user to after login + + Returns + The URI to follow when choosing a specific identity provider. + """ + base_url = urljoin( + self._public_baseurl, + f"{CLIENT_API_PREFIX}/v3/login/sso/redirect", + ) + + serialized_query_parameters = urlencode({"redirectUrl": client_redirect_url}) + + if idp_id: + resultant_url = urljoin( + # We have to add a trailing slash to the base URL to ensure that the + # last path segment is not stripped away when joining with another path. + f"{base_url}/", + f"{idp_id}?{serialized_query_parameters}", + ) + else: + resultant_url = f"{base_url}?{serialized_query_parameters}" + + return resultant_url diff --git a/synapse/config/cas.py b/synapse/config/cas.py index fa59c350c1..c32bf36951 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -20,7 +20,7 @@ # # -from typing import Any, List +from typing import Any, List, Optional from synapse.config.sso import SsoAttributeRequirement from synapse.types import JsonDict @@ -46,7 +46,9 @@ class CasConfig(Config): # TODO Update this to a _synapse URL. public_baseurl = self.root.server.public_baseurl - self.cas_service_url = public_baseurl + "_matrix/client/r0/login/cas/ticket" + self.cas_service_url: Optional[str] = ( + public_baseurl + "_matrix/client/r0/login/cas/ticket" + ) self.cas_protocol_version = cas_config.get("protocol_version") if ( diff --git a/synapse/config/server.py b/synapse/config/server.py index ad7331de42..6b29983617 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -332,8 +332,14 @@ class ServerConfig(Config): logger.info("Using default public_baseurl %s", public_baseurl) else: self.serve_client_wellknown = True + # Ensure that public_baseurl ends with a trailing slash if public_baseurl[-1] != "/": public_baseurl += "/" + + # Scrutinize user-provided config + if not isinstance(public_baseurl, str): + raise ConfigError("Must be a string", ("public_baseurl",)) + self.public_baseurl = public_baseurl # check that public_baseurl is valid diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 7d51441e91..6ab5356660 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -495,7 +495,7 @@ class LockReleasedCommand(Command): class NewActiveTaskCommand(_SimpleCommand): - """Sent to inform instance handling background tasks that a new active task is available to run. + """Sent to inform instance handling background tasks that a new task is ready to run. Format:: diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 6101226938..1fafbb48c3 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -727,7 +727,7 @@ class ReplicationCommandHandler: ) -> None: """Called when get a new NEW_ACTIVE_TASK command.""" if self._task_scheduler: - self._task_scheduler.launch_task_by_id(cmd.data) + self._task_scheduler.on_new_task(cmd.data) def new_connection(self, connection: IReplicationConnection) -> None: """Called when we have a new connection.""" diff --git a/synapse/rest/synapse/client/pick_idp.py b/synapse/rest/synapse/client/pick_idp.py index f26929bd60..5e599f85b0 100644 --- a/synapse/rest/synapse/client/pick_idp.py +++ b/synapse/rest/synapse/client/pick_idp.py @@ -21,6 +21,7 @@ import logging from typing import TYPE_CHECKING +from synapse.api.urls import LoginSSORedirectURIBuilder from synapse.http.server import ( DirectServeHtmlResource, finish_request, @@ -49,6 +50,8 @@ class PickIdpResource(DirectServeHtmlResource): hs.config.sso.sso_login_idp_picker_template ) self._server_name = hs.hostname + self._public_baseurl = hs.config.server.public_baseurl + self._login_sso_redirect_url_builder = LoginSSORedirectURIBuilder(hs.config) async def _async_render_GET(self, request: SynapseRequest) -> None: client_redirect_url = parse_string( @@ -56,25 +59,23 @@ class PickIdpResource(DirectServeHtmlResource): ) idp = parse_string(request, "idp", required=False) - # if we need to pick an IdP, do so + # If we need to pick an IdP, do so if not idp: return await self._serve_id_picker(request, client_redirect_url) - # otherwise, redirect to the IdP's redirect URI - providers = self._sso_handler.get_identity_providers() - auth_provider = providers.get(idp) - if not auth_provider: - logger.info("Unknown idp %r", idp) - self._sso_handler.render_error( - request, "unknown_idp", "Unknown identity provider ID" + # Otherwise, redirect to the login SSO redirect endpoint for the given IdP + # (which will in turn take us to the the IdP's redirect URI). + # + # We could go directly to the IdP's redirect URI, but this way we ensure that + # the user goes through the same logic as normal flow. Additionally, if a proxy + # needs to intercept the request, it only needs to intercept the one endpoint. + sso_login_redirect_url = ( + self._login_sso_redirect_url_builder.build_login_sso_redirect_uri( + idp_id=idp, client_redirect_url=client_redirect_url ) - return - - sso_url = await auth_provider.handle_redirect_request( - request, client_redirect_url.encode("utf8") ) - logger.info("Redirecting to %s", sso_url) - request.redirect(sso_url) + logger.info("Redirecting to %s", sso_login_redirect_url) + request.redirect(sso_login_redirect_url) finish_request(request) async def _serve_id_picker( diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py index 448960b297..3ed457bd30 100644 --- a/synapse/util/task_scheduler.py +++ b/synapse/util/task_scheduler.py @@ -174,9 +174,10 @@ class TaskScheduler: The id of the scheduled task """ status = TaskStatus.SCHEDULED + start_now = False if timestamp is None or timestamp < self._clock.time_msec(): timestamp = self._clock.time_msec() - status = TaskStatus.ACTIVE + start_now = True task = ScheduledTask( random_string(16), @@ -190,9 +191,11 @@ class TaskScheduler: ) await self._store.insert_scheduled_task(task) - if status == TaskStatus.ACTIVE: + # If the task is ready to run immediately, run the scheduling algorithm now + # rather than waiting + if start_now: if self._run_background_tasks: - await self._launch_task(task) + self._launch_scheduled_tasks() else: self._hs.get_replication_command_handler().send_new_active_task(task.id) @@ -300,23 +303,13 @@ class TaskScheduler: raise Exception(f"Task {id} is currently ACTIVE and can't be deleted") await self._store.delete_scheduled_task(id) - def launch_task_by_id(self, id: str) -> None: - """Try launching the task with the given ID.""" - # Don't bother trying to launch new tasks if we're already at capacity. - if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS: - return + def on_new_task(self, task_id: str) -> None: + """Handle a notification that a new ready-to-run task has been added to the queue""" + # Just run the scheduler + self._launch_scheduled_tasks() - run_as_background_process("launch_task_by_id", self._launch_task_by_id, id) - - async def _launch_task_by_id(self, id: str) -> None: - """Helper async function for `launch_task_by_id`.""" - task = await self.get_task(id) - if task: - await self._launch_task(task) - - @wrap_as_background_process("launch_scheduled_tasks") - async def _launch_scheduled_tasks(self) -> None: - """Retrieve and launch scheduled tasks that should be running at that time.""" + def _launch_scheduled_tasks(self) -> None: + """Retrieve and launch scheduled tasks that should be running at this time.""" # Don't bother trying to launch new tasks if we're already at capacity. if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS: return @@ -326,20 +319,26 @@ class TaskScheduler: self._launching_new_tasks = True - try: - for task in await self.get_tasks( - statuses=[TaskStatus.ACTIVE], limit=self.MAX_CONCURRENT_RUNNING_TASKS - ): - await self._launch_task(task) - for task in await self.get_tasks( - statuses=[TaskStatus.SCHEDULED], - max_timestamp=self._clock.time_msec(), - limit=self.MAX_CONCURRENT_RUNNING_TASKS, - ): - await self._launch_task(task) + async def inner() -> None: + try: + for task in await self.get_tasks( + statuses=[TaskStatus.ACTIVE], + limit=self.MAX_CONCURRENT_RUNNING_TASKS, + ): + # _launch_task will ignore tasks that we're already running, and + # will also do nothing if we're already at the maximum capacity. + await self._launch_task(task) + for task in await self.get_tasks( + statuses=[TaskStatus.SCHEDULED], + max_timestamp=self._clock.time_msec(), + limit=self.MAX_CONCURRENT_RUNNING_TASKS, + ): + await self._launch_task(task) - finally: - self._launching_new_tasks = False + finally: + self._launching_new_tasks = False + + run_as_background_process("launch_scheduled_tasks", inner) @wrap_as_background_process("clean_scheduled_tasks") async def _clean_scheduled_tasks(self) -> None: diff --git a/tests/api/test_urls.py b/tests/api/test_urls.py new file mode 100644 index 0000000000..ce156a05dc --- /dev/null +++ b/tests/api/test_urls.py @@ -0,0 +1,55 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2024 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# + + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.urls import LoginSSORedirectURIBuilder +from synapse.server import HomeServer +from synapse.util import Clock + +from tests.unittest import HomeserverTestCase + +# a (valid) url with some annoying characters in. %3D is =, %26 is &, %2B is + +TRICKY_TEST_CLIENT_REDIRECT_URL = 'https://x?&q"+%3D%2B"="fö%26=o"' + + +class LoginSSORedirectURIBuilderTestCase(HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.login_sso_redirect_url_builder = LoginSSORedirectURIBuilder(hs.config) + + def test_no_idp_id(self) -> None: + self.assertEqual( + self.login_sso_redirect_url_builder.build_login_sso_redirect_uri( + idp_id=None, client_redirect_url="http://example.com/redirect" + ), + "https://test/_matrix/client/v3/login/sso/redirect?redirectUrl=http%3A%2F%2Fexample.com%2Fredirect", + ) + + def test_explicit_idp_id(self) -> None: + self.assertEqual( + self.login_sso_redirect_url_builder.build_login_sso_redirect_uri( + idp_id="oidc-github", client_redirect_url="http://example.com/redirect" + ), + "https://test/_matrix/client/v3/login/sso/redirect/oidc-github?redirectUrl=http%3A%2F%2Fexample.com%2Fredirect", + ) + + def test_tricky_redirect_uri(self) -> None: + self.assertEqual( + self.login_sso_redirect_url_builder.build_login_sso_redirect_uri( + idp_id="oidc-github", + client_redirect_url=TRICKY_TEST_CLIENT_REDIRECT_URL, + ), + "https://test/_matrix/client/v3/login/sso/redirect/oidc-github?redirectUrl=https%3A%2F%2Fx%3F%3Cab+c%3E%26q%22%2B%253D%252B%22%3D%22f%C3%B6%2526%3Do%22", + ) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index cbd6d8d4bf..1451fd7c29 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -43,6 +43,7 @@ from twisted.web.resource import Resource import synapse.rest.admin from synapse.api.constants import ApprovalNoticeMedium, LoginType from synapse.api.errors import Codes +from synapse.api.urls import LoginSSORedirectURIBuilder from synapse.appservice import ApplicationService from synapse.http.client import RawHeaders from synapse.module_api import ModuleApi @@ -69,6 +70,10 @@ try: except ImportError: HAS_JWT = False +import logging + +logger = logging.getLogger(__name__) + # synapse server name: used to populate public_baseurl in some tests SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse" @@ -77,7 +82,7 @@ SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse" # FakeChannel.isSecure() returns False, so synapse will see the requested uri as # http://..., so using http in the public_baseurl stops Synapse trying to redirect to # https://.... -BASE_URL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,) +PUBLIC_BASEURL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,) # CAS server used in some tests CAS_SERVER = "https://fake.test" @@ -109,6 +114,23 @@ ADDITIONAL_LOGIN_FLOWS = [ ] +def get_relative_uri_from_absolute_uri(absolute_uri: str) -> str: + """ + Peels off the path and query string from an absolute URI. Useful when interacting + with `make_request(...)` util function which expects a relative path instead of a + full URI. + """ + parsed_uri = urllib.parse.urlparse(absolute_uri) + # Sanity check that we're working with an absolute URI + assert parsed_uri.scheme == "http" or parsed_uri.scheme == "https" + + relative_uri = parsed_uri.path + if parsed_uri.query: + relative_uri += "?" + parsed_uri.query + + return relative_uri + + class TestSpamChecker: def __init__(self, config: None, api: ModuleApi): api.register_spam_checker_callbacks( @@ -614,7 +636,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): def default_config(self) -> Dict[str, Any]: config = super().default_config() - config["public_baseurl"] = BASE_URL + config["public_baseurl"] = PUBLIC_BASEURL config["cas_config"] = { "enabled": True, @@ -653,6 +675,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): ] return config + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.login_sso_redirect_url_builder = LoginSSORedirectURIBuilder(hs.config) + def create_resource_dict(self) -> Dict[str, Resource]: d = super().create_resource_dict() d.update(build_synapse_client_resource_tree(self.hs)) @@ -725,6 +750,32 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): + "&idp=cas", shorthand=False, ) + self.assertEqual(channel.code, 302, channel.result) + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + sso_login_redirect_uri = location_headers[0] + + # it should redirect us to the standard login SSO redirect flow + self.assertEqual( + sso_login_redirect_uri, + self.login_sso_redirect_url_builder.build_login_sso_redirect_uri( + idp_id="cas", client_redirect_url=TEST_CLIENT_REDIRECT_URL + ), + ) + + # follow the redirect + channel = self.make_request( + "GET", + # We have to make this relative to be compatible with `make_request(...)` + get_relative_uri_from_absolute_uri(sso_login_redirect_uri), + # We have to set the Host header to match the `public_baseurl` to avoid + # the extra redirect in the `SsoRedirectServlet` in order for the + # cookies to be visible. + custom_headers=[ + ("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME), + ], + ) + self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers @@ -750,6 +801,32 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + "&idp=saml", ) + self.assertEqual(channel.code, 302, channel.result) + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + sso_login_redirect_uri = location_headers[0] + + # it should redirect us to the standard login SSO redirect flow + self.assertEqual( + sso_login_redirect_uri, + self.login_sso_redirect_url_builder.build_login_sso_redirect_uri( + idp_id="saml", client_redirect_url=TEST_CLIENT_REDIRECT_URL + ), + ) + + # follow the redirect + channel = self.make_request( + "GET", + # We have to make this relative to be compatible with `make_request(...)` + get_relative_uri_from_absolute_uri(sso_login_redirect_uri), + # We have to set the Host header to match the `public_baseurl` to avoid + # the extra redirect in the `SsoRedirectServlet` in order for the + # cookies to be visible. + custom_headers=[ + ("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME), + ], + ) + self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers @@ -773,13 +850,38 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): # pick the default OIDC provider channel = self.make_request( "GET", - "/_synapse/client/pick_idp?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) - + "&idp=oidc", + f"/_synapse/client/pick_idp?redirectUrl={urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)}&idp=oidc", ) self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers + sso_login_redirect_uri = location_headers[0] + + # it should redirect us to the standard login SSO redirect flow + self.assertEqual( + sso_login_redirect_uri, + self.login_sso_redirect_url_builder.build_login_sso_redirect_uri( + idp_id="oidc", client_redirect_url=TEST_CLIENT_REDIRECT_URL + ), + ) + + with fake_oidc_server.patch_homeserver(hs=self.hs): + # follow the redirect + channel = self.make_request( + "GET", + # We have to make this relative to be compatible with `make_request(...)` + get_relative_uri_from_absolute_uri(sso_login_redirect_uri), + # We have to set the Host header to match the `public_baseurl` to avoid + # the extra redirect in the `SsoRedirectServlet` in order for the + # cookies to be visible. + custom_headers=[ + ("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME), + ], + ) + + self.assertEqual(channel.code, 302, channel.result) + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers oidc_uri = location_headers[0] oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) @@ -838,12 +940,38 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): self.assertEqual(chan.json_body["user_id"], "@user1:test") def test_multi_sso_redirect_to_unknown(self) -> None: - """An unknown IdP should cause a 400""" + """An unknown IdP should cause a 404""" channel = self.make_request( "GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz", ) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, 302, channel.result) + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + sso_login_redirect_uri = location_headers[0] + + # it should redirect us to the standard login SSO redirect flow + self.assertEqual( + sso_login_redirect_uri, + self.login_sso_redirect_url_builder.build_login_sso_redirect_uri( + idp_id="xyz", client_redirect_url="http://x" + ), + ) + + # follow the redirect + channel = self.make_request( + "GET", + # We have to make this relative to be compatible with `make_request(...)` + get_relative_uri_from_absolute_uri(sso_login_redirect_uri), + # We have to set the Host header to match the `public_baseurl` to avoid + # the extra redirect in the `SsoRedirectServlet` in order for the + # cookies to be visible. + custom_headers=[ + ("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME), + ], + ) + + self.assertEqual(channel.code, 404, channel.result) def test_client_idp_redirect_to_unknown(self) -> None: """If the client tries to pick an unknown IdP, return a 404""" @@ -1473,7 +1601,7 @@ class UsernamePickerTestCase(HomeserverTestCase): def default_config(self) -> Dict[str, Any]: config = super().default_config() - config["public_baseurl"] = BASE_URL + config["public_baseurl"] = PUBLIC_BASEURL config["oidc_config"] = {} config["oidc_config"].update(TEST_OIDC_CONFIG) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index a1c284726a..dbd6049f9f 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -889,7 +889,7 @@ class RestHelper: "GET", uri, ) - assert channel.code == 302 + assert channel.code == 302, f"Expected 302 for {uri}, got {channel.code}" # hit the redirect url again with the right Host header, which should now issue # a cookie and redirect to the SSO provider. @@ -901,17 +901,18 @@ class RestHelper: location = get_location(channel) parts = urllib.parse.urlsplit(location) + next_uri = urllib.parse.urlunsplit(("", "") + parts[2:]) channel = make_request( self.reactor, self.site, "GET", - urllib.parse.urlunsplit(("", "") + parts[2:]), + next_uri, custom_headers=[ ("Host", parts[1]), ], ) - assert channel.code == 302 + assert channel.code == 302, f"Expected 302 for {next_uri}, got {channel.code}" channel.extract_cookies(cookies) return get_location(channel) diff --git a/tests/util/test_task_scheduler.py b/tests/util/test_task_scheduler.py index 30f0510c9f..9e403b948b 100644 --- a/tests/util/test_task_scheduler.py +++ b/tests/util/test_task_scheduler.py @@ -18,8 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # - -from typing import Optional, Tuple +from typing import List, Optional, Tuple from twisted.internet.task import deferLater from twisted.test.proto_helpers import MemoryReactor @@ -104,33 +103,43 @@ class TestTaskScheduler(HomeserverTestCase): ) ) - # This is to give the time to the active tasks to finish - self.reactor.advance(1) - - # Check that only MAX_CONCURRENT_RUNNING_TASKS tasks has run and that one - # is still scheduled. - tasks = [ - self.get_success(self.task_scheduler.get_task(task_id)) - for task_id in task_ids - ] + def get_tasks_of_status(status: TaskStatus) -> List[ScheduledTask]: + tasks = ( + self.get_success(self.task_scheduler.get_task(task_id)) + for task_id in task_ids + ) + return [t for t in tasks if t is not None and t.status == status] + # At this point, there should be MAX_CONCURRENT_RUNNING_TASKS active tasks and + # one scheduled task. self.assertEquals( - len( - [t for t in tasks if t is not None and t.status == TaskStatus.COMPLETE] - ), + len(get_tasks_of_status(TaskStatus.ACTIVE)), TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS, ) + self.assertEquals( + len(get_tasks_of_status(TaskStatus.SCHEDULED)), + 1, + ) - scheduled_tasks = [ - t for t in tasks if t is not None and t.status == TaskStatus.ACTIVE - ] - self.assertEquals(len(scheduled_tasks), 1) - - # We need to wait for the next run of the scheduler loop - self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000)) + # Give the time to the active tasks to finish self.reactor.advance(1) - # Check that the last task has been properly executed after the next scheduler loop run + # Check that MAX_CONCURRENT_RUNNING_TASKS tasks have run and that one + # is still scheduled. + self.assertEquals( + len(get_tasks_of_status(TaskStatus.COMPLETE)), + TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS, + ) + scheduled_tasks = get_tasks_of_status(TaskStatus.SCHEDULED) + self.assertEquals(len(scheduled_tasks), 1) + + # The scheduled task should start 0.1s after the first of the active tasks + # finishes + self.reactor.advance(0.1) + self.assertEquals(len(get_tasks_of_status(TaskStatus.ACTIVE)), 1) + + # ... and should finally complete after another second + self.reactor.advance(1) prev_scheduled_task = self.get_success( self.task_scheduler.get_task(scheduled_tasks[0].id) )