Add cancellation support to @cached and @cachedList decorators (#12183)

These decorators mostly support cancellation already. Add cancellation
tests and fix use of finished logging contexts by delaying cancellation,
as suggested by @erikjohnston.

Signed-off-by: Sean Quah <seanq@element.io>
This commit is contained in:
Sean Quah 2022-03-14 19:04:29 +00:00 committed by GitHub
parent 605d161d7d
commit 2fcf4b3f6c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 157 additions and 2 deletions

1
changelog.d/12183.misc Normal file
View file

@ -0,0 +1 @@
Add cancellation support to `@cached` and `@cachedList` decorators.

View file

@ -41,6 +41,7 @@ from twisted.python.failure import Failure
from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async_helpers import delay_cancellation
from synapse.util.caches.deferred_cache import DeferredCache from synapse.util.caches.deferred_cache import DeferredCache
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
@ -350,6 +351,11 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs) ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
ret = cache.set(cache_key, ret, callback=invalidate_callback) ret = cache.set(cache_key, ret, callback=invalidate_callback)
# We started a new call to `self.orig`, so we must always wait for it to
# complete. Otherwise we might mark our current logging context as
# finished while `self.orig` is still using it in the background.
ret = delay_cancellation(ret)
return make_deferred_yieldable(ret) return make_deferred_yieldable(ret)
wrapped = cast(_CachedFunction, _wrapped) wrapped = cast(_CachedFunction, _wrapped)
@ -510,6 +516,11 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks( d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
lambda _: results, unwrapFirstError lambda _: results, unwrapFirstError
) )
if missing:
# We started a new call to `self.orig`, so we must always wait for it to
# complete. Otherwise we might mark our current logging context as
# finished while `self.orig` is still using it in the background.
d = delay_cancellation(d)
return make_deferred_yieldable(d) return make_deferred_yieldable(d)
else: else:
return defer.succeed(results) return defer.succeed(results)

View file

@ -17,7 +17,7 @@ from typing import Set
from unittest import mock from unittest import mock
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.defer import Deferred from twisted.internet.defer import CancelledError, Deferred
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.logging.context import ( from synapse.logging.context import (
@ -28,7 +28,7 @@ from synapse.logging.context import (
make_deferred_yieldable, make_deferred_yieldable,
) )
from synapse.util.caches import descriptors from synapse.util.caches import descriptors
from synapse.util.caches.descriptors import cached, lru_cache from synapse.util.caches.descriptors import cached, cachedList, lru_cache
from tests import unittest from tests import unittest
from tests.test_utils import get_awaitable_result from tests.test_utils import get_awaitable_result
@ -493,6 +493,74 @@ class DescriptorTestCase(unittest.TestCase):
obj.invalidate() obj.invalidate()
top_invalidate.assert_called_once() top_invalidate.assert_called_once()
def test_cancel(self):
"""Test that cancelling a lookup does not cancel other lookups"""
complete_lookup: "Deferred[None]" = Deferred()
class Cls:
@cached()
async def fn(self, arg1):
await complete_lookup
return str(arg1)
obj = Cls()
d1 = obj.fn(123)
d2 = obj.fn(123)
self.assertFalse(d1.called)
self.assertFalse(d2.called)
# Cancel `d1`, which is the lookup that caused `fn` to run.
d1.cancel()
# `d2` should complete normally.
complete_lookup.callback(None)
self.failureResultOf(d1, CancelledError)
self.assertEqual(d2.result, "123")
def test_cancel_logcontexts(self):
"""Test that cancellation does not break logcontexts.
* The `CancelledError` must be raised with the correct logcontext.
* The inner lookup must not resume with a finished logcontext.
* The inner lookup must not restore a finished logcontext when done.
"""
complete_lookup: "Deferred[None]" = Deferred()
class Cls:
inner_context_was_finished = False
@cached()
async def fn(self, arg1):
await make_deferred_yieldable(complete_lookup)
self.inner_context_was_finished = current_context().finished
return str(arg1)
obj = Cls()
async def do_lookup():
with LoggingContext("c1") as c1:
try:
await obj.fn(123)
self.fail("No CancelledError thrown")
except CancelledError:
self.assertEqual(
current_context(),
c1,
"CancelledError was not raised with the correct logcontext",
)
# suppress the error and succeed
d = defer.ensureDeferred(do_lookup())
d.cancel()
complete_lookup.callback(None)
self.successResultOf(d)
self.assertFalse(
obj.inner_context_was_finished, "Tried to restart a finished logcontext"
)
self.assertEqual(current_context(), SENTINEL_CONTEXT)
class CacheDecoratorTestCase(unittest.HomeserverTestCase): class CacheDecoratorTestCase(unittest.HomeserverTestCase):
"""More tests for @cached """More tests for @cached
@ -865,3 +933,78 @@ class CachedListDescriptorTestCase(unittest.TestCase):
obj.fn.invalidate((10, 2)) obj.fn.invalidate((10, 2))
invalidate0.assert_called_once() invalidate0.assert_called_once()
invalidate1.assert_called_once() invalidate1.assert_called_once()
def test_cancel(self):
"""Test that cancelling a lookup does not cancel other lookups"""
complete_lookup: "Deferred[None]" = Deferred()
class Cls:
@cached()
def fn(self, arg1):
pass
@cachedList(cached_method_name="fn", list_name="args")
async def list_fn(self, args):
await complete_lookup
return {arg: str(arg) for arg in args}
obj = Cls()
d1 = obj.list_fn([123, 456])
d2 = obj.list_fn([123, 456, 789])
self.assertFalse(d1.called)
self.assertFalse(d2.called)
d1.cancel()
# `d2` should complete normally.
complete_lookup.callback(None)
self.failureResultOf(d1, CancelledError)
self.assertEqual(d2.result, {123: "123", 456: "456", 789: "789"})
def test_cancel_logcontexts(self):
"""Test that cancellation does not break logcontexts.
* The `CancelledError` must be raised with the correct logcontext.
* The inner lookup must not resume with a finished logcontext.
* The inner lookup must not restore a finished logcontext when done.
"""
complete_lookup: "Deferred[None]" = Deferred()
class Cls:
inner_context_was_finished = False
@cached()
def fn(self, arg1):
pass
@cachedList(cached_method_name="fn", list_name="args")
async def list_fn(self, args):
await make_deferred_yieldable(complete_lookup)
self.inner_context_was_finished = current_context().finished
return {arg: str(arg) for arg in args}
obj = Cls()
async def do_lookup():
with LoggingContext("c1") as c1:
try:
await obj.list_fn([123])
self.fail("No CancelledError thrown")
except CancelledError:
self.assertEqual(
current_context(),
c1,
"CancelledError was not raised with the correct logcontext",
)
# suppress the error and succeed
d = defer.ensureDeferred(do_lookup())
d.cancel()
complete_lookup.callback(None)
self.successResultOf(d)
self.assertFalse(
obj.inner_context_was_finished, "Tried to restart a finished logcontext"
)
self.assertEqual(current_context(), SENTINEL_CONTEXT)