mirror of
https://github.com/element-hq/synapse.git
synced 2024-12-14 11:57:44 +00:00
Support MSC4140: Delayed events (Futures) (#17326)
This commit is contained in:
parent
75e2c17d2a
commit
5173741c71
21 changed files with 1772 additions and 12 deletions
2
changelog.d/17326.feature
Normal file
2
changelog.d/17326.feature
Normal file
|
@ -0,0 +1,2 @@
|
|||
Add initial implementation of delayed events as proposed by [MSC4140](https://github.com/matrix-org/matrix-spec-proposals/pull/4140).
|
||||
|
|
@ -111,6 +111,9 @@ server_notices:
|
|||
system_mxid_avatar_url: ""
|
||||
room_name: "Server Alert"
|
||||
|
||||
# Enable delayed events (msc4140)
|
||||
max_event_delay_duration: 24h
|
||||
|
||||
|
||||
# Disable sync cache so that initial `/sync` requests are up-to-date.
|
||||
caches:
|
||||
|
|
|
@ -761,6 +761,19 @@ email:
|
|||
password_reset: "[%(server_name)s] Password reset"
|
||||
email_validation: "[%(server_name)s] Validate your email"
|
||||
```
|
||||
---
|
||||
### `max_event_delay_duration`
|
||||
|
||||
The maximum allowed duration by which sent events can be delayed, as per
|
||||
[MSC4140](https://github.com/matrix-org/matrix-spec-proposals/pull/4140).
|
||||
Must be a positive value if set.
|
||||
|
||||
Defaults to no duration (`null`), which disallows sending delayed events.
|
||||
|
||||
Example configuration:
|
||||
```yaml
|
||||
max_event_delay_duration: 24h
|
||||
```
|
||||
|
||||
## Homeserver blocking
|
||||
Useful options for Synapse admins.
|
||||
|
|
|
@ -290,6 +290,7 @@ information.
|
|||
Additionally, the following REST endpoints can be handled for GET requests:
|
||||
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/pushrules/
|
||||
^/_matrix/client/unstable/org.matrix.msc4140/delayed_events
|
||||
|
||||
Pagination requests can also be handled, but all requests for a given
|
||||
room must be routed to the same instance. Additionally, care must be taken to
|
||||
|
|
|
@ -223,6 +223,7 @@ test_packages=(
|
|||
./tests/msc3930
|
||||
./tests/msc3902
|
||||
./tests/msc3967
|
||||
./tests/msc4140
|
||||
)
|
||||
|
||||
# Enable dirty runs, so tests will reuse the same container where possible.
|
||||
|
|
|
@ -65,6 +65,7 @@ from synapse.storage.databases.main.appservice import (
|
|||
)
|
||||
from synapse.storage.databases.main.censor_events import CensorEventsStore
|
||||
from synapse.storage.databases.main.client_ips import ClientIpWorkerStore
|
||||
from synapse.storage.databases.main.delayed_events import DelayedEventsStore
|
||||
from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
|
||||
from synapse.storage.databases.main.devices import DeviceWorkerStore
|
||||
from synapse.storage.databases.main.directory import DirectoryWorkerStore
|
||||
|
@ -161,6 +162,7 @@ class GenericWorkerStore(
|
|||
TaskSchedulerWorkerStore,
|
||||
ExperimentalFeaturesStore,
|
||||
SlidingSyncStore,
|
||||
DelayedEventsStore,
|
||||
):
|
||||
# Properties that multiple storage classes define. Tell mypy what the
|
||||
# expected type is.
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
|
||||
# Copyright (C) 2023 New Vector, Ltd
|
||||
# Copyright (C) 2023-2024 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
|
@ -780,6 +780,17 @@ class ServerConfig(Config):
|
|||
else:
|
||||
self.delete_stale_devices_after = None
|
||||
|
||||
# The maximum allowed delay duration for delayed events (MSC4140).
|
||||
max_event_delay_duration = config.get("max_event_delay_duration")
|
||||
if max_event_delay_duration is not None:
|
||||
self.max_event_delay_ms: Optional[int] = self.parse_duration(
|
||||
max_event_delay_duration
|
||||
)
|
||||
if self.max_event_delay_ms <= 0:
|
||||
raise ConfigError("max_event_delay_duration must be a positive value")
|
||||
else:
|
||||
self.max_event_delay_ms = None
|
||||
|
||||
def has_tls_listener(self) -> bool:
|
||||
return any(listener.is_tls() for listener in self.listeners)
|
||||
|
||||
|
|
484
synapse/handlers/delayed_events.py
Normal file
484
synapse/handlers/delayed_events.py
Normal file
|
@ -0,0 +1,484 @@
|
|||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional, Set, Tuple
|
||||
|
||||
from twisted.internet.interfaces import IDelayedCall
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import ShadowBanError
|
||||
from synapse.config.workers import MAIN_PROCESS_INSTANCE_NAME
|
||||
from synapse.logging.opentracing import set_tag
|
||||
from synapse.metrics import event_processing_positions
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.http.delayed_events import (
|
||||
ReplicationAddedDelayedEventRestServlet,
|
||||
)
|
||||
from synapse.storage.databases.main.delayed_events import (
|
||||
DelayedEventDetails,
|
||||
DelayID,
|
||||
EventType,
|
||||
StateKey,
|
||||
Timestamp,
|
||||
UserLocalpart,
|
||||
)
|
||||
from synapse.storage.databases.main.state_deltas import StateDelta
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
Requester,
|
||||
RoomID,
|
||||
UserID,
|
||||
create_requester,
|
||||
)
|
||||
from synapse.util.events import generate_fake_event_id
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DelayedEventsHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._store = hs.get_datastores().main
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self._config = hs.config
|
||||
self._clock = hs.get_clock()
|
||||
self._request_ratelimiter = hs.get_request_ratelimiter()
|
||||
self._event_creation_handler = hs.get_event_creation_handler()
|
||||
self._room_member_handler = hs.get_room_member_handler()
|
||||
|
||||
self._next_delayed_event_call: Optional[IDelayedCall] = None
|
||||
|
||||
# The current position in the current_state_delta stream
|
||||
self._event_pos: Optional[int] = None
|
||||
|
||||
# Guard to ensure we only process event deltas one at a time
|
||||
self._event_processing = False
|
||||
|
||||
if hs.config.worker.worker_app is None:
|
||||
self._repl_client = None
|
||||
|
||||
async def _schedule_db_events() -> None:
|
||||
# We kick this off to pick up outstanding work from before the last restart.
|
||||
# Block until we're up to date.
|
||||
await self._unsafe_process_new_event()
|
||||
hs.get_notifier().add_replication_callback(self.notify_new_event)
|
||||
# Kick off again (without blocking) to catch any missed notifications
|
||||
# that may have fired before the callback was added.
|
||||
self._clock.call_later(0, self.notify_new_event)
|
||||
|
||||
# Delayed events that are already marked as processed on startup might not have been
|
||||
# sent properly on the last run of the server, so unmark them to send them again.
|
||||
# Caveat: this will double-send delayed events that successfully persisted, but failed
|
||||
# to be removed from the DB table of delayed events.
|
||||
# TODO: To avoid double-sending, scan the timeline to find which of these events were
|
||||
# already sent. To do so, must store delay_ids in sent events to retrieve them later.
|
||||
await self._store.unprocess_delayed_events()
|
||||
|
||||
events, next_send_ts = await self._store.process_timeout_delayed_events(
|
||||
self._get_current_ts()
|
||||
)
|
||||
|
||||
if next_send_ts:
|
||||
self._schedule_next_at(next_send_ts)
|
||||
|
||||
# Can send the events in background after having awaited on marking them as processed
|
||||
run_as_background_process(
|
||||
"_send_events",
|
||||
self._send_events,
|
||||
events,
|
||||
)
|
||||
|
||||
self._initialized_from_db = run_as_background_process(
|
||||
"_schedule_db_events", _schedule_db_events
|
||||
)
|
||||
else:
|
||||
self._repl_client = ReplicationAddedDelayedEventRestServlet.make_client(hs)
|
||||
|
||||
@property
|
||||
def _is_master(self) -> bool:
|
||||
return self._repl_client is None
|
||||
|
||||
def notify_new_event(self) -> None:
|
||||
"""
|
||||
Called when there may be more state event deltas to process,
|
||||
which should cancel pending delayed events for the same state.
|
||||
"""
|
||||
if self._event_processing:
|
||||
return
|
||||
|
||||
self._event_processing = True
|
||||
|
||||
async def process() -> None:
|
||||
try:
|
||||
await self._unsafe_process_new_event()
|
||||
finally:
|
||||
self._event_processing = False
|
||||
|
||||
run_as_background_process("delayed_events.notify_new_event", process)
|
||||
|
||||
async def _unsafe_process_new_event(self) -> None:
|
||||
# If self._event_pos is None then means we haven't fetched it from the DB yet
|
||||
if self._event_pos is None:
|
||||
self._event_pos = await self._store.get_delayed_events_stream_pos()
|
||||
room_max_stream_ordering = self._store.get_room_max_stream_ordering()
|
||||
if self._event_pos > room_max_stream_ordering:
|
||||
# apparently, we've processed more events than exist in the database!
|
||||
# this can happen if events are removed with history purge or similar.
|
||||
logger.warning(
|
||||
"Event stream ordering appears to have gone backwards (%i -> %i): "
|
||||
"rewinding delayed events processor",
|
||||
self._event_pos,
|
||||
room_max_stream_ordering,
|
||||
)
|
||||
self._event_pos = room_max_stream_ordering
|
||||
|
||||
# Loop round handling deltas until we're up to date
|
||||
while True:
|
||||
with Measure(self._clock, "delayed_events_delta"):
|
||||
room_max_stream_ordering = self._store.get_room_max_stream_ordering()
|
||||
if self._event_pos == room_max_stream_ordering:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
"Processing delayed events %s->%s",
|
||||
self._event_pos,
|
||||
room_max_stream_ordering,
|
||||
)
|
||||
(
|
||||
max_pos,
|
||||
deltas,
|
||||
) = await self._storage_controllers.state.get_current_state_deltas(
|
||||
self._event_pos, room_max_stream_ordering
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Handling %d state deltas for delayed events processing",
|
||||
len(deltas),
|
||||
)
|
||||
await self._handle_state_deltas(deltas)
|
||||
|
||||
self._event_pos = max_pos
|
||||
|
||||
# Expose current event processing position to prometheus
|
||||
event_processing_positions.labels("delayed_events").set(max_pos)
|
||||
|
||||
await self._store.update_delayed_events_stream_pos(max_pos)
|
||||
|
||||
async def _handle_state_deltas(self, deltas: List[StateDelta]) -> None:
|
||||
"""
|
||||
Process current state deltas to cancel pending delayed events
|
||||
that target the same state.
|
||||
"""
|
||||
for delta in deltas:
|
||||
logger.debug(
|
||||
"Handling: %r %r, %s", delta.event_type, delta.state_key, delta.event_id
|
||||
)
|
||||
|
||||
next_send_ts = await self._store.cancel_delayed_state_events(
|
||||
room_id=delta.room_id,
|
||||
event_type=delta.event_type,
|
||||
state_key=delta.state_key,
|
||||
)
|
||||
|
||||
if self._next_send_ts_changed(next_send_ts):
|
||||
self._schedule_next_at_or_none(next_send_ts)
|
||||
|
||||
async def add(
|
||||
self,
|
||||
requester: Requester,
|
||||
*,
|
||||
room_id: str,
|
||||
event_type: str,
|
||||
state_key: Optional[str],
|
||||
origin_server_ts: Optional[int],
|
||||
content: JsonDict,
|
||||
delay: int,
|
||||
) -> str:
|
||||
"""
|
||||
Creates a new delayed event and schedules its delivery.
|
||||
|
||||
Args:
|
||||
requester: The requester of the delayed event, who will be its owner.
|
||||
room_id: The ID of the room where the event should be sent to.
|
||||
event_type: The type of event to be sent.
|
||||
state_key: The state key of the event to be sent, or None if it is not a state event.
|
||||
origin_server_ts: The custom timestamp to send the event with.
|
||||
If None, the timestamp will be the actual time when the event is sent.
|
||||
content: The content of the event to be sent.
|
||||
delay: How long (in milliseconds) to wait before automatically sending the event.
|
||||
|
||||
Returns: The ID of the added delayed event.
|
||||
|
||||
Raises:
|
||||
SynapseError: if the delayed event fails validation checks.
|
||||
"""
|
||||
await self._request_ratelimiter.ratelimit(requester)
|
||||
|
||||
self._event_creation_handler.validator.validate_builder(
|
||||
self._event_creation_handler.event_builder_factory.for_room_version(
|
||||
await self._store.get_room_version(room_id),
|
||||
{
|
||||
"type": event_type,
|
||||
"content": content,
|
||||
"room_id": room_id,
|
||||
"sender": str(requester.user),
|
||||
**({"state_key": state_key} if state_key is not None else {}),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
creation_ts = self._get_current_ts()
|
||||
|
||||
delay_id, next_send_ts = await self._store.add_delayed_event(
|
||||
user_localpart=requester.user.localpart,
|
||||
device_id=requester.device_id,
|
||||
creation_ts=creation_ts,
|
||||
room_id=room_id,
|
||||
event_type=event_type,
|
||||
state_key=state_key,
|
||||
origin_server_ts=origin_server_ts,
|
||||
content=content,
|
||||
delay=delay,
|
||||
)
|
||||
|
||||
if self._repl_client is not None:
|
||||
# NOTE: If this throws, the delayed event will remain in the DB and
|
||||
# will be picked up once the main worker gets another delayed event.
|
||||
await self._repl_client(
|
||||
instance_name=MAIN_PROCESS_INSTANCE_NAME,
|
||||
next_send_ts=next_send_ts,
|
||||
)
|
||||
elif self._next_send_ts_changed(next_send_ts):
|
||||
self._schedule_next_at(next_send_ts)
|
||||
|
||||
return delay_id
|
||||
|
||||
def on_added(self, next_send_ts: int) -> None:
|
||||
next_send_ts = Timestamp(next_send_ts)
|
||||
if self._next_send_ts_changed(next_send_ts):
|
||||
self._schedule_next_at(next_send_ts)
|
||||
|
||||
async def cancel(self, requester: Requester, delay_id: str) -> None:
|
||||
"""
|
||||
Cancels the scheduled delivery of the matching delayed event.
|
||||
|
||||
Args:
|
||||
requester: The owner of the delayed event to act on.
|
||||
delay_id: The ID of the delayed event to act on.
|
||||
|
||||
Raises:
|
||||
NotFoundError: if no matching delayed event could be found.
|
||||
"""
|
||||
assert self._is_master
|
||||
await self._request_ratelimiter.ratelimit(requester)
|
||||
await self._initialized_from_db
|
||||
|
||||
next_send_ts = await self._store.cancel_delayed_event(
|
||||
delay_id=delay_id,
|
||||
user_localpart=requester.user.localpart,
|
||||
)
|
||||
|
||||
if self._next_send_ts_changed(next_send_ts):
|
||||
self._schedule_next_at_or_none(next_send_ts)
|
||||
|
||||
async def restart(self, requester: Requester, delay_id: str) -> None:
|
||||
"""
|
||||
Restarts the scheduled delivery of the matching delayed event.
|
||||
|
||||
Args:
|
||||
requester: The owner of the delayed event to act on.
|
||||
delay_id: The ID of the delayed event to act on.
|
||||
|
||||
Raises:
|
||||
NotFoundError: if no matching delayed event could be found.
|
||||
"""
|
||||
assert self._is_master
|
||||
await self._request_ratelimiter.ratelimit(requester)
|
||||
await self._initialized_from_db
|
||||
|
||||
next_send_ts = await self._store.restart_delayed_event(
|
||||
delay_id=delay_id,
|
||||
user_localpart=requester.user.localpart,
|
||||
current_ts=self._get_current_ts(),
|
||||
)
|
||||
|
||||
if self._next_send_ts_changed(next_send_ts):
|
||||
self._schedule_next_at(next_send_ts)
|
||||
|
||||
async def send(self, requester: Requester, delay_id: str) -> None:
|
||||
"""
|
||||
Immediately sends the matching delayed event, instead of waiting for its scheduled delivery.
|
||||
|
||||
Args:
|
||||
requester: The owner of the delayed event to act on.
|
||||
delay_id: The ID of the delayed event to act on.
|
||||
|
||||
Raises:
|
||||
NotFoundError: if no matching delayed event could be found.
|
||||
"""
|
||||
assert self._is_master
|
||||
await self._request_ratelimiter.ratelimit(requester)
|
||||
await self._initialized_from_db
|
||||
|
||||
event, next_send_ts = await self._store.process_target_delayed_event(
|
||||
delay_id=delay_id,
|
||||
user_localpart=requester.user.localpart,
|
||||
)
|
||||
|
||||
if self._next_send_ts_changed(next_send_ts):
|
||||
self._schedule_next_at_or_none(next_send_ts)
|
||||
|
||||
await self._send_event(
|
||||
DelayedEventDetails(
|
||||
delay_id=DelayID(delay_id),
|
||||
user_localpart=UserLocalpart(requester.user.localpart),
|
||||
room_id=event.room_id,
|
||||
type=event.type,
|
||||
state_key=event.state_key,
|
||||
origin_server_ts=event.origin_server_ts,
|
||||
content=event.content,
|
||||
device_id=event.device_id,
|
||||
)
|
||||
)
|
||||
|
||||
async def _send_on_timeout(self) -> None:
|
||||
self._next_delayed_event_call = None
|
||||
|
||||
events, next_send_ts = await self._store.process_timeout_delayed_events(
|
||||
self._get_current_ts()
|
||||
)
|
||||
|
||||
if next_send_ts:
|
||||
self._schedule_next_at(next_send_ts)
|
||||
|
||||
await self._send_events(events)
|
||||
|
||||
async def _send_events(self, events: List[DelayedEventDetails]) -> None:
|
||||
sent_state: Set[Tuple[RoomID, EventType, StateKey]] = set()
|
||||
for event in events:
|
||||
if event.state_key is not None:
|
||||
state_info = (event.room_id, event.type, event.state_key)
|
||||
if state_info in sent_state:
|
||||
continue
|
||||
else:
|
||||
state_info = None
|
||||
try:
|
||||
# TODO: send in background if message event or non-conflicting state event
|
||||
await self._send_event(event)
|
||||
if state_info is not None:
|
||||
sent_state.add(state_info)
|
||||
except Exception:
|
||||
logger.exception("Failed to send delayed event")
|
||||
|
||||
for room_id, event_type, state_key in sent_state:
|
||||
await self._store.delete_processed_delayed_state_events(
|
||||
room_id=str(room_id),
|
||||
event_type=event_type,
|
||||
state_key=state_key,
|
||||
)
|
||||
|
||||
def _schedule_next_at_or_none(self, next_send_ts: Optional[Timestamp]) -> None:
|
||||
if next_send_ts is not None:
|
||||
self._schedule_next_at(next_send_ts)
|
||||
elif self._next_delayed_event_call is not None:
|
||||
self._next_delayed_event_call.cancel()
|
||||
self._next_delayed_event_call = None
|
||||
|
||||
def _schedule_next_at(self, next_send_ts: Timestamp) -> None:
|
||||
delay = next_send_ts - self._get_current_ts()
|
||||
delay_sec = delay / 1000 if delay > 0 else 0
|
||||
|
||||
if self._next_delayed_event_call is None:
|
||||
self._next_delayed_event_call = self._clock.call_later(
|
||||
delay_sec,
|
||||
run_as_background_process,
|
||||
"_send_on_timeout",
|
||||
self._send_on_timeout,
|
||||
)
|
||||
else:
|
||||
self._next_delayed_event_call.reset(delay_sec)
|
||||
|
||||
async def get_all_for_user(self, requester: Requester) -> List[JsonDict]:
|
||||
"""Return all pending delayed events requested by the given user."""
|
||||
await self._request_ratelimiter.ratelimit(requester)
|
||||
return await self._store.get_all_delayed_events_for_user(
|
||||
requester.user.localpart
|
||||
)
|
||||
|
||||
async def _send_event(
|
||||
self,
|
||||
event: DelayedEventDetails,
|
||||
txn_id: Optional[str] = None,
|
||||
) -> None:
|
||||
user_id = UserID(event.user_localpart, self._config.server.server_name)
|
||||
user_id_str = user_id.to_string()
|
||||
# Create a new requester from what data is currently available
|
||||
requester = create_requester(
|
||||
user_id,
|
||||
is_guest=await self._store.is_guest(user_id_str),
|
||||
device_id=event.device_id,
|
||||
)
|
||||
|
||||
try:
|
||||
if event.state_key is not None and event.type == EventTypes.Member:
|
||||
membership = event.content.get("membership")
|
||||
assert membership is not None
|
||||
event_id, _ = await self._room_member_handler.update_membership(
|
||||
requester,
|
||||
target=UserID.from_string(event.state_key),
|
||||
room_id=event.room_id.to_string(),
|
||||
action=membership,
|
||||
content=event.content,
|
||||
origin_server_ts=event.origin_server_ts,
|
||||
)
|
||||
else:
|
||||
event_dict: JsonDict = {
|
||||
"type": event.type,
|
||||
"content": event.content,
|
||||
"room_id": event.room_id.to_string(),
|
||||
"sender": user_id_str,
|
||||
}
|
||||
|
||||
if event.origin_server_ts is not None:
|
||||
event_dict["origin_server_ts"] = event.origin_server_ts
|
||||
|
||||
if event.state_key is not None:
|
||||
event_dict["state_key"] = event.state_key
|
||||
|
||||
(
|
||||
sent_event,
|
||||
_,
|
||||
) = await self._event_creation_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
event_dict,
|
||||
txn_id=txn_id,
|
||||
)
|
||||
event_id = sent_event.event_id
|
||||
except ShadowBanError:
|
||||
event_id = generate_fake_event_id()
|
||||
finally:
|
||||
# TODO: If this is a temporary error, retry. Otherwise, consider notifying clients of the failure
|
||||
try:
|
||||
await self._store.delete_processed_delayed_event(
|
||||
event.delay_id, event.user_localpart
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to delete processed delayed event")
|
||||
|
||||
set_tag("event_id", event_id)
|
||||
|
||||
def _get_current_ts(self) -> Timestamp:
|
||||
return Timestamp(self._clock.time_msec())
|
||||
|
||||
def _next_send_ts_changed(self, next_send_ts: Optional[Timestamp]) -> bool:
|
||||
# The DB alone knows if the next send time changed after adding/modifying
|
||||
# a delayed event, but if we were to ever miss updating our delayed call's
|
||||
# firing time, we may miss other updates. So, keep track of changes to the
|
||||
# the next send time here instead of in the DB.
|
||||
cached_next_send_ts = (
|
||||
int(self._next_delayed_event_call.getTime() * 1000)
|
||||
if self._next_delayed_event_call is not None
|
||||
else None
|
||||
)
|
||||
return next_send_ts != cached_next_send_ts
|
|
@ -23,6 +23,7 @@ from typing import TYPE_CHECKING
|
|||
from synapse.http.server import JsonResource
|
||||
from synapse.replication.http import (
|
||||
account_data,
|
||||
delayed_events,
|
||||
devices,
|
||||
federation,
|
||||
login,
|
||||
|
@ -64,3 +65,4 @@ class ReplicationRestResource(JsonResource):
|
|||
login.register_servlets(hs, self)
|
||||
register.register_servlets(hs, self)
|
||||
devices.register_servlets(hs, self)
|
||||
delayed_events.register_servlets(hs, self)
|
||||
|
|
48
synapse/replication/http/delayed_events.py
Normal file
48
synapse/replication/http/delayed_events.py
Normal file
|
@ -0,0 +1,48 @@
|
|||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.replication.http._base import ReplicationEndpoint
|
||||
from synapse.types import JsonDict, JsonMapping
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReplicationAddedDelayedEventRestServlet(ReplicationEndpoint):
|
||||
"""Handle a delayed event being added by another worker.
|
||||
|
||||
Request format:
|
||||
|
||||
POST /_synapse/replication/delayed_event_added/
|
||||
|
||||
{}
|
||||
"""
|
||||
|
||||
NAME = "added_delayed_event"
|
||||
PATH_ARGS = ()
|
||||
CACHE = False
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
|
||||
self.handler = hs.get_delayed_events_handler()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload(next_send_ts: int) -> JsonDict: # type: ignore[override]
|
||||
return {"next_send_ts": next_send_ts}
|
||||
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: Request, content: JsonDict
|
||||
) -> Tuple[int, Dict[str, Optional[JsonMapping]]]:
|
||||
self.handler.on_added(int(content["next_send_ts"]))
|
||||
|
||||
return 200, {}
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
ReplicationAddedDelayedEventRestServlet(hs).register(http_server)
|
|
@ -2,7 +2,7 @@
|
|||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright (C) 2023 New Vector, Ltd
|
||||
# Copyright (C) 2023-2024 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
|
@ -31,6 +31,7 @@ from synapse.rest.client import (
|
|||
auth,
|
||||
auth_issuer,
|
||||
capabilities,
|
||||
delayed_events,
|
||||
devices,
|
||||
directory,
|
||||
events,
|
||||
|
@ -81,6 +82,7 @@ CLIENT_SERVLET_FUNCTIONS: Tuple[RegisterServletsFunc, ...] = (
|
|||
room.register_deprecated_servlets,
|
||||
events.register_servlets,
|
||||
room.register_servlets,
|
||||
delayed_events.register_servlets,
|
||||
login.register_servlets,
|
||||
profile.register_servlets,
|
||||
presence.register_servlets,
|
||||
|
|
97
synapse/rest/client/delayed_events.py
Normal file
97
synapse/rest/client/delayed_events.py
Normal file
|
@ -0,0 +1,97 @@
|
|||
# This module contains REST servlets to do with delayed events: /delayed_events/<paths>
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client._base import client_patterns
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _UpdateDelayedEventAction(Enum):
|
||||
CANCEL = "cancel"
|
||||
RESTART = "restart"
|
||||
SEND = "send"
|
||||
|
||||
|
||||
class UpdateDelayedEventServlet(RestServlet):
|
||||
PATTERNS = client_patterns(
|
||||
r"/org\.matrix\.msc4140/delayed_events/(?P<delay_id>[^/]+)$",
|
||||
releases=(),
|
||||
)
|
||||
CATEGORY = "Delayed event management requests"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.delayed_events_handler = hs.get_delayed_events_handler()
|
||||
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, delay_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
try:
|
||||
action = str(body["action"])
|
||||
except KeyError:
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"'action' is missing",
|
||||
Codes.MISSING_PARAM,
|
||||
)
|
||||
try:
|
||||
enum_action = _UpdateDelayedEventAction(action)
|
||||
except ValueError:
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"'action' is not one of "
|
||||
+ ", ".join(f"'{m.value}'" for m in _UpdateDelayedEventAction),
|
||||
Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
if enum_action == _UpdateDelayedEventAction.CANCEL:
|
||||
await self.delayed_events_handler.cancel(requester, delay_id)
|
||||
elif enum_action == _UpdateDelayedEventAction.RESTART:
|
||||
await self.delayed_events_handler.restart(requester, delay_id)
|
||||
elif enum_action == _UpdateDelayedEventAction.SEND:
|
||||
await self.delayed_events_handler.send(requester, delay_id)
|
||||
return 200, {}
|
||||
|
||||
|
||||
class DelayedEventsServlet(RestServlet):
|
||||
PATTERNS = client_patterns(
|
||||
r"/org\.matrix\.msc4140/delayed_events$",
|
||||
releases=(),
|
||||
)
|
||||
CATEGORY = "Delayed event management requests"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.delayed_events_handler = hs.get_delayed_events_handler()
|
||||
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
# TODO: Support Pagination stream API ("from" query parameter)
|
||||
delayed_events = await self.delayed_events_handler.get_all_for_user(requester)
|
||||
|
||||
ret = {"delayed_events": delayed_events}
|
||||
return 200, ret
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
# The following can't currently be instantiated on workers.
|
||||
if hs.config.worker.worker_app is None:
|
||||
UpdateDelayedEventServlet(hs).register(http_server)
|
||||
DelayedEventsServlet(hs).register(http_server)
|
|
@ -2,7 +2,7 @@
|
|||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright (C) 2023 New Vector, Ltd
|
||||
# Copyright (C) 2023-2024 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
|
@ -195,7 +195,9 @@ class RoomStateEventRestServlet(RestServlet):
|
|||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
self.room_member_handler = hs.get_room_member_handler()
|
||||
self.message_handler = hs.get_message_handler()
|
||||
self.delayed_events_handler = hs.get_delayed_events_handler()
|
||||
self.auth = hs.get_auth()
|
||||
self._max_event_delay_ms = hs.config.server.max_event_delay_ms
|
||||
|
||||
def register(self, http_server: HttpServer) -> None:
|
||||
# /rooms/$roomid/state/$eventtype
|
||||
|
@ -291,6 +293,22 @@ class RoomStateEventRestServlet(RestServlet):
|
|||
if requester.app_service:
|
||||
origin_server_ts = parse_integer(request, "ts")
|
||||
|
||||
delay = _parse_request_delay(request, self._max_event_delay_ms)
|
||||
if delay is not None:
|
||||
delay_id = await self.delayed_events_handler.add(
|
||||
requester,
|
||||
room_id=room_id,
|
||||
event_type=event_type,
|
||||
state_key=state_key,
|
||||
origin_server_ts=origin_server_ts,
|
||||
content=content,
|
||||
delay=delay,
|
||||
)
|
||||
|
||||
set_tag("delay_id", delay_id)
|
||||
ret = {"delay_id": delay_id}
|
||||
return 200, ret
|
||||
|
||||
try:
|
||||
if event_type == EventTypes.Member:
|
||||
membership = content.get("membership", None)
|
||||
|
@ -341,7 +359,9 @@ class RoomSendEventRestServlet(TransactionRestServlet):
|
|||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
self.delayed_events_handler = hs.get_delayed_events_handler()
|
||||
self.auth = hs.get_auth()
|
||||
self._max_event_delay_ms = hs.config.server.max_event_delay_ms
|
||||
|
||||
def register(self, http_server: HttpServer) -> None:
|
||||
# /rooms/$roomid/send/$event_type[/$txn_id]
|
||||
|
@ -358,6 +378,26 @@ class RoomSendEventRestServlet(TransactionRestServlet):
|
|||
) -> Tuple[int, JsonDict]:
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
origin_server_ts = None
|
||||
if requester.app_service:
|
||||
origin_server_ts = parse_integer(request, "ts")
|
||||
|
||||
delay = _parse_request_delay(request, self._max_event_delay_ms)
|
||||
if delay is not None:
|
||||
delay_id = await self.delayed_events_handler.add(
|
||||
requester,
|
||||
room_id=room_id,
|
||||
event_type=event_type,
|
||||
state_key=None,
|
||||
origin_server_ts=origin_server_ts,
|
||||
content=content,
|
||||
delay=delay,
|
||||
)
|
||||
|
||||
set_tag("delay_id", delay_id)
|
||||
ret = {"delay_id": delay_id}
|
||||
return 200, ret
|
||||
|
||||
event_dict: JsonDict = {
|
||||
"type": event_type,
|
||||
"content": content,
|
||||
|
@ -365,8 +405,6 @@ class RoomSendEventRestServlet(TransactionRestServlet):
|
|||
"sender": requester.user.to_string(),
|
||||
}
|
||||
|
||||
if requester.app_service:
|
||||
origin_server_ts = parse_integer(request, "ts")
|
||||
if origin_server_ts is not None:
|
||||
event_dict["origin_server_ts"] = origin_server_ts
|
||||
|
||||
|
@ -411,6 +449,49 @@ class RoomSendEventRestServlet(TransactionRestServlet):
|
|||
)
|
||||
|
||||
|
||||
def _parse_request_delay(
|
||||
request: SynapseRequest,
|
||||
max_delay: Optional[int],
|
||||
) -> Optional[int]:
|
||||
"""Parses from the request string the delay parameter for
|
||||
delayed event requests, and checks it for correctness.
|
||||
|
||||
Args:
|
||||
request: the twisted HTTP request.
|
||||
max_delay: the maximum allowed value of the delay parameter,
|
||||
or None if no delay parameter is allowed.
|
||||
Returns:
|
||||
The value of the requested delay, or None if it was absent.
|
||||
|
||||
Raises:
|
||||
SynapseError: if the delay parameter is present and forbidden,
|
||||
or if it exceeds the maximum allowed value.
|
||||
"""
|
||||
delay = parse_integer(request, "org.matrix.msc4140.delay")
|
||||
if delay is None:
|
||||
return None
|
||||
if max_delay is None:
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"Delayed events are not supported on this server",
|
||||
Codes.UNKNOWN,
|
||||
{
|
||||
"org.matrix.msc4140.errcode": "M_MAX_DELAY_UNSUPPORTED",
|
||||
},
|
||||
)
|
||||
if delay > max_delay:
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"The requested delay exceeds the allowed maximum.",
|
||||
Codes.UNKNOWN,
|
||||
{
|
||||
"org.matrix.msc4140.errcode": "M_MAX_DELAY_EXCEEDED",
|
||||
"org.matrix.msc4140.max_delay": max_delay,
|
||||
},
|
||||
)
|
||||
return delay
|
||||
|
||||
|
||||
# TODO: Needs unit testing for room ID + alias joins
|
||||
class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
|
||||
CATEGORY = "Event sending requests"
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
# Copyright (C) 2023 New Vector, Ltd
|
||||
# Copyright (C) 2023-2024 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
|
@ -171,6 +171,8 @@ class VersionsRestServlet(RestServlet):
|
|||
is not None
|
||||
)
|
||||
),
|
||||
# MSC4140: Delayed events
|
||||
"org.matrix.msc4140": True,
|
||||
# MSC4151: Report room API (Client-Server API)
|
||||
"org.matrix.msc4151": self.config.experimental.msc4151_enabled,
|
||||
# Simplified sliding sync
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
# Copyright (C) 2023 New Vector, Ltd
|
||||
# Copyright (C) 2023-2024 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
|
@ -68,6 +68,7 @@ from synapse.handlers.appservice import ApplicationServicesHandler
|
|||
from synapse.handlers.auth import AuthHandler, PasswordAuthProvider
|
||||
from synapse.handlers.cas import CasHandler
|
||||
from synapse.handlers.deactivate_account import DeactivateAccountHandler
|
||||
from synapse.handlers.delayed_events import DelayedEventsHandler
|
||||
from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler
|
||||
from synapse.handlers.devicemessage import DeviceMessageHandler
|
||||
from synapse.handlers.directory import DirectoryHandler
|
||||
|
@ -251,6 +252,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
|||
"account_validity",
|
||||
"auth",
|
||||
"deactivate_account",
|
||||
"delayed_events",
|
||||
"message",
|
||||
"pagination",
|
||||
"profile",
|
||||
|
@ -964,3 +966,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
|||
register_threadpool("media", media_threadpool)
|
||||
|
||||
return media_threadpool
|
||||
|
||||
@cache_in_self
|
||||
def get_delayed_events_handler(self) -> DelayedEventsHandler:
|
||||
return DelayedEventsHandler(self)
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright (C) 2023 New Vector, Ltd
|
||||
# Copyright (C) 2023-2024 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
|
@ -44,6 +44,7 @@ from .appservice import ApplicationServiceStore, ApplicationServiceTransactionSt
|
|||
from .cache import CacheInvalidationWorkerStore
|
||||
from .censor_events import CensorEventsStore
|
||||
from .client_ips import ClientIpWorkerStore
|
||||
from .delayed_events import DelayedEventsStore
|
||||
from .deviceinbox import DeviceInboxStore
|
||||
from .devices import DeviceStore
|
||||
from .directory import DirectoryStore
|
||||
|
@ -158,6 +159,7 @@ class DataStore(
|
|||
SessionStore,
|
||||
TaskSchedulerWorkerStore,
|
||||
SlidingSyncStore,
|
||||
DelayedEventsStore,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
523
synapse/storage/databases/main/delayed_events.py
Normal file
523
synapse/storage/databases/main/delayed_events.py
Normal file
|
@ -0,0 +1,523 @@
|
|||
import logging
|
||||
from typing import List, NewType, Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.errors import NotFoundError
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import LoggingTransaction, StoreError
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.types import JsonDict, RoomID
|
||||
from synapse.util import json_encoder, stringutils as stringutils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DelayID = NewType("DelayID", str)
|
||||
UserLocalpart = NewType("UserLocalpart", str)
|
||||
DeviceID = NewType("DeviceID", str)
|
||||
EventType = NewType("EventType", str)
|
||||
StateKey = NewType("StateKey", str)
|
||||
|
||||
Delay = NewType("Delay", int)
|
||||
Timestamp = NewType("Timestamp", int)
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class EventDetails:
|
||||
room_id: RoomID
|
||||
type: EventType
|
||||
state_key: Optional[StateKey]
|
||||
origin_server_ts: Optional[Timestamp]
|
||||
content: JsonDict
|
||||
device_id: Optional[DeviceID]
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class DelayedEventDetails(EventDetails):
|
||||
delay_id: DelayID
|
||||
user_localpart: UserLocalpart
|
||||
|
||||
|
||||
class DelayedEventsStore(SQLBaseStore):
|
||||
async def get_delayed_events_stream_pos(self) -> int:
|
||||
"""
|
||||
Gets the stream position of the background process to watch for state events
|
||||
that target the same piece of state as any pending delayed events.
|
||||
"""
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="delayed_events_stream_pos",
|
||||
keyvalues={},
|
||||
retcol="stream_id",
|
||||
desc="get_delayed_events_stream_pos",
|
||||
)
|
||||
|
||||
async def update_delayed_events_stream_pos(self, stream_id: Optional[int]) -> None:
|
||||
"""
|
||||
Updates the stream position of the background process to watch for state events
|
||||
that target the same piece of state as any pending delayed events.
|
||||
|
||||
Must only be used by the worker running the background process.
|
||||
"""
|
||||
await self.db_pool.simple_update_one(
|
||||
table="delayed_events_stream_pos",
|
||||
keyvalues={},
|
||||
updatevalues={"stream_id": stream_id},
|
||||
desc="update_delayed_events_stream_pos",
|
||||
)
|
||||
|
||||
async def add_delayed_event(
|
||||
self,
|
||||
*,
|
||||
user_localpart: str,
|
||||
device_id: Optional[str],
|
||||
creation_ts: Timestamp,
|
||||
room_id: str,
|
||||
event_type: str,
|
||||
state_key: Optional[str],
|
||||
origin_server_ts: Optional[int],
|
||||
content: JsonDict,
|
||||
delay: int,
|
||||
) -> Tuple[DelayID, Timestamp]:
|
||||
"""
|
||||
Inserts a new delayed event in the DB.
|
||||
|
||||
Returns: The generated ID assigned to the added delayed event,
|
||||
and the send time of the next delayed event to be sent,
|
||||
which is either the event just added or one added earlier.
|
||||
"""
|
||||
delay_id = _generate_delay_id()
|
||||
send_ts = Timestamp(creation_ts + delay)
|
||||
|
||||
def add_delayed_event_txn(txn: LoggingTransaction) -> Timestamp:
|
||||
self.db_pool.simple_insert_txn(
|
||||
txn,
|
||||
table="delayed_events",
|
||||
values={
|
||||
"delay_id": delay_id,
|
||||
"user_localpart": user_localpart,
|
||||
"device_id": device_id,
|
||||
"delay": delay,
|
||||
"send_ts": send_ts,
|
||||
"room_id": room_id,
|
||||
"event_type": event_type,
|
||||
"state_key": state_key,
|
||||
"origin_server_ts": origin_server_ts,
|
||||
"content": json_encoder.encode(content),
|
||||
},
|
||||
)
|
||||
|
||||
next_send_ts = self._get_next_delayed_event_send_ts_txn(txn)
|
||||
assert next_send_ts is not None
|
||||
return next_send_ts
|
||||
|
||||
next_send_ts = await self.db_pool.runInteraction(
|
||||
"add_delayed_event", add_delayed_event_txn
|
||||
)
|
||||
|
||||
return delay_id, next_send_ts
|
||||
|
||||
async def restart_delayed_event(
|
||||
self,
|
||||
*,
|
||||
delay_id: str,
|
||||
user_localpart: str,
|
||||
current_ts: Timestamp,
|
||||
) -> Timestamp:
|
||||
"""
|
||||
Restarts the send time of the matching delayed event,
|
||||
as long as it hasn't already been marked for processing.
|
||||
|
||||
Args:
|
||||
delay_id: The ID of the delayed event to restart.
|
||||
user_localpart: The localpart of the delayed event's owner.
|
||||
current_ts: The current time, which will be used to calculate the new send time.
|
||||
|
||||
Returns: The send time of the next delayed event to be sent,
|
||||
which is either the event just restarted, or another one
|
||||
with an earlier send time than the restarted one's new send time.
|
||||
|
||||
Raises:
|
||||
NotFoundError: if there is no matching delayed event.
|
||||
"""
|
||||
|
||||
def restart_delayed_event_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Timestamp:
|
||||
txn.execute(
|
||||
"""
|
||||
UPDATE delayed_events
|
||||
SET send_ts = ? + delay
|
||||
WHERE delay_id = ? AND user_localpart = ?
|
||||
AND NOT is_processed
|
||||
""",
|
||||
(
|
||||
current_ts,
|
||||
delay_id,
|
||||
user_localpart,
|
||||
),
|
||||
)
|
||||
if txn.rowcount == 0:
|
||||
raise NotFoundError("Delayed event not found")
|
||||
|
||||
next_send_ts = self._get_next_delayed_event_send_ts_txn(txn)
|
||||
assert next_send_ts is not None
|
||||
return next_send_ts
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"restart_delayed_event", restart_delayed_event_txn
|
||||
)
|
||||
|
||||
async def get_all_delayed_events_for_user(
|
||||
self,
|
||||
user_localpart: str,
|
||||
) -> List[JsonDict]:
|
||||
"""Returns all pending delayed events owned by the given user."""
|
||||
# TODO: Support Pagination stream API ("next_batch" field)
|
||||
rows = await self.db_pool.execute(
|
||||
"get_all_delayed_events_for_user",
|
||||
"""
|
||||
SELECT
|
||||
delay_id,
|
||||
room_id,
|
||||
event_type,
|
||||
state_key,
|
||||
delay,
|
||||
send_ts,
|
||||
content
|
||||
FROM delayed_events
|
||||
WHERE user_localpart = ? AND NOT is_processed
|
||||
ORDER BY send_ts
|
||||
""",
|
||||
user_localpart,
|
||||
)
|
||||
return [
|
||||
{
|
||||
"delay_id": DelayID(row[0]),
|
||||
"room_id": str(RoomID.from_string(row[1])),
|
||||
"type": EventType(row[2]),
|
||||
**({"state_key": StateKey(row[3])} if row[3] is not None else {}),
|
||||
"delay": Delay(row[4]),
|
||||
"running_since": Timestamp(row[5] - row[4]),
|
||||
"content": db_to_json(row[6]),
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
async def process_timeout_delayed_events(
|
||||
self, current_ts: Timestamp
|
||||
) -> Tuple[
|
||||
List[DelayedEventDetails],
|
||||
Optional[Timestamp],
|
||||
]:
|
||||
"""
|
||||
Marks for processing all delayed events that should have been sent prior to the provided time
|
||||
that haven't already been marked as such.
|
||||
|
||||
Returns: The details of all newly-processed delayed events,
|
||||
and the send time of the next delayed event to be sent, if any.
|
||||
"""
|
||||
|
||||
def process_timeout_delayed_events_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Tuple[
|
||||
List[DelayedEventDetails],
|
||||
Optional[Timestamp],
|
||||
]:
|
||||
sql_cols = ", ".join(
|
||||
(
|
||||
"delay_id",
|
||||
"user_localpart",
|
||||
"room_id",
|
||||
"event_type",
|
||||
"state_key",
|
||||
"origin_server_ts",
|
||||
"send_ts",
|
||||
"content",
|
||||
"device_id",
|
||||
)
|
||||
)
|
||||
sql_update = "UPDATE delayed_events SET is_processed = TRUE"
|
||||
sql_where = "WHERE send_ts <= ? AND NOT is_processed"
|
||||
sql_args = (current_ts,)
|
||||
sql_order = "ORDER BY send_ts"
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
# Do this only in Postgres because:
|
||||
# - SQLite's RETURNING emits rows in an arbitrary order
|
||||
# - https://www.sqlite.org/lang_returning.html#limitations_and_caveats
|
||||
# - SQLite does not support data-modifying statements in a WITH clause
|
||||
# - https://www.sqlite.org/lang_with.html
|
||||
# - https://www.postgresql.org/docs/current/queries-with.html#QUERIES-WITH-MODIFYING
|
||||
txn.execute(
|
||||
f"""
|
||||
WITH events_to_send AS (
|
||||
{sql_update} {sql_where} RETURNING *
|
||||
) SELECT {sql_cols} FROM events_to_send {sql_order}
|
||||
""",
|
||||
sql_args,
|
||||
)
|
||||
rows = txn.fetchall()
|
||||
else:
|
||||
txn.execute(
|
||||
f"SELECT {sql_cols} FROM delayed_events {sql_where} {sql_order}",
|
||||
sql_args,
|
||||
)
|
||||
rows = txn.fetchall()
|
||||
txn.execute(f"{sql_update} {sql_where}", sql_args)
|
||||
assert txn.rowcount == len(rows)
|
||||
|
||||
events = [
|
||||
DelayedEventDetails(
|
||||
RoomID.from_string(row[2]),
|
||||
EventType(row[3]),
|
||||
StateKey(row[4]) if row[4] is not None else None,
|
||||
# If no custom_origin_ts is set, use send_ts as the event's timestamp
|
||||
Timestamp(row[5] if row[5] is not None else row[6]),
|
||||
db_to_json(row[7]),
|
||||
DeviceID(row[8]) if row[8] is not None else None,
|
||||
DelayID(row[0]),
|
||||
UserLocalpart(row[1]),
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
next_send_ts = self._get_next_delayed_event_send_ts_txn(txn)
|
||||
return events, next_send_ts
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"process_timeout_delayed_events", process_timeout_delayed_events_txn
|
||||
)
|
||||
|
||||
async def process_target_delayed_event(
|
||||
self,
|
||||
*,
|
||||
delay_id: str,
|
||||
user_localpart: str,
|
||||
) -> Tuple[
|
||||
EventDetails,
|
||||
Optional[Timestamp],
|
||||
]:
|
||||
"""
|
||||
Marks for processing the matching delayed event, regardless of its timeout time,
|
||||
as long as it has not already been marked as such.
|
||||
|
||||
Args:
|
||||
delay_id: The ID of the delayed event to restart.
|
||||
user_localpart: The localpart of the delayed event's owner.
|
||||
|
||||
Returns: The details of the matching delayed event,
|
||||
and the send time of the next delayed event to be sent, if any.
|
||||
|
||||
Raises:
|
||||
NotFoundError: if there is no matching delayed event.
|
||||
"""
|
||||
|
||||
def process_target_delayed_event_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Tuple[
|
||||
EventDetails,
|
||||
Optional[Timestamp],
|
||||
]:
|
||||
sql_cols = ", ".join(
|
||||
(
|
||||
"room_id",
|
||||
"event_type",
|
||||
"state_key",
|
||||
"origin_server_ts",
|
||||
"content",
|
||||
"device_id",
|
||||
)
|
||||
)
|
||||
sql_update = "UPDATE delayed_events SET is_processed = TRUE"
|
||||
sql_where = "WHERE delay_id = ? AND user_localpart = ? AND NOT is_processed"
|
||||
sql_args = (delay_id, user_localpart)
|
||||
txn.execute(
|
||||
(
|
||||
f"{sql_update} {sql_where} RETURNING {sql_cols}"
|
||||
if self.database_engine.supports_returning
|
||||
else f"SELECT {sql_cols} FROM delayed_events {sql_where}"
|
||||
),
|
||||
sql_args,
|
||||
)
|
||||
row = txn.fetchone()
|
||||
if row is None:
|
||||
raise NotFoundError("Delayed event not found")
|
||||
elif not self.database_engine.supports_returning:
|
||||
txn.execute(f"{sql_update} {sql_where}", sql_args)
|
||||
assert txn.rowcount == 1
|
||||
|
||||
event = EventDetails(
|
||||
RoomID.from_string(row[0]),
|
||||
EventType(row[1]),
|
||||
StateKey(row[2]) if row[2] is not None else None,
|
||||
Timestamp(row[3]) if row[3] is not None else None,
|
||||
db_to_json(row[4]),
|
||||
DeviceID(row[5]) if row[5] is not None else None,
|
||||
)
|
||||
|
||||
return event, self._get_next_delayed_event_send_ts_txn(txn)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"process_target_delayed_event", process_target_delayed_event_txn
|
||||
)
|
||||
|
||||
async def cancel_delayed_event(
|
||||
self,
|
||||
*,
|
||||
delay_id: str,
|
||||
user_localpart: str,
|
||||
) -> Optional[Timestamp]:
|
||||
"""
|
||||
Cancels the matching delayed event, i.e. remove it as long as it hasn't been processed.
|
||||
|
||||
Args:
|
||||
delay_id: The ID of the delayed event to restart.
|
||||
user_localpart: The localpart of the delayed event's owner.
|
||||
|
||||
Returns: The send time of the next delayed event to be sent, if any.
|
||||
|
||||
Raises:
|
||||
NotFoundError: if there is no matching delayed event.
|
||||
"""
|
||||
|
||||
def cancel_delayed_event_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Optional[Timestamp]:
|
||||
try:
|
||||
self.db_pool.simple_delete_one_txn(
|
||||
txn,
|
||||
table="delayed_events",
|
||||
keyvalues={
|
||||
"delay_id": delay_id,
|
||||
"user_localpart": user_localpart,
|
||||
"is_processed": False,
|
||||
},
|
||||
)
|
||||
except StoreError:
|
||||
if txn.rowcount == 0:
|
||||
raise NotFoundError("Delayed event not found")
|
||||
else:
|
||||
raise
|
||||
|
||||
return self._get_next_delayed_event_send_ts_txn(txn)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"cancel_delayed_event", cancel_delayed_event_txn
|
||||
)
|
||||
|
||||
async def cancel_delayed_state_events(
|
||||
self,
|
||||
*,
|
||||
room_id: str,
|
||||
event_type: str,
|
||||
state_key: str,
|
||||
) -> Optional[Timestamp]:
|
||||
"""
|
||||
Cancels all matching delayed state events, i.e. remove them as long as they haven't been processed.
|
||||
|
||||
Returns: The send time of the next delayed event to be sent, if any.
|
||||
"""
|
||||
|
||||
def cancel_delayed_state_events_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Optional[Timestamp]:
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
table="delayed_events",
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
"event_type": event_type,
|
||||
"state_key": state_key,
|
||||
"is_processed": False,
|
||||
},
|
||||
)
|
||||
return self._get_next_delayed_event_send_ts_txn(txn)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"cancel_delayed_state_events", cancel_delayed_state_events_txn
|
||||
)
|
||||
|
||||
async def delete_processed_delayed_event(
|
||||
self,
|
||||
delay_id: DelayID,
|
||||
user_localpart: UserLocalpart,
|
||||
) -> None:
|
||||
"""
|
||||
Delete the matching delayed event, as long as it has been marked as processed.
|
||||
|
||||
Throws:
|
||||
StoreError: if there is no matching delayed event, or if it has not yet been processed.
|
||||
"""
|
||||
return await self.db_pool.simple_delete_one(
|
||||
table="delayed_events",
|
||||
keyvalues={
|
||||
"delay_id": delay_id,
|
||||
"user_localpart": user_localpart,
|
||||
"is_processed": True,
|
||||
},
|
||||
desc="delete_processed_delayed_event",
|
||||
)
|
||||
|
||||
async def delete_processed_delayed_state_events(
|
||||
self,
|
||||
*,
|
||||
room_id: str,
|
||||
event_type: str,
|
||||
state_key: str,
|
||||
) -> None:
|
||||
"""
|
||||
Delete the matching delayed state events that have been marked as processed.
|
||||
"""
|
||||
await self.db_pool.simple_delete(
|
||||
table="delayed_events",
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
"event_type": event_type,
|
||||
"state_key": state_key,
|
||||
"is_processed": True,
|
||||
},
|
||||
desc="delete_processed_delayed_state_events",
|
||||
)
|
||||
|
||||
async def unprocess_delayed_events(self) -> None:
|
||||
"""
|
||||
Unmark all delayed events for processing.
|
||||
"""
|
||||
await self.db_pool.simple_update(
|
||||
table="delayed_events",
|
||||
keyvalues={"is_processed": True},
|
||||
updatevalues={"is_processed": False},
|
||||
desc="unprocess_delayed_events",
|
||||
)
|
||||
|
||||
async def get_next_delayed_event_send_ts(self) -> Optional[Timestamp]:
|
||||
"""
|
||||
Returns the send time of the next delayed event to be sent, if any.
|
||||
"""
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_next_delayed_event_send_ts",
|
||||
self._get_next_delayed_event_send_ts_txn,
|
||||
db_autocommit=True,
|
||||
)
|
||||
|
||||
def _get_next_delayed_event_send_ts_txn(
|
||||
self, txn: LoggingTransaction
|
||||
) -> Optional[Timestamp]:
|
||||
result = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="delayed_events",
|
||||
keyvalues={"is_processed": False},
|
||||
retcol="MIN(send_ts)",
|
||||
allow_none=True,
|
||||
)
|
||||
return Timestamp(result) if result is not None else None
|
||||
|
||||
|
||||
def _generate_delay_id() -> DelayID:
|
||||
"""Generates an opaque string, for use as a delay ID"""
|
||||
|
||||
# We use the following format for delay IDs:
|
||||
# syd_<random string>
|
||||
# They are scoped to user localparts, so it is possible for
|
||||
# the same ID to exist for multiple users.
|
||||
|
||||
return DelayID(f"syd_{stringutils.random_string(20)}")
|
|
@ -2,7 +2,7 @@
|
|||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
# Copyright (C) 2023 New Vector, Ltd
|
||||
# Copyright (C) 2023-2024 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
|
@ -19,7 +19,7 @@
|
|||
#
|
||||
#
|
||||
|
||||
SCHEMA_VERSION = 87 # remember to update the list below when updating
|
||||
SCHEMA_VERSION = 88 # remember to update the list below when updating
|
||||
"""Represents the expectations made by the codebase about the database schema
|
||||
|
||||
This should be incremented whenever the codebase changes its requirements on the
|
||||
|
@ -149,6 +149,10 @@ Changes in SCHEMA_VERSION = 87
|
|||
- Add tables for storing the per-connection state for sliding sync requests:
|
||||
sliding_sync_connections, sliding_sync_connection_positions, sliding_sync_connection_required_state,
|
||||
sliding_sync_connection_room_configs, sliding_sync_connection_streams
|
||||
|
||||
Changes in SCHEMA_VERSION = 88
|
||||
- MSC4140: Add `delayed_events` table that keeps track of events that are to
|
||||
be posted in response to a resettable timeout or an on-demand action.
|
||||
"""
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
CREATE TABLE delayed_events (
|
||||
delay_id TEXT NOT NULL,
|
||||
user_localpart TEXT NOT NULL,
|
||||
device_id TEXT,
|
||||
delay BIGINT NOT NULL,
|
||||
send_ts BIGINT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
event_type TEXT NOT NULL,
|
||||
state_key TEXT,
|
||||
origin_server_ts BIGINT,
|
||||
content bytea NOT NULL,
|
||||
is_processed BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
PRIMARY KEY (user_localpart, delay_id)
|
||||
);
|
||||
|
||||
CREATE INDEX delayed_events_send_ts ON delayed_events (send_ts);
|
||||
CREATE INDEX delayed_events_is_processed ON delayed_events (is_processed);
|
||||
CREATE INDEX delayed_events_room_state_event_idx ON delayed_events (room_id, event_type, state_key) WHERE state_key IS NOT NULL;
|
||||
|
||||
CREATE TABLE delayed_events_stream_pos (
|
||||
Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
|
||||
stream_id BIGINT NOT NULL,
|
||||
CHECK (Lock='X')
|
||||
);
|
||||
|
||||
-- Start processing events from the point this migration was run, rather
|
||||
-- than the beginning of time.
|
||||
INSERT INTO delayed_events_stream_pos (
|
||||
stream_id
|
||||
) SELECT COALESCE(MAX(stream_ordering), 0) from events;
|
346
tests/rest/client/test_delayed_events.py
Normal file
346
tests/rest/client/test_delayed_events.py
Normal file
|
@ -0,0 +1,346 @@
|
|||
"""Tests REST events for /delayed_events paths."""
|
||||
|
||||
from http import HTTPStatus
|
||||
from typing import List
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.errors import Codes
|
||||
from synapse.rest.client import delayed_events, room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.unittest import HomeserverTestCase
|
||||
|
||||
PATH_PREFIX = "/_matrix/client/unstable/org.matrix.msc4140/delayed_events"
|
||||
|
||||
_HS_NAME = "red"
|
||||
_EVENT_TYPE = "com.example.test"
|
||||
|
||||
|
||||
class DelayedEventsTestCase(HomeserverTestCase):
|
||||
"""Tests getting and managing delayed events."""
|
||||
|
||||
servlets = [delayed_events.register_servlets, room.register_servlets]
|
||||
user_id = f"@sid1:{_HS_NAME}"
|
||||
|
||||
def default_config(self) -> JsonDict:
|
||||
config = super().default_config()
|
||||
config["server_name"] = _HS_NAME
|
||||
config["max_event_delay_duration"] = "24h"
|
||||
return config
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.room_id = self.helper.create_room_as(
|
||||
self.user_id,
|
||||
extra_content={
|
||||
"preset": "trusted_private_chat",
|
||||
},
|
||||
)
|
||||
|
||||
def test_delayed_events_empty_on_startup(self) -> None:
|
||||
self.assertListEqual([], self._get_delayed_events())
|
||||
|
||||
def test_delayed_state_events_are_sent_on_timeout(self) -> None:
|
||||
state_key = "to_send_on_timeout"
|
||||
|
||||
setter_key = "setter"
|
||||
setter_expected = "on_timeout"
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
_get_path_for_delayed_state(self.room_id, _EVENT_TYPE, state_key, 900),
|
||||
{
|
||||
setter_key: setter_expected,
|
||||
},
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
events = self._get_delayed_events()
|
||||
self.assertEqual(1, len(events), events)
|
||||
content = self._get_delayed_event_content(events[0])
|
||||
self.assertEqual(setter_expected, content.get(setter_key), content)
|
||||
self.helper.get_state(
|
||||
self.room_id,
|
||||
_EVENT_TYPE,
|
||||
"",
|
||||
state_key=state_key,
|
||||
expect_code=HTTPStatus.NOT_FOUND,
|
||||
)
|
||||
|
||||
self.reactor.advance(1)
|
||||
self.assertListEqual([], self._get_delayed_events())
|
||||
content = self.helper.get_state(
|
||||
self.room_id,
|
||||
_EVENT_TYPE,
|
||||
"",
|
||||
state_key=state_key,
|
||||
)
|
||||
self.assertEqual(setter_expected, content.get(setter_key), content)
|
||||
|
||||
def test_update_delayed_event_without_id(self) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/",
|
||||
)
|
||||
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result)
|
||||
|
||||
def test_update_delayed_event_without_body(self) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/abc",
|
||||
)
|
||||
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
|
||||
self.assertEqual(
|
||||
Codes.NOT_JSON,
|
||||
channel.json_body["errcode"],
|
||||
)
|
||||
|
||||
def test_update_delayed_event_without_action(self) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/abc",
|
||||
{},
|
||||
)
|
||||
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
|
||||
self.assertEqual(
|
||||
Codes.MISSING_PARAM,
|
||||
channel.json_body["errcode"],
|
||||
)
|
||||
|
||||
def test_update_delayed_event_with_invalid_action(self) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/abc",
|
||||
{"action": "oops"},
|
||||
)
|
||||
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
|
||||
self.assertEqual(
|
||||
Codes.INVALID_PARAM,
|
||||
channel.json_body["errcode"],
|
||||
)
|
||||
|
||||
@parameterized.expand(["cancel", "restart", "send"])
|
||||
def test_update_delayed_event_without_match(self, action: str) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/abc",
|
||||
{"action": action},
|
||||
)
|
||||
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result)
|
||||
|
||||
def test_cancel_delayed_state_event(self) -> None:
|
||||
state_key = "to_never_send"
|
||||
|
||||
setter_key = "setter"
|
||||
setter_expected = "none"
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
_get_path_for_delayed_state(self.room_id, _EVENT_TYPE, state_key, 1500),
|
||||
{
|
||||
setter_key: setter_expected,
|
||||
},
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
delay_id = channel.json_body.get("delay_id")
|
||||
self.assertIsNotNone(delay_id)
|
||||
|
||||
self.reactor.advance(1)
|
||||
events = self._get_delayed_events()
|
||||
self.assertEqual(1, len(events), events)
|
||||
content = self._get_delayed_event_content(events[0])
|
||||
self.assertEqual(setter_expected, content.get(setter_key), content)
|
||||
self.helper.get_state(
|
||||
self.room_id,
|
||||
_EVENT_TYPE,
|
||||
"",
|
||||
state_key=state_key,
|
||||
expect_code=HTTPStatus.NOT_FOUND,
|
||||
)
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/{delay_id}",
|
||||
{"action": "cancel"},
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
self.assertListEqual([], self._get_delayed_events())
|
||||
|
||||
self.reactor.advance(1)
|
||||
content = self.helper.get_state(
|
||||
self.room_id,
|
||||
_EVENT_TYPE,
|
||||
"",
|
||||
state_key=state_key,
|
||||
expect_code=HTTPStatus.NOT_FOUND,
|
||||
)
|
||||
|
||||
def test_send_delayed_state_event(self) -> None:
|
||||
state_key = "to_send_on_request"
|
||||
|
||||
setter_key = "setter"
|
||||
setter_expected = "on_send"
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
_get_path_for_delayed_state(self.room_id, _EVENT_TYPE, state_key, 100000),
|
||||
{
|
||||
setter_key: setter_expected,
|
||||
},
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
delay_id = channel.json_body.get("delay_id")
|
||||
self.assertIsNotNone(delay_id)
|
||||
|
||||
self.reactor.advance(1)
|
||||
events = self._get_delayed_events()
|
||||
self.assertEqual(1, len(events), events)
|
||||
content = self._get_delayed_event_content(events[0])
|
||||
self.assertEqual(setter_expected, content.get(setter_key), content)
|
||||
self.helper.get_state(
|
||||
self.room_id,
|
||||
_EVENT_TYPE,
|
||||
"",
|
||||
state_key=state_key,
|
||||
expect_code=HTTPStatus.NOT_FOUND,
|
||||
)
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/{delay_id}",
|
||||
{"action": "send"},
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
self.assertListEqual([], self._get_delayed_events())
|
||||
content = self.helper.get_state(
|
||||
self.room_id,
|
||||
_EVENT_TYPE,
|
||||
"",
|
||||
state_key=state_key,
|
||||
)
|
||||
self.assertEqual(setter_expected, content.get(setter_key), content)
|
||||
|
||||
def test_restart_delayed_state_event(self) -> None:
|
||||
state_key = "to_send_on_restarted_timeout"
|
||||
|
||||
setter_key = "setter"
|
||||
setter_expected = "on_timeout"
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
_get_path_for_delayed_state(self.room_id, _EVENT_TYPE, state_key, 1500),
|
||||
{
|
||||
setter_key: setter_expected,
|
||||
},
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
delay_id = channel.json_body.get("delay_id")
|
||||
self.assertIsNotNone(delay_id)
|
||||
|
||||
self.reactor.advance(1)
|
||||
events = self._get_delayed_events()
|
||||
self.assertEqual(1, len(events), events)
|
||||
content = self._get_delayed_event_content(events[0])
|
||||
self.assertEqual(setter_expected, content.get(setter_key), content)
|
||||
self.helper.get_state(
|
||||
self.room_id,
|
||||
_EVENT_TYPE,
|
||||
"",
|
||||
state_key=state_key,
|
||||
expect_code=HTTPStatus.NOT_FOUND,
|
||||
)
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/{delay_id}",
|
||||
{"action": "restart"},
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
|
||||
self.reactor.advance(1)
|
||||
events = self._get_delayed_events()
|
||||
self.assertEqual(1, len(events), events)
|
||||
content = self._get_delayed_event_content(events[0])
|
||||
self.assertEqual(setter_expected, content.get(setter_key), content)
|
||||
self.helper.get_state(
|
||||
self.room_id,
|
||||
_EVENT_TYPE,
|
||||
"",
|
||||
state_key=state_key,
|
||||
expect_code=HTTPStatus.NOT_FOUND,
|
||||
)
|
||||
|
||||
self.reactor.advance(1)
|
||||
self.assertListEqual([], self._get_delayed_events())
|
||||
content = self.helper.get_state(
|
||||
self.room_id,
|
||||
_EVENT_TYPE,
|
||||
"",
|
||||
state_key=state_key,
|
||||
)
|
||||
self.assertEqual(setter_expected, content.get(setter_key), content)
|
||||
|
||||
def test_delayed_state_events_are_cancelled_by_more_recent_state(self) -> None:
|
||||
state_key = "to_be_cancelled"
|
||||
|
||||
setter_key = "setter"
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
_get_path_for_delayed_state(self.room_id, _EVENT_TYPE, state_key, 900),
|
||||
{
|
||||
setter_key: "on_timeout",
|
||||
},
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
events = self._get_delayed_events()
|
||||
self.assertEqual(1, len(events), events)
|
||||
|
||||
setter_expected = "manual"
|
||||
self.helper.send_state(
|
||||
self.room_id,
|
||||
_EVENT_TYPE,
|
||||
{
|
||||
setter_key: setter_expected,
|
||||
},
|
||||
None,
|
||||
state_key=state_key,
|
||||
)
|
||||
self.assertListEqual([], self._get_delayed_events())
|
||||
|
||||
self.reactor.advance(1)
|
||||
content = self.helper.get_state(
|
||||
self.room_id,
|
||||
_EVENT_TYPE,
|
||||
"",
|
||||
state_key=state_key,
|
||||
)
|
||||
self.assertEqual(setter_expected, content.get(setter_key), content)
|
||||
|
||||
def _get_delayed_events(self) -> List[JsonDict]:
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
PATH_PREFIX,
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
|
||||
key = "delayed_events"
|
||||
self.assertIn(key, channel.json_body)
|
||||
|
||||
events = channel.json_body[key]
|
||||
self.assertIsInstance(events, list)
|
||||
|
||||
return events
|
||||
|
||||
def _get_delayed_event_content(self, event: JsonDict) -> JsonDict:
|
||||
key = "content"
|
||||
self.assertIn(key, event)
|
||||
|
||||
content = event[key]
|
||||
self.assertIsInstance(content, dict)
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def _get_path_for_delayed_state(
|
||||
room_id: str, event_type: str, state_key: str, delay_ms: int
|
||||
) -> str:
|
||||
return f"rooms/{room_id}/state/{event_type}/{state_key}?org.matrix.msc4140.delay={delay_ms}"
|
|
@ -2291,6 +2291,106 @@ class RoomMessageFilterTestCase(RoomBase):
|
|||
self.assertEqual(len(chunk), 2, [event["content"] for event in chunk])
|
||||
|
||||
|
||||
class RoomDelayedEventTestCase(RoomBase):
|
||||
"""Tests delayed events."""
|
||||
|
||||
user_id = "@sid1:red"
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.room_id = self.helper.create_room_as(self.user_id)
|
||||
|
||||
@unittest.override_config({"max_event_delay_duration": "24h"})
|
||||
def test_send_delayed_invalid_event(self) -> None:
|
||||
"""Test sending a delayed event with invalid content."""
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
(
|
||||
"rooms/%s/send/m.room.message/mid1?org.matrix.msc4140.delay=2000"
|
||||
% self.room_id
|
||||
).encode("ascii"),
|
||||
{},
|
||||
)
|
||||
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
|
||||
self.assertNotIn("org.matrix.msc4140.errcode", channel.json_body)
|
||||
|
||||
def test_delayed_event_unsupported_by_default(self) -> None:
|
||||
"""Test that sending a delayed event is unsupported with the default config."""
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
(
|
||||
"rooms/%s/send/m.room.message/mid1?org.matrix.msc4140.delay=2000"
|
||||
% self.room_id
|
||||
).encode("ascii"),
|
||||
{"body": "test", "msgtype": "m.text"},
|
||||
)
|
||||
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
|
||||
self.assertEqual(
|
||||
"M_MAX_DELAY_UNSUPPORTED",
|
||||
channel.json_body.get("org.matrix.msc4140.errcode"),
|
||||
channel.json_body,
|
||||
)
|
||||
|
||||
@unittest.override_config({"max_event_delay_duration": "1000"})
|
||||
def test_delayed_event_exceeds_max_delay(self) -> None:
|
||||
"""Test that sending a delayed event fails if its delay is longer than allowed."""
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
(
|
||||
"rooms/%s/send/m.room.message/mid1?org.matrix.msc4140.delay=2000"
|
||||
% self.room_id
|
||||
).encode("ascii"),
|
||||
{"body": "test", "msgtype": "m.text"},
|
||||
)
|
||||
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
|
||||
self.assertEqual(
|
||||
"M_MAX_DELAY_EXCEEDED",
|
||||
channel.json_body.get("org.matrix.msc4140.errcode"),
|
||||
channel.json_body,
|
||||
)
|
||||
|
||||
@unittest.override_config({"max_event_delay_duration": "24h"})
|
||||
def test_delayed_event_with_negative_delay(self) -> None:
|
||||
"""Test that sending a delayed event fails if its delay is negative."""
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
(
|
||||
"rooms/%s/send/m.room.message/mid1?org.matrix.msc4140.delay=-2000"
|
||||
% self.room_id
|
||||
).encode("ascii"),
|
||||
{"body": "test", "msgtype": "m.text"},
|
||||
)
|
||||
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
|
||||
self.assertEqual(
|
||||
Codes.INVALID_PARAM, channel.json_body["errcode"], channel.json_body
|
||||
)
|
||||
|
||||
@unittest.override_config({"max_event_delay_duration": "24h"})
|
||||
def test_send_delayed_message_event(self) -> None:
|
||||
"""Test sending a valid delayed message event."""
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
(
|
||||
"rooms/%s/send/m.room.message/mid1?org.matrix.msc4140.delay=2000"
|
||||
% self.room_id
|
||||
).encode("ascii"),
|
||||
{"body": "test", "msgtype": "m.text"},
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
|
||||
@unittest.override_config({"max_event_delay_duration": "24h"})
|
||||
def test_send_delayed_state_event(self) -> None:
|
||||
"""Test sending a valid delayed state event."""
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
(
|
||||
"rooms/%s/state/m.room.topic/?org.matrix.msc4140.delay=2000"
|
||||
% self.room_id
|
||||
).encode("ascii"),
|
||||
{"topic": "This is a topic"},
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
|
||||
|
||||
class RoomSearchTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||
|
|
Loading…
Reference in a new issue