mirror of
https://github.com/element-hq/synapse.git
synced 2024-12-14 11:57:44 +00:00
Merge branch 'develop' into madlittlemods/17368-bust-_membership_stream_cache
This commit is contained in:
commit
f5f0e36ec1
18 changed files with 338 additions and 93 deletions
3
.github/workflows/release-artifacts.yml
vendored
3
.github/workflows/release-artifacts.yml
vendored
|
@ -212,7 +212,8 @@ jobs:
|
||||||
mv debs*/* debs/
|
mv debs*/* debs/
|
||||||
tar -cvJf debs.tar.xz debs
|
tar -cvJf debs.tar.xz debs
|
||||||
- name: Attach to release
|
- name: Attach to release
|
||||||
uses: softprops/action-gh-release@v2
|
# Pinned to work around https://github.com/softprops/action-gh-release/issues/445
|
||||||
|
uses: softprops/action-gh-release@v2.0.5
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
with:
|
with:
|
||||||
|
|
1
changelog.d/17962.misc
Normal file
1
changelog.d/17962.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix new scheduled tasks jumping the queue.
|
1
changelog.d/17965.feature
Normal file
1
changelog.d/17965.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Use stable `M_USER_LOCKED` error code for locked accounts, as per [Matrix 1.12](https://spec.matrix.org/v1.12/client-server-api/#account-locking).
|
1
changelog.d/17970.bugfix
Normal file
1
changelog.d/17970.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix release process to not create duplicate releases.
|
1
changelog.d/17972.misc
Normal file
1
changelog.d/17972.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Consolidate SSO redirects through `/_matrix/client/v3/login/sso/redirect(/{idpId})`.
|
12
poetry.lock
generated
12
poetry.lock
generated
|
@ -1,4 +1,4 @@
|
||||||
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "annotated-types"
|
name = "annotated-types"
|
||||||
|
@ -1917,13 +1917,13 @@ test = ["pretend", "pytest (>=3.0.1)", "pytest-rerunfailures"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pysaml2"
|
name = "pysaml2"
|
||||||
version = "7.3.1"
|
version = "7.5.0"
|
||||||
description = "Python implementation of SAML Version 2 Standard"
|
description = "Python implementation of SAML Version 2 Standard"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=3.6.2,<4.0.0"
|
python-versions = ">=3.9,<4.0"
|
||||||
files = [
|
files = [
|
||||||
{file = "pysaml2-7.3.1-py3-none-any.whl", hash = "sha256:2cc66e7a371d3f5ff9601f0ed93b5276cca816fce82bb38447d5a0651f2f5193"},
|
{file = "pysaml2-7.5.0-py3-none-any.whl", hash = "sha256:bc6627cc344476a83c757f440a73fda1369f13b6fda1b4e16bca63ffbabb5318"},
|
||||||
{file = "pysaml2-7.3.1.tar.gz", hash = "sha256:eab22d187c6dd7707c58b5bb1688f9b8e816427667fc99d77f54399e15cd0a0a"},
|
{file = "pysaml2-7.5.0.tar.gz", hash = "sha256:f36871d4e5ee857c6b85532e942550d2cf90ea4ee943d75eb681044bbc4f54f7"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -1933,7 +1933,7 @@ pyopenssl = "*"
|
||||||
python-dateutil = "*"
|
python-dateutil = "*"
|
||||||
pytz = "*"
|
pytz = "*"
|
||||||
requests = ">=2,<3"
|
requests = ">=2,<3"
|
||||||
xmlschema = ">=1.2.1"
|
xmlschema = ">=2,<3"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
s2repoze = ["paste", "repoze.who", "zope.interface"]
|
s2repoze = ["paste", "repoze.who", "zope.interface"]
|
||||||
|
|
|
@ -87,8 +87,7 @@ class Codes(str, Enum):
|
||||||
WEAK_PASSWORD = "M_WEAK_PASSWORD"
|
WEAK_PASSWORD = "M_WEAK_PASSWORD"
|
||||||
INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
|
INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
|
||||||
USER_DEACTIVATED = "M_USER_DEACTIVATED"
|
USER_DEACTIVATED = "M_USER_DEACTIVATED"
|
||||||
# USER_LOCKED = "M_USER_LOCKED"
|
USER_LOCKED = "M_USER_LOCKED"
|
||||||
USER_LOCKED = "ORG_MATRIX_MSC3939_USER_LOCKED"
|
|
||||||
NOT_YET_UPLOADED = "M_NOT_YET_UPLOADED"
|
NOT_YET_UPLOADED = "M_NOT_YET_UPLOADED"
|
||||||
CANNOT_OVERWRITE_MEDIA = "M_CANNOT_OVERWRITE_MEDIA"
|
CANNOT_OVERWRITE_MEDIA = "M_CANNOT_OVERWRITE_MEDIA"
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,8 @@
|
||||||
|
|
||||||
import hmac
|
import hmac
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from urllib.parse import urlencode
|
from typing import Optional
|
||||||
|
from urllib.parse import urlencode, urljoin
|
||||||
|
|
||||||
from synapse.config import ConfigError
|
from synapse.config import ConfigError
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
|
@ -66,3 +67,42 @@ class ConsentURIBuilder:
|
||||||
urlencode({"u": user_id, "h": mac}),
|
urlencode({"u": user_id, "h": mac}),
|
||||||
)
|
)
|
||||||
return consent_uri
|
return consent_uri
|
||||||
|
|
||||||
|
|
||||||
|
class LoginSSORedirectURIBuilder:
|
||||||
|
def __init__(self, hs_config: HomeServerConfig):
|
||||||
|
self._public_baseurl = hs_config.server.public_baseurl
|
||||||
|
|
||||||
|
def build_login_sso_redirect_uri(
|
||||||
|
self, *, idp_id: Optional[str], client_redirect_url: str
|
||||||
|
) -> str:
|
||||||
|
"""Build a `/login/sso/redirect` URI for the given identity provider.
|
||||||
|
|
||||||
|
Builds `/_matrix/client/v3/login/sso/redirect/{idpId}?redirectUrl=xxx` when `idp_id` is specified.
|
||||||
|
Otherwise, builds `/_matrix/client/v3/login/sso/redirect?redirectUrl=xxx` when `idp_id` is `None`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idp_id: Optional ID of the identity provider
|
||||||
|
client_redirect_url: URL to redirect the user to after login
|
||||||
|
|
||||||
|
Returns
|
||||||
|
The URI to follow when choosing a specific identity provider.
|
||||||
|
"""
|
||||||
|
base_url = urljoin(
|
||||||
|
self._public_baseurl,
|
||||||
|
f"{CLIENT_API_PREFIX}/v3/login/sso/redirect",
|
||||||
|
)
|
||||||
|
|
||||||
|
serialized_query_parameters = urlencode({"redirectUrl": client_redirect_url})
|
||||||
|
|
||||||
|
if idp_id:
|
||||||
|
resultant_url = urljoin(
|
||||||
|
# We have to add a trailing slash to the base URL to ensure that the
|
||||||
|
# last path segment is not stripped away when joining with another path.
|
||||||
|
f"{base_url}/",
|
||||||
|
f"{idp_id}?{serialized_query_parameters}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
resultant_url = f"{base_url}?{serialized_query_parameters}"
|
||||||
|
|
||||||
|
return resultant_url
|
||||||
|
|
|
@ -20,7 +20,7 @@
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
|
|
||||||
from typing import Any, List
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from synapse.config.sso import SsoAttributeRequirement
|
from synapse.config.sso import SsoAttributeRequirement
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
|
@ -46,7 +46,9 @@ class CasConfig(Config):
|
||||||
|
|
||||||
# TODO Update this to a _synapse URL.
|
# TODO Update this to a _synapse URL.
|
||||||
public_baseurl = self.root.server.public_baseurl
|
public_baseurl = self.root.server.public_baseurl
|
||||||
self.cas_service_url = public_baseurl + "_matrix/client/r0/login/cas/ticket"
|
self.cas_service_url: Optional[str] = (
|
||||||
|
public_baseurl + "_matrix/client/r0/login/cas/ticket"
|
||||||
|
)
|
||||||
|
|
||||||
self.cas_protocol_version = cas_config.get("protocol_version")
|
self.cas_protocol_version = cas_config.get("protocol_version")
|
||||||
if (
|
if (
|
||||||
|
|
|
@ -332,8 +332,14 @@ class ServerConfig(Config):
|
||||||
logger.info("Using default public_baseurl %s", public_baseurl)
|
logger.info("Using default public_baseurl %s", public_baseurl)
|
||||||
else:
|
else:
|
||||||
self.serve_client_wellknown = True
|
self.serve_client_wellknown = True
|
||||||
|
# Ensure that public_baseurl ends with a trailing slash
|
||||||
if public_baseurl[-1] != "/":
|
if public_baseurl[-1] != "/":
|
||||||
public_baseurl += "/"
|
public_baseurl += "/"
|
||||||
|
|
||||||
|
# Scrutinize user-provided config
|
||||||
|
if not isinstance(public_baseurl, str):
|
||||||
|
raise ConfigError("Must be a string", ("public_baseurl",))
|
||||||
|
|
||||||
self.public_baseurl = public_baseurl
|
self.public_baseurl = public_baseurl
|
||||||
|
|
||||||
# check that public_baseurl is valid
|
# check that public_baseurl is valid
|
||||||
|
|
|
@ -495,7 +495,7 @@ class LockReleasedCommand(Command):
|
||||||
|
|
||||||
|
|
||||||
class NewActiveTaskCommand(_SimpleCommand):
|
class NewActiveTaskCommand(_SimpleCommand):
|
||||||
"""Sent to inform instance handling background tasks that a new active task is available to run.
|
"""Sent to inform instance handling background tasks that a new task is ready to run.
|
||||||
|
|
||||||
Format::
|
Format::
|
||||||
|
|
||||||
|
|
|
@ -727,7 +727,7 @@ class ReplicationCommandHandler:
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Called when get a new NEW_ACTIVE_TASK command."""
|
"""Called when get a new NEW_ACTIVE_TASK command."""
|
||||||
if self._task_scheduler:
|
if self._task_scheduler:
|
||||||
self._task_scheduler.launch_task_by_id(cmd.data)
|
self._task_scheduler.on_new_task(cmd.data)
|
||||||
|
|
||||||
def new_connection(self, connection: IReplicationConnection) -> None:
|
def new_connection(self, connection: IReplicationConnection) -> None:
|
||||||
"""Called when we have a new connection."""
|
"""Called when we have a new connection."""
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from synapse.api.urls import LoginSSORedirectURIBuilder
|
||||||
from synapse.http.server import (
|
from synapse.http.server import (
|
||||||
DirectServeHtmlResource,
|
DirectServeHtmlResource,
|
||||||
finish_request,
|
finish_request,
|
||||||
|
@ -49,6 +50,8 @@ class PickIdpResource(DirectServeHtmlResource):
|
||||||
hs.config.sso.sso_login_idp_picker_template
|
hs.config.sso.sso_login_idp_picker_template
|
||||||
)
|
)
|
||||||
self._server_name = hs.hostname
|
self._server_name = hs.hostname
|
||||||
|
self._public_baseurl = hs.config.server.public_baseurl
|
||||||
|
self._login_sso_redirect_url_builder = LoginSSORedirectURIBuilder(hs.config)
|
||||||
|
|
||||||
async def _async_render_GET(self, request: SynapseRequest) -> None:
|
async def _async_render_GET(self, request: SynapseRequest) -> None:
|
||||||
client_redirect_url = parse_string(
|
client_redirect_url = parse_string(
|
||||||
|
@ -56,25 +59,23 @@ class PickIdpResource(DirectServeHtmlResource):
|
||||||
)
|
)
|
||||||
idp = parse_string(request, "idp", required=False)
|
idp = parse_string(request, "idp", required=False)
|
||||||
|
|
||||||
# if we need to pick an IdP, do so
|
# If we need to pick an IdP, do so
|
||||||
if not idp:
|
if not idp:
|
||||||
return await self._serve_id_picker(request, client_redirect_url)
|
return await self._serve_id_picker(request, client_redirect_url)
|
||||||
|
|
||||||
# otherwise, redirect to the IdP's redirect URI
|
# Otherwise, redirect to the login SSO redirect endpoint for the given IdP
|
||||||
providers = self._sso_handler.get_identity_providers()
|
# (which will in turn take us to the the IdP's redirect URI).
|
||||||
auth_provider = providers.get(idp)
|
#
|
||||||
if not auth_provider:
|
# We could go directly to the IdP's redirect URI, but this way we ensure that
|
||||||
logger.info("Unknown idp %r", idp)
|
# the user goes through the same logic as normal flow. Additionally, if a proxy
|
||||||
self._sso_handler.render_error(
|
# needs to intercept the request, it only needs to intercept the one endpoint.
|
||||||
request, "unknown_idp", "Unknown identity provider ID"
|
sso_login_redirect_url = (
|
||||||
|
self._login_sso_redirect_url_builder.build_login_sso_redirect_uri(
|
||||||
|
idp_id=idp, client_redirect_url=client_redirect_url
|
||||||
)
|
)
|
||||||
return
|
|
||||||
|
|
||||||
sso_url = await auth_provider.handle_redirect_request(
|
|
||||||
request, client_redirect_url.encode("utf8")
|
|
||||||
)
|
)
|
||||||
logger.info("Redirecting to %s", sso_url)
|
logger.info("Redirecting to %s", sso_login_redirect_url)
|
||||||
request.redirect(sso_url)
|
request.redirect(sso_login_redirect_url)
|
||||||
finish_request(request)
|
finish_request(request)
|
||||||
|
|
||||||
async def _serve_id_picker(
|
async def _serve_id_picker(
|
||||||
|
|
|
@ -174,9 +174,10 @@ class TaskScheduler:
|
||||||
The id of the scheduled task
|
The id of the scheduled task
|
||||||
"""
|
"""
|
||||||
status = TaskStatus.SCHEDULED
|
status = TaskStatus.SCHEDULED
|
||||||
|
start_now = False
|
||||||
if timestamp is None or timestamp < self._clock.time_msec():
|
if timestamp is None or timestamp < self._clock.time_msec():
|
||||||
timestamp = self._clock.time_msec()
|
timestamp = self._clock.time_msec()
|
||||||
status = TaskStatus.ACTIVE
|
start_now = True
|
||||||
|
|
||||||
task = ScheduledTask(
|
task = ScheduledTask(
|
||||||
random_string(16),
|
random_string(16),
|
||||||
|
@ -190,9 +191,11 @@ class TaskScheduler:
|
||||||
)
|
)
|
||||||
await self._store.insert_scheduled_task(task)
|
await self._store.insert_scheduled_task(task)
|
||||||
|
|
||||||
if status == TaskStatus.ACTIVE:
|
# If the task is ready to run immediately, run the scheduling algorithm now
|
||||||
|
# rather than waiting
|
||||||
|
if start_now:
|
||||||
if self._run_background_tasks:
|
if self._run_background_tasks:
|
||||||
await self._launch_task(task)
|
self._launch_scheduled_tasks()
|
||||||
else:
|
else:
|
||||||
self._hs.get_replication_command_handler().send_new_active_task(task.id)
|
self._hs.get_replication_command_handler().send_new_active_task(task.id)
|
||||||
|
|
||||||
|
@ -300,23 +303,13 @@ class TaskScheduler:
|
||||||
raise Exception(f"Task {id} is currently ACTIVE and can't be deleted")
|
raise Exception(f"Task {id} is currently ACTIVE and can't be deleted")
|
||||||
await self._store.delete_scheduled_task(id)
|
await self._store.delete_scheduled_task(id)
|
||||||
|
|
||||||
def launch_task_by_id(self, id: str) -> None:
|
def on_new_task(self, task_id: str) -> None:
|
||||||
"""Try launching the task with the given ID."""
|
"""Handle a notification that a new ready-to-run task has been added to the queue"""
|
||||||
# Don't bother trying to launch new tasks if we're already at capacity.
|
# Just run the scheduler
|
||||||
if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS:
|
self._launch_scheduled_tasks()
|
||||||
return
|
|
||||||
|
|
||||||
run_as_background_process("launch_task_by_id", self._launch_task_by_id, id)
|
def _launch_scheduled_tasks(self) -> None:
|
||||||
|
"""Retrieve and launch scheduled tasks that should be running at this time."""
|
||||||
async def _launch_task_by_id(self, id: str) -> None:
|
|
||||||
"""Helper async function for `launch_task_by_id`."""
|
|
||||||
task = await self.get_task(id)
|
|
||||||
if task:
|
|
||||||
await self._launch_task(task)
|
|
||||||
|
|
||||||
@wrap_as_background_process("launch_scheduled_tasks")
|
|
||||||
async def _launch_scheduled_tasks(self) -> None:
|
|
||||||
"""Retrieve and launch scheduled tasks that should be running at that time."""
|
|
||||||
# Don't bother trying to launch new tasks if we're already at capacity.
|
# Don't bother trying to launch new tasks if we're already at capacity.
|
||||||
if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS:
|
if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS:
|
||||||
return
|
return
|
||||||
|
@ -326,20 +319,26 @@ class TaskScheduler:
|
||||||
|
|
||||||
self._launching_new_tasks = True
|
self._launching_new_tasks = True
|
||||||
|
|
||||||
try:
|
async def inner() -> None:
|
||||||
for task in await self.get_tasks(
|
try:
|
||||||
statuses=[TaskStatus.ACTIVE], limit=self.MAX_CONCURRENT_RUNNING_TASKS
|
for task in await self.get_tasks(
|
||||||
):
|
statuses=[TaskStatus.ACTIVE],
|
||||||
await self._launch_task(task)
|
limit=self.MAX_CONCURRENT_RUNNING_TASKS,
|
||||||
for task in await self.get_tasks(
|
):
|
||||||
statuses=[TaskStatus.SCHEDULED],
|
# _launch_task will ignore tasks that we're already running, and
|
||||||
max_timestamp=self._clock.time_msec(),
|
# will also do nothing if we're already at the maximum capacity.
|
||||||
limit=self.MAX_CONCURRENT_RUNNING_TASKS,
|
await self._launch_task(task)
|
||||||
):
|
for task in await self.get_tasks(
|
||||||
await self._launch_task(task)
|
statuses=[TaskStatus.SCHEDULED],
|
||||||
|
max_timestamp=self._clock.time_msec(),
|
||||||
|
limit=self.MAX_CONCURRENT_RUNNING_TASKS,
|
||||||
|
):
|
||||||
|
await self._launch_task(task)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
self._launching_new_tasks = False
|
self._launching_new_tasks = False
|
||||||
|
|
||||||
|
run_as_background_process("launch_scheduled_tasks", inner)
|
||||||
|
|
||||||
@wrap_as_background_process("clean_scheduled_tasks")
|
@wrap_as_background_process("clean_scheduled_tasks")
|
||||||
async def _clean_scheduled_tasks(self) -> None:
|
async def _clean_scheduled_tasks(self) -> None:
|
||||||
|
|
55
tests/api/test_urls.py
Normal file
55
tests/api/test_urls.py
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
#
|
||||||
|
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||||
|
#
|
||||||
|
# Copyright (C) 2024 New Vector, Ltd
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as
|
||||||
|
# published by the Free Software Foundation, either version 3 of the
|
||||||
|
# License, or (at your option) any later version.
|
||||||
|
#
|
||||||
|
# See the GNU Affero General Public License for more details:
|
||||||
|
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
|
from synapse.api.urls import LoginSSORedirectURIBuilder
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
|
# a (valid) url with some annoying characters in. %3D is =, %26 is &, %2B is +
|
||||||
|
TRICKY_TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="fö%26=o"'
|
||||||
|
|
||||||
|
|
||||||
|
class LoginSSORedirectURIBuilderTestCase(HomeserverTestCase):
|
||||||
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
|
self.login_sso_redirect_url_builder = LoginSSORedirectURIBuilder(hs.config)
|
||||||
|
|
||||||
|
def test_no_idp_id(self) -> None:
|
||||||
|
self.assertEqual(
|
||||||
|
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
|
||||||
|
idp_id=None, client_redirect_url="http://example.com/redirect"
|
||||||
|
),
|
||||||
|
"https://test/_matrix/client/v3/login/sso/redirect?redirectUrl=http%3A%2F%2Fexample.com%2Fredirect",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_explicit_idp_id(self) -> None:
|
||||||
|
self.assertEqual(
|
||||||
|
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
|
||||||
|
idp_id="oidc-github", client_redirect_url="http://example.com/redirect"
|
||||||
|
),
|
||||||
|
"https://test/_matrix/client/v3/login/sso/redirect/oidc-github?redirectUrl=http%3A%2F%2Fexample.com%2Fredirect",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_tricky_redirect_uri(self) -> None:
|
||||||
|
self.assertEqual(
|
||||||
|
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
|
||||||
|
idp_id="oidc-github",
|
||||||
|
client_redirect_url=TRICKY_TEST_CLIENT_REDIRECT_URL,
|
||||||
|
),
|
||||||
|
"https://test/_matrix/client/v3/login/sso/redirect/oidc-github?redirectUrl=https%3A%2F%2Fx%3F%3Cab+c%3E%26q%22%2B%253D%252B%22%3D%22f%C3%B6%2526%3Do%22",
|
||||||
|
)
|
|
@ -43,6 +43,7 @@ from twisted.web.resource import Resource
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
from synapse.api.constants import ApprovalNoticeMedium, LoginType
|
from synapse.api.constants import ApprovalNoticeMedium, LoginType
|
||||||
from synapse.api.errors import Codes
|
from synapse.api.errors import Codes
|
||||||
|
from synapse.api.urls import LoginSSORedirectURIBuilder
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.appservice import ApplicationService
|
||||||
from synapse.http.client import RawHeaders
|
from synapse.http.client import RawHeaders
|
||||||
from synapse.module_api import ModuleApi
|
from synapse.module_api import ModuleApi
|
||||||
|
@ -69,6 +70,10 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
HAS_JWT = False
|
HAS_JWT = False
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# synapse server name: used to populate public_baseurl in some tests
|
# synapse server name: used to populate public_baseurl in some tests
|
||||||
SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse"
|
SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse"
|
||||||
|
@ -77,7 +82,7 @@ SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse"
|
||||||
# FakeChannel.isSecure() returns False, so synapse will see the requested uri as
|
# FakeChannel.isSecure() returns False, so synapse will see the requested uri as
|
||||||
# http://..., so using http in the public_baseurl stops Synapse trying to redirect to
|
# http://..., so using http in the public_baseurl stops Synapse trying to redirect to
|
||||||
# https://....
|
# https://....
|
||||||
BASE_URL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,)
|
PUBLIC_BASEURL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,)
|
||||||
|
|
||||||
# CAS server used in some tests
|
# CAS server used in some tests
|
||||||
CAS_SERVER = "https://fake.test"
|
CAS_SERVER = "https://fake.test"
|
||||||
|
@ -109,6 +114,23 @@ ADDITIONAL_LOGIN_FLOWS = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_relative_uri_from_absolute_uri(absolute_uri: str) -> str:
|
||||||
|
"""
|
||||||
|
Peels off the path and query string from an absolute URI. Useful when interacting
|
||||||
|
with `make_request(...)` util function which expects a relative path instead of a
|
||||||
|
full URI.
|
||||||
|
"""
|
||||||
|
parsed_uri = urllib.parse.urlparse(absolute_uri)
|
||||||
|
# Sanity check that we're working with an absolute URI
|
||||||
|
assert parsed_uri.scheme == "http" or parsed_uri.scheme == "https"
|
||||||
|
|
||||||
|
relative_uri = parsed_uri.path
|
||||||
|
if parsed_uri.query:
|
||||||
|
relative_uri += "?" + parsed_uri.query
|
||||||
|
|
||||||
|
return relative_uri
|
||||||
|
|
||||||
|
|
||||||
class TestSpamChecker:
|
class TestSpamChecker:
|
||||||
def __init__(self, config: None, api: ModuleApi):
|
def __init__(self, config: None, api: ModuleApi):
|
||||||
api.register_spam_checker_callbacks(
|
api.register_spam_checker_callbacks(
|
||||||
|
@ -614,7 +636,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
def default_config(self) -> Dict[str, Any]:
|
def default_config(self) -> Dict[str, Any]:
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
|
|
||||||
config["public_baseurl"] = BASE_URL
|
config["public_baseurl"] = PUBLIC_BASEURL
|
||||||
|
|
||||||
config["cas_config"] = {
|
config["cas_config"] = {
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
|
@ -653,6 +675,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
]
|
]
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
|
self.login_sso_redirect_url_builder = LoginSSORedirectURIBuilder(hs.config)
|
||||||
|
|
||||||
def create_resource_dict(self) -> Dict[str, Resource]:
|
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||||
d = super().create_resource_dict()
|
d = super().create_resource_dict()
|
||||||
d.update(build_synapse_client_resource_tree(self.hs))
|
d.update(build_synapse_client_resource_tree(self.hs))
|
||||||
|
@ -725,6 +750,32 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
+ "&idp=cas",
|
+ "&idp=cas",
|
||||||
shorthand=False,
|
shorthand=False,
|
||||||
)
|
)
|
||||||
|
self.assertEqual(channel.code, 302, channel.result)
|
||||||
|
location_headers = channel.headers.getRawHeaders("Location")
|
||||||
|
assert location_headers
|
||||||
|
sso_login_redirect_uri = location_headers[0]
|
||||||
|
|
||||||
|
# it should redirect us to the standard login SSO redirect flow
|
||||||
|
self.assertEqual(
|
||||||
|
sso_login_redirect_uri,
|
||||||
|
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
|
||||||
|
idp_id="cas", client_redirect_url=TEST_CLIENT_REDIRECT_URL
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# follow the redirect
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
# We have to make this relative to be compatible with `make_request(...)`
|
||||||
|
get_relative_uri_from_absolute_uri(sso_login_redirect_uri),
|
||||||
|
# We have to set the Host header to match the `public_baseurl` to avoid
|
||||||
|
# the extra redirect in the `SsoRedirectServlet` in order for the
|
||||||
|
# cookies to be visible.
|
||||||
|
custom_headers=[
|
||||||
|
("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(channel.code, 302, channel.result)
|
self.assertEqual(channel.code, 302, channel.result)
|
||||||
location_headers = channel.headers.getRawHeaders("Location")
|
location_headers = channel.headers.getRawHeaders("Location")
|
||||||
assert location_headers
|
assert location_headers
|
||||||
|
@ -750,6 +801,32 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
|
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
|
||||||
+ "&idp=saml",
|
+ "&idp=saml",
|
||||||
)
|
)
|
||||||
|
self.assertEqual(channel.code, 302, channel.result)
|
||||||
|
location_headers = channel.headers.getRawHeaders("Location")
|
||||||
|
assert location_headers
|
||||||
|
sso_login_redirect_uri = location_headers[0]
|
||||||
|
|
||||||
|
# it should redirect us to the standard login SSO redirect flow
|
||||||
|
self.assertEqual(
|
||||||
|
sso_login_redirect_uri,
|
||||||
|
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
|
||||||
|
idp_id="saml", client_redirect_url=TEST_CLIENT_REDIRECT_URL
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# follow the redirect
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
# We have to make this relative to be compatible with `make_request(...)`
|
||||||
|
get_relative_uri_from_absolute_uri(sso_login_redirect_uri),
|
||||||
|
# We have to set the Host header to match the `public_baseurl` to avoid
|
||||||
|
# the extra redirect in the `SsoRedirectServlet` in order for the
|
||||||
|
# cookies to be visible.
|
||||||
|
custom_headers=[
|
||||||
|
("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(channel.code, 302, channel.result)
|
self.assertEqual(channel.code, 302, channel.result)
|
||||||
location_headers = channel.headers.getRawHeaders("Location")
|
location_headers = channel.headers.getRawHeaders("Location")
|
||||||
assert location_headers
|
assert location_headers
|
||||||
|
@ -773,13 +850,38 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
# pick the default OIDC provider
|
# pick the default OIDC provider
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET",
|
"GET",
|
||||||
"/_synapse/client/pick_idp?redirectUrl="
|
f"/_synapse/client/pick_idp?redirectUrl={urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)}&idp=oidc",
|
||||||
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
|
|
||||||
+ "&idp=oidc",
|
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 302, channel.result)
|
self.assertEqual(channel.code, 302, channel.result)
|
||||||
location_headers = channel.headers.getRawHeaders("Location")
|
location_headers = channel.headers.getRawHeaders("Location")
|
||||||
assert location_headers
|
assert location_headers
|
||||||
|
sso_login_redirect_uri = location_headers[0]
|
||||||
|
|
||||||
|
# it should redirect us to the standard login SSO redirect flow
|
||||||
|
self.assertEqual(
|
||||||
|
sso_login_redirect_uri,
|
||||||
|
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
|
||||||
|
idp_id="oidc", client_redirect_url=TEST_CLIENT_REDIRECT_URL
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
with fake_oidc_server.patch_homeserver(hs=self.hs):
|
||||||
|
# follow the redirect
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
# We have to make this relative to be compatible with `make_request(...)`
|
||||||
|
get_relative_uri_from_absolute_uri(sso_login_redirect_uri),
|
||||||
|
# We have to set the Host header to match the `public_baseurl` to avoid
|
||||||
|
# the extra redirect in the `SsoRedirectServlet` in order for the
|
||||||
|
# cookies to be visible.
|
||||||
|
custom_headers=[
|
||||||
|
("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(channel.code, 302, channel.result)
|
||||||
|
location_headers = channel.headers.getRawHeaders("Location")
|
||||||
|
assert location_headers
|
||||||
oidc_uri = location_headers[0]
|
oidc_uri = location_headers[0]
|
||||||
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
|
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
|
||||||
|
|
||||||
|
@ -838,12 +940,38 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(chan.json_body["user_id"], "@user1:test")
|
self.assertEqual(chan.json_body["user_id"], "@user1:test")
|
||||||
|
|
||||||
def test_multi_sso_redirect_to_unknown(self) -> None:
|
def test_multi_sso_redirect_to_unknown(self) -> None:
|
||||||
"""An unknown IdP should cause a 400"""
|
"""An unknown IdP should cause a 404"""
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET",
|
"GET",
|
||||||
"/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
|
"/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 400, channel.result)
|
self.assertEqual(channel.code, 302, channel.result)
|
||||||
|
location_headers = channel.headers.getRawHeaders("Location")
|
||||||
|
assert location_headers
|
||||||
|
sso_login_redirect_uri = location_headers[0]
|
||||||
|
|
||||||
|
# it should redirect us to the standard login SSO redirect flow
|
||||||
|
self.assertEqual(
|
||||||
|
sso_login_redirect_uri,
|
||||||
|
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
|
||||||
|
idp_id="xyz", client_redirect_url="http://x"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# follow the redirect
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
# We have to make this relative to be compatible with `make_request(...)`
|
||||||
|
get_relative_uri_from_absolute_uri(sso_login_redirect_uri),
|
||||||
|
# We have to set the Host header to match the `public_baseurl` to avoid
|
||||||
|
# the extra redirect in the `SsoRedirectServlet` in order for the
|
||||||
|
# cookies to be visible.
|
||||||
|
custom_headers=[
|
||||||
|
("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(channel.code, 404, channel.result)
|
||||||
|
|
||||||
def test_client_idp_redirect_to_unknown(self) -> None:
|
def test_client_idp_redirect_to_unknown(self) -> None:
|
||||||
"""If the client tries to pick an unknown IdP, return a 404"""
|
"""If the client tries to pick an unknown IdP, return a 404"""
|
||||||
|
@ -1473,7 +1601,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
def default_config(self) -> Dict[str, Any]:
|
def default_config(self) -> Dict[str, Any]:
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
config["public_baseurl"] = BASE_URL
|
config["public_baseurl"] = PUBLIC_BASEURL
|
||||||
|
|
||||||
config["oidc_config"] = {}
|
config["oidc_config"] = {}
|
||||||
config["oidc_config"].update(TEST_OIDC_CONFIG)
|
config["oidc_config"].update(TEST_OIDC_CONFIG)
|
||||||
|
|
|
@ -889,7 +889,7 @@ class RestHelper:
|
||||||
"GET",
|
"GET",
|
||||||
uri,
|
uri,
|
||||||
)
|
)
|
||||||
assert channel.code == 302
|
assert channel.code == 302, f"Expected 302 for {uri}, got {channel.code}"
|
||||||
|
|
||||||
# hit the redirect url again with the right Host header, which should now issue
|
# hit the redirect url again with the right Host header, which should now issue
|
||||||
# a cookie and redirect to the SSO provider.
|
# a cookie and redirect to the SSO provider.
|
||||||
|
@ -901,17 +901,18 @@ class RestHelper:
|
||||||
|
|
||||||
location = get_location(channel)
|
location = get_location(channel)
|
||||||
parts = urllib.parse.urlsplit(location)
|
parts = urllib.parse.urlsplit(location)
|
||||||
|
next_uri = urllib.parse.urlunsplit(("", "") + parts[2:])
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
self.site,
|
self.site,
|
||||||
"GET",
|
"GET",
|
||||||
urllib.parse.urlunsplit(("", "") + parts[2:]),
|
next_uri,
|
||||||
custom_headers=[
|
custom_headers=[
|
||||||
("Host", parts[1]),
|
("Host", parts[1]),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert channel.code == 302
|
assert channel.code == 302, f"Expected 302 for {next_uri}, got {channel.code}"
|
||||||
channel.extract_cookies(cookies)
|
channel.extract_cookies(cookies)
|
||||||
return get_location(channel)
|
return get_location(channel)
|
||||||
|
|
||||||
|
|
|
@ -18,8 +18,7 @@
|
||||||
# [This file includes modifications made by New Vector Limited]
|
# [This file includes modifications made by New Vector Limited]
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
from twisted.internet.task import deferLater
|
from twisted.internet.task import deferLater
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
@ -104,33 +103,43 @@ class TestTaskScheduler(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# This is to give the time to the active tasks to finish
|
def get_tasks_of_status(status: TaskStatus) -> List[ScheduledTask]:
|
||||||
self.reactor.advance(1)
|
tasks = (
|
||||||
|
self.get_success(self.task_scheduler.get_task(task_id))
|
||||||
# Check that only MAX_CONCURRENT_RUNNING_TASKS tasks has run and that one
|
for task_id in task_ids
|
||||||
# is still scheduled.
|
)
|
||||||
tasks = [
|
return [t for t in tasks if t is not None and t.status == status]
|
||||||
self.get_success(self.task_scheduler.get_task(task_id))
|
|
||||||
for task_id in task_ids
|
|
||||||
]
|
|
||||||
|
|
||||||
|
# At this point, there should be MAX_CONCURRENT_RUNNING_TASKS active tasks and
|
||||||
|
# one scheduled task.
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
len(
|
len(get_tasks_of_status(TaskStatus.ACTIVE)),
|
||||||
[t for t in tasks if t is not None and t.status == TaskStatus.COMPLETE]
|
|
||||||
),
|
|
||||||
TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS,
|
TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS,
|
||||||
)
|
)
|
||||||
|
self.assertEquals(
|
||||||
|
len(get_tasks_of_status(TaskStatus.SCHEDULED)),
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
scheduled_tasks = [
|
# Give the time to the active tasks to finish
|
||||||
t for t in tasks if t is not None and t.status == TaskStatus.ACTIVE
|
|
||||||
]
|
|
||||||
self.assertEquals(len(scheduled_tasks), 1)
|
|
||||||
|
|
||||||
# We need to wait for the next run of the scheduler loop
|
|
||||||
self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
|
|
||||||
self.reactor.advance(1)
|
self.reactor.advance(1)
|
||||||
|
|
||||||
# Check that the last task has been properly executed after the next scheduler loop run
|
# Check that MAX_CONCURRENT_RUNNING_TASKS tasks have run and that one
|
||||||
|
# is still scheduled.
|
||||||
|
self.assertEquals(
|
||||||
|
len(get_tasks_of_status(TaskStatus.COMPLETE)),
|
||||||
|
TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS,
|
||||||
|
)
|
||||||
|
scheduled_tasks = get_tasks_of_status(TaskStatus.SCHEDULED)
|
||||||
|
self.assertEquals(len(scheduled_tasks), 1)
|
||||||
|
|
||||||
|
# The scheduled task should start 0.1s after the first of the active tasks
|
||||||
|
# finishes
|
||||||
|
self.reactor.advance(0.1)
|
||||||
|
self.assertEquals(len(get_tasks_of_status(TaskStatus.ACTIVE)), 1)
|
||||||
|
|
||||||
|
# ... and should finally complete after another second
|
||||||
|
self.reactor.advance(1)
|
||||||
prev_scheduled_task = self.get_success(
|
prev_scheduled_task = self.get_success(
|
||||||
self.task_scheduler.get_task(scheduled_tasks[0].id)
|
self.task_scheduler.get_task(scheduled_tasks[0].id)
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue