diff --git a/changelog.d/18231.feature b/changelog.d/18231.feature new file mode 100644 index 0000000000..7fa65e4fa6 --- /dev/null +++ b/changelog.d/18231.feature @@ -0,0 +1 @@ +Add an access token introspection cache to make Matrix Authentication Service integration (MSC3861) more efficient. \ No newline at end of file diff --git a/synapse/api/auth/msc3861_delegated.py b/synapse/api/auth/msc3861_delegated.py index e6bf271a1f..74e526123f 100644 --- a/synapse/api/auth/msc3861_delegated.py +++ b/synapse/api/auth/msc3861_delegated.py @@ -19,6 +19,7 @@ # # import logging +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional from urllib.parse import urlencode @@ -47,6 +48,7 @@ from synapse.logging.context import make_deferred_yieldable from synapse.types import Requester, UserID, create_requester from synapse.util import json_decoder from synapse.util.caches.cached_call import RetryOnExceptionCachedCall +from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: from synapse.rest.admin.experimental_features import ExperimentalFeature @@ -76,6 +78,61 @@ def scope_to_list(scope: str) -> List[str]: return scope.strip().split(" ") +@dataclass +class IntrospectionResult: + _inner: IntrospectionToken + + # when we retrieved this token, + # in milliseconds since the Unix epoch + retrieved_at_ms: int + + def is_active(self, now_ms: int) -> bool: + if not self._inner.get("active"): + return False + + expires_in = self._inner.get("expires_in") + if expires_in is None: + return True + if not isinstance(expires_in, int): + raise InvalidClientTokenError("token `expires_in` is not an int") + + absolute_expiry_ms = expires_in * 1000 + self.retrieved_at_ms + return now_ms < absolute_expiry_ms + + def get_scope_list(self) -> List[str]: + value = self._inner.get("scope") + if not isinstance(value, str): + return [] + return scope_to_list(value) + + def get_sub(self) -> Optional[str]: + value = self._inner.get("sub") + if not isinstance(value, str): + return None + return value + + def get_username(self) -> Optional[str]: + value = self._inner.get("username") + if not isinstance(value, str): + return None + return value + + def get_name(self) -> Optional[str]: + value = self._inner.get("name") + if not isinstance(value, str): + return None + return value + + def get_device_id(self) -> Optional[str]: + value = self._inner.get("device_id") + if value is not None and not isinstance(value, str): + raise AuthError( + 500, + "Invalid device ID in introspection result", + ) + return value + + class PrivateKeyJWTWithKid(PrivateKeyJWT): # type: ignore[misc] """An implementation of the private_key_jwt client auth method that includes a kid header. @@ -121,6 +178,31 @@ class MSC3861DelegatedAuth(BaseAuth): self._hostname = hs.hostname self._admin_token: Callable[[], Optional[str]] = self._config.admin_token + # # Token Introspection Cache + # This remembers what users/devices are represented by which access tokens, + # in order to reduce overall system load: + # - on Synapse (as requests are relatively expensive) + # - on the network + # - on MAS + # + # Since there is no invalidation mechanism currently, + # the entries expire after 2 minutes. + # This does mean tokens can be treated as valid by Synapse + # for longer than reality. + # + # Ideally, tokens should logically be invalidated in the following circumstances: + # - If a session logout happens. + # In this case, MAS will delete the device within Synapse + # anyway and this is good enough as an invalidation. + # - If the client refreshes their token in MAS. + # In this case, the device still exists and it's not the end of the world for + # the old access token to continue working for a short time. + self._introspection_cache: ResponseCache[str] = ResponseCache( + self._clock, + "token_introspection", + timeout_ms=120_000, + ) + self._issuer_metadata = RetryOnExceptionCachedCall[OpenIDProviderMetadata]( self._load_metadata ) @@ -193,7 +275,7 @@ class MSC3861DelegatedAuth(BaseAuth): metadata = await self._issuer_metadata.get() return metadata.get("introspection_endpoint") - async def _introspect_token(self, token: str) -> IntrospectionToken: + async def _introspect_token(self, token: str) -> IntrospectionResult: """ Send a token to the introspection endpoint and returns the introspection response @@ -266,7 +348,9 @@ class MSC3861DelegatedAuth(BaseAuth): "The introspection endpoint returned an invalid JSON response." ) - return IntrospectionToken(**resp) + return IntrospectionResult( + IntrospectionToken(**resp), retrieved_at_ms=self._clock.time_msec() + ) async def is_server_admin(self, requester: Requester) -> bool: return "urn:synapse:admin:*" in requester.scope @@ -344,7 +428,9 @@ class MSC3861DelegatedAuth(BaseAuth): ) try: - introspection_result = await self._introspect_token(token) + introspection_result = await self._introspection_cache.wrap( + token, self._introspect_token, token + ) except Exception: logger.exception("Failed to introspect token") raise SynapseError(503, "Unable to introspect the access token") @@ -353,11 +439,11 @@ class MSC3861DelegatedAuth(BaseAuth): # TODO: introspection verification should be more extensive, especially: # - verify the audience - if not introspection_result.get("active"): + if not introspection_result.is_active(self._clock.time_msec()): raise InvalidClientTokenError("Token is not active") # Let's look at the scope - scope: List[str] = scope_to_list(introspection_result.get("scope", "")) + scope: List[str] = introspection_result.get_scope_list() # Determine type of user based on presence of particular scopes has_user_scope = SCOPE_MATRIX_API in scope @@ -367,7 +453,7 @@ class MSC3861DelegatedAuth(BaseAuth): raise InvalidClientTokenError("No scope in token granting user rights") # Match via the sub claim - sub: Optional[str] = introspection_result.get("sub") + sub: Optional[str] = introspection_result.get_sub() if sub is None: raise InvalidClientTokenError( "Invalid sub claim in the introspection result" @@ -381,7 +467,7 @@ class MSC3861DelegatedAuth(BaseAuth): # or the external_id was never recorded # TODO: claim mapping should be configurable - username: Optional[str] = introspection_result.get("username") + username: Optional[str] = introspection_result.get_username() if username is None or not isinstance(username, str): raise AuthError( 500, @@ -399,7 +485,7 @@ class MSC3861DelegatedAuth(BaseAuth): # TODO: claim mapping should be configurable # If present, use the name claim as the displayname - name: Optional[str] = introspection_result.get("name") + name: Optional[str] = introspection_result.get_name() await self.store.register_user( user_id=user_id.to_string(), create_profile_with_displayname=name @@ -414,15 +500,8 @@ class MSC3861DelegatedAuth(BaseAuth): # MAS 0.15+ will give us the device ID as an explicit value for compatibility sessions # If present, we get it from here, if not we get it in thee scope - device_id = introspection_result.get("device_id") - if device_id is not None: - # We got the device ID explicitly, just sanity check that it's a string - if not isinstance(device_id, str): - raise AuthError( - 500, - "Invalid device ID in introspection result", - ) - else: + device_id = introspection_result.get_device_id() + if device_id is None: # Find device_ids in scope # We only allow a single device_id in the scope, so we find them all in the # scope list, and raise if there are more than one. The OIDC server should be diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py index 5f8c25557a..034a1594d9 100644 --- a/tests/handlers/test_oauth_delegation.py +++ b/tests/handlers/test_oauth_delegation.py @@ -539,6 +539,44 @@ class MSC3861OAuthDelegation(HomeserverTestCase): error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) self.assertEqual(error.value.code, 503) + def test_cached_expired_introspection(self) -> None: + """The handler should raise an error if the introspection response gives + an expiry time, the introspection response is cached and then the entry is + re-requested after it has expired.""" + + self.http_client.request = introspection_mock = AsyncMock( + return_value=FakeResponse.json( + code=200, + payload={ + "active": True, + "sub": SUBJECT, + "scope": " ".join( + [ + MATRIX_USER_SCOPE, + f"{MATRIX_DEVICE_SCOPE_PREFIX}AABBCC", + ] + ), + "username": USERNAME, + "expires_in": 60, + }, + ) + ) + request = Mock(args={}) + request.args[b"access_token"] = [b"mockAccessToken"] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + + # The first CS-API request causes a successful introspection + self.get_success(self.auth.get_user_by_req(request)) + self.assertEqual(introspection_mock.call_count, 1) + + # Sleep for 60 seconds so the token expires. + self.reactor.advance(60.0) + + # Now the CS-API request fails because the token expired + self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError) + # Ensure another introspection request was not sent + self.assertEqual(introspection_mock.call_count, 1) + def make_device_keys(self, user_id: str, device_id: str) -> JsonDict: # We only generate a master key to simplify the test. master_signing_key = generate_signing_key(device_id)