Add maybe_awaitable and fix __init__ bugs

This commit is contained in:
Erik Johnston 2019-10-11 15:26:09 +01:00
parent c3b0fbe9c3
commit 3c2d6c708c
2 changed files with 34 additions and 2 deletions

View file

@ -44,6 +44,7 @@ from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
from synapse.rest.admin.users import UserAdminServlet from synapse.rest.admin.users import UserAdminServlet
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
from synapse.util.async_helpers import maybe_awaitable
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -310,7 +311,7 @@ class PurgeHistoryRestServlet(RestServlet):
errcode=Codes.BAD_JSON, errcode=Codes.BAD_JSON,
) )
purge_id = await self.pagination_handler.start_purge_history( purge_id = self.pagination_handler.start_purge_history(
room_id, token, delete_local_events=delete_local_events room_id, token, delete_local_events=delete_local_events
) )
@ -480,7 +481,9 @@ class ShutdownRoomRestServlet(RestServlet):
ratelimit=False, ratelimit=False,
) )
aliases_for_room = await self.store.get_aliases_for_room(room_id) aliases_for_room = await maybe_awaitable(
self.store.get_aliases_for_room(room_id)
)
await self.store.update_aliases_for_room( await self.store.update_aliases_for_room(
room_id, new_room_id, requester_user_id room_id, new_room_id, requester_user_id

View file

@ -21,6 +21,8 @@ from typing import Dict, Sequence, Set, Union
from six.moves import range from six.moves import range
import attr
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import CancelledError from twisted.internet.defer import CancelledError
from twisted.python import failure from twisted.python import failure
@ -483,3 +485,30 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
deferred.addCallbacks(success_cb, failure_cb) deferred.addCallbacks(success_cb, failure_cb)
return new_d return new_d
@attr.s(slots=True, frozen=True)
class DoneAwaitable(object):
"""Simple awaitable that returns the provided value.
"""
value = attr.ib()
def __await__(self):
return self
def __iter__(self):
return self
def __next__(self):
raise StopIteration(self.value)
def maybe_awaitable(value):
"""Convert a value to an awaitable if not already an awaitable.
"""
if hasattr(value, "__await__"):
return value
return DoneAwaitable(value)