mirror of
https://github.com/element-hq/synapse.git
synced 2025-01-20 18:42:33 +00:00
Change EventContext to use the Storage class (#6564)
This commit is contained in:
parent
0b5dbadd96
commit
fa780e9721
15 changed files with 64 additions and 53 deletions
1
changelog.d/6564.misc
Normal file
1
changelog.d/6564.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Change `EventContext` to use the `Storage` class, in preparation for moving state database queries to a separate data store.
|
|
@ -79,7 +79,7 @@ class Auth(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_from_context(self, room_version, event, context, do_sig_check=True):
|
def check_from_context(self, room_version, event, context, do_sig_check=True):
|
||||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
prev_state_ids = yield context.get_prev_state_ids()
|
||||||
auth_events_ids = yield self.compute_auth_events(
|
auth_events_ids = yield self.compute_auth_events(
|
||||||
event, prev_state_ids, for_verification=True
|
event, prev_state_ids, for_verification=True
|
||||||
)
|
)
|
||||||
|
|
|
@ -149,7 +149,7 @@ class EventContext:
|
||||||
# the prev_state_ids, so if we're a state event we include the event
|
# the prev_state_ids, so if we're a state event we include the event
|
||||||
# id that we replaced in the state.
|
# id that we replaced in the state.
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
prev_state_ids = yield self.get_prev_state_ids(store)
|
prev_state_ids = yield self.get_prev_state_ids()
|
||||||
prev_state_id = prev_state_ids.get((event.type, event.state_key))
|
prev_state_id = prev_state_ids.get((event.type, event.state_key))
|
||||||
else:
|
else:
|
||||||
prev_state_id = None
|
prev_state_id = None
|
||||||
|
@ -167,12 +167,13 @@ class EventContext:
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def deserialize(store, input):
|
def deserialize(storage, input):
|
||||||
"""Converts a dict that was produced by `serialize` back into a
|
"""Converts a dict that was produced by `serialize` back into a
|
||||||
EventContext.
|
EventContext.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
store (DataStore): Used to convert AS ID to AS object
|
storage (Storage): Used to convert AS ID to AS object and fetch
|
||||||
|
state.
|
||||||
input (dict): A dict produced by `serialize`
|
input (dict): A dict produced by `serialize`
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -181,6 +182,7 @@ class EventContext:
|
||||||
context = _AsyncEventContextImpl(
|
context = _AsyncEventContextImpl(
|
||||||
# We use the state_group and prev_state_id stuff to pull the
|
# We use the state_group and prev_state_id stuff to pull the
|
||||||
# current_state_ids out of the DB and construct prev_state_ids.
|
# current_state_ids out of the DB and construct prev_state_ids.
|
||||||
|
storage=storage,
|
||||||
prev_state_id=input["prev_state_id"],
|
prev_state_id=input["prev_state_id"],
|
||||||
event_type=input["event_type"],
|
event_type=input["event_type"],
|
||||||
event_state_key=input["event_state_key"],
|
event_state_key=input["event_state_key"],
|
||||||
|
@ -193,7 +195,7 @@ class EventContext:
|
||||||
|
|
||||||
app_service_id = input["app_service_id"]
|
app_service_id = input["app_service_id"]
|
||||||
if app_service_id:
|
if app_service_id:
|
||||||
context.app_service = store.get_app_service_by_id(app_service_id)
|
context.app_service = storage.main.get_app_service_by_id(app_service_id)
|
||||||
|
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
@ -216,7 +218,7 @@ class EventContext:
|
||||||
return self._state_group
|
return self._state_group
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_current_state_ids(self, store):
|
def get_current_state_ids(self):
|
||||||
"""
|
"""
|
||||||
Gets the room state map, including this event - ie, the state in ``state_group``
|
Gets the room state map, including this event - ie, the state in ``state_group``
|
||||||
|
|
||||||
|
@ -234,11 +236,11 @@ class EventContext:
|
||||||
if self.rejected:
|
if self.rejected:
|
||||||
raise RuntimeError("Attempt to access state_ids of rejected event")
|
raise RuntimeError("Attempt to access state_ids of rejected event")
|
||||||
|
|
||||||
yield self._ensure_fetched(store)
|
yield self._ensure_fetched()
|
||||||
return self._current_state_ids
|
return self._current_state_ids
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_prev_state_ids(self, store):
|
def get_prev_state_ids(self):
|
||||||
"""
|
"""
|
||||||
Gets the room state map, excluding this event.
|
Gets the room state map, excluding this event.
|
||||||
|
|
||||||
|
@ -250,7 +252,7 @@ class EventContext:
|
||||||
Maps a (type, state_key) to the event ID of the state event matching
|
Maps a (type, state_key) to the event ID of the state event matching
|
||||||
this tuple.
|
this tuple.
|
||||||
"""
|
"""
|
||||||
yield self._ensure_fetched(store)
|
yield self._ensure_fetched()
|
||||||
return self._prev_state_ids
|
return self._prev_state_ids
|
||||||
|
|
||||||
def get_cached_current_state_ids(self):
|
def get_cached_current_state_ids(self):
|
||||||
|
@ -270,7 +272,7 @@ class EventContext:
|
||||||
|
|
||||||
return self._current_state_ids
|
return self._current_state_ids
|
||||||
|
|
||||||
def _ensure_fetched(self, store):
|
def _ensure_fetched(self):
|
||||||
return defer.succeed(None)
|
return defer.succeed(None)
|
||||||
|
|
||||||
|
|
||||||
|
@ -282,6 +284,8 @@ class _AsyncEventContextImpl(EventContext):
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
|
|
||||||
|
_storage (Storage)
|
||||||
|
|
||||||
_fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
|
_fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
|
||||||
been calculated. None if we haven't started calculating yet
|
been calculated. None if we haven't started calculating yet
|
||||||
|
|
||||||
|
@ -295,28 +299,30 @@ class _AsyncEventContextImpl(EventContext):
|
||||||
that was replaced.
|
that was replaced.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# This needs to have a default as we're inheriting
|
||||||
|
_storage = attr.ib(default=None)
|
||||||
_prev_state_id = attr.ib(default=None)
|
_prev_state_id = attr.ib(default=None)
|
||||||
_event_type = attr.ib(default=None)
|
_event_type = attr.ib(default=None)
|
||||||
_event_state_key = attr.ib(default=None)
|
_event_state_key = attr.ib(default=None)
|
||||||
_fetching_state_deferred = attr.ib(default=None)
|
_fetching_state_deferred = attr.ib(default=None)
|
||||||
|
|
||||||
def _ensure_fetched(self, store):
|
def _ensure_fetched(self):
|
||||||
if not self._fetching_state_deferred:
|
if not self._fetching_state_deferred:
|
||||||
self._fetching_state_deferred = run_in_background(
|
self._fetching_state_deferred = run_in_background(self._fill_out_state)
|
||||||
self._fill_out_state, store
|
|
||||||
)
|
|
||||||
|
|
||||||
return make_deferred_yieldable(self._fetching_state_deferred)
|
return make_deferred_yieldable(self._fetching_state_deferred)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _fill_out_state(self, store):
|
def _fill_out_state(self):
|
||||||
"""Called to populate the _current_state_ids and _prev_state_ids
|
"""Called to populate the _current_state_ids and _prev_state_ids
|
||||||
attributes by loading from the database.
|
attributes by loading from the database.
|
||||||
"""
|
"""
|
||||||
if self.state_group is None:
|
if self.state_group is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._current_state_ids = yield store.get_state_ids_for_group(self.state_group)
|
self._current_state_ids = yield self._storage.state.get_state_ids_for_group(
|
||||||
|
self.state_group
|
||||||
|
)
|
||||||
if self._prev_state_id and self._event_state_key is not None:
|
if self._prev_state_id and self._event_state_key is not None:
|
||||||
self._prev_state_ids = dict(self._current_state_ids)
|
self._prev_state_ids = dict(self._current_state_ids)
|
||||||
|
|
||||||
|
|
|
@ -53,7 +53,7 @@ class ThirdPartyEventRules(object):
|
||||||
if self.third_party_rules is None:
|
if self.third_party_rules is None:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
prev_state_ids = yield context.get_prev_state_ids()
|
||||||
|
|
||||||
# Retrieve the state events from the database.
|
# Retrieve the state events from the database.
|
||||||
state_events = {}
|
state_events = {}
|
||||||
|
|
|
@ -134,7 +134,7 @@ class BaseHandler(object):
|
||||||
guest_access = event.content.get("guest_access", "forbidden")
|
guest_access = event.content.get("guest_access", "forbidden")
|
||||||
if guest_access != "can_join":
|
if guest_access != "can_join":
|
||||||
if context:
|
if context:
|
||||||
current_state_ids = yield context.get_current_state_ids(self.store)
|
current_state_ids = yield context.get_current_state_ids()
|
||||||
current_state = yield self.store.get_events(
|
current_state = yield self.store.get_events(
|
||||||
list(current_state_ids.values())
|
list(current_state_ids.values())
|
||||||
)
|
)
|
||||||
|
|
|
@ -718,7 +718,7 @@ class FederationHandler(BaseHandler):
|
||||||
# changing their profile info.
|
# changing their profile info.
|
||||||
newly_joined = True
|
newly_joined = True
|
||||||
|
|
||||||
prev_state_ids = await context.get_prev_state_ids(self.store)
|
prev_state_ids = await context.get_prev_state_ids()
|
||||||
|
|
||||||
prev_state_id = prev_state_ids.get((event.type, event.state_key))
|
prev_state_id = prev_state_ids.get((event.type, event.state_key))
|
||||||
if prev_state_id:
|
if prev_state_id:
|
||||||
|
@ -1418,7 +1418,7 @@ class FederationHandler(BaseHandler):
|
||||||
user = UserID.from_string(event.state_key)
|
user = UserID.from_string(event.state_key)
|
||||||
yield self.user_joined_room(user, event.room_id)
|
yield self.user_joined_room(user, event.room_id)
|
||||||
|
|
||||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
prev_state_ids = yield context.get_prev_state_ids()
|
||||||
|
|
||||||
state_ids = list(prev_state_ids.values())
|
state_ids = list(prev_state_ids.values())
|
||||||
auth_chain = yield self.store.get_auth_chain(state_ids)
|
auth_chain = yield self.store.get_auth_chain(state_ids)
|
||||||
|
@ -1927,7 +1927,7 @@ class FederationHandler(BaseHandler):
|
||||||
context = yield self.state_handler.compute_event_context(event, old_state=state)
|
context = yield self.state_handler.compute_event_context(event, old_state=state)
|
||||||
|
|
||||||
if not auth_events:
|
if not auth_events:
|
||||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
prev_state_ids = yield context.get_prev_state_ids()
|
||||||
auth_events_ids = yield self.auth.compute_auth_events(
|
auth_events_ids = yield self.auth.compute_auth_events(
|
||||||
event, prev_state_ids, for_verification=True
|
event, prev_state_ids, for_verification=True
|
||||||
)
|
)
|
||||||
|
@ -2336,12 +2336,12 @@ class FederationHandler(BaseHandler):
|
||||||
k: a.event_id for k, a in iteritems(auth_events) if k != event_key
|
k: a.event_id for k, a in iteritems(auth_events) if k != event_key
|
||||||
}
|
}
|
||||||
|
|
||||||
current_state_ids = yield context.get_current_state_ids(self.store)
|
current_state_ids = yield context.get_current_state_ids()
|
||||||
current_state_ids = dict(current_state_ids)
|
current_state_ids = dict(current_state_ids)
|
||||||
|
|
||||||
current_state_ids.update(state_updates)
|
current_state_ids.update(state_updates)
|
||||||
|
|
||||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
prev_state_ids = yield context.get_prev_state_ids()
|
||||||
prev_state_ids = dict(prev_state_ids)
|
prev_state_ids = dict(prev_state_ids)
|
||||||
|
|
||||||
prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)})
|
prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)})
|
||||||
|
@ -2625,7 +2625,7 @@ class FederationHandler(BaseHandler):
|
||||||
event.content["third_party_invite"]["signed"]["token"],
|
event.content["third_party_invite"]["signed"]["token"],
|
||||||
)
|
)
|
||||||
original_invite = None
|
original_invite = None
|
||||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
prev_state_ids = yield context.get_prev_state_ids()
|
||||||
original_invite_id = prev_state_ids.get(key)
|
original_invite_id = prev_state_ids.get(key)
|
||||||
if original_invite_id:
|
if original_invite_id:
|
||||||
original_invite = yield self.store.get_event(
|
original_invite = yield self.store.get_event(
|
||||||
|
@ -2673,7 +2673,7 @@ class FederationHandler(BaseHandler):
|
||||||
signed = event.content["third_party_invite"]["signed"]
|
signed = event.content["third_party_invite"]["signed"]
|
||||||
token = signed["token"]
|
token = signed["token"]
|
||||||
|
|
||||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
prev_state_ids = yield context.get_prev_state_ids()
|
||||||
invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token))
|
invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token))
|
||||||
|
|
||||||
invite_event = None
|
invite_event = None
|
||||||
|
|
|
@ -515,7 +515,7 @@ class EventCreationHandler(object):
|
||||||
# federation as well as those created locally. As of room v3, aliases events
|
# federation as well as those created locally. As of room v3, aliases events
|
||||||
# can be created by users that are not in the room, therefore we have to
|
# can be created by users that are not in the room, therefore we have to
|
||||||
# tolerate them in event_auth.check().
|
# tolerate them in event_auth.check().
|
||||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
prev_state_ids = yield context.get_prev_state_ids()
|
||||||
prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
|
prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
|
||||||
prev_event = (
|
prev_event = (
|
||||||
yield self.store.get_event(prev_event_id, allow_none=True)
|
yield self.store.get_event(prev_event_id, allow_none=True)
|
||||||
|
@ -665,7 +665,7 @@ class EventCreationHandler(object):
|
||||||
If so, returns the version of the event in context.
|
If so, returns the version of the event in context.
|
||||||
Otherwise, returns None.
|
Otherwise, returns None.
|
||||||
"""
|
"""
|
||||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
prev_state_ids = yield context.get_prev_state_ids()
|
||||||
prev_event_id = prev_state_ids.get((event.type, event.state_key))
|
prev_event_id = prev_state_ids.get((event.type, event.state_key))
|
||||||
if not prev_event_id:
|
if not prev_event_id:
|
||||||
return
|
return
|
||||||
|
@ -914,7 +914,7 @@ class EventCreationHandler(object):
|
||||||
def is_inviter_member_event(e):
|
def is_inviter_member_event(e):
|
||||||
return e.type == EventTypes.Member and e.sender == event.sender
|
return e.type == EventTypes.Member and e.sender == event.sender
|
||||||
|
|
||||||
current_state_ids = yield context.get_current_state_ids(self.store)
|
current_state_ids = yield context.get_current_state_ids()
|
||||||
|
|
||||||
state_to_include_ids = [
|
state_to_include_ids = [
|
||||||
e_id
|
e_id
|
||||||
|
@ -967,7 +967,7 @@ class EventCreationHandler(object):
|
||||||
if original_event.room_id != event.room_id:
|
if original_event.room_id != event.room_id:
|
||||||
raise SynapseError(400, "Cannot redact event from a different room")
|
raise SynapseError(400, "Cannot redact event from a different room")
|
||||||
|
|
||||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
prev_state_ids = yield context.get_prev_state_ids()
|
||||||
auth_events_ids = yield self.auth.compute_auth_events(
|
auth_events_ids = yield self.auth.compute_auth_events(
|
||||||
event, prev_state_ids, for_verification=True
|
event, prev_state_ids, for_verification=True
|
||||||
)
|
)
|
||||||
|
@ -989,7 +989,7 @@ class EventCreationHandler(object):
|
||||||
event.internal_metadata.recheck_redaction = False
|
event.internal_metadata.recheck_redaction = False
|
||||||
|
|
||||||
if event.type == EventTypes.Create:
|
if event.type == EventTypes.Create:
|
||||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
prev_state_ids = yield context.get_prev_state_ids()
|
||||||
if prev_state_ids:
|
if prev_state_ids:
|
||||||
raise AuthError(403, "Changing the room create event is forbidden")
|
raise AuthError(403, "Changing the room create event is forbidden")
|
||||||
|
|
||||||
|
|
|
@ -184,7 +184,7 @@ class RoomCreationHandler(BaseHandler):
|
||||||
requester, tombstone_event, tombstone_context
|
requester, tombstone_event, tombstone_context
|
||||||
)
|
)
|
||||||
|
|
||||||
old_room_state = yield tombstone_context.get_current_state_ids(self.store)
|
old_room_state = yield tombstone_context.get_current_state_ids()
|
||||||
|
|
||||||
# update any aliases
|
# update any aliases
|
||||||
yield self._move_aliases_to_new_room(
|
yield self._move_aliases_to_new_room(
|
||||||
|
|
|
@ -193,7 +193,7 @@ class RoomMemberHandler(object):
|
||||||
requester, event, context, extra_users=[target], ratelimit=ratelimit
|
requester, event, context, extra_users=[target], ratelimit=ratelimit
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
prev_state_ids = yield context.get_prev_state_ids()
|
||||||
|
|
||||||
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
|
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
|
||||||
|
|
||||||
|
@ -601,7 +601,7 @@ class RoomMemberHandler(object):
|
||||||
if prev_event is not None:
|
if prev_event is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
prev_state_ids = yield context.get_prev_state_ids()
|
||||||
if event.membership == Membership.JOIN:
|
if event.membership == Membership.JOIN:
|
||||||
if requester.is_guest:
|
if requester.is_guest:
|
||||||
guest_can_join = yield self._can_guest_join(prev_state_ids)
|
guest_can_join = yield self._can_guest_join(prev_state_ids)
|
||||||
|
|
|
@ -116,7 +116,7 @@ class BulkPushRuleEvaluator(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_power_levels_and_sender_level(self, event, context):
|
def _get_power_levels_and_sender_level(self, event, context):
|
||||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
prev_state_ids = yield context.get_prev_state_ids()
|
||||||
pl_event_id = prev_state_ids.get(POWER_KEY)
|
pl_event_id = prev_state_ids.get(POWER_KEY)
|
||||||
if pl_event_id:
|
if pl_event_id:
|
||||||
# fastpath: if there's a power level event, that's all we need, and
|
# fastpath: if there's a power level event, that's all we need, and
|
||||||
|
@ -304,7 +304,7 @@ class RulesForRoom(object):
|
||||||
|
|
||||||
push_rules_delta_state_cache_metric.inc_hits()
|
push_rules_delta_state_cache_metric.inc_hits()
|
||||||
else:
|
else:
|
||||||
current_state_ids = yield context.get_current_state_ids(self.store)
|
current_state_ids = yield context.get_current_state_ids()
|
||||||
push_rules_delta_state_cache_metric.inc_misses()
|
push_rules_delta_state_cache_metric.inc_misses()
|
||||||
|
|
||||||
push_rules_state_size_counter.inc(len(current_state_ids))
|
push_rules_state_size_counter.inc(len(current_state_ids))
|
||||||
|
|
|
@ -51,6 +51,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
||||||
super(ReplicationFederationSendEventsRestServlet, self).__init__(hs)
|
super(ReplicationFederationSendEventsRestServlet, self).__init__(hs)
|
||||||
|
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
self.storage = hs.get_storage()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.federation_handler = hs.get_handlers().federation_handler
|
self.federation_handler = hs.get_handlers().federation_handler
|
||||||
|
|
||||||
|
@ -100,7 +101,9 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
||||||
EventType = event_type_from_format_version(format_ver)
|
EventType = event_type_from_format_version(format_ver)
|
||||||
event = EventType(event_dict, internal_metadata, rejected_reason)
|
event = EventType(event_dict, internal_metadata, rejected_reason)
|
||||||
|
|
||||||
context = EventContext.deserialize(self.store, event_payload["context"])
|
context = EventContext.deserialize(
|
||||||
|
self.storage, event_payload["context"]
|
||||||
|
)
|
||||||
|
|
||||||
event_and_contexts.append((event, context))
|
event_and_contexts.append((event, context))
|
||||||
|
|
||||||
|
|
|
@ -54,6 +54,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
|
||||||
|
|
||||||
self.event_creation_handler = hs.get_event_creation_handler()
|
self.event_creation_handler = hs.get_event_creation_handler()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
self.storage = hs.get_storage()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -100,7 +101,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
|
||||||
event = EventType(event_dict, internal_metadata, rejected_reason)
|
event = EventType(event_dict, internal_metadata, rejected_reason)
|
||||||
|
|
||||||
requester = Requester.deserialize(self.store, content["requester"])
|
requester = Requester.deserialize(self.store, content["requester"])
|
||||||
context = EventContext.deserialize(self.store, content["context"])
|
context = EventContext.deserialize(self.storage, content["context"])
|
||||||
|
|
||||||
ratelimit = content["ratelimit"]
|
ratelimit = content["ratelimit"]
|
||||||
extra_users = [UserID.from_string(u) for u in content["extra_users"]]
|
extra_users = [UserID.from_string(u) for u in content["extra_users"]]
|
||||||
|
|
|
@ -244,7 +244,7 @@ class PushRulesWorkerStore(
|
||||||
# To do this we set the state_group to a new object as object() != object()
|
# To do this we set the state_group to a new object as object() != object()
|
||||||
state_group = object()
|
state_group = object()
|
||||||
|
|
||||||
current_state_ids = yield context.get_current_state_ids(self)
|
current_state_ids = yield context.get_current_state_ids()
|
||||||
result = yield self._bulk_get_push_rules_for_room(
|
result = yield self._bulk_get_push_rules_for_room(
|
||||||
event.room_id, state_group, current_state_ids, event=event
|
event.room_id, state_group, current_state_ids, event=event
|
||||||
)
|
)
|
||||||
|
|
|
@ -477,7 +477,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
# To do this we set the state_group to a new object as object() != object()
|
# To do this we set the state_group to a new object as object() != object()
|
||||||
state_group = object()
|
state_group = object()
|
||||||
|
|
||||||
current_state_ids = yield context.get_current_state_ids(self)
|
current_state_ids = yield context.get_current_state_ids()
|
||||||
result = yield self._get_joined_users_from_context(
|
result = yield self._get_joined_users_from_context(
|
||||||
event.room_id, state_group, current_state_ids, event=event, context=context
|
event.room_id, state_group, current_state_ids, event=event, context=context
|
||||||
)
|
)
|
||||||
|
|
|
@ -209,7 +209,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
ctx_c = context_store["C"]
|
ctx_c = context_store["C"]
|
||||||
ctx_d = context_store["D"]
|
ctx_d = context_store["D"]
|
||||||
|
|
||||||
prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
|
prev_state_ids = yield ctx_d.get_prev_state_ids()
|
||||||
self.assertEqual(2, len(prev_state_ids))
|
self.assertEqual(2, len(prev_state_ids))
|
||||||
|
|
||||||
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
|
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
|
||||||
|
@ -253,7 +253,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
ctx_c = context_store["C"]
|
ctx_c = context_store["C"]
|
||||||
ctx_d = context_store["D"]
|
ctx_d = context_store["D"]
|
||||||
|
|
||||||
prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
|
prev_state_ids = yield ctx_d.get_prev_state_ids()
|
||||||
self.assertSetEqual(
|
self.assertSetEqual(
|
||||||
{"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()}
|
{"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()}
|
||||||
)
|
)
|
||||||
|
@ -312,7 +312,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
ctx_c = context_store["C"]
|
ctx_c = context_store["C"]
|
||||||
ctx_e = context_store["E"]
|
ctx_e = context_store["E"]
|
||||||
|
|
||||||
prev_state_ids = yield ctx_e.get_prev_state_ids(self.store)
|
prev_state_ids = yield ctx_e.get_prev_state_ids()
|
||||||
self.assertSetEqual(
|
self.assertSetEqual(
|
||||||
{"START", "A", "B", "C"}, {e for e in prev_state_ids.values()}
|
{"START", "A", "B", "C"}, {e for e in prev_state_ids.values()}
|
||||||
)
|
)
|
||||||
|
@ -387,7 +387,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
ctx_b = context_store["B"]
|
ctx_b = context_store["B"]
|
||||||
ctx_d = context_store["D"]
|
ctx_d = context_store["D"]
|
||||||
|
|
||||||
prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
|
prev_state_ids = yield ctx_d.get_prev_state_ids()
|
||||||
self.assertSetEqual(
|
self.assertSetEqual(
|
||||||
{"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()}
|
{"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()}
|
||||||
)
|
)
|
||||||
|
@ -419,10 +419,10 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
context = yield self.state.compute_event_context(event, old_state=old_state)
|
context = yield self.state.compute_event_context(event, old_state=old_state)
|
||||||
|
|
||||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
prev_state_ids = yield context.get_prev_state_ids()
|
||||||
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
|
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
|
||||||
|
|
||||||
current_state_ids = yield context.get_current_state_ids(self.store)
|
current_state_ids = yield context.get_current_state_ids()
|
||||||
self.assertCountEqual(
|
self.assertCountEqual(
|
||||||
(e.event_id for e in old_state), current_state_ids.values()
|
(e.event_id for e in old_state), current_state_ids.values()
|
||||||
)
|
)
|
||||||
|
@ -442,10 +442,10 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
context = yield self.state.compute_event_context(event, old_state=old_state)
|
context = yield self.state.compute_event_context(event, old_state=old_state)
|
||||||
|
|
||||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
prev_state_ids = yield context.get_prev_state_ids()
|
||||||
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
|
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
|
||||||
|
|
||||||
current_state_ids = yield context.get_current_state_ids(self.store)
|
current_state_ids = yield context.get_current_state_ids()
|
||||||
self.assertCountEqual(
|
self.assertCountEqual(
|
||||||
(e.event_id for e in old_state + [event]), current_state_ids.values()
|
(e.event_id for e in old_state + [event]), current_state_ids.values()
|
||||||
)
|
)
|
||||||
|
@ -479,7 +479,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
context = yield self.state.compute_event_context(event)
|
context = yield self.state.compute_event_context(event)
|
||||||
|
|
||||||
current_state_ids = yield context.get_current_state_ids(self.store)
|
current_state_ids = yield context.get_current_state_ids()
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
set([e.event_id for e in old_state]), set(current_state_ids.values())
|
set([e.event_id for e in old_state]), set(current_state_ids.values())
|
||||||
|
@ -511,7 +511,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
context = yield self.state.compute_event_context(event)
|
context = yield self.state.compute_event_context(event)
|
||||||
|
|
||||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
prev_state_ids = yield context.get_prev_state_ids()
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
set([e.event_id for e in old_state]), set(prev_state_ids.values())
|
set([e.event_id for e in old_state]), set(prev_state_ids.values())
|
||||||
|
@ -552,7 +552,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
|
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
|
||||||
)
|
)
|
||||||
|
|
||||||
current_state_ids = yield context.get_current_state_ids(self.store)
|
current_state_ids = yield context.get_current_state_ids()
|
||||||
|
|
||||||
self.assertEqual(len(current_state_ids), 6)
|
self.assertEqual(len(current_state_ids), 6)
|
||||||
|
|
||||||
|
@ -594,7 +594,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
|
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
|
||||||
)
|
)
|
||||||
|
|
||||||
current_state_ids = yield context.get_current_state_ids(self.store)
|
current_state_ids = yield context.get_current_state_ids()
|
||||||
|
|
||||||
self.assertEqual(len(current_state_ids), 6)
|
self.assertEqual(len(current_state_ids), 6)
|
||||||
|
|
||||||
|
@ -649,7 +649,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
|
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
|
||||||
)
|
)
|
||||||
|
|
||||||
current_state_ids = yield context.get_current_state_ids(self.store)
|
current_state_ids = yield context.get_current_state_ids()
|
||||||
|
|
||||||
self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
|
self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
|
||||||
|
|
||||||
|
@ -677,7 +677,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
|
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
|
||||||
)
|
)
|
||||||
|
|
||||||
current_state_ids = yield context.get_current_state_ids(self.store)
|
current_state_ids = yield context.get_current_state_ids()
|
||||||
|
|
||||||
self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])
|
self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue