Take a snapshot of the state of the room before performing updates

This commit is contained in:
Mark Haines 2014-08-22 17:00:10 +01:00
parent d12a7c3939
commit 1379dcae6f
6 changed files with 162 additions and 58 deletions

View file

@ -34,7 +34,7 @@ class Auth(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def check(self, event, raises=False): def check(self, event, snapshot, raises=False):
""" Checks if this event is correctly authed. """ Checks if this event is correctly authed.
Returns: Returns:
@ -46,7 +46,11 @@ class Auth(object):
try: try:
if event.type in [RoomTopicEvent.TYPE, MessageEvent.TYPE, if event.type in [RoomTopicEvent.TYPE, MessageEvent.TYPE,
FeedbackEvent.TYPE]: FeedbackEvent.TYPE]:
yield self.check_joined_room(event.room_id, event.user_id) self._check_joined_room(
member=snapshot.membership_state,
user_id=snapshot.user_id,
room_id=snapshot.room_id,
)
defer.returnValue(True) defer.returnValue(True)
elif event.type == RoomMemberEvent.TYPE: elif event.type == RoomMemberEvent.TYPE:
allowed = yield self.is_membership_change_allowed(event) allowed = yield self.is_membership_change_allowed(event)
@ -67,14 +71,16 @@ class Auth(object):
room_id=room_id, room_id=room_id,
user_id=user_id user_id=user_id
) )
if not member or member.membership != Membership.JOIN: self._check_joined_room(member, user_id, room_id)
raise AuthError(403, "User %s not in room %s" %
(user_id, room_id))
defer.returnValue(member) defer.returnValue(member)
except AttributeError: except AttributeError:
pass pass
defer.returnValue(None) defer.returnValue(None)
def _check_joined_room(self, member, user_id, room_id):
if not member or member.membership != Membership.JOIN:
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def is_membership_change_allowed(self, event): def is_membership_change_allowed(self, event):
# does this room even exist # does this room even exist

View file

@ -51,19 +51,20 @@ class FederationEventHandler(object):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_new_event(self, event): def handle_new_event(self, event, snapshot):
""" Takes in an event from the client to server side, that has already """ Takes in an event from the client to server side, that has already
been authed and handled by the state module, and sends it to any been authed and handled by the state module, and sends it to any
remote home servers that may be interested. remote home servers that may be interested.
Args: Args:
event event
snapshot (.storage.Snapshot): THe snapshot the event happened after
Returns: Returns:
Deferred: Resolved when it has successfully been queued for Deferred: Resolved when it has successfully been queued for
processing. processing.
""" """
yield self.fill_out_prev_events(event) yield self.fill_out_prev_events(event, snapshot)
pdu = self.pdu_codec.pdu_from_event(event) pdu = self.pdu_codec.pdu_from_event(event)
@ -137,13 +138,11 @@ class FederationEventHandler(object):
yield self.event_handler.on_receive(new_state_event) yield self.event_handler.on_receive(new_state_event)
@defer.inlineCallbacks @defer.inlineCallbacks
def fill_out_prev_events(self, event): def fill_out_prev_events(self, event, snapshot):
if hasattr(event, "prev_events"): if hasattr(event, "prev_events"):
return return
results = yield self.store.get_latest_pdus_in_context( results = snapshot.prev_pdus
event.room_id
)
es = [ es = [
"%s@%s" % (p_id, origin) for p_id, origin, _ in results "%s@%s" % (p_id, origin) for p_id, origin, _ in results

View file

@ -85,9 +85,10 @@ class MessageHandler(BaseHandler):
if stamp_event: if stamp_event:
event.content["hsob_ts"] = int(self.clock.time_msec()) event.content["hsob_ts"] = int(self.clock.time_msec())
with (yield self.room_lock.lock(event.room_id)): snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
if not suppress_auth: if not suppress_auth:
yield self.auth.check(event, raises=True) yield self.auth.check(event, snapshot, raises=True)
# store message in db # store message in db
store_id = yield self.store.persist_event(event) store_id = yield self.store.persist_event(event)
@ -98,7 +99,7 @@ class MessageHandler(BaseHandler):
self.notifier.on_new_room_event(event, store_id) self.notifier.on_new_room_event(event, store_id)
yield self.hs.get_federation().handle_new_event(event) yield self.hs.get_federation().handle_new_event(event, snapshot)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_messages(self, user_id=None, room_id=None, pagin_config=None, def get_messages(self, user_id=None, room_id=None, pagin_config=None,
@ -135,8 +136,9 @@ class MessageHandler(BaseHandler):
SynapseError if something went wrong. SynapseError if something went wrong.
""" """
with (yield self.room_lock.lock(event.room_id)): snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
yield self.auth.check(event, raises=True)
yield self.auth.check(event, snapshot, raises=True)
if stamp_event: if stamp_event:
event.content["hsob_ts"] = int(self.clock.time_msec()) event.content["hsob_ts"] = int(self.clock.time_msec())
@ -151,7 +153,7 @@ class MessageHandler(BaseHandler):
) )
self.notifier.on_new_room_event(event, store_id) self.notifier.on_new_room_event(event, store_id)
yield self.hs.get_federation().handle_new_event(event) yield self.hs.get_federation().handle_new_event(event, snapshot)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_data(self, user_id=None, room_id=None, def get_room_data(self, user_id=None, room_id=None,
@ -220,8 +222,10 @@ class MessageHandler(BaseHandler):
if stamp_event: if stamp_event:
event.content["hsob_ts"] = int(self.clock.time_msec()) event.content["hsob_ts"] = int(self.clock.time_msec())
with (yield self.room_lock.lock(event.room_id)): snapshot = yield self.store.snapshot_room(event.room_id, user_id)
yield self.auth.check(event, raises=True)
yield self.auth.check(event, snapshot, raises=True)
# store message in db # store message in db
store_id = yield self.store.persist_event(event) store_id = yield self.store.persist_event(event)
@ -229,7 +233,7 @@ class MessageHandler(BaseHandler):
event.destinations = yield self.store.get_joined_hosts_for_room( event.destinations = yield self.store.get_joined_hosts_for_room(
event.room_id event.room_id
) )
yield self.hs.get_federation().handle_new_event(event) yield self.hs.get_federation().handle_new_event(event, snapshot)
self.notifier.on_new_room_event(event, store_id) self.notifier.on_new_room_event(event, store_id)
@ -503,6 +507,11 @@ class RoomMemberHandler(BaseHandler):
SynapseError if there was a problem changing the membership. SynapseError if there was a problem changing the membership.
""" """
snapshot = yield self.store.snapshot_room(
event.room_id, event.user_id,
RoomMemberEvent.TYPE, event.target_user_id
)
## TODO(markjh): get prev state from snapshot.
prev_state = yield self.store.get_room_member( prev_state = yield self.store.get_room_member(
event.target_user_id, event.room_id event.target_user_id, event.room_id
) )
@ -523,24 +532,22 @@ class RoomMemberHandler(BaseHandler):
# if this HS is not currently in the room, i.e. we have to do the # if this HS is not currently in the room, i.e. we have to do the
# invite/join dance. # invite/join dance.
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
yield self._do_join(event, do_auth=do_auth) yield self._do_join(event, snapshot, do_auth=do_auth)
else: else:
# This is not a JOIN, so we can handle it normally. # This is not a JOIN, so we can handle it normally.
if do_auth: if do_auth:
yield self.auth.check(event, raises=True) yield self.auth.check(event, snapshot, raises=True)
prev_state = yield self.store.get_room_member(
event.target_user_id, event.room_id
)
if prev_state and prev_state.membership == event.membership: if prev_state and prev_state.membership == event.membership:
# double same action, treat this event as a NOOP. # double same action, treat this event as a NOOP.
defer.returnValue({}) defer.returnValue({})
return return
yield self.state_handler.handle_new_event(event) yield self.state_handler.handle_new_event(event, snapshot)
yield self._do_local_membership_update( yield self._do_local_membership_update(
event, event,
membership=event.content["membership"], membership=event.content["membership"],
snapshot=snapshot,
) )
defer.returnValue({"room_id": room_id}) defer.returnValue({"room_id": room_id})
@ -570,12 +577,16 @@ class RoomMemberHandler(BaseHandler):
content=content, content=content,
) )
yield self._do_join(new_event, room_host=host, do_auth=True) snapshot = yield store.snapshot_room(
room_id, joinee, RoomMemberEvent.TYPE, event.target_user_id
)
yield self._do_join(new_event, snapshot, room_host=host, do_auth=True)
defer.returnValue({"room_id": room_id}) defer.returnValue({"room_id": room_id})
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_join(self, event, room_host=None, do_auth=True): def _do_join(self, event, snapshot, room_host=None, do_auth=True):
joinee = self.hs.parse_userid(event.target_user_id) joinee = self.hs.parse_userid(event.target_user_id)
# room_id = RoomID.from_string(event.room_id, self.hs) # room_id = RoomID.from_string(event.room_id, self.hs)
room_id = event.room_id room_id = event.room_id
@ -597,6 +608,7 @@ class RoomMemberHandler(BaseHandler):
elif room_host: elif room_host:
should_do_dance = True should_do_dance = True
else: else:
# TODO(markjh): get prev_state from snapshot
prev_state = yield self.store.get_room_member( prev_state = yield self.store.get_room_member(
joinee.to_string(), room_id joinee.to_string(), room_id
) )
@ -624,12 +636,13 @@ class RoomMemberHandler(BaseHandler):
logger.debug("Doing normal join") logger.debug("Doing normal join")
if do_auth: if do_auth:
yield self.auth.check(event, raises=True) yield self.auth.check(event, snapshot, raises=True)
yield self.state_handler.handle_new_event(event) yield self.state_handler.handle_new_event(event, snapshot)
yield self._do_local_membership_update( yield self._do_local_membership_update(
event, event,
membership=event.content["membership"], membership=event.content["membership"],
snapshot=snapshot,
) )
user = self.hs.parse_userid(event.user_id) user = self.hs.parse_userid(event.user_id)
@ -674,7 +687,7 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue([r.room_id for r in rooms]) defer.returnValue([r.room_id for r in rooms])
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_local_membership_update(self, event, membership): def _do_local_membership_update(self, event, membership, snapshot):
# store membership # store membership
store_id = yield self.store.persist_event(event) store_id = yield self.store.persist_event(event)
@ -700,7 +713,7 @@ class RoomMemberHandler(BaseHandler):
event.destinations = list(set(destinations)) event.destinations = list(set(destinations))
yield self.hs.get_federation().handle_new_event(event) yield self.hs.get_federation().handle_new_event(event, snapshot)
self.notifier.on_new_room_event(event, store_id) self.notifier.on_new_room_event(event, store_id)

View file

@ -187,6 +187,70 @@ class DataStore(RoomMemberStore, RoomStore,
defer.returnValue(self.min_token) defer.returnValue(self.min_token)
def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
"""Snapshot the room for an update by a user
Args:
room_id (synapse.types.RoomId): The room to snapshot.
user_id (synapse.types.UserId): The user to snapshot the room for.
state_type (str): Optional state type to snapshot.
state_key (str): Optional state key to snapshot.
Returns:
synapse.storage.Snapshot: A snapshot of the state of the room.
"""
def _snapshot(txn):
membership_state = self._get_room_member(txn, user_id)
prev_pdus = self._get_latest_pdus_in_context(
txn, room_id
)
if state_type is not None and state_key is not None:
prev_state_pdu = self._get_current_state_pdu(
txn, room_id, state_type, state_key
)
else:
prev_state_pdu = None
return Snapshot(
store=self,
room_id=room_id,
user_id=user_id,
prev_pdus=prev_pdus,
membership_state=membership_state,
state_type=state_type,
state_key=state_key,
prev_state_pdu=prev_state_pdu,
)
return self._db_pool.runInteraction(_snapshot)
class Snapshot(object):
"""Snapshot of the state of a room
Args:
store (DataStore): The datastore.
room_id (RoomId): The room of the snapshot.
user_id (UserId): The user this snapshot is for.
prev_pdus (list): The list of PDU ids this snapshot is after.
membership_state (RoomMemberEvent): The current state of the user in
the room.
state_type (str, optional): State type captured by the snapshot
state_key (str, optional): State key captured by the snapshot
prev_state_pdu (PduEntry, optional): pdu id of
the previous value of the state type and key in the room.
"""
def __init__(self, store, room_id, user_id, prev_pdus,
membership_state, state_type=None, state_key=None,
prev_state_pdu=None):
self.store = store
self.room_id = room_id
self.user_id = user_id
self.prev_pdus = prev_pdus
self.membership_state
self.state_type = state_type
self.state_key = state_key
self.prev_state_pdu = prev_state_pdu
def schema_path(schema): def schema_path(schema):
""" Get a filesystem path for the named database schema """ Get a filesystem path for the named database schema

View file

@ -45,6 +45,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
"get_room_member", "get_room_member",
"get_room", "get_room",
"store_room", "store_room",
"snapshot_room",
]), ]),
resource_for_federation=NonCallableMock(), resource_for_federation=NonCallableMock(),
http_client=NonCallableMock(spec_set=[]), http_client=NonCallableMock(spec_set=[]),
@ -75,6 +76,10 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
self.handlers.profile_handler = ProfileHandler(self.hs) self.handlers.profile_handler = ProfileHandler(self.hs)
self.room_member_handler = self.handlers.room_member_handler self.room_member_handler = self.handlers.room_member_handler
self.snapshot = Mock()
self.datastore.snapshot_room.return_value = self.snapshot
@defer.inlineCallbacks @defer.inlineCallbacks
def test_invite(self): def test_invite(self):
room_id = "!foo:red" room_id = "!foo:red"
@ -104,8 +109,12 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
# Actual invocation # Actual invocation
yield self.room_member_handler.change_membership(event) yield self.room_member_handler.change_membership(event)
self.state_handler.handle_new_event.assert_called_once_with(event) self.state_handler.handle_new_event.assert_called_once_with(
self.federation.handle_new_event.assert_called_once_with(event) event, self.snapshot,
)
self.federation.handle_new_event.assert_called_once_with(
event, self.snapshot,
)
self.assertEquals( self.assertEquals(
set(["blue", "red", "green"]), set(["blue", "red", "green"]),
@ -116,7 +125,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
event event
) )
self.notifier.on_new_room_event.assert_called_once_with( self.notifier.on_new_room_event.assert_called_once_with(
event, store_id) event, store_id
)
self.assertFalse(self.datastore.get_room.called) self.assertFalse(self.datastore.get_room.called)
self.assertFalse(self.datastore.store_room.called) self.assertFalse(self.datastore.store_room.called)
@ -148,6 +158,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
self.datastore.get_joined_hosts_for_room.side_effect = get_joined self.datastore.get_joined_hosts_for_room.side_effect = get_joined
store_id = "store_id_fooo" store_id = "store_id_fooo"
self.datastore.persist_event.return_value = defer.succeed(store_id) self.datastore.persist_event.return_value = defer.succeed(store_id)
self.datastore.get_room.return_value = defer.succeed(1) # Not None. self.datastore.get_room.return_value = defer.succeed(1) # Not None.
@ -163,8 +174,12 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
# Actual invocation # Actual invocation
yield self.room_member_handler.change_membership(event) yield self.room_member_handler.change_membership(event)
self.state_handler.handle_new_event.assert_called_once_with(event) self.state_handler.handle_new_event.assert_called_once_with(
self.federation.handle_new_event.assert_called_once_with(event) event, self.snapshot
)
self.federation.handle_new_event.assert_called_once_with(
event, self.snapshot
)
self.assertEquals( self.assertEquals(
set(["red", "green"]), set(["red", "green"]),

View file

@ -127,6 +127,13 @@ class MemoryDataStore(object):
self.current_state = {} self.current_state = {}
self.events = [] self.events = []
Snapshot = namedtuple("Snapshot", "room_id user_id membership_state")
def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
return self.Snapshot(
room_id, user_id, self.get_room_member(user_id, room_id)
)
def register(self, user_id, token, password_hash): def register(self, user_id, token, password_hash):
if user_id in self.tokens_to_users.values(): if user_id in self.tokens_to_users.values():
raise StoreError(400, "User in use.") raise StoreError(400, "User in use.")