Merge branch 'develop' into madlittlemods/17368-bust-_membership_stream_cache

This commit is contained in:
Eric Eastwood 2024-12-02 09:49:53 -06:00
commit f5f0e36ec1
18 changed files with 338 additions and 93 deletions

View file

@ -212,7 +212,8 @@ jobs:
mv debs*/* debs/ mv debs*/* debs/
tar -cvJf debs.tar.xz debs tar -cvJf debs.tar.xz debs
- name: Attach to release - name: Attach to release
uses: softprops/action-gh-release@v2 # Pinned to work around https://github.com/softprops/action-gh-release/issues/445
uses: softprops/action-gh-release@v2.0.5
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with: with:

1
changelog.d/17962.misc Normal file
View file

@ -0,0 +1 @@
Fix new scheduled tasks jumping the queue.

View file

@ -0,0 +1 @@
Use stable `M_USER_LOCKED` error code for locked accounts, as per [Matrix 1.12](https://spec.matrix.org/v1.12/client-server-api/#account-locking).

1
changelog.d/17970.bugfix Normal file
View file

@ -0,0 +1 @@
Fix release process to not create duplicate releases.

1
changelog.d/17972.misc Normal file
View file

@ -0,0 +1 @@
Consolidate SSO redirects through `/_matrix/client/v3/login/sso/redirect(/{idpId})`.

12
poetry.lock generated
View file

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]] [[package]]
name = "annotated-types" name = "annotated-types"
@ -1917,13 +1917,13 @@ test = ["pretend", "pytest (>=3.0.1)", "pytest-rerunfailures"]
[[package]] [[package]]
name = "pysaml2" name = "pysaml2"
version = "7.3.1" version = "7.5.0"
description = "Python implementation of SAML Version 2 Standard" description = "Python implementation of SAML Version 2 Standard"
optional = true optional = true
python-versions = ">=3.6.2,<4.0.0" python-versions = ">=3.9,<4.0"
files = [ files = [
{file = "pysaml2-7.3.1-py3-none-any.whl", hash = "sha256:2cc66e7a371d3f5ff9601f0ed93b5276cca816fce82bb38447d5a0651f2f5193"}, {file = "pysaml2-7.5.0-py3-none-any.whl", hash = "sha256:bc6627cc344476a83c757f440a73fda1369f13b6fda1b4e16bca63ffbabb5318"},
{file = "pysaml2-7.3.1.tar.gz", hash = "sha256:eab22d187c6dd7707c58b5bb1688f9b8e816427667fc99d77f54399e15cd0a0a"}, {file = "pysaml2-7.5.0.tar.gz", hash = "sha256:f36871d4e5ee857c6b85532e942550d2cf90ea4ee943d75eb681044bbc4f54f7"},
] ]
[package.dependencies] [package.dependencies]
@ -1933,7 +1933,7 @@ pyopenssl = "*"
python-dateutil = "*" python-dateutil = "*"
pytz = "*" pytz = "*"
requests = ">=2,<3" requests = ">=2,<3"
xmlschema = ">=1.2.1" xmlschema = ">=2,<3"
[package.extras] [package.extras]
s2repoze = ["paste", "repoze.who", "zope.interface"] s2repoze = ["paste", "repoze.who", "zope.interface"]

View file

@ -87,8 +87,7 @@ class Codes(str, Enum):
WEAK_PASSWORD = "M_WEAK_PASSWORD" WEAK_PASSWORD = "M_WEAK_PASSWORD"
INVALID_SIGNATURE = "M_INVALID_SIGNATURE" INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
USER_DEACTIVATED = "M_USER_DEACTIVATED" USER_DEACTIVATED = "M_USER_DEACTIVATED"
# USER_LOCKED = "M_USER_LOCKED" USER_LOCKED = "M_USER_LOCKED"
USER_LOCKED = "ORG_MATRIX_MSC3939_USER_LOCKED"
NOT_YET_UPLOADED = "M_NOT_YET_UPLOADED" NOT_YET_UPLOADED = "M_NOT_YET_UPLOADED"
CANNOT_OVERWRITE_MEDIA = "M_CANNOT_OVERWRITE_MEDIA" CANNOT_OVERWRITE_MEDIA = "M_CANNOT_OVERWRITE_MEDIA"

View file

@ -23,7 +23,8 @@
import hmac import hmac
from hashlib import sha256 from hashlib import sha256
from urllib.parse import urlencode from typing import Optional
from urllib.parse import urlencode, urljoin
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
@ -66,3 +67,42 @@ class ConsentURIBuilder:
urlencode({"u": user_id, "h": mac}), urlencode({"u": user_id, "h": mac}),
) )
return consent_uri return consent_uri
class LoginSSORedirectURIBuilder:
def __init__(self, hs_config: HomeServerConfig):
self._public_baseurl = hs_config.server.public_baseurl
def build_login_sso_redirect_uri(
self, *, idp_id: Optional[str], client_redirect_url: str
) -> str:
"""Build a `/login/sso/redirect` URI for the given identity provider.
Builds `/_matrix/client/v3/login/sso/redirect/{idpId}?redirectUrl=xxx` when `idp_id` is specified.
Otherwise, builds `/_matrix/client/v3/login/sso/redirect?redirectUrl=xxx` when `idp_id` is `None`.
Args:
idp_id: Optional ID of the identity provider
client_redirect_url: URL to redirect the user to after login
Returns
The URI to follow when choosing a specific identity provider.
"""
base_url = urljoin(
self._public_baseurl,
f"{CLIENT_API_PREFIX}/v3/login/sso/redirect",
)
serialized_query_parameters = urlencode({"redirectUrl": client_redirect_url})
if idp_id:
resultant_url = urljoin(
# We have to add a trailing slash to the base URL to ensure that the
# last path segment is not stripped away when joining with another path.
f"{base_url}/",
f"{idp_id}?{serialized_query_parameters}",
)
else:
resultant_url = f"{base_url}?{serialized_query_parameters}"
return resultant_url

View file

@ -20,7 +20,7 @@
# #
# #
from typing import Any, List from typing import Any, List, Optional
from synapse.config.sso import SsoAttributeRequirement from synapse.config.sso import SsoAttributeRequirement
from synapse.types import JsonDict from synapse.types import JsonDict
@ -46,7 +46,9 @@ class CasConfig(Config):
# TODO Update this to a _synapse URL. # TODO Update this to a _synapse URL.
public_baseurl = self.root.server.public_baseurl public_baseurl = self.root.server.public_baseurl
self.cas_service_url = public_baseurl + "_matrix/client/r0/login/cas/ticket" self.cas_service_url: Optional[str] = (
public_baseurl + "_matrix/client/r0/login/cas/ticket"
)
self.cas_protocol_version = cas_config.get("protocol_version") self.cas_protocol_version = cas_config.get("protocol_version")
if ( if (

View file

@ -332,8 +332,14 @@ class ServerConfig(Config):
logger.info("Using default public_baseurl %s", public_baseurl) logger.info("Using default public_baseurl %s", public_baseurl)
else: else:
self.serve_client_wellknown = True self.serve_client_wellknown = True
# Ensure that public_baseurl ends with a trailing slash
if public_baseurl[-1] != "/": if public_baseurl[-1] != "/":
public_baseurl += "/" public_baseurl += "/"
# Scrutinize user-provided config
if not isinstance(public_baseurl, str):
raise ConfigError("Must be a string", ("public_baseurl",))
self.public_baseurl = public_baseurl self.public_baseurl = public_baseurl
# check that public_baseurl is valid # check that public_baseurl is valid

View file

@ -495,7 +495,7 @@ class LockReleasedCommand(Command):
class NewActiveTaskCommand(_SimpleCommand): class NewActiveTaskCommand(_SimpleCommand):
"""Sent to inform instance handling background tasks that a new active task is available to run. """Sent to inform instance handling background tasks that a new task is ready to run.
Format:: Format::

View file

@ -727,7 +727,7 @@ class ReplicationCommandHandler:
) -> None: ) -> None:
"""Called when get a new NEW_ACTIVE_TASK command.""" """Called when get a new NEW_ACTIVE_TASK command."""
if self._task_scheduler: if self._task_scheduler:
self._task_scheduler.launch_task_by_id(cmd.data) self._task_scheduler.on_new_task(cmd.data)
def new_connection(self, connection: IReplicationConnection) -> None: def new_connection(self, connection: IReplicationConnection) -> None:
"""Called when we have a new connection.""" """Called when we have a new connection."""

View file

@ -21,6 +21,7 @@
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from synapse.api.urls import LoginSSORedirectURIBuilder
from synapse.http.server import ( from synapse.http.server import (
DirectServeHtmlResource, DirectServeHtmlResource,
finish_request, finish_request,
@ -49,6 +50,8 @@ class PickIdpResource(DirectServeHtmlResource):
hs.config.sso.sso_login_idp_picker_template hs.config.sso.sso_login_idp_picker_template
) )
self._server_name = hs.hostname self._server_name = hs.hostname
self._public_baseurl = hs.config.server.public_baseurl
self._login_sso_redirect_url_builder = LoginSSORedirectURIBuilder(hs.config)
async def _async_render_GET(self, request: SynapseRequest) -> None: async def _async_render_GET(self, request: SynapseRequest) -> None:
client_redirect_url = parse_string( client_redirect_url = parse_string(
@ -56,25 +59,23 @@ class PickIdpResource(DirectServeHtmlResource):
) )
idp = parse_string(request, "idp", required=False) idp = parse_string(request, "idp", required=False)
# if we need to pick an IdP, do so # If we need to pick an IdP, do so
if not idp: if not idp:
return await self._serve_id_picker(request, client_redirect_url) return await self._serve_id_picker(request, client_redirect_url)
# otherwise, redirect to the IdP's redirect URI # Otherwise, redirect to the login SSO redirect endpoint for the given IdP
providers = self._sso_handler.get_identity_providers() # (which will in turn take us to the the IdP's redirect URI).
auth_provider = providers.get(idp) #
if not auth_provider: # We could go directly to the IdP's redirect URI, but this way we ensure that
logger.info("Unknown idp %r", idp) # the user goes through the same logic as normal flow. Additionally, if a proxy
self._sso_handler.render_error( # needs to intercept the request, it only needs to intercept the one endpoint.
request, "unknown_idp", "Unknown identity provider ID" sso_login_redirect_url = (
self._login_sso_redirect_url_builder.build_login_sso_redirect_uri(
idp_id=idp, client_redirect_url=client_redirect_url
) )
return
sso_url = await auth_provider.handle_redirect_request(
request, client_redirect_url.encode("utf8")
) )
logger.info("Redirecting to %s", sso_url) logger.info("Redirecting to %s", sso_login_redirect_url)
request.redirect(sso_url) request.redirect(sso_login_redirect_url)
finish_request(request) finish_request(request)
async def _serve_id_picker( async def _serve_id_picker(

View file

@ -174,9 +174,10 @@ class TaskScheduler:
The id of the scheduled task The id of the scheduled task
""" """
status = TaskStatus.SCHEDULED status = TaskStatus.SCHEDULED
start_now = False
if timestamp is None or timestamp < self._clock.time_msec(): if timestamp is None or timestamp < self._clock.time_msec():
timestamp = self._clock.time_msec() timestamp = self._clock.time_msec()
status = TaskStatus.ACTIVE start_now = True
task = ScheduledTask( task = ScheduledTask(
random_string(16), random_string(16),
@ -190,9 +191,11 @@ class TaskScheduler:
) )
await self._store.insert_scheduled_task(task) await self._store.insert_scheduled_task(task)
if status == TaskStatus.ACTIVE: # If the task is ready to run immediately, run the scheduling algorithm now
# rather than waiting
if start_now:
if self._run_background_tasks: if self._run_background_tasks:
await self._launch_task(task) self._launch_scheduled_tasks()
else: else:
self._hs.get_replication_command_handler().send_new_active_task(task.id) self._hs.get_replication_command_handler().send_new_active_task(task.id)
@ -300,23 +303,13 @@ class TaskScheduler:
raise Exception(f"Task {id} is currently ACTIVE and can't be deleted") raise Exception(f"Task {id} is currently ACTIVE and can't be deleted")
await self._store.delete_scheduled_task(id) await self._store.delete_scheduled_task(id)
def launch_task_by_id(self, id: str) -> None: def on_new_task(self, task_id: str) -> None:
"""Try launching the task with the given ID.""" """Handle a notification that a new ready-to-run task has been added to the queue"""
# Don't bother trying to launch new tasks if we're already at capacity. # Just run the scheduler
if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS: self._launch_scheduled_tasks()
return
run_as_background_process("launch_task_by_id", self._launch_task_by_id, id) def _launch_scheduled_tasks(self) -> None:
"""Retrieve and launch scheduled tasks that should be running at this time."""
async def _launch_task_by_id(self, id: str) -> None:
"""Helper async function for `launch_task_by_id`."""
task = await self.get_task(id)
if task:
await self._launch_task(task)
@wrap_as_background_process("launch_scheduled_tasks")
async def _launch_scheduled_tasks(self) -> None:
"""Retrieve and launch scheduled tasks that should be running at that time."""
# Don't bother trying to launch new tasks if we're already at capacity. # Don't bother trying to launch new tasks if we're already at capacity.
if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS: if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS:
return return
@ -326,20 +319,26 @@ class TaskScheduler:
self._launching_new_tasks = True self._launching_new_tasks = True
try: async def inner() -> None:
for task in await self.get_tasks( try:
statuses=[TaskStatus.ACTIVE], limit=self.MAX_CONCURRENT_RUNNING_TASKS for task in await self.get_tasks(
): statuses=[TaskStatus.ACTIVE],
await self._launch_task(task) limit=self.MAX_CONCURRENT_RUNNING_TASKS,
for task in await self.get_tasks( ):
statuses=[TaskStatus.SCHEDULED], # _launch_task will ignore tasks that we're already running, and
max_timestamp=self._clock.time_msec(), # will also do nothing if we're already at the maximum capacity.
limit=self.MAX_CONCURRENT_RUNNING_TASKS, await self._launch_task(task)
): for task in await self.get_tasks(
await self._launch_task(task) statuses=[TaskStatus.SCHEDULED],
max_timestamp=self._clock.time_msec(),
limit=self.MAX_CONCURRENT_RUNNING_TASKS,
):
await self._launch_task(task)
finally: finally:
self._launching_new_tasks = False self._launching_new_tasks = False
run_as_background_process("launch_scheduled_tasks", inner)
@wrap_as_background_process("clean_scheduled_tasks") @wrap_as_background_process("clean_scheduled_tasks")
async def _clean_scheduled_tasks(self) -> None: async def _clean_scheduled_tasks(self) -> None:

55
tests/api/test_urls.py Normal file
View file

@ -0,0 +1,55 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2024 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.urls import LoginSSORedirectURIBuilder
from synapse.server import HomeServer
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
# a (valid) url with some annoying characters in. %3D is =, %26 is &, %2B is +
TRICKY_TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="%26=o"'
class LoginSSORedirectURIBuilderTestCase(HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.login_sso_redirect_url_builder = LoginSSORedirectURIBuilder(hs.config)
def test_no_idp_id(self) -> None:
self.assertEqual(
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
idp_id=None, client_redirect_url="http://example.com/redirect"
),
"https://test/_matrix/client/v3/login/sso/redirect?redirectUrl=http%3A%2F%2Fexample.com%2Fredirect",
)
def test_explicit_idp_id(self) -> None:
self.assertEqual(
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
idp_id="oidc-github", client_redirect_url="http://example.com/redirect"
),
"https://test/_matrix/client/v3/login/sso/redirect/oidc-github?redirectUrl=http%3A%2F%2Fexample.com%2Fredirect",
)
def test_tricky_redirect_uri(self) -> None:
self.assertEqual(
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
idp_id="oidc-github",
client_redirect_url=TRICKY_TEST_CLIENT_REDIRECT_URL,
),
"https://test/_matrix/client/v3/login/sso/redirect/oidc-github?redirectUrl=https%3A%2F%2Fx%3F%3Cab+c%3E%26q%22%2B%253D%252B%22%3D%22f%C3%B6%2526%3Do%22",
)

View file

@ -43,6 +43,7 @@ from twisted.web.resource import Resource
import synapse.rest.admin import synapse.rest.admin
from synapse.api.constants import ApprovalNoticeMedium, LoginType from synapse.api.constants import ApprovalNoticeMedium, LoginType
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.api.urls import LoginSSORedirectURIBuilder
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.http.client import RawHeaders from synapse.http.client import RawHeaders
from synapse.module_api import ModuleApi from synapse.module_api import ModuleApi
@ -69,6 +70,10 @@ try:
except ImportError: except ImportError:
HAS_JWT = False HAS_JWT = False
import logging
logger = logging.getLogger(__name__)
# synapse server name: used to populate public_baseurl in some tests # synapse server name: used to populate public_baseurl in some tests
SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse" SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse"
@ -77,7 +82,7 @@ SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse"
# FakeChannel.isSecure() returns False, so synapse will see the requested uri as # FakeChannel.isSecure() returns False, so synapse will see the requested uri as
# http://..., so using http in the public_baseurl stops Synapse trying to redirect to # http://..., so using http in the public_baseurl stops Synapse trying to redirect to
# https://.... # https://....
BASE_URL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,) PUBLIC_BASEURL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,)
# CAS server used in some tests # CAS server used in some tests
CAS_SERVER = "https://fake.test" CAS_SERVER = "https://fake.test"
@ -109,6 +114,23 @@ ADDITIONAL_LOGIN_FLOWS = [
] ]
def get_relative_uri_from_absolute_uri(absolute_uri: str) -> str:
"""
Peels off the path and query string from an absolute URI. Useful when interacting
with `make_request(...)` util function which expects a relative path instead of a
full URI.
"""
parsed_uri = urllib.parse.urlparse(absolute_uri)
# Sanity check that we're working with an absolute URI
assert parsed_uri.scheme == "http" or parsed_uri.scheme == "https"
relative_uri = parsed_uri.path
if parsed_uri.query:
relative_uri += "?" + parsed_uri.query
return relative_uri
class TestSpamChecker: class TestSpamChecker:
def __init__(self, config: None, api: ModuleApi): def __init__(self, config: None, api: ModuleApi):
api.register_spam_checker_callbacks( api.register_spam_checker_callbacks(
@ -614,7 +636,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
def default_config(self) -> Dict[str, Any]: def default_config(self) -> Dict[str, Any]:
config = super().default_config() config = super().default_config()
config["public_baseurl"] = BASE_URL config["public_baseurl"] = PUBLIC_BASEURL
config["cas_config"] = { config["cas_config"] = {
"enabled": True, "enabled": True,
@ -653,6 +675,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
] ]
return config return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.login_sso_redirect_url_builder = LoginSSORedirectURIBuilder(hs.config)
def create_resource_dict(self) -> Dict[str, Resource]: def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict() d = super().create_resource_dict()
d.update(build_synapse_client_resource_tree(self.hs)) d.update(build_synapse_client_resource_tree(self.hs))
@ -725,6 +750,32 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ "&idp=cas", + "&idp=cas",
shorthand=False, shorthand=False,
) )
self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
sso_login_redirect_uri = location_headers[0]
# it should redirect us to the standard login SSO redirect flow
self.assertEqual(
sso_login_redirect_uri,
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
idp_id="cas", client_redirect_url=TEST_CLIENT_REDIRECT_URL
),
)
# follow the redirect
channel = self.make_request(
"GET",
# We have to make this relative to be compatible with `make_request(...)`
get_relative_uri_from_absolute_uri(sso_login_redirect_uri),
# We have to set the Host header to match the `public_baseurl` to avoid
# the extra redirect in the `SsoRedirectServlet` in order for the
# cookies to be visible.
custom_headers=[
("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME),
],
)
self.assertEqual(channel.code, 302, channel.result) self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location") location_headers = channel.headers.getRawHeaders("Location")
assert location_headers assert location_headers
@ -750,6 +801,32 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ "&idp=saml", + "&idp=saml",
) )
self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
sso_login_redirect_uri = location_headers[0]
# it should redirect us to the standard login SSO redirect flow
self.assertEqual(
sso_login_redirect_uri,
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
idp_id="saml", client_redirect_url=TEST_CLIENT_REDIRECT_URL
),
)
# follow the redirect
channel = self.make_request(
"GET",
# We have to make this relative to be compatible with `make_request(...)`
get_relative_uri_from_absolute_uri(sso_login_redirect_uri),
# We have to set the Host header to match the `public_baseurl` to avoid
# the extra redirect in the `SsoRedirectServlet` in order for the
# cookies to be visible.
custom_headers=[
("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME),
],
)
self.assertEqual(channel.code, 302, channel.result) self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location") location_headers = channel.headers.getRawHeaders("Location")
assert location_headers assert location_headers
@ -773,13 +850,38 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
# pick the default OIDC provider # pick the default OIDC provider
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_synapse/client/pick_idp?redirectUrl=" f"/_synapse/client/pick_idp?redirectUrl={urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)}&idp=oidc",
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ "&idp=oidc",
) )
self.assertEqual(channel.code, 302, channel.result) self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location") location_headers = channel.headers.getRawHeaders("Location")
assert location_headers assert location_headers
sso_login_redirect_uri = location_headers[0]
# it should redirect us to the standard login SSO redirect flow
self.assertEqual(
sso_login_redirect_uri,
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
idp_id="oidc", client_redirect_url=TEST_CLIENT_REDIRECT_URL
),
)
with fake_oidc_server.patch_homeserver(hs=self.hs):
# follow the redirect
channel = self.make_request(
"GET",
# We have to make this relative to be compatible with `make_request(...)`
get_relative_uri_from_absolute_uri(sso_login_redirect_uri),
# We have to set the Host header to match the `public_baseurl` to avoid
# the extra redirect in the `SsoRedirectServlet` in order for the
# cookies to be visible.
custom_headers=[
("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME),
],
)
self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
oidc_uri = location_headers[0] oidc_uri = location_headers[0]
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
@ -838,12 +940,38 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
self.assertEqual(chan.json_body["user_id"], "@user1:test") self.assertEqual(chan.json_body["user_id"], "@user1:test")
def test_multi_sso_redirect_to_unknown(self) -> None: def test_multi_sso_redirect_to_unknown(self) -> None:
"""An unknown IdP should cause a 400""" """An unknown IdP should cause a 404"""
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
) )
self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
sso_login_redirect_uri = location_headers[0]
# it should redirect us to the standard login SSO redirect flow
self.assertEqual(
sso_login_redirect_uri,
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
idp_id="xyz", client_redirect_url="http://x"
),
)
# follow the redirect
channel = self.make_request(
"GET",
# We have to make this relative to be compatible with `make_request(...)`
get_relative_uri_from_absolute_uri(sso_login_redirect_uri),
# We have to set the Host header to match the `public_baseurl` to avoid
# the extra redirect in the `SsoRedirectServlet` in order for the
# cookies to be visible.
custom_headers=[
("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME),
],
)
self.assertEqual(channel.code, 404, channel.result)
def test_client_idp_redirect_to_unknown(self) -> None: def test_client_idp_redirect_to_unknown(self) -> None:
"""If the client tries to pick an unknown IdP, return a 404""" """If the client tries to pick an unknown IdP, return a 404"""
@ -1473,7 +1601,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
def default_config(self) -> Dict[str, Any]: def default_config(self) -> Dict[str, Any]:
config = super().default_config() config = super().default_config()
config["public_baseurl"] = BASE_URL config["public_baseurl"] = PUBLIC_BASEURL
config["oidc_config"] = {} config["oidc_config"] = {}
config["oidc_config"].update(TEST_OIDC_CONFIG) config["oidc_config"].update(TEST_OIDC_CONFIG)

View file

@ -889,7 +889,7 @@ class RestHelper:
"GET", "GET",
uri, uri,
) )
assert channel.code == 302 assert channel.code == 302, f"Expected 302 for {uri}, got {channel.code}"
# hit the redirect url again with the right Host header, which should now issue # hit the redirect url again with the right Host header, which should now issue
# a cookie and redirect to the SSO provider. # a cookie and redirect to the SSO provider.
@ -901,17 +901,18 @@ class RestHelper:
location = get_location(channel) location = get_location(channel)
parts = urllib.parse.urlsplit(location) parts = urllib.parse.urlsplit(location)
next_uri = urllib.parse.urlunsplit(("", "") + parts[2:])
channel = make_request( channel = make_request(
self.reactor, self.reactor,
self.site, self.site,
"GET", "GET",
urllib.parse.urlunsplit(("", "") + parts[2:]), next_uri,
custom_headers=[ custom_headers=[
("Host", parts[1]), ("Host", parts[1]),
], ],
) )
assert channel.code == 302 assert channel.code == 302, f"Expected 302 for {next_uri}, got {channel.code}"
channel.extract_cookies(cookies) channel.extract_cookies(cookies)
return get_location(channel) return get_location(channel)

View file

@ -18,8 +18,7 @@
# [This file includes modifications made by New Vector Limited] # [This file includes modifications made by New Vector Limited]
# #
# #
from typing import List, Optional, Tuple
from typing import Optional, Tuple
from twisted.internet.task import deferLater from twisted.internet.task import deferLater
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -104,33 +103,43 @@ class TestTaskScheduler(HomeserverTestCase):
) )
) )
# This is to give the time to the active tasks to finish def get_tasks_of_status(status: TaskStatus) -> List[ScheduledTask]:
self.reactor.advance(1) tasks = (
self.get_success(self.task_scheduler.get_task(task_id))
# Check that only MAX_CONCURRENT_RUNNING_TASKS tasks has run and that one for task_id in task_ids
# is still scheduled. )
tasks = [ return [t for t in tasks if t is not None and t.status == status]
self.get_success(self.task_scheduler.get_task(task_id))
for task_id in task_ids
]
# At this point, there should be MAX_CONCURRENT_RUNNING_TASKS active tasks and
# one scheduled task.
self.assertEquals( self.assertEquals(
len( len(get_tasks_of_status(TaskStatus.ACTIVE)),
[t for t in tasks if t is not None and t.status == TaskStatus.COMPLETE]
),
TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS, TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS,
) )
self.assertEquals(
len(get_tasks_of_status(TaskStatus.SCHEDULED)),
1,
)
scheduled_tasks = [ # Give the time to the active tasks to finish
t for t in tasks if t is not None and t.status == TaskStatus.ACTIVE
]
self.assertEquals(len(scheduled_tasks), 1)
# We need to wait for the next run of the scheduler loop
self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
self.reactor.advance(1) self.reactor.advance(1)
# Check that the last task has been properly executed after the next scheduler loop run # Check that MAX_CONCURRENT_RUNNING_TASKS tasks have run and that one
# is still scheduled.
self.assertEquals(
len(get_tasks_of_status(TaskStatus.COMPLETE)),
TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS,
)
scheduled_tasks = get_tasks_of_status(TaskStatus.SCHEDULED)
self.assertEquals(len(scheduled_tasks), 1)
# The scheduled task should start 0.1s after the first of the active tasks
# finishes
self.reactor.advance(0.1)
self.assertEquals(len(get_tasks_of_status(TaskStatus.ACTIVE)), 1)
# ... and should finally complete after another second
self.reactor.advance(1)
prev_scheduled_task = self.get_success( prev_scheduled_task = self.get_success(
self.task_scheduler.get_task(scheduled_tasks[0].id) self.task_scheduler.get_task(scheduled_tasks[0].id)
) )