From 027b4af5aca1fe5493ee15317ae3517d20931876 Mon Sep 17 00:00:00 2001
From: Mathieu Velten <mathieu.velten@beta.gouv.fr>
Date: Tue, 16 Jan 2024 21:28:50 +0100
Subject: [PATCH 1/4] Fix a race when registering via email 3pid

---
 changelog.d/16827.bugfix           |  1 +
 synapse/rest/client/register.py    | 22 ++++++-
 tests/rest/client/test_register.py | 92 +++++++++++++++++++++++++++++-
 3 files changed, 113 insertions(+), 2 deletions(-)
 create mode 100644 changelog.d/16827.bugfix

diff --git a/changelog.d/16827.bugfix b/changelog.d/16827.bugfix
new file mode 100644
index 0000000000..e0ed9e262a
--- /dev/null
+++ b/changelog.d/16827.bugfix
@@ -0,0 +1 @@
+Fix a race when registering via email 3pid where 2 different user ids would be created.
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 634ebed2be..01541a1a9b 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -75,6 +75,8 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
+USER_REGISTRATION_LOCK_NAME = "user_registration"
+
 
 class EmailRegisterRequestTokenRestServlet(RestServlet):
     PATTERNS = client_patterns("/register/email/requestToken$")
@@ -417,6 +419,7 @@ class RegisterRestServlet(RestServlet):
         self.macaroon_gen = hs.get_macaroon_generator()
         self.ratelimiter = hs.get_registration_ratelimiter()
         self.password_policy_handler = hs.get_password_policy_handler()
+        self._worker_lock_handler = hs.get_worker_locks_handler()
         self.clock = hs.get_clock()
         self.password_auth_provider = hs.get_password_auth_provider()
         self._registration_enabled = self.hs.config.registration.enable_registration
@@ -508,6 +511,23 @@ class RegisterRestServlet(RestServlet):
                 "An access token should not be provided on requests to /register (except if type is m.login.application_service)",
             )
 
+        # Take a global lock when doing user registration to avoid races,
+        # for example when doing 3pid email binding.
+        async with self._worker_lock_handler.acquire_lock(
+            USER_REGISTRATION_LOCK_NAME, ""
+        ):
+            return await self._do_user_register(
+                desired_username, client_addr, body, should_issue_refresh_token, request
+            )
+
+    async def _do_user_register(
+        self,
+        desired_username: Optional[str],
+        address: str,
+        body: JsonDict,
+        should_issue_refresh_token: bool,
+        request: SynapseRequest,
+    ) -> Tuple[int, JsonDict]:
         # == Normal User Registration == (everyone else)
         if not self._registration_enabled:
             raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN)
@@ -702,7 +722,7 @@ class RegisterRestServlet(RestServlet):
                 guest_access_token=guest_access_token,
                 threepid=threepid,
                 default_display_name=display_name,
-                address=client_addr,
+                address=address,
                 user_agent_ips=entries,
             )
             # Necessary due to auth checks prior to the threepid being
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 859051cdda..1a2c0594fb 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -21,7 +21,8 @@
 #
 import datetime
 import os
-from typing import Any, Dict, List, Tuple
+import re
+from typing import Any, Dict, List, Optional, Tuple
 
 import pkg_resources
 
@@ -42,6 +43,7 @@ from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests import unittest
+from tests.server import ThreadedMemoryReactorClock
 from tests.unittest import override_config
 
 
@@ -1248,3 +1250,91 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
             f"{self.url}?token={token}",
         )
         self.assertEqual(channel.code, 200, msg=channel.result)
+
+
+class EmailRegisterRestServletTestCase(unittest.HomeserverTestCase):
+    servlets = [register.register_servlets]
+
+    def make_homeserver(
+        self, reactor: ThreadedMemoryReactorClock, clock: Clock
+    ) -> HomeServer:
+        hs = super().make_homeserver(reactor, clock)
+
+        async def send_email(
+            email_address: str,
+            subject: str,
+            app_name: str,
+            html: str,
+            text: str,
+            additional_headers: Optional[Dict[str, str]] = None,
+        ) -> None:
+            self.email_attempts.append(text)
+
+        self.email_attempts: List[str] = []
+        hs.get_send_email_handler().send_email = send_email  # type: ignore[method-assign]
+        return hs
+
+    @unittest.override_config(
+        {
+            "public_baseurl": "https://test_server",
+            "registrations_require_3pid": ["email"],
+            "disable_msisdn_registration": True,
+            "email": {
+                "smtp_host": "mail_server",
+                "smtp_port": 2525,
+                "notif_from": "sender@host",
+            },
+        }
+    )
+    def test_email_3pid_registration_race(self) -> None:
+        channel = self.make_request("POST", b"register", {"password": "password"})
+        session = channel.json_body["session"]
+
+        # request a token to be sent by email for validation
+        channel = self.make_request(
+            "POST",
+            b"register/email/requestToken",
+            {
+                "client_secret": "client_secret",
+                "email": "email@email",
+                "send_attempt": 1,
+            },
+        )
+        sid = channel.json_body["sid"]
+
+        email_text = self.email_attempts[0]
+        match = re.search("https://test_server(.*)", email_text)
+        assert match is not None
+        validation_url = match.group(1)
+
+        # "Click" the link in the email to validate the adress
+        self.make_request("GET", validation_url.encode("utf-8"))
+
+        # launch 2 simultaneous register request, only one account
+        # should be created after that.
+        register_content = {
+            "auth": {
+                "session": session,
+                "threepid_creds": {
+                    "client_secret": "client_secret",
+                    "sid": sid,
+                },
+                "type": "m.login.email.identity",
+            },
+            "password": "password",
+        }
+        register1_channel = self.make_request(
+            "POST", b"register", register_content, await_result=False
+        )
+        register2_channel = self.make_request(
+            "POST", b"register", register_content, await_result=False
+        )
+        while (
+            not register1_channel.is_finished() or not register2_channel.is_finished()
+        ):
+            self.pump()
+
+        self.assertEqual(
+            register1_channel.json_body["user_id"],
+            register2_channel.json_body["user_id"],
+        )

From d733e5adfff2318f709aef0a82b57efdbd169a5b Mon Sep 17 00:00:00 2001
From: Mathieu Velten <mathieu.velten@beta.gouv.fr>
Date: Thu, 4 Jul 2024 15:36:55 +0200
Subject: [PATCH 2/4] Add some lock logs

---
 synapse/handlers/worker_lock.py | 6 ++++++
 1 file changed, 6 insertions(+)

diff --git a/synapse/handlers/worker_lock.py b/synapse/handlers/worker_lock.py
index 7e578cf462..72c71d5858 100644
--- a/synapse/handlers/worker_lock.py
+++ b/synapse/handlers/worker_lock.py
@@ -19,6 +19,7 @@
 #
 #
 
+import logging
 import random
 from types import TracebackType
 from typing import (
@@ -48,6 +49,8 @@ if TYPE_CHECKING:
     from synapse.logging.opentracing import opentracing
     from synapse.server import HomeServer
 
+logger = logging.getLogger(__name__)
+
 
 # This lock is used to avoid creating an event while we are purging the room.
 # We take a read lock when creating an event, and a write one when purging a room.
@@ -247,6 +250,8 @@ class WaitingLock:
                 except Exception:
                     pass
 
+        logger.warn(f"lock taken: {self.lock_name}, {self.lock_key}")
+
         return await self._inner_lock.__aenter__()
 
     async def __aexit__(
@@ -261,6 +266,7 @@ class WaitingLock:
 
         try:
             r = await self._inner_lock.__aexit__(exc_type, exc, tb)
+            logger.warn(f"lock released: {self.lock_name}, {self.lock_key}")
         finally:
             self._lock_span.__exit__(exc_type, exc, tb)
 

From 8ba92f3d57ee29e4f6671759b4b35a26ea891061 Mon Sep 17 00:00:00 2001
From: Mathieu Velten <mathieu.velten@beta.gouv.fr>
Date: Wed, 24 Jul 2024 17:37:17 +0200
Subject: [PATCH 3/4] Try

---
 synapse/handlers/worker_lock.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/synapse/handlers/worker_lock.py b/synapse/handlers/worker_lock.py
index 72c71d5858..d743413e86 100644
--- a/synapse/handlers/worker_lock.py
+++ b/synapse/handlers/worker_lock.py
@@ -247,6 +247,7 @@ class WaitingLock:
                             timeout=self._get_next_retry_interval(),
                             reactor=self.reactor,
                         )
+                        self._retry_interval = 0.1
                 except Exception:
                     pass
 

From 9c807132e72af212f8d598de570488995efee78e Mon Sep 17 00:00:00 2001
From: Mathieu Velten <mathieu.velten@beta.gouv.fr>
Date: Thu, 25 Jul 2024 16:09:48 +0200
Subject: [PATCH 4/4] Add comment

---
 synapse/handlers/worker_lock.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/synapse/handlers/worker_lock.py b/synapse/handlers/worker_lock.py
index d743413e86..c87d02519b 100644
--- a/synapse/handlers/worker_lock.py
+++ b/synapse/handlers/worker_lock.py
@@ -247,6 +247,8 @@ class WaitingLock:
                             timeout=self._get_next_retry_interval(),
                             reactor=self.reactor,
                         )
+                        # Let's reset retry interval since we got notified, we
+                        # should only increase it if we hit the previous one
                         self._retry_interval = 0.1
                 except Exception:
                     pass