Add missing type hints to test.util.caches (#14529)

This commit is contained in:
Patrick Cloke 2022-11-22 17:35:54 -05:00 committed by GitHub
parent 7f78b383ca
commit 4ae967cf63
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 76 additions and 66 deletions

1
changelog.d/14529.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints.

View file

@ -59,11 +59,6 @@ exclude = (?x)
|tests/server_notices/test_resource_limits_server_notices.py
|tests/test_state.py
|tests/test_terms_auth.py
|tests/util/caches/test_cached_call.py
|tests/util/caches/test_deferred_cache.py
|tests/util/caches/test_descriptors.py
|tests/util/caches/test_response_cache.py
|tests/util/caches/test_ttlcache.py
|tests/util/test_async_helpers.py
|tests/util/test_batching_queue.py
|tests/util/test_dict_cache.py
@ -133,6 +128,12 @@ disallow_untyped_defs = True
[mypy-tests.federation.transport.test_client]
disallow_untyped_defs = True
[mypy-tests.util.caches.*]
disallow_untyped_defs = True
[mypy-tests.util.caches.test_descriptors]
disallow_untyped_defs = False
[mypy-tests.utils]
disallow_untyped_defs = True

View file

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import NoReturn
from unittest.mock import Mock
from twisted.internet import defer
@ -23,14 +24,14 @@ from tests.unittest import TestCase
class CachedCallTestCase(TestCase):
def test_get(self):
def test_get(self) -> None:
"""
Happy-path test case: makes a couple of calls and makes sure they behave
correctly
"""
d = Deferred()
d: "Deferred[int]" = Deferred()
async def f():
async def f() -> int:
return await d
slow_call = Mock(side_effect=f)
@ -43,7 +44,7 @@ class CachedCallTestCase(TestCase):
# now fire off a couple of calls
completed_results = []
async def r():
async def r() -> None:
res = await cached_call.get()
completed_results.append(res)
@ -69,12 +70,12 @@ class CachedCallTestCase(TestCase):
self.assertEqual(r3, 123)
slow_call.assert_not_called()
def test_fast_call(self):
def test_fast_call(self) -> None:
"""
Test the behaviour when the underlying function completes immediately
"""
async def f():
async def f() -> int:
return 12
fast_call = Mock(side_effect=f)
@ -92,12 +93,12 @@ class CachedCallTestCase(TestCase):
class RetryOnExceptionCachedCallTestCase(TestCase):
def test_get(self):
def test_get(self) -> None:
# set up the RetryOnExceptionCachedCall around a function which will fail
# (after a while)
d = Deferred()
d: "Deferred[int]" = Deferred()
async def f1():
async def f1() -> NoReturn:
await d
raise ValueError("moo")
@ -110,7 +111,7 @@ class RetryOnExceptionCachedCallTestCase(TestCase):
# now fire off a couple of calls
completed_results = []
async def r():
async def r() -> None:
try:
await cached_call.get()
except Exception as e1:
@ -137,7 +138,7 @@ class RetryOnExceptionCachedCallTestCase(TestCase):
# to the getter
d = Deferred()
async def f2():
async def f2() -> int:
return await d
slow_call.reset_mock()

View file

@ -13,6 +13,7 @@
# limitations under the License.
from functools import partial
from typing import List, Tuple
from twisted.internet import defer
@ -22,20 +23,20 @@ from tests.unittest import TestCase
class DeferredCacheTestCase(TestCase):
def test_empty(self):
cache = DeferredCache("test")
def test_empty(self) -> None:
cache: DeferredCache[str, int] = DeferredCache("test")
with self.assertRaises(KeyError):
cache.get("foo")
def test_hit(self):
cache = DeferredCache("test")
def test_hit(self) -> None:
cache: DeferredCache[str, int] = DeferredCache("test")
cache.prefill("foo", 123)
self.assertEqual(self.successResultOf(cache.get("foo")), 123)
def test_hit_deferred(self):
cache = DeferredCache("test")
origin_d = defer.Deferred()
def test_hit_deferred(self) -> None:
cache: DeferredCache[str, int] = DeferredCache("test")
origin_d: "defer.Deferred[int]" = defer.Deferred()
set_d = cache.set("k1", origin_d)
# get should return an incomplete deferred
@ -43,7 +44,7 @@ class DeferredCacheTestCase(TestCase):
self.assertFalse(get_d.called)
# add a callback that will make sure that the set_d gets called before the get_d
def check1(r):
def check1(r: str) -> str:
self.assertTrue(set_d.called)
return r
@ -55,16 +56,16 @@ class DeferredCacheTestCase(TestCase):
self.assertEqual(self.successResultOf(set_d), 99)
self.assertEqual(self.successResultOf(get_d), 99)
def test_callbacks(self):
def test_callbacks(self) -> None:
"""Invalidation callbacks are called at the right time"""
cache = DeferredCache("test")
cache: DeferredCache[str, int] = DeferredCache("test")
callbacks = set()
# start with an entry, with a callback
cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
# now replace that entry with a pending result
origin_d = defer.Deferred()
origin_d: "defer.Deferred[int]" = defer.Deferred()
set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
# ... and also make a get request
@ -89,15 +90,15 @@ class DeferredCacheTestCase(TestCase):
cache.prefill("k1", 30)
self.assertEqual(callbacks, {"set", "get"})
def test_set_fail(self):
cache = DeferredCache("test")
def test_set_fail(self) -> None:
cache: DeferredCache[str, int] = DeferredCache("test")
callbacks = set()
# start with an entry, with a callback
cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
# now replace that entry with a pending result
origin_d = defer.Deferred()
origin_d: defer.Deferred = defer.Deferred()
set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
# ... and also make a get request
@ -126,9 +127,9 @@ class DeferredCacheTestCase(TestCase):
cache.prefill("k1", 30)
self.assertEqual(callbacks, {"prefill", "get2"})
def test_get_immediate(self):
cache = DeferredCache("test")
d1 = defer.Deferred()
def test_get_immediate(self) -> None:
cache: DeferredCache[str, int] = DeferredCache("test")
d1: "defer.Deferred[int]" = defer.Deferred()
cache.set("key1", d1)
# get_immediate should return default
@ -142,27 +143,27 @@ class DeferredCacheTestCase(TestCase):
v = cache.get_immediate("key1", 1)
self.assertEqual(v, 2)
def test_invalidate(self):
cache = DeferredCache("test")
def test_invalidate(self) -> None:
cache: DeferredCache[Tuple[str], int] = DeferredCache("test")
cache.prefill(("foo",), 123)
cache.invalidate(("foo",))
with self.assertRaises(KeyError):
cache.get(("foo",))
def test_invalidate_all(self):
cache = DeferredCache("testcache")
def test_invalidate_all(self) -> None:
cache: DeferredCache[str, str] = DeferredCache("testcache")
callback_record = [False, False]
def record_callback(idx):
def record_callback(idx: int) -> None:
callback_record[idx] = True
# add a couple of pending entries
d1 = defer.Deferred()
d1: "defer.Deferred[str]" = defer.Deferred()
cache.set("key1", d1, partial(record_callback, 0))
d2 = defer.Deferred()
d2: "defer.Deferred[str]" = defer.Deferred()
cache.set("key2", d2, partial(record_callback, 1))
# lookup should return pending deferreds
@ -193,8 +194,8 @@ class DeferredCacheTestCase(TestCase):
with self.assertRaises(KeyError):
cache.get("key1", None)
def test_eviction(self):
cache = DeferredCache(
def test_eviction(self) -> None:
cache: DeferredCache[int, str] = DeferredCache(
"test", max_entries=2, apply_cache_factor_from_config=False
)
@ -208,8 +209,8 @@ class DeferredCacheTestCase(TestCase):
cache.get(2)
cache.get(3)
def test_eviction_lru(self):
cache = DeferredCache(
def test_eviction_lru(self) -> None:
cache: DeferredCache[int, str] = DeferredCache(
"test", max_entries=2, apply_cache_factor_from_config=False
)
@ -227,8 +228,8 @@ class DeferredCacheTestCase(TestCase):
cache.get(1)
cache.get(3)
def test_eviction_iterable(self):
cache = DeferredCache(
def test_eviction_iterable(self) -> None:
cache: DeferredCache[int, List[str]] = DeferredCache(
"test",
max_entries=3,
apply_cache_factor_from_config=False,

View file

@ -13,11 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Iterable, Set, Tuple
from typing import Iterable, Set, Tuple, cast
from unittest import mock
from twisted.internet import defer, reactor
from twisted.internet.defer import CancelledError, Deferred
from twisted.internet.interfaces import IReactorTime
from synapse.api.errors import SynapseError
from synapse.logging.context import (
@ -37,8 +38,8 @@ logger = logging.getLogger(__name__)
def run_on_reactor():
d = defer.Deferred()
reactor.callLater(0, d.callback, 0)
d: "Deferred[int]" = defer.Deferred()
cast(IReactorTime, reactor).callLater(0, d.callback, 0)
return make_deferred_yieldable(d)
@ -224,7 +225,8 @@ class DescriptorTestCase(unittest.TestCase):
callbacks: Set[str] = set()
# set off an asynchronous request
obj.result = origin_d = defer.Deferred()
origin_d: Deferred = defer.Deferred()
obj.result = origin_d
d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1"))
self.assertFalse(d1.called)
@ -262,7 +264,7 @@ class DescriptorTestCase(unittest.TestCase):
"""Check that logcontexts are set and restored correctly when
using the cache."""
complete_lookup = defer.Deferred()
complete_lookup: Deferred = defer.Deferred()
class Cls:
@descriptors.cached()
@ -772,10 +774,14 @@ class CachedListDescriptorTestCase(unittest.TestCase):
@descriptors.cachedList(cached_method_name="fn", list_name="args1")
async def list_fn(self, args1, arg2):
assert current_context().name == "c1"
context = current_context()
assert isinstance(context, LoggingContext)
assert context.name == "c1"
# we want this to behave like an asynchronous function
await run_on_reactor()
assert current_context().name == "c1"
context = current_context()
assert isinstance(context, LoggingContext)
assert context.name == "c1"
return self.mock(args1, arg2)
with LoggingContext("c1") as c1:
@ -834,7 +840,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
return self.mock(args1)
obj = Cls()
deferred_result = Deferred()
deferred_result: "Deferred[dict]" = Deferred()
obj.mock.return_value = deferred_result
# start off several concurrent lookups of the same key

View file

@ -35,7 +35,7 @@ class ResponseCacheTestCase(TestCase):
(These have cache with a short timeout_ms=, shorter than will be tested through advancing the clock)
"""
def setUp(self):
def setUp(self) -> None:
self.reactor, self.clock = get_clock()
def with_cache(self, name: str, ms: int = 0) -> ResponseCache:
@ -49,7 +49,7 @@ class ResponseCacheTestCase(TestCase):
await self.clock.sleep(1)
return o
def test_cache_hit(self):
def test_cache_hit(self) -> None:
cache = self.with_cache("keeping_cache", ms=9001)
expected_result = "howdy"
@ -74,7 +74,7 @@ class ResponseCacheTestCase(TestCase):
"cache should still have the result",
)
def test_cache_miss(self):
def test_cache_miss(self) -> None:
cache = self.with_cache("trashing_cache", ms=0)
expected_result = "howdy"
@ -90,7 +90,7 @@ class ResponseCacheTestCase(TestCase):
)
self.assertCountEqual([], cache.keys(), "cache should not have the result now")
def test_cache_expire(self):
def test_cache_expire(self) -> None:
cache = self.with_cache("short_cache", ms=1000)
expected_result = "howdy"
@ -115,7 +115,7 @@ class ResponseCacheTestCase(TestCase):
self.reactor.pump((2,))
self.assertCountEqual([], cache.keys(), "cache should not have the result now")
def test_cache_wait_hit(self):
def test_cache_wait_hit(self) -> None:
cache = self.with_cache("neutral_cache")
expected_result = "howdy"
@ -131,7 +131,7 @@ class ResponseCacheTestCase(TestCase):
self.assertEqual(expected_result, self.successResultOf(wrap_d))
def test_cache_wait_expire(self):
def test_cache_wait_expire(self) -> None:
cache = self.with_cache("medium_cache", ms=3000)
expected_result = "howdy"
@ -162,7 +162,7 @@ class ResponseCacheTestCase(TestCase):
self.assertCountEqual([], cache.keys(), "cache should not have the result now")
@parameterized.expand([(True,), (False,)])
def test_cache_context_nocache(self, should_cache: bool):
def test_cache_context_nocache(self, should_cache: bool) -> None:
"""If the callback clears the should_cache bit, the result should not be cached"""
cache = self.with_cache("medium_cache", ms=3000)
@ -170,7 +170,7 @@ class ResponseCacheTestCase(TestCase):
call_count = 0
async def non_caching(o: str, cache_context: ResponseCacheContext[int]):
async def non_caching(o: str, cache_context: ResponseCacheContext[int]) -> str:
nonlocal call_count
call_count += 1
await self.clock.sleep(1)

View file

@ -20,11 +20,11 @@ from tests import unittest
class CacheTestCase(unittest.TestCase):
def setUp(self):
def setUp(self) -> None:
self.mock_timer = Mock(side_effect=lambda: 100.0)
self.cache = TTLCache("test_cache", self.mock_timer)
self.cache: TTLCache[str, str] = TTLCache("test_cache", self.mock_timer)
def test_get(self):
def test_get(self) -> None:
"""simple set/get tests"""
self.cache.set("one", "1", 10)
self.cache.set("two", "2", 20)
@ -59,7 +59,7 @@ class CacheTestCase(unittest.TestCase):
self.assertEqual(self.cache._metrics.hits, 4)
self.assertEqual(self.cache._metrics.misses, 5)
def test_expiry(self):
def test_expiry(self) -> None:
self.cache.set("one", "1", 10)
self.cache.set("two", "2", 20)
self.cache.set("three", "3", 30)