Fixup from review comments.

This commit is contained in:
Erik Johnston 2019-07-04 11:07:09 +01:00
parent d0b849c86d
commit c061d4f237
2 changed files with 27 additions and 22 deletions

View file

@ -99,7 +99,7 @@ class AdminHandler(BaseHandler):
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def exfiltrate_user_data(self, user_id, writer): def export_user_data(self, user_id, writer):
"""Write all data we have on the user to the given writer. """Write all data we have on the user to the given writer.
Args: Args:
@ -107,7 +107,8 @@ class AdminHandler(BaseHandler):
writer (ExfiltrationWriter) writer (ExfiltrationWriter)
Returns: Returns:
defer.Deferred defer.Deferred: Resolves when all data for a user has been written.
The returned value is that returned by `writer.finished()`.
""" """
# Get all rooms the user is in or has been in # Get all rooms the user is in or has been in
rooms = yield self.store.get_rooms_for_user_where_membership_is( rooms = yield self.store.get_rooms_for_user_where_membership_is(
@ -134,7 +135,7 @@ class AdminHandler(BaseHandler):
forgotten = yield self.store.did_forget(user_id, room_id) forgotten = yield self.store.did_forget(user_id, room_id)
if forgotten: if forgotten:
logger.info("[%s] User forgot room %d, ignoring", room_id) logger.info("[%s] User forgot room %d, ignoring", user_id, room_id)
continue continue
if room_id not in rooms_user_has_been_in: if room_id not in rooms_user_has_been_in:
@ -172,9 +173,10 @@ class AdminHandler(BaseHandler):
# dict[str, set[str]]. # dict[str, set[str]].
event_to_unseen_prevs = {} event_to_unseen_prevs = {}
# The reverse mapping to above, i.e. map from unseen event to parent # The reverse mapping to above, i.e. map from unseen event to events
# events. dict[str, set[str]] # that have the unseen event in their prev_events, i.e. the unseen
unseen_event_to_parents = {} # events "children". dict[str, set[str]]
unseen_to_child_events = {}
# We fetch events in the room the user could see by fetching *all* # We fetch events in the room the user could see by fetching *all*
# events that we have and then filtering, this isn't the most # events that we have and then filtering, this isn't the most
@ -200,14 +202,14 @@ class AdminHandler(BaseHandler):
if unseen_events: if unseen_events:
event_to_unseen_prevs[event.event_id] = unseen_events event_to_unseen_prevs[event.event_id] = unseen_events
for unseen in unseen_events: for unseen in unseen_events:
unseen_event_to_parents.setdefault(unseen, set()).add( unseen_to_child_events.setdefault(unseen, set()).add(
event.event_id event.event_id
) )
# Now check if this event is an unseen prev event, if so # Now check if this event is an unseen prev event, if so
# then we remove this event from the appropriate dicts. # then we remove this event from the appropriate dicts.
for event_id in unseen_event_to_parents.pop(event.event_id, []): for child_id in unseen_to_child_events.pop(event.event_id, []):
event_to_unseen_prevs.get(event_id, set()).discard( event_to_unseen_prevs.get(child_id, set()).discard(
event.event_id event.event_id
) )
@ -233,7 +235,7 @@ class AdminHandler(BaseHandler):
class ExfiltrationWriter(object): class ExfiltrationWriter(object):
"""Interface used to specify how to write exfiltrated data. """Interface used to specify how to write exported data.
""" """
def write_events(self, room_id, events): def write_events(self, room_id, events):
@ -254,7 +256,7 @@ class ExfiltrationWriter(object):
Args: Args:
room_id (str) room_id (str)
event_id (str) event_id (str)
state (list[FrozenEvent]) state (dict[tuple[str, str], FrozenEvent])
""" """
pass pass
@ -264,13 +266,16 @@ class ExfiltrationWriter(object):
Args: Args:
room_id (str) room_id (str)
event (FrozenEvent) event (FrozenEvent)
state (list[dict]): A subset of the state at the invite, with a state (dict[tuple[str, str], dict]): A subset of the state at the
subset of the event keys (type, state_key, content and sender) invite, with a subset of the event keys (type, state_key
content and sender)
""" """
def finished(self): def finished(self):
"""Called when exfiltration is complete, and the return valus is passed """Called when all data has succesfully been exported and written.
to the requester.
This functions return value is passed to the caller of
`export_user_data`.
""" """
pass pass
@ -281,7 +286,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
Returns the directory location on completion. Returns the directory location on completion.
Args: Args:
user_id (str): The user whose data is being exfiltrated. user_id (str): The user whose data is being exported.
directory (str|None): The directory to write the data to. If None then directory (str|None): The directory to write the data to. If None then
will write to a temporary directory. will write to a temporary directory.
""" """
@ -293,7 +298,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
self.base_directory = directory self.base_directory = directory
else: else:
self.base_directory = tempfile.mkdtemp( self.base_directory = tempfile.mkdtemp(
prefix="synapse-exfiltrate__%s__" % (user_id,) prefix="synapse-exported__%s__" % (user_id,)
) )
os.makedirs(self.base_directory, exist_ok=True) os.makedirs(self.base_directory, exist_ok=True)

View file

@ -55,7 +55,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
writer = Mock() writer = Mock()
self.get_success(self.admin_handler.exfiltrate_user_data(self.user2, writer)) self.get_success(self.admin_handler.export_user_data(self.user2, writer))
writer.write_events.assert_called() writer.write_events.assert_called()
@ -94,7 +94,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
writer = Mock() writer = Mock()
self.get_success(self.admin_handler.exfiltrate_user_data(self.user2, writer)) self.get_success(self.admin_handler.export_user_data(self.user2, writer))
writer.write_events.assert_called() writer.write_events.assert_called()
@ -127,7 +127,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
writer = Mock() writer = Mock()
self.get_success(self.admin_handler.exfiltrate_user_data(self.user2, writer)) self.get_success(self.admin_handler.export_user_data(self.user2, writer))
writer.write_events.assert_called() writer.write_events.assert_called()
@ -169,7 +169,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
writer = Mock() writer = Mock()
self.get_success(self.admin_handler.exfiltrate_user_data(self.user2, writer)) self.get_success(self.admin_handler.export_user_data(self.user2, writer))
writer.write_events.assert_called_once() writer.write_events.assert_called_once()
@ -198,7 +198,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
writer = Mock() writer = Mock()
self.get_success(self.admin_handler.exfiltrate_user_data(self.user2, writer)) self.get_success(self.admin_handler.export_user_data(self.user2, writer))
writer.write_events.assert_not_called() writer.write_events.assert_not_called()
writer.write_state.assert_not_called() writer.write_state.assert_not_called()