mirror of
https://github.com/element-hq/synapse.git
synced 2024-12-14 11:57:44 +00:00
Pass around the reactor explicitly (#3385)
This commit is contained in:
parent
c2eff937ac
commit
77ac14b960
25 changed files with 141 additions and 93 deletions
|
@ -33,6 +33,7 @@ import logging
|
|||
import bcrypt
|
||||
import pymacaroons
|
||||
import simplejson
|
||||
import attr
|
||||
|
||||
import synapse.util.stringutils as stringutils
|
||||
|
||||
|
@ -854,7 +855,11 @@ class AuthHandler(BaseHandler):
|
|||
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
|
||||
bcrypt.gensalt(self.bcrypt_rounds))
|
||||
|
||||
return make_deferred_yieldable(threads.deferToThread(_do_hash))
|
||||
return make_deferred_yieldable(
|
||||
threads.deferToThreadPool(
|
||||
self.hs.get_reactor(), self.hs.get_reactor().getThreadPool(), _do_hash
|
||||
),
|
||||
)
|
||||
|
||||
def validate_hash(self, password, stored_hash):
|
||||
"""Validates that self.hash(password) == stored_hash.
|
||||
|
@ -874,16 +879,21 @@ class AuthHandler(BaseHandler):
|
|||
)
|
||||
|
||||
if stored_hash:
|
||||
return make_deferred_yieldable(threads.deferToThread(_do_validate_hash))
|
||||
return make_deferred_yieldable(
|
||||
threads.deferToThreadPool(
|
||||
self.hs.get_reactor(),
|
||||
self.hs.get_reactor().getThreadPool(),
|
||||
_do_validate_hash,
|
||||
),
|
||||
)
|
||||
else:
|
||||
return defer.succeed(False)
|
||||
|
||||
|
||||
class MacaroonGeneartor(object):
|
||||
def __init__(self, hs):
|
||||
self.clock = hs.get_clock()
|
||||
self.server_name = hs.config.server_name
|
||||
self.macaroon_secret_key = hs.config.macaroon_secret_key
|
||||
@attr.s
|
||||
class MacaroonGenerator(object):
|
||||
|
||||
hs = attr.ib()
|
||||
|
||||
def generate_access_token(self, user_id, extra_caveats=None):
|
||||
extra_caveats = extra_caveats or []
|
||||
|
@ -901,7 +911,7 @@ class MacaroonGeneartor(object):
|
|||
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
|
||||
macaroon = self._generate_base_macaroon(user_id)
|
||||
macaroon.add_first_party_caveat("type = login")
|
||||
now = self.clock.time_msec()
|
||||
now = self.hs.get_clock().time_msec()
|
||||
expiry = now + duration_in_ms
|
||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||
return macaroon.serialize()
|
||||
|
@ -913,9 +923,9 @@ class MacaroonGeneartor(object):
|
|||
|
||||
def _generate_base_macaroon(self, user_id):
|
||||
macaroon = pymacaroons.Macaroon(
|
||||
location=self.server_name,
|
||||
location=self.hs.config.server_name,
|
||||
identifier="key",
|
||||
key=self.macaroon_secret_key)
|
||||
key=self.hs.config.macaroon_secret_key)
|
||||
macaroon.add_first_party_caveat("gen = 1")
|
||||
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||
return macaroon
|
||||
|
|
|
@ -806,6 +806,7 @@ class EventCreationHandler(object):
|
|||
# If we're a worker we need to hit out to the master.
|
||||
if self.config.worker_app:
|
||||
yield send_event_to_master(
|
||||
self.hs.get_clock(),
|
||||
self.http_client,
|
||||
host=self.config.worker_replication_host,
|
||||
port=self.config.worker_replication_http_port,
|
||||
|
|
|
@ -19,7 +19,6 @@ from twisted.internet import defer
|
|||
from synapse.api.constants import EventTypes, JoinRules, Membership
|
||||
from synapse.storage.roommember import ProfileInfo
|
||||
from synapse.util.metrics import Measure
|
||||
from synapse.util.async import sleep
|
||||
from synapse.types import get_localpart_from_id
|
||||
|
||||
from six import iteritems
|
||||
|
@ -174,7 +173,7 @@ class UserDirectoryHandler(object):
|
|||
logger.info("Handling room %d/%d", num_processed_rooms + 1, len(room_ids))
|
||||
yield self._handle_initial_room(room_id)
|
||||
num_processed_rooms += 1
|
||||
yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
|
||||
yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
|
||||
|
||||
logger.info("Processed all rooms.")
|
||||
|
||||
|
@ -188,7 +187,7 @@ class UserDirectoryHandler(object):
|
|||
logger.info("Handling user %d/%d", num_processed_users + 1, len(user_ids))
|
||||
yield self._handle_local_user(user_id)
|
||||
num_processed_users += 1
|
||||
yield sleep(self.INITIAL_USER_SLEEP_MS / 1000.)
|
||||
yield self.clock.sleep(self.INITIAL_USER_SLEEP_MS / 1000.)
|
||||
|
||||
logger.info("Processed all users")
|
||||
|
||||
|
@ -236,7 +235,7 @@ class UserDirectoryHandler(object):
|
|||
count = 0
|
||||
for user_id in user_ids:
|
||||
if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
|
||||
yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
|
||||
yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
|
||||
|
||||
if not self.is_mine_id(user_id):
|
||||
count += 1
|
||||
|
@ -251,7 +250,7 @@ class UserDirectoryHandler(object):
|
|||
continue
|
||||
|
||||
if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
|
||||
yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
|
||||
yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
|
||||
count += 1
|
||||
|
||||
user_set = (user_id, other_user_id)
|
||||
|
|
|
@ -98,8 +98,8 @@ class SimpleHttpClient(object):
|
|||
method, uri, *args, **kwargs
|
||||
)
|
||||
add_timeout_to_deferred(
|
||||
request_deferred,
|
||||
60, cancelled_to_request_timed_out_error,
|
||||
request_deferred, 60, self.hs.get_reactor(),
|
||||
cancelled_to_request_timed_out_error,
|
||||
)
|
||||
response = yield make_deferred_yieldable(request_deferred)
|
||||
|
||||
|
@ -115,7 +115,7 @@ class SimpleHttpClient(object):
|
|||
"Error sending request to %s %s: %s %s",
|
||||
method, redact_uri(uri), type(e).__name__, e.message
|
||||
)
|
||||
raise e
|
||||
raise
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def post_urlencoded_get_json(self, uri, args={}, headers=None):
|
||||
|
|
|
@ -22,7 +22,7 @@ from twisted.web._newclient import ResponseDone
|
|||
from synapse.http import cancelled_to_request_timed_out_error
|
||||
from synapse.http.endpoint import matrix_federation_endpoint
|
||||
import synapse.metrics
|
||||
from synapse.util.async import sleep, add_timeout_to_deferred
|
||||
from synapse.util.async import add_timeout_to_deferred
|
||||
from synapse.util import logcontext
|
||||
from synapse.util.logcontext import make_deferred_yieldable
|
||||
import synapse.util.retryutils
|
||||
|
@ -193,6 +193,7 @@ class MatrixFederationHttpClient(object):
|
|||
add_timeout_to_deferred(
|
||||
request_deferred,
|
||||
timeout / 1000. if timeout else 60,
|
||||
self.hs.get_reactor(),
|
||||
cancelled_to_request_timed_out_error,
|
||||
)
|
||||
response = yield make_deferred_yieldable(
|
||||
|
@ -234,7 +235,7 @@ class MatrixFederationHttpClient(object):
|
|||
delay = min(delay, 2)
|
||||
delay *= random.uniform(0.8, 1.4)
|
||||
|
||||
yield sleep(delay)
|
||||
yield self.clock.sleep(delay)
|
||||
retries_left -= 1
|
||||
else:
|
||||
raise
|
||||
|
|
|
@ -161,6 +161,7 @@ class Notifier(object):
|
|||
self.user_to_user_stream = {}
|
||||
self.room_to_user_streams = {}
|
||||
|
||||
self.hs = hs
|
||||
self.event_sources = hs.get_event_sources()
|
||||
self.store = hs.get_datastore()
|
||||
self.pending_new_room_events = []
|
||||
|
@ -340,6 +341,7 @@ class Notifier(object):
|
|||
add_timeout_to_deferred(
|
||||
listener.deferred,
|
||||
(end_time - now) / 1000.,
|
||||
self.hs.get_reactor(),
|
||||
)
|
||||
with PreserveLoggingContext():
|
||||
yield listener.deferred
|
||||
|
@ -561,6 +563,7 @@ class Notifier(object):
|
|||
add_timeout_to_deferred(
|
||||
listener.deferred.addTimeout,
|
||||
(end_time - now) / 1000.,
|
||||
self.hs.get_reactor(),
|
||||
)
|
||||
try:
|
||||
with PreserveLoggingContext():
|
||||
|
|
|
@ -21,7 +21,6 @@ from synapse.api.errors import (
|
|||
from synapse.events import FrozenEvent
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.util.async import sleep
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.util.metrics import Measure
|
||||
from synapse.types import Requester, UserID
|
||||
|
@ -33,11 +32,12 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_event_to_master(client, host, port, requester, event, context,
|
||||
def send_event_to_master(clock, client, host, port, requester, event, context,
|
||||
ratelimit, extra_users):
|
||||
"""Send event to be handled on the master
|
||||
|
||||
Args:
|
||||
clock (synapse.util.Clock)
|
||||
client (SimpleHttpClient)
|
||||
host (str): host of master
|
||||
port (int): port on master listening for HTTP replication
|
||||
|
@ -77,7 +77,7 @@ def send_event_to_master(client, host, port, requester, event, context,
|
|||
|
||||
# If we timed out we probably don't need to worry about backing
|
||||
# off too much, but lets just wait a little anyway.
|
||||
yield sleep(1)
|
||||
yield clock.sleep(1)
|
||||
except MatrixCodeMessageException as e:
|
||||
# We convert to SynapseError as we know that it was a SynapseError
|
||||
# on the master process that we should send to the client. (And
|
||||
|
|
|
@ -58,6 +58,7 @@ UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
|
|||
|
||||
class MediaRepository(object):
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.client = MatrixFederationHttpClient(hs)
|
||||
self.clock = hs.get_clock()
|
||||
|
@ -94,7 +95,7 @@ class MediaRepository(object):
|
|||
storage_providers.append(provider)
|
||||
|
||||
self.media_storage = MediaStorage(
|
||||
self.primary_base_path, self.filepaths, storage_providers,
|
||||
self.hs, self.primary_base_path, self.filepaths, storage_providers,
|
||||
)
|
||||
|
||||
self.clock.looping_call(
|
||||
|
|
|
@ -37,13 +37,15 @@ class MediaStorage(object):
|
|||
"""Responsible for storing/fetching files from local sources.
|
||||
|
||||
Args:
|
||||
hs (synapse.server.Homeserver)
|
||||
local_media_directory (str): Base path where we store media on disk
|
||||
filepaths (MediaFilePaths)
|
||||
storage_providers ([StorageProvider]): List of StorageProvider that are
|
||||
used to fetch and store files.
|
||||
"""
|
||||
|
||||
def __init__(self, local_media_directory, filepaths, storage_providers):
|
||||
def __init__(self, hs, local_media_directory, filepaths, storage_providers):
|
||||
self.hs = hs
|
||||
self.local_media_directory = local_media_directory
|
||||
self.filepaths = filepaths
|
||||
self.storage_providers = storage_providers
|
||||
|
@ -175,7 +177,8 @@ class MediaStorage(object):
|
|||
res = yield provider.fetch(path, file_info)
|
||||
if res:
|
||||
with res:
|
||||
consumer = BackgroundFileConsumer(open(local_path, "w"))
|
||||
consumer = BackgroundFileConsumer(
|
||||
open(local_path, "w"), self.hs.get_reactor())
|
||||
yield res.write_to_consumer(consumer)
|
||||
yield consumer.wait()
|
||||
defer.returnValue(local_path)
|
||||
|
|
|
@ -40,7 +40,7 @@ from synapse.federation.transport.client import TransportLayerClient
|
|||
from synapse.federation.transaction_queue import TransactionQueue
|
||||
from synapse.handlers import Handlers
|
||||
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||
from synapse.handlers.auth import AuthHandler, MacaroonGeneartor
|
||||
from synapse.handlers.auth import AuthHandler, MacaroonGenerator
|
||||
from synapse.handlers.deactivate_account import DeactivateAccountHandler
|
||||
from synapse.handlers.devicemessage import DeviceMessageHandler
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
|
@ -165,15 +165,19 @@ class HomeServer(object):
|
|||
'server_notices_sender',
|
||||
]
|
||||
|
||||
def __init__(self, hostname, **kwargs):
|
||||
def __init__(self, hostname, reactor=None, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
hostname : The hostname for the server.
|
||||
"""
|
||||
if not reactor:
|
||||
from twisted.internet import reactor
|
||||
|
||||
self._reactor = reactor
|
||||
self.hostname = hostname
|
||||
self._building = {}
|
||||
|
||||
self.clock = Clock()
|
||||
self.clock = Clock(reactor)
|
||||
self.distributor = Distributor()
|
||||
self.ratelimiter = Ratelimiter()
|
||||
|
||||
|
@ -186,6 +190,12 @@ class HomeServer(object):
|
|||
self.datastore = DataStore(self.get_db_conn(), self)
|
||||
logger.info("Finished setting up.")
|
||||
|
||||
def get_reactor(self):
|
||||
"""
|
||||
Fetch the Twisted reactor in use by this HomeServer.
|
||||
"""
|
||||
return self._reactor
|
||||
|
||||
def get_ip_from_request(self, request):
|
||||
# X-Forwarded-For is handled by our custom request type.
|
||||
return request.getClientIP()
|
||||
|
@ -261,7 +271,7 @@ class HomeServer(object):
|
|||
return AuthHandler(self)
|
||||
|
||||
def build_macaroon_generator(self):
|
||||
return MacaroonGeneartor(self)
|
||||
return MacaroonGenerator(self)
|
||||
|
||||
def build_device_handler(self):
|
||||
return DeviceHandler(self)
|
||||
|
@ -328,6 +338,7 @@ class HomeServer(object):
|
|||
|
||||
return adbapi.ConnectionPool(
|
||||
name,
|
||||
cp_reactor=self.get_reactor(),
|
||||
**self.db_config.get("args", {})
|
||||
)
|
||||
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import synapse.util.async
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from . import engines
|
||||
|
@ -92,7 +91,7 @@ class BackgroundUpdateStore(SQLBaseStore):
|
|||
logger.info("Starting background schema updates")
|
||||
|
||||
while True:
|
||||
yield synapse.util.async.sleep(
|
||||
yield self.hs.get_clock().sleep(
|
||||
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.)
|
||||
|
||||
try:
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet import defer
|
||||
|
||||
from ._base import Cache
|
||||
from . import background_updates
|
||||
|
@ -70,7 +70,9 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
|||
self._client_ip_looper = self._clock.looping_call(
|
||||
self._update_client_ips_batch, 5 * 1000
|
||||
)
|
||||
reactor.addSystemEventTrigger("before", "shutdown", self._update_client_ips_batch)
|
||||
self.hs.get_reactor().addSystemEventTrigger(
|
||||
"before", "shutdown", self._update_client_ips_batch
|
||||
)
|
||||
|
||||
def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id,
|
||||
now=None):
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
|
||||
from synapse.storage._base import SQLBaseStore, LoggingTransaction
|
||||
from twisted.internet import defer
|
||||
from synapse.util.async import sleep
|
||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||
|
||||
import logging
|
||||
|
@ -800,7 +799,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
|||
)
|
||||
if caught_up:
|
||||
break
|
||||
yield sleep(5)
|
||||
yield self.hs.get_clock().sleep(5)
|
||||
finally:
|
||||
self._doing_notif_rotation = False
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
from ._base import SQLBaseStore
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.events.utils import prune_event
|
||||
|
@ -265,7 +265,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
except Exception:
|
||||
logger.exception("Failed to callback")
|
||||
with PreserveLoggingContext():
|
||||
reactor.callFromThread(fire, event_list, row_dict)
|
||||
self.hs.get_reactor().callFromThread(fire, event_list, row_dict)
|
||||
except Exception as e:
|
||||
logger.exception("do_fetch")
|
||||
|
||||
|
@ -278,7 +278,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
|
||||
if event_list:
|
||||
with PreserveLoggingContext():
|
||||
reactor.callFromThread(fire, event_list)
|
||||
self.hs.get_reactor().callFromThread(fire, event_list)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
|
||||
|
|
|
@ -13,15 +13,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
|
||||
from twisted.internet import defer, reactor, task
|
||||
|
||||
import time
|
||||
import logging
|
||||
|
||||
from itertools import islice
|
||||
|
||||
import attr
|
||||
from twisted.internet import defer, task
|
||||
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -31,16 +30,24 @@ def unwrapFirstError(failure):
|
|||
return failure.value.subFailure
|
||||
|
||||
|
||||
@attr.s
|
||||
class Clock(object):
|
||||
"""A small utility that obtains current time-of-day so that time may be
|
||||
mocked during unit-tests.
|
||||
|
||||
TODO(paul): Also move the sleep() functionality into it
|
||||
"""
|
||||
A Clock wraps a Twisted reactor and provides utilities on top of it.
|
||||
"""
|
||||
_reactor = attr.ib()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def sleep(self, seconds):
|
||||
d = defer.Deferred()
|
||||
with PreserveLoggingContext():
|
||||
self._reactor.callLater(seconds, d.callback, seconds)
|
||||
res = yield d
|
||||
defer.returnValue(res)
|
||||
|
||||
def time(self):
|
||||
"""Returns the current system time in seconds since epoch."""
|
||||
return time.time()
|
||||
return self._reactor.seconds()
|
||||
|
||||
def time_msec(self):
|
||||
"""Returns the current system time in miliseconds since epoch."""
|
||||
|
@ -56,6 +63,7 @@ class Clock(object):
|
|||
msec(float): How long to wait between calls in milliseconds.
|
||||
"""
|
||||
call = task.LoopingCall(f)
|
||||
call.clock = self._reactor
|
||||
call.start(msec / 1000.0, now=False)
|
||||
return call
|
||||
|
||||
|
@ -73,7 +81,7 @@ class Clock(object):
|
|||
callback(*args, **kwargs)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
return reactor.callLater(delay, wrapped_callback, *args, **kwargs)
|
||||
return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs)
|
||||
|
||||
def cancel_call_later(self, timer, ignore_errs=False):
|
||||
try:
|
||||
|
|
|
@ -13,14 +13,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import CancelledError
|
||||
from twisted.python import failure
|
||||
|
||||
from .logcontext import (
|
||||
PreserveLoggingContext, make_deferred_yieldable, run_in_background
|
||||
)
|
||||
from synapse.util import logcontext, unwrapFirstError
|
||||
from synapse.util import logcontext, unwrapFirstError, Clock
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
@ -31,15 +31,6 @@ from six.moves import range
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def sleep(seconds):
|
||||
d = defer.Deferred()
|
||||
with PreserveLoggingContext():
|
||||
reactor.callLater(seconds, d.callback, seconds)
|
||||
res = yield d
|
||||
defer.returnValue(res)
|
||||
|
||||
|
||||
class ObservableDeferred(object):
|
||||
"""Wraps a deferred object so that we can add observer deferreds. These
|
||||
observer deferreds do not affect the callback chain of the original
|
||||
|
@ -172,13 +163,18 @@ class Linearizer(object):
|
|||
# do some work.
|
||||
|
||||
"""
|
||||
def __init__(self, name=None):
|
||||
def __init__(self, name=None, clock=None):
|
||||
if name is None:
|
||||
self.name = id(self)
|
||||
else:
|
||||
self.name = name
|
||||
self.key_to_defer = {}
|
||||
|
||||
if not clock:
|
||||
from twisted.internet import reactor
|
||||
clock = Clock(reactor)
|
||||
self._clock = clock
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def queue(self, key):
|
||||
# If there is already a deferred in the queue, we pull it out so that
|
||||
|
@ -219,7 +215,7 @@ class Linearizer(object):
|
|||
# the context manager, but it needs to happen while we hold the
|
||||
# lock, and the context manager's exit code must be synchronous,
|
||||
# so actually this is the only sensible place.
|
||||
yield sleep(0)
|
||||
yield self._clock.sleep(0)
|
||||
|
||||
else:
|
||||
logger.info("Acquired uncontended linearizer lock %r for key %r",
|
||||
|
@ -396,7 +392,7 @@ class DeferredTimeoutError(Exception):
|
|||
"""
|
||||
|
||||
|
||||
def add_timeout_to_deferred(deferred, timeout, on_timeout_cancel=None):
|
||||
def add_timeout_to_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
|
||||
"""
|
||||
Add a timeout to a deferred by scheduling it to be cancelled after
|
||||
timeout seconds.
|
||||
|
@ -411,6 +407,7 @@ def add_timeout_to_deferred(deferred, timeout, on_timeout_cancel=None):
|
|||
Args:
|
||||
deferred (defer.Deferred): deferred to be timed out
|
||||
timeout (Number): seconds to time out after
|
||||
reactor (twisted.internet.reactor): the Twisted reactor to use
|
||||
|
||||
on_timeout_cancel (callable): A callable which is called immediately
|
||||
after the deferred times out, and not if this deferred is
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import threads, reactor
|
||||
from twisted.internet import threads
|
||||
|
||||
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
|
||||
|
||||
|
@ -27,6 +27,7 @@ class BackgroundFileConsumer(object):
|
|||
Args:
|
||||
file_obj (file): The file like object to write to. Closed when
|
||||
finished.
|
||||
reactor (twisted.internet.reactor): the Twisted reactor to use
|
||||
"""
|
||||
|
||||
# For PushProducers pause if we have this many unwritten slices
|
||||
|
@ -34,9 +35,11 @@ class BackgroundFileConsumer(object):
|
|||
# And resume once the size of the queue is less than this
|
||||
_RESUME_ON_QUEUE_SIZE = 2
|
||||
|
||||
def __init__(self, file_obj):
|
||||
def __init__(self, file_obj, reactor):
|
||||
self._file_obj = file_obj
|
||||
|
||||
self._reactor = reactor
|
||||
|
||||
# Producer we're registered with
|
||||
self._producer = None
|
||||
|
||||
|
@ -71,7 +74,10 @@ class BackgroundFileConsumer(object):
|
|||
self._producer = producer
|
||||
self.streaming = streaming
|
||||
self._finished_deferred = run_in_background(
|
||||
threads.deferToThread, self._writer
|
||||
threads.deferToThreadPool,
|
||||
self._reactor,
|
||||
self._reactor.getThreadPool(),
|
||||
self._writer,
|
||||
)
|
||||
if not streaming:
|
||||
self._producer.resumeProducing()
|
||||
|
@ -109,7 +115,7 @@ class BackgroundFileConsumer(object):
|
|||
# producer.
|
||||
if self._producer and self._paused_producer:
|
||||
if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE:
|
||||
reactor.callFromThread(self._resume_paused_producer)
|
||||
self._reactor.callFromThread(self._resume_paused_producer)
|
||||
|
||||
bytes = self._bytes_queue.get()
|
||||
|
||||
|
@ -121,7 +127,7 @@ class BackgroundFileConsumer(object):
|
|||
# If its a pull producer then we need to explicitly ask for
|
||||
# more stuff.
|
||||
if not self.streaming and self._producer:
|
||||
reactor.callFromThread(self._producer.resumeProducing)
|
||||
self._reactor.callFromThread(self._producer.resumeProducing)
|
||||
except Exception as e:
|
||||
self._write_exception = e
|
||||
raise
|
||||
|
|
|
@ -17,7 +17,6 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.api.errors import LimitExceededError
|
||||
|
||||
from synapse.util.async import sleep
|
||||
from synapse.util.logcontext import (
|
||||
run_in_background, make_deferred_yieldable,
|
||||
PreserveLoggingContext,
|
||||
|
@ -153,7 +152,7 @@ class _PerHostRatelimiter(object):
|
|||
"Ratelimit [%s]: sleeping req",
|
||||
id(request_id),
|
||||
)
|
||||
ret_defer = run_in_background(sleep, self.sleep_msec / 1000.0)
|
||||
ret_defer = run_in_background(self.clock.sleep, self.sleep_msec / 1000.0)
|
||||
|
||||
self.sleeping_requests.add(request_id)
|
||||
|
||||
|
|
|
@ -19,10 +19,10 @@ import signedjson.sign
|
|||
from mock import Mock
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.crypto import keyring
|
||||
from synapse.util import async, logcontext
|
||||
from synapse.util import logcontext, Clock
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from tests import unittest, utils
|
||||
from twisted.internet import defer
|
||||
from twisted.internet import defer, reactor
|
||||
|
||||
|
||||
class MockPerspectiveServer(object):
|
||||
|
@ -118,6 +118,7 @@ class KeyringTestCase(unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def test_verify_json_objects_for_server_awaits_previous_requests(self):
|
||||
clock = Clock(reactor)
|
||||
key1 = signedjson.key.generate_signing_key(1)
|
||||
|
||||
kr = keyring.Keyring(self.hs)
|
||||
|
@ -167,7 +168,7 @@ class KeyringTestCase(unittest.TestCase):
|
|||
|
||||
# wait a tick for it to send the request to the perspectives server
|
||||
# (it first tries the datastore)
|
||||
yield async.sleep(1) # XXX find out why this takes so long!
|
||||
yield clock.sleep(1) # XXX find out why this takes so long!
|
||||
self.http_client.post_json.assert_called_once()
|
||||
|
||||
self.assertIs(LoggingContext.current_context(), context_11)
|
||||
|
@ -183,7 +184,7 @@ class KeyringTestCase(unittest.TestCase):
|
|||
res_deferreds_2 = kr.verify_json_objects_for_server(
|
||||
[("server10", json1)],
|
||||
)
|
||||
yield async.sleep(1)
|
||||
yield clock.sleep(1)
|
||||
self.http_client.post_json.assert_not_called()
|
||||
res_deferreds_2[0].addBoth(self.check_context, None)
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
from synapse.rest.client.transactions import HttpTransactionCache
|
||||
from synapse.rest.client.transactions import CLEANUP_PERIOD_MS
|
||||
from twisted.internet import defer
|
||||
from twisted.internet import defer, reactor
|
||||
from mock import Mock, call
|
||||
|
||||
from synapse.util import async
|
||||
from synapse.util import Clock
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from tests import unittest
|
||||
from tests.utils import MockClock
|
||||
|
@ -46,7 +46,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
|
|||
def test_logcontexts_with_async_result(self):
|
||||
@defer.inlineCallbacks
|
||||
def cb():
|
||||
yield async.sleep(0)
|
||||
yield Clock(reactor).sleep(0)
|
||||
defer.returnValue("yay")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet import defer, reactor
|
||||
|
||||
from synapse.rest.media.v1._base import FileInfo
|
||||
from synapse.rest.media.v1.media_storage import MediaStorage
|
||||
|
@ -38,6 +38,7 @@ class MediaStorageTests(unittest.TestCase):
|
|||
self.secondary_base_path = os.path.join(self.test_dir, "secondary")
|
||||
|
||||
hs = Mock()
|
||||
hs.get_reactor = Mock(return_value=reactor)
|
||||
hs.config.media_store_path = self.primary_base_path
|
||||
|
||||
storage_providers = [FileStorageProviderBackend(
|
||||
|
@ -46,7 +47,7 @@ class MediaStorageTests(unittest.TestCase):
|
|||
|
||||
self.filepaths = MediaFilePaths(self.primary_base_path)
|
||||
self.media_storage = MediaStorage(
|
||||
self.primary_base_path, self.filepaths, storage_providers,
|
||||
hs, self.primary_base_path, self.filepaths, storage_providers,
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
|
|
|
@ -30,7 +30,7 @@ class FileConsumerTests(unittest.TestCase):
|
|||
@defer.inlineCallbacks
|
||||
def test_pull_consumer(self):
|
||||
string_file = StringIO()
|
||||
consumer = BackgroundFileConsumer(string_file)
|
||||
consumer = BackgroundFileConsumer(string_file, reactor=reactor)
|
||||
|
||||
try:
|
||||
producer = DummyPullProducer()
|
||||
|
@ -54,7 +54,7 @@ class FileConsumerTests(unittest.TestCase):
|
|||
@defer.inlineCallbacks
|
||||
def test_push_consumer(self):
|
||||
string_file = BlockingStringWrite()
|
||||
consumer = BackgroundFileConsumer(string_file)
|
||||
consumer = BackgroundFileConsumer(string_file, reactor=reactor)
|
||||
|
||||
try:
|
||||
producer = NonCallableMock(spec_set=[])
|
||||
|
@ -80,7 +80,7 @@ class FileConsumerTests(unittest.TestCase):
|
|||
@defer.inlineCallbacks
|
||||
def test_push_producer_feedback(self):
|
||||
string_file = BlockingStringWrite()
|
||||
consumer = BackgroundFileConsumer(string_file)
|
||||
consumer = BackgroundFileConsumer(string_file, reactor=reactor)
|
||||
|
||||
try:
|
||||
producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
|
||||
|
|
|
@ -12,10 +12,11 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from synapse.util import async, logcontext
|
||||
|
||||
from synapse.util import logcontext, Clock
|
||||
from tests import unittest
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet import defer, reactor
|
||||
|
||||
from synapse.util.async import Linearizer
|
||||
from six.moves import range
|
||||
|
@ -53,7 +54,7 @@ class LinearizerTestCase(unittest.TestCase):
|
|||
self.assertEqual(
|
||||
logcontext.LoggingContext.current_context(), lc)
|
||||
if sleep:
|
||||
yield async.sleep(0)
|
||||
yield Clock(reactor).sleep(0)
|
||||
|
||||
self.assertEqual(
|
||||
logcontext.LoggingContext.current_context(), lc)
|
||||
|
|
|
@ -3,8 +3,7 @@ from twisted.internet import defer
|
|||
from twisted.internet import reactor
|
||||
from .. import unittest
|
||||
|
||||
from synapse.util.async import sleep
|
||||
from synapse.util import logcontext
|
||||
from synapse.util import logcontext, Clock
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
|
||||
|
||||
|
@ -22,18 +21,20 @@ class LoggingContextTestCase(unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def test_sleep(self):
|
||||
clock = Clock(reactor)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def competing_callback():
|
||||
with LoggingContext() as competing_context:
|
||||
competing_context.request = "competing"
|
||||
yield sleep(0)
|
||||
yield clock.sleep(0)
|
||||
self._check_test_key("competing")
|
||||
|
||||
reactor.callLater(0, competing_callback)
|
||||
|
||||
with LoggingContext() as context_one:
|
||||
context_one.request = "one"
|
||||
yield sleep(0)
|
||||
yield clock.sleep(0)
|
||||
self._check_test_key("one")
|
||||
|
||||
def _test_run_in_background(self, function):
|
||||
|
@ -87,7 +88,7 @@ class LoggingContextTestCase(unittest.TestCase):
|
|||
def test_run_in_background_with_blocking_fn(self):
|
||||
@defer.inlineCallbacks
|
||||
def blocking_function():
|
||||
yield sleep(0)
|
||||
yield Clock(reactor).sleep(0)
|
||||
|
||||
return self._test_run_in_background(blocking_function)
|
||||
|
||||
|
|
|
@ -37,11 +37,15 @@ USE_POSTGRES_FOR_TESTS = False
|
|||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
|
||||
def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None,
|
||||
**kargs):
|
||||
"""Setup a homeserver suitable for running tests against. Keyword arguments
|
||||
are passed to the Homeserver constructor. If no datastore is supplied a
|
||||
datastore backed by an in-memory sqlite db will be given to the HS.
|
||||
"""
|
||||
if reactor is None:
|
||||
from twisted.internet import reactor
|
||||
|
||||
if config is None:
|
||||
config = Mock()
|
||||
config.signing_key = [MockKey()]
|
||||
|
@ -110,6 +114,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
|
|||
database_engine=db_engine,
|
||||
room_list_handler=object(),
|
||||
tls_server_context_factory=Mock(),
|
||||
reactor=reactor,
|
||||
**kargs
|
||||
)
|
||||
db_conn = hs.get_db_conn()
|
||||
|
|
Loading…
Reference in a new issue