1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-14 11:58:02 +00:00

feat: Add black formatter to the project (#1544)

Add black formatter and run it on pytests
This commit is contained in:
Kostas Kyrimis 2023-07-17 13:13:12 +03:00 committed by GitHub
parent 9448220607
commit 7944af3c62
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 793 additions and 566 deletions

View file

@ -24,3 +24,8 @@ repos:
hooks:
- id: clang-format
name: Clang formatting
- repo: https://github.com/psf/black
rev: 23.7.0
hooks:
- id: black

12
pyproject.toml Normal file
View file

@ -0,0 +1,12 @@
[tool.black]
line-length = 100
include = '\.py$'
extend-exclude = '''
/(
| .git
| .__pycache__
| build-dbg
| build-opt
| helio
)/
'''

View file

@ -34,7 +34,7 @@ class DflyInstance:
self.args = args
self.params = params
self.proc = None
self._client : Optional[RedisClient] = None
self._client: Optional[RedisClient] = None
def client(self) -> RedisClient:
return RedisClient(port=self.port)
@ -70,8 +70,7 @@ class DflyInstance:
return
base_args = [f"--{v}" for v in self.params.args]
all_args = self.format_args(self.args) + base_args
print(
f"Starting instance on {self.port} with arguments {all_args} from {self.params.path}")
print(f"Starting instance on {self.port} with arguments {all_args} from {self.params.path}")
run_cmd = [self.params.path, *all_args]
if self.params.gdb:
@ -82,8 +81,7 @@ class DflyInstance:
if not self.params.existing_port:
return_code = self.proc.poll()
if return_code is not None:
raise Exception(
f"Failed to start instance, return code {return_code}")
raise Exception(f"Failed to start instance, return code {return_code}")
def __getitem__(self, k):
return self.args.get(k)
@ -93,11 +91,13 @@ class DflyInstance:
if self.params.existing_port:
return self.params.existing_port
return int(self.args.get("port", "6379"))
@property
def admin_port(self) -> int:
if self.params.existing_admin_port:
return self.params.existing_admin_port
return int(self.args.get("admin_port", "16379"))
@property
def mc_port(self) -> int:
if self.params.existing_mc_port:
@ -107,7 +107,7 @@ class DflyInstance:
@staticmethod
def format_args(args):
out = []
for (k, v) in args.items():
for k, v in args.items():
out.append(f"--{k}")
if v is not None:
out.append(str(v))
@ -118,7 +118,10 @@ class DflyInstance:
resp = await session.get(f"http://localhost:{self.port}/metrics")
data = await resp.text()
await session.close()
return {metric_family.name : metric_family for metric_family in text_string_to_metric_families(data)}
return {
metric_family.name: metric_family
for metric_family in text_string_to_metric_families(data)
}
class DflyInstanceFactory:
@ -142,7 +145,7 @@ class DflyInstanceFactory:
return instance
def start_all(self, instances):
""" Start multiple instances in parallel """
"""Start multiple instances in parallel"""
for instance in instances:
instance._start()
@ -162,17 +165,17 @@ class DflyInstanceFactory:
def dfly_args(*args):
""" Used to define a singular set of arguments for dragonfly test """
"""Used to define a singular set of arguments for dragonfly test"""
return pytest.mark.parametrize("df_factory", args, indirect=True)
def dfly_multi_test_args(*args):
""" Used to define multiple sets of arguments to test multiple dragonfly configurations """
"""Used to define multiple sets of arguments to test multiple dragonfly configurations"""
return pytest.mark.parametrize("df_factory", args, indirect=True)
class PortPicker():
""" A simple port manager to allocate available ports for tests """
class PortPicker:
"""A simple port manager to allocate available ports for tests"""
def __init__(self):
self.next_port = 5555
@ -185,5 +188,6 @@ class PortPicker():
def is_port_available(self, port):
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) != 0
return s.connect_ex(("localhost", port)) != 0

View file

@ -13,9 +13,9 @@ BASE_PORT = 30001
async def push_config(config, admin_connections):
print("Pushing config ", config)
await asyncio.gather(*(c_admin.execute_command(
"DFLYCLUSTER", "CONFIG", config)
for c_admin in admin_connections))
await asyncio.gather(
*(c_admin.execute_command("DFLYCLUSTER", "CONFIG", config) for c_admin in admin_connections)
)
async def get_node_id(admin_connection):
@ -38,15 +38,13 @@ class TestNotEmulated:
@dfly_args({"cluster_mode": "emulated"})
class TestEmulated:
def test_cluster_slots_command(self, cluster_client: redis.RedisCluster):
expected = {(0, 16383): {'primary': (
'127.0.0.1', 6379), 'replicas': []}}
expected = {(0, 16383): {"primary": ("127.0.0.1", 6379), "replicas": []}}
res = cluster_client.execute_command("CLUSTER SLOTS")
assert expected == res
def test_cluster_help_command(self, cluster_client: redis.RedisCluster):
# `target_nodes` is necessary because CLUSTER HELP is not mapped on redis-py
res = cluster_client.execute_command(
"CLUSTER HELP", target_nodes=redis.RedisCluster.RANDOM)
res = cluster_client.execute_command("CLUSTER HELP", target_nodes=redis.RedisCluster.RANDOM)
assert "HELP" in res
assert "SLOTS" in res
@ -61,31 +59,31 @@ class TestEmulated:
@dfly_args({"cluster_mode": "emulated", "cluster_announce_ip": "127.0.0.2"})
class TestEmulatedWithAnnounceIp:
def test_cluster_slots_command(self, cluster_client: redis.RedisCluster):
expected = {(0, 16383): {'primary': (
'127.0.0.2', 6379), 'replicas': []}}
expected = {(0, 16383): {"primary": ("127.0.0.2", 6379), "replicas": []}}
res = cluster_client.execute_command("CLUSTER SLOTS")
assert expected == res
def verify_slots_result(ip: str, port: int, answer: list, rep_ip: str = None, rep_port: int = None) -> bool:
def verify_slots_result(
ip: str, port: int, answer: list, rep_ip: str = None, rep_port: int = None
) -> bool:
def is_local_host(ip: str) -> bool:
return ip == '127.0.0.1' or ip == 'localhost'
return ip == "127.0.0.1" or ip == "localhost"
assert answer[0] == 0 # start shard
assert answer[1] == 16383 # last shard
assert answer[0] == 0 # start shard
assert answer[1] == 16383 # last shard
if rep_ip is not None:
assert len(answer) == 4 # the network info
rep_info = answer[3]
assert len(rep_info) == 3
ip_addr = str(rep_info[0], 'utf-8')
assert ip_addr == rep_ip or (
is_local_host(ip_addr) and is_local_host(ip))
ip_addr = str(rep_info[0], "utf-8")
assert ip_addr == rep_ip or (is_local_host(ip_addr) and is_local_host(ip))
assert rep_info[1] == rep_port
else:
assert len(answer) == 3
info = answer[2]
assert len(info) == 3
ip_addr = str(info[0], 'utf-8')
ip_addr = str(info[0], "utf-8")
assert ip_addr == ip or (is_local_host(ip_addr) and is_local_host(ip))
assert info[1] == port
return True
@ -94,7 +92,7 @@ def verify_slots_result(ip: str, port: int, answer: list, rep_ip: str = None, re
@dfly_args({"proactor_threads": 4, "cluster_mode": "emulated"})
async def test_cluster_slots_in_replicas(df_local_factory):
master = df_local_factory.create(port=BASE_PORT)
replica = df_local_factory.create(port=BASE_PORT+1, logtostdout=True)
replica = df_local_factory.create(port=BASE_PORT + 1, logtostdout=True)
df_local_factory.start_all([master, replica])
@ -103,45 +101,46 @@ async def test_cluster_slots_in_replicas(df_local_factory):
res = await c_replica.execute_command("CLUSTER SLOTS")
assert len(res) == 1
assert verify_slots_result(
ip="127.0.0.1", port=replica.port, answer=res[0])
assert verify_slots_result(ip="127.0.0.1", port=replica.port, answer=res[0])
res = await c_master.execute_command("CLUSTER SLOTS")
assert verify_slots_result(
ip="127.0.0.1", port=master.port, answer=res[0])
assert verify_slots_result(ip="127.0.0.1", port=master.port, answer=res[0])
# Connect replica to master
rc = await c_replica.execute_command(f"REPLICAOF localhost {master.port}")
assert str(rc, 'utf-8') == "OK"
assert str(rc, "utf-8") == "OK"
await asyncio.sleep(0.5)
res = await c_replica.execute_command("CLUSTER SLOTS")
assert verify_slots_result(
ip="127.0.0.1", port=master.port, answer=res[0], rep_ip="127.0.0.1", rep_port=replica.port)
ip="127.0.0.1", port=master.port, answer=res[0], rep_ip="127.0.0.1", rep_port=replica.port
)
res = await c_master.execute_command("CLUSTER SLOTS")
assert verify_slots_result(
ip="127.0.0.1", port=master.port, answer=res[0], rep_ip="127.0.0.1", rep_port=replica.port)
ip="127.0.0.1", port=master.port, answer=res[0], rep_ip="127.0.0.1", rep_port=replica.port
)
@dfly_args({"cluster_mode": "emulated", "cluster_announce_ip": "127.0.0.2"})
async def test_cluster_info(async_client):
res = await async_client.execute_command("CLUSTER INFO")
assert len(res) == 16
assert res == {'cluster_current_epoch': '1',
'cluster_known_nodes': '1',
'cluster_my_epoch': '1',
'cluster_size': '1',
'cluster_slots_assigned': '16384',
'cluster_slots_fail': '0',
'cluster_slots_ok': '16384',
'cluster_slots_pfail': '0',
'cluster_state': 'ok',
'cluster_stats_messages_meet_received': '0',
'cluster_stats_messages_ping_received': '1',
'cluster_stats_messages_ping_sent': '1',
'cluster_stats_messages_pong_received': '1',
'cluster_stats_messages_pong_sent': '1',
'cluster_stats_messages_received': '1',
'cluster_stats_messages_sent': '1'
}
assert res == {
"cluster_current_epoch": "1",
"cluster_known_nodes": "1",
"cluster_my_epoch": "1",
"cluster_size": "1",
"cluster_slots_assigned": "16384",
"cluster_slots_fail": "0",
"cluster_slots_ok": "16384",
"cluster_slots_pfail": "0",
"cluster_state": "ok",
"cluster_stats_messages_meet_received": "0",
"cluster_stats_messages_ping_received": "1",
"cluster_stats_messages_ping_sent": "1",
"cluster_stats_messages_pong_received": "1",
"cluster_stats_messages_pong_sent": "1",
"cluster_stats_messages_received": "1",
"cluster_stats_messages_sent": "1",
}
@dfly_args({"cluster_mode": "emulated", "cluster_announce_ip": "127.0.0.2"})
@ -149,14 +148,14 @@ async def test_cluster_info(async_client):
async def test_cluster_nodes(async_client):
res = await async_client.execute_command("CLUSTER NODES")
assert len(res) == 1
info = res['127.0.0.2:6379']
info = res["127.0.0.2:6379"]
assert res is not None
assert info['connected'] == True
assert info['epoch'] == '0'
assert info['flags'] == 'myself,master'
assert info['last_ping_sent'] == '0'
assert info['slots'] == [['0', '16383']]
assert info['master_id'] == "-"
assert info["connected"] == True
assert info["epoch"] == "0"
assert info["flags"] == "myself,master"
assert info["last_ping_sent"] == "0"
assert info["slots"] == [["0", "16383"]]
assert info["master_id"] == "-"
"""
@ -166,11 +165,13 @@ Add a key to node0, then move the slot ownership to node1 and see that they both
intended.
Also add keys to each of them that are *not* moved, and see that they are unaffected by the move.
"""
@dfly_args({"proactor_threads": 4, "cluster_mode": "yes"})
async def test_cluster_slot_ownership_changes(df_local_factory):
# Start and configure cluster with 2 nodes
nodes = [
df_local_factory.create(port=BASE_PORT+i, admin_port=BASE_PORT+i+1000)
df_local_factory.create(port=BASE_PORT + i, admin_port=BASE_PORT + i + 1000)
for i in range(2)
]
@ -214,7 +215,10 @@ async def test_cluster_slot_ownership_changes(df_local_factory):
]
"""
await push_config(config.replace('LAST_SLOT_CUTOFF', '5259').replace('NEXT_SLOT_CUTOFF', '5260'), c_nodes_admin)
await push_config(
config.replace("LAST_SLOT_CUTOFF", "5259").replace("NEXT_SLOT_CUTOFF", "5260"),
c_nodes_admin,
)
# Slot for "KEY1" is 5259
@ -243,7 +247,10 @@ async def test_cluster_slot_ownership_changes(df_local_factory):
print("Moving ownership over 5259 ('KEY1') to other node")
await push_config(config.replace('LAST_SLOT_CUTOFF', '5258').replace('NEXT_SLOT_CUTOFF', '5259'), c_nodes_admin)
await push_config(
config.replace("LAST_SLOT_CUTOFF", "5258").replace("NEXT_SLOT_CUTOFF", "5259"),
c_nodes_admin,
)
# node0 should have removed "KEY1" as it no longer owns it
assert await c_nodes[0].execute_command("DBSIZE") == 1
@ -292,8 +299,8 @@ async def test_cluster_slot_ownership_changes(df_local_factory):
@dfly_args({"proactor_threads": 4, "cluster_mode": "yes"})
async def test_cluster_replica_sets_non_owned_keys(df_local_factory):
# Start and configure cluster with 1 master and 1 replica, both own all slots
master = df_local_factory.create(port=BASE_PORT, admin_port=BASE_PORT+1000)
replica = df_local_factory.create(port=BASE_PORT+1, admin_port=BASE_PORT+1001)
master = df_local_factory.create(port=BASE_PORT, admin_port=BASE_PORT + 1000)
replica = df_local_factory.create(port=BASE_PORT + 1, admin_port=BASE_PORT + 1001)
df_local_factory.start_all([master, replica])
c_master = aioredis.Redis(port=master.port)
@ -404,8 +411,8 @@ async def test_cluster_replica_sets_non_owned_keys(df_local_factory):
@dfly_args({"proactor_threads": 4, "cluster_mode": "yes"})
async def test_cluster_flush_slots_after_config_change(df_local_factory):
# Start and configure cluster with 1 master and 1 replica, both own all slots
master = df_local_factory.create(port=BASE_PORT, admin_port=BASE_PORT+1000)
replica = df_local_factory.create(port=BASE_PORT+1, admin_port=BASE_PORT+1001)
master = df_local_factory.create(port=BASE_PORT, admin_port=BASE_PORT + 1000)
replica = df_local_factory.create(port=BASE_PORT + 1, admin_port=BASE_PORT + 1001)
df_local_factory.start_all([master, replica])
c_master = aioredis.Redis(port=master.port)
@ -453,7 +460,7 @@ async def test_cluster_flush_slots_after_config_change(df_local_factory):
resp = await c_master_admin.execute_command("dflycluster", "getslotinfo", "slots", "0")
assert resp[0][0] == 0
slot_0_size = resp[0][2]
print(f'Slot 0 size = {slot_0_size}')
print(f"Slot 0 size = {slot_0_size}")
assert slot_0_size > 0
config = f"""
@ -512,7 +519,7 @@ async def test_cluster_flush_slots_after_config_change(df_local_factory):
async def test_cluster_native_client(df_local_factory):
# Start and configure cluster with 3 masters and 3 replicas
masters = [
df_local_factory.create(port=BASE_PORT+i, admin_port=BASE_PORT+i+1000)
df_local_factory.create(port=BASE_PORT + i, admin_port=BASE_PORT + i + 1000)
for i in range(3)
]
df_local_factory.start_all(masters)
@ -521,7 +528,7 @@ async def test_cluster_native_client(df_local_factory):
master_ids = await asyncio.gather(*(get_node_id(c) for c in c_masters_admin))
replicas = [
df_local_factory.create(port=BASE_PORT+100+i, admin_port=BASE_PORT+i+1100)
df_local_factory.create(port=BASE_PORT + 100 + i, admin_port=BASE_PORT + i + 1100)
for i in range(3)
]
df_local_factory.start_all(replicas)
@ -597,22 +604,23 @@ async def test_cluster_native_client(df_local_factory):
client = aioredis.RedisCluster(decode_responses=True, host="localhost", port=masters[0].port)
assert await client.set('key0', 'value') == True
assert await client.get('key0') == 'value'
assert await client.set("key0", "value") == True
assert await client.get("key0") == "value"
async def test_random_keys():
for i in range(100):
key = 'key' + str(random.randint(0, 100_000))
assert await client.set(key, 'value') == True
assert await client.get(key) == 'value'
key = "key" + str(random.randint(0, 100_000))
assert await client.set(key, "value") == True
assert await client.get(key) == "value"
await test_random_keys()
await asyncio.gather(*(wait_available_async(c) for c in c_replicas))
# Make sure that getting a value from a replica works as well.
replica_response = await client.execute_command(
'get', 'key0', target_nodes=aioredis.RedisCluster.REPLICAS)
assert 'value' in replica_response.values()
"get", "key0", target_nodes=aioredis.RedisCluster.REPLICAS
)
assert "value" in replica_response.values()
# Push new config
config = f"""

View file

@ -21,7 +21,7 @@ from tempfile import TemporaryDirectory
from . import DflyInstance, DflyInstanceFactory, DflyParams, PortPicker, dfly_args
from .utility import DflySeederFactory, gen_certificate
logging.getLogger('asyncio').setLevel(logging.WARNING)
logging.getLogger("asyncio").setLevel(logging.WARNING)
DATABASE_INDEX = 1
@ -55,7 +55,6 @@ def df_seeder_factory(request) -> DflySeederFactory:
if seed is None:
seed = random.randrange(sys.maxsize)
random.seed(int(seed))
print(f"--- Random seed: {seed}, check: {random.randrange(100)} ---")
@ -68,8 +67,7 @@ def df_factory(request, tmp_dir, test_env) -> DflyInstanceFactory:
Create an instance factory with supplied params.
"""
scripts_dir = os.path.dirname(os.path.abspath(__file__))
path = os.environ.get("DRAGONFLY_PATH", os.path.join(
scripts_dir, '../../build-dbg/dragonfly'))
path = os.environ.get("DRAGONFLY_PATH", os.path.join(scripts_dir, "../../build-dbg/dragonfly"))
args = request.param if request.param else {}
existing = request.config.getoption("--existing-port")
@ -83,7 +81,7 @@ def df_factory(request, tmp_dir, test_env) -> DflyInstanceFactory:
existing_port=int(existing) if existing else None,
existing_admin_port=int(existing_admin) if existing_admin else None,
existing_mc_port=int(existing_mc) if existing_mc else None,
env=test_env
env=test_env,
)
factory = DflyInstanceFactory(params, args)
@ -121,15 +119,15 @@ def df_server(df_factory: DflyInstanceFactory) -> DflyInstance:
# TODO: Investigate spurious open connection with cluster client
# if not instance['cluster_mode']:
# TODO: Investigate adding fine grain control over the pool by
# by adding a cache ontop of the clients connection pool and then evict
# properly with client.connection_pool.disconnect() avoiding non synced
# side effects
# assert clients_left == []
# TODO: Investigate adding fine grain control over the pool by
# by adding a cache ontop of the clients connection pool and then evict
# properly with client.connection_pool.disconnect() avoiding non synced
# side effects
# assert clients_left == []
# else:
# print("Cluster clients left: ", len(clients_left))
if instance['cluster_mode']:
if instance["cluster_mode"]:
print("Cluster clients left: ", len(clients_left))
@ -160,8 +158,7 @@ def cluster_client(df_server):
"""
Return a cluster client to the default instance with all entries flushed.
"""
client = redis.RedisCluster(decode_responses=True, host="localhost",
port=df_server.port)
client = redis.RedisCluster(decode_responses=True, host="localhost", port=df_server.port)
client.client_setname("default-cluster-fixture")
client.flushall()
@ -171,11 +168,17 @@ def cluster_client(df_server):
@pytest_asyncio.fixture(scope="function")
async def async_pool(df_server: DflyInstance):
pool = aioredis.ConnectionPool(host="localhost", port=df_server.port,
db=DATABASE_INDEX, decode_responses=True, max_connections=32)
pool = aioredis.ConnectionPool(
host="localhost",
port=df_server.port,
db=DATABASE_INDEX,
decode_responses=True,
max_connections=32,
)
yield pool
await pool.disconnect(inuse_connections=True)
@pytest_asyncio.fixture(scope="function")
async def async_client(async_pool):
"""
@ -197,25 +200,35 @@ def pytest_addoption(parser):
--existing-admin-port - to provide an admin port to an existing process instead of starting a new instance
--rand-seed - to set the global random seed
"""
parser.addoption("--gdb", action="store_true", default=False, help="Run instances in gdb")
parser.addoption("--df", action="append", default=[], help="Add arguments to dragonfly")
parser.addoption(
'--gdb', action='store_true', default=False, help='Run instances in gdb'
"--log-seeder", action="store", default=None, help="Store last generator commands in file"
)
parser.addoption(
'--df', action='append', default=[], help='Add arguments to dragonfly'
"--rand-seed",
action="store",
default=None,
help="Set seed for global random. Makes seeder predictable",
)
parser.addoption(
'--log-seeder', action='store', default=None, help='Store last generator commands in file'
"--existing-port",
action="store",
default=None,
help="Provide a port to the existing process for the test",
)
parser.addoption(
'--rand-seed', action='store', default=None, help='Set seed for global random. Makes seeder predictable'
"--existing-admin-port",
action="store",
default=None,
help="Provide an admin port to the existing process for the test",
)
parser.addoption(
'--existing-port', action='store', default=None, help='Provide a port to the existing process for the test')
parser.addoption(
'--existing-admin-port', action='store', default=None, help='Provide an admin port to the existing process for the test')
parser.addoption(
'--existing-mc-port', action='store', default=None, help='Provide a port to the existing memcached process for the test'
"--existing-mc-port",
action="store",
default=None,
help="Provide a port to the existing memcached process for the test",
)
@ -251,11 +264,15 @@ def with_tls_server_args(tmp_dir, gen_ca_cert):
tls_server_req = os.path.join(tmp_dir, "df-req.pem")
tls_server_cert = os.path.join(tmp_dir, "df-cert.pem")
gen_certificate(gen_ca_cert["ca_key"], gen_ca_cert["ca_cert"], tls_server_req, tls_server_key, tls_server_cert)
gen_certificate(
gen_ca_cert["ca_key"],
gen_ca_cert["ca_cert"],
tls_server_req,
tls_server_key,
tls_server_cert,
)
args = {"tls": "",
"tls_key_file": tls_server_key,
"tls_cert_file": tls_server_cert}
args = {"tls": "", "tls_key_file": tls_server_key, "tls_cert_file": tls_server_cert}
return args
@ -272,11 +289,15 @@ def with_tls_client_args(tmp_dir, gen_ca_cert):
tls_client_req = os.path.join(tmp_dir, "client-req.pem")
tls_client_cert = os.path.join(tmp_dir, "client-cert.pem")
gen_certificate(gen_ca_cert["ca_key"], gen_ca_cert["ca_cert"], tls_client_req, tls_client_key, tls_client_cert)
gen_certificate(
gen_ca_cert["ca_key"],
gen_ca_cert["ca_cert"],
tls_client_req,
tls_client_key,
tls_client_cert,
)
args = {"ssl": True,
"ssl_keyfile": tls_client_key,
"ssl_certfile": tls_client_cert}
args = {"ssl": True, "ssl_keyfile": tls_client_key, "ssl_certfile": tls_client_cert}
return args

View file

@ -9,6 +9,7 @@ from . import DflyInstance, dfly_args
BASE_PORT = 1111
async def run_monitor_eval(monitor, expected):
async with monitor as mon:
count = 0
@ -29,33 +30,32 @@ async def run_monitor_eval(monitor, expected):
return False
return True
'''
"""
Test issue https://github.com/dragonflydb/dragonfly/issues/756
Monitor command do not return when we have lua script issue
'''
"""
@pytest.mark.asyncio
async def test_monitor_command_lua(async_pool):
expected = ["EVAL return redis",
"EVAL return redis", "SET foo2"]
expected = ["EVAL return redis", "EVAL return redis", "SET foo2"]
conn = aioredis.Redis(connection_pool=async_pool)
monitor = conn.monitor()
cmd1 = aioredis.Redis(connection_pool=async_pool)
future = asyncio.create_task(run_monitor_eval(
monitor=monitor, expected=expected))
future = asyncio.create_task(run_monitor_eval(monitor=monitor, expected=expected))
await asyncio.sleep(0.1)
try:
res = await cmd1.eval(r'return redis.call("GET", "bar")', 0)
assert False # this will return an error
assert False # this will return an error
except Exception as e:
assert "script tried accessing undeclared key" in str(e)
try:
res = await cmd1.eval(r'return redis.call("SET", KEYS[1], ARGV[1])', 1, 'foo2', 'bar2')
res = await cmd1.eval(r'return redis.call("SET", KEYS[1], ARGV[1])', 1, "foo2", "bar2")
except Exception as e:
print(f"EVAL error: {e}")
assert False
@ -66,12 +66,12 @@ async def test_monitor_command_lua(async_pool):
assert status
'''
"""
Test the monitor command.
Open connection which is used for monitoring
Then send on other connection commands to dragonfly instance
Make sure that we are getting the commands in the monitor context
'''
"""
@pytest.mark.asyncio
@ -101,9 +101,11 @@ async def process_cmd(monitor, key, value):
if "select" not in response["command"].lower():
success = verify_response(response, key, value)
if not success:
print(
f"failed to verify message {response} for {key}/{value}")
return False, f"failed on the verification of the message {response} at {key}: {value}"
print(f"failed to verify message {response} for {key}/{value}")
return (
False,
f"failed on the verification of the message {response} at {key}: {value}",
)
else:
return True, None
except asyncio.TimeoutError:
@ -146,11 +148,11 @@ async def run_monitor(messages: dict, pool: aioredis.ConnectionPool):
return False, f"monitor result: {status}: {message}, set command success {success}"
'''
"""
Run test in pipeline mode.
This is mostly how this is done with python - its more like a transaction that
the connections is running all commands in its context
'''
"""
@pytest.mark.asyncio
@ -194,12 +196,12 @@ async def run_pipeline_mode(async_client: aioredis.Redis, messages):
return True, "all command processed successfully"
'''
"""
Test the pipeline command
Open connection to the subscriber and publish on the other end messages
Make sure that we are able to send all of them and that we are getting the
expected results on the subscriber side
'''
"""
@pytest.mark.asyncio
@ -232,7 +234,10 @@ async def run_pubsub(async_client, messages, channel_name):
if status and success:
return True, "successfully completed all"
else:
return False, f"subscriber result: {status}: {message}, publisher publish: success {success}"
return (
False,
f"subscriber result: {status}: {message}, publisher publish: success {success}",
)
async def run_multi_pubsub(async_client, messages, channel_name):
@ -241,7 +246,8 @@ async def run_multi_pubsub(async_client, messages, channel_name):
await s.subscribe(channel_name)
tasks = [
asyncio.create_task(reader(s, messages, random.randint(0, len(messages)))) for s in subs]
asyncio.create_task(reader(s, messages, random.randint(0, len(messages)))) for s in subs
]
success = True
@ -260,18 +266,18 @@ async def run_multi_pubsub(async_client, messages, channel_name):
if success:
for status, message in results:
if not status:
return False, f"failed to process {message}"
return False, f"failed to process {message}"
return True, "success"
else:
return False, "failed to publish"
'''
"""
Test with multiple subscribers for a channel
We want to stress this to see if we have any issue
with the pub sub code since we are "sharing" the message
across multiple connections internally
'''
"""
@pytest.mark.asyncio
@ -279,6 +285,7 @@ async def test_multi_pubsub(async_client):
def generate(max):
for i in range(max):
yield f"this is message number {i} from the publisher on the channel"
messages = [a for a in generate(500)]
state, message = await run_multi_pubsub(async_client, messages, "my-channel")
@ -288,8 +295,13 @@ async def test_multi_pubsub(async_client):
@pytest.mark.asyncio
async def test_subscribers_with_active_publisher(df_server: DflyInstance, max_connections=100):
# TODO: I am not how to customize the max connections for the pool.
async_pool = aioredis.ConnectionPool(host="localhost", port=df_server.port,
db=0, decode_responses=True, max_connections=max_connections)
async_pool = aioredis.ConnectionPool(
host="localhost",
port=df_server.port,
db=0,
decode_responses=True,
max_connections=max_connections,
)
async def publish_worker():
client = aioredis.Redis(connection_pool=async_pool)
@ -322,12 +334,12 @@ async def test_subscribers_with_active_publisher(df_server: DflyInstance, max_co
async def test_big_command(df_server, size=8 * 1024):
reader, writer = await asyncio.open_connection('127.0.0.1', df_server.port)
reader, writer = await asyncio.open_connection("127.0.0.1", df_server.port)
writer.write(f"SET a {'v'*size}\n".encode())
await writer.drain()
assert 'OK' in (await reader.readline()).decode()
assert "OK" in (await reader.readline()).decode()
writer.close()
await writer.wait_closed()
@ -335,9 +347,8 @@ async def test_big_command(df_server, size=8 * 1024):
async def test_subscribe_pipelined(async_client: aioredis.Redis):
pipe = async_client.pipeline(transaction=False)
pipe.execute_command('subscribe channel').execute_command(
'subscribe channel')
await pipe.echo('bye bye').execute()
pipe.execute_command("subscribe channel").execute_command("subscribe channel")
await pipe.echo("bye bye").execute()
async def test_subscribe_in_pipeline(async_client: aioredis.Redis):
@ -349,8 +360,8 @@ async def test_subscribe_in_pipeline(async_client: aioredis.Redis):
pipe.echo("three")
res = await pipe.execute()
assert res == ['one', ['subscribe', 'ch1', 1],
'two', ['subscribe', 'ch2', 2], 'three']
assert res == ["one", ["subscribe", "ch1", 1], "two", ["subscribe", "ch2", 2], "three"]
"""
This test makes sure that Dragonfly can receive blocks of pipelined commands even
@ -376,9 +387,13 @@ MGET m4 m5 m6
MGET m7 m8 m9\n
"""
PACKET3 = """
PACKET3 = (
"""
PING
""" * 500 + "ECHO DONE\n"
"""
* 500
+ "ECHO DONE\n"
)
async def test_parser_while_script_running(async_client: aioredis.Redis, df_server: DflyInstance):
@ -386,7 +401,7 @@ async def test_parser_while_script_running(async_client: aioredis.Redis, df_serv
# Use a raw tcp connection for strict control of sent commands
# Below we send commands while the previous ones didn't finish
reader, writer = await asyncio.open_connection('localhost', df_server.port)
reader, writer = await asyncio.open_connection("localhost", df_server.port)
# Send first pipeline packet, last commands is a long executing script
writer.write(PACKET1.format(sha=sha).encode())
@ -409,7 +424,9 @@ async def test_parser_while_script_running(async_client: aioredis.Redis, df_serv
@dfly_args({"proactor_threads": 1})
async def test_large_cmd(async_client: aioredis.Redis):
MAX_ARR_SIZE = 65535
res = await async_client.hset('foo', mapping={f"key{i}": f"val{i}" for i in range(MAX_ARR_SIZE // 2)})
res = await async_client.hset(
"foo", mapping={f"key{i}": f"val{i}" for i in range(MAX_ARR_SIZE // 2)}
)
assert res == MAX_ARR_SIZE // 2
res = await async_client.mset({f"key{i}": f"val{i}" for i in range(MAX_ARR_SIZE // 2)})
@ -421,7 +438,9 @@ async def test_large_cmd(async_client: aioredis.Redis):
@pytest.mark.asyncio
async def test_reject_non_tls_connections_on_tls(with_tls_server_args, df_local_factory):
server = df_local_factory.create(no_tls_on_admin_port="true", admin_port=1111, port=1211, **with_tls_server_args)
server = df_local_factory.create(
no_tls_on_admin_port="true", admin_port=1111, port=1211, **with_tls_server_args
)
server.start()
client = aioredis.Redis(port=server.port)

View file

@ -75,14 +75,12 @@ return 'OK'
"""
def DJANGO_CACHEOPS_SCHEMA(vs): return {
"table_1": [
{"f-1": f'v-{vs[0]}'}, {"f-2": f'v-{vs[1]}'}
],
"table_2": [
{"f-1": f'v-{vs[2]}'}, {"f-2": f'v-{vs[3]}'}
]
}
def DJANGO_CACHEOPS_SCHEMA(vs):
return {
"table_1": [{"f-1": f"v-{vs[0]}"}, {"f-2": f"v-{vs[1]}"}],
"table_2": [{"f-1": f"v-{vs[2]}"}, {"f-2": f"v-{vs[3]}"}],
}
"""
Test the main caching script of https://github.com/Suor/django-cacheops.
@ -91,21 +89,25 @@ so Dragonfly must run in global (1) or non-atomic (4) multi eval mode.
"""
@dfly_multi_test_args({'default_lua_flags': 'allow-undeclared-keys', 'proactor_threads': 4},
{'default_lua_flags': 'allow-undeclared-keys disable-atomicity', 'proactor_threads': 4})
@dfly_multi_test_args(
{"default_lua_flags": "allow-undeclared-keys", "proactor_threads": 4},
{"default_lua_flags": "allow-undeclared-keys disable-atomicity", "proactor_threads": 4},
)
async def test_django_cacheops_script(async_client, num_keys=500):
script = async_client.register_script(DJANGO_CACHEOPS_SCRIPT)
data = [(f'k-{k}', [random.randint(0, 10) for _ in range(4)])
for k in range(num_keys)]
data = [(f"k-{k}", [random.randint(0, 10) for _ in range(4)]) for k in range(num_keys)]
for k, vs in data:
schema = DJANGO_CACHEOPS_SCHEMA(vs)
assert await script(keys=['', k, ''], args=['a' * 10, json.dumps(schema, sort_keys=True), 100]) == 'OK'
assert (
await script(keys=["", k, ""], args=["a" * 10, json.dumps(schema, sort_keys=True), 100])
== "OK"
)
# Check schema was built correctly
base_schema = DJANGO_CACHEOPS_SCHEMA([0] * 4)
for table, fields in base_schema.items():
schema = await async_client.smembers(f'schemes:{table}')
schema = await async_client.smembers(f"schemes:{table}")
fields = set.union(*(set(part.keys()) for part in fields))
assert schema == fields
@ -114,9 +116,9 @@ async def test_django_cacheops_script(async_client, num_keys=500):
assert await async_client.exists(k)
for table, fields in DJANGO_CACHEOPS_SCHEMA(vs).items():
for sub_schema in fields:
conj_key = f'conj:{table}:' + \
'&'.join("{}={}".format(f, v)
for f, v in sub_schema.items())
conj_key = f"conj:{table}:" + "&".join(
"{}={}".format(f, v) for f, v in sub_schema.items()
)
assert await async_client.sismember(conj_key, k)
@ -158,24 +160,29 @@ the task system should work reliably.
"""
@dfly_multi_test_args({'default_lua_flags': 'allow-undeclared-keys', 'proactor_threads': 4},
{'default_lua_flags': 'allow-undeclared-keys disable-atomicity', 'proactor_threads': 4})
@dfly_multi_test_args(
{"default_lua_flags": "allow-undeclared-keys", "proactor_threads": 4},
{"default_lua_flags": "allow-undeclared-keys disable-atomicity", "proactor_threads": 4},
)
async def test_golang_asynq_script(async_pool, num_queues=10, num_tasks=100):
async def enqueue_worker(queue):
client = aioredis.Redis(connection_pool=async_pool)
enqueue = client.register_script(ASYNQ_ENQUEUE_SCRIPT)
task_ids = 2*list(range(num_tasks))
task_ids = 2 * list(range(num_tasks))
random.shuffle(task_ids)
res = [await enqueue(keys=[f"asynq:{{{queue}}}:t:{task_id}", f"asynq:{{{queue}}}:pending"],
args=[f"{task_id}", task_id, int(time.time())])
for task_id in task_ids]
res = [
await enqueue(
keys=[f"asynq:{{{queue}}}:t:{task_id}", f"asynq:{{{queue}}}:pending"],
args=[f"{task_id}", task_id, int(time.time())],
)
for task_id in task_ids
]
assert sum(res) == num_tasks
# Start filling the queues
jobs = [asyncio.create_task(enqueue_worker(
f"q-{queue}")) for queue in range(num_queues)]
jobs = [asyncio.create_task(enqueue_worker(f"q-{queue}")) for queue in range(num_queues)]
collected = 0
@ -185,15 +192,19 @@ async def test_golang_asynq_script(async_pool, num_queues=10, num_tasks=100):
dequeue = client.register_script(ASYNQ_DEQUE_SCRIPT)
while collected < num_tasks * num_queues:
#pct = round(collected/(num_tasks*num_queues), 2)
#print(f'\r \r{pct}', end='', flush=True)
# pct = round(collected/(num_tasks*num_queues), 2)
# print(f'\r \r{pct}', end='', flush=True)
for queue in (f"q-{queue}" for queue in range(num_queues)):
prefix = f"asynq:{{{queue}}}:t:"
msg = await dequeue(keys=[f"asynq:{{{queue}}}:"+t for t in ["pending", "paused", "active", "lease"]],
args=[int(time.time()), prefix])
msg = await dequeue(
keys=[
f"asynq:{{{queue}}}:" + t for t in ["pending", "paused", "active", "lease"]
],
args=[int(time.time()), prefix],
)
if msg is not None:
collected += 1
assert await client.hget(prefix+msg, 'state') == 'active'
assert await client.hget(prefix + msg, "state") == "active"
# Run many contending workers
await asyncio.gather(*(dequeue_worker() for _ in range(num_queues * 2)))
@ -204,19 +215,19 @@ async def test_golang_asynq_script(async_pool, num_queues=10, num_tasks=100):
ERROR_CALL_SCRIPT_TEMPLATE = [
"redis.{}('LTRIM', 'l', 'a', 'b')", # error only on evaluation
"redis.{}('obviously wrong')" # error immediately on preprocessing
"redis.{}('obviously wrong')", # error immediately on preprocessing
]
@dfly_args({"proactor_threads": 1})
@pytest.mark.asyncio
async def test_eval_error_propagation(async_client):
CMDS = ['call', 'pcall', 'acall', 'apcall']
CMDS = ["call", "pcall", "acall", "apcall"]
for cmd, template in itertools.product(CMDS, ERROR_CALL_SCRIPT_TEMPLATE):
does_abort = 'p' not in cmd
does_abort = "p" not in cmd
try:
await async_client.eval(template.format(cmd), 1, 'l')
await async_client.eval(template.format(cmd), 1, "l")
if does_abort:
assert False, "Eval must have thrown an error: " + cmd
except aioredis.RedisError as e:
@ -230,12 +241,12 @@ async def test_global_eval_in_multi(async_client: aioredis.Redis):
return redis.call('GET', 'any-key');
"""
await async_client.set('any-key', 'works')
await async_client.set("any-key", "works")
pipe = async_client.pipeline(transaction=True)
pipe.set('another-key', 'ok')
pipe.set("another-key", "ok")
pipe.eval(GLOBAL_SCRIPT, 0)
res = await pipe.execute()
print(res)
assert res[1] == 'works'
assert res[1] == "works"

View file

@ -8,22 +8,24 @@ from . import dfly_multi_test_args, dfly_args
from .utility import batch_fill_data, gen_test_data
@dfly_multi_test_args({'keys_output_limit': 512}, {'keys_output_limit': 1024})
@dfly_multi_test_args({"keys_output_limit": 512}, {"keys_output_limit": 1024})
class TestKeys:
async def test_max_keys(self, async_client: aioredis.Redis, df_server):
max_keys = df_server['keys_output_limit']
max_keys = df_server["keys_output_limit"]
pipe = async_client.pipeline()
batch_fill_data(pipe, gen_test_data(max_keys * 3))
await pipe.execute()
keys = await async_client.keys()
assert len(keys) in range(max_keys, max_keys+512)
assert len(keys) in range(max_keys, max_keys + 512)
@pytest.fixture(scope="function")
def export_dfly_password() -> str:
pwd = 'flypwd'
os.environ['DFLY_PASSWORD'] = pwd
pwd = "flypwd"
os.environ["DFLY_PASSWORD"] = pwd
yield pwd
del os.environ['DFLY_PASSWORD']
del os.environ["DFLY_PASSWORD"]
async def test_password(df_local_factory, export_dfly_password):
dfly = df_local_factory.create()
@ -38,7 +40,7 @@ async def test_password(df_local_factory, export_dfly_password):
dfly.stop()
# --requirepass should take precedence over environment variable
requirepass = 'requirepass'
requirepass = "requirepass"
dfly = df_local_factory.create(requirepass=requirepass)
dfly.start()
@ -61,6 +63,7 @@ for i = 0, ARGV[1] do
end
"""
@dfly_args({"proactor_threads": 1})
async def test_txq_ooo(async_client: aioredis.Redis, df_server):
async def task1(k, h):
@ -77,4 +80,6 @@ async def test_txq_ooo(async_client: aioredis.Redis, df_server):
pipe.blpop(k, 0.001)
await pipe.execute()
await asyncio.gather(task1('i1', 2), task1('i2', 3), task2('l1', 2), task2('l1', 2), task2('l1', 5))
await asyncio.gather(
task1("i1", 2), task1("i2", 3), task2("l1", 2), task2("l1", 2), task2("l1", 5)
)

View file

@ -4,15 +4,9 @@ from redis import asyncio as aioredis
from .utility import *
from json import JSONDecoder, JSONEncoder
jane = {
'name': "Jane",
'Age': 33,
'Location': "Chawton"
}
jane = {"name": "Jane", "Age": 33, "Location": "Chawton"}
json_num = {
"a": {"a": 1, "b": 2, "c": 3}
}
json_num = {"a": {"a": 1, "b": 2, "c": 3}}
async def get_set_json(connection: aioredis.Redis, key, value, path="$"):
@ -30,8 +24,8 @@ async def test_basic_json_get_set(async_client: aioredis.Redis):
the_type = await async_client.type(key_name)
assert the_type == "ReJSON-RL"
assert len(result) == 1
assert result[0]['name'] == 'Jane'
assert result[0]['Age'] == 33
assert result[0]["name"] == "Jane"
assert result[0]["Age"] == 33
async def test_access_json_value_as_string(async_client: aioredis.Redis):
@ -76,18 +70,16 @@ async def test_update_value(async_client: aioredis.Redis):
# make sure that we have valid JSON here
the_type = await async_client.type(key_name)
assert the_type == "ReJSON-RL"
result = await get_set_json(async_client, value="0",
key=key_name, path="$.a.*")
result = await get_set_json(async_client, value="0", key=key_name, path="$.a.*")
assert len(result) == 3
# make sure that all the values under 'a' where set to 0
assert result == ['0', '0', '0']
assert result == ["0", "0", "0"]
# Ensure that after we're changing this into STRING type, it will no longer work
await async_client.set(key_name, "some random value")
assert await async_client.type(key_name) == "string"
try:
await get_set_json(async_client, value="0", key=key_name,
path="$.a.*")
await get_set_json(async_client, value="0", key=key_name, path="$.a.*")
assert False, "should not be able to modify JSON value as string"
except redis.exceptions.ResponseError as e:
assert e.args[0] == "WRONGTYPE Operation against a key holding the wrong kind of value"

View file

@ -4,20 +4,19 @@ from redis import asyncio as aioredis
import pytest
@pytest.mark.parametrize('index', range(50))
@pytest.mark.parametrize("index", range(50))
class TestBlPop:
async def async_blpop(client: aioredis.Redis):
return await client.blpop(
['list1{t}', 'list2{t}', 'list2{t}', 'list1{t}'], 0.5)
return await client.blpop(["list1{t}", "list2{t}", "list2{t}", "list1{t}"], 0.5)
async def blpop_mult_keys(async_client: aioredis.Redis, key: str, val: str):
task = asyncio.create_task(TestBlPop.async_blpop(async_client))
await async_client.lpush(key, val)
result = await asyncio.wait_for(task, 3)
assert result[1] == val
watched = await async_client.execute_command('DEBUG WATCHED')
assert watched == ['awaked', [], 'watched', []]
watched = await async_client.execute_command("DEBUG WATCHED")
assert watched == ["awaked", [], "watched", []]
async def test_blpop_multiple_keys(self, async_client: aioredis.Redis, index):
await TestBlPop.blpop_mult_keys(async_client, 'list1{t}', 'a')
await TestBlPop.blpop_mult_keys(async_client, 'list2{t}', 'b')
await TestBlPop.blpop_mult_keys(async_client, "list1{t}", "a")
await TestBlPop.blpop_mult_keys(async_client, "list2{t}", "b")

View file

@ -8,12 +8,14 @@ def test_add_get(memcached_connection):
assert memcached_connection.add(b"key", b"data", noreply=False)
assert memcached_connection.get(b"key") == b"data"
@dfly_args({"memcached_port": 11211})
def test_add_set(memcached_connection):
assert memcached_connection.add(b"key", b"data", noreply=False)
memcached_connection.set(b"key", b"other")
assert memcached_connection.get(b"key") == b"other"
@dfly_args({"memcached_port": 11211})
def test_set_add(memcached_connection):
memcached_connection.set(b"key", b"data")
@ -23,6 +25,7 @@ def test_set_add(memcached_connection):
memcached_connection.set(b"key", b"other")
assert memcached_connection.get(b"key") == b"other"
@dfly_args({"memcached_port": 11211})
def test_mixed_reply(memcached_connection):
memcached_connection.set(b"key", b"data", noreply=True)

View file

@ -12,13 +12,17 @@ class RedisServer:
self.proc = None
def start(self):
self.proc = subprocess.Popen(["redis-server-6.2.11",
f"--port {self.port}",
"--save ''",
"--appendonly no",
"--protected-mode no",
"--repl-diskless-sync yes",
"--repl-diskless-sync-delay 0"])
self.proc = subprocess.Popen(
[
"redis-server-6.2.11",
f"--port {self.port}",
"--save ''",
"--appendonly no",
"--protected-mode no",
"--repl-diskless-sync yes",
"--repl-diskless-sync-delay 0",
]
)
print(self.proc.args)
def stop(self):
@ -28,6 +32,7 @@ class RedisServer:
except Exception as e:
pass
# Checks that master redis and dragonfly replica are synced by writing a random key to master
# and waiting for it to exist in replica. Foreach db in 0..dbcount-1.
async def await_synced(master_port, replica_port, dbcount=1):
@ -58,7 +63,7 @@ async def await_synced_all(c_master, c_replicas):
async def check_data(seeder, replicas, c_replicas):
capture = await seeder.capture()
for (replica, c_replica) in zip(replicas, c_replicas):
for replica, c_replica in zip(replicas, c_replicas):
await wait_available_async(c_replica)
assert await seeder.compare(capture, port=replica.port)
@ -84,7 +89,9 @@ full_sync_replication_specs = [
@pytest.mark.parametrize("t_replicas, seeder_config", full_sync_replication_specs)
async def test_replication_full_sync(df_local_factory, df_seeder_factory, redis_server, t_replicas, seeder_config, port_picker):
async def test_replication_full_sync(
df_local_factory, df_seeder_factory, redis_server, t_replicas, seeder_config, port_picker
):
master = redis_server
c_master = aioredis.Redis(port=master.port)
assert await c_master.ping()
@ -93,7 +100,8 @@ async def test_replication_full_sync(df_local_factory, df_seeder_factory, redis_
await seeder.run(target_deviation=0.1)
replica = df_local_factory.create(
port=port_picker.get_available_port(), proactor_threads=t_replicas[0])
port=port_picker.get_available_port(), proactor_threads=t_replicas[0]
)
replica.start()
c_replica = aioredis.Redis(port=replica.port)
assert await c_replica.ping()
@ -105,6 +113,7 @@ async def test_replication_full_sync(df_local_factory, df_seeder_factory, redis_
capture = await seeder.capture()
assert await seeder.compare(capture, port=replica.port)
stable_sync_replication_specs = [
([1], dict(keys=100, dbcount=1, unsupported_types=[ValueType.JSON])),
([1], dict(keys=10_000, dbcount=2, unsupported_types=[ValueType.JSON])),
@ -115,13 +124,16 @@ stable_sync_replication_specs = [
@pytest.mark.parametrize("t_replicas, seeder_config", stable_sync_replication_specs)
async def test_replication_stable_sync(df_local_factory, df_seeder_factory, redis_server, t_replicas, seeder_config, port_picker):
async def test_replication_stable_sync(
df_local_factory, df_seeder_factory, redis_server, t_replicas, seeder_config, port_picker
):
master = redis_server
c_master = aioredis.Redis(port=master.port)
assert await c_master.ping()
replica = df_local_factory.create(
port=port_picker.get_available_port(), proactor_threads=t_replicas[0])
port=port_picker.get_available_port(), proactor_threads=t_replicas[0]
)
replica.start()
c_replica = aioredis.Redis(port=replica.port)
assert await c_replica.ping()
@ -150,7 +162,9 @@ replication_specs = [
@pytest.mark.parametrize("t_replicas, seeder_config", replication_specs)
async def test_redis_replication_all(df_local_factory, df_seeder_factory, redis_server, t_replicas, seeder_config, port_picker):
async def test_redis_replication_all(
df_local_factory, df_seeder_factory, redis_server, t_replicas, seeder_config, port_picker
):
master = redis_server
c_master = aioredis.Redis(port=master.port)
assert await c_master.ping()
@ -178,11 +192,11 @@ async def test_redis_replication_all(df_local_factory, df_seeder_factory, redis_
await c_replica.execute_command("REPLICAOF localhost " + str(master.port))
await wait_available_async(c_replica)
await asyncio.gather(*(asyncio.create_task(run_replication(c))
for c in c_replicas))
await asyncio.gather(*(asyncio.create_task(run_replication(c)) for c in c_replicas))
# Wait for streaming to finish
assert not stream_task.done(
assert (
not stream_task.done()
), "Weak testcase. Increase number of streamed iterations to surpass full sync"
seeder.stop()
await stream_task
@ -206,7 +220,15 @@ master_disconnect_cases = [
@pytest.mark.parametrize("t_replicas, t_disconnect, seeder_config", master_disconnect_cases)
async def test_disconnect_master(df_local_factory, df_seeder_factory, redis_server, t_replicas, t_disconnect, seeder_config, port_picker):
async def test_disconnect_master(
df_local_factory,
df_seeder_factory,
redis_server,
t_replicas,
t_disconnect,
seeder_config,
port_picker,
):
master = redis_server
c_master = aioredis.Redis(port=master.port)
assert await c_master.ping()
@ -234,11 +256,11 @@ async def test_disconnect_master(df_local_factory, df_seeder_factory, redis_serv
await c_replica.execute_command("REPLICAOF localhost " + str(master.port))
await wait_available_async(c_replica)
await asyncio.gather(*(asyncio.create_task(run_replication(c))
for c in c_replicas))
await asyncio.gather(*(asyncio.create_task(run_replication(c)) for c in c_replicas))
# Wait for streaming to finish
assert not stream_task.done(
assert (
not stream_task.done()
), "Weak testcase. Increase number of streamed iterations to surpass full sync"
seeder.stop()
await stream_task

View file

@ -35,10 +35,12 @@ replication_cases = [
@pytest.mark.asyncio
@pytest.mark.parametrize("t_master, t_replicas, seeder_config", replication_cases)
async def test_replication_all(df_local_factory, df_seeder_factory, t_master, t_replicas, seeder_config):
async def test_replication_all(
df_local_factory, df_seeder_factory, t_master, t_replicas, seeder_config
):
master = df_local_factory.create(port=BASE_PORT, proactor_threads=t_master)
replicas = [
df_local_factory.create(port=BASE_PORT+i+1, proactor_threads=t)
df_local_factory.create(port=BASE_PORT + i + 1, proactor_threads=t)
for i, t in enumerate(t_replicas)
]
@ -63,11 +65,11 @@ async def test_replication_all(df_local_factory, df_seeder_factory, t_master, t_
async def run_replication(c_replica):
await c_replica.execute_command("REPLICAOF localhost " + str(master.port))
await asyncio.gather(*(asyncio.create_task(run_replication(c))
for c in c_replicas))
await asyncio.gather(*(asyncio.create_task(run_replication(c)) for c in c_replicas))
# Wait for streaming to finish
assert not stream_task.done(
assert (
not stream_task.done()
), "Weak testcase. Increase number of streamed iterations to surpass full sync"
await stream_task
@ -87,7 +89,7 @@ async def check_replica_finished_exec(c_replica, c_master):
syncid, r_offset = await c_replica.execute_command("DEBUG REPLICA OFFSET")
m_offset = await c_master.execute_command("DFLY REPLICAOFFSET")
print(" offset", syncid.decode(), r_offset, m_offset)
print(" offset", syncid.decode(), r_offset, m_offset)
return r_offset == m_offset
@ -98,18 +100,16 @@ async def check_all_replicas_finished(c_replicas, c_master):
while len(waiting_for) > 0:
await asyncio.sleep(1.0)
tasks = (asyncio.create_task(check_replica_finished_exec(c, c_master))
for c in waiting_for)
tasks = (asyncio.create_task(check_replica_finished_exec(c, c_master)) for c in waiting_for)
finished_list = await asyncio.gather(*tasks)
# Remove clients that finished from waiting list
waiting_for = [c for (c, finished) in zip(
waiting_for, finished_list) if not finished]
waiting_for = [c for (c, finished) in zip(waiting_for, finished_list) if not finished]
async def check_data(seeder, replicas, c_replicas):
capture = await seeder.capture()
for (replica, c_replica) in zip(replicas, c_replicas):
for replica, c_replica in zip(replicas, c_replicas):
await wait_available_async(c_replica)
assert await seeder.compare(capture, port=replica.port)
@ -140,22 +140,30 @@ disconnect_cases = [
# stable state heavy
(8, [], [4] * 4, [], 4_000),
# disconnect only
(8, [], [], [4] * 4, 4_000)
(8, [], [], [4] * 4, 4_000),
]
@pytest.mark.skip(reason='Failing on github regression action')
@pytest.mark.skip(reason="Failing on github regression action")
@pytest.mark.asyncio
@pytest.mark.parametrize("t_master, t_crash_fs, t_crash_ss, t_disonnect, n_keys", disconnect_cases)
async def test_disconnect_replica(df_local_factory: DflyInstanceFactory, df_seeder_factory, t_master, t_crash_fs, t_crash_ss, t_disonnect, n_keys):
async def test_disconnect_replica(
df_local_factory: DflyInstanceFactory,
df_seeder_factory,
t_master,
t_crash_fs,
t_crash_ss,
t_disonnect,
n_keys,
):
master = df_local_factory.create(port=BASE_PORT, proactor_threads=t_master)
replicas = [
(df_local_factory.create(port=BASE_PORT+i+1, proactor_threads=t), crash_fs)
(df_local_factory.create(port=BASE_PORT + i + 1, proactor_threads=t), crash_fs)
for i, (t, crash_fs) in enumerate(
chain(
zip(t_crash_fs, repeat(DISCONNECT_CRASH_FULL_SYNC)),
zip(t_crash_ss, repeat(DISCONNECT_CRASH_STABLE_SYNC)),
zip(t_disonnect, repeat(DISCONNECT_NORMAL_STABLE_SYNC))
zip(t_disonnect, repeat(DISCONNECT_NORMAL_STABLE_SYNC)),
)
)
]
@ -168,15 +176,11 @@ async def test_disconnect_replica(df_local_factory: DflyInstanceFactory, df_seed
df_local_factory.start_all([replica for replica, _ in replicas])
c_replicas = [
(replica, aioredis.Redis(port=replica.port), crash_type)
for replica, crash_type in replicas
(replica, aioredis.Redis(port=replica.port), crash_type) for replica, crash_type in replicas
]
def replicas_of_type(tfunc):
return [
args for args in c_replicas
if tfunc(args[2])
]
return [args for args in c_replicas if tfunc(args[2])]
# Start data fill loop
seeder = df_seeder_factory.create(port=master.port, keys=n_keys, dbcount=2)
@ -187,7 +191,7 @@ async def test_disconnect_replica(df_local_factory: DflyInstanceFactory, df_seed
c_replica = aioredis.Redis(port=replica.port)
await c_replica.execute_command("REPLICAOF localhost " + str(master.port))
if crash_type == 0:
await asyncio.sleep(random.random()/100+0.01)
await asyncio.sleep(random.random() / 100 + 0.01)
await c_replica.connection_pool.disconnect()
replica.stop(kill=True)
else:
@ -211,8 +215,7 @@ async def test_disconnect_replica(df_local_factory: DflyInstanceFactory, df_seed
await c_replica.connection_pool.disconnect()
replica.stop(kill=True)
await asyncio.gather(*(stable_sync(*args) for args
in replicas_of_type(lambda t: t == 1)))
await asyncio.gather(*(stable_sync(*args) for args in replicas_of_type(lambda t: t == 1)))
# Check master survived all crashes
assert await c_master.ping()
@ -241,8 +244,7 @@ async def test_disconnect_replica(df_local_factory: DflyInstanceFactory, df_seed
await asyncio.sleep(random.random() / 100)
await c_replica.execute_command("REPLICAOF NO ONE")
await asyncio.gather(*(disconnect(*args) for args
in replicas_of_type(lambda t: t == 2)))
await asyncio.gather(*(disconnect(*args) for args in replicas_of_type(lambda t: t == 2)))
await asyncio.sleep(0.5)
@ -282,10 +284,12 @@ master_crash_cases = [
@pytest.mark.asyncio
@pytest.mark.parametrize("t_master, t_replicas, n_random_crashes, n_keys", master_crash_cases)
async def test_disconnect_master(df_local_factory, df_seeder_factory, t_master, t_replicas, n_random_crashes, n_keys):
async def test_disconnect_master(
df_local_factory, df_seeder_factory, t_master, t_replicas, n_random_crashes, n_keys
):
master = df_local_factory.create(port=1111, proactor_threads=t_master)
replicas = [
df_local_factory.create(port=BASE_PORT+i+1, proactor_threads=t)
df_local_factory.create(port=BASE_PORT + i + 1, proactor_threads=t)
for i, t in enumerate(t_replicas)
]
@ -309,8 +313,12 @@ async def test_disconnect_master(df_local_factory, df_seeder_factory, t_master,
await start_master()
# Crash master during full sync, but with all passing initial connection phase
await asyncio.gather(*(c_replica.execute_command("REPLICAOF localhost " + str(master.port))
for c_replica in c_replicas))
await asyncio.gather(
*(
c_replica.execute_command("REPLICAOF localhost " + str(master.port))
for c_replica in c_replicas
)
)
await crash_master_fs()
await asyncio.sleep(1 + len(replicas) * 0.5)
@ -338,24 +346,25 @@ async def test_disconnect_master(df_local_factory, df_seeder_factory, t_master,
await wait_available_async(c_replica)
assert await seeder.compare(capture, port=replica.port)
"""
Test re-connecting replica to different masters.
"""
rotating_master_cases = [
(4, [4, 4, 4, 4], dict(keys=2_000, dbcount=4))
]
rotating_master_cases = [(4, [4, 4, 4, 4], dict(keys=2_000, dbcount=4))]
@pytest.mark.asyncio
@pytest.mark.parametrize("t_replica, t_masters, seeder_config", rotating_master_cases)
async def test_rotating_masters(df_local_factory, df_seeder_factory, t_replica, t_masters, seeder_config):
replica = df_local_factory.create(
port=BASE_PORT, proactor_threads=t_replica)
masters = [df_local_factory.create(
port=BASE_PORT+i+1, proactor_threads=t) for i, t in enumerate(t_masters)]
seeders = [df_seeder_factory.create(
port=m.port, **seeder_config) for m in masters]
async def test_rotating_masters(
df_local_factory, df_seeder_factory, t_replica, t_masters, seeder_config
):
replica = df_local_factory.create(port=BASE_PORT, proactor_threads=t_replica)
masters = [
df_local_factory.create(port=BASE_PORT + i + 1, proactor_threads=t)
for i, t in enumerate(t_masters)
]
seeders = [df_seeder_factory.create(port=m.port, **seeder_config) for m in masters]
df_local_factory.start_all([replica] + masters)
@ -398,7 +407,7 @@ async def test_cancel_replication_immediately(df_local_factory, df_seeder_factor
COMMANDS_TO_ISSUE = 40
replica = df_local_factory.create(port=BASE_PORT)
masters = [df_local_factory.create(port=BASE_PORT+i+1) for i in range(4)]
masters = [df_local_factory.create(port=BASE_PORT + i + 1) for i in range(4)]
seeders = [df_seeder_factory.create(port=m.port) for m in masters]
df_local_factory.start_all([replica] + masters)
@ -443,7 +452,7 @@ Check replica keys at the end.
@pytest.mark.asyncio
async def test_flushall(df_local_factory):
master = df_local_factory.create(port=BASE_PORT, proactor_threads=4)
replica = df_local_factory.create(port=BASE_PORT+1, proactor_threads=2)
replica = df_local_factory.create(port=BASE_PORT + 1, proactor_threads=2)
master.start()
replica.start()
@ -465,8 +474,7 @@ async def test_flushall(df_local_factory):
# flushall
pipe.flushall()
# Set simple keys n_keys..n_keys*2 on master
batch_fill_data(client=pipe, gen=gen_test_data(
n_keys, n_keys*2), batch_size=3)
batch_fill_data(client=pipe, gen=gen_test_data(n_keys, n_keys * 2), batch_size=3)
await pipe.execute()
# Check replica finished executing the replicated commands
@ -480,7 +488,7 @@ async def test_flushall(df_local_factory):
assert all(v is None for v in vals)
# Check replica keys n_keys..n_keys*2-1 exist
for i in range(n_keys, n_keys*2):
for i in range(n_keys, n_keys * 2):
pipe.get(f"key-{i}")
vals = await pipe.execute()
assert all(v is not None for v in vals)
@ -494,11 +502,11 @@ Test journal rewrites.
@dfly_args({"proactor_threads": 4})
@pytest.mark.asyncio
async def test_rewrites(df_local_factory):
CLOSE_TIMESTAMP = (int(time.time()) + 100)
CLOSE_TIMESTAMP = int(time.time()) + 100
CLOSE_TIMESTAMP_MS = CLOSE_TIMESTAMP * 1000
master = df_local_factory.create(port=BASE_PORT)
replica = df_local_factory.create(port=BASE_PORT+1)
replica = df_local_factory.create(port=BASE_PORT + 1)
master.start()
replica.start()
@ -513,16 +521,16 @@ async def test_rewrites(df_local_factory):
m_replica = c_replica.monitor()
async def get_next_command():
mcmd = (await m_replica.next_command())['command']
mcmd = (await m_replica.next_command())["command"]
# skip select command
if (mcmd == "SELECT 0"):
if mcmd == "SELECT 0":
print("Got:", mcmd)
mcmd = (await m_replica.next_command())['command']
mcmd = (await m_replica.next_command())["command"]
print("Got:", mcmd)
return mcmd
async def is_match_rsp(rx):
mcmd = (await get_next_command())
mcmd = await get_next_command()
print("Regex:", rx)
return re.match(rx, mcmd)
@ -531,14 +539,14 @@ async def test_rewrites(df_local_factory):
async def check(cmd, rx):
await c_master.execute_command(cmd)
match = (await is_match_rsp(rx))
match = await is_match_rsp(rx)
assert match
async def check_list(cmd, rx_list):
print("master cmd:", cmd)
await c_master.execute_command(cmd)
for rx in rx_list:
match = (await is_match_rsp(rx))
match = await is_match_rsp(rx)
assert match
async def check_list_ooo(cmd, rx_list):
@ -546,7 +554,7 @@ async def test_rewrites(df_local_factory):
await c_master.execute_command(cmd)
expected_cmds = len(rx_list)
for i in range(expected_cmds):
mcmd = (await get_next_command())
mcmd = await get_next_command()
# check command matches one regex from list
match_rx = list(filter(lambda rx: re.match(rx, mcmd), rx_list))
assert len(match_rx) == 1
@ -556,7 +564,7 @@ async def test_rewrites(df_local_factory):
ttl1 = await c_master.ttl(key)
ttl2 = await c_replica.ttl(key)
await skip_cmd()
assert abs(ttl1-ttl2) <= 1
assert abs(ttl1 - ttl2) <= 1
async with m_replica:
# CHECK EXPIRE, PEXPIRE, PEXPIRE turn into EXPIREAT
@ -650,10 +658,16 @@ async def test_rewrites(df_local_factory):
await c_master.set("renamekey", "1000", px=50000)
await skip_cmd()
# Check RENAME turns into DEL SET and PEXPIREAT
await check_list_ooo("RENAME renamekey renamed", [r"DEL renamekey", r"SET renamed 1000", r"PEXPIREAT renamed (.*?)"])
await check_list_ooo(
"RENAME renamekey renamed",
[r"DEL renamekey", r"SET renamed 1000", r"PEXPIREAT renamed (.*?)"],
)
await check_expire("renamed")
# Check RENAMENX turns into DEL SET and PEXPIREAT
await check_list_ooo("RENAMENX renamed renamekey", [r"DEL renamed", r"SET renamekey 1000", r"PEXPIREAT renamekey (.*?)"])
await check_list_ooo(
"RENAMENX renamed renamekey",
[r"DEL renamed", r"SET renamekey 1000", r"PEXPIREAT renamekey (.*?)"],
)
await check_expire("renamekey")
@ -666,7 +680,7 @@ Test automatic replication of expiry.
@pytest.mark.asyncio
async def test_expiry(df_local_factory, n_keys=1000):
master = df_local_factory.create(port=BASE_PORT)
replica = df_local_factory.create(port=BASE_PORT+1, logtostdout=True)
replica = df_local_factory.create(port=BASE_PORT + 1, logtostdout=True)
df_local_factory.start_all([master, replica])
@ -700,10 +714,9 @@ async def test_expiry(df_local_factory, n_keys=1000):
c_master_db = aioredis.Redis(port=master.port, db=i)
pipe = c_master_db.pipeline(transaction=is_multi)
# Set simple keys n_keys..n_keys*2 on master
start_key = n_keys*(i+1)
start_key = n_keys * (i + 1)
end_key = start_key + n_keys
batch_fill_data(client=pipe, gen=gen_test_data(
end_key, start_key), batch_size=20)
batch_fill_data(client=pipe, gen=gen_test_data(end_key, start_key), batch_size=20)
await pipe.execute()
@ -771,14 +784,15 @@ return 'OK'
"""
@pytest.mark.skip(reason='Failing')
@pytest.mark.skip(reason="Failing")
@pytest.mark.asyncio
@pytest.mark.parametrize("t_master, t_replicas, num_ops, num_keys, num_par, flags", script_cases)
async def test_scripts(df_local_factory, t_master, t_replicas, num_ops, num_keys, num_par, flags):
master = df_local_factory.create(
port=BASE_PORT, proactor_threads=t_master)
replicas = [df_local_factory.create(
port=BASE_PORT+i+1, proactor_threads=t) for i, t in enumerate(t_replicas)]
master = df_local_factory.create(port=BASE_PORT, proactor_threads=t_master)
replicas = [
df_local_factory.create(port=BASE_PORT + i + 1, proactor_threads=t)
for i, t in enumerate(t_replicas)
]
df_local_factory.start_all([master] + replicas)
@ -788,16 +802,15 @@ async def test_scripts(df_local_factory, t_master, t_replicas, num_ops, num_keys
await c_replica.execute_command(f"REPLICAOF localhost {master.port}")
await wait_available_async(c_replica)
script = script_test_s1.format(
flags=f'#!lua flags={flags}' if flags else '')
script = script_test_s1.format(flags=f"#!lua flags={flags}" if flags else "")
sha = await c_master.script_load(script)
key_sets = [
[f'{i}-{j}' for j in range(num_keys)] for i in range(num_par)
]
key_sets = [[f"{i}-{j}" for j in range(num_keys)] for i in range(num_par)]
rsps = await asyncio.gather(*(c_master.evalsha(sha, len(keys), *keys, num_ops) for keys in key_sets))
assert rsps == [b'OK'] * num_par
rsps = await asyncio.gather(
*(c_master.evalsha(sha, len(keys), *keys, num_ops) for keys in key_sets)
)
assert rsps == [b"OK"] * num_par
await check_all_replicas_finished(c_replicas, c_master)
@ -805,17 +818,18 @@ async def test_scripts(df_local_factory, t_master, t_replicas, num_ops, num_keys
for key_set in key_sets:
for j, k in enumerate(key_set):
l = await c_replica.lrange(k, 0, -1)
assert l == [f'{j}'.encode()] * num_ops
assert l == [f"{j}".encode()] * num_ops
@dfly_args({"proactor_threads": 4})
@pytest.mark.asyncio
async def test_auth_master(df_local_factory, n_keys=20):
masterpass = 'requirepass'
replicapass = 'replicapass'
masterpass = "requirepass"
replicapass = "replicapass"
master = df_local_factory.create(port=BASE_PORT, requirepass=masterpass)
replica = df_local_factory.create(
port=BASE_PORT+1, logtostdout=True, masterauth=masterpass, requirepass=replicapass)
port=BASE_PORT + 1, logtostdout=True, masterauth=masterpass, requirepass=replicapass
)
df_local_factory.start_all([master, replica])
@ -845,7 +859,7 @@ SCRIPT_TEMPLATE = "return {}"
@dfly_args({"proactor_threads": 2})
async def test_script_transfer(df_local_factory):
master = df_local_factory.create(port=BASE_PORT)
replica = df_local_factory.create(port=BASE_PORT+1)
replica = df_local_factory.create(port=BASE_PORT + 1)
df_local_factory.start_all([master, replica])
@ -879,28 +893,38 @@ async def test_script_transfer(df_local_factory):
@pytest.mark.asyncio
async def test_role_command(df_local_factory, n_keys=20):
master = df_local_factory.create(port=BASE_PORT)
replica = df_local_factory.create(port=BASE_PORT+1, logtostdout=True)
replica = df_local_factory.create(port=BASE_PORT + 1, logtostdout=True)
df_local_factory.start_all([master, replica])
c_master = aioredis.Redis(port=master.port)
c_replica = aioredis.Redis(port=replica.port)
assert await c_master.execute_command("role") == [b'master', []]
assert await c_master.execute_command("role") == [b"master", []]
await c_replica.execute_command(f"REPLICAOF localhost {master.port}")
await wait_available_async(c_replica)
assert await c_master.execute_command("role") == [
b'master', [[b'127.0.0.1', bytes(str(replica.port), 'ascii'), b'stable_sync']]]
b"master",
[[b"127.0.0.1", bytes(str(replica.port), "ascii"), b"stable_sync"]],
]
assert await c_replica.execute_command("role") == [
b'replica', b'localhost', bytes(str(master.port), 'ascii'), b'stable_sync']
b"replica",
b"localhost",
bytes(str(master.port), "ascii"),
b"stable_sync",
]
# This tests that we react fast to socket shutdowns and don't hang on
# things like the ACK or execution fibers.
master.stop()
await asyncio.sleep(0.1)
assert await c_replica.execute_command("role") == [
b'replica', b'localhost', bytes(str(master.port), 'ascii'), b'connecting']
b"replica",
b"localhost",
bytes(str(master.port), "ascii"),
b"connecting",
]
await c_master.connection_pool.disconnect()
await c_replica.connection_pool.disconnect()
@ -943,7 +967,8 @@ async def assert_lag_condition(inst, client, condition):
async def test_replication_info(df_local_factory, df_seeder_factory, n_keys=2000):
master = df_local_factory.create(port=BASE_PORT)
replica = df_local_factory.create(
port=BASE_PORT+1, logtostdout=True, replication_acks_interval=100)
port=BASE_PORT + 1, logtostdout=True, replication_acks_interval=100
)
df_local_factory.start_all([master, replica])
c_master = aioredis.Redis(port=master.port)
c_replica = aioredis.Redis(port=replica.port)
@ -974,18 +999,15 @@ More details in https://github.com/dragonflydb/dragonfly/issues/1231
@pytest.mark.asyncio
async def test_flushall_in_full_sync(df_local_factory, df_seeder_factory):
master = df_local_factory.create(
port=BASE_PORT, proactor_threads=4, logtostdout=True)
replica = df_local_factory.create(
port=BASE_PORT+1, proactor_threads=2, logtostdout=True)
master = df_local_factory.create(port=BASE_PORT, proactor_threads=4, logtostdout=True)
replica = df_local_factory.create(port=BASE_PORT + 1, proactor_threads=2, logtostdout=True)
# Start master
master.start()
c_master = aioredis.Redis(port=master.port)
# Fill master with test data
seeder = df_seeder_factory.create(
port=master.port, keys=100_000, dbcount=1)
seeder = df_seeder_factory.create(port=master.port, keys=100_000, dbcount=1)
await seeder.run(target_deviation=0.1)
# Start replica
@ -998,7 +1020,7 @@ async def test_flushall_in_full_sync(df_local_factory, df_seeder_factory):
return result[3]
async def is_full_sync_mode(c_replica):
return await get_sync_mode(c_replica) == b'full_sync'
return await get_sync_mode(c_replica) == b"full_sync"
# Wait for full sync to start
while not await is_full_sync_mode(c_replica):
@ -1010,12 +1032,10 @@ async def test_flushall_in_full_sync(df_local_factory, df_seeder_factory):
await c_master.execute_command("FLUSHALL")
if not await is_full_sync_mode(c_replica):
logging.error(
"!!! Full sync finished too fast. Adjust test parameters !!!")
logging.error("!!! Full sync finished too fast. Adjust test parameters !!!")
return
post_seeder = df_seeder_factory.create(
port=master.port, keys=10, dbcount=1)
post_seeder = df_seeder_factory.create(port=master.port, keys=10, dbcount=1)
await post_seeder.run(target_deviation=0.1)
await check_all_replicas_finished([c_replica], c_master)
@ -1047,28 +1067,26 @@ redis.call('SET', 'A', 'ErrroR')
@pytest.mark.asyncio
async def test_readonly_script(df_local_factory):
master = df_local_factory.create(
port=BASE_PORT, proactor_threads=2, logtostdout=True)
replica = df_local_factory.create(
port=BASE_PORT+1, proactor_threads=2, logtostdout=True)
master = df_local_factory.create(port=BASE_PORT, proactor_threads=2, logtostdout=True)
replica = df_local_factory.create(port=BASE_PORT + 1, proactor_threads=2, logtostdout=True)
df_local_factory.start_all([master, replica])
c_master = aioredis.Redis(port=master.port)
c_replica = aioredis.Redis(port=replica.port)
await c_master.set('WORKS', 'YES')
await c_master.set("WORKS", "YES")
await c_replica.execute_command(f"REPLICAOF localhost {master.port}")
await wait_available_async(c_replica)
await c_replica.eval(READONLY_SCRIPT, 3, 'A', 'B', 'WORKS') == 'YES'
await c_replica.eval(READONLY_SCRIPT, 3, "A", "B", "WORKS") == "YES"
try:
await c_replica.eval(WRITE_SCRIPT, 1, 'A')
await c_replica.eval(WRITE_SCRIPT, 1, "A")
assert False
except aioredis.ResponseError as roe:
assert 'READONLY ' in str(roe)
assert "READONLY " in str(roe)
take_over_cases = [
@ -1082,16 +1100,15 @@ take_over_cases = [
@pytest.mark.parametrize("master_threads, replica_threads", take_over_cases)
@pytest.mark.asyncio
async def test_take_over_counters(df_local_factory, master_threads, replica_threads):
master = df_local_factory.create(proactor_threads=master_threads,
port=BASE_PORT,
# vmodule="journal_slice=2,dflycmd=2,main_service=1",
logtostderr=True)
replica1 = df_local_factory.create(
port=BASE_PORT+1, proactor_threads=replica_threads)
replica2 = df_local_factory.create(
port=BASE_PORT+2, proactor_threads=replica_threads)
replica3 = df_local_factory.create(
port=BASE_PORT+3, proactor_threads=replica_threads)
master = df_local_factory.create(
proactor_threads=master_threads,
port=BASE_PORT,
# vmodule="journal_slice=2,dflycmd=2,main_service=1",
logtostderr=True,
)
replica1 = df_local_factory.create(port=BASE_PORT + 1, proactor_threads=replica_threads)
replica2 = df_local_factory.create(port=BASE_PORT + 2, proactor_threads=replica_threads)
replica3 = df_local_factory.create(port=BASE_PORT + 3, proactor_threads=replica_threads)
df_local_factory.start_all([master, replica1, replica2, replica3])
c_master = master.client()
c1 = replica1.client()
@ -1130,8 +1147,10 @@ async def test_take_over_counters(df_local_factory, master_threads, replica_thre
await asyncio.sleep(1)
await c1.execute_command(f"REPLTAKEOVER 5")
_, _, *results = await asyncio.gather(delayed_takeover(), block_during_takeover(), *[counter(f"key{i}") for i in range(16)])
assert await c1.execute_command("role") == [b'master', []]
_, _, *results = await asyncio.gather(
delayed_takeover(), block_during_takeover(), *[counter(f"key{i}") for i in range(16)]
)
assert await c1.execute_command("role") == [b"master", []]
for key, client_value in results:
replicated_value = await c1.get(key)
@ -1142,18 +1161,20 @@ async def test_take_over_counters(df_local_factory, master_threads, replica_thre
@pytest.mark.parametrize("master_threads, replica_threads", take_over_cases)
@pytest.mark.asyncio
async def test_take_over_seeder(request, df_local_factory, df_seeder_factory, master_threads, replica_threads):
tmp_file_name = ''.join(random.choices(string.ascii_letters, k=10))
master = df_local_factory.create(proactor_threads=master_threads,
port=BASE_PORT,
dbfilename=f"dump_{tmp_file_name}",
logtostderr=True)
replica = df_local_factory.create(
port=BASE_PORT+1, proactor_threads=replica_threads)
async def test_take_over_seeder(
request, df_local_factory, df_seeder_factory, master_threads, replica_threads
):
tmp_file_name = "".join(random.choices(string.ascii_letters, k=10))
master = df_local_factory.create(
proactor_threads=master_threads,
port=BASE_PORT,
dbfilename=f"dump_{tmp_file_name}",
logtostderr=True,
)
replica = df_local_factory.create(port=BASE_PORT + 1, proactor_threads=replica_threads)
df_local_factory.start_all([master, replica])
seeder = df_seeder_factory.create(
port=master.port, keys=1000, dbcount=5, stop_on_failure=False)
seeder = df_seeder_factory.create(port=master.port, keys=1000, dbcount=5, stop_on_failure=False)
c_master = master.client()
c_replica = replica.client()
@ -1171,7 +1192,7 @@ async def test_take_over_seeder(request, df_local_factory, df_seeder_factory, ma
await c_replica.execute_command(f"REPLTAKEOVER 5")
seeder.stop()
assert await c_replica.execute_command("role") == [b'master', []]
assert await c_replica.execute_command("role") == [b"master", []]
# Need to wait a bit to give time to write the shutdown snapshot
await asyncio.sleep(1)
@ -1188,15 +1209,11 @@ async def test_take_over_seeder(request, df_local_factory, df_seeder_factory, ma
@pytest.mark.asyncio
async def test_take_over_timeout(df_local_factory, df_seeder_factory):
master = df_local_factory.create(proactor_threads=2,
port=BASE_PORT,
logtostderr=True)
replica = df_local_factory.create(
port=BASE_PORT+1, proactor_threads=2)
master = df_local_factory.create(proactor_threads=2, port=BASE_PORT, logtostderr=True)
replica = df_local_factory.create(port=BASE_PORT + 1, proactor_threads=2)
df_local_factory.start_all([master, replica])
seeder = df_seeder_factory.create(
port=master.port, keys=1000, dbcount=5, stop_on_failure=False)
seeder = df_seeder_factory.create(port=master.port, keys=1000, dbcount=5, stop_on_failure=False)
c_master = master.client()
c_replica = replica.client()
@ -1217,8 +1234,16 @@ async def test_take_over_timeout(df_local_factory, df_seeder_factory):
seeder.stop()
await fill_task
assert await c_master.execute_command("role") == [b'master', [[b'127.0.0.1', bytes(str(replica.port), 'ascii'), b'stable_sync']]]
assert await c_replica.execute_command("role") == [b'replica', b'localhost', bytes(str(master.port), 'ascii'), b'stable_sync']
assert await c_master.execute_command("role") == [
b"master",
[[b"127.0.0.1", bytes(str(replica.port), "ascii"), b"stable_sync"]],
]
assert await c_replica.execute_command("role") == [
b"replica",
b"localhost",
bytes(str(master.port), "ascii"),
b"stable_sync",
]
await disconnect_clients(c_master, c_replica)
@ -1230,11 +1255,18 @@ replication_cases = [(8, 8)]
@pytest.mark.asyncio
@pytest.mark.parametrize("t_master, t_replica", replication_cases)
async def test_no_tls_on_admin_port(df_local_factory, df_seeder_factory, t_master, t_replica, with_tls_server_args):
async def test_no_tls_on_admin_port(
df_local_factory, df_seeder_factory, t_master, t_replica, with_tls_server_args
):
# 1. Spin up dragonfly without tls, debug populate
master = df_local_factory.create(
no_tls_on_admin_port="true", admin_port=ADMIN_PORT, **with_tls_server_args, port=BASE_PORT, proactor_threads=t_master)
no_tls_on_admin_port="true",
admin_port=ADMIN_PORT,
**with_tls_server_args,
port=BASE_PORT,
proactor_threads=t_master,
)
master.start()
c_master = aioredis.Redis(port=master.admin_port)
await c_master.execute_command("DEBUG POPULATE 100")
@ -1244,7 +1276,12 @@ async def test_no_tls_on_admin_port(df_local_factory, df_seeder_factory, t_maste
# 2. Spin up a replica and initiate a REPLICAOF
replica = df_local_factory.create(
no_tls_on_admin_port="true", admin_port=ADMIN_PORT + 1, **with_tls_server_args, port=BASE_PORT + 1, proactor_threads=t_replica)
no_tls_on_admin_port="true",
admin_port=ADMIN_PORT + 1,
**with_tls_server_args,
port=BASE_PORT + 1,
proactor_threads=t_replica,
)
replica.start()
c_replica = aioredis.Redis(port=replica.admin_port)
res = await c_replica.execute_command("REPLICAOF localhost " + str(master.admin_port))

View file

@ -14,29 +14,46 @@ from redis.commands.search.field import TextField, NumericField, TagField, Vecto
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
TEST_DATA = [
{"title": "First article", "content": "Long description",
"views": 100, "topic": "world, science"},
{"title": "Second article", "content": "Small text",
"views": 200, "topic": "national, policits"},
{"title": "Third piece", "content": "Brief description",
"views": 300, "topic": "health, lifestyle"},
{"title": "Last piece", "content": "Interesting text",
"views": 400, "topic": "world, business"},
{
"title": "First article",
"content": "Long description",
"views": 100,
"topic": "world, science",
},
{
"title": "Second article",
"content": "Small text",
"views": 200,
"topic": "national, policits",
},
{
"title": "Third piece",
"content": "Brief description",
"views": 300,
"topic": "health, lifestyle",
},
{
"title": "Last piece",
"content": "Interesting text",
"views": 400,
"topic": "world, business",
},
]
BASIC_TEST_SCHEMA = [TextField("title"), TextField(
"content"), NumericField("views"), TagField("topic")]
BASIC_TEST_SCHEMA = [
TextField("title"),
TextField("content"),
NumericField("views"),
TagField("topic"),
]
async def index_test_data(async_client: aioredis.Redis, itype: IndexType, prefix=""):
for i, e in enumerate(TEST_DATA):
if itype == IndexType.HASH:
await async_client.hset(prefix+str(i), mapping=e)
await async_client.hset(prefix + str(i), mapping=e)
else:
await async_client.json().set(prefix+str(i), "$", e)
await async_client.json().set(prefix + str(i), "$", e)
def doc_to_str(doc):
@ -44,10 +61,10 @@ def doc_to_str(doc):
doc = doc.__dict__
doc = dict(doc) # copy to remove fields
doc.pop('id', None)
doc.pop('payload', None)
doc.pop("id", None)
doc.pop("payload", None)
return '//'.join(sorted(doc))
return "//".join(sorted(doc))
def contains_test_data(res, td_indices):
@ -66,7 +83,7 @@ def contains_test_data(res, td_indices):
@dfly_args({"proactor_threads": 4})
@pytest.mark.parametrize("index_type", [IndexType.HASH, IndexType.JSON])
async def test_basic(async_client: aioredis.Redis, index_type):
i1 = async_client.ft("i1-"+str(index_type))
i1 = async_client.ft("i1-" + str(index_type))
await index_test_data(async_client, index_type)
await i1.create_index(BASIC_TEST_SCHEMA, definition=IndexDefinition(index_type=index_type))
@ -105,21 +122,27 @@ async def test_basic(async_client: aioredis.Redis, index_type):
async def knn_query(idx, query, vector):
params = {"vec": np.array(vector, dtype=np.float32).tobytes()}
result = await idx.search(query, params)
return {doc['id'] for doc in result.docs}
return {doc["id"] for doc in result.docs}
@dfly_args({"proactor_threads": 4})
@pytest.mark.parametrize("index_type", [IndexType.HASH, IndexType.JSON])
async def test_knn(async_client: aioredis.Redis, index_type):
i2 = async_client.ft("i2-"+str(index_type))
i2 = async_client.ft("i2-" + str(index_type))
vector_field = VectorField("pos", "FLAT", {
"TYPE": "FLOAT32",
"DIM": 1,
"DISTANCE_METRIC": "L2",
})
vector_field = VectorField(
"pos",
"FLAT",
{
"TYPE": "FLOAT32",
"DIM": 1,
"DISTANCE_METRIC": "L2",
},
)
await i2.create_index([TagField("even"), vector_field], definition=IndexDefinition(index_type=index_type))
await i2.create_index(
[TagField("even"), vector_field], definition=IndexDefinition(index_type=index_type)
)
pipe = async_client.pipeline()
for i in range(100):
@ -131,13 +154,18 @@ async def test_knn(async_client: aioredis.Redis, index_type):
pipe.json().set(f"k{i}", "$", {"even": even, "pos": [float(i)]})
await pipe.execute()
assert await knn_query(i2, "* => [KNN 3 @pos VEC]", [50.0]) == {'k49', 'k50', 'k51'}
assert await knn_query(i2, "* => [KNN 3 @pos VEC]", [50.0]) == {"k49", "k50", "k51"}
assert await knn_query(i2, "@even:{yes} => [KNN 3 @pos VEC]", [20.0]) == {'k18', 'k20', 'k22'}
assert await knn_query(i2, "@even:{yes} => [KNN 3 @pos VEC]", [20.0]) == {"k18", "k20", "k22"}
assert await knn_query(i2, "@even:{no} => [KNN 4 @pos VEC]", [30.0]) == {'k27', 'k29', 'k31', 'k33'}
assert await knn_query(i2, "@even:{no} => [KNN 4 @pos VEC]", [30.0]) == {
"k27",
"k29",
"k31",
"k33",
}
assert await knn_query(i2, "@even:{yes} => [KNN 3 @pos VEC]", [10.0] == {'k8', 'k10', 'k12'})
assert await knn_query(i2, "@even:{yes} => [KNN 3 @pos VEC]", [10.0] == {"k8", "k10", "k12"})
NUM_DIMS = 10
@ -147,21 +175,24 @@ NUM_POINTS = 100
@dfly_args({"proactor_threads": 4})
@pytest.mark.parametrize("index_type", [IndexType.HASH, IndexType.JSON])
async def test_multidim_knn(async_client: aioredis.Redis, index_type):
vector_field = VectorField("pos", "FLAT", {
"TYPE": "FLOAT32",
"DIM": NUM_DIMS,
"DISTANCE_METRIC": "L2",
})
vector_field = VectorField(
"pos",
"FLAT",
{
"TYPE": "FLOAT32",
"DIM": NUM_DIMS,
"DISTANCE_METRIC": "L2",
},
)
i3 = async_client.ft("i3-"+str(index_type))
i3 = async_client.ft("i3-" + str(index_type))
await i3.create_index([vector_field], definition=IndexDefinition(index_type=index_type))
def rand_point():
return np.random.uniform(0, 10, NUM_DIMS).astype(np.float32)
# Generate points and send to DF
points = [rand_point()
for _ in range(NUM_POINTS)]
points = [rand_point() for _ in range(NUM_POINTS)]
points = list(enumerate(points))
pipe = async_client.pipeline(transaction=False)
@ -175,10 +206,12 @@ async def test_multidim_knn(async_client: aioredis.Redis, index_type):
# Run 10 random queries
for _ in range(10):
center = rand_point()
limit = random.randint(1, NUM_POINTS/10)
limit = random.randint(1, NUM_POINTS / 10)
expected_ids = [f"k{i}" for i, point in sorted(
points, key=lambda p: np.linalg.norm(center - p[1]))[:limit]]
expected_ids = [
f"k{i}"
for i, point in sorted(points, key=lambda p: np.linalg.norm(center - p[1]))[:limit]
]
got_ids = await knn_query(i3, f"* => [KNN {limit} @pos VEC]", center)

View file

@ -14,26 +14,28 @@ import logging
# Output is expected be of even number of lines where each pair of consecutive lines results in a single key value pair.
# If new_dict_key is not empty, encountering it in the output will start a new dictionary, this let us return multiple
# dictionaries, for example in the 'slaves' command, one dictionary for each slave.
def stdout_as_list_of_dicts(cp: subprocess.CompletedProcess, new_dict_key =""):
def stdout_as_list_of_dicts(cp: subprocess.CompletedProcess, new_dict_key=""):
lines = cp.stdout.splitlines()
res = []
d = None
if (new_dict_key == ''):
if new_dict_key == "":
d = dict()
res.append(d)
for i in range(0, len(lines), 2):
if (lines[i]) == new_dict_key: # assumes output never has '' as a key
if (lines[i]) == new_dict_key: # assumes output never has '' as a key
d = dict()
res.append(d)
d[lines[i]] = lines[i + 1]
return res
def wait_for(func, pred, timeout_sec, timeout_msg=""):
while not pred(func()):
assert timeout_sec > 0, timeout_msg
timeout_sec = timeout_sec - 1
time.sleep(1)
async def await_for(func, pred, timeout_sec, timeout_msg=""):
done = False
while not done:
@ -60,21 +62,26 @@ class Sentinel:
config = [
f"port {self.port}",
f"sentinel monitor {self.default_deployment} 127.0.0.1 {self.initial_master_port} 1",
f"sentinel down-after-milliseconds {self.default_deployment} 3000"
]
f"sentinel down-after-milliseconds {self.default_deployment} 3000",
]
self.config_file.write_text("\n".join(config))
logging.info(self.config_file.read_text())
self.proc = subprocess.Popen(["redis-server", f"{self.config_file.absolute()}", "--sentinel"])
self.proc = subprocess.Popen(
["redis-server", f"{self.config_file.absolute()}", "--sentinel"]
)
def stop(self):
self.proc.terminate()
self.proc.wait(timeout=10)
def run_cmd(self, args, sentinel_cmd=True, capture_output=False, assert_ok=True) -> subprocess.CompletedProcess:
def run_cmd(
self, args, sentinel_cmd=True, capture_output=False, assert_ok=True
) -> subprocess.CompletedProcess:
run_args = ["redis-cli", "-p", f"{self.port}"]
if sentinel_cmd: run_args = run_args + ["sentinel"]
if sentinel_cmd:
run_args = run_args + ["sentinel"]
run_args = run_args + args
cp = subprocess.run(run_args, capture_output=capture_output, text=True)
if assert_ok:
@ -84,8 +91,10 @@ class Sentinel:
def wait_ready(self):
wait_for(
lambda: self.run_cmd(["ping"], sentinel_cmd=False, assert_ok=False),
lambda cp:cp.returncode == 0,
timeout_sec=10, timeout_msg="Timeout waiting for sentinel to become ready.")
lambda cp: cp.returncode == 0,
timeout_sec=10,
timeout_msg="Timeout waiting for sentinel to become ready.",
)
def master(self, deployment="") -> dict:
if deployment == "":
@ -108,10 +117,17 @@ class Sentinel:
def failover(self, deployment=""):
if deployment == "":
deployment = self.default_deployment
self.run_cmd(["failover", deployment,])
self.run_cmd(
[
"failover",
deployment,
]
)
@pytest.fixture(scope="function") # Sentinel has state which we don't want carried over form test to test.
@pytest.fixture(
scope="function"
) # Sentinel has state which we don't want carried over form test to test.
def sentinel(tmp_dir, port_picker) -> Sentinel:
s = Sentinel(port_picker.get_available_port(), port_picker.get_available_port(), tmp_dir)
s.start()
@ -138,9 +154,11 @@ async def test_failover(df_local_factory, sentinel, port_picker):
# Verify sentinel picked up replica.
await await_for(
lambda: sentinel.master(),
lambda m: m["num-slaves"] == "1",
timeout_sec=15, timeout_msg="Timeout waiting for sentinel to pick up replica.")
lambda: sentinel.master(),
lambda m: m["num-slaves"] == "1",
timeout_sec=15,
timeout_msg="Timeout waiting for sentinel to pick up replica.",
)
sentinel.failover()
@ -148,7 +166,8 @@ async def test_failover(df_local_factory, sentinel, port_picker):
await await_for(
lambda: sentinel.live_master_port(),
lambda p: p == replica.port,
timeout_sec=10, timeout_msg="Timeout waiting for sentinel to report replica as master."
timeout_sec=10,
timeout_msg="Timeout waiting for sentinel to report replica as master.",
)
assert sentinel.slaves()[0]["port"] == str(master.port)
@ -159,7 +178,8 @@ async def test_failover(df_local_factory, sentinel, port_picker):
await await_for(
lambda: master_client.get("key"),
lambda val: val == b"value",
15, "Timeout waiting for key to exist in replica."
15,
"Timeout waiting for key to exist in replica.",
)
except AssertionError:
syncid, r_offset = await master_client.execute_command("DEBUG REPLICA OFFSET")
@ -197,9 +217,11 @@ async def test_master_failure(df_local_factory, sentinel, port_picker):
# Verify sentinel picked up replica.
await await_for(
lambda: sentinel.master(),
lambda m: m["num-slaves"] == "1",
timeout_sec=15, timeout_msg="Timeout waiting for sentinel to pick up replica.")
lambda: sentinel.master(),
lambda m: m["num-slaves"] == "1",
timeout_sec=15,
timeout_msg="Timeout waiting for sentinel to pick up replica.",
)
# Simulate master failure.
master.stop()
@ -208,7 +230,8 @@ async def test_master_failure(df_local_factory, sentinel, port_picker):
await await_for(
lambda: sentinel.live_master_port(),
lambda p: p == replica.port,
timeout_sec=300, timeout_msg="Timeout waiting for sentinel to report replica as master."
timeout_sec=300,
timeout_msg="Timeout waiting for sentinel to report replica as master.",
)
# Verify we can now write to replica.

View file

@ -5,7 +5,7 @@ from .utility import *
def test_quit(connection):
connection.send_command("QUIT")
assert connection.read_response() == b'OK'
assert connection.read_response() == b"OK"
with pytest.raises(redis.exceptions.ConnectionError) as e:
connection.read_response()
@ -16,7 +16,7 @@ def test_quit_after_sub(connection):
connection.read_response()
connection.send_command("QUIT")
assert connection.read_response() == b'OK'
assert connection.read_response() == b"OK"
with pytest.raises(redis.exceptions.ConnectionError) as e:
connection.read_response()
@ -30,7 +30,7 @@ async def test_multi_exec(async_client: aioredis.Redis):
assert val == [True, "bar"]
'''
"""
see https://github.com/dragonflydb/dragonfly/issues/457
For now we would not allow for eval command inside multi
As this would create to level transactions (in effect recursive call
@ -38,7 +38,8 @@ to Schedule function).
When this issue is fully fixed, this test would failed, and then it should
change to match the fact that we supporting this operation.
For now we are expecting to get an error
'''
"""
@pytest.mark.skip("Skip until we decided on correct behaviour of eval inside multi")
async def test_multi_eval(async_client: aioredis.Redis):
@ -76,10 +77,12 @@ async def test_client_list(df_factory):
instance.stop()
await disconnect_clients(client, admin_client)
async def test_scan(async_client: aioredis.Redis):
'''
"""
make sure that the scan command is working with python
'''
"""
def gen_test_data():
for i in range(10):
yield f"key-{i}", f"value-{i}"

View file

@ -9,13 +9,16 @@ from . import dfly_args
BASIC_ARGS = {"dir": "{DRAGONFLY_TMP}/"}
@pytest.mark.skip(reason='Currently we can not guarantee that on shutdown if command is executed and value is written we response before breaking the connection')
@pytest.mark.skip(
reason="Currently we can not guarantee that on shutdown if command is executed and value is written we response before breaking the connection"
)
@dfly_args({"proactor_threads": "4"})
class TestDflyAutoLoadSnapshot():
class TestDflyAutoLoadSnapshot:
"""
Test automatic loading of dump files on startup with timestamp.
When command is executed if a value is written we should send the response before shutdown
"""
@pytest.mark.asyncio
async def test_gracefull_shutdown(self, df_local_factory):
df_args = {"dbfilename": "dump", **BASIC_ARGS, "port": 1111}
@ -39,7 +42,9 @@ class TestDflyAutoLoadSnapshot():
await client.execute_command("SHUTDOWN")
await client.connection_pool.disconnect()
_, *results = await asyncio.gather(delayed_takeover(), *[counter(f"key{i}") for i in range(16)])
_, *results = await asyncio.gather(
delayed_takeover(), *[counter(f"key{i}") for i in range(16)]
)
df_server.start()
client = aioredis.Redis(port=df_server.port)

View file

@ -18,9 +18,10 @@ class SnapshotTestBase:
self.tmp_dir = tmp_dir
def get_main_file(self, pattern):
def is_main(f): return "summary" in f if pattern.endswith(
"dfs") else True
files = glob.glob(str(self.tmp_dir.absolute()) + '/' + pattern)
def is_main(f):
return "summary" in f if pattern.endswith("dfs") else True
files = glob.glob(str(self.tmp_dir.absolute()) + "/" + pattern)
possible_mains = list(filter(is_main, files))
assert len(possible_mains) == 1, possible_mains
return possible_mains[0]
@ -29,6 +30,7 @@ class SnapshotTestBase:
@dfly_args({**BASIC_ARGS, "dbfilename": "test-rdb-{{timestamp}}"})
class TestRdbSnapshot(SnapshotTestBase):
"""Test single file rdb snapshot"""
@pytest.fixture(autouse=True)
def setup(self, tmp_dir: Path):
super().setup(tmp_dir)
@ -51,6 +53,7 @@ class TestRdbSnapshot(SnapshotTestBase):
@dfly_args({**BASIC_ARGS, "dbfilename": "test-rdbexact.rdb", "nodf_snapshot_format": None})
class TestRdbSnapshotExactFilename(SnapshotTestBase):
"""Test single file rdb snapshot without a timestamp"""
@pytest.fixture(autouse=True)
def setup(self, tmp_dir: Path):
super().setup(tmp_dir)
@ -74,6 +77,7 @@ class TestRdbSnapshotExactFilename(SnapshotTestBase):
@dfly_args({**BASIC_ARGS, "dbfilename": "test-dfs"})
class TestDflySnapshot(SnapshotTestBase):
"""Test multi file snapshot"""
@pytest.fixture(autouse=True)
def setup(self, tmp_dir: Path):
self.tmp_dir = tmp_dir
@ -88,16 +92,20 @@ class TestDflySnapshot(SnapshotTestBase):
# save + flush + load
await async_client.execute_command("SAVE DF")
assert await async_client.flushall()
await async_client.execute_command("DEBUG LOAD " + super().get_main_file("test-dfs-summary.dfs"))
await async_client.execute_command(
"DEBUG LOAD " + super().get_main_file("test-dfs-summary.dfs")
)
assert await seeder.compare(start_capture)
# We spawn instances manually, so reduce memory usage of default to minimum
@dfly_args({"proactor_threads": "1"})
class TestDflyAutoLoadSnapshot(SnapshotTestBase):
"""Test automatic loading of dump files on startup with timestamp"""
@pytest.fixture(autouse=True)
def setup(self, tmp_dir: Path):
self.tmp_dir = tmp_dir
@ -115,8 +123,8 @@ class TestDflyAutoLoadSnapshot(SnapshotTestBase):
@pytest.mark.parametrize("save_type, dbfilename", cases)
async def test_snapshot(self, df_local_factory, save_type, dbfilename):
df_args = {"dbfilename": dbfilename, **BASIC_ARGS, "port": 1111}
if save_type == 'rdb':
df_args['nodf_snapshot_format'] = ""
if save_type == "rdb":
df_args["nodf_snapshot_format"] = ""
df_server = df_local_factory.create(**df_args)
df_server.start()
@ -135,6 +143,7 @@ class TestDflyAutoLoadSnapshot(SnapshotTestBase):
@dfly_args({**BASIC_ARGS, "dbfilename": "test-periodic", "save_schedule": "*:*"})
class TestPeriodicSnapshot(SnapshotTestBase):
"""Test periodic snapshotting"""
@pytest.fixture(autouse=True)
def setup(self, tmp_dir: Path):
super().setup(tmp_dir)
@ -142,7 +151,8 @@ class TestPeriodicSnapshot(SnapshotTestBase):
@pytest.mark.asyncio
async def test_snapshot(self, df_seeder_factory, df_server):
seeder = df_seeder_factory.create(
port=df_server.port, keys=10, multi_transaction_probability=0)
port=df_server.port, keys=10, multi_transaction_probability=0
)
await seeder.run(target_deviation=0.5)
time.sleep(60)
@ -154,14 +164,14 @@ class TestPeriodicSnapshot(SnapshotTestBase):
class TestPathEscapes(SnapshotTestBase):
"""Test that we don't allow path escapes. We just check that df_server.start()
fails because we don't have a much better way to test that."""
@pytest.fixture(autouse=True)
def setup(self, tmp_dir: Path):
super().setup(tmp_dir)
@pytest.mark.asyncio
async def test_snapshot(self, df_local_factory):
df_server = df_local_factory.create(
dbfilename="../../../../etc/passwd")
df_server = df_local_factory.create(dbfilename="../../../../etc/passwd")
try:
df_server.start()
assert False, "Server should not start correctly"
@ -172,6 +182,7 @@ class TestPathEscapes(SnapshotTestBase):
@dfly_args({**BASIC_ARGS, "dbfilename": "test-shutdown"})
class TestDflySnapshotOnShutdown(SnapshotTestBase):
"""Test multi file snapshot"""
@pytest.fixture(autouse=True)
def setup(self, tmp_dir: Path):
self.tmp_dir = tmp_dir
@ -192,15 +203,17 @@ class TestDflySnapshotOnShutdown(SnapshotTestBase):
assert await seeder.compare(start_capture)
@dfly_args({**BASIC_ARGS, "dbfilename": "test-info-persistence"})
class TestDflyInfoPersistenceLoadingField(SnapshotTestBase):
"""Test is_loading field on INFO PERSISTENCE during snapshot loading"""
@pytest.fixture(autouse=True)
def setup(self, tmp_dir: Path):
self.tmp_dir = tmp_dir
def extract_is_loading_field(self, res):
matcher = b'loading:'
matcher = b"loading:"
start = res.find(matcher)
pos = start + len(matcher)
return chr(res[pos])
@ -211,9 +224,9 @@ class TestDflyInfoPersistenceLoadingField(SnapshotTestBase):
await seeder.run(target_deviation=0.05)
a_client = aioredis.Redis(port=df_server.port)
#Wait for snapshot to finish loading and try INFO PERSISTENCE
# Wait for snapshot to finish loading and try INFO PERSISTENCE
await wait_available_async(a_client)
res = await a_client.execute_command("INFO PERSISTENCE")
assert '0' == self.extract_is_loading_field(res)
assert "0" == self.extract_is_loading_field(res)
await a_client.connection_pool.disconnect()

View file

@ -30,7 +30,7 @@ def eprint(*args, **kwargs):
def gen_test_data(n, start=0, seed=None):
for i in range(start, n):
yield "k-"+str(i), "v-"+str(i) + ("-"+str(seed) if seed else "")
yield "k-" + str(i), "v-" + str(i) + ("-" + str(seed) if seed else "")
def batch_fill_data(client, gen, batch_size=100):
@ -43,16 +43,16 @@ async def wait_available_async(client: aioredis.Redis):
its = 0
while True:
try:
await client.get('key')
await client.get("key")
return
except aioredis.ResponseError as e:
if ("MOVED" in str(e)):
if "MOVED" in str(e):
# MOVED means we *can* serve traffic, but 'key' does not belong to an owned slot
return
assert "Can not execute during LOADING" in str(e)
# Print W to indicate test is waiting for replica
print('W', end='', flush=True)
print("W", end="", flush=True)
await asyncio.sleep(0.01)
its += 1
@ -141,7 +141,8 @@ class CommandGenerator:
def generate_val(self, t: ValueType):
"""Generate filler value of configured size for type t"""
def rand_str(k=3, s=''):
def rand_str(k=3, s=""):
# Use small k value to reduce mem usage and increase number of ops
return s.join(random.choices(string.ascii_letters, k=k))
@ -150,19 +151,21 @@ class CommandGenerator:
return (rand_str(self.val_size),)
elif t == ValueType.LIST:
# Random sequence k-letter elements for LPUSH
return tuple(rand_str() for _ in range(self.val_size//4))
return tuple(rand_str() for _ in range(self.val_size // 4))
elif t == ValueType.SET:
# Random sequence of k-letter elements for SADD
return tuple(rand_str() for _ in range(self.val_size//4))
return tuple(rand_str() for _ in range(self.val_size // 4))
elif t == ValueType.HSET:
# Random sequence of k-letter keys + int and two start values for HSET
elements = ((rand_str(), random.randint(0, self.val_size))
for _ in range(self.val_size//5))
return ('v0', 0, 'v1', 0) + tuple(itertools.chain(*elements))
elements = (
(rand_str(), random.randint(0, self.val_size)) for _ in range(self.val_size // 5)
)
return ("v0", 0, "v1", 0) + tuple(itertools.chain(*elements))
elif t == ValueType.ZSET:
# Random sequnce of k-letter keys and int score for ZSET
elements = ((random.randint(0, self.val_size), rand_str())
for _ in range(self.val_size//4))
elements = (
(random.randint(0, self.val_size), rand_str()) for _ in range(self.val_size // 4)
)
return tuple(itertools.chain(*elements))
elif t == ValueType.JSON:
@ -170,9 +173,8 @@ class CommandGenerator:
# - arr (array of random strings)
# - ints (array of objects {i:random integer})
# - i (random integer)
ints = [{"i": random.randint(0, 100)}
for i in range(self.val_size//6)]
strs = [rand_str() for _ in range(self.val_size//6)]
ints = [{"i": random.randint(0, 100)} for i in range(self.val_size // 6)]
strs = [rand_str() for _ in range(self.val_size // 6)]
return "$", json.dumps({"arr": strs, "ints": ints, "i": random.randint(0, 100)})
else:
assert False, "Invalid ValueType"
@ -187,8 +189,9 @@ class CommandGenerator:
return None, 0
return f"PEXPIRE k{key} {random.randint(0, 50)}", -1
else:
keys_gen = (self.randomize_key(pop=True)
for _ in range(random.randint(1, self.max_multikey)))
keys_gen = (
self.randomize_key(pop=True) for _ in range(random.randint(1, self.max_multikey))
)
keys = [f"k{k}" for k, _ in keys_gen if k is not None]
if len(keys) == 0:
@ -196,19 +199,19 @@ class CommandGenerator:
return "DEL " + " ".join(keys), -len(keys)
UPDATE_ACTIONS = [
('APPEND {k} {val}', ValueType.STRING),
('SETRANGE {k} 10 {val}', ValueType.STRING),
('LPUSH {k} {val}', ValueType.LIST),
('LPOP {k}', ValueType.LIST),
('SADD {k} {val}', ValueType.SET),
('SPOP {k}', ValueType.SET),
('HSETNX {k} v0 {val}', ValueType.HSET),
('HINCRBY {k} v1 1', ValueType.HSET),
('ZPOPMIN {k} 1', ValueType.ZSET),
('ZADD {k} 0 {val}', ValueType.ZSET),
('JSON.NUMINCRBY {k} $..i 1', ValueType.JSON),
('JSON.ARRPOP {k} $.arr', ValueType.JSON),
('JSON.ARRAPPEND {k} $.arr "{val}"', ValueType.JSON)
("APPEND {k} {val}", ValueType.STRING),
("SETRANGE {k} 10 {val}", ValueType.STRING),
("LPUSH {k} {val}", ValueType.LIST),
("LPOP {k}", ValueType.LIST),
("SADD {k} {val}", ValueType.SET),
("SPOP {k}", ValueType.SET),
("HSETNX {k} v0 {val}", ValueType.HSET),
("HINCRBY {k} v1 1", ValueType.HSET),
("ZPOPMIN {k} 1", ValueType.ZSET),
("ZADD {k} 0 {val}", ValueType.ZSET),
("JSON.NUMINCRBY {k} $..i 1", ValueType.JSON),
("JSON.ARRPOP {k} $.arr", ValueType.JSON),
('JSON.ARRAPPEND {k} $.arr "{val}"', ValueType.JSON),
]
def gen_update_cmd(self):
@ -217,16 +220,16 @@ class CommandGenerator:
"""
cmd, t = random.choice(self.UPDATE_ACTIONS)
k, _ = self.randomize_key(t)
val = ''.join(random.choices(string.ascii_letters, k=3))
val = "".join(random.choices(string.ascii_letters, k=3))
return cmd.format(k=f"k{k}", val=val) if k is not None else None, 0
GROW_ACTINONS = {
ValueType.STRING: 'MSET',
ValueType.LIST: 'LPUSH',
ValueType.SET: 'SADD',
ValueType.HSET: 'HMSET',
ValueType.ZSET: 'ZADD',
ValueType.JSON: 'JSON.SET'
ValueType.STRING: "MSET",
ValueType.LIST: "LPUSH",
ValueType.SET: "SADD",
ValueType.HSET: "HMSET",
ValueType.ZSET: "ZADD",
ValueType.JSON: "JSON.SET",
}
def gen_grow_cmd(self):
@ -241,14 +244,13 @@ class CommandGenerator:
count = 1
keys = (self.add_key(t) for _ in range(count))
payload = itertools.chain(
*((f"k{k}",) + self.generate_val(t) for k in keys))
payload = itertools.chain(*((f"k{k}",) + self.generate_val(t) for k in keys))
filtered_payload = filter(lambda p: p is not None, payload)
return (self.GROW_ACTINONS[t],) + tuple(filtered_payload), count
def make(self, action):
""" Create command for action and return it together with number of keys added (removed)"""
"""Create command for action and return it together with number of keys added (removed)"""
if action == SizeChange.SHRINK:
return self.gen_shrink_cmd()
elif action == SizeChange.NO_CHANGE:
@ -269,8 +271,7 @@ class CommandGenerator:
return [
max(self.base_diff_prob - self.diff_speed * dist, self.min_diff_prob),
1.0,
max(self.base_diff_prob + 2 *
self.diff_speed * dist, self.min_diff_prob)
max(self.base_diff_prob + 2 * self.diff_speed * dist, self.min_diff_prob),
]
def generate(self):
@ -280,15 +281,14 @@ class CommandGenerator:
while len(cmds) < self.batch_size:
# Re-calculating changes in small groups
if len(changes) == 0:
changes = random.choices(
list(SizeChange), weights=self.size_change_probs(), k=20)
changes = random.choices(list(SizeChange), weights=self.size_change_probs(), k=20)
cmd, delta = self.make(changes.pop())
if cmd is not None:
cmds.append(cmd)
self.key_cnt += delta
return cmds, self.key_cnt/self.key_cnt_target
return cmds, self.key_cnt / self.key_cnt_target
class DataCapture:
@ -311,7 +311,7 @@ class DataCapture:
printed = 0
diff = difflib.ndiff(self.entries, other.entries)
for line in diff:
if line.startswith(' '):
if line.startswith(" "):
continue
eprint(line)
if printed >= 20:
@ -344,10 +344,20 @@ class DflySeeder:
assert await seeder.compare(capture, port=1112)
"""
def __init__(self, port=6379, keys=1000, val_size=50, batch_size=100, max_multikey=5, dbcount=1, multi_transaction_probability=0.3, log_file=None, unsupported_types=[], stop_on_failure=True):
self.gen = CommandGenerator(
keys, val_size, batch_size, max_multikey, unsupported_types
)
def __init__(
self,
port=6379,
keys=1000,
val_size=50,
batch_size=100,
max_multikey=5,
dbcount=1,
multi_transaction_probability=0.3,
log_file=None,
unsupported_types=[],
stop_on_failure=True,
):
self.gen = CommandGenerator(keys, val_size, batch_size, max_multikey, unsupported_types)
self.port = port
self.dbcount = dbcount
self.multi_transaction_probability = multi_transaction_probability
@ -356,7 +366,7 @@ class DflySeeder:
self.log_file = log_file
if self.log_file is not None:
open(self.log_file, 'w').close()
open(self.log_file, "w").close()
async def run(self, target_ops=None, target_deviation=None):
"""
@ -366,11 +376,11 @@ class DflySeeder:
print(f"Running ops:{target_ops} deviation:{target_deviation}")
self.stop_flag = False
queues = [asyncio.Queue(maxsize=3) for _ in range(self.dbcount)]
producer = asyncio.create_task(self._generator_task(
queues, target_ops=target_ops, target_deviation=target_deviation))
producer = asyncio.create_task(
self._generator_task(queues, target_ops=target_ops, target_deviation=target_deviation)
)
consumers = [
asyncio.create_task(self._executor_task(i, queue))
for i, queue in enumerate(queues)
asyncio.create_task(self._executor_task(i, queue)) for i, queue in enumerate(queues)
]
time_start = time.time()
@ -388,7 +398,7 @@ class DflySeeder:
self.stop_flag = True
def reset(self):
""" Reset internal state. Needs to be called after flush or restart"""
"""Reset internal state. Needs to be called after flush or restart"""
self.gen.reset()
async def capture(self, port=None):
@ -398,9 +408,9 @@ class DflySeeder:
port = self.port
keys = sorted(list(self.gen.keys_and_types()))
captures = await asyncio.gather(*(
self._capture_db(port=port, target_db=db, keys=keys) for db in range(self.dbcount)
))
captures = await asyncio.gather(
*(self._capture_db(port=port, target_db=db, keys=keys) for db in range(self.dbcount))
)
return captures
async def compare(self, initial_captures, port=6379):
@ -408,7 +418,9 @@ class DflySeeder:
print(f"comparing capture to {port}")
target_captures = await self.capture(port=port)
for db, target_capture, initial_capture in zip(range(self.dbcount), target_captures, initial_captures):
for db, target_capture, initial_capture in zip(
range(self.dbcount), target_captures, initial_captures
):
print(f"comparing capture to {port}, db: {db}")
if not initial_capture.compare(target_capture):
eprint(f">>> Inconsistent data on port {port}, db {db}")
@ -433,14 +445,14 @@ class DflySeeder:
file = None
if self.log_file:
file = open(self.log_file, 'a')
file = open(self.log_file, "a")
def should_run():
if self.stop_flag:
return False
if target_ops is not None and submitted >= target_ops:
return False
if target_deviation is not None and abs(1-deviation) < target_deviation:
if target_deviation is not None and abs(1 - deviation) < target_deviation:
return False
return True
@ -455,7 +467,7 @@ class DflySeeder:
blob, deviation = self.gen.generate()
is_multi_transaction = random.random() < self.multi_transaction_probability
tx_data = (blob, is_multi_transaction)
cpu_time += (time.time() - start_time)
cpu_time += time.time() - start_time
await asyncio.gather(*(q.put(tx_data) for q in queues))
submitted += len(blob)
@ -463,14 +475,12 @@ class DflySeeder:
if file is not None:
pattern = "MULTI\n{}\nEXEC\n" if is_multi_transaction else "{}\n"
file.write(pattern.format('\n'.join(stringify_cmd(cmd)
for cmd in blob)))
file.write(pattern.format("\n".join(stringify_cmd(cmd) for cmd in blob)))
print('.', end='', flush=True)
print(".", end="", flush=True)
await asyncio.sleep(0.0)
print("\ncpu time", cpu_time, "batches",
batches, "commands", submitted)
print("\ncpu time", cpu_time, "batches", batches, "commands", submitted)
await asyncio.gather(*(q.put(None) for q in queues))
for q in queues:
@ -512,20 +522,19 @@ class DflySeeder:
ValueType.LIST: lambda pipe, k: pipe.lrange(k, 0, -1),
ValueType.SET: lambda pipe, k: pipe.smembers(k),
ValueType.HSET: lambda pipe, k: pipe.hgetall(k),
ValueType.ZSET: lambda pipe, k: pipe.zrange(
k, start=0, end=-1, withscores=True),
ValueType.JSON: lambda pipe, k: pipe.execute_command(
"JSON.GET", k, "$")
ValueType.ZSET: lambda pipe, k: pipe.zrange(k, start=0, end=-1, withscores=True),
ValueType.JSON: lambda pipe, k: pipe.execute_command("JSON.GET", k, "$"),
}
CAPTURE_EXTRACTORS = {
ValueType.STRING: lambda res, tostr: (tostr(res),),
ValueType.LIST: lambda res, tostr: (tostr(s) for s in res),
ValueType.SET: lambda res, tostr: sorted(tostr(s) for s in res),
ValueType.HSET: lambda res, tostr: sorted(tostr(k)+"="+tostr(v) for k, v in res.items()),
ValueType.ZSET: lambda res, tostr: (
tostr(s)+"-"+str(f) for (s, f) in res),
ValueType.JSON: lambda res, tostr: (tostr(res),)
ValueType.HSET: lambda res, tostr: sorted(
tostr(k) + "=" + tostr(v) for k, v in res.items()
),
ValueType.ZSET: lambda res, tostr: (tostr(s) + "-" + str(f) for (s, f) in res),
ValueType.JSON: lambda res, tostr: (tostr(res),),
}
async def _capture_entries(self, client, keys):
@ -540,8 +549,7 @@ class DflySeeder:
results = await pipe.execute()
for (k, t), res in zip(group, results):
out = f"{t.name} k{k}: " + \
' '.join(self.CAPTURE_EXTRACTORS[t](res, tostr))
out = f"{t.name} k{k}: " + " ".join(self.CAPTURE_EXTRACTORS[t](res, tostr))
entries.append(out)
return entries
@ -563,11 +571,13 @@ async def disconnect_clients(*clients):
await asyncio.gather(*(c.connection_pool.disconnect() for c in clients))
def gen_certificate(ca_key_path, ca_certificate_path, certificate_request_path, private_key_path, certificate_path):
def gen_certificate(
ca_key_path, ca_certificate_path, certificate_request_path, private_key_path, certificate_path
):
# Generate Dragonfly's private key and certificate signing request (CSR)
step1 = rf'openssl req -newkey rsa:4096 -nodes -keyout {private_key_path} -out {certificate_request_path} -subj "/C=GR/ST=SKG/L=Thessaloniki/O=KK/OU=Comp/CN=Gr/emailAddress=does_not_exist@gmail.com"'
subprocess.run(step1, shell=True)
# Use CA's private key to sign dragonfly's CSR and get back the signed certificate
step2 = fr'openssl x509 -req -in {certificate_request_path} -days 1 -CA {ca_certificate_path} -CAkey {ca_key_path} -CAcreateserial -out {certificate_path}'
step2 = rf"openssl x509 -req -in {certificate_request_path} -days 1 -CA {ca_certificate_path} -CAkey {ca_key_path} -CAcreateserial -out {certificate_path}"
subprocess.run(step2, shell=True)

View file

@ -13,12 +13,14 @@ from loguru import logger as log
import sys
import random
connection_pool = aioredis.ConnectionPool(host="localhost", port=6379,
db=1, decode_responses=True, max_connections=16)
connection_pool = aioredis.ConnectionPool(
host="localhost", port=6379, db=1, decode_responses=True, max_connections=16
)
key_index = 1
async def post_to_redis(sem, db_name, index):
global key_index
async with sem:
@ -26,10 +28,10 @@ async def post_to_redis(sem, db_name, index):
try:
redis_client = aioredis.Redis(connection_pool=connection_pool)
async with redis_client.pipeline(transaction=True) as pipe:
for i in range(1, 15):
for i in range(1, 15):
pipe.hsetnx(name=f"key_{key_index}", key="name", value="bla")
key_index += 1
#log.info(f"after first half {key_index}")
# log.info(f"after first half {key_index}")
for i in range(1, 15):
pipe.hsetnx(name=f"bla_{key_index}", key="name2", value="bla")
key_index += 1
@ -40,8 +42,8 @@ async def post_to_redis(sem, db_name, index):
finally:
# log.info(f"before close {index}")
await redis_client.close()
#log.info(f"after close {index} {len(results)}")
# log.info(f"after close {index} {len(results)}")
async def do_concurrent(db_name):
tasks = []
@ -49,10 +51,10 @@ async def do_concurrent(db_name):
for i in range(1, 3000):
tasks.append(post_to_redis(sem, db_name, i))
res = await asyncio.gather(*tasks)
if __name__ == '__main__':
if __name__ == "__main__":
log.remove()
log.add(sys.stdout, enqueue=True, level='INFO')
log.add(sys.stdout, enqueue=True, level="INFO")
loop = asyncio.get_event_loop()
loop.run_until_complete(do_concurrent("my_db"))

View file

@ -12,40 +12,40 @@ def fill_set(args, redis: rclient.Redis):
for j in range(args.num):
token = uuid.uuid1().hex
# print(token)
key = f'USER_OTP:{token}'
key = f"USER_OTP:{token}"
arr = []
for i in range(30):
otp = ''.join(random.choices(
string.ascii_uppercase + string.digits, k=12))
otp = "".join(random.choices(string.ascii_uppercase + string.digits, k=12))
arr.append(otp)
redis.execute_command('sadd', key, *arr)
redis.execute_command("sadd", key, *arr)
def fill_hset(args, redis):
for j in range(args.num):
token = uuid.uuid1().hex
key = f'USER_INFO:{token}'
phone = f'555-999-{j}'
user_id = 'user' * 5 + f'-{j}'
redis.hset(key, 'phone', phone)
redis.hset(key, 'user_id', user_id)
redis.hset(key, 'login_time', time.time())
key = f"USER_INFO:{token}"
phone = f"555-999-{j}"
user_id = "user" * 5 + f"-{j}"
redis.hset(key, "phone", phone)
redis.hset(key, "user_id", user_id)
redis.hset(key, "login_time", time.time())
def main():
parser = argparse.ArgumentParser(description='fill hset entities')
parser = argparse.ArgumentParser(description="fill hset entities")
parser.add_argument("-p", type=int, help="redis port", dest="port", default=6380)
parser.add_argument("-n", type=int, help="number of keys", dest="num", default=10000)
parser.add_argument(
'-p', type=int, help='redis port', dest='port', default=6380)
parser.add_argument(
'-n', type=int, help='number of keys', dest='num', default=10000)
parser.add_argument(
'--type', type=str, choices=['hset', 'set'], help='set type', default='hset')
"--type", type=str, choices=["hset", "set"], help="set type", default="hset"
)
args = parser.parse_args()
redis = rclient.Redis(host='localhost', port=args.port, db=0)
if args.type == 'hset':
redis = rclient.Redis(host="localhost", port=args.port, db=0)
if args.type == "hset":
fill_hset(args, redis)
elif args.type == 'set':
elif args.type == "set":
fill_set(args, redis)
if __name__ == "__main__":
main()