#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2021 The Matrix.org Foundation C.I.C
# Copyright (C) 2024 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#
import asyncio
from asyncio import Future
from http import HTTPStatus
from typing import Any, Awaitable, Dict, List, Optional, Tuple, TypeVar, cast
from unittest.mock import Mock

import attr
from parameterized import parameterized

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
from synapse.config.auto_accept_invites import AutoAcceptInvitesConfig
from synapse.events.auto_accept_invites import InviteAutoAccepter
from synapse.federation.federation_base import event_from_pdu_json
from synapse.handlers.sync import JoinedSyncResult, SyncRequestKey, SyncVersion
from synapse.module_api import ModuleApi
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.types import StreamToken, UserID, UserInfo, create_requester
from synapse.util import Clock

from tests.handlers.test_sync import generate_sync_config
from tests.unittest import (
    FederatingHomeserverTestCase,
    HomeserverTestCase,
    TestCase,
    override_config,
)


class AutoAcceptInvitesTestCase(FederatingHomeserverTestCase):
    """
    Integration test cases for auto-accepting invites.
    """

    servlets = [
        admin.register_servlets,
        login.register_servlets,
        room.register_servlets,
    ]

    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
        hs = self.setup_test_homeserver()
        self.handler = hs.get_federation_handler()
        self.store = hs.get_datastores().main
        return hs

    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
        self.sync_handler = self.hs.get_sync_handler()
        self.module_api = hs.get_module_api()

    @parameterized.expand(
        [
            [False],
            [True],
        ]
    )
    @override_config(
        {
            "auto_accept_invites": {
                "enabled": True,
            },
        }
    )
    def test_auto_accept_invites(self, direct_room: bool) -> None:
        """Test that a user automatically joins a room when invited, if the
        module is enabled.
        """
        # A local user who sends an invite
        inviting_user_id = self.register_user("inviter", "pass")
        inviting_user_tok = self.login("inviter", "pass")

        # A local user who receives an invite
        invited_user_id = self.register_user("invitee", "pass")
        self.login("invitee", "pass")

        # Create a room and send an invite to the other user
        room_id = self.helper.create_room_as(
            inviting_user_id,
            is_public=False,
            tok=inviting_user_tok,
        )

        self.helper.invite(
            room_id,
            inviting_user_id,
            invited_user_id,
            tok=inviting_user_tok,
            extra_data={"is_direct": direct_room},
        )

        # Check that the invite receiving user has automatically joined the room when syncing
        join_updates, _ = sync_join(self, invited_user_id)
        self.assertEqual(len(join_updates), 1)

        join_update: JoinedSyncResult = join_updates[0]
        self.assertEqual(join_update.room_id, room_id)

    @override_config(
        {
            "auto_accept_invites": {
                "enabled": False,
            },
        }
    )
    def test_module_not_enabled(self) -> None:
        """Test that a user does not automatically join a room when invited,
        if the module is not enabled.
        """
        # A local user who sends an invite
        inviting_user_id = self.register_user("inviter", "pass")
        inviting_user_tok = self.login("inviter", "pass")

        # A local user who receives an invite
        invited_user_id = self.register_user("invitee", "pass")
        self.login("invitee", "pass")

        # Create a room and send an invite to the other user
        room_id = self.helper.create_room_as(
            inviting_user_id, is_public=False, tok=inviting_user_tok
        )

        self.helper.invite(
            room_id,
            inviting_user_id,
            invited_user_id,
            tok=inviting_user_tok,
        )

        # Check that the invite receiving user has not automatically joined the room when syncing
        join_updates, _ = sync_join(self, invited_user_id)
        self.assertEqual(len(join_updates), 0)

    @override_config(
        {
            "auto_accept_invites": {
                "enabled": True,
            },
        }
    )
    def test_invite_from_remote_user(self) -> None:
        """Test that an invite from a remote user results in the invited user
        automatically joining the room.
        """
        # A remote user who sends the invite
        remote_server = "otherserver"
        remote_user = "@otheruser:" + remote_server

        # A local user who creates the room
        creator_user_id = self.register_user("creator", "pass")
        creator_user_tok = self.login("creator", "pass")

        # A local user who receives an invite
        invited_user_id = self.register_user("invitee", "pass")
        self.login("invitee", "pass")

        room_id = self.helper.create_room_as(
            room_creator=creator_user_id, tok=creator_user_tok
        )
        room_version = self.get_success(self.store.get_room_version(room_id))

        invite_event = event_from_pdu_json(
            {
                "type": EventTypes.Member,
                "content": {"membership": "invite"},
                "room_id": room_id,
                "sender": remote_user,
                "state_key": invited_user_id,
                "depth": 32,
                "prev_events": [],
                "auth_events": [],
                "origin_server_ts": self.clock.time_msec(),
            },
            room_version,
        )
        self.get_success(
            self.handler.on_invite_request(
                remote_server,
                invite_event,
                invite_event.room_version,
            )
        )

        # Check that the invite receiving user has automatically joined the room when syncing
        join_updates, _ = sync_join(self, invited_user_id)
        self.assertEqual(len(join_updates), 1)

        join_update: JoinedSyncResult = join_updates[0]
        self.assertEqual(join_update.room_id, room_id)

    @parameterized.expand(
        [
            [False, False],
            [True, True],
        ]
    )
    @override_config(
        {
            "auto_accept_invites": {
                "enabled": True,
                "only_for_direct_messages": True,
            },
        }
    )
    def test_accept_invite_direct_message(
        self,
        direct_room: bool,
        expect_auto_join: bool,
    ) -> None:
        """Tests that, if the module is configured to only accept DM invites, invites to DM rooms are still
        automatically accepted. Otherwise they are rejected.
        """
        # A local user who sends an invite
        inviting_user_id = self.register_user("inviter", "pass")
        inviting_user_tok = self.login("inviter", "pass")

        # A local user who receives an invite
        invited_user_id = self.register_user("invitee", "pass")
        self.login("invitee", "pass")

        # Create a room and send an invite to the other user
        room_id = self.helper.create_room_as(
            inviting_user_id,
            is_public=False,
            tok=inviting_user_tok,
        )

        self.helper.invite(
            room_id,
            inviting_user_id,
            invited_user_id,
            tok=inviting_user_tok,
            extra_data={"is_direct": direct_room},
        )

        if expect_auto_join:
            # Check that the invite receiving user has automatically joined the room when syncing
            join_updates, _ = sync_join(self, invited_user_id)
            self.assertEqual(len(join_updates), 1)

            join_update: JoinedSyncResult = join_updates[0]
            self.assertEqual(join_update.room_id, room_id)
        else:
            # Check that the invite receiving user has not automatically joined the room when syncing
            join_updates, _ = sync_join(self, invited_user_id)
            self.assertEqual(len(join_updates), 0)

    @parameterized.expand(
        [
            [False, True],
            [True, False],
        ]
    )
    @override_config(
        {
            "auto_accept_invites": {
                "enabled": True,
                "only_from_local_users": True,
            },
        }
    )
    def test_accept_invite_local_user(
        self, remote_inviter: bool, expect_auto_join: bool
    ) -> None:
        """Tests that, if the module is configured to only accept invites from local users, invites
        from local users are still automatically accepted. Otherwise they are rejected.
        """
        # A local user who sends an invite
        creator_user_id = self.register_user("inviter", "pass")
        creator_user_tok = self.login("inviter", "pass")

        # A local user who receives an invite
        invited_user_id = self.register_user("invitee", "pass")
        self.login("invitee", "pass")

        # Create a room and send an invite to the other user
        room_id = self.helper.create_room_as(
            creator_user_id, is_public=False, tok=creator_user_tok
        )

        if remote_inviter:
            room_version = self.get_success(self.store.get_room_version(room_id))

            # A remote user who sends the invite
            remote_server = "otherserver"
            remote_user = "@otheruser:" + remote_server

            invite_event = event_from_pdu_json(
                {
                    "type": EventTypes.Member,
                    "content": {"membership": "invite"},
                    "room_id": room_id,
                    "sender": remote_user,
                    "state_key": invited_user_id,
                    "depth": 32,
                    "prev_events": [],
                    "auth_events": [],
                    "origin_server_ts": self.clock.time_msec(),
                },
                room_version,
            )
            self.get_success(
                self.handler.on_invite_request(
                    remote_server,
                    invite_event,
                    invite_event.room_version,
                )
            )
        else:
            self.helper.invite(
                room_id,
                creator_user_id,
                invited_user_id,
                tok=creator_user_tok,
            )

        if expect_auto_join:
            # Check that the invite receiving user has automatically joined the room when syncing
            join_updates, _ = sync_join(self, invited_user_id)
            self.assertEqual(len(join_updates), 1)

            join_update: JoinedSyncResult = join_updates[0]
            self.assertEqual(join_update.room_id, room_id)
        else:
            # Check that the invite receiving user has not automatically joined the room when syncing
            join_updates, _ = sync_join(self, invited_user_id)
            self.assertEqual(len(join_updates), 0)

    @override_config(
        {
            "auto_accept_invites": {
                "enabled": True,
            },
        }
    )
    async def test_ignore_invite_for_missing_user(self) -> None:
        """Tests that receiving an invite for a missing user is ignored."""
        inviting_user_id = self.register_user("inviter", "pass")
        inviting_user_tok = self.login("inviter", "pass")

        # A local user who receives an invite
        invited_user_id = "@fake:" + self.hs.config.server.server_name

        # Create a room and send an invite to the other user
        room_id = self.helper.create_room_as(
            inviting_user_id,
            tok=inviting_user_tok,
        )

        self.helper.invite(
            room_id,
            inviting_user_id,
            invited_user_id,
            tok=inviting_user_tok,
        )

        join_updates, _ = sync_join(self, inviting_user_id)
        # Assert that the last event in the room was not a member event for the target user.
        self.assertEqual(
            join_updates[0].timeline.events[-1].content["membership"], "invite"
        )

    @override_config(
        {
            "auto_accept_invites": {
                "enabled": True,
            },
        }
    )
    async def test_ignore_invite_for_deactivated_user(self) -> None:
        """Tests that receiving an invite for a deactivated user is ignored."""
        inviting_user_id = self.register_user("inviter", "pass", admin=True)
        inviting_user_tok = self.login("inviter", "pass")

        # A local user who receives an invite
        invited_user_id = self.register_user("invitee", "pass")

        # Create a room and send an invite to the other user
        room_id = self.helper.create_room_as(
            inviting_user_id,
            tok=inviting_user_tok,
        )

        channel = self.make_request(
            "PUT",
            "/_synapse/admin/v2/users/%s" % invited_user_id,
            {"deactivated": True},
            access_token=inviting_user_tok,
        )

        assert channel.code == 200

        self.helper.invite(
            room_id,
            inviting_user_id,
            invited_user_id,
            tok=inviting_user_tok,
        )

        join_updates, b = sync_join(self, inviting_user_id)
        # Assert that the last event in the room was not a member event for the target user.
        self.assertEqual(
            join_updates[0].timeline.events[-1].content["membership"], "invite"
        )

    @override_config(
        {
            "auto_accept_invites": {
                "enabled": True,
            },
        }
    )
    async def test_ignore_invite_for_suspended_user(self) -> None:
        """Tests that receiving an invite for a suspended user is ignored."""
        inviting_user_id = self.register_user("inviter", "pass", admin=True)
        inviting_user_tok = self.login("inviter", "pass")

        # A local user who receives an invite
        invited_user_id = self.register_user("invitee", "pass")

        # Create a room and send an invite to the other user
        room_id = self.helper.create_room_as(
            inviting_user_id,
            tok=inviting_user_tok,
        )

        channel = self.make_request(
            "PUT",
            f"/_synapse/admin/v1/suspend/{invited_user_id}",
            {"suspend": True},
            access_token=inviting_user_tok,
        )

        assert channel.code == 200

        self.helper.invite(
            room_id,
            inviting_user_id,
            invited_user_id,
            tok=inviting_user_tok,
        )

        join_updates, b = sync_join(self, inviting_user_id)
        # Assert that the last event in the room was not a member event for the target user.
        self.assertEqual(
            join_updates[0].timeline.events[-1].content["membership"], "invite"
        )

    @override_config(
        {
            "auto_accept_invites": {
                "enabled": True,
            },
        }
    )
    async def test_ignore_invite_for_locked_user(self) -> None:
        """Tests that receiving an invite for a suspended user is ignored."""
        inviting_user_id = self.register_user("inviter", "pass", admin=True)
        inviting_user_tok = self.login("inviter", "pass")

        # A local user who receives an invite
        invited_user_id = self.register_user("invitee", "pass")

        # Create a room and send an invite to the other user
        room_id = self.helper.create_room_as(
            inviting_user_id,
            tok=inviting_user_tok,
        )

        channel = self.make_request(
            "PUT",
            f"/_synapse/admin/v2/users/{invited_user_id}",
            {"locked": True},
            access_token=inviting_user_tok,
        )

        assert channel.code == 200

        self.helper.invite(
            room_id,
            inviting_user_id,
            invited_user_id,
            tok=inviting_user_tok,
        )

        join_updates, b = sync_join(self, inviting_user_id)
        # Assert that the last event in the room was not a member event for the target user.
        self.assertEqual(
            join_updates[0].timeline.events[-1].content["membership"], "invite"
        )


_request_key = 0


def generate_request_key() -> SyncRequestKey:
    global _request_key
    _request_key += 1
    return ("request_key", _request_key)


def sync_join(
    testcase: HomeserverTestCase,
    user_id: str,
    since_token: Optional[StreamToken] = None,
) -> Tuple[List[JoinedSyncResult], StreamToken]:
    """Perform a sync request for the given user and return the user join updates
    they've received, as well as the next_batch token.

    This method assumes testcase.sync_handler points to the homeserver's sync handler.

    Args:
        testcase: The testcase that is currently being run.
        user_id: The ID of the user to generate a sync response for.
        since_token: An optional token to indicate from at what point to sync from.

    Returns:
        A tuple containing a list of join updates, and the sync response's
        next_batch token.
    """
    requester = create_requester(user_id)
    sync_config = generate_sync_config(requester.user.to_string())
    sync_result = testcase.get_success(
        testcase.hs.get_sync_handler().wait_for_sync_for_user(
            requester,
            sync_config,
            SyncVersion.SYNC_V2,
            generate_request_key(),
            since_token,
        )
    )

    return sync_result.joined, sync_result.next_batch


class InviteAutoAccepterInternalTestCase(TestCase):
    """
    Test cases which exercise the internals of the InviteAutoAccepter.
    """

    def setUp(self) -> None:
        self.module = create_module()
        self.user_id = "@peter:test"
        self.invitee = "@lesley:test"
        self.remote_invitee = "@thomas:remote"

        # We know our module API is a mock, but mypy doesn't.
        self.mocked_update_membership: Mock = self.module._api.update_room_membership  # type: ignore[assignment]

    async def test_accept_invite_with_failures(self) -> None:
        """Tests that receiving an invite for a local user makes the module attempt to
        make the invitee join the room. This test verifies that it works if the call to
        update membership returns exceptions before successfully completing and returning an event.
        """
        invite = MockEvent(
            sender="@inviter:test",
            state_key="@invitee:test",
            type="m.room.member",
            content={"membership": "invite"},
        )

        join_event = MockEvent(
            sender="someone",
            state_key="someone",
            type="m.room.member",
            content={"membership": "join"},
        )
        # the first two calls raise an exception while the third call is successful
        self.mocked_update_membership.side_effect = [
            SynapseError(HTTPStatus.FORBIDDEN, "Forbidden"),
            SynapseError(HTTPStatus.FORBIDDEN, "Forbidden"),
            make_awaitable(join_event),
        ]

        # Stop mypy from complaining that we give on_new_event a MockEvent rather than an
        # EventBase.
        await self.module.on_new_event(event=invite)  # type: ignore[arg-type]

        await self.retry_assertions(
            self.mocked_update_membership,
            3,
            sender=invite.state_key,
            target=invite.state_key,
            room_id=invite.room_id,
            new_membership="join",
        )

    async def test_accept_invite_failures(self) -> None:
        """Tests that receiving an invite for a local user makes the module attempt to
        make the invitee join the room. This test verifies that if the update_membership call
        fails consistently, _retry_make_join will break the loop after the set number of retries and
        execution will continue.
        """
        invite = MockEvent(
            sender=self.user_id,
            state_key=self.invitee,
            type="m.room.member",
            content={"membership": "invite"},
        )
        self.mocked_update_membership.side_effect = SynapseError(
            HTTPStatus.FORBIDDEN, "Forbidden"
        )

        # Stop mypy from complaining that we give on_new_event a MockEvent rather than an
        # EventBase.
        await self.module.on_new_event(event=invite)  # type: ignore[arg-type]

        await self.retry_assertions(
            self.mocked_update_membership,
            5,
            sender=invite.state_key,
            target=invite.state_key,
            room_id=invite.room_id,
            new_membership="join",
        )

    async def test_not_state(self) -> None:
        """Tests that receiving an invite that's not a state event does nothing."""
        invite = MockEvent(
            sender=self.user_id, type="m.room.member", content={"membership": "invite"}
        )

        # Stop mypy from complaining that we give on_new_event a MockEvent rather than an
        # EventBase.
        await self.module.on_new_event(event=invite)  # type: ignore[arg-type]

        self.mocked_update_membership.assert_not_called()

    async def test_not_invite(self) -> None:
        """Tests that receiving a membership update that's not an invite does nothing."""
        invite = MockEvent(
            sender=self.user_id,
            state_key=self.user_id,
            type="m.room.member",
            content={"membership": "join"},
        )

        # Stop mypy from complaining that we give on_new_event a MockEvent rather than an
        # EventBase.
        await self.module.on_new_event(event=invite)  # type: ignore[arg-type]

        self.mocked_update_membership.assert_not_called()

    async def test_not_membership(self) -> None:
        """Tests that receiving a state event that's not a membership update does
        nothing.
        """
        invite = MockEvent(
            sender=self.user_id,
            state_key=self.user_id,
            type="org.matrix.test",
            content={"foo": "bar"},
        )

        # Stop mypy from complaining that we give on_new_event a MockEvent rather than an
        # EventBase.
        await self.module.on_new_event(event=invite)  # type: ignore[arg-type]

        self.mocked_update_membership.assert_not_called()

    def test_config_parse(self) -> None:
        """Tests that a correct configuration parses."""
        config = {
            "auto_accept_invites": {
                "enabled": True,
                "only_for_direct_messages": True,
                "only_from_local_users": True,
            }
        }
        parsed_config = AutoAcceptInvitesConfig()
        parsed_config.read_config(config)

        self.assertTrue(parsed_config.enabled)
        self.assertTrue(parsed_config.accept_invites_only_for_direct_messages)
        self.assertTrue(parsed_config.accept_invites_only_from_local_users)

    def test_runs_on_only_one_worker(self) -> None:
        """
        Tests that the module only runs on the specified worker.
        """
        # By default, we run on the main process...
        main_module = create_module(
            config_override={"auto_accept_invites": {"enabled": True}}, worker_name=None
        )
        cast(
            Mock, main_module._api.register_third_party_rules_callbacks
        ).assert_called_once()

        # ...and not on other workers (like synchrotrons)...
        sync_module = create_module(worker_name="synchrotron42")
        cast(
            Mock, sync_module._api.register_third_party_rules_callbacks
        ).assert_not_called()

        # ...unless we configured them to be the designated worker.
        specified_module = create_module(
            config_override={
                "auto_accept_invites": {
                    "enabled": True,
                    "worker_to_run_on": "account_data1",
                }
            },
            worker_name="account_data1",
        )
        cast(
            Mock, specified_module._api.register_third_party_rules_callbacks
        ).assert_called_once()

    async def retry_assertions(
        self, mock: Mock, call_count: int, **kwargs: Any
    ) -> None:
        """
        This is a hacky way to ensure that the assertions are not called before the other coroutine
        has a chance to call `update_room_membership`. It catches the exception caused by a failure,
        and sleeps the thread before retrying, up until 5 tries.

        Args:
            call_count: the number of times the mock should have been called
            mock: the mocked function we want to assert on
            kwargs: keyword arguments to assert that the mock was called with
        """

        i = 0
        while i < 5:
            try:
                # Check that the mocked method is called the expected amount of times and with the right
                # arguments to attempt to make the user join the room.
                mock.assert_called_with(**kwargs)
                self.assertEqual(call_count, mock.call_count)
                break
            except AssertionError as e:
                i += 1
                if i == 5:
                    # we've used up the tries, force the test to fail as we've already caught the exception
                    self.fail(e)
                await asyncio.sleep(1)


@attr.s(auto_attribs=True)
class MockEvent:
    """Mocks an event. Only exposes properties the module uses."""

    sender: str
    type: str
    content: Dict[str, Any]
    room_id: str = "!someroom"
    state_key: Optional[str] = None

    def is_state(self) -> bool:
        """Checks if the event is a state event by checking if it has a state key."""
        return self.state_key is not None

    @property
    def membership(self) -> str:
        """Extracts the membership from the event. Should only be called on an event
        that's a membership event, and will raise a KeyError otherwise.
        """
        membership: str = self.content["membership"]
        return membership


T = TypeVar("T")
TV = TypeVar("TV")


async def make_awaitable(value: T) -> T:
    return value


def make_multiple_awaitable(result: TV) -> Awaitable[TV]:
    """
    Makes an awaitable, suitable for mocking an `async` function.
    This uses Futures as they can be awaited multiple times so can be returned
    to multiple callers.
    """
    future: Future[TV] = Future()
    future.set_result(result)
    return future


def create_module(
    config_override: Optional[Dict[str, Any]] = None, worker_name: Optional[str] = None
) -> InviteAutoAccepter:
    # Create a mock based on the ModuleApi spec, but override some mocked functions
    # because some capabilities are needed for running the tests.
    module_api = Mock(spec=ModuleApi)
    module_api.is_mine.side_effect = lambda a: a.split(":")[1] == "test"
    module_api.worker_name = worker_name
    module_api.sleep.return_value = make_multiple_awaitable(None)
    module_api.get_userinfo_by_id.return_value = UserInfo(
        user_id=UserID.from_string("@user:test"),
        is_admin=False,
        is_guest=False,
        consent_server_notice_sent=None,
        consent_ts=None,
        consent_version=None,
        appservice_id=None,
        creation_ts=0,
        user_type=None,
        is_deactivated=False,
        locked=False,
        is_shadow_banned=False,
        approved=True,
        suspended=False,
    )

    if config_override is None:
        config_override = {}

    config = AutoAcceptInvitesConfig()
    config.read_config(config_override)

    return InviteAutoAccepter(config, module_api)