mirror of
https://github.com/element-hq/synapse.git
synced 2024-12-14 11:57:44 +00:00
Add a module callback to set username at registration (#11790)
This is in the context of mainlining the Tchap fork of Synapse. Currently in Tchap usernames are derived from the user's email address (extracted from the UIA results, more specifically the m.login.email.identity step). This change also exports the check_username method from the registration handler as part of the module API, so that a module can check if the username it's trying to generate is correct and doesn't conflict with an existing one, and fallback gracefully if not. Co-authored-by: David Robertson <davidr@element.io>
This commit is contained in:
parent
2897fb6b4f
commit
2d3bd9aa67
6 changed files with 231 additions and 3 deletions
1
changelog.d/11790.feature
Normal file
1
changelog.d/11790.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add a module callback to set username at registration.
|
|
@ -105,6 +105,68 @@ device ID), and the (now deactivated) access token.
|
||||||
|
|
||||||
If multiple modules implement this callback, Synapse runs them all in order.
|
If multiple modules implement this callback, Synapse runs them all in order.
|
||||||
|
|
||||||
|
### `get_username_for_registration`
|
||||||
|
|
||||||
|
_First introduced in Synapse v1.52.0_
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def get_username_for_registration(
|
||||||
|
uia_results: Dict[str, Any],
|
||||||
|
params: Dict[str, Any],
|
||||||
|
) -> Optional[str]
|
||||||
|
```
|
||||||
|
|
||||||
|
Called when registering a new user. The module can return a username to set for the user
|
||||||
|
being registered by returning it as a string, or `None` if it doesn't wish to force a
|
||||||
|
username for this user. If a username is returned, it will be used as the local part of a
|
||||||
|
user's full Matrix ID (e.g. it's `alice` in `@alice:example.com`).
|
||||||
|
|
||||||
|
This callback is called once [User-Interactive Authentication](https://spec.matrix.org/latest/client-server-api/#user-interactive-authentication-api)
|
||||||
|
has been completed by the user. It is not called when registering a user via SSO. It is
|
||||||
|
passed two dictionaries, which include the information that the user has provided during
|
||||||
|
the registration process.
|
||||||
|
|
||||||
|
The first dictionary contains the results of the [User-Interactive Authentication](https://spec.matrix.org/latest/client-server-api/#user-interactive-authentication-api)
|
||||||
|
flow followed by the user. Its keys are the identifiers of every step involved in the flow,
|
||||||
|
associated with either a boolean value indicating whether the step was correctly completed,
|
||||||
|
or additional information (e.g. email address, phone number...). A list of most existing
|
||||||
|
identifiers can be found in the [Matrix specification](https://spec.matrix.org/v1.1/client-server-api/#authentication-types).
|
||||||
|
Here's an example featuring all currently supported keys:
|
||||||
|
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
"m.login.dummy": True, # Dummy authentication
|
||||||
|
"m.login.terms": True, # User has accepted the terms of service for the homeserver
|
||||||
|
"m.login.recaptcha": True, # User has completed the recaptcha challenge
|
||||||
|
"m.login.email.identity": { # User has provided and verified an email address
|
||||||
|
"medium": "email",
|
||||||
|
"address": "alice@example.com",
|
||||||
|
"validated_at": 1642701357084,
|
||||||
|
},
|
||||||
|
"m.login.msisdn": { # User has provided and verified a phone number
|
||||||
|
"medium": "msisdn",
|
||||||
|
"address": "33123456789",
|
||||||
|
"validated_at": 1642701357084,
|
||||||
|
},
|
||||||
|
"org.matrix.msc3231.login.registration_token": "sometoken", # User has registered through the flow described in MSC3231
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The second dictionary contains the parameters provided by the user's client in the request
|
||||||
|
to `/_matrix/client/v3/register`. See the [Matrix specification](https://spec.matrix.org/latest/client-server-api/#post_matrixclientv3register)
|
||||||
|
for a complete list of these parameters.
|
||||||
|
|
||||||
|
If the module cannot, or does not wish to, generate a username for this user, it must
|
||||||
|
return `None`.
|
||||||
|
|
||||||
|
If multiple modules implement this callback, they will be considered in order. If a
|
||||||
|
callback returns `None`, Synapse falls through to the next one. The value of the first
|
||||||
|
callback that does not return `None` will be used. If this happens, Synapse will not call
|
||||||
|
any of the subsequent implementations of this callback. If every callback return `None`,
|
||||||
|
the username provided by the user is used, if any (otherwise one is automatically
|
||||||
|
generated).
|
||||||
|
|
||||||
|
|
||||||
## Example
|
## Example
|
||||||
|
|
||||||
The example module below implements authentication checkers for two different login types:
|
The example module below implements authentication checkers for two different login types:
|
||||||
|
|
|
@ -2060,6 +2060,10 @@ CHECK_AUTH_CALLBACK = Callable[
|
||||||
Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
|
Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
|
||||||
],
|
],
|
||||||
]
|
]
|
||||||
|
GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
|
||||||
|
[JsonDict, JsonDict],
|
||||||
|
Awaitable[Optional[str]],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class PasswordAuthProvider:
|
class PasswordAuthProvider:
|
||||||
|
@ -2072,6 +2076,9 @@ class PasswordAuthProvider:
|
||||||
# lists of callbacks
|
# lists of callbacks
|
||||||
self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = []
|
self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = []
|
||||||
self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = []
|
self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = []
|
||||||
|
self.get_username_for_registration_callbacks: List[
|
||||||
|
GET_USERNAME_FOR_REGISTRATION_CALLBACK
|
||||||
|
] = []
|
||||||
|
|
||||||
# Mapping from login type to login parameters
|
# Mapping from login type to login parameters
|
||||||
self._supported_login_types: Dict[str, Iterable[str]] = {}
|
self._supported_login_types: Dict[str, Iterable[str]] = {}
|
||||||
|
@ -2086,6 +2093,9 @@ class PasswordAuthProvider:
|
||||||
auth_checkers: Optional[
|
auth_checkers: Optional[
|
||||||
Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
|
Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
|
||||||
] = None,
|
] = None,
|
||||||
|
get_username_for_registration: Optional[
|
||||||
|
GET_USERNAME_FOR_REGISTRATION_CALLBACK
|
||||||
|
] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Register check_3pid_auth callback
|
# Register check_3pid_auth callback
|
||||||
if check_3pid_auth is not None:
|
if check_3pid_auth is not None:
|
||||||
|
@ -2130,6 +2140,11 @@ class PasswordAuthProvider:
|
||||||
# Add the new method to the list of auth_checker_callbacks for this login type
|
# Add the new method to the list of auth_checker_callbacks for this login type
|
||||||
self.auth_checker_callbacks.setdefault(login_type, []).append(callback)
|
self.auth_checker_callbacks.setdefault(login_type, []).append(callback)
|
||||||
|
|
||||||
|
if get_username_for_registration is not None:
|
||||||
|
self.get_username_for_registration_callbacks.append(
|
||||||
|
get_username_for_registration,
|
||||||
|
)
|
||||||
|
|
||||||
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
|
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
|
||||||
"""Get the login types supported by this password provider
|
"""Get the login types supported by this password provider
|
||||||
|
|
||||||
|
@ -2285,3 +2300,46 @@ class PasswordAuthProvider:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Failed to run module API callback %s: %s", callback, e)
|
logger.warning("Failed to run module API callback %s: %s", callback, e)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
async def get_username_for_registration(
|
||||||
|
self,
|
||||||
|
uia_results: JsonDict,
|
||||||
|
params: JsonDict,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Defines the username to use when registering the user, using the credentials
|
||||||
|
and parameters provided during the UIA flow.
|
||||||
|
|
||||||
|
Stops at the first callback that returns a string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
uia_results: The credentials provided during the UIA flow.
|
||||||
|
params: The parameters provided by the registration request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The localpart to use when registering this user, or None if no module
|
||||||
|
returned a localpart.
|
||||||
|
"""
|
||||||
|
for callback in self.get_username_for_registration_callbacks:
|
||||||
|
try:
|
||||||
|
res = await callback(uia_results, params)
|
||||||
|
|
||||||
|
if isinstance(res, str):
|
||||||
|
return res
|
||||||
|
elif res is not None:
|
||||||
|
# mypy complains that this line is unreachable because it assumes the
|
||||||
|
# data returned by the module fits the expected type. We just want
|
||||||
|
# to make sure this is the case.
|
||||||
|
logger.warning( # type: ignore[unreachable]
|
||||||
|
"Ignoring non-string value returned by"
|
||||||
|
" get_username_for_registration callback %s: %s",
|
||||||
|
callback,
|
||||||
|
res,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"Module raised an exception in get_username_for_registration: %s",
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
raise SynapseError(code=500, msg="Internal Server Error")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
|
@ -71,6 +71,7 @@ from synapse.handlers.account_validity import (
|
||||||
from synapse.handlers.auth import (
|
from synapse.handlers.auth import (
|
||||||
CHECK_3PID_AUTH_CALLBACK,
|
CHECK_3PID_AUTH_CALLBACK,
|
||||||
CHECK_AUTH_CALLBACK,
|
CHECK_AUTH_CALLBACK,
|
||||||
|
GET_USERNAME_FOR_REGISTRATION_CALLBACK,
|
||||||
ON_LOGGED_OUT_CALLBACK,
|
ON_LOGGED_OUT_CALLBACK,
|
||||||
AuthHandler,
|
AuthHandler,
|
||||||
)
|
)
|
||||||
|
@ -177,6 +178,7 @@ class ModuleApi:
|
||||||
self._presence_stream = hs.get_event_sources().sources.presence
|
self._presence_stream = hs.get_event_sources().sources.presence
|
||||||
self._state = hs.get_state_handler()
|
self._state = hs.get_state_handler()
|
||||||
self._clock: Clock = hs.get_clock()
|
self._clock: Clock = hs.get_clock()
|
||||||
|
self._registration_handler = hs.get_registration_handler()
|
||||||
self._send_email_handler = hs.get_send_email_handler()
|
self._send_email_handler = hs.get_send_email_handler()
|
||||||
self.custom_template_dir = hs.config.server.custom_template_directory
|
self.custom_template_dir = hs.config.server.custom_template_directory
|
||||||
|
|
||||||
|
@ -310,6 +312,9 @@ class ModuleApi:
|
||||||
auth_checkers: Optional[
|
auth_checkers: Optional[
|
||||||
Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
|
Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
|
||||||
] = None,
|
] = None,
|
||||||
|
get_username_for_registration: Optional[
|
||||||
|
GET_USERNAME_FOR_REGISTRATION_CALLBACK
|
||||||
|
] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Registers callbacks for password auth provider capabilities.
|
"""Registers callbacks for password auth provider capabilities.
|
||||||
|
|
||||||
|
@ -319,6 +324,7 @@ class ModuleApi:
|
||||||
check_3pid_auth=check_3pid_auth,
|
check_3pid_auth=check_3pid_auth,
|
||||||
on_logged_out=on_logged_out,
|
on_logged_out=on_logged_out,
|
||||||
auth_checkers=auth_checkers,
|
auth_checkers=auth_checkers,
|
||||||
|
get_username_for_registration=get_username_for_registration,
|
||||||
)
|
)
|
||||||
|
|
||||||
def register_background_update_controller_callbacks(
|
def register_background_update_controller_callbacks(
|
||||||
|
@ -1202,6 +1208,22 @@ class ModuleApi:
|
||||||
"""
|
"""
|
||||||
return await defer_to_thread(self._hs.get_reactor(), f, *args, **kwargs)
|
return await defer_to_thread(self._hs.get_reactor(), f, *args, **kwargs)
|
||||||
|
|
||||||
|
async def check_username(self, username: str) -> None:
|
||||||
|
"""Checks if the provided username uses the grammar defined in the Matrix
|
||||||
|
specification, and is already being used by an existing user.
|
||||||
|
|
||||||
|
Added in Synapse v1.52.0.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
username: The username to check. This is the local part of the user's full
|
||||||
|
Matrix user ID, i.e. it's "alice" if the full user ID is "@alice:foo.com".
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SynapseError with the errcode "M_USER_IN_USE" if the username is already in
|
||||||
|
use.
|
||||||
|
"""
|
||||||
|
await self._registration_handler.check_username(username)
|
||||||
|
|
||||||
|
|
||||||
class PublicRoomListManager:
|
class PublicRoomListManager:
|
||||||
"""Contains methods for adding to, removing from and querying whether a room
|
"""Contains methods for adding to, removing from and querying whether a room
|
||||||
|
|
|
@ -425,6 +425,7 @@ class RegisterRestServlet(RestServlet):
|
||||||
self.ratelimiter = hs.get_registration_ratelimiter()
|
self.ratelimiter = hs.get_registration_ratelimiter()
|
||||||
self.password_policy_handler = hs.get_password_policy_handler()
|
self.password_policy_handler = hs.get_password_policy_handler()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
self.password_auth_provider = hs.get_password_auth_provider()
|
||||||
self._registration_enabled = self.hs.config.registration.enable_registration
|
self._registration_enabled = self.hs.config.registration.enable_registration
|
||||||
self._refresh_tokens_enabled = (
|
self._refresh_tokens_enabled = (
|
||||||
hs.config.registration.refreshable_access_token_lifetime is not None
|
hs.config.registration.refreshable_access_token_lifetime is not None
|
||||||
|
@ -638,7 +639,16 @@ class RegisterRestServlet(RestServlet):
|
||||||
if not password_hash:
|
if not password_hash:
|
||||||
raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
|
raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
|
||||||
|
|
||||||
desired_username = params.get("username", None)
|
desired_username = await (
|
||||||
|
self.password_auth_provider.get_username_for_registration(
|
||||||
|
auth_result,
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if desired_username is None:
|
||||||
|
desired_username = params.get("username", None)
|
||||||
|
|
||||||
guest_access_token = params.get("guest_access_token", None)
|
guest_access_token = params.get("guest_access_token", None)
|
||||||
|
|
||||||
if desired_username is not None:
|
if desired_username is not None:
|
||||||
|
|
|
@ -20,10 +20,11 @@ from unittest.mock import Mock
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
import synapse
|
import synapse
|
||||||
|
from synapse.api.constants import LoginType
|
||||||
from synapse.handlers.auth import load_legacy_password_auth_providers
|
from synapse.handlers.auth import load_legacy_password_auth_providers
|
||||||
from synapse.module_api import ModuleApi
|
from synapse.module_api import ModuleApi
|
||||||
from synapse.rest.client import devices, login, logout
|
from synapse.rest.client import devices, login, logout, register
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict, UserID
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.server import FakeChannel
|
from tests.server import FakeChannel
|
||||||
|
@ -156,6 +157,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
devices.register_servlets,
|
devices.register_servlets,
|
||||||
logout.register_servlets,
|
logout.register_servlets,
|
||||||
|
register.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
@ -745,6 +747,79 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
on_logged_out.assert_called_once()
|
on_logged_out.assert_called_once()
|
||||||
self.assertTrue(self.called)
|
self.assertTrue(self.called)
|
||||||
|
|
||||||
|
def test_username(self):
|
||||||
|
"""Tests that the get_username_for_registration callback can define the username
|
||||||
|
of a user when registering.
|
||||||
|
"""
|
||||||
|
self._setup_get_username_for_registration()
|
||||||
|
|
||||||
|
username = "rin"
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"/register",
|
||||||
|
{
|
||||||
|
"username": username,
|
||||||
|
"password": "bar",
|
||||||
|
"auth": {"type": LoginType.DUMMY},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 200)
|
||||||
|
|
||||||
|
# Our callback takes the username and appends "-foo" to it, check that's what we
|
||||||
|
# have.
|
||||||
|
mxid = channel.json_body["user_id"]
|
||||||
|
self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
|
||||||
|
|
||||||
|
def test_username_uia(self):
|
||||||
|
"""Tests that the get_username_for_registration callback is only called at the
|
||||||
|
end of the UIA flow.
|
||||||
|
"""
|
||||||
|
m = self._setup_get_username_for_registration()
|
||||||
|
|
||||||
|
# Initiate the UIA flow.
|
||||||
|
username = "rin"
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"register",
|
||||||
|
{"username": username, "type": "m.login.password", "password": "bar"},
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 401)
|
||||||
|
self.assertIn("session", channel.json_body)
|
||||||
|
|
||||||
|
# Check that the callback hasn't been called yet.
|
||||||
|
m.assert_not_called()
|
||||||
|
|
||||||
|
# Finish the UIA flow.
|
||||||
|
session = channel.json_body["session"]
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"register",
|
||||||
|
{"auth": {"session": session, "type": LoginType.DUMMY}},
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 200, channel.json_body)
|
||||||
|
mxid = channel.json_body["user_id"]
|
||||||
|
self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
|
||||||
|
|
||||||
|
# Check that the callback has been called.
|
||||||
|
m.assert_called_once()
|
||||||
|
|
||||||
|
def _setup_get_username_for_registration(self) -> Mock:
|
||||||
|
"""Registers a get_username_for_registration callback that appends "-foo" to the
|
||||||
|
username the client is trying to register.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def get_username_for_registration(uia_results, params):
|
||||||
|
self.assertIn(LoginType.DUMMY, uia_results)
|
||||||
|
username = params["username"]
|
||||||
|
return username + "-foo"
|
||||||
|
|
||||||
|
m = Mock(side_effect=get_username_for_registration)
|
||||||
|
|
||||||
|
password_auth_provider = self.hs.get_password_auth_provider()
|
||||||
|
password_auth_provider.get_username_for_registration_callbacks.append(m)
|
||||||
|
|
||||||
|
return m
|
||||||
|
|
||||||
def _get_login_flows(self) -> JsonDict:
|
def _get_login_flows(self) -> JsonDict:
|
||||||
channel = self.make_request("GET", "/_matrix/client/r0/login")
|
channel = self.make_request("GET", "/_matrix/client/r0/login")
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
|
|
Loading…
Reference in a new issue