mirror of
https://github.com/element-hq/synapse.git
synced 2024-12-14 11:57:44 +00:00
Merge branch 'develop' into madlittlemods/17368-bust-_membership_stream_cache
This commit is contained in:
commit
f5f0e36ec1
18 changed files with 338 additions and 93 deletions
3
.github/workflows/release-artifacts.yml
vendored
3
.github/workflows/release-artifacts.yml
vendored
|
@ -212,7 +212,8 @@ jobs:
|
|||
mv debs*/* debs/
|
||||
tar -cvJf debs.tar.xz debs
|
||||
- name: Attach to release
|
||||
uses: softprops/action-gh-release@v2
|
||||
# Pinned to work around https://github.com/softprops/action-gh-release/issues/445
|
||||
uses: softprops/action-gh-release@v2.0.5
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
|
|
1
changelog.d/17962.misc
Normal file
1
changelog.d/17962.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Fix new scheduled tasks jumping the queue.
|
1
changelog.d/17965.feature
Normal file
1
changelog.d/17965.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Use stable `M_USER_LOCKED` error code for locked accounts, as per [Matrix 1.12](https://spec.matrix.org/v1.12/client-server-api/#account-locking).
|
1
changelog.d/17970.bugfix
Normal file
1
changelog.d/17970.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix release process to not create duplicate releases.
|
1
changelog.d/17972.misc
Normal file
1
changelog.d/17972.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Consolidate SSO redirects through `/_matrix/client/v3/login/sso/redirect(/{idpId})`.
|
12
poetry.lock
generated
12
poetry.lock
generated
|
@ -1,4 +1,4 @@
|
|||
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "annotated-types"
|
||||
|
@ -1917,13 +1917,13 @@ test = ["pretend", "pytest (>=3.0.1)", "pytest-rerunfailures"]
|
|||
|
||||
[[package]]
|
||||
name = "pysaml2"
|
||||
version = "7.3.1"
|
||||
version = "7.5.0"
|
||||
description = "Python implementation of SAML Version 2 Standard"
|
||||
optional = true
|
||||
python-versions = ">=3.6.2,<4.0.0"
|
||||
python-versions = ">=3.9,<4.0"
|
||||
files = [
|
||||
{file = "pysaml2-7.3.1-py3-none-any.whl", hash = "sha256:2cc66e7a371d3f5ff9601f0ed93b5276cca816fce82bb38447d5a0651f2f5193"},
|
||||
{file = "pysaml2-7.3.1.tar.gz", hash = "sha256:eab22d187c6dd7707c58b5bb1688f9b8e816427667fc99d77f54399e15cd0a0a"},
|
||||
{file = "pysaml2-7.5.0-py3-none-any.whl", hash = "sha256:bc6627cc344476a83c757f440a73fda1369f13b6fda1b4e16bca63ffbabb5318"},
|
||||
{file = "pysaml2-7.5.0.tar.gz", hash = "sha256:f36871d4e5ee857c6b85532e942550d2cf90ea4ee943d75eb681044bbc4f54f7"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -1933,7 +1933,7 @@ pyopenssl = "*"
|
|||
python-dateutil = "*"
|
||||
pytz = "*"
|
||||
requests = ">=2,<3"
|
||||
xmlschema = ">=1.2.1"
|
||||
xmlschema = ">=2,<3"
|
||||
|
||||
[package.extras]
|
||||
s2repoze = ["paste", "repoze.who", "zope.interface"]
|
||||
|
|
|
@ -87,8 +87,7 @@ class Codes(str, Enum):
|
|||
WEAK_PASSWORD = "M_WEAK_PASSWORD"
|
||||
INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
|
||||
USER_DEACTIVATED = "M_USER_DEACTIVATED"
|
||||
# USER_LOCKED = "M_USER_LOCKED"
|
||||
USER_LOCKED = "ORG_MATRIX_MSC3939_USER_LOCKED"
|
||||
USER_LOCKED = "M_USER_LOCKED"
|
||||
NOT_YET_UPLOADED = "M_NOT_YET_UPLOADED"
|
||||
CANNOT_OVERWRITE_MEDIA = "M_CANNOT_OVERWRITE_MEDIA"
|
||||
|
||||
|
|
|
@ -23,7 +23,8 @@
|
|||
|
||||
import hmac
|
||||
from hashlib import sha256
|
||||
from urllib.parse import urlencode
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode, urljoin
|
||||
|
||||
from synapse.config import ConfigError
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
|
@ -66,3 +67,42 @@ class ConsentURIBuilder:
|
|||
urlencode({"u": user_id, "h": mac}),
|
||||
)
|
||||
return consent_uri
|
||||
|
||||
|
||||
class LoginSSORedirectURIBuilder:
|
||||
def __init__(self, hs_config: HomeServerConfig):
|
||||
self._public_baseurl = hs_config.server.public_baseurl
|
||||
|
||||
def build_login_sso_redirect_uri(
|
||||
self, *, idp_id: Optional[str], client_redirect_url: str
|
||||
) -> str:
|
||||
"""Build a `/login/sso/redirect` URI for the given identity provider.
|
||||
|
||||
Builds `/_matrix/client/v3/login/sso/redirect/{idpId}?redirectUrl=xxx` when `idp_id` is specified.
|
||||
Otherwise, builds `/_matrix/client/v3/login/sso/redirect?redirectUrl=xxx` when `idp_id` is `None`.
|
||||
|
||||
Args:
|
||||
idp_id: Optional ID of the identity provider
|
||||
client_redirect_url: URL to redirect the user to after login
|
||||
|
||||
Returns
|
||||
The URI to follow when choosing a specific identity provider.
|
||||
"""
|
||||
base_url = urljoin(
|
||||
self._public_baseurl,
|
||||
f"{CLIENT_API_PREFIX}/v3/login/sso/redirect",
|
||||
)
|
||||
|
||||
serialized_query_parameters = urlencode({"redirectUrl": client_redirect_url})
|
||||
|
||||
if idp_id:
|
||||
resultant_url = urljoin(
|
||||
# We have to add a trailing slash to the base URL to ensure that the
|
||||
# last path segment is not stripped away when joining with another path.
|
||||
f"{base_url}/",
|
||||
f"{idp_id}?{serialized_query_parameters}",
|
||||
)
|
||||
else:
|
||||
resultant_url = f"{base_url}?{serialized_query_parameters}"
|
||||
|
||||
return resultant_url
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#
|
||||
#
|
||||
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from synapse.config.sso import SsoAttributeRequirement
|
||||
from synapse.types import JsonDict
|
||||
|
@ -46,7 +46,9 @@ class CasConfig(Config):
|
|||
|
||||
# TODO Update this to a _synapse URL.
|
||||
public_baseurl = self.root.server.public_baseurl
|
||||
self.cas_service_url = public_baseurl + "_matrix/client/r0/login/cas/ticket"
|
||||
self.cas_service_url: Optional[str] = (
|
||||
public_baseurl + "_matrix/client/r0/login/cas/ticket"
|
||||
)
|
||||
|
||||
self.cas_protocol_version = cas_config.get("protocol_version")
|
||||
if (
|
||||
|
|
|
@ -332,8 +332,14 @@ class ServerConfig(Config):
|
|||
logger.info("Using default public_baseurl %s", public_baseurl)
|
||||
else:
|
||||
self.serve_client_wellknown = True
|
||||
# Ensure that public_baseurl ends with a trailing slash
|
||||
if public_baseurl[-1] != "/":
|
||||
public_baseurl += "/"
|
||||
|
||||
# Scrutinize user-provided config
|
||||
if not isinstance(public_baseurl, str):
|
||||
raise ConfigError("Must be a string", ("public_baseurl",))
|
||||
|
||||
self.public_baseurl = public_baseurl
|
||||
|
||||
# check that public_baseurl is valid
|
||||
|
|
|
@ -495,7 +495,7 @@ class LockReleasedCommand(Command):
|
|||
|
||||
|
||||
class NewActiveTaskCommand(_SimpleCommand):
|
||||
"""Sent to inform instance handling background tasks that a new active task is available to run.
|
||||
"""Sent to inform instance handling background tasks that a new task is ready to run.
|
||||
|
||||
Format::
|
||||
|
||||
|
|
|
@ -727,7 +727,7 @@ class ReplicationCommandHandler:
|
|||
) -> None:
|
||||
"""Called when get a new NEW_ACTIVE_TASK command."""
|
||||
if self._task_scheduler:
|
||||
self._task_scheduler.launch_task_by_id(cmd.data)
|
||||
self._task_scheduler.on_new_task(cmd.data)
|
||||
|
||||
def new_connection(self, connection: IReplicationConnection) -> None:
|
||||
"""Called when we have a new connection."""
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from synapse.api.urls import LoginSSORedirectURIBuilder
|
||||
from synapse.http.server import (
|
||||
DirectServeHtmlResource,
|
||||
finish_request,
|
||||
|
@ -49,6 +50,8 @@ class PickIdpResource(DirectServeHtmlResource):
|
|||
hs.config.sso.sso_login_idp_picker_template
|
||||
)
|
||||
self._server_name = hs.hostname
|
||||
self._public_baseurl = hs.config.server.public_baseurl
|
||||
self._login_sso_redirect_url_builder = LoginSSORedirectURIBuilder(hs.config)
|
||||
|
||||
async def _async_render_GET(self, request: SynapseRequest) -> None:
|
||||
client_redirect_url = parse_string(
|
||||
|
@ -56,25 +59,23 @@ class PickIdpResource(DirectServeHtmlResource):
|
|||
)
|
||||
idp = parse_string(request, "idp", required=False)
|
||||
|
||||
# if we need to pick an IdP, do so
|
||||
# If we need to pick an IdP, do so
|
||||
if not idp:
|
||||
return await self._serve_id_picker(request, client_redirect_url)
|
||||
|
||||
# otherwise, redirect to the IdP's redirect URI
|
||||
providers = self._sso_handler.get_identity_providers()
|
||||
auth_provider = providers.get(idp)
|
||||
if not auth_provider:
|
||||
logger.info("Unknown idp %r", idp)
|
||||
self._sso_handler.render_error(
|
||||
request, "unknown_idp", "Unknown identity provider ID"
|
||||
# Otherwise, redirect to the login SSO redirect endpoint for the given IdP
|
||||
# (which will in turn take us to the the IdP's redirect URI).
|
||||
#
|
||||
# We could go directly to the IdP's redirect URI, but this way we ensure that
|
||||
# the user goes through the same logic as normal flow. Additionally, if a proxy
|
||||
# needs to intercept the request, it only needs to intercept the one endpoint.
|
||||
sso_login_redirect_url = (
|
||||
self._login_sso_redirect_url_builder.build_login_sso_redirect_uri(
|
||||
idp_id=idp, client_redirect_url=client_redirect_url
|
||||
)
|
||||
return
|
||||
|
||||
sso_url = await auth_provider.handle_redirect_request(
|
||||
request, client_redirect_url.encode("utf8")
|
||||
)
|
||||
logger.info("Redirecting to %s", sso_url)
|
||||
request.redirect(sso_url)
|
||||
logger.info("Redirecting to %s", sso_login_redirect_url)
|
||||
request.redirect(sso_login_redirect_url)
|
||||
finish_request(request)
|
||||
|
||||
async def _serve_id_picker(
|
||||
|
|
|
@ -174,9 +174,10 @@ class TaskScheduler:
|
|||
The id of the scheduled task
|
||||
"""
|
||||
status = TaskStatus.SCHEDULED
|
||||
start_now = False
|
||||
if timestamp is None or timestamp < self._clock.time_msec():
|
||||
timestamp = self._clock.time_msec()
|
||||
status = TaskStatus.ACTIVE
|
||||
start_now = True
|
||||
|
||||
task = ScheduledTask(
|
||||
random_string(16),
|
||||
|
@ -190,9 +191,11 @@ class TaskScheduler:
|
|||
)
|
||||
await self._store.insert_scheduled_task(task)
|
||||
|
||||
if status == TaskStatus.ACTIVE:
|
||||
# If the task is ready to run immediately, run the scheduling algorithm now
|
||||
# rather than waiting
|
||||
if start_now:
|
||||
if self._run_background_tasks:
|
||||
await self._launch_task(task)
|
||||
self._launch_scheduled_tasks()
|
||||
else:
|
||||
self._hs.get_replication_command_handler().send_new_active_task(task.id)
|
||||
|
||||
|
@ -300,23 +303,13 @@ class TaskScheduler:
|
|||
raise Exception(f"Task {id} is currently ACTIVE and can't be deleted")
|
||||
await self._store.delete_scheduled_task(id)
|
||||
|
||||
def launch_task_by_id(self, id: str) -> None:
|
||||
"""Try launching the task with the given ID."""
|
||||
# Don't bother trying to launch new tasks if we're already at capacity.
|
||||
if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS:
|
||||
return
|
||||
def on_new_task(self, task_id: str) -> None:
|
||||
"""Handle a notification that a new ready-to-run task has been added to the queue"""
|
||||
# Just run the scheduler
|
||||
self._launch_scheduled_tasks()
|
||||
|
||||
run_as_background_process("launch_task_by_id", self._launch_task_by_id, id)
|
||||
|
||||
async def _launch_task_by_id(self, id: str) -> None:
|
||||
"""Helper async function for `launch_task_by_id`."""
|
||||
task = await self.get_task(id)
|
||||
if task:
|
||||
await self._launch_task(task)
|
||||
|
||||
@wrap_as_background_process("launch_scheduled_tasks")
|
||||
async def _launch_scheduled_tasks(self) -> None:
|
||||
"""Retrieve and launch scheduled tasks that should be running at that time."""
|
||||
def _launch_scheduled_tasks(self) -> None:
|
||||
"""Retrieve and launch scheduled tasks that should be running at this time."""
|
||||
# Don't bother trying to launch new tasks if we're already at capacity.
|
||||
if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS:
|
||||
return
|
||||
|
@ -326,20 +319,26 @@ class TaskScheduler:
|
|||
|
||||
self._launching_new_tasks = True
|
||||
|
||||
try:
|
||||
for task in await self.get_tasks(
|
||||
statuses=[TaskStatus.ACTIVE], limit=self.MAX_CONCURRENT_RUNNING_TASKS
|
||||
):
|
||||
await self._launch_task(task)
|
||||
for task in await self.get_tasks(
|
||||
statuses=[TaskStatus.SCHEDULED],
|
||||
max_timestamp=self._clock.time_msec(),
|
||||
limit=self.MAX_CONCURRENT_RUNNING_TASKS,
|
||||
):
|
||||
await self._launch_task(task)
|
||||
async def inner() -> None:
|
||||
try:
|
||||
for task in await self.get_tasks(
|
||||
statuses=[TaskStatus.ACTIVE],
|
||||
limit=self.MAX_CONCURRENT_RUNNING_TASKS,
|
||||
):
|
||||
# _launch_task will ignore tasks that we're already running, and
|
||||
# will also do nothing if we're already at the maximum capacity.
|
||||
await self._launch_task(task)
|
||||
for task in await self.get_tasks(
|
||||
statuses=[TaskStatus.SCHEDULED],
|
||||
max_timestamp=self._clock.time_msec(),
|
||||
limit=self.MAX_CONCURRENT_RUNNING_TASKS,
|
||||
):
|
||||
await self._launch_task(task)
|
||||
|
||||
finally:
|
||||
self._launching_new_tasks = False
|
||||
finally:
|
||||
self._launching_new_tasks = False
|
||||
|
||||
run_as_background_process("launch_scheduled_tasks", inner)
|
||||
|
||||
@wrap_as_background_process("clean_scheduled_tasks")
|
||||
async def _clean_scheduled_tasks(self) -> None:
|
||||
|
|
55
tests/api/test_urls.py
Normal file
55
tests/api/test_urls.py
Normal file
|
@ -0,0 +1,55 @@
|
|||
#
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright (C) 2024 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
#
|
||||
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.urls import LoginSSORedirectURIBuilder
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.unittest import HomeserverTestCase
|
||||
|
||||
# a (valid) url with some annoying characters in. %3D is =, %26 is &, %2B is +
|
||||
TRICKY_TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="fö%26=o"'
|
||||
|
||||
|
||||
class LoginSSORedirectURIBuilderTestCase(HomeserverTestCase):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.login_sso_redirect_url_builder = LoginSSORedirectURIBuilder(hs.config)
|
||||
|
||||
def test_no_idp_id(self) -> None:
|
||||
self.assertEqual(
|
||||
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
|
||||
idp_id=None, client_redirect_url="http://example.com/redirect"
|
||||
),
|
||||
"https://test/_matrix/client/v3/login/sso/redirect?redirectUrl=http%3A%2F%2Fexample.com%2Fredirect",
|
||||
)
|
||||
|
||||
def test_explicit_idp_id(self) -> None:
|
||||
self.assertEqual(
|
||||
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
|
||||
idp_id="oidc-github", client_redirect_url="http://example.com/redirect"
|
||||
),
|
||||
"https://test/_matrix/client/v3/login/sso/redirect/oidc-github?redirectUrl=http%3A%2F%2Fexample.com%2Fredirect",
|
||||
)
|
||||
|
||||
def test_tricky_redirect_uri(self) -> None:
|
||||
self.assertEqual(
|
||||
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
|
||||
idp_id="oidc-github",
|
||||
client_redirect_url=TRICKY_TEST_CLIENT_REDIRECT_URL,
|
||||
),
|
||||
"https://test/_matrix/client/v3/login/sso/redirect/oidc-github?redirectUrl=https%3A%2F%2Fx%3F%3Cab+c%3E%26q%22%2B%253D%252B%22%3D%22f%C3%B6%2526%3Do%22",
|
||||
)
|
|
@ -43,6 +43,7 @@ from twisted.web.resource import Resource
|
|||
import synapse.rest.admin
|
||||
from synapse.api.constants import ApprovalNoticeMedium, LoginType
|
||||
from synapse.api.errors import Codes
|
||||
from synapse.api.urls import LoginSSORedirectURIBuilder
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.http.client import RawHeaders
|
||||
from synapse.module_api import ModuleApi
|
||||
|
@ -69,6 +70,10 @@ try:
|
|||
except ImportError:
|
||||
HAS_JWT = False
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# synapse server name: used to populate public_baseurl in some tests
|
||||
SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse"
|
||||
|
@ -77,7 +82,7 @@ SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse"
|
|||
# FakeChannel.isSecure() returns False, so synapse will see the requested uri as
|
||||
# http://..., so using http in the public_baseurl stops Synapse trying to redirect to
|
||||
# https://....
|
||||
BASE_URL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,)
|
||||
PUBLIC_BASEURL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,)
|
||||
|
||||
# CAS server used in some tests
|
||||
CAS_SERVER = "https://fake.test"
|
||||
|
@ -109,6 +114,23 @@ ADDITIONAL_LOGIN_FLOWS = [
|
|||
]
|
||||
|
||||
|
||||
def get_relative_uri_from_absolute_uri(absolute_uri: str) -> str:
|
||||
"""
|
||||
Peels off the path and query string from an absolute URI. Useful when interacting
|
||||
with `make_request(...)` util function which expects a relative path instead of a
|
||||
full URI.
|
||||
"""
|
||||
parsed_uri = urllib.parse.urlparse(absolute_uri)
|
||||
# Sanity check that we're working with an absolute URI
|
||||
assert parsed_uri.scheme == "http" or parsed_uri.scheme == "https"
|
||||
|
||||
relative_uri = parsed_uri.path
|
||||
if parsed_uri.query:
|
||||
relative_uri += "?" + parsed_uri.query
|
||||
|
||||
return relative_uri
|
||||
|
||||
|
||||
class TestSpamChecker:
|
||||
def __init__(self, config: None, api: ModuleApi):
|
||||
api.register_spam_checker_callbacks(
|
||||
|
@ -614,7 +636,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
|||
def default_config(self) -> Dict[str, Any]:
|
||||
config = super().default_config()
|
||||
|
||||
config["public_baseurl"] = BASE_URL
|
||||
config["public_baseurl"] = PUBLIC_BASEURL
|
||||
|
||||
config["cas_config"] = {
|
||||
"enabled": True,
|
||||
|
@ -653,6 +675,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
|||
]
|
||||
return config
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.login_sso_redirect_url_builder = LoginSSORedirectURIBuilder(hs.config)
|
||||
|
||||
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||
d = super().create_resource_dict()
|
||||
d.update(build_synapse_client_resource_tree(self.hs))
|
||||
|
@ -725,6 +750,32 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
|||
+ "&idp=cas",
|
||||
shorthand=False,
|
||||
)
|
||||
self.assertEqual(channel.code, 302, channel.result)
|
||||
location_headers = channel.headers.getRawHeaders("Location")
|
||||
assert location_headers
|
||||
sso_login_redirect_uri = location_headers[0]
|
||||
|
||||
# it should redirect us to the standard login SSO redirect flow
|
||||
self.assertEqual(
|
||||
sso_login_redirect_uri,
|
||||
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
|
||||
idp_id="cas", client_redirect_url=TEST_CLIENT_REDIRECT_URL
|
||||
),
|
||||
)
|
||||
|
||||
# follow the redirect
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
# We have to make this relative to be compatible with `make_request(...)`
|
||||
get_relative_uri_from_absolute_uri(sso_login_redirect_uri),
|
||||
# We have to set the Host header to match the `public_baseurl` to avoid
|
||||
# the extra redirect in the `SsoRedirectServlet` in order for the
|
||||
# cookies to be visible.
|
||||
custom_headers=[
|
||||
("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME),
|
||||
],
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 302, channel.result)
|
||||
location_headers = channel.headers.getRawHeaders("Location")
|
||||
assert location_headers
|
||||
|
@ -750,6 +801,32 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
|||
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
|
||||
+ "&idp=saml",
|
||||
)
|
||||
self.assertEqual(channel.code, 302, channel.result)
|
||||
location_headers = channel.headers.getRawHeaders("Location")
|
||||
assert location_headers
|
||||
sso_login_redirect_uri = location_headers[0]
|
||||
|
||||
# it should redirect us to the standard login SSO redirect flow
|
||||
self.assertEqual(
|
||||
sso_login_redirect_uri,
|
||||
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
|
||||
idp_id="saml", client_redirect_url=TEST_CLIENT_REDIRECT_URL
|
||||
),
|
||||
)
|
||||
|
||||
# follow the redirect
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
# We have to make this relative to be compatible with `make_request(...)`
|
||||
get_relative_uri_from_absolute_uri(sso_login_redirect_uri),
|
||||
# We have to set the Host header to match the `public_baseurl` to avoid
|
||||
# the extra redirect in the `SsoRedirectServlet` in order for the
|
||||
# cookies to be visible.
|
||||
custom_headers=[
|
||||
("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME),
|
||||
],
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 302, channel.result)
|
||||
location_headers = channel.headers.getRawHeaders("Location")
|
||||
assert location_headers
|
||||
|
@ -773,13 +850,38 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
|||
# pick the default OIDC provider
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_synapse/client/pick_idp?redirectUrl="
|
||||
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
|
||||
+ "&idp=oidc",
|
||||
f"/_synapse/client/pick_idp?redirectUrl={urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)}&idp=oidc",
|
||||
)
|
||||
self.assertEqual(channel.code, 302, channel.result)
|
||||
location_headers = channel.headers.getRawHeaders("Location")
|
||||
assert location_headers
|
||||
sso_login_redirect_uri = location_headers[0]
|
||||
|
||||
# it should redirect us to the standard login SSO redirect flow
|
||||
self.assertEqual(
|
||||
sso_login_redirect_uri,
|
||||
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
|
||||
idp_id="oidc", client_redirect_url=TEST_CLIENT_REDIRECT_URL
|
||||
),
|
||||
)
|
||||
|
||||
with fake_oidc_server.patch_homeserver(hs=self.hs):
|
||||
# follow the redirect
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
# We have to make this relative to be compatible with `make_request(...)`
|
||||
get_relative_uri_from_absolute_uri(sso_login_redirect_uri),
|
||||
# We have to set the Host header to match the `public_baseurl` to avoid
|
||||
# the extra redirect in the `SsoRedirectServlet` in order for the
|
||||
# cookies to be visible.
|
||||
custom_headers=[
|
||||
("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME),
|
||||
],
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 302, channel.result)
|
||||
location_headers = channel.headers.getRawHeaders("Location")
|
||||
assert location_headers
|
||||
oidc_uri = location_headers[0]
|
||||
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
|
||||
|
||||
|
@ -838,12 +940,38 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(chan.json_body["user_id"], "@user1:test")
|
||||
|
||||
def test_multi_sso_redirect_to_unknown(self) -> None:
|
||||
"""An unknown IdP should cause a 400"""
|
||||
"""An unknown IdP should cause a 404"""
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
|
||||
)
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
self.assertEqual(channel.code, 302, channel.result)
|
||||
location_headers = channel.headers.getRawHeaders("Location")
|
||||
assert location_headers
|
||||
sso_login_redirect_uri = location_headers[0]
|
||||
|
||||
# it should redirect us to the standard login SSO redirect flow
|
||||
self.assertEqual(
|
||||
sso_login_redirect_uri,
|
||||
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
|
||||
idp_id="xyz", client_redirect_url="http://x"
|
||||
),
|
||||
)
|
||||
|
||||
# follow the redirect
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
# We have to make this relative to be compatible with `make_request(...)`
|
||||
get_relative_uri_from_absolute_uri(sso_login_redirect_uri),
|
||||
# We have to set the Host header to match the `public_baseurl` to avoid
|
||||
# the extra redirect in the `SsoRedirectServlet` in order for the
|
||||
# cookies to be visible.
|
||||
custom_headers=[
|
||||
("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME),
|
||||
],
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 404, channel.result)
|
||||
|
||||
def test_client_idp_redirect_to_unknown(self) -> None:
|
||||
"""If the client tries to pick an unknown IdP, return a 404"""
|
||||
|
@ -1473,7 +1601,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
|||
|
||||
def default_config(self) -> Dict[str, Any]:
|
||||
config = super().default_config()
|
||||
config["public_baseurl"] = BASE_URL
|
||||
config["public_baseurl"] = PUBLIC_BASEURL
|
||||
|
||||
config["oidc_config"] = {}
|
||||
config["oidc_config"].update(TEST_OIDC_CONFIG)
|
||||
|
|
|
@ -889,7 +889,7 @@ class RestHelper:
|
|||
"GET",
|
||||
uri,
|
||||
)
|
||||
assert channel.code == 302
|
||||
assert channel.code == 302, f"Expected 302 for {uri}, got {channel.code}"
|
||||
|
||||
# hit the redirect url again with the right Host header, which should now issue
|
||||
# a cookie and redirect to the SSO provider.
|
||||
|
@ -901,17 +901,18 @@ class RestHelper:
|
|||
|
||||
location = get_location(channel)
|
||||
parts = urllib.parse.urlsplit(location)
|
||||
next_uri = urllib.parse.urlunsplit(("", "") + parts[2:])
|
||||
channel = make_request(
|
||||
self.reactor,
|
||||
self.site,
|
||||
"GET",
|
||||
urllib.parse.urlunsplit(("", "") + parts[2:]),
|
||||
next_uri,
|
||||
custom_headers=[
|
||||
("Host", parts[1]),
|
||||
],
|
||||
)
|
||||
|
||||
assert channel.code == 302
|
||||
assert channel.code == 302, f"Expected 302 for {next_uri}, got {channel.code}"
|
||||
channel.extract_cookies(cookies)
|
||||
return get_location(channel)
|
||||
|
||||
|
|
|
@ -18,8 +18,7 @@
|
|||
# [This file includes modifications made by New Vector Limited]
|
||||
#
|
||||
#
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from twisted.internet.task import deferLater
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
@ -104,33 +103,43 @@ class TestTaskScheduler(HomeserverTestCase):
|
|||
)
|
||||
)
|
||||
|
||||
# This is to give the time to the active tasks to finish
|
||||
self.reactor.advance(1)
|
||||
|
||||
# Check that only MAX_CONCURRENT_RUNNING_TASKS tasks has run and that one
|
||||
# is still scheduled.
|
||||
tasks = [
|
||||
self.get_success(self.task_scheduler.get_task(task_id))
|
||||
for task_id in task_ids
|
||||
]
|
||||
def get_tasks_of_status(status: TaskStatus) -> List[ScheduledTask]:
|
||||
tasks = (
|
||||
self.get_success(self.task_scheduler.get_task(task_id))
|
||||
for task_id in task_ids
|
||||
)
|
||||
return [t for t in tasks if t is not None and t.status == status]
|
||||
|
||||
# At this point, there should be MAX_CONCURRENT_RUNNING_TASKS active tasks and
|
||||
# one scheduled task.
|
||||
self.assertEquals(
|
||||
len(
|
||||
[t for t in tasks if t is not None and t.status == TaskStatus.COMPLETE]
|
||||
),
|
||||
len(get_tasks_of_status(TaskStatus.ACTIVE)),
|
||||
TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS,
|
||||
)
|
||||
self.assertEquals(
|
||||
len(get_tasks_of_status(TaskStatus.SCHEDULED)),
|
||||
1,
|
||||
)
|
||||
|
||||
scheduled_tasks = [
|
||||
t for t in tasks if t is not None and t.status == TaskStatus.ACTIVE
|
||||
]
|
||||
self.assertEquals(len(scheduled_tasks), 1)
|
||||
|
||||
# We need to wait for the next run of the scheduler loop
|
||||
self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
|
||||
# Give the time to the active tasks to finish
|
||||
self.reactor.advance(1)
|
||||
|
||||
# Check that the last task has been properly executed after the next scheduler loop run
|
||||
# Check that MAX_CONCURRENT_RUNNING_TASKS tasks have run and that one
|
||||
# is still scheduled.
|
||||
self.assertEquals(
|
||||
len(get_tasks_of_status(TaskStatus.COMPLETE)),
|
||||
TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS,
|
||||
)
|
||||
scheduled_tasks = get_tasks_of_status(TaskStatus.SCHEDULED)
|
||||
self.assertEquals(len(scheduled_tasks), 1)
|
||||
|
||||
# The scheduled task should start 0.1s after the first of the active tasks
|
||||
# finishes
|
||||
self.reactor.advance(0.1)
|
||||
self.assertEquals(len(get_tasks_of_status(TaskStatus.ACTIVE)), 1)
|
||||
|
||||
# ... and should finally complete after another second
|
||||
self.reactor.advance(1)
|
||||
prev_scheduled_task = self.get_success(
|
||||
self.task_scheduler.get_task(scheduled_tasks[0].id)
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue