mirror of
https://github.com/dragonflydb/dragonfly.git
synced 2024-12-14 11:58:02 +00:00
0e7ae34fe4
* fix: enforce load limits when loading snapshot Prevent loading snapshots with used memory higher than max memory limit. 1. Store the used memory metadata only inside the summary file 2. Load the summary file before loading anything else, and if the used-memory is higher, abort the load. --------- Signed-off-by: Roman Gershman <roman@dragonflydb.io>
749 lines
25 KiB
Python
749 lines
25 KiB
Python
import asyncio
|
|
import functools
|
|
import itertools
|
|
import logging
|
|
import sys
|
|
import wrapt
|
|
from redis import asyncio as aioredis
|
|
import redis
|
|
import random
|
|
import string
|
|
import time
|
|
import difflib
|
|
import json
|
|
import subprocess
|
|
import pytest
|
|
import os
|
|
from typing import Iterable, Union
|
|
from enum import Enum
|
|
|
|
|
|
def tmp_file_name():
|
|
return "".join(random.choices(string.ascii_letters, k=10))
|
|
|
|
|
|
def chunked(n, iterable):
|
|
"""Transform iterable into iterator of chunks of size n"""
|
|
it = iter(iterable)
|
|
while True:
|
|
chunk = tuple(itertools.islice(it, n))
|
|
if not chunk:
|
|
return
|
|
yield chunk
|
|
|
|
|
|
def eprint(*args, **kwargs):
|
|
"""Print to stderr"""
|
|
print(*args, file=sys.stderr, **kwargs)
|
|
|
|
|
|
def gen_test_data(n, start=0, seed=None):
|
|
for i in range(start, n):
|
|
yield "k-" + str(i), "v-" + str(i) + ("-" + str(seed) if seed else "")
|
|
|
|
|
|
def batch_fill_data(client, gen, batch_size=100):
|
|
for group in chunked(batch_size, gen):
|
|
client.mset({k: v for k, v, in group})
|
|
|
|
|
|
async def tick_timer(func, timeout=5, step=0.1):
|
|
"""
|
|
Async generator with automatic break when all asserts pass
|
|
|
|
for object, breaker in tick_timer():
|
|
with breaker:
|
|
assert conditions on object
|
|
|
|
If the generator times out, the last failed assert is raised
|
|
"""
|
|
|
|
class ticker_breaker:
|
|
def __init__(self):
|
|
self.exc = None
|
|
self.entered = False
|
|
|
|
def __enter__(self):
|
|
self.entered = True
|
|
|
|
def __exit__(self, exc_type, exc_value, trace):
|
|
if exc_value:
|
|
self.exc = exc_value
|
|
return True
|
|
|
|
last_error = None
|
|
start = time.time()
|
|
while time.time() - start < timeout:
|
|
breaker = ticker_breaker()
|
|
yield (await func(), breaker)
|
|
if breaker.entered and not breaker.exc:
|
|
return
|
|
|
|
last_error = breaker.exc
|
|
await asyncio.sleep(step)
|
|
|
|
if last_error:
|
|
raise TimeoutError("Timed out!") from last_error
|
|
raise TimeoutError("Timed out!")
|
|
|
|
|
|
async def info_tick_timer(client: aioredis.Redis, section=None, **kwargs):
|
|
async for x in tick_timer(lambda: client.info(section), **kwargs):
|
|
yield x
|
|
|
|
|
|
# wait for a process becomes "responsive":
|
|
# for a master - waits that it finishes loading a snapshot if it's budy doing so,
|
|
# and for replica it waits until it finishes its full sync stage and reaches the stable sync state.
|
|
async def wait_available_async(
|
|
clients: Union[aioredis.Redis, Iterable[aioredis.Redis]], timeout=120
|
|
):
|
|
if not isinstance(clients, aioredis.Redis):
|
|
# Syntactic sugar to seamlessly handle an array of clients.
|
|
return await asyncio.gather(*(wait_available_async(c) for c in clients))
|
|
|
|
"""Block until instance exits loading phase"""
|
|
# First we make sure that ping passes
|
|
start = time.time()
|
|
while (time.time() - start) < timeout:
|
|
try:
|
|
await clients.ping()
|
|
break
|
|
except aioredis.BusyLoadingError as e:
|
|
assert "Dragonfly is loading the dataset in memory" in str(e)
|
|
timeout -= time.time() - start
|
|
if timeout <= 0:
|
|
raise TimeoutError("Timed out!")
|
|
|
|
# Secondly for replicas, we make sure they reached stable state replicaton
|
|
async for info, breaker in info_tick_timer(clients, "REPLICATION", timeout=timeout):
|
|
with breaker:
|
|
assert info["role"] == "master" or "slave_repl_offset" in info, info
|
|
|
|
|
|
class SizeChange(Enum):
|
|
SHRINK = 0
|
|
NO_CHANGE = 1
|
|
GROW = 2
|
|
|
|
|
|
class ValueType(Enum):
|
|
STRING = 0
|
|
LIST = 1
|
|
SET = 2
|
|
HSET = 3
|
|
ZSET = 4
|
|
JSON = 5
|
|
|
|
|
|
class CommandGenerator:
|
|
"""Class for generating complex command sequences"""
|
|
|
|
def __init__(self, target_keys, val_size, batch_size, max_multikey, unsupported_types=[]):
|
|
self.key_cnt_target = target_keys
|
|
self.val_size = val_size
|
|
self.batch_size = min(batch_size, target_keys)
|
|
self.max_multikey = max_multikey
|
|
self.unsupported_types = unsupported_types
|
|
|
|
# Key management
|
|
self.key_sets = [set() for _ in ValueType]
|
|
self.key_cursor = 0
|
|
self.key_cnt = 0
|
|
|
|
# Grow factors
|
|
self.diff_speed = 5
|
|
self.base_diff_prob = 0.2
|
|
self.min_diff_prob = 0.1
|
|
|
|
def keys(self):
|
|
return itertools.chain(*self.key_sets)
|
|
|
|
def keys_and_types(self):
|
|
return ((k, t) for t in list(ValueType) for k in self.set_for_type(t))
|
|
|
|
def set_for_type(self, t: ValueType):
|
|
return self.key_sets[t.value]
|
|
|
|
def add_key(self, t: ValueType):
|
|
"""Add new key of type t"""
|
|
k, self.key_cursor = self.key_cursor, self.key_cursor + 1
|
|
self.set_for_type(t).add(k)
|
|
return k
|
|
|
|
def random_type(self):
|
|
return random.choice([t for t in ValueType if (t not in self.unsupported_types)])
|
|
|
|
def randomize_nonempty_set(self):
|
|
"""Return random non-empty set and its type"""
|
|
if not any(self.key_sets):
|
|
return None, None
|
|
|
|
t = self.random_type()
|
|
s = self.set_for_type(t)
|
|
|
|
if len(s) == 0:
|
|
return self.randomize_nonempty_set()
|
|
else:
|
|
return s, t
|
|
|
|
def randomize_key(self, t=None, pop=False):
|
|
"""Return random key and its type"""
|
|
if t is None:
|
|
s, t = self.randomize_nonempty_set()
|
|
else:
|
|
s = self.set_for_type(t)
|
|
|
|
if s is None or len(s) == 0:
|
|
return None, None
|
|
|
|
k = s.pop()
|
|
if not pop:
|
|
s.add(k)
|
|
|
|
return k, t
|
|
|
|
def generate_val(self, t: ValueType):
|
|
"""Generate filler value of configured size for type t"""
|
|
|
|
def rand_str(k=3, s=""):
|
|
# Use small k value to reduce mem usage and increase number of ops
|
|
return s.join(random.choices(string.ascii_letters, k=k))
|
|
|
|
if t == ValueType.STRING:
|
|
# Random string for MSET
|
|
return (rand_str(self.val_size),)
|
|
elif t == ValueType.LIST:
|
|
# Random sequence k-letter elements for LPUSH
|
|
return tuple(rand_str() for _ in range(self.val_size // 4))
|
|
elif t == ValueType.SET:
|
|
# Random sequence of k-letter elements for SADD
|
|
return tuple(rand_str() for _ in range(self.val_size // 4))
|
|
elif t == ValueType.HSET:
|
|
# Random sequence of k-letter keys + int and two start values for HSET
|
|
elements = (
|
|
(rand_str(), random.randint(0, self.val_size)) for _ in range(self.val_size // 5)
|
|
)
|
|
return ("v0", 0, "v1", 0) + tuple(itertools.chain(*elements))
|
|
elif t == ValueType.ZSET:
|
|
# Random sequnce of k-letter members and int score for ZADD
|
|
# The length of the sequence will vary between val_size/4 and 130.
|
|
# This ensures that we test both the ZSET implementation with listpack and the our custom BPtree.
|
|
value_sizes = [self.val_size // 4, 130]
|
|
probabilities = [8, 1]
|
|
value_size = random.choices(value_sizes, probabilities)[0]
|
|
elements = ((random.randint(0, self.val_size), rand_str()) for _ in range(value_size))
|
|
return tuple(itertools.chain(*elements))
|
|
|
|
elif t == ValueType.JSON:
|
|
# Json object with keys:
|
|
# - arr (array of random strings)
|
|
# - ints (array of objects {i:random integer})
|
|
# - i (random integer)
|
|
ints = [{"i": random.randint(0, 100)} for i in range(self.val_size // 6)]
|
|
strs = [rand_str() for _ in range(self.val_size // 6)]
|
|
return "$", json.dumps({"arr": strs, "ints": ints, "i": random.randint(0, 100)})
|
|
else:
|
|
assert False, "Invalid ValueType"
|
|
|
|
def gen_shrink_cmd(self):
|
|
"""
|
|
Generate command that shrinks data: DEL of random keys or almost immediate <=50ms PEXPIRE.
|
|
"""
|
|
if random.random() < 0.3:
|
|
key, _ = self.randomize_key(pop=True)
|
|
if key == None:
|
|
return None, 0
|
|
return ("PEXPIRE", f"k{key}", f"{random.randint(0, 50)}"), -1
|
|
else:
|
|
keys_gen = (
|
|
self.randomize_key(pop=True) for _ in range(random.randint(1, self.max_multikey))
|
|
)
|
|
keys = [f"k{k}" for k, _ in keys_gen if k is not None]
|
|
|
|
if len(keys) == 0:
|
|
return None, 0
|
|
return ("DEL", *keys), -len(keys)
|
|
|
|
UPDATE_ACTIONS = [
|
|
("APPEND {k} {val}", ValueType.STRING),
|
|
("SETRANGE {k} 10 {val}", ValueType.STRING),
|
|
("LPUSH {k} {val}", ValueType.LIST),
|
|
("LPOP {k}", ValueType.LIST),
|
|
("SADD {k} {val}", ValueType.SET),
|
|
("SPOP {k}", ValueType.SET),
|
|
("HSETNX {k} v0 {val}", ValueType.HSET),
|
|
("HINCRBY {k} v1 1", ValueType.HSET),
|
|
("ZPOPMIN {k} 1", ValueType.ZSET),
|
|
("ZADD {k} 0 {val}", ValueType.ZSET),
|
|
("JSON.NUMINCRBY {k} $..i 1", ValueType.JSON),
|
|
("JSON.ARRPOP {k} $.arr", ValueType.JSON),
|
|
('JSON.ARRAPPEND {k} $.arr "{val}"', ValueType.JSON),
|
|
]
|
|
|
|
def gen_update_cmd(self):
|
|
"""
|
|
Generate command that makes no change to keyset: random of UPDATE_ACTIONS.
|
|
"""
|
|
cmd, t = random.choice(self.UPDATE_ACTIONS)
|
|
k, _ = self.randomize_key(t)
|
|
val = "".join(random.choices(string.ascii_letters, k=3))
|
|
return cmd.format(k=f"k{k}", val=val).split() if k is not None else None, 0
|
|
|
|
GROW_ACTINONS = {
|
|
ValueType.STRING: "MSET",
|
|
ValueType.LIST: "LPUSH",
|
|
ValueType.SET: "SADD",
|
|
ValueType.HSET: "HMSET",
|
|
ValueType.ZSET: "ZADD",
|
|
ValueType.JSON: "JSON.MSET",
|
|
}
|
|
|
|
def gen_grow_cmd(self):
|
|
"""
|
|
Generate command that grows keyset: Initialize key of random type with filler value.
|
|
"""
|
|
# TODO: Implement COPY in Dragonfly.
|
|
t = self.random_type()
|
|
if t in [ValueType.STRING, ValueType.JSON]:
|
|
count = random.randint(1, self.max_multikey)
|
|
else:
|
|
count = 1
|
|
|
|
keys = (self.add_key(t) for _ in range(count))
|
|
payload = itertools.chain(*((f"k{k}",) + self.generate_val(t) for k in keys))
|
|
filtered_payload = filter(lambda p: p is not None, payload)
|
|
|
|
return (self.GROW_ACTINONS[t],) + tuple(filtered_payload), count
|
|
|
|
def make(self, action):
|
|
"""Create command for action and return it together with number of keys added (removed)"""
|
|
if action == SizeChange.SHRINK:
|
|
return self.gen_shrink_cmd()
|
|
elif action == SizeChange.NO_CHANGE:
|
|
return self.gen_update_cmd()
|
|
else:
|
|
return self.gen_grow_cmd()
|
|
|
|
def reset(self):
|
|
self.key_sets = [set() for _ in ValueType]
|
|
self.key_cursor = 0
|
|
self.key_cnt = 0
|
|
|
|
def size_change_probs(self):
|
|
"""Calculate probabilities of size change actions"""
|
|
# Relative distance to key target
|
|
dist = (self.key_cnt_target - self.key_cnt) / self.key_cnt_target
|
|
# Shrink has a roughly twice as large expected number of changed keys than grow
|
|
return [
|
|
max(self.base_diff_prob - self.diff_speed * dist, self.min_diff_prob),
|
|
15.0,
|
|
max(self.base_diff_prob + 2 * self.diff_speed * dist, self.min_diff_prob),
|
|
]
|
|
|
|
def generate(self):
|
|
"""Generate next batch of commands, return it and ratio of current keys to target"""
|
|
changes = []
|
|
cmds = []
|
|
while len(cmds) < self.batch_size:
|
|
# Re-calculating changes in small groups
|
|
if len(changes) == 0:
|
|
changes = random.choices(list(SizeChange), weights=self.size_change_probs(), k=20)
|
|
|
|
cmd, delta = self.make(changes.pop())
|
|
if cmd is not None:
|
|
cmds.append(cmd)
|
|
self.key_cnt += delta
|
|
return cmds, self.key_cnt / self.key_cnt_target
|
|
|
|
|
|
class DataCapture:
|
|
"""
|
|
Captured state of single database.
|
|
"""
|
|
|
|
def __init__(self, entries):
|
|
self.entries = entries
|
|
|
|
def compare(self, other):
|
|
if self.entries == other.entries:
|
|
return True
|
|
|
|
self._print_diff(other)
|
|
return False
|
|
|
|
def _print_diff(self, other):
|
|
eprint("=== DIFF ===")
|
|
printed = 0
|
|
diff = difflib.ndiff(self.entries, other.entries)
|
|
for line in diff:
|
|
if line.startswith(" "):
|
|
continue
|
|
eprint(line)
|
|
if printed >= 20:
|
|
eprint("... omitted ...")
|
|
break
|
|
printed += 1
|
|
eprint("=== END DIFF ===")
|
|
|
|
|
|
class DflySeeder:
|
|
"""
|
|
Data seeder with support for multiple types and commands.
|
|
|
|
Usage:
|
|
|
|
Create a seeder with target number of keys (100k) of specified size (200) and work on 5 dbs,
|
|
|
|
seeder = new DflySeeder(keys=100_000, value_size=200, dbcount=5)
|
|
|
|
Stop when we are in 5% of target number of keys (i.e. above 95_000)
|
|
Because its probabilistic we might never reach exactly 100_000.
|
|
|
|
await seeder.run(target_deviation=0.05)
|
|
|
|
Run 3000 commands in stable state, crate a capture and compare it to
|
|
replica on port 1112
|
|
|
|
await seeder.run(target_op=3000)
|
|
capture = await seeder.capture()
|
|
assert await seeder.compare(capture, port=1112)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
port=6379,
|
|
keys=1000,
|
|
val_size=50,
|
|
batch_size=100,
|
|
max_multikey=5,
|
|
dbcount=1,
|
|
multi_transaction_probability=0.3,
|
|
log_file=None,
|
|
unsupported_types=[],
|
|
stop_on_failure=True,
|
|
cluster_mode=False,
|
|
):
|
|
if cluster_mode:
|
|
max_multikey = 1
|
|
multi_transaction_probability = 0
|
|
unsupported_types.append(ValueType.JSON) # Cluster aio client doesn't support JSON
|
|
|
|
self.cluster_mode = cluster_mode
|
|
self.gen = CommandGenerator(keys, val_size, batch_size, max_multikey, unsupported_types)
|
|
self.port = port
|
|
self.dbcount = dbcount
|
|
self.multi_transaction_probability = multi_transaction_probability
|
|
self.stop_flag = False
|
|
self.stop_on_failure = stop_on_failure
|
|
|
|
self.log_file = log_file
|
|
if self.log_file is not None:
|
|
open(self.log_file, "w").close()
|
|
|
|
async def run(self, target_ops=None, target_deviation=None):
|
|
"""
|
|
Run a seeding cycle on all dbs either until stop(), a fixed number of commands (target_ops)
|
|
or until reaching an allowed deviation from the target number of keys (target_deviation)
|
|
"""
|
|
logging.debug(f"Running ops:{target_ops} deviation:{target_deviation}")
|
|
self.stop_flag = False
|
|
queues = [asyncio.Queue(maxsize=3) for _ in range(self.dbcount)]
|
|
producer = asyncio.create_task(
|
|
self._generator_task(queues, target_ops=target_ops, target_deviation=target_deviation)
|
|
)
|
|
consumers = [
|
|
asyncio.create_task(self._executor_task(i, queue)) for i, queue in enumerate(queues)
|
|
]
|
|
|
|
time_start = time.time()
|
|
|
|
cmdcount = await producer
|
|
for consumer in consumers:
|
|
await consumer
|
|
|
|
took = time.time() - time_start
|
|
qps = round(cmdcount * self.dbcount / took, 2)
|
|
logging.debug(f"Filling took: {took}, QPS: {qps}")
|
|
|
|
def stop(self):
|
|
"""Stop all invocations to run"""
|
|
self.stop_flag = True
|
|
|
|
def reset(self):
|
|
"""Reset internal state. Needs to be called after flush or restart"""
|
|
self.gen.reset()
|
|
|
|
async def capture(self, port=None):
|
|
"""Create DataCapture for all dbs"""
|
|
|
|
if port is None:
|
|
port = self.port
|
|
logging.debug(f"Starting capture from {port=}")
|
|
keys = sorted(list(self.gen.keys_and_types()))
|
|
|
|
captures = await asyncio.gather(
|
|
*(self._capture_db(port=port, target_db=db, keys=keys) for db in range(self.dbcount))
|
|
)
|
|
return captures
|
|
|
|
async def compare(self, initial_captures, port=6379):
|
|
"""Compare data capture with all dbs of instance and return True if all dbs are correct"""
|
|
print(f"comparing capture to {port}")
|
|
target_captures = await self.capture(port=port)
|
|
|
|
for db, target_capture, initial_capture in zip(
|
|
range(self.dbcount), target_captures, initial_captures
|
|
):
|
|
print(f"comparing capture to {port}, db: {db}")
|
|
if not initial_capture.compare(target_capture):
|
|
eprint(f">>> Inconsistent data on port {port}, db {db}")
|
|
return False
|
|
return True
|
|
|
|
def target(self, key_cnt):
|
|
self.gen.key_cnt_target = key_cnt
|
|
|
|
def _make_client(self, **kwargs):
|
|
if self.cluster_mode:
|
|
return aioredis.RedisCluster(host="127.0.0.1", **kwargs)
|
|
else:
|
|
return aioredis.Redis(**kwargs)
|
|
|
|
async def _close_client(self, client):
|
|
if not self.cluster_mode:
|
|
await client.connection_pool.disconnect()
|
|
await client.close()
|
|
|
|
async def _capture_db(self, port, target_db, keys):
|
|
client = self._make_client(port=port, db=target_db)
|
|
capture = DataCapture(await self._capture_entries(client, keys))
|
|
|
|
await self._close_client(client)
|
|
|
|
return capture
|
|
|
|
async def _generator_task(self, queues, target_ops=None, target_deviation=None):
|
|
cpu_time = 0
|
|
submitted = 0
|
|
batches = 0
|
|
deviation = 0.0
|
|
|
|
file = None
|
|
if self.log_file:
|
|
file = open(self.log_file, "a")
|
|
|
|
def should_run():
|
|
if self.stop_flag:
|
|
return False
|
|
if target_ops is not None and submitted >= target_ops:
|
|
return False
|
|
if target_deviation is not None and (
|
|
deviation > 1 or abs(1 - deviation) < target_deviation
|
|
):
|
|
return False
|
|
return True
|
|
|
|
def stringify_cmd(cmd):
|
|
if isinstance(cmd, tuple):
|
|
return " ".join(str(c) for c in cmd)
|
|
else:
|
|
return str(cmd)
|
|
|
|
while should_run():
|
|
start_time = time.time()
|
|
blob, deviation = self.gen.generate()
|
|
is_multi_transaction = random.random() < self.multi_transaction_probability
|
|
tx_data = (blob, is_multi_transaction)
|
|
cpu_time += time.time() - start_time
|
|
|
|
await asyncio.gather(*(q.put(tx_data) for q in queues))
|
|
submitted += len(blob)
|
|
batches += 1
|
|
|
|
if file is not None:
|
|
pattern = "MULTI\n{}\nEXEC\n" if is_multi_transaction else "{}\n"
|
|
file.write(pattern.format("\n".join(stringify_cmd(cmd) for cmd in blob)))
|
|
|
|
print(".", end="", flush=True)
|
|
await asyncio.sleep(0.0)
|
|
|
|
print("\ncpu time", cpu_time, "batches", batches, "commands", submitted)
|
|
|
|
await asyncio.gather(*(q.put(None) for q in queues))
|
|
for q in queues:
|
|
await q.join()
|
|
|
|
if file is not None:
|
|
file.flush()
|
|
|
|
return submitted
|
|
|
|
async def _executor_task(self, db, queue):
|
|
client = self._make_client(port=self.port, db=db)
|
|
|
|
while True:
|
|
tx_data = await queue.get()
|
|
if tx_data is None:
|
|
queue.task_done()
|
|
break
|
|
|
|
pipe = client.pipeline(transaction=tx_data[1])
|
|
for cmd in tx_data[0]:
|
|
pipe.execute_command(*cmd)
|
|
|
|
try:
|
|
await pipe.execute()
|
|
except (redis.exceptions.ConnectionError, redis.exceptions.ResponseError) as e:
|
|
if self.stop_on_failure:
|
|
await self._close_client(client)
|
|
raise SystemExit(e)
|
|
except Exception as e:
|
|
await self._close_client(client)
|
|
raise SystemExit(e)
|
|
queue.task_done()
|
|
|
|
await self._close_client(client)
|
|
|
|
CAPTURE_COMMANDS = {
|
|
ValueType.STRING: lambda pipe, k: pipe.get(k),
|
|
ValueType.LIST: lambda pipe, k: pipe.lrange(k, 0, -1),
|
|
ValueType.SET: lambda pipe, k: pipe.smembers(k),
|
|
ValueType.HSET: lambda pipe, k: pipe.hgetall(k),
|
|
ValueType.ZSET: lambda pipe, k: pipe.zrange(k, start=0, end=-1, withscores=True),
|
|
ValueType.JSON: lambda pipe, k: pipe.execute_command("JSON.GET", k, "$"),
|
|
}
|
|
|
|
CAPTURE_EXTRACTORS = {
|
|
ValueType.STRING: lambda res, tostr: (tostr(res),),
|
|
ValueType.LIST: lambda res, tostr: (tostr(s) for s in res),
|
|
ValueType.SET: lambda res, tostr: sorted(tostr(s) for s in res),
|
|
ValueType.HSET: lambda res, tostr: sorted(
|
|
tostr(k) + "=" + tostr(v) for k, v in res.items()
|
|
),
|
|
ValueType.ZSET: lambda res, tostr: (tostr(s) + "-" + str(f) for (s, f) in res),
|
|
ValueType.JSON: lambda res, tostr: (tostr(res),),
|
|
}
|
|
|
|
async def _capture_entries(self, client, keys):
|
|
def tostr(b):
|
|
return b.decode("utf-8") if isinstance(b, bytes) else str(b)
|
|
|
|
entries = []
|
|
for group in chunked(self.gen.batch_size * 2, keys):
|
|
pipe = client.pipeline(transaction=False)
|
|
for k, t in group:
|
|
self.CAPTURE_COMMANDS[t](pipe, f"k{k}")
|
|
|
|
results = await pipe.execute()
|
|
for (k, t), res in zip(group, results):
|
|
out = f"{t.name} k{k}: " + " ".join(self.CAPTURE_EXTRACTORS[t](res, tostr))
|
|
entries.append(out)
|
|
|
|
return entries
|
|
|
|
|
|
class DflySeederFactory:
|
|
"""
|
|
Used to pass params to a DflySeeder.
|
|
"""
|
|
|
|
def __init__(self, log_file=None):
|
|
self.log_file = log_file
|
|
|
|
def __repr__(self) -> str:
|
|
return f"DflySeederFactory(log_file={self.log_file})"
|
|
|
|
def create(self, **kwargs):
|
|
return DflySeeder(log_file=self.log_file, **kwargs)
|
|
|
|
|
|
def gen_ca_cert(ca_key_path, ca_cert_path):
|
|
# We first need to generate the tls certificates to be used by the server
|
|
|
|
# Generate CA (certificate authority) key and self-signed certificate
|
|
# In production, CA should be generated by a third party authority
|
|
# Expires in one day and is not encrtypted (-nodes)
|
|
# X.509 format for the key
|
|
step = rf'openssl req -x509 -newkey rsa:4096 -days 1 -nodes -keyout {ca_key_path} -out {ca_cert_path} -subj "/C=GR/ST=SKG/L=Thessaloniki/O=KK/OU=AcmeStudios/CN=Gr/emailAddress=acme@gmail.com"'
|
|
subprocess.run(step, shell=True)
|
|
|
|
|
|
def gen_certificate(
|
|
ca_key_path, ca_certificate_path, certificate_request_path, private_key_path, certificate_path
|
|
):
|
|
# Generate Dragonfly's private key and certificate signing request (CSR)
|
|
step1 = rf'openssl req -newkey rsa:4096 -nodes -keyout {private_key_path} -out {certificate_request_path} -subj "/C=GR/ST=SKG/L=Thessaloniki/O=KK/OU=Comp/CN=Gr/emailAddress=does_not_exist@gmail.com"'
|
|
subprocess.run(step1, shell=True)
|
|
|
|
# Use CA's private key to sign dragonfly's CSR and get back the signed certificate
|
|
step2 = rf"openssl x509 -req -in {certificate_request_path} -days 1 -CA {ca_certificate_path} -CAkey {ca_key_path} -CAcreateserial -out {certificate_path}"
|
|
subprocess.run(step2, shell=True)
|
|
|
|
|
|
class EnvironCntx:
|
|
def __init__(self, **kwargs):
|
|
self.updates = kwargs
|
|
self.undo = {}
|
|
|
|
def __enter__(self):
|
|
for k, v in self.updates.items():
|
|
if k in os.environ:
|
|
self.undo[k] = os.environ[k]
|
|
os.environ[k] = v
|
|
|
|
def __exit__(self, exc_type, exc_value, exc_traceback):
|
|
for k, v in self.updates.items():
|
|
if k in self.undo:
|
|
os.environ[k] = self.undo[k]
|
|
else:
|
|
del os.environ[k]
|
|
|
|
|
|
async def is_saving(c_client: aioredis.Redis):
|
|
return "saving:1" in (await c_client.execute_command("INFO PERSISTENCE"))
|
|
|
|
|
|
def assert_eventually(wrapped=None, *, times=100):
|
|
if wrapped is None:
|
|
return functools.partial(assert_eventually, times=100)
|
|
|
|
@wrapt.decorator
|
|
async def wrapper(wrapped, instance, args, kwargs):
|
|
for attempt in range(times):
|
|
try:
|
|
result = await wrapped(*args, **kwargs)
|
|
return result
|
|
except AssertionError as e:
|
|
if attempt == times - 1:
|
|
raise
|
|
await asyncio.sleep(0.1)
|
|
|
|
return wrapper(wrapped)
|
|
|
|
|
|
def skip_if_not_in_github():
|
|
if os.getenv("GITHUB_ACTIONS") == None:
|
|
pytest.skip("Redis server not found")
|
|
|
|
|
|
class ExpirySeeder:
|
|
def __init__(self):
|
|
self.stop_flag = False
|
|
self.i = 0
|
|
self.batch_size = 200
|
|
|
|
async def run(self, client):
|
|
while not self.stop_flag:
|
|
pipeline = client.pipeline(transaction=True)
|
|
for i in range(0, self.batch_size):
|
|
pipeline.execute_command(f"SET tmp{self.i} bar{self.i} EX 3")
|
|
self.i = self.i + 1
|
|
await pipeline.execute()
|
|
|
|
async def wait_until_n_inserts(self, count):
|
|
while not self.i > count:
|
|
await asyncio.sleep(0.5)
|
|
|
|
def stop(self):
|
|
self.stop_flag = True
|