mirror of
https://github.com/element-hq/synapse.git
synced 2024-12-14 11:57:44 +00:00
Take a snapshot of the state of the room before performing updates
This commit is contained in:
parent
d12a7c3939
commit
1379dcae6f
6 changed files with 162 additions and 58 deletions
|
@ -34,7 +34,7 @@ class Auth(object):
|
|||
self.store = hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check(self, event, raises=False):
|
||||
def check(self, event, snapshot, raises=False):
|
||||
""" Checks if this event is correctly authed.
|
||||
|
||||
Returns:
|
||||
|
@ -46,7 +46,11 @@ class Auth(object):
|
|||
try:
|
||||
if event.type in [RoomTopicEvent.TYPE, MessageEvent.TYPE,
|
||||
FeedbackEvent.TYPE]:
|
||||
yield self.check_joined_room(event.room_id, event.user_id)
|
||||
self._check_joined_room(
|
||||
member=snapshot.membership_state,
|
||||
user_id=snapshot.user_id,
|
||||
room_id=snapshot.room_id,
|
||||
)
|
||||
defer.returnValue(True)
|
||||
elif event.type == RoomMemberEvent.TYPE:
|
||||
allowed = yield self.is_membership_change_allowed(event)
|
||||
|
@ -67,14 +71,16 @@ class Auth(object):
|
|||
room_id=room_id,
|
||||
user_id=user_id
|
||||
)
|
||||
if not member or member.membership != Membership.JOIN:
|
||||
raise AuthError(403, "User %s not in room %s" %
|
||||
(user_id, room_id))
|
||||
self._check_joined_room(member, user_id, room_id)
|
||||
defer.returnValue(member)
|
||||
except AttributeError:
|
||||
pass
|
||||
defer.returnValue(None)
|
||||
|
||||
def _check_joined_room(self, member, user_id, room_id):
|
||||
if not member or member.membership != Membership.JOIN:
|
||||
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def is_membership_change_allowed(self, event):
|
||||
# does this room even exist
|
||||
|
|
|
@ -51,19 +51,20 @@ class FederationEventHandler(object):
|
|||
|
||||
@log_function
|
||||
@defer.inlineCallbacks
|
||||
def handle_new_event(self, event):
|
||||
def handle_new_event(self, event, snapshot):
|
||||
""" Takes in an event from the client to server side, that has already
|
||||
been authed and handled by the state module, and sends it to any
|
||||
remote home servers that may be interested.
|
||||
|
||||
Args:
|
||||
event
|
||||
snapshot (.storage.Snapshot): THe snapshot the event happened after
|
||||
|
||||
Returns:
|
||||
Deferred: Resolved when it has successfully been queued for
|
||||
processing.
|
||||
"""
|
||||
yield self.fill_out_prev_events(event)
|
||||
yield self.fill_out_prev_events(event, snapshot)
|
||||
|
||||
pdu = self.pdu_codec.pdu_from_event(event)
|
||||
|
||||
|
@ -137,13 +138,11 @@ class FederationEventHandler(object):
|
|||
yield self.event_handler.on_receive(new_state_event)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def fill_out_prev_events(self, event):
|
||||
def fill_out_prev_events(self, event, snapshot):
|
||||
if hasattr(event, "prev_events"):
|
||||
return
|
||||
|
||||
results = yield self.store.get_latest_pdus_in_context(
|
||||
event.room_id
|
||||
)
|
||||
results = snapshot.prev_pdus
|
||||
|
||||
es = [
|
||||
"%s@%s" % (p_id, origin) for p_id, origin, _ in results
|
||||
|
|
|
@ -85,9 +85,10 @@ class MessageHandler(BaseHandler):
|
|||
if stamp_event:
|
||||
event.content["hsob_ts"] = int(self.clock.time_msec())
|
||||
|
||||
with (yield self.room_lock.lock(event.room_id)):
|
||||
snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
|
||||
|
||||
if not suppress_auth:
|
||||
yield self.auth.check(event, raises=True)
|
||||
yield self.auth.check(event, snapshot, raises=True)
|
||||
|
||||
# store message in db
|
||||
store_id = yield self.store.persist_event(event)
|
||||
|
@ -98,7 +99,7 @@ class MessageHandler(BaseHandler):
|
|||
|
||||
self.notifier.on_new_room_event(event, store_id)
|
||||
|
||||
yield self.hs.get_federation().handle_new_event(event)
|
||||
yield self.hs.get_federation().handle_new_event(event, snapshot)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_messages(self, user_id=None, room_id=None, pagin_config=None,
|
||||
|
@ -135,8 +136,9 @@ class MessageHandler(BaseHandler):
|
|||
SynapseError if something went wrong.
|
||||
"""
|
||||
|
||||
with (yield self.room_lock.lock(event.room_id)):
|
||||
yield self.auth.check(event, raises=True)
|
||||
snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
|
||||
|
||||
yield self.auth.check(event, snapshot, raises=True)
|
||||
|
||||
if stamp_event:
|
||||
event.content["hsob_ts"] = int(self.clock.time_msec())
|
||||
|
@ -151,7 +153,7 @@ class MessageHandler(BaseHandler):
|
|||
)
|
||||
self.notifier.on_new_room_event(event, store_id)
|
||||
|
||||
yield self.hs.get_federation().handle_new_event(event)
|
||||
yield self.hs.get_federation().handle_new_event(event, snapshot)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_data(self, user_id=None, room_id=None,
|
||||
|
@ -220,8 +222,10 @@ class MessageHandler(BaseHandler):
|
|||
if stamp_event:
|
||||
event.content["hsob_ts"] = int(self.clock.time_msec())
|
||||
|
||||
with (yield self.room_lock.lock(event.room_id)):
|
||||
yield self.auth.check(event, raises=True)
|
||||
snapshot = yield self.store.snapshot_room(event.room_id, user_id)
|
||||
|
||||
|
||||
yield self.auth.check(event, snapshot, raises=True)
|
||||
|
||||
# store message in db
|
||||
store_id = yield self.store.persist_event(event)
|
||||
|
@ -229,7 +233,7 @@ class MessageHandler(BaseHandler):
|
|||
event.destinations = yield self.store.get_joined_hosts_for_room(
|
||||
event.room_id
|
||||
)
|
||||
yield self.hs.get_federation().handle_new_event(event)
|
||||
yield self.hs.get_federation().handle_new_event(event, snapshot)
|
||||
|
||||
self.notifier.on_new_room_event(event, store_id)
|
||||
|
||||
|
@ -503,6 +507,11 @@ class RoomMemberHandler(BaseHandler):
|
|||
SynapseError if there was a problem changing the membership.
|
||||
"""
|
||||
|
||||
snapshot = yield self.store.snapshot_room(
|
||||
event.room_id, event.user_id,
|
||||
RoomMemberEvent.TYPE, event.target_user_id
|
||||
)
|
||||
## TODO(markjh): get prev state from snapshot.
|
||||
prev_state = yield self.store.get_room_member(
|
||||
event.target_user_id, event.room_id
|
||||
)
|
||||
|
@ -523,24 +532,22 @@ class RoomMemberHandler(BaseHandler):
|
|||
# if this HS is not currently in the room, i.e. we have to do the
|
||||
# invite/join dance.
|
||||
if event.membership == Membership.JOIN:
|
||||
yield self._do_join(event, do_auth=do_auth)
|
||||
yield self._do_join(event, snapshot, do_auth=do_auth)
|
||||
else:
|
||||
# This is not a JOIN, so we can handle it normally.
|
||||
if do_auth:
|
||||
yield self.auth.check(event, raises=True)
|
||||
yield self.auth.check(event, snapshot, raises=True)
|
||||
|
||||
prev_state = yield self.store.get_room_member(
|
||||
event.target_user_id, event.room_id
|
||||
)
|
||||
if prev_state and prev_state.membership == event.membership:
|
||||
# double same action, treat this event as a NOOP.
|
||||
defer.returnValue({})
|
||||
return
|
||||
|
||||
yield self.state_handler.handle_new_event(event)
|
||||
yield self.state_handler.handle_new_event(event, snapshot)
|
||||
yield self._do_local_membership_update(
|
||||
event,
|
||||
membership=event.content["membership"],
|
||||
snapshot=snapshot,
|
||||
)
|
||||
|
||||
defer.returnValue({"room_id": room_id})
|
||||
|
@ -570,12 +577,16 @@ class RoomMemberHandler(BaseHandler):
|
|||
content=content,
|
||||
)
|
||||
|
||||
yield self._do_join(new_event, room_host=host, do_auth=True)
|
||||
snapshot = yield store.snapshot_room(
|
||||
room_id, joinee, RoomMemberEvent.TYPE, event.target_user_id
|
||||
)
|
||||
|
||||
yield self._do_join(new_event, snapshot, room_host=host, do_auth=True)
|
||||
|
||||
defer.returnValue({"room_id": room_id})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_join(self, event, room_host=None, do_auth=True):
|
||||
def _do_join(self, event, snapshot, room_host=None, do_auth=True):
|
||||
joinee = self.hs.parse_userid(event.target_user_id)
|
||||
# room_id = RoomID.from_string(event.room_id, self.hs)
|
||||
room_id = event.room_id
|
||||
|
@ -597,6 +608,7 @@ class RoomMemberHandler(BaseHandler):
|
|||
elif room_host:
|
||||
should_do_dance = True
|
||||
else:
|
||||
# TODO(markjh): get prev_state from snapshot
|
||||
prev_state = yield self.store.get_room_member(
|
||||
joinee.to_string(), room_id
|
||||
)
|
||||
|
@ -624,12 +636,13 @@ class RoomMemberHandler(BaseHandler):
|
|||
logger.debug("Doing normal join")
|
||||
|
||||
if do_auth:
|
||||
yield self.auth.check(event, raises=True)
|
||||
yield self.auth.check(event, snapshot, raises=True)
|
||||
|
||||
yield self.state_handler.handle_new_event(event)
|
||||
yield self.state_handler.handle_new_event(event, snapshot)
|
||||
yield self._do_local_membership_update(
|
||||
event,
|
||||
membership=event.content["membership"],
|
||||
snapshot=snapshot,
|
||||
)
|
||||
|
||||
user = self.hs.parse_userid(event.user_id)
|
||||
|
@ -674,7 +687,7 @@ class RoomMemberHandler(BaseHandler):
|
|||
defer.returnValue([r.room_id for r in rooms])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_local_membership_update(self, event, membership):
|
||||
def _do_local_membership_update(self, event, membership, snapshot):
|
||||
# store membership
|
||||
store_id = yield self.store.persist_event(event)
|
||||
|
||||
|
@ -700,7 +713,7 @@ class RoomMemberHandler(BaseHandler):
|
|||
|
||||
event.destinations = list(set(destinations))
|
||||
|
||||
yield self.hs.get_federation().handle_new_event(event)
|
||||
yield self.hs.get_federation().handle_new_event(event, snapshot)
|
||||
self.notifier.on_new_room_event(event, store_id)
|
||||
|
||||
|
||||
|
|
|
@ -187,6 +187,70 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
defer.returnValue(self.min_token)
|
||||
|
||||
|
||||
def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
|
||||
"""Snapshot the room for an update by a user
|
||||
Args:
|
||||
room_id (synapse.types.RoomId): The room to snapshot.
|
||||
user_id (synapse.types.UserId): The user to snapshot the room for.
|
||||
state_type (str): Optional state type to snapshot.
|
||||
state_key (str): Optional state key to snapshot.
|
||||
Returns:
|
||||
synapse.storage.Snapshot: A snapshot of the state of the room.
|
||||
"""
|
||||
def _snapshot(txn):
|
||||
membership_state = self._get_room_member(txn, user_id)
|
||||
prev_pdus = self._get_latest_pdus_in_context(
|
||||
txn, room_id
|
||||
)
|
||||
if state_type is not None and state_key is not None:
|
||||
prev_state_pdu = self._get_current_state_pdu(
|
||||
txn, room_id, state_type, state_key
|
||||
)
|
||||
else:
|
||||
prev_state_pdu = None
|
||||
|
||||
return Snapshot(
|
||||
store=self,
|
||||
room_id=room_id,
|
||||
user_id=user_id,
|
||||
prev_pdus=prev_pdus,
|
||||
membership_state=membership_state,
|
||||
state_type=state_type,
|
||||
state_key=state_key,
|
||||
prev_state_pdu=prev_state_pdu,
|
||||
)
|
||||
|
||||
return self._db_pool.runInteraction(_snapshot)
|
||||
|
||||
|
||||
class Snapshot(object):
|
||||
"""Snapshot of the state of a room
|
||||
Args:
|
||||
store (DataStore): The datastore.
|
||||
room_id (RoomId): The room of the snapshot.
|
||||
user_id (UserId): The user this snapshot is for.
|
||||
prev_pdus (list): The list of PDU ids this snapshot is after.
|
||||
membership_state (RoomMemberEvent): The current state of the user in
|
||||
the room.
|
||||
state_type (str, optional): State type captured by the snapshot
|
||||
state_key (str, optional): State key captured by the snapshot
|
||||
prev_state_pdu (PduEntry, optional): pdu id of
|
||||
the previous value of the state type and key in the room.
|
||||
"""
|
||||
|
||||
def __init__(self, store, room_id, user_id, prev_pdus,
|
||||
membership_state, state_type=None, state_key=None,
|
||||
prev_state_pdu=None):
|
||||
self.store = store
|
||||
self.room_id = room_id
|
||||
self.user_id = user_id
|
||||
self.prev_pdus = prev_pdus
|
||||
self.membership_state
|
||||
self.state_type = state_type
|
||||
self.state_key = state_key
|
||||
self.prev_state_pdu = prev_state_pdu
|
||||
|
||||
|
||||
def schema_path(schema):
|
||||
""" Get a filesystem path for the named database schema
|
||||
|
||||
|
|
|
@ -45,6 +45,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
|||
"get_room_member",
|
||||
"get_room",
|
||||
"store_room",
|
||||
"snapshot_room",
|
||||
]),
|
||||
resource_for_federation=NonCallableMock(),
|
||||
http_client=NonCallableMock(spec_set=[]),
|
||||
|
@ -75,6 +76,10 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
|||
self.handlers.profile_handler = ProfileHandler(self.hs)
|
||||
self.room_member_handler = self.handlers.room_member_handler
|
||||
|
||||
self.snapshot = Mock()
|
||||
self.datastore.snapshot_room.return_value = self.snapshot
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_invite(self):
|
||||
room_id = "!foo:red"
|
||||
|
@ -104,8 +109,12 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
|||
# Actual invocation
|
||||
yield self.room_member_handler.change_membership(event)
|
||||
|
||||
self.state_handler.handle_new_event.assert_called_once_with(event)
|
||||
self.federation.handle_new_event.assert_called_once_with(event)
|
||||
self.state_handler.handle_new_event.assert_called_once_with(
|
||||
event, self.snapshot,
|
||||
)
|
||||
self.federation.handle_new_event.assert_called_once_with(
|
||||
event, self.snapshot,
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
set(["blue", "red", "green"]),
|
||||
|
@ -116,7 +125,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
|||
event
|
||||
)
|
||||
self.notifier.on_new_room_event.assert_called_once_with(
|
||||
event, store_id)
|
||||
event, store_id
|
||||
)
|
||||
|
||||
self.assertFalse(self.datastore.get_room.called)
|
||||
self.assertFalse(self.datastore.store_room.called)
|
||||
|
@ -148,6 +158,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
|||
|
||||
self.datastore.get_joined_hosts_for_room.side_effect = get_joined
|
||||
|
||||
|
||||
store_id = "store_id_fooo"
|
||||
self.datastore.persist_event.return_value = defer.succeed(store_id)
|
||||
self.datastore.get_room.return_value = defer.succeed(1) # Not None.
|
||||
|
@ -163,8 +174,12 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
|||
# Actual invocation
|
||||
yield self.room_member_handler.change_membership(event)
|
||||
|
||||
self.state_handler.handle_new_event.assert_called_once_with(event)
|
||||
self.federation.handle_new_event.assert_called_once_with(event)
|
||||
self.state_handler.handle_new_event.assert_called_once_with(
|
||||
event, self.snapshot
|
||||
)
|
||||
self.federation.handle_new_event.assert_called_once_with(
|
||||
event, self.snapshot
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
set(["red", "green"]),
|
||||
|
|
|
@ -127,6 +127,13 @@ class MemoryDataStore(object):
|
|||
self.current_state = {}
|
||||
self.events = []
|
||||
|
||||
Snapshot = namedtuple("Snapshot", "room_id user_id membership_state")
|
||||
|
||||
def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
|
||||
return self.Snapshot(
|
||||
room_id, user_id, self.get_room_member(user_id, room_id)
|
||||
)
|
||||
|
||||
def register(self, user_id, token, password_hash):
|
||||
if user_id in self.tokens_to_users.values():
|
||||
raise StoreError(400, "User in use.")
|
||||
|
|
Loading…
Reference in a new issue