1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-14 11:58:02 +00:00

feat: Add black formatter to the project (#1544)

Add black formatter and run it on pytests
This commit is contained in:
Kostas Kyrimis 2023-07-17 13:13:12 +03:00 committed by GitHub
parent 9448220607
commit 7944af3c62
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 793 additions and 566 deletions

View file

@ -24,3 +24,8 @@ repos:
hooks: hooks:
- id: clang-format - id: clang-format
name: Clang formatting name: Clang formatting
- repo: https://github.com/psf/black
rev: 23.7.0
hooks:
- id: black

12
pyproject.toml Normal file
View file

@ -0,0 +1,12 @@
[tool.black]
line-length = 100
include = '\.py$'
extend-exclude = '''
/(
| .git
| .__pycache__
| build-dbg
| build-opt
| helio
)/
'''

View file

@ -70,8 +70,7 @@ class DflyInstance:
return return
base_args = [f"--{v}" for v in self.params.args] base_args = [f"--{v}" for v in self.params.args]
all_args = self.format_args(self.args) + base_args all_args = self.format_args(self.args) + base_args
print( print(f"Starting instance on {self.port} with arguments {all_args} from {self.params.path}")
f"Starting instance on {self.port} with arguments {all_args} from {self.params.path}")
run_cmd = [self.params.path, *all_args] run_cmd = [self.params.path, *all_args]
if self.params.gdb: if self.params.gdb:
@ -82,8 +81,7 @@ class DflyInstance:
if not self.params.existing_port: if not self.params.existing_port:
return_code = self.proc.poll() return_code = self.proc.poll()
if return_code is not None: if return_code is not None:
raise Exception( raise Exception(f"Failed to start instance, return code {return_code}")
f"Failed to start instance, return code {return_code}")
def __getitem__(self, k): def __getitem__(self, k):
return self.args.get(k) return self.args.get(k)
@ -93,11 +91,13 @@ class DflyInstance:
if self.params.existing_port: if self.params.existing_port:
return self.params.existing_port return self.params.existing_port
return int(self.args.get("port", "6379")) return int(self.args.get("port", "6379"))
@property @property
def admin_port(self) -> int: def admin_port(self) -> int:
if self.params.existing_admin_port: if self.params.existing_admin_port:
return self.params.existing_admin_port return self.params.existing_admin_port
return int(self.args.get("admin_port", "16379")) return int(self.args.get("admin_port", "16379"))
@property @property
def mc_port(self) -> int: def mc_port(self) -> int:
if self.params.existing_mc_port: if self.params.existing_mc_port:
@ -107,7 +107,7 @@ class DflyInstance:
@staticmethod @staticmethod
def format_args(args): def format_args(args):
out = [] out = []
for (k, v) in args.items(): for k, v in args.items():
out.append(f"--{k}") out.append(f"--{k}")
if v is not None: if v is not None:
out.append(str(v)) out.append(str(v))
@ -118,7 +118,10 @@ class DflyInstance:
resp = await session.get(f"http://localhost:{self.port}/metrics") resp = await session.get(f"http://localhost:{self.port}/metrics")
data = await resp.text() data = await resp.text()
await session.close() await session.close()
return {metric_family.name : metric_family for metric_family in text_string_to_metric_families(data)} return {
metric_family.name: metric_family
for metric_family in text_string_to_metric_families(data)
}
class DflyInstanceFactory: class DflyInstanceFactory:
@ -171,7 +174,7 @@ def dfly_multi_test_args(*args):
return pytest.mark.parametrize("df_factory", args, indirect=True) return pytest.mark.parametrize("df_factory", args, indirect=True)
class PortPicker(): class PortPicker:
"""A simple port manager to allocate available ports for tests""" """A simple port manager to allocate available ports for tests"""
def __init__(self): def __init__(self):
@ -185,5 +188,6 @@ class PortPicker():
def is_port_available(self, port): def is_port_available(self, port):
import socket import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) != 0 return s.connect_ex(("localhost", port)) != 0

View file

@ -13,9 +13,9 @@ BASE_PORT = 30001
async def push_config(config, admin_connections): async def push_config(config, admin_connections):
print("Pushing config ", config) print("Pushing config ", config)
await asyncio.gather(*(c_admin.execute_command( await asyncio.gather(
"DFLYCLUSTER", "CONFIG", config) *(c_admin.execute_command("DFLYCLUSTER", "CONFIG", config) for c_admin in admin_connections)
for c_admin in admin_connections)) )
async def get_node_id(admin_connection): async def get_node_id(admin_connection):
@ -38,15 +38,13 @@ class TestNotEmulated:
@dfly_args({"cluster_mode": "emulated"}) @dfly_args({"cluster_mode": "emulated"})
class TestEmulated: class TestEmulated:
def test_cluster_slots_command(self, cluster_client: redis.RedisCluster): def test_cluster_slots_command(self, cluster_client: redis.RedisCluster):
expected = {(0, 16383): {'primary': ( expected = {(0, 16383): {"primary": ("127.0.0.1", 6379), "replicas": []}}
'127.0.0.1', 6379), 'replicas': []}}
res = cluster_client.execute_command("CLUSTER SLOTS") res = cluster_client.execute_command("CLUSTER SLOTS")
assert expected == res assert expected == res
def test_cluster_help_command(self, cluster_client: redis.RedisCluster): def test_cluster_help_command(self, cluster_client: redis.RedisCluster):
# `target_nodes` is necessary because CLUSTER HELP is not mapped on redis-py # `target_nodes` is necessary because CLUSTER HELP is not mapped on redis-py
res = cluster_client.execute_command( res = cluster_client.execute_command("CLUSTER HELP", target_nodes=redis.RedisCluster.RANDOM)
"CLUSTER HELP", target_nodes=redis.RedisCluster.RANDOM)
assert "HELP" in res assert "HELP" in res
assert "SLOTS" in res assert "SLOTS" in res
@ -61,15 +59,16 @@ class TestEmulated:
@dfly_args({"cluster_mode": "emulated", "cluster_announce_ip": "127.0.0.2"}) @dfly_args({"cluster_mode": "emulated", "cluster_announce_ip": "127.0.0.2"})
class TestEmulatedWithAnnounceIp: class TestEmulatedWithAnnounceIp:
def test_cluster_slots_command(self, cluster_client: redis.RedisCluster): def test_cluster_slots_command(self, cluster_client: redis.RedisCluster):
expected = {(0, 16383): {'primary': ( expected = {(0, 16383): {"primary": ("127.0.0.2", 6379), "replicas": []}}
'127.0.0.2', 6379), 'replicas': []}}
res = cluster_client.execute_command("CLUSTER SLOTS") res = cluster_client.execute_command("CLUSTER SLOTS")
assert expected == res assert expected == res
def verify_slots_result(ip: str, port: int, answer: list, rep_ip: str = None, rep_port: int = None) -> bool: def verify_slots_result(
ip: str, port: int, answer: list, rep_ip: str = None, rep_port: int = None
) -> bool:
def is_local_host(ip: str) -> bool: def is_local_host(ip: str) -> bool:
return ip == '127.0.0.1' or ip == 'localhost' return ip == "127.0.0.1" or ip == "localhost"
assert answer[0] == 0 # start shard assert answer[0] == 0 # start shard
assert answer[1] == 16383 # last shard assert answer[1] == 16383 # last shard
@ -77,15 +76,14 @@ def verify_slots_result(ip: str, port: int, answer: list, rep_ip: str = None, re
assert len(answer) == 4 # the network info assert len(answer) == 4 # the network info
rep_info = answer[3] rep_info = answer[3]
assert len(rep_info) == 3 assert len(rep_info) == 3
ip_addr = str(rep_info[0], 'utf-8') ip_addr = str(rep_info[0], "utf-8")
assert ip_addr == rep_ip or ( assert ip_addr == rep_ip or (is_local_host(ip_addr) and is_local_host(ip))
is_local_host(ip_addr) and is_local_host(ip))
assert rep_info[1] == rep_port assert rep_info[1] == rep_port
else: else:
assert len(answer) == 3 assert len(answer) == 3
info = answer[2] info = answer[2]
assert len(info) == 3 assert len(info) == 3
ip_addr = str(info[0], 'utf-8') ip_addr = str(info[0], "utf-8")
assert ip_addr == ip or (is_local_host(ip_addr) and is_local_host(ip)) assert ip_addr == ip or (is_local_host(ip_addr) and is_local_host(ip))
assert info[1] == port assert info[1] == port
return True return True
@ -103,44 +101,45 @@ async def test_cluster_slots_in_replicas(df_local_factory):
res = await c_replica.execute_command("CLUSTER SLOTS") res = await c_replica.execute_command("CLUSTER SLOTS")
assert len(res) == 1 assert len(res) == 1
assert verify_slots_result( assert verify_slots_result(ip="127.0.0.1", port=replica.port, answer=res[0])
ip="127.0.0.1", port=replica.port, answer=res[0])
res = await c_master.execute_command("CLUSTER SLOTS") res = await c_master.execute_command("CLUSTER SLOTS")
assert verify_slots_result( assert verify_slots_result(ip="127.0.0.1", port=master.port, answer=res[0])
ip="127.0.0.1", port=master.port, answer=res[0])
# Connect replica to master # Connect replica to master
rc = await c_replica.execute_command(f"REPLICAOF localhost {master.port}") rc = await c_replica.execute_command(f"REPLICAOF localhost {master.port}")
assert str(rc, 'utf-8') == "OK" assert str(rc, "utf-8") == "OK"
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
res = await c_replica.execute_command("CLUSTER SLOTS") res = await c_replica.execute_command("CLUSTER SLOTS")
assert verify_slots_result( assert verify_slots_result(
ip="127.0.0.1", port=master.port, answer=res[0], rep_ip="127.0.0.1", rep_port=replica.port) ip="127.0.0.1", port=master.port, answer=res[0], rep_ip="127.0.0.1", rep_port=replica.port
)
res = await c_master.execute_command("CLUSTER SLOTS") res = await c_master.execute_command("CLUSTER SLOTS")
assert verify_slots_result( assert verify_slots_result(
ip="127.0.0.1", port=master.port, answer=res[0], rep_ip="127.0.0.1", rep_port=replica.port) ip="127.0.0.1", port=master.port, answer=res[0], rep_ip="127.0.0.1", rep_port=replica.port
)
@dfly_args({"cluster_mode": "emulated", "cluster_announce_ip": "127.0.0.2"}) @dfly_args({"cluster_mode": "emulated", "cluster_announce_ip": "127.0.0.2"})
async def test_cluster_info(async_client): async def test_cluster_info(async_client):
res = await async_client.execute_command("CLUSTER INFO") res = await async_client.execute_command("CLUSTER INFO")
assert len(res) == 16 assert len(res) == 16
assert res == {'cluster_current_epoch': '1', assert res == {
'cluster_known_nodes': '1', "cluster_current_epoch": "1",
'cluster_my_epoch': '1', "cluster_known_nodes": "1",
'cluster_size': '1', "cluster_my_epoch": "1",
'cluster_slots_assigned': '16384', "cluster_size": "1",
'cluster_slots_fail': '0', "cluster_slots_assigned": "16384",
'cluster_slots_ok': '16384', "cluster_slots_fail": "0",
'cluster_slots_pfail': '0', "cluster_slots_ok": "16384",
'cluster_state': 'ok', "cluster_slots_pfail": "0",
'cluster_stats_messages_meet_received': '0', "cluster_state": "ok",
'cluster_stats_messages_ping_received': '1', "cluster_stats_messages_meet_received": "0",
'cluster_stats_messages_ping_sent': '1', "cluster_stats_messages_ping_received": "1",
'cluster_stats_messages_pong_received': '1', "cluster_stats_messages_ping_sent": "1",
'cluster_stats_messages_pong_sent': '1', "cluster_stats_messages_pong_received": "1",
'cluster_stats_messages_received': '1', "cluster_stats_messages_pong_sent": "1",
'cluster_stats_messages_sent': '1' "cluster_stats_messages_received": "1",
"cluster_stats_messages_sent": "1",
} }
@ -149,14 +148,14 @@ async def test_cluster_info(async_client):
async def test_cluster_nodes(async_client): async def test_cluster_nodes(async_client):
res = await async_client.execute_command("CLUSTER NODES") res = await async_client.execute_command("CLUSTER NODES")
assert len(res) == 1 assert len(res) == 1
info = res['127.0.0.2:6379'] info = res["127.0.0.2:6379"]
assert res is not None assert res is not None
assert info['connected'] == True assert info["connected"] == True
assert info['epoch'] == '0' assert info["epoch"] == "0"
assert info['flags'] == 'myself,master' assert info["flags"] == "myself,master"
assert info['last_ping_sent'] == '0' assert info["last_ping_sent"] == "0"
assert info['slots'] == [['0', '16383']] assert info["slots"] == [["0", "16383"]]
assert info['master_id'] == "-" assert info["master_id"] == "-"
""" """
@ -166,6 +165,8 @@ Add a key to node0, then move the slot ownership to node1 and see that they both
intended. intended.
Also add keys to each of them that are *not* moved, and see that they are unaffected by the move. Also add keys to each of them that are *not* moved, and see that they are unaffected by the move.
""" """
@dfly_args({"proactor_threads": 4, "cluster_mode": "yes"}) @dfly_args({"proactor_threads": 4, "cluster_mode": "yes"})
async def test_cluster_slot_ownership_changes(df_local_factory): async def test_cluster_slot_ownership_changes(df_local_factory):
# Start and configure cluster with 2 nodes # Start and configure cluster with 2 nodes
@ -214,7 +215,10 @@ async def test_cluster_slot_ownership_changes(df_local_factory):
] ]
""" """
await push_config(config.replace('LAST_SLOT_CUTOFF', '5259').replace('NEXT_SLOT_CUTOFF', '5260'), c_nodes_admin) await push_config(
config.replace("LAST_SLOT_CUTOFF", "5259").replace("NEXT_SLOT_CUTOFF", "5260"),
c_nodes_admin,
)
# Slot for "KEY1" is 5259 # Slot for "KEY1" is 5259
@ -243,7 +247,10 @@ async def test_cluster_slot_ownership_changes(df_local_factory):
print("Moving ownership over 5259 ('KEY1') to other node") print("Moving ownership over 5259 ('KEY1') to other node")
await push_config(config.replace('LAST_SLOT_CUTOFF', '5258').replace('NEXT_SLOT_CUTOFF', '5259'), c_nodes_admin) await push_config(
config.replace("LAST_SLOT_CUTOFF", "5258").replace("NEXT_SLOT_CUTOFF", "5259"),
c_nodes_admin,
)
# node0 should have removed "KEY1" as it no longer owns it # node0 should have removed "KEY1" as it no longer owns it
assert await c_nodes[0].execute_command("DBSIZE") == 1 assert await c_nodes[0].execute_command("DBSIZE") == 1
@ -453,7 +460,7 @@ async def test_cluster_flush_slots_after_config_change(df_local_factory):
resp = await c_master_admin.execute_command("dflycluster", "getslotinfo", "slots", "0") resp = await c_master_admin.execute_command("dflycluster", "getslotinfo", "slots", "0")
assert resp[0][0] == 0 assert resp[0][0] == 0
slot_0_size = resp[0][2] slot_0_size = resp[0][2]
print(f'Slot 0 size = {slot_0_size}') print(f"Slot 0 size = {slot_0_size}")
assert slot_0_size > 0 assert slot_0_size > 0
config = f""" config = f"""
@ -597,22 +604,23 @@ async def test_cluster_native_client(df_local_factory):
client = aioredis.RedisCluster(decode_responses=True, host="localhost", port=masters[0].port) client = aioredis.RedisCluster(decode_responses=True, host="localhost", port=masters[0].port)
assert await client.set('key0', 'value') == True assert await client.set("key0", "value") == True
assert await client.get('key0') == 'value' assert await client.get("key0") == "value"
async def test_random_keys(): async def test_random_keys():
for i in range(100): for i in range(100):
key = 'key' + str(random.randint(0, 100_000)) key = "key" + str(random.randint(0, 100_000))
assert await client.set(key, 'value') == True assert await client.set(key, "value") == True
assert await client.get(key) == 'value' assert await client.get(key) == "value"
await test_random_keys() await test_random_keys()
await asyncio.gather(*(wait_available_async(c) for c in c_replicas)) await asyncio.gather(*(wait_available_async(c) for c in c_replicas))
# Make sure that getting a value from a replica works as well. # Make sure that getting a value from a replica works as well.
replica_response = await client.execute_command( replica_response = await client.execute_command(
'get', 'key0', target_nodes=aioredis.RedisCluster.REPLICAS) "get", "key0", target_nodes=aioredis.RedisCluster.REPLICAS
assert 'value' in replica_response.values() )
assert "value" in replica_response.values()
# Push new config # Push new config
config = f""" config = f"""

View file

@ -21,7 +21,7 @@ from tempfile import TemporaryDirectory
from . import DflyInstance, DflyInstanceFactory, DflyParams, PortPicker, dfly_args from . import DflyInstance, DflyInstanceFactory, DflyParams, PortPicker, dfly_args
from .utility import DflySeederFactory, gen_certificate from .utility import DflySeederFactory, gen_certificate
logging.getLogger('asyncio').setLevel(logging.WARNING) logging.getLogger("asyncio").setLevel(logging.WARNING)
DATABASE_INDEX = 1 DATABASE_INDEX = 1
@ -55,7 +55,6 @@ def df_seeder_factory(request) -> DflySeederFactory:
if seed is None: if seed is None:
seed = random.randrange(sys.maxsize) seed = random.randrange(sys.maxsize)
random.seed(int(seed)) random.seed(int(seed))
print(f"--- Random seed: {seed}, check: {random.randrange(100)} ---") print(f"--- Random seed: {seed}, check: {random.randrange(100)} ---")
@ -68,8 +67,7 @@ def df_factory(request, tmp_dir, test_env) -> DflyInstanceFactory:
Create an instance factory with supplied params. Create an instance factory with supplied params.
""" """
scripts_dir = os.path.dirname(os.path.abspath(__file__)) scripts_dir = os.path.dirname(os.path.abspath(__file__))
path = os.environ.get("DRAGONFLY_PATH", os.path.join( path = os.environ.get("DRAGONFLY_PATH", os.path.join(scripts_dir, "../../build-dbg/dragonfly"))
scripts_dir, '../../build-dbg/dragonfly'))
args = request.param if request.param else {} args = request.param if request.param else {}
existing = request.config.getoption("--existing-port") existing = request.config.getoption("--existing-port")
@ -83,7 +81,7 @@ def df_factory(request, tmp_dir, test_env) -> DflyInstanceFactory:
existing_port=int(existing) if existing else None, existing_port=int(existing) if existing else None,
existing_admin_port=int(existing_admin) if existing_admin else None, existing_admin_port=int(existing_admin) if existing_admin else None,
existing_mc_port=int(existing_mc) if existing_mc else None, existing_mc_port=int(existing_mc) if existing_mc else None,
env=test_env env=test_env,
) )
factory = DflyInstanceFactory(params, args) factory = DflyInstanceFactory(params, args)
@ -129,7 +127,7 @@ def df_server(df_factory: DflyInstanceFactory) -> DflyInstance:
# else: # else:
# print("Cluster clients left: ", len(clients_left)) # print("Cluster clients left: ", len(clients_left))
if instance['cluster_mode']: if instance["cluster_mode"]:
print("Cluster clients left: ", len(clients_left)) print("Cluster clients left: ", len(clients_left))
@ -160,8 +158,7 @@ def cluster_client(df_server):
""" """
Return a cluster client to the default instance with all entries flushed. Return a cluster client to the default instance with all entries flushed.
""" """
client = redis.RedisCluster(decode_responses=True, host="localhost", client = redis.RedisCluster(decode_responses=True, host="localhost", port=df_server.port)
port=df_server.port)
client.client_setname("default-cluster-fixture") client.client_setname("default-cluster-fixture")
client.flushall() client.flushall()
@ -171,11 +168,17 @@ def cluster_client(df_server):
@pytest_asyncio.fixture(scope="function") @pytest_asyncio.fixture(scope="function")
async def async_pool(df_server: DflyInstance): async def async_pool(df_server: DflyInstance):
pool = aioredis.ConnectionPool(host="localhost", port=df_server.port, pool = aioredis.ConnectionPool(
db=DATABASE_INDEX, decode_responses=True, max_connections=32) host="localhost",
port=df_server.port,
db=DATABASE_INDEX,
decode_responses=True,
max_connections=32,
)
yield pool yield pool
await pool.disconnect(inuse_connections=True) await pool.disconnect(inuse_connections=True)
@pytest_asyncio.fixture(scope="function") @pytest_asyncio.fixture(scope="function")
async def async_client(async_pool): async def async_client(async_pool):
""" """
@ -197,25 +200,35 @@ def pytest_addoption(parser):
--existing-admin-port - to provide an admin port to an existing process instead of starting a new instance --existing-admin-port - to provide an admin port to an existing process instead of starting a new instance
--rand-seed - to set the global random seed --rand-seed - to set the global random seed
""" """
parser.addoption("--gdb", action="store_true", default=False, help="Run instances in gdb")
parser.addoption("--df", action="append", default=[], help="Add arguments to dragonfly")
parser.addoption( parser.addoption(
'--gdb', action='store_true', default=False, help='Run instances in gdb' "--log-seeder", action="store", default=None, help="Store last generator commands in file"
) )
parser.addoption( parser.addoption(
'--df', action='append', default=[], help='Add arguments to dragonfly' "--rand-seed",
action="store",
default=None,
help="Set seed for global random. Makes seeder predictable",
) )
parser.addoption( parser.addoption(
'--log-seeder', action='store', default=None, help='Store last generator commands in file' "--existing-port",
action="store",
default=None,
help="Provide a port to the existing process for the test",
) )
parser.addoption( parser.addoption(
'--rand-seed', action='store', default=None, help='Set seed for global random. Makes seeder predictable' "--existing-admin-port",
action="store",
default=None,
help="Provide an admin port to the existing process for the test",
) )
parser.addoption(
'--existing-port', action='store', default=None, help='Provide a port to the existing process for the test')
parser.addoption(
'--existing-admin-port', action='store', default=None, help='Provide an admin port to the existing process for the test')
parser.addoption( parser.addoption(
'--existing-mc-port', action='store', default=None, help='Provide a port to the existing memcached process for the test' "--existing-mc-port",
action="store",
default=None,
help="Provide a port to the existing memcached process for the test",
) )
@ -251,11 +264,15 @@ def with_tls_server_args(tmp_dir, gen_ca_cert):
tls_server_req = os.path.join(tmp_dir, "df-req.pem") tls_server_req = os.path.join(tmp_dir, "df-req.pem")
tls_server_cert = os.path.join(tmp_dir, "df-cert.pem") tls_server_cert = os.path.join(tmp_dir, "df-cert.pem")
gen_certificate(gen_ca_cert["ca_key"], gen_ca_cert["ca_cert"], tls_server_req, tls_server_key, tls_server_cert) gen_certificate(
gen_ca_cert["ca_key"],
gen_ca_cert["ca_cert"],
tls_server_req,
tls_server_key,
tls_server_cert,
)
args = {"tls": "", args = {"tls": "", "tls_key_file": tls_server_key, "tls_cert_file": tls_server_cert}
"tls_key_file": tls_server_key,
"tls_cert_file": tls_server_cert}
return args return args
@ -272,11 +289,15 @@ def with_tls_client_args(tmp_dir, gen_ca_cert):
tls_client_req = os.path.join(tmp_dir, "client-req.pem") tls_client_req = os.path.join(tmp_dir, "client-req.pem")
tls_client_cert = os.path.join(tmp_dir, "client-cert.pem") tls_client_cert = os.path.join(tmp_dir, "client-cert.pem")
gen_certificate(gen_ca_cert["ca_key"], gen_ca_cert["ca_cert"], tls_client_req, tls_client_key, tls_client_cert) gen_certificate(
gen_ca_cert["ca_key"],
gen_ca_cert["ca_cert"],
tls_client_req,
tls_client_key,
tls_client_cert,
)
args = {"ssl": True, args = {"ssl": True, "ssl_keyfile": tls_client_key, "ssl_certfile": tls_client_cert}
"ssl_keyfile": tls_client_key,
"ssl_certfile": tls_client_cert}
return args return args

View file

@ -9,6 +9,7 @@ from . import DflyInstance, dfly_args
BASE_PORT = 1111 BASE_PORT = 1111
async def run_monitor_eval(monitor, expected): async def run_monitor_eval(monitor, expected):
async with monitor as mon: async with monitor as mon:
count = 0 count = 0
@ -29,23 +30,22 @@ async def run_monitor_eval(monitor, expected):
return False return False
return True return True
'''
"""
Test issue https://github.com/dragonflydb/dragonfly/issues/756 Test issue https://github.com/dragonflydb/dragonfly/issues/756
Monitor command do not return when we have lua script issue Monitor command do not return when we have lua script issue
''' """
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_monitor_command_lua(async_pool): async def test_monitor_command_lua(async_pool):
expected = ["EVAL return redis", expected = ["EVAL return redis", "EVAL return redis", "SET foo2"]
"EVAL return redis", "SET foo2"]
conn = aioredis.Redis(connection_pool=async_pool) conn = aioredis.Redis(connection_pool=async_pool)
monitor = conn.monitor() monitor = conn.monitor()
cmd1 = aioredis.Redis(connection_pool=async_pool) cmd1 = aioredis.Redis(connection_pool=async_pool)
future = asyncio.create_task(run_monitor_eval( future = asyncio.create_task(run_monitor_eval(monitor=monitor, expected=expected))
monitor=monitor, expected=expected))
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
try: try:
@ -55,7 +55,7 @@ async def test_monitor_command_lua(async_pool):
assert "script tried accessing undeclared key" in str(e) assert "script tried accessing undeclared key" in str(e)
try: try:
res = await cmd1.eval(r'return redis.call("SET", KEYS[1], ARGV[1])', 1, 'foo2', 'bar2') res = await cmd1.eval(r'return redis.call("SET", KEYS[1], ARGV[1])', 1, "foo2", "bar2")
except Exception as e: except Exception as e:
print(f"EVAL error: {e}") print(f"EVAL error: {e}")
assert False assert False
@ -66,12 +66,12 @@ async def test_monitor_command_lua(async_pool):
assert status assert status
''' """
Test the monitor command. Test the monitor command.
Open connection which is used for monitoring Open connection which is used for monitoring
Then send on other connection commands to dragonfly instance Then send on other connection commands to dragonfly instance
Make sure that we are getting the commands in the monitor context Make sure that we are getting the commands in the monitor context
''' """
@pytest.mark.asyncio @pytest.mark.asyncio
@ -101,9 +101,11 @@ async def process_cmd(monitor, key, value):
if "select" not in response["command"].lower(): if "select" not in response["command"].lower():
success = verify_response(response, key, value) success = verify_response(response, key, value)
if not success: if not success:
print( print(f"failed to verify message {response} for {key}/{value}")
f"failed to verify message {response} for {key}/{value}") return (
return False, f"failed on the verification of the message {response} at {key}: {value}" False,
f"failed on the verification of the message {response} at {key}: {value}",
)
else: else:
return True, None return True, None
except asyncio.TimeoutError: except asyncio.TimeoutError:
@ -146,11 +148,11 @@ async def run_monitor(messages: dict, pool: aioredis.ConnectionPool):
return False, f"monitor result: {status}: {message}, set command success {success}" return False, f"monitor result: {status}: {message}, set command success {success}"
''' """
Run test in pipeline mode. Run test in pipeline mode.
This is mostly how this is done with python - its more like a transaction that This is mostly how this is done with python - its more like a transaction that
the connections is running all commands in its context the connections is running all commands in its context
''' """
@pytest.mark.asyncio @pytest.mark.asyncio
@ -194,12 +196,12 @@ async def run_pipeline_mode(async_client: aioredis.Redis, messages):
return True, "all command processed successfully" return True, "all command processed successfully"
''' """
Test the pipeline command Test the pipeline command
Open connection to the subscriber and publish on the other end messages Open connection to the subscriber and publish on the other end messages
Make sure that we are able to send all of them and that we are getting the Make sure that we are able to send all of them and that we are getting the
expected results on the subscriber side expected results on the subscriber side
''' """
@pytest.mark.asyncio @pytest.mark.asyncio
@ -232,7 +234,10 @@ async def run_pubsub(async_client, messages, channel_name):
if status and success: if status and success:
return True, "successfully completed all" return True, "successfully completed all"
else: else:
return False, f"subscriber result: {status}: {message}, publisher publish: success {success}" return (
False,
f"subscriber result: {status}: {message}, publisher publish: success {success}",
)
async def run_multi_pubsub(async_client, messages, channel_name): async def run_multi_pubsub(async_client, messages, channel_name):
@ -241,7 +246,8 @@ async def run_multi_pubsub(async_client, messages, channel_name):
await s.subscribe(channel_name) await s.subscribe(channel_name)
tasks = [ tasks = [
asyncio.create_task(reader(s, messages, random.randint(0, len(messages)))) for s in subs] asyncio.create_task(reader(s, messages, random.randint(0, len(messages)))) for s in subs
]
success = True success = True
@ -266,12 +272,12 @@ async def run_multi_pubsub(async_client, messages, channel_name):
return False, "failed to publish" return False, "failed to publish"
''' """
Test with multiple subscribers for a channel Test with multiple subscribers for a channel
We want to stress this to see if we have any issue We want to stress this to see if we have any issue
with the pub sub code since we are "sharing" the message with the pub sub code since we are "sharing" the message
across multiple connections internally across multiple connections internally
''' """
@pytest.mark.asyncio @pytest.mark.asyncio
@ -279,6 +285,7 @@ async def test_multi_pubsub(async_client):
def generate(max): def generate(max):
for i in range(max): for i in range(max):
yield f"this is message number {i} from the publisher on the channel" yield f"this is message number {i} from the publisher on the channel"
messages = [a for a in generate(500)] messages = [a for a in generate(500)]
state, message = await run_multi_pubsub(async_client, messages, "my-channel") state, message = await run_multi_pubsub(async_client, messages, "my-channel")
@ -288,8 +295,13 @@ async def test_multi_pubsub(async_client):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_subscribers_with_active_publisher(df_server: DflyInstance, max_connections=100): async def test_subscribers_with_active_publisher(df_server: DflyInstance, max_connections=100):
# TODO: I am not how to customize the max connections for the pool. # TODO: I am not how to customize the max connections for the pool.
async_pool = aioredis.ConnectionPool(host="localhost", port=df_server.port, async_pool = aioredis.ConnectionPool(
db=0, decode_responses=True, max_connections=max_connections) host="localhost",
port=df_server.port,
db=0,
decode_responses=True,
max_connections=max_connections,
)
async def publish_worker(): async def publish_worker():
client = aioredis.Redis(connection_pool=async_pool) client = aioredis.Redis(connection_pool=async_pool)
@ -322,12 +334,12 @@ async def test_subscribers_with_active_publisher(df_server: DflyInstance, max_co
async def test_big_command(df_server, size=8 * 1024): async def test_big_command(df_server, size=8 * 1024):
reader, writer = await asyncio.open_connection('127.0.0.1', df_server.port) reader, writer = await asyncio.open_connection("127.0.0.1", df_server.port)
writer.write(f"SET a {'v'*size}\n".encode()) writer.write(f"SET a {'v'*size}\n".encode())
await writer.drain() await writer.drain()
assert 'OK' in (await reader.readline()).decode() assert "OK" in (await reader.readline()).decode()
writer.close() writer.close()
await writer.wait_closed() await writer.wait_closed()
@ -335,9 +347,8 @@ async def test_big_command(df_server, size=8 * 1024):
async def test_subscribe_pipelined(async_client: aioredis.Redis): async def test_subscribe_pipelined(async_client: aioredis.Redis):
pipe = async_client.pipeline(transaction=False) pipe = async_client.pipeline(transaction=False)
pipe.execute_command('subscribe channel').execute_command( pipe.execute_command("subscribe channel").execute_command("subscribe channel")
'subscribe channel') await pipe.echo("bye bye").execute()
await pipe.echo('bye bye').execute()
async def test_subscribe_in_pipeline(async_client: aioredis.Redis): async def test_subscribe_in_pipeline(async_client: aioredis.Redis):
@ -349,8 +360,8 @@ async def test_subscribe_in_pipeline(async_client: aioredis.Redis):
pipe.echo("three") pipe.echo("three")
res = await pipe.execute() res = await pipe.execute()
assert res == ['one', ['subscribe', 'ch1', 1], assert res == ["one", ["subscribe", "ch1", 1], "two", ["subscribe", "ch2", 2], "three"]
'two', ['subscribe', 'ch2', 2], 'three']
""" """
This test makes sure that Dragonfly can receive blocks of pipelined commands even This test makes sure that Dragonfly can receive blocks of pipelined commands even
@ -376,9 +387,13 @@ MGET m4 m5 m6
MGET m7 m8 m9\n MGET m7 m8 m9\n
""" """
PACKET3 = """ PACKET3 = (
"""
PING PING
""" * 500 + "ECHO DONE\n" """
* 500
+ "ECHO DONE\n"
)
async def test_parser_while_script_running(async_client: aioredis.Redis, df_server: DflyInstance): async def test_parser_while_script_running(async_client: aioredis.Redis, df_server: DflyInstance):
@ -386,7 +401,7 @@ async def test_parser_while_script_running(async_client: aioredis.Redis, df_serv
# Use a raw tcp connection for strict control of sent commands # Use a raw tcp connection for strict control of sent commands
# Below we send commands while the previous ones didn't finish # Below we send commands while the previous ones didn't finish
reader, writer = await asyncio.open_connection('localhost', df_server.port) reader, writer = await asyncio.open_connection("localhost", df_server.port)
# Send first pipeline packet, last commands is a long executing script # Send first pipeline packet, last commands is a long executing script
writer.write(PACKET1.format(sha=sha).encode()) writer.write(PACKET1.format(sha=sha).encode())
@ -409,7 +424,9 @@ async def test_parser_while_script_running(async_client: aioredis.Redis, df_serv
@dfly_args({"proactor_threads": 1}) @dfly_args({"proactor_threads": 1})
async def test_large_cmd(async_client: aioredis.Redis): async def test_large_cmd(async_client: aioredis.Redis):
MAX_ARR_SIZE = 65535 MAX_ARR_SIZE = 65535
res = await async_client.hset('foo', mapping={f"key{i}": f"val{i}" for i in range(MAX_ARR_SIZE // 2)}) res = await async_client.hset(
"foo", mapping={f"key{i}": f"val{i}" for i in range(MAX_ARR_SIZE // 2)}
)
assert res == MAX_ARR_SIZE // 2 assert res == MAX_ARR_SIZE // 2
res = await async_client.mset({f"key{i}": f"val{i}" for i in range(MAX_ARR_SIZE // 2)}) res = await async_client.mset({f"key{i}": f"val{i}" for i in range(MAX_ARR_SIZE // 2)})
@ -421,7 +438,9 @@ async def test_large_cmd(async_client: aioredis.Redis):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reject_non_tls_connections_on_tls(with_tls_server_args, df_local_factory): async def test_reject_non_tls_connections_on_tls(with_tls_server_args, df_local_factory):
server = df_local_factory.create(no_tls_on_admin_port="true", admin_port=1111, port=1211, **with_tls_server_args) server = df_local_factory.create(
no_tls_on_admin_port="true", admin_port=1111, port=1211, **with_tls_server_args
)
server.start() server.start()
client = aioredis.Redis(port=server.port) client = aioredis.Redis(port=server.port)

View file

@ -75,15 +75,13 @@ return 'OK'
""" """
def DJANGO_CACHEOPS_SCHEMA(vs): return { def DJANGO_CACHEOPS_SCHEMA(vs):
"table_1": [ return {
{"f-1": f'v-{vs[0]}'}, {"f-2": f'v-{vs[1]}'} "table_1": [{"f-1": f"v-{vs[0]}"}, {"f-2": f"v-{vs[1]}"}],
], "table_2": [{"f-1": f"v-{vs[2]}"}, {"f-2": f"v-{vs[3]}"}],
"table_2": [
{"f-1": f'v-{vs[2]}'}, {"f-2": f'v-{vs[3]}'}
]
} }
""" """
Test the main caching script of https://github.com/Suor/django-cacheops. Test the main caching script of https://github.com/Suor/django-cacheops.
The script accesses undeclared keys (that are built based on argument data), The script accesses undeclared keys (that are built based on argument data),
@ -91,21 +89,25 @@ so Dragonfly must run in global (1) or non-atomic (4) multi eval mode.
""" """
@dfly_multi_test_args({'default_lua_flags': 'allow-undeclared-keys', 'proactor_threads': 4}, @dfly_multi_test_args(
{'default_lua_flags': 'allow-undeclared-keys disable-atomicity', 'proactor_threads': 4}) {"default_lua_flags": "allow-undeclared-keys", "proactor_threads": 4},
{"default_lua_flags": "allow-undeclared-keys disable-atomicity", "proactor_threads": 4},
)
async def test_django_cacheops_script(async_client, num_keys=500): async def test_django_cacheops_script(async_client, num_keys=500):
script = async_client.register_script(DJANGO_CACHEOPS_SCRIPT) script = async_client.register_script(DJANGO_CACHEOPS_SCRIPT)
data = [(f'k-{k}', [random.randint(0, 10) for _ in range(4)]) data = [(f"k-{k}", [random.randint(0, 10) for _ in range(4)]) for k in range(num_keys)]
for k in range(num_keys)]
for k, vs in data: for k, vs in data:
schema = DJANGO_CACHEOPS_SCHEMA(vs) schema = DJANGO_CACHEOPS_SCHEMA(vs)
assert await script(keys=['', k, ''], args=['a' * 10, json.dumps(schema, sort_keys=True), 100]) == 'OK' assert (
await script(keys=["", k, ""], args=["a" * 10, json.dumps(schema, sort_keys=True), 100])
== "OK"
)
# Check schema was built correctly # Check schema was built correctly
base_schema = DJANGO_CACHEOPS_SCHEMA([0] * 4) base_schema = DJANGO_CACHEOPS_SCHEMA([0] * 4)
for table, fields in base_schema.items(): for table, fields in base_schema.items():
schema = await async_client.smembers(f'schemes:{table}') schema = await async_client.smembers(f"schemes:{table}")
fields = set.union(*(set(part.keys()) for part in fields)) fields = set.union(*(set(part.keys()) for part in fields))
assert schema == fields assert schema == fields
@ -114,9 +116,9 @@ async def test_django_cacheops_script(async_client, num_keys=500):
assert await async_client.exists(k) assert await async_client.exists(k)
for table, fields in DJANGO_CACHEOPS_SCHEMA(vs).items(): for table, fields in DJANGO_CACHEOPS_SCHEMA(vs).items():
for sub_schema in fields: for sub_schema in fields:
conj_key = f'conj:{table}:' + \ conj_key = f"conj:{table}:" + "&".join(
'&'.join("{}={}".format(f, v) "{}={}".format(f, v) for f, v in sub_schema.items()
for f, v in sub_schema.items()) )
assert await async_client.sismember(conj_key, k) assert await async_client.sismember(conj_key, k)
@ -158,8 +160,10 @@ the task system should work reliably.
""" """
@dfly_multi_test_args({'default_lua_flags': 'allow-undeclared-keys', 'proactor_threads': 4}, @dfly_multi_test_args(
{'default_lua_flags': 'allow-undeclared-keys disable-atomicity', 'proactor_threads': 4}) {"default_lua_flags": "allow-undeclared-keys", "proactor_threads": 4},
{"default_lua_flags": "allow-undeclared-keys disable-atomicity", "proactor_threads": 4},
)
async def test_golang_asynq_script(async_pool, num_queues=10, num_tasks=100): async def test_golang_asynq_script(async_pool, num_queues=10, num_tasks=100):
async def enqueue_worker(queue): async def enqueue_worker(queue):
client = aioredis.Redis(connection_pool=async_pool) client = aioredis.Redis(connection_pool=async_pool)
@ -167,15 +171,18 @@ async def test_golang_asynq_script(async_pool, num_queues=10, num_tasks=100):
task_ids = 2 * list(range(num_tasks)) task_ids = 2 * list(range(num_tasks))
random.shuffle(task_ids) random.shuffle(task_ids)
res = [await enqueue(keys=[f"asynq:{{{queue}}}:t:{task_id}", f"asynq:{{{queue}}}:pending"], res = [
args=[f"{task_id}", task_id, int(time.time())]) await enqueue(
for task_id in task_ids] keys=[f"asynq:{{{queue}}}:t:{task_id}", f"asynq:{{{queue}}}:pending"],
args=[f"{task_id}", task_id, int(time.time())],
)
for task_id in task_ids
]
assert sum(res) == num_tasks assert sum(res) == num_tasks
# Start filling the queues # Start filling the queues
jobs = [asyncio.create_task(enqueue_worker( jobs = [asyncio.create_task(enqueue_worker(f"q-{queue}")) for queue in range(num_queues)]
f"q-{queue}")) for queue in range(num_queues)]
collected = 0 collected = 0
@ -189,11 +196,15 @@ async def test_golang_asynq_script(async_pool, num_queues=10, num_tasks=100):
# print(f'\r \r{pct}', end='', flush=True) # print(f'\r \r{pct}', end='', flush=True)
for queue in (f"q-{queue}" for queue in range(num_queues)): for queue in (f"q-{queue}" for queue in range(num_queues)):
prefix = f"asynq:{{{queue}}}:t:" prefix = f"asynq:{{{queue}}}:t:"
msg = await dequeue(keys=[f"asynq:{{{queue}}}:"+t for t in ["pending", "paused", "active", "lease"]], msg = await dequeue(
args=[int(time.time()), prefix]) keys=[
f"asynq:{{{queue}}}:" + t for t in ["pending", "paused", "active", "lease"]
],
args=[int(time.time()), prefix],
)
if msg is not None: if msg is not None:
collected += 1 collected += 1
assert await client.hget(prefix+msg, 'state') == 'active' assert await client.hget(prefix + msg, "state") == "active"
# Run many contending workers # Run many contending workers
await asyncio.gather(*(dequeue_worker() for _ in range(num_queues * 2))) await asyncio.gather(*(dequeue_worker() for _ in range(num_queues * 2)))
@ -204,19 +215,19 @@ async def test_golang_asynq_script(async_pool, num_queues=10, num_tasks=100):
ERROR_CALL_SCRIPT_TEMPLATE = [ ERROR_CALL_SCRIPT_TEMPLATE = [
"redis.{}('LTRIM', 'l', 'a', 'b')", # error only on evaluation "redis.{}('LTRIM', 'l', 'a', 'b')", # error only on evaluation
"redis.{}('obviously wrong')" # error immediately on preprocessing "redis.{}('obviously wrong')", # error immediately on preprocessing
] ]
@dfly_args({"proactor_threads": 1}) @dfly_args({"proactor_threads": 1})
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_eval_error_propagation(async_client): async def test_eval_error_propagation(async_client):
CMDS = ['call', 'pcall', 'acall', 'apcall'] CMDS = ["call", "pcall", "acall", "apcall"]
for cmd, template in itertools.product(CMDS, ERROR_CALL_SCRIPT_TEMPLATE): for cmd, template in itertools.product(CMDS, ERROR_CALL_SCRIPT_TEMPLATE):
does_abort = 'p' not in cmd does_abort = "p" not in cmd
try: try:
await async_client.eval(template.format(cmd), 1, 'l') await async_client.eval(template.format(cmd), 1, "l")
if does_abort: if does_abort:
assert False, "Eval must have thrown an error: " + cmd assert False, "Eval must have thrown an error: " + cmd
except aioredis.RedisError as e: except aioredis.RedisError as e:
@ -230,12 +241,12 @@ async def test_global_eval_in_multi(async_client: aioredis.Redis):
return redis.call('GET', 'any-key'); return redis.call('GET', 'any-key');
""" """
await async_client.set('any-key', 'works') await async_client.set("any-key", "works")
pipe = async_client.pipeline(transaction=True) pipe = async_client.pipeline(transaction=True)
pipe.set('another-key', 'ok') pipe.set("another-key", "ok")
pipe.eval(GLOBAL_SCRIPT, 0) pipe.eval(GLOBAL_SCRIPT, 0)
res = await pipe.execute() res = await pipe.execute()
print(res) print(res)
assert res[1] == 'works' assert res[1] == "works"

View file

@ -8,22 +8,24 @@ from . import dfly_multi_test_args, dfly_args
from .utility import batch_fill_data, gen_test_data from .utility import batch_fill_data, gen_test_data
@dfly_multi_test_args({'keys_output_limit': 512}, {'keys_output_limit': 1024}) @dfly_multi_test_args({"keys_output_limit": 512}, {"keys_output_limit": 1024})
class TestKeys: class TestKeys:
async def test_max_keys(self, async_client: aioredis.Redis, df_server): async def test_max_keys(self, async_client: aioredis.Redis, df_server):
max_keys = df_server['keys_output_limit'] max_keys = df_server["keys_output_limit"]
pipe = async_client.pipeline() pipe = async_client.pipeline()
batch_fill_data(pipe, gen_test_data(max_keys * 3)) batch_fill_data(pipe, gen_test_data(max_keys * 3))
await pipe.execute() await pipe.execute()
keys = await async_client.keys() keys = await async_client.keys()
assert len(keys) in range(max_keys, max_keys + 512) assert len(keys) in range(max_keys, max_keys + 512)
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def export_dfly_password() -> str: def export_dfly_password() -> str:
pwd = 'flypwd' pwd = "flypwd"
os.environ['DFLY_PASSWORD'] = pwd os.environ["DFLY_PASSWORD"] = pwd
yield pwd yield pwd
del os.environ['DFLY_PASSWORD'] del os.environ["DFLY_PASSWORD"]
async def test_password(df_local_factory, export_dfly_password): async def test_password(df_local_factory, export_dfly_password):
dfly = df_local_factory.create() dfly = df_local_factory.create()
@ -38,7 +40,7 @@ async def test_password(df_local_factory, export_dfly_password):
dfly.stop() dfly.stop()
# --requirepass should take precedence over environment variable # --requirepass should take precedence over environment variable
requirepass = 'requirepass' requirepass = "requirepass"
dfly = df_local_factory.create(requirepass=requirepass) dfly = df_local_factory.create(requirepass=requirepass)
dfly.start() dfly.start()
@ -61,6 +63,7 @@ for i = 0, ARGV[1] do
end end
""" """
@dfly_args({"proactor_threads": 1}) @dfly_args({"proactor_threads": 1})
async def test_txq_ooo(async_client: aioredis.Redis, df_server): async def test_txq_ooo(async_client: aioredis.Redis, df_server):
async def task1(k, h): async def task1(k, h):
@ -77,4 +80,6 @@ async def test_txq_ooo(async_client: aioredis.Redis, df_server):
pipe.blpop(k, 0.001) pipe.blpop(k, 0.001)
await pipe.execute() await pipe.execute()
await asyncio.gather(task1('i1', 2), task1('i2', 3), task2('l1', 2), task2('l1', 2), task2('l1', 5)) await asyncio.gather(
task1("i1", 2), task1("i2", 3), task2("l1", 2), task2("l1", 2), task2("l1", 5)
)

View file

@ -4,15 +4,9 @@ from redis import asyncio as aioredis
from .utility import * from .utility import *
from json import JSONDecoder, JSONEncoder from json import JSONDecoder, JSONEncoder
jane = { jane = {"name": "Jane", "Age": 33, "Location": "Chawton"}
'name': "Jane",
'Age': 33,
'Location': "Chawton"
}
json_num = { json_num = {"a": {"a": 1, "b": 2, "c": 3}}
"a": {"a": 1, "b": 2, "c": 3}
}
async def get_set_json(connection: aioredis.Redis, key, value, path="$"): async def get_set_json(connection: aioredis.Redis, key, value, path="$"):
@ -30,8 +24,8 @@ async def test_basic_json_get_set(async_client: aioredis.Redis):
the_type = await async_client.type(key_name) the_type = await async_client.type(key_name)
assert the_type == "ReJSON-RL" assert the_type == "ReJSON-RL"
assert len(result) == 1 assert len(result) == 1
assert result[0]['name'] == 'Jane' assert result[0]["name"] == "Jane"
assert result[0]['Age'] == 33 assert result[0]["Age"] == 33
async def test_access_json_value_as_string(async_client: aioredis.Redis): async def test_access_json_value_as_string(async_client: aioredis.Redis):
@ -76,18 +70,16 @@ async def test_update_value(async_client: aioredis.Redis):
# make sure that we have valid JSON here # make sure that we have valid JSON here
the_type = await async_client.type(key_name) the_type = await async_client.type(key_name)
assert the_type == "ReJSON-RL" assert the_type == "ReJSON-RL"
result = await get_set_json(async_client, value="0", result = await get_set_json(async_client, value="0", key=key_name, path="$.a.*")
key=key_name, path="$.a.*")
assert len(result) == 3 assert len(result) == 3
# make sure that all the values under 'a' where set to 0 # make sure that all the values under 'a' where set to 0
assert result == ['0', '0', '0'] assert result == ["0", "0", "0"]
# Ensure that after we're changing this into STRING type, it will no longer work # Ensure that after we're changing this into STRING type, it will no longer work
await async_client.set(key_name, "some random value") await async_client.set(key_name, "some random value")
assert await async_client.type(key_name) == "string" assert await async_client.type(key_name) == "string"
try: try:
await get_set_json(async_client, value="0", key=key_name, await get_set_json(async_client, value="0", key=key_name, path="$.a.*")
path="$.a.*")
assert False, "should not be able to modify JSON value as string" assert False, "should not be able to modify JSON value as string"
except redis.exceptions.ResponseError as e: except redis.exceptions.ResponseError as e:
assert e.args[0] == "WRONGTYPE Operation against a key holding the wrong kind of value" assert e.args[0] == "WRONGTYPE Operation against a key holding the wrong kind of value"

View file

@ -4,20 +4,19 @@ from redis import asyncio as aioredis
import pytest import pytest
@pytest.mark.parametrize('index', range(50)) @pytest.mark.parametrize("index", range(50))
class TestBlPop: class TestBlPop:
async def async_blpop(client: aioredis.Redis): async def async_blpop(client: aioredis.Redis):
return await client.blpop( return await client.blpop(["list1{t}", "list2{t}", "list2{t}", "list1{t}"], 0.5)
['list1{t}', 'list2{t}', 'list2{t}', 'list1{t}'], 0.5)
async def blpop_mult_keys(async_client: aioredis.Redis, key: str, val: str): async def blpop_mult_keys(async_client: aioredis.Redis, key: str, val: str):
task = asyncio.create_task(TestBlPop.async_blpop(async_client)) task = asyncio.create_task(TestBlPop.async_blpop(async_client))
await async_client.lpush(key, val) await async_client.lpush(key, val)
result = await asyncio.wait_for(task, 3) result = await asyncio.wait_for(task, 3)
assert result[1] == val assert result[1] == val
watched = await async_client.execute_command('DEBUG WATCHED') watched = await async_client.execute_command("DEBUG WATCHED")
assert watched == ['awaked', [], 'watched', []] assert watched == ["awaked", [], "watched", []]
async def test_blpop_multiple_keys(self, async_client: aioredis.Redis, index): async def test_blpop_multiple_keys(self, async_client: aioredis.Redis, index):
await TestBlPop.blpop_mult_keys(async_client, 'list1{t}', 'a') await TestBlPop.blpop_mult_keys(async_client, "list1{t}", "a")
await TestBlPop.blpop_mult_keys(async_client, 'list2{t}', 'b') await TestBlPop.blpop_mult_keys(async_client, "list2{t}", "b")

View file

@ -8,12 +8,14 @@ def test_add_get(memcached_connection):
assert memcached_connection.add(b"key", b"data", noreply=False) assert memcached_connection.add(b"key", b"data", noreply=False)
assert memcached_connection.get(b"key") == b"data" assert memcached_connection.get(b"key") == b"data"
@dfly_args({"memcached_port": 11211}) @dfly_args({"memcached_port": 11211})
def test_add_set(memcached_connection): def test_add_set(memcached_connection):
assert memcached_connection.add(b"key", b"data", noreply=False) assert memcached_connection.add(b"key", b"data", noreply=False)
memcached_connection.set(b"key", b"other") memcached_connection.set(b"key", b"other")
assert memcached_connection.get(b"key") == b"other" assert memcached_connection.get(b"key") == b"other"
@dfly_args({"memcached_port": 11211}) @dfly_args({"memcached_port": 11211})
def test_set_add(memcached_connection): def test_set_add(memcached_connection):
memcached_connection.set(b"key", b"data") memcached_connection.set(b"key", b"data")
@ -23,6 +25,7 @@ def test_set_add(memcached_connection):
memcached_connection.set(b"key", b"other") memcached_connection.set(b"key", b"other")
assert memcached_connection.get(b"key") == b"other" assert memcached_connection.get(b"key") == b"other"
@dfly_args({"memcached_port": 11211}) @dfly_args({"memcached_port": 11211})
def test_mixed_reply(memcached_connection): def test_mixed_reply(memcached_connection):
memcached_connection.set(b"key", b"data", noreply=True) memcached_connection.set(b"key", b"data", noreply=True)

View file

@ -12,13 +12,17 @@ class RedisServer:
self.proc = None self.proc = None
def start(self): def start(self):
self.proc = subprocess.Popen(["redis-server-6.2.11", self.proc = subprocess.Popen(
[
"redis-server-6.2.11",
f"--port {self.port}", f"--port {self.port}",
"--save ''", "--save ''",
"--appendonly no", "--appendonly no",
"--protected-mode no", "--protected-mode no",
"--repl-diskless-sync yes", "--repl-diskless-sync yes",
"--repl-diskless-sync-delay 0"]) "--repl-diskless-sync-delay 0",
]
)
print(self.proc.args) print(self.proc.args)
def stop(self): def stop(self):
@ -28,6 +32,7 @@ class RedisServer:
except Exception as e: except Exception as e:
pass pass
# Checks that master redis and dragonfly replica are synced by writing a random key to master # Checks that master redis and dragonfly replica are synced by writing a random key to master
# and waiting for it to exist in replica. Foreach db in 0..dbcount-1. # and waiting for it to exist in replica. Foreach db in 0..dbcount-1.
async def await_synced(master_port, replica_port, dbcount=1): async def await_synced(master_port, replica_port, dbcount=1):
@ -58,7 +63,7 @@ async def await_synced_all(c_master, c_replicas):
async def check_data(seeder, replicas, c_replicas): async def check_data(seeder, replicas, c_replicas):
capture = await seeder.capture() capture = await seeder.capture()
for (replica, c_replica) in zip(replicas, c_replicas): for replica, c_replica in zip(replicas, c_replicas):
await wait_available_async(c_replica) await wait_available_async(c_replica)
assert await seeder.compare(capture, port=replica.port) assert await seeder.compare(capture, port=replica.port)
@ -84,7 +89,9 @@ full_sync_replication_specs = [
@pytest.mark.parametrize("t_replicas, seeder_config", full_sync_replication_specs) @pytest.mark.parametrize("t_replicas, seeder_config", full_sync_replication_specs)
async def test_replication_full_sync(df_local_factory, df_seeder_factory, redis_server, t_replicas, seeder_config, port_picker): async def test_replication_full_sync(
df_local_factory, df_seeder_factory, redis_server, t_replicas, seeder_config, port_picker
):
master = redis_server master = redis_server
c_master = aioredis.Redis(port=master.port) c_master = aioredis.Redis(port=master.port)
assert await c_master.ping() assert await c_master.ping()
@ -93,7 +100,8 @@ async def test_replication_full_sync(df_local_factory, df_seeder_factory, redis_
await seeder.run(target_deviation=0.1) await seeder.run(target_deviation=0.1)
replica = df_local_factory.create( replica = df_local_factory.create(
port=port_picker.get_available_port(), proactor_threads=t_replicas[0]) port=port_picker.get_available_port(), proactor_threads=t_replicas[0]
)
replica.start() replica.start()
c_replica = aioredis.Redis(port=replica.port) c_replica = aioredis.Redis(port=replica.port)
assert await c_replica.ping() assert await c_replica.ping()
@ -105,6 +113,7 @@ async def test_replication_full_sync(df_local_factory, df_seeder_factory, redis_
capture = await seeder.capture() capture = await seeder.capture()
assert await seeder.compare(capture, port=replica.port) assert await seeder.compare(capture, port=replica.port)
stable_sync_replication_specs = [ stable_sync_replication_specs = [
([1], dict(keys=100, dbcount=1, unsupported_types=[ValueType.JSON])), ([1], dict(keys=100, dbcount=1, unsupported_types=[ValueType.JSON])),
([1], dict(keys=10_000, dbcount=2, unsupported_types=[ValueType.JSON])), ([1], dict(keys=10_000, dbcount=2, unsupported_types=[ValueType.JSON])),
@ -115,13 +124,16 @@ stable_sync_replication_specs = [
@pytest.mark.parametrize("t_replicas, seeder_config", stable_sync_replication_specs) @pytest.mark.parametrize("t_replicas, seeder_config", stable_sync_replication_specs)
async def test_replication_stable_sync(df_local_factory, df_seeder_factory, redis_server, t_replicas, seeder_config, port_picker): async def test_replication_stable_sync(
df_local_factory, df_seeder_factory, redis_server, t_replicas, seeder_config, port_picker
):
master = redis_server master = redis_server
c_master = aioredis.Redis(port=master.port) c_master = aioredis.Redis(port=master.port)
assert await c_master.ping() assert await c_master.ping()
replica = df_local_factory.create( replica = df_local_factory.create(
port=port_picker.get_available_port(), proactor_threads=t_replicas[0]) port=port_picker.get_available_port(), proactor_threads=t_replicas[0]
)
replica.start() replica.start()
c_replica = aioredis.Redis(port=replica.port) c_replica = aioredis.Redis(port=replica.port)
assert await c_replica.ping() assert await c_replica.ping()
@ -150,7 +162,9 @@ replication_specs = [
@pytest.mark.parametrize("t_replicas, seeder_config", replication_specs) @pytest.mark.parametrize("t_replicas, seeder_config", replication_specs)
async def test_redis_replication_all(df_local_factory, df_seeder_factory, redis_server, t_replicas, seeder_config, port_picker): async def test_redis_replication_all(
df_local_factory, df_seeder_factory, redis_server, t_replicas, seeder_config, port_picker
):
master = redis_server master = redis_server
c_master = aioredis.Redis(port=master.port) c_master = aioredis.Redis(port=master.port)
assert await c_master.ping() assert await c_master.ping()
@ -178,11 +192,11 @@ async def test_redis_replication_all(df_local_factory, df_seeder_factory, redis_
await c_replica.execute_command("REPLICAOF localhost " + str(master.port)) await c_replica.execute_command("REPLICAOF localhost " + str(master.port))
await wait_available_async(c_replica) await wait_available_async(c_replica)
await asyncio.gather(*(asyncio.create_task(run_replication(c)) await asyncio.gather(*(asyncio.create_task(run_replication(c)) for c in c_replicas))
for c in c_replicas))
# Wait for streaming to finish # Wait for streaming to finish
assert not stream_task.done( assert (
not stream_task.done()
), "Weak testcase. Increase number of streamed iterations to surpass full sync" ), "Weak testcase. Increase number of streamed iterations to surpass full sync"
seeder.stop() seeder.stop()
await stream_task await stream_task
@ -206,7 +220,15 @@ master_disconnect_cases = [
@pytest.mark.parametrize("t_replicas, t_disconnect, seeder_config", master_disconnect_cases) @pytest.mark.parametrize("t_replicas, t_disconnect, seeder_config", master_disconnect_cases)
async def test_disconnect_master(df_local_factory, df_seeder_factory, redis_server, t_replicas, t_disconnect, seeder_config, port_picker): async def test_disconnect_master(
df_local_factory,
df_seeder_factory,
redis_server,
t_replicas,
t_disconnect,
seeder_config,
port_picker,
):
master = redis_server master = redis_server
c_master = aioredis.Redis(port=master.port) c_master = aioredis.Redis(port=master.port)
assert await c_master.ping() assert await c_master.ping()
@ -234,11 +256,11 @@ async def test_disconnect_master(df_local_factory, df_seeder_factory, redis_serv
await c_replica.execute_command("REPLICAOF localhost " + str(master.port)) await c_replica.execute_command("REPLICAOF localhost " + str(master.port))
await wait_available_async(c_replica) await wait_available_async(c_replica)
await asyncio.gather(*(asyncio.create_task(run_replication(c)) await asyncio.gather(*(asyncio.create_task(run_replication(c)) for c in c_replicas))
for c in c_replicas))
# Wait for streaming to finish # Wait for streaming to finish
assert not stream_task.done( assert (
not stream_task.done()
), "Weak testcase. Increase number of streamed iterations to surpass full sync" ), "Weak testcase. Increase number of streamed iterations to surpass full sync"
seeder.stop() seeder.stop()
await stream_task await stream_task

View file

@ -35,7 +35,9 @@ replication_cases = [
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("t_master, t_replicas, seeder_config", replication_cases) @pytest.mark.parametrize("t_master, t_replicas, seeder_config", replication_cases)
async def test_replication_all(df_local_factory, df_seeder_factory, t_master, t_replicas, seeder_config): async def test_replication_all(
df_local_factory, df_seeder_factory, t_master, t_replicas, seeder_config
):
master = df_local_factory.create(port=BASE_PORT, proactor_threads=t_master) master = df_local_factory.create(port=BASE_PORT, proactor_threads=t_master)
replicas = [ replicas = [
df_local_factory.create(port=BASE_PORT + i + 1, proactor_threads=t) df_local_factory.create(port=BASE_PORT + i + 1, proactor_threads=t)
@ -63,11 +65,11 @@ async def test_replication_all(df_local_factory, df_seeder_factory, t_master, t_
async def run_replication(c_replica): async def run_replication(c_replica):
await c_replica.execute_command("REPLICAOF localhost " + str(master.port)) await c_replica.execute_command("REPLICAOF localhost " + str(master.port))
await asyncio.gather(*(asyncio.create_task(run_replication(c)) await asyncio.gather(*(asyncio.create_task(run_replication(c)) for c in c_replicas))
for c in c_replicas))
# Wait for streaming to finish # Wait for streaming to finish
assert not stream_task.done( assert (
not stream_task.done()
), "Weak testcase. Increase number of streamed iterations to surpass full sync" ), "Weak testcase. Increase number of streamed iterations to surpass full sync"
await stream_task await stream_task
@ -98,18 +100,16 @@ async def check_all_replicas_finished(c_replicas, c_master):
while len(waiting_for) > 0: while len(waiting_for) > 0:
await asyncio.sleep(1.0) await asyncio.sleep(1.0)
tasks = (asyncio.create_task(check_replica_finished_exec(c, c_master)) tasks = (asyncio.create_task(check_replica_finished_exec(c, c_master)) for c in waiting_for)
for c in waiting_for)
finished_list = await asyncio.gather(*tasks) finished_list = await asyncio.gather(*tasks)
# Remove clients that finished from waiting list # Remove clients that finished from waiting list
waiting_for = [c for (c, finished) in zip( waiting_for = [c for (c, finished) in zip(waiting_for, finished_list) if not finished]
waiting_for, finished_list) if not finished]
async def check_data(seeder, replicas, c_replicas): async def check_data(seeder, replicas, c_replicas):
capture = await seeder.capture() capture = await seeder.capture()
for (replica, c_replica) in zip(replicas, c_replicas): for replica, c_replica in zip(replicas, c_replicas):
await wait_available_async(c_replica) await wait_available_async(c_replica)
assert await seeder.compare(capture, port=replica.port) assert await seeder.compare(capture, port=replica.port)
@ -140,14 +140,22 @@ disconnect_cases = [
# stable state heavy # stable state heavy
(8, [], [4] * 4, [], 4_000), (8, [], [4] * 4, [], 4_000),
# disconnect only # disconnect only
(8, [], [], [4] * 4, 4_000) (8, [], [], [4] * 4, 4_000),
] ]
@pytest.mark.skip(reason='Failing on github regression action') @pytest.mark.skip(reason="Failing on github regression action")
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("t_master, t_crash_fs, t_crash_ss, t_disonnect, n_keys", disconnect_cases) @pytest.mark.parametrize("t_master, t_crash_fs, t_crash_ss, t_disonnect, n_keys", disconnect_cases)
async def test_disconnect_replica(df_local_factory: DflyInstanceFactory, df_seeder_factory, t_master, t_crash_fs, t_crash_ss, t_disonnect, n_keys): async def test_disconnect_replica(
df_local_factory: DflyInstanceFactory,
df_seeder_factory,
t_master,
t_crash_fs,
t_crash_ss,
t_disonnect,
n_keys,
):
master = df_local_factory.create(port=BASE_PORT, proactor_threads=t_master) master = df_local_factory.create(port=BASE_PORT, proactor_threads=t_master)
replicas = [ replicas = [
(df_local_factory.create(port=BASE_PORT + i + 1, proactor_threads=t), crash_fs) (df_local_factory.create(port=BASE_PORT + i + 1, proactor_threads=t), crash_fs)
@ -155,7 +163,7 @@ async def test_disconnect_replica(df_local_factory: DflyInstanceFactory, df_seed
chain( chain(
zip(t_crash_fs, repeat(DISCONNECT_CRASH_FULL_SYNC)), zip(t_crash_fs, repeat(DISCONNECT_CRASH_FULL_SYNC)),
zip(t_crash_ss, repeat(DISCONNECT_CRASH_STABLE_SYNC)), zip(t_crash_ss, repeat(DISCONNECT_CRASH_STABLE_SYNC)),
zip(t_disonnect, repeat(DISCONNECT_NORMAL_STABLE_SYNC)) zip(t_disonnect, repeat(DISCONNECT_NORMAL_STABLE_SYNC)),
) )
) )
] ]
@ -168,15 +176,11 @@ async def test_disconnect_replica(df_local_factory: DflyInstanceFactory, df_seed
df_local_factory.start_all([replica for replica, _ in replicas]) df_local_factory.start_all([replica for replica, _ in replicas])
c_replicas = [ c_replicas = [
(replica, aioredis.Redis(port=replica.port), crash_type) (replica, aioredis.Redis(port=replica.port), crash_type) for replica, crash_type in replicas
for replica, crash_type in replicas
] ]
def replicas_of_type(tfunc): def replicas_of_type(tfunc):
return [ return [args for args in c_replicas if tfunc(args[2])]
args for args in c_replicas
if tfunc(args[2])
]
# Start data fill loop # Start data fill loop
seeder = df_seeder_factory.create(port=master.port, keys=n_keys, dbcount=2) seeder = df_seeder_factory.create(port=master.port, keys=n_keys, dbcount=2)
@ -211,8 +215,7 @@ async def test_disconnect_replica(df_local_factory: DflyInstanceFactory, df_seed
await c_replica.connection_pool.disconnect() await c_replica.connection_pool.disconnect()
replica.stop(kill=True) replica.stop(kill=True)
await asyncio.gather(*(stable_sync(*args) for args await asyncio.gather(*(stable_sync(*args) for args in replicas_of_type(lambda t: t == 1)))
in replicas_of_type(lambda t: t == 1)))
# Check master survived all crashes # Check master survived all crashes
assert await c_master.ping() assert await c_master.ping()
@ -241,8 +244,7 @@ async def test_disconnect_replica(df_local_factory: DflyInstanceFactory, df_seed
await asyncio.sleep(random.random() / 100) await asyncio.sleep(random.random() / 100)
await c_replica.execute_command("REPLICAOF NO ONE") await c_replica.execute_command("REPLICAOF NO ONE")
await asyncio.gather(*(disconnect(*args) for args await asyncio.gather(*(disconnect(*args) for args in replicas_of_type(lambda t: t == 2)))
in replicas_of_type(lambda t: t == 2)))
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
@ -282,7 +284,9 @@ master_crash_cases = [
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("t_master, t_replicas, n_random_crashes, n_keys", master_crash_cases) @pytest.mark.parametrize("t_master, t_replicas, n_random_crashes, n_keys", master_crash_cases)
async def test_disconnect_master(df_local_factory, df_seeder_factory, t_master, t_replicas, n_random_crashes, n_keys): async def test_disconnect_master(
df_local_factory, df_seeder_factory, t_master, t_replicas, n_random_crashes, n_keys
):
master = df_local_factory.create(port=1111, proactor_threads=t_master) master = df_local_factory.create(port=1111, proactor_threads=t_master)
replicas = [ replicas = [
df_local_factory.create(port=BASE_PORT + i + 1, proactor_threads=t) df_local_factory.create(port=BASE_PORT + i + 1, proactor_threads=t)
@ -309,8 +313,12 @@ async def test_disconnect_master(df_local_factory, df_seeder_factory, t_master,
await start_master() await start_master()
# Crash master during full sync, but with all passing initial connection phase # Crash master during full sync, but with all passing initial connection phase
await asyncio.gather(*(c_replica.execute_command("REPLICAOF localhost " + str(master.port)) await asyncio.gather(
for c_replica in c_replicas)) *(
c_replica.execute_command("REPLICAOF localhost " + str(master.port))
for c_replica in c_replicas
)
)
await crash_master_fs() await crash_master_fs()
await asyncio.sleep(1 + len(replicas) * 0.5) await asyncio.sleep(1 + len(replicas) * 0.5)
@ -338,24 +346,25 @@ async def test_disconnect_master(df_local_factory, df_seeder_factory, t_master,
await wait_available_async(c_replica) await wait_available_async(c_replica)
assert await seeder.compare(capture, port=replica.port) assert await seeder.compare(capture, port=replica.port)
""" """
Test re-connecting replica to different masters. Test re-connecting replica to different masters.
""" """
rotating_master_cases = [ rotating_master_cases = [(4, [4, 4, 4, 4], dict(keys=2_000, dbcount=4))]
(4, [4, 4, 4, 4], dict(keys=2_000, dbcount=4))
]
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("t_replica, t_masters, seeder_config", rotating_master_cases) @pytest.mark.parametrize("t_replica, t_masters, seeder_config", rotating_master_cases)
async def test_rotating_masters(df_local_factory, df_seeder_factory, t_replica, t_masters, seeder_config): async def test_rotating_masters(
replica = df_local_factory.create( df_local_factory, df_seeder_factory, t_replica, t_masters, seeder_config
port=BASE_PORT, proactor_threads=t_replica) ):
masters = [df_local_factory.create( replica = df_local_factory.create(port=BASE_PORT, proactor_threads=t_replica)
port=BASE_PORT+i+1, proactor_threads=t) for i, t in enumerate(t_masters)] masters = [
seeders = [df_seeder_factory.create( df_local_factory.create(port=BASE_PORT + i + 1, proactor_threads=t)
port=m.port, **seeder_config) for m in masters] for i, t in enumerate(t_masters)
]
seeders = [df_seeder_factory.create(port=m.port, **seeder_config) for m in masters]
df_local_factory.start_all([replica] + masters) df_local_factory.start_all([replica] + masters)
@ -465,8 +474,7 @@ async def test_flushall(df_local_factory):
# flushall # flushall
pipe.flushall() pipe.flushall()
# Set simple keys n_keys..n_keys*2 on master # Set simple keys n_keys..n_keys*2 on master
batch_fill_data(client=pipe, gen=gen_test_data( batch_fill_data(client=pipe, gen=gen_test_data(n_keys, n_keys * 2), batch_size=3)
n_keys, n_keys*2), batch_size=3)
await pipe.execute() await pipe.execute()
# Check replica finished executing the replicated commands # Check replica finished executing the replicated commands
@ -494,7 +502,7 @@ Test journal rewrites.
@dfly_args({"proactor_threads": 4}) @dfly_args({"proactor_threads": 4})
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_rewrites(df_local_factory): async def test_rewrites(df_local_factory):
CLOSE_TIMESTAMP = (int(time.time()) + 100) CLOSE_TIMESTAMP = int(time.time()) + 100
CLOSE_TIMESTAMP_MS = CLOSE_TIMESTAMP * 1000 CLOSE_TIMESTAMP_MS = CLOSE_TIMESTAMP * 1000
master = df_local_factory.create(port=BASE_PORT) master = df_local_factory.create(port=BASE_PORT)
@ -513,16 +521,16 @@ async def test_rewrites(df_local_factory):
m_replica = c_replica.monitor() m_replica = c_replica.monitor()
async def get_next_command(): async def get_next_command():
mcmd = (await m_replica.next_command())['command'] mcmd = (await m_replica.next_command())["command"]
# skip select command # skip select command
if (mcmd == "SELECT 0"): if mcmd == "SELECT 0":
print("Got:", mcmd) print("Got:", mcmd)
mcmd = (await m_replica.next_command())['command'] mcmd = (await m_replica.next_command())["command"]
print("Got:", mcmd) print("Got:", mcmd)
return mcmd return mcmd
async def is_match_rsp(rx): async def is_match_rsp(rx):
mcmd = (await get_next_command()) mcmd = await get_next_command()
print("Regex:", rx) print("Regex:", rx)
return re.match(rx, mcmd) return re.match(rx, mcmd)
@ -531,14 +539,14 @@ async def test_rewrites(df_local_factory):
async def check(cmd, rx): async def check(cmd, rx):
await c_master.execute_command(cmd) await c_master.execute_command(cmd)
match = (await is_match_rsp(rx)) match = await is_match_rsp(rx)
assert match assert match
async def check_list(cmd, rx_list): async def check_list(cmd, rx_list):
print("master cmd:", cmd) print("master cmd:", cmd)
await c_master.execute_command(cmd) await c_master.execute_command(cmd)
for rx in rx_list: for rx in rx_list:
match = (await is_match_rsp(rx)) match = await is_match_rsp(rx)
assert match assert match
async def check_list_ooo(cmd, rx_list): async def check_list_ooo(cmd, rx_list):
@ -546,7 +554,7 @@ async def test_rewrites(df_local_factory):
await c_master.execute_command(cmd) await c_master.execute_command(cmd)
expected_cmds = len(rx_list) expected_cmds = len(rx_list)
for i in range(expected_cmds): for i in range(expected_cmds):
mcmd = (await get_next_command()) mcmd = await get_next_command()
# check command matches one regex from list # check command matches one regex from list
match_rx = list(filter(lambda rx: re.match(rx, mcmd), rx_list)) match_rx = list(filter(lambda rx: re.match(rx, mcmd), rx_list))
assert len(match_rx) == 1 assert len(match_rx) == 1
@ -650,10 +658,16 @@ async def test_rewrites(df_local_factory):
await c_master.set("renamekey", "1000", px=50000) await c_master.set("renamekey", "1000", px=50000)
await skip_cmd() await skip_cmd()
# Check RENAME turns into DEL SET and PEXPIREAT # Check RENAME turns into DEL SET and PEXPIREAT
await check_list_ooo("RENAME renamekey renamed", [r"DEL renamekey", r"SET renamed 1000", r"PEXPIREAT renamed (.*?)"]) await check_list_ooo(
"RENAME renamekey renamed",
[r"DEL renamekey", r"SET renamed 1000", r"PEXPIREAT renamed (.*?)"],
)
await check_expire("renamed") await check_expire("renamed")
# Check RENAMENX turns into DEL SET and PEXPIREAT # Check RENAMENX turns into DEL SET and PEXPIREAT
await check_list_ooo("RENAMENX renamed renamekey", [r"DEL renamed", r"SET renamekey 1000", r"PEXPIREAT renamekey (.*?)"]) await check_list_ooo(
"RENAMENX renamed renamekey",
[r"DEL renamed", r"SET renamekey 1000", r"PEXPIREAT renamekey (.*?)"],
)
await check_expire("renamekey") await check_expire("renamekey")
@ -702,8 +716,7 @@ async def test_expiry(df_local_factory, n_keys=1000):
# Set simple keys n_keys..n_keys*2 on master # Set simple keys n_keys..n_keys*2 on master
start_key = n_keys * (i + 1) start_key = n_keys * (i + 1)
end_key = start_key + n_keys end_key = start_key + n_keys
batch_fill_data(client=pipe, gen=gen_test_data( batch_fill_data(client=pipe, gen=gen_test_data(end_key, start_key), batch_size=20)
end_key, start_key), batch_size=20)
await pipe.execute() await pipe.execute()
@ -771,14 +784,15 @@ return 'OK'
""" """
@pytest.mark.skip(reason='Failing') @pytest.mark.skip(reason="Failing")
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("t_master, t_replicas, num_ops, num_keys, num_par, flags", script_cases) @pytest.mark.parametrize("t_master, t_replicas, num_ops, num_keys, num_par, flags", script_cases)
async def test_scripts(df_local_factory, t_master, t_replicas, num_ops, num_keys, num_par, flags): async def test_scripts(df_local_factory, t_master, t_replicas, num_ops, num_keys, num_par, flags):
master = df_local_factory.create( master = df_local_factory.create(port=BASE_PORT, proactor_threads=t_master)
port=BASE_PORT, proactor_threads=t_master) replicas = [
replicas = [df_local_factory.create( df_local_factory.create(port=BASE_PORT + i + 1, proactor_threads=t)
port=BASE_PORT+i+1, proactor_threads=t) for i, t in enumerate(t_replicas)] for i, t in enumerate(t_replicas)
]
df_local_factory.start_all([master] + replicas) df_local_factory.start_all([master] + replicas)
@ -788,16 +802,15 @@ async def test_scripts(df_local_factory, t_master, t_replicas, num_ops, num_keys
await c_replica.execute_command(f"REPLICAOF localhost {master.port}") await c_replica.execute_command(f"REPLICAOF localhost {master.port}")
await wait_available_async(c_replica) await wait_available_async(c_replica)
script = script_test_s1.format( script = script_test_s1.format(flags=f"#!lua flags={flags}" if flags else "")
flags=f'#!lua flags={flags}' if flags else '')
sha = await c_master.script_load(script) sha = await c_master.script_load(script)
key_sets = [ key_sets = [[f"{i}-{j}" for j in range(num_keys)] for i in range(num_par)]
[f'{i}-{j}' for j in range(num_keys)] for i in range(num_par)
]
rsps = await asyncio.gather(*(c_master.evalsha(sha, len(keys), *keys, num_ops) for keys in key_sets)) rsps = await asyncio.gather(
assert rsps == [b'OK'] * num_par *(c_master.evalsha(sha, len(keys), *keys, num_ops) for keys in key_sets)
)
assert rsps == [b"OK"] * num_par
await check_all_replicas_finished(c_replicas, c_master) await check_all_replicas_finished(c_replicas, c_master)
@ -805,17 +818,18 @@ async def test_scripts(df_local_factory, t_master, t_replicas, num_ops, num_keys
for key_set in key_sets: for key_set in key_sets:
for j, k in enumerate(key_set): for j, k in enumerate(key_set):
l = await c_replica.lrange(k, 0, -1) l = await c_replica.lrange(k, 0, -1)
assert l == [f'{j}'.encode()] * num_ops assert l == [f"{j}".encode()] * num_ops
@dfly_args({"proactor_threads": 4}) @dfly_args({"proactor_threads": 4})
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_auth_master(df_local_factory, n_keys=20): async def test_auth_master(df_local_factory, n_keys=20):
masterpass = 'requirepass' masterpass = "requirepass"
replicapass = 'replicapass' replicapass = "replicapass"
master = df_local_factory.create(port=BASE_PORT, requirepass=masterpass) master = df_local_factory.create(port=BASE_PORT, requirepass=masterpass)
replica = df_local_factory.create( replica = df_local_factory.create(
port=BASE_PORT+1, logtostdout=True, masterauth=masterpass, requirepass=replicapass) port=BASE_PORT + 1, logtostdout=True, masterauth=masterpass, requirepass=replicapass
)
df_local_factory.start_all([master, replica]) df_local_factory.start_all([master, replica])
@ -886,21 +900,31 @@ async def test_role_command(df_local_factory, n_keys=20):
c_master = aioredis.Redis(port=master.port) c_master = aioredis.Redis(port=master.port)
c_replica = aioredis.Redis(port=replica.port) c_replica = aioredis.Redis(port=replica.port)
assert await c_master.execute_command("role") == [b'master', []] assert await c_master.execute_command("role") == [b"master", []]
await c_replica.execute_command(f"REPLICAOF localhost {master.port}") await c_replica.execute_command(f"REPLICAOF localhost {master.port}")
await wait_available_async(c_replica) await wait_available_async(c_replica)
assert await c_master.execute_command("role") == [ assert await c_master.execute_command("role") == [
b'master', [[b'127.0.0.1', bytes(str(replica.port), 'ascii'), b'stable_sync']]] b"master",
[[b"127.0.0.1", bytes(str(replica.port), "ascii"), b"stable_sync"]],
]
assert await c_replica.execute_command("role") == [ assert await c_replica.execute_command("role") == [
b'replica', b'localhost', bytes(str(master.port), 'ascii'), b'stable_sync'] b"replica",
b"localhost",
bytes(str(master.port), "ascii"),
b"stable_sync",
]
# This tests that we react fast to socket shutdowns and don't hang on # This tests that we react fast to socket shutdowns and don't hang on
# things like the ACK or execution fibers. # things like the ACK or execution fibers.
master.stop() master.stop()
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
assert await c_replica.execute_command("role") == [ assert await c_replica.execute_command("role") == [
b'replica', b'localhost', bytes(str(master.port), 'ascii'), b'connecting'] b"replica",
b"localhost",
bytes(str(master.port), "ascii"),
b"connecting",
]
await c_master.connection_pool.disconnect() await c_master.connection_pool.disconnect()
await c_replica.connection_pool.disconnect() await c_replica.connection_pool.disconnect()
@ -943,7 +967,8 @@ async def assert_lag_condition(inst, client, condition):
async def test_replication_info(df_local_factory, df_seeder_factory, n_keys=2000): async def test_replication_info(df_local_factory, df_seeder_factory, n_keys=2000):
master = df_local_factory.create(port=BASE_PORT) master = df_local_factory.create(port=BASE_PORT)
replica = df_local_factory.create( replica = df_local_factory.create(
port=BASE_PORT+1, logtostdout=True, replication_acks_interval=100) port=BASE_PORT + 1, logtostdout=True, replication_acks_interval=100
)
df_local_factory.start_all([master, replica]) df_local_factory.start_all([master, replica])
c_master = aioredis.Redis(port=master.port) c_master = aioredis.Redis(port=master.port)
c_replica = aioredis.Redis(port=replica.port) c_replica = aioredis.Redis(port=replica.port)
@ -974,18 +999,15 @@ More details in https://github.com/dragonflydb/dragonfly/issues/1231
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flushall_in_full_sync(df_local_factory, df_seeder_factory): async def test_flushall_in_full_sync(df_local_factory, df_seeder_factory):
master = df_local_factory.create( master = df_local_factory.create(port=BASE_PORT, proactor_threads=4, logtostdout=True)
port=BASE_PORT, proactor_threads=4, logtostdout=True) replica = df_local_factory.create(port=BASE_PORT + 1, proactor_threads=2, logtostdout=True)
replica = df_local_factory.create(
port=BASE_PORT+1, proactor_threads=2, logtostdout=True)
# Start master # Start master
master.start() master.start()
c_master = aioredis.Redis(port=master.port) c_master = aioredis.Redis(port=master.port)
# Fill master with test data # Fill master with test data
seeder = df_seeder_factory.create( seeder = df_seeder_factory.create(port=master.port, keys=100_000, dbcount=1)
port=master.port, keys=100_000, dbcount=1)
await seeder.run(target_deviation=0.1) await seeder.run(target_deviation=0.1)
# Start replica # Start replica
@ -998,7 +1020,7 @@ async def test_flushall_in_full_sync(df_local_factory, df_seeder_factory):
return result[3] return result[3]
async def is_full_sync_mode(c_replica): async def is_full_sync_mode(c_replica):
return await get_sync_mode(c_replica) == b'full_sync' return await get_sync_mode(c_replica) == b"full_sync"
# Wait for full sync to start # Wait for full sync to start
while not await is_full_sync_mode(c_replica): while not await is_full_sync_mode(c_replica):
@ -1010,12 +1032,10 @@ async def test_flushall_in_full_sync(df_local_factory, df_seeder_factory):
await c_master.execute_command("FLUSHALL") await c_master.execute_command("FLUSHALL")
if not await is_full_sync_mode(c_replica): if not await is_full_sync_mode(c_replica):
logging.error( logging.error("!!! Full sync finished too fast. Adjust test parameters !!!")
"!!! Full sync finished too fast. Adjust test parameters !!!")
return return
post_seeder = df_seeder_factory.create( post_seeder = df_seeder_factory.create(port=master.port, keys=10, dbcount=1)
port=master.port, keys=10, dbcount=1)
await post_seeder.run(target_deviation=0.1) await post_seeder.run(target_deviation=0.1)
await check_all_replicas_finished([c_replica], c_master) await check_all_replicas_finished([c_replica], c_master)
@ -1047,28 +1067,26 @@ redis.call('SET', 'A', 'ErrroR')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_readonly_script(df_local_factory): async def test_readonly_script(df_local_factory):
master = df_local_factory.create( master = df_local_factory.create(port=BASE_PORT, proactor_threads=2, logtostdout=True)
port=BASE_PORT, proactor_threads=2, logtostdout=True) replica = df_local_factory.create(port=BASE_PORT + 1, proactor_threads=2, logtostdout=True)
replica = df_local_factory.create(
port=BASE_PORT+1, proactor_threads=2, logtostdout=True)
df_local_factory.start_all([master, replica]) df_local_factory.start_all([master, replica])
c_master = aioredis.Redis(port=master.port) c_master = aioredis.Redis(port=master.port)
c_replica = aioredis.Redis(port=replica.port) c_replica = aioredis.Redis(port=replica.port)
await c_master.set('WORKS', 'YES') await c_master.set("WORKS", "YES")
await c_replica.execute_command(f"REPLICAOF localhost {master.port}") await c_replica.execute_command(f"REPLICAOF localhost {master.port}")
await wait_available_async(c_replica) await wait_available_async(c_replica)
await c_replica.eval(READONLY_SCRIPT, 3, 'A', 'B', 'WORKS') == 'YES' await c_replica.eval(READONLY_SCRIPT, 3, "A", "B", "WORKS") == "YES"
try: try:
await c_replica.eval(WRITE_SCRIPT, 1, 'A') await c_replica.eval(WRITE_SCRIPT, 1, "A")
assert False assert False
except aioredis.ResponseError as roe: except aioredis.ResponseError as roe:
assert 'READONLY ' in str(roe) assert "READONLY " in str(roe)
take_over_cases = [ take_over_cases = [
@ -1082,16 +1100,15 @@ take_over_cases = [
@pytest.mark.parametrize("master_threads, replica_threads", take_over_cases) @pytest.mark.parametrize("master_threads, replica_threads", take_over_cases)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_take_over_counters(df_local_factory, master_threads, replica_threads): async def test_take_over_counters(df_local_factory, master_threads, replica_threads):
master = df_local_factory.create(proactor_threads=master_threads, master = df_local_factory.create(
proactor_threads=master_threads,
port=BASE_PORT, port=BASE_PORT,
# vmodule="journal_slice=2,dflycmd=2,main_service=1", # vmodule="journal_slice=2,dflycmd=2,main_service=1",
logtostderr=True) logtostderr=True,
replica1 = df_local_factory.create( )
port=BASE_PORT+1, proactor_threads=replica_threads) replica1 = df_local_factory.create(port=BASE_PORT + 1, proactor_threads=replica_threads)
replica2 = df_local_factory.create( replica2 = df_local_factory.create(port=BASE_PORT + 2, proactor_threads=replica_threads)
port=BASE_PORT+2, proactor_threads=replica_threads) replica3 = df_local_factory.create(port=BASE_PORT + 3, proactor_threads=replica_threads)
replica3 = df_local_factory.create(
port=BASE_PORT+3, proactor_threads=replica_threads)
df_local_factory.start_all([master, replica1, replica2, replica3]) df_local_factory.start_all([master, replica1, replica2, replica3])
c_master = master.client() c_master = master.client()
c1 = replica1.client() c1 = replica1.client()
@ -1130,8 +1147,10 @@ async def test_take_over_counters(df_local_factory, master_threads, replica_thre
await asyncio.sleep(1) await asyncio.sleep(1)
await c1.execute_command(f"REPLTAKEOVER 5") await c1.execute_command(f"REPLTAKEOVER 5")
_, _, *results = await asyncio.gather(delayed_takeover(), block_during_takeover(), *[counter(f"key{i}") for i in range(16)]) _, _, *results = await asyncio.gather(
assert await c1.execute_command("role") == [b'master', []] delayed_takeover(), block_during_takeover(), *[counter(f"key{i}") for i in range(16)]
)
assert await c1.execute_command("role") == [b"master", []]
for key, client_value in results: for key, client_value in results:
replicated_value = await c1.get(key) replicated_value = await c1.get(key)
@ -1142,18 +1161,20 @@ async def test_take_over_counters(df_local_factory, master_threads, replica_thre
@pytest.mark.parametrize("master_threads, replica_threads", take_over_cases) @pytest.mark.parametrize("master_threads, replica_threads", take_over_cases)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_take_over_seeder(request, df_local_factory, df_seeder_factory, master_threads, replica_threads): async def test_take_over_seeder(
tmp_file_name = ''.join(random.choices(string.ascii_letters, k=10)) request, df_local_factory, df_seeder_factory, master_threads, replica_threads
master = df_local_factory.create(proactor_threads=master_threads, ):
tmp_file_name = "".join(random.choices(string.ascii_letters, k=10))
master = df_local_factory.create(
proactor_threads=master_threads,
port=BASE_PORT, port=BASE_PORT,
dbfilename=f"dump_{tmp_file_name}", dbfilename=f"dump_{tmp_file_name}",
logtostderr=True) logtostderr=True,
replica = df_local_factory.create( )
port=BASE_PORT+1, proactor_threads=replica_threads) replica = df_local_factory.create(port=BASE_PORT + 1, proactor_threads=replica_threads)
df_local_factory.start_all([master, replica]) df_local_factory.start_all([master, replica])
seeder = df_seeder_factory.create( seeder = df_seeder_factory.create(port=master.port, keys=1000, dbcount=5, stop_on_failure=False)
port=master.port, keys=1000, dbcount=5, stop_on_failure=False)
c_master = master.client() c_master = master.client()
c_replica = replica.client() c_replica = replica.client()
@ -1171,7 +1192,7 @@ async def test_take_over_seeder(request, df_local_factory, df_seeder_factory, ma
await c_replica.execute_command(f"REPLTAKEOVER 5") await c_replica.execute_command(f"REPLTAKEOVER 5")
seeder.stop() seeder.stop()
assert await c_replica.execute_command("role") == [b'master', []] assert await c_replica.execute_command("role") == [b"master", []]
# Need to wait a bit to give time to write the shutdown snapshot # Need to wait a bit to give time to write the shutdown snapshot
await asyncio.sleep(1) await asyncio.sleep(1)
@ -1188,15 +1209,11 @@ async def test_take_over_seeder(request, df_local_factory, df_seeder_factory, ma
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_take_over_timeout(df_local_factory, df_seeder_factory): async def test_take_over_timeout(df_local_factory, df_seeder_factory):
master = df_local_factory.create(proactor_threads=2, master = df_local_factory.create(proactor_threads=2, port=BASE_PORT, logtostderr=True)
port=BASE_PORT, replica = df_local_factory.create(port=BASE_PORT + 1, proactor_threads=2)
logtostderr=True)
replica = df_local_factory.create(
port=BASE_PORT+1, proactor_threads=2)
df_local_factory.start_all([master, replica]) df_local_factory.start_all([master, replica])
seeder = df_seeder_factory.create( seeder = df_seeder_factory.create(port=master.port, keys=1000, dbcount=5, stop_on_failure=False)
port=master.port, keys=1000, dbcount=5, stop_on_failure=False)
c_master = master.client() c_master = master.client()
c_replica = replica.client() c_replica = replica.client()
@ -1217,8 +1234,16 @@ async def test_take_over_timeout(df_local_factory, df_seeder_factory):
seeder.stop() seeder.stop()
await fill_task await fill_task
assert await c_master.execute_command("role") == [b'master', [[b'127.0.0.1', bytes(str(replica.port), 'ascii'), b'stable_sync']]] assert await c_master.execute_command("role") == [
assert await c_replica.execute_command("role") == [b'replica', b'localhost', bytes(str(master.port), 'ascii'), b'stable_sync'] b"master",
[[b"127.0.0.1", bytes(str(replica.port), "ascii"), b"stable_sync"]],
]
assert await c_replica.execute_command("role") == [
b"replica",
b"localhost",
bytes(str(master.port), "ascii"),
b"stable_sync",
]
await disconnect_clients(c_master, c_replica) await disconnect_clients(c_master, c_replica)
@ -1230,11 +1255,18 @@ replication_cases = [(8, 8)]
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("t_master, t_replica", replication_cases) @pytest.mark.parametrize("t_master, t_replica", replication_cases)
async def test_no_tls_on_admin_port(df_local_factory, df_seeder_factory, t_master, t_replica, with_tls_server_args): async def test_no_tls_on_admin_port(
df_local_factory, df_seeder_factory, t_master, t_replica, with_tls_server_args
):
# 1. Spin up dragonfly without tls, debug populate # 1. Spin up dragonfly without tls, debug populate
master = df_local_factory.create( master = df_local_factory.create(
no_tls_on_admin_port="true", admin_port=ADMIN_PORT, **with_tls_server_args, port=BASE_PORT, proactor_threads=t_master) no_tls_on_admin_port="true",
admin_port=ADMIN_PORT,
**with_tls_server_args,
port=BASE_PORT,
proactor_threads=t_master,
)
master.start() master.start()
c_master = aioredis.Redis(port=master.admin_port) c_master = aioredis.Redis(port=master.admin_port)
await c_master.execute_command("DEBUG POPULATE 100") await c_master.execute_command("DEBUG POPULATE 100")
@ -1244,7 +1276,12 @@ async def test_no_tls_on_admin_port(df_local_factory, df_seeder_factory, t_maste
# 2. Spin up a replica and initiate a REPLICAOF # 2. Spin up a replica and initiate a REPLICAOF
replica = df_local_factory.create( replica = df_local_factory.create(
no_tls_on_admin_port="true", admin_port=ADMIN_PORT + 1, **with_tls_server_args, port=BASE_PORT + 1, proactor_threads=t_replica) no_tls_on_admin_port="true",
admin_port=ADMIN_PORT + 1,
**with_tls_server_args,
port=BASE_PORT + 1,
proactor_threads=t_replica,
)
replica.start() replica.start()
c_replica = aioredis.Redis(port=replica.admin_port) c_replica = aioredis.Redis(port=replica.admin_port)
res = await c_replica.execute_command("REPLICAOF localhost " + str(master.admin_port)) res = await c_replica.execute_command("REPLICAOF localhost " + str(master.admin_port))

View file

@ -14,21 +14,38 @@ from redis.commands.search.field import TextField, NumericField, TagField, Vecto
from redis.commands.search.indexDefinition import IndexDefinition, IndexType from redis.commands.search.indexDefinition import IndexDefinition, IndexType
TEST_DATA = [ TEST_DATA = [
{"title": "First article", "content": "Long description", {
"views": 100, "topic": "world, science"}, "title": "First article",
"content": "Long description",
{"title": "Second article", "content": "Small text", "views": 100,
"views": 200, "topic": "national, policits"}, "topic": "world, science",
},
{"title": "Third piece", "content": "Brief description", {
"views": 300, "topic": "health, lifestyle"}, "title": "Second article",
"content": "Small text",
{"title": "Last piece", "content": "Interesting text", "views": 200,
"views": 400, "topic": "world, business"}, "topic": "national, policits",
},
{
"title": "Third piece",
"content": "Brief description",
"views": 300,
"topic": "health, lifestyle",
},
{
"title": "Last piece",
"content": "Interesting text",
"views": 400,
"topic": "world, business",
},
] ]
BASIC_TEST_SCHEMA = [TextField("title"), TextField( BASIC_TEST_SCHEMA = [
"content"), NumericField("views"), TagField("topic")] TextField("title"),
TextField("content"),
NumericField("views"),
TagField("topic"),
]
async def index_test_data(async_client: aioredis.Redis, itype: IndexType, prefix=""): async def index_test_data(async_client: aioredis.Redis, itype: IndexType, prefix=""):
@ -44,10 +61,10 @@ def doc_to_str(doc):
doc = doc.__dict__ doc = doc.__dict__
doc = dict(doc) # copy to remove fields doc = dict(doc) # copy to remove fields
doc.pop('id', None) doc.pop("id", None)
doc.pop('payload', None) doc.pop("payload", None)
return '//'.join(sorted(doc)) return "//".join(sorted(doc))
def contains_test_data(res, td_indices): def contains_test_data(res, td_indices):
@ -105,7 +122,7 @@ async def test_basic(async_client: aioredis.Redis, index_type):
async def knn_query(idx, query, vector): async def knn_query(idx, query, vector):
params = {"vec": np.array(vector, dtype=np.float32).tobytes()} params = {"vec": np.array(vector, dtype=np.float32).tobytes()}
result = await idx.search(query, params) result = await idx.search(query, params)
return {doc['id'] for doc in result.docs} return {doc["id"] for doc in result.docs}
@dfly_args({"proactor_threads": 4}) @dfly_args({"proactor_threads": 4})
@ -113,13 +130,19 @@ async def knn_query(idx, query, vector):
async def test_knn(async_client: aioredis.Redis, index_type): async def test_knn(async_client: aioredis.Redis, index_type):
i2 = async_client.ft("i2-" + str(index_type)) i2 = async_client.ft("i2-" + str(index_type))
vector_field = VectorField("pos", "FLAT", { vector_field = VectorField(
"pos",
"FLAT",
{
"TYPE": "FLOAT32", "TYPE": "FLOAT32",
"DIM": 1, "DIM": 1,
"DISTANCE_METRIC": "L2", "DISTANCE_METRIC": "L2",
}) },
)
await i2.create_index([TagField("even"), vector_field], definition=IndexDefinition(index_type=index_type)) await i2.create_index(
[TagField("even"), vector_field], definition=IndexDefinition(index_type=index_type)
)
pipe = async_client.pipeline() pipe = async_client.pipeline()
for i in range(100): for i in range(100):
@ -131,13 +154,18 @@ async def test_knn(async_client: aioredis.Redis, index_type):
pipe.json().set(f"k{i}", "$", {"even": even, "pos": [float(i)]}) pipe.json().set(f"k{i}", "$", {"even": even, "pos": [float(i)]})
await pipe.execute() await pipe.execute()
assert await knn_query(i2, "* => [KNN 3 @pos VEC]", [50.0]) == {'k49', 'k50', 'k51'} assert await knn_query(i2, "* => [KNN 3 @pos VEC]", [50.0]) == {"k49", "k50", "k51"}
assert await knn_query(i2, "@even:{yes} => [KNN 3 @pos VEC]", [20.0]) == {'k18', 'k20', 'k22'} assert await knn_query(i2, "@even:{yes} => [KNN 3 @pos VEC]", [20.0]) == {"k18", "k20", "k22"}
assert await knn_query(i2, "@even:{no} => [KNN 4 @pos VEC]", [30.0]) == {'k27', 'k29', 'k31', 'k33'} assert await knn_query(i2, "@even:{no} => [KNN 4 @pos VEC]", [30.0]) == {
"k27",
"k29",
"k31",
"k33",
}
assert await knn_query(i2, "@even:{yes} => [KNN 3 @pos VEC]", [10.0] == {'k8', 'k10', 'k12'}) assert await knn_query(i2, "@even:{yes} => [KNN 3 @pos VEC]", [10.0] == {"k8", "k10", "k12"})
NUM_DIMS = 10 NUM_DIMS = 10
@ -147,11 +175,15 @@ NUM_POINTS = 100
@dfly_args({"proactor_threads": 4}) @dfly_args({"proactor_threads": 4})
@pytest.mark.parametrize("index_type", [IndexType.HASH, IndexType.JSON]) @pytest.mark.parametrize("index_type", [IndexType.HASH, IndexType.JSON])
async def test_multidim_knn(async_client: aioredis.Redis, index_type): async def test_multidim_knn(async_client: aioredis.Redis, index_type):
vector_field = VectorField("pos", "FLAT", { vector_field = VectorField(
"pos",
"FLAT",
{
"TYPE": "FLOAT32", "TYPE": "FLOAT32",
"DIM": NUM_DIMS, "DIM": NUM_DIMS,
"DISTANCE_METRIC": "L2", "DISTANCE_METRIC": "L2",
}) },
)
i3 = async_client.ft("i3-" + str(index_type)) i3 = async_client.ft("i3-" + str(index_type))
await i3.create_index([vector_field], definition=IndexDefinition(index_type=index_type)) await i3.create_index([vector_field], definition=IndexDefinition(index_type=index_type))
@ -160,8 +192,7 @@ async def test_multidim_knn(async_client: aioredis.Redis, index_type):
return np.random.uniform(0, 10, NUM_DIMS).astype(np.float32) return np.random.uniform(0, 10, NUM_DIMS).astype(np.float32)
# Generate points and send to DF # Generate points and send to DF
points = [rand_point() points = [rand_point() for _ in range(NUM_POINTS)]
for _ in range(NUM_POINTS)]
points = list(enumerate(points)) points = list(enumerate(points))
pipe = async_client.pipeline(transaction=False) pipe = async_client.pipeline(transaction=False)
@ -177,8 +208,10 @@ async def test_multidim_knn(async_client: aioredis.Redis, index_type):
center = rand_point() center = rand_point()
limit = random.randint(1, NUM_POINTS / 10) limit = random.randint(1, NUM_POINTS / 10)
expected_ids = [f"k{i}" for i, point in sorted( expected_ids = [
points, key=lambda p: np.linalg.norm(center - p[1]))[:limit]] f"k{i}"
for i, point in sorted(points, key=lambda p: np.linalg.norm(center - p[1]))[:limit]
]
got_ids = await knn_query(i3, f"* => [KNN {limit} @pos VEC]", center) got_ids = await knn_query(i3, f"* => [KNN {limit} @pos VEC]", center)

View file

@ -18,7 +18,7 @@ def stdout_as_list_of_dicts(cp: subprocess.CompletedProcess, new_dict_key =""):
lines = cp.stdout.splitlines() lines = cp.stdout.splitlines()
res = [] res = []
d = None d = None
if (new_dict_key == ''): if new_dict_key == "":
d = dict() d = dict()
res.append(d) res.append(d)
for i in range(0, len(lines), 2): for i in range(0, len(lines), 2):
@ -28,12 +28,14 @@ def stdout_as_list_of_dicts(cp: subprocess.CompletedProcess, new_dict_key =""):
d[lines[i]] = lines[i + 1] d[lines[i]] = lines[i + 1]
return res return res
def wait_for(func, pred, timeout_sec, timeout_msg=""): def wait_for(func, pred, timeout_sec, timeout_msg=""):
while not pred(func()): while not pred(func()):
assert timeout_sec > 0, timeout_msg assert timeout_sec > 0, timeout_msg
timeout_sec = timeout_sec - 1 timeout_sec = timeout_sec - 1
time.sleep(1) time.sleep(1)
async def await_for(func, pred, timeout_sec, timeout_msg=""): async def await_for(func, pred, timeout_sec, timeout_msg=""):
done = False done = False
while not done: while not done:
@ -60,21 +62,26 @@ class Sentinel:
config = [ config = [
f"port {self.port}", f"port {self.port}",
f"sentinel monitor {self.default_deployment} 127.0.0.1 {self.initial_master_port} 1", f"sentinel monitor {self.default_deployment} 127.0.0.1 {self.initial_master_port} 1",
f"sentinel down-after-milliseconds {self.default_deployment} 3000" f"sentinel down-after-milliseconds {self.default_deployment} 3000",
] ]
self.config_file.write_text("\n".join(config)) self.config_file.write_text("\n".join(config))
logging.info(self.config_file.read_text()) logging.info(self.config_file.read_text())
self.proc = subprocess.Popen(["redis-server", f"{self.config_file.absolute()}", "--sentinel"]) self.proc = subprocess.Popen(
["redis-server", f"{self.config_file.absolute()}", "--sentinel"]
)
def stop(self): def stop(self):
self.proc.terminate() self.proc.terminate()
self.proc.wait(timeout=10) self.proc.wait(timeout=10)
def run_cmd(self, args, sentinel_cmd=True, capture_output=False, assert_ok=True) -> subprocess.CompletedProcess: def run_cmd(
self, args, sentinel_cmd=True, capture_output=False, assert_ok=True
) -> subprocess.CompletedProcess:
run_args = ["redis-cli", "-p", f"{self.port}"] run_args = ["redis-cli", "-p", f"{self.port}"]
if sentinel_cmd: run_args = run_args + ["sentinel"] if sentinel_cmd:
run_args = run_args + ["sentinel"]
run_args = run_args + args run_args = run_args + args
cp = subprocess.run(run_args, capture_output=capture_output, text=True) cp = subprocess.run(run_args, capture_output=capture_output, text=True)
if assert_ok: if assert_ok:
@ -85,7 +92,9 @@ class Sentinel:
wait_for( wait_for(
lambda: self.run_cmd(["ping"], sentinel_cmd=False, assert_ok=False), lambda: self.run_cmd(["ping"], sentinel_cmd=False, assert_ok=False),
lambda cp: cp.returncode == 0, lambda cp: cp.returncode == 0,
timeout_sec=10, timeout_msg="Timeout waiting for sentinel to become ready.") timeout_sec=10,
timeout_msg="Timeout waiting for sentinel to become ready.",
)
def master(self, deployment="") -> dict: def master(self, deployment="") -> dict:
if deployment == "": if deployment == "":
@ -108,10 +117,17 @@ class Sentinel:
def failover(self, deployment=""): def failover(self, deployment=""):
if deployment == "": if deployment == "":
deployment = self.default_deployment deployment = self.default_deployment
self.run_cmd(["failover", deployment,]) self.run_cmd(
[
"failover",
deployment,
]
)
@pytest.fixture(scope="function") # Sentinel has state which we don't want carried over form test to test. @pytest.fixture(
scope="function"
) # Sentinel has state which we don't want carried over form test to test.
def sentinel(tmp_dir, port_picker) -> Sentinel: def sentinel(tmp_dir, port_picker) -> Sentinel:
s = Sentinel(port_picker.get_available_port(), port_picker.get_available_port(), tmp_dir) s = Sentinel(port_picker.get_available_port(), port_picker.get_available_port(), tmp_dir)
s.start() s.start()
@ -140,7 +156,9 @@ async def test_failover(df_local_factory, sentinel, port_picker):
await await_for( await await_for(
lambda: sentinel.master(), lambda: sentinel.master(),
lambda m: m["num-slaves"] == "1", lambda m: m["num-slaves"] == "1",
timeout_sec=15, timeout_msg="Timeout waiting for sentinel to pick up replica.") timeout_sec=15,
timeout_msg="Timeout waiting for sentinel to pick up replica.",
)
sentinel.failover() sentinel.failover()
@ -148,7 +166,8 @@ async def test_failover(df_local_factory, sentinel, port_picker):
await await_for( await await_for(
lambda: sentinel.live_master_port(), lambda: sentinel.live_master_port(),
lambda p: p == replica.port, lambda p: p == replica.port,
timeout_sec=10, timeout_msg="Timeout waiting for sentinel to report replica as master." timeout_sec=10,
timeout_msg="Timeout waiting for sentinel to report replica as master.",
) )
assert sentinel.slaves()[0]["port"] == str(master.port) assert sentinel.slaves()[0]["port"] == str(master.port)
@ -159,7 +178,8 @@ async def test_failover(df_local_factory, sentinel, port_picker):
await await_for( await await_for(
lambda: master_client.get("key"), lambda: master_client.get("key"),
lambda val: val == b"value", lambda val: val == b"value",
15, "Timeout waiting for key to exist in replica." 15,
"Timeout waiting for key to exist in replica.",
) )
except AssertionError: except AssertionError:
syncid, r_offset = await master_client.execute_command("DEBUG REPLICA OFFSET") syncid, r_offset = await master_client.execute_command("DEBUG REPLICA OFFSET")
@ -199,7 +219,9 @@ async def test_master_failure(df_local_factory, sentinel, port_picker):
await await_for( await await_for(
lambda: sentinel.master(), lambda: sentinel.master(),
lambda m: m["num-slaves"] == "1", lambda m: m["num-slaves"] == "1",
timeout_sec=15, timeout_msg="Timeout waiting for sentinel to pick up replica.") timeout_sec=15,
timeout_msg="Timeout waiting for sentinel to pick up replica.",
)
# Simulate master failure. # Simulate master failure.
master.stop() master.stop()
@ -208,7 +230,8 @@ async def test_master_failure(df_local_factory, sentinel, port_picker):
await await_for( await await_for(
lambda: sentinel.live_master_port(), lambda: sentinel.live_master_port(),
lambda p: p == replica.port, lambda p: p == replica.port,
timeout_sec=300, timeout_msg="Timeout waiting for sentinel to report replica as master." timeout_sec=300,
timeout_msg="Timeout waiting for sentinel to report replica as master.",
) )
# Verify we can now write to replica. # Verify we can now write to replica.

View file

@ -5,7 +5,7 @@ from .utility import *
def test_quit(connection): def test_quit(connection):
connection.send_command("QUIT") connection.send_command("QUIT")
assert connection.read_response() == b'OK' assert connection.read_response() == b"OK"
with pytest.raises(redis.exceptions.ConnectionError) as e: with pytest.raises(redis.exceptions.ConnectionError) as e:
connection.read_response() connection.read_response()
@ -16,7 +16,7 @@ def test_quit_after_sub(connection):
connection.read_response() connection.read_response()
connection.send_command("QUIT") connection.send_command("QUIT")
assert connection.read_response() == b'OK' assert connection.read_response() == b"OK"
with pytest.raises(redis.exceptions.ConnectionError) as e: with pytest.raises(redis.exceptions.ConnectionError) as e:
connection.read_response() connection.read_response()
@ -30,7 +30,7 @@ async def test_multi_exec(async_client: aioredis.Redis):
assert val == [True, "bar"] assert val == [True, "bar"]
''' """
see https://github.com/dragonflydb/dragonfly/issues/457 see https://github.com/dragonflydb/dragonfly/issues/457
For now we would not allow for eval command inside multi For now we would not allow for eval command inside multi
As this would create to level transactions (in effect recursive call As this would create to level transactions (in effect recursive call
@ -38,7 +38,8 @@ to Schedule function).
When this issue is fully fixed, this test would failed, and then it should When this issue is fully fixed, this test would failed, and then it should
change to match the fact that we supporting this operation. change to match the fact that we supporting this operation.
For now we are expecting to get an error For now we are expecting to get an error
''' """
@pytest.mark.skip("Skip until we decided on correct behaviour of eval inside multi") @pytest.mark.skip("Skip until we decided on correct behaviour of eval inside multi")
async def test_multi_eval(async_client: aioredis.Redis): async def test_multi_eval(async_client: aioredis.Redis):
@ -76,10 +77,12 @@ async def test_client_list(df_factory):
instance.stop() instance.stop()
await disconnect_clients(client, admin_client) await disconnect_clients(client, admin_client)
async def test_scan(async_client: aioredis.Redis): async def test_scan(async_client: aioredis.Redis):
''' """
make sure that the scan command is working with python make sure that the scan command is working with python
''' """
def gen_test_data(): def gen_test_data():
for i in range(10): for i in range(10):
yield f"key-{i}", f"value-{i}" yield f"key-{i}", f"value-{i}"

View file

@ -9,13 +9,16 @@ from . import dfly_args
BASIC_ARGS = {"dir": "{DRAGONFLY_TMP}/"} BASIC_ARGS = {"dir": "{DRAGONFLY_TMP}/"}
@pytest.mark.skip(reason='Currently we can not guarantee that on shutdown if command is executed and value is written we response before breaking the connection') @pytest.mark.skip(
reason="Currently we can not guarantee that on shutdown if command is executed and value is written we response before breaking the connection"
)
@dfly_args({"proactor_threads": "4"}) @dfly_args({"proactor_threads": "4"})
class TestDflyAutoLoadSnapshot(): class TestDflyAutoLoadSnapshot:
""" """
Test automatic loading of dump files on startup with timestamp. Test automatic loading of dump files on startup with timestamp.
When command is executed if a value is written we should send the response before shutdown When command is executed if a value is written we should send the response before shutdown
""" """
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_gracefull_shutdown(self, df_local_factory): async def test_gracefull_shutdown(self, df_local_factory):
df_args = {"dbfilename": "dump", **BASIC_ARGS, "port": 1111} df_args = {"dbfilename": "dump", **BASIC_ARGS, "port": 1111}
@ -39,7 +42,9 @@ class TestDflyAutoLoadSnapshot():
await client.execute_command("SHUTDOWN") await client.execute_command("SHUTDOWN")
await client.connection_pool.disconnect() await client.connection_pool.disconnect()
_, *results = await asyncio.gather(delayed_takeover(), *[counter(f"key{i}") for i in range(16)]) _, *results = await asyncio.gather(
delayed_takeover(), *[counter(f"key{i}") for i in range(16)]
)
df_server.start() df_server.start()
client = aioredis.Redis(port=df_server.port) client = aioredis.Redis(port=df_server.port)

View file

@ -18,9 +18,10 @@ class SnapshotTestBase:
self.tmp_dir = tmp_dir self.tmp_dir = tmp_dir
def get_main_file(self, pattern): def get_main_file(self, pattern):
def is_main(f): return "summary" in f if pattern.endswith( def is_main(f):
"dfs") else True return "summary" in f if pattern.endswith("dfs") else True
files = glob.glob(str(self.tmp_dir.absolute()) + '/' + pattern)
files = glob.glob(str(self.tmp_dir.absolute()) + "/" + pattern)
possible_mains = list(filter(is_main, files)) possible_mains = list(filter(is_main, files))
assert len(possible_mains) == 1, possible_mains assert len(possible_mains) == 1, possible_mains
return possible_mains[0] return possible_mains[0]
@ -29,6 +30,7 @@ class SnapshotTestBase:
@dfly_args({**BASIC_ARGS, "dbfilename": "test-rdb-{{timestamp}}"}) @dfly_args({**BASIC_ARGS, "dbfilename": "test-rdb-{{timestamp}}"})
class TestRdbSnapshot(SnapshotTestBase): class TestRdbSnapshot(SnapshotTestBase):
"""Test single file rdb snapshot""" """Test single file rdb snapshot"""
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def setup(self, tmp_dir: Path): def setup(self, tmp_dir: Path):
super().setup(tmp_dir) super().setup(tmp_dir)
@ -51,6 +53,7 @@ class TestRdbSnapshot(SnapshotTestBase):
@dfly_args({**BASIC_ARGS, "dbfilename": "test-rdbexact.rdb", "nodf_snapshot_format": None}) @dfly_args({**BASIC_ARGS, "dbfilename": "test-rdbexact.rdb", "nodf_snapshot_format": None})
class TestRdbSnapshotExactFilename(SnapshotTestBase): class TestRdbSnapshotExactFilename(SnapshotTestBase):
"""Test single file rdb snapshot without a timestamp""" """Test single file rdb snapshot without a timestamp"""
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def setup(self, tmp_dir: Path): def setup(self, tmp_dir: Path):
super().setup(tmp_dir) super().setup(tmp_dir)
@ -74,6 +77,7 @@ class TestRdbSnapshotExactFilename(SnapshotTestBase):
@dfly_args({**BASIC_ARGS, "dbfilename": "test-dfs"}) @dfly_args({**BASIC_ARGS, "dbfilename": "test-dfs"})
class TestDflySnapshot(SnapshotTestBase): class TestDflySnapshot(SnapshotTestBase):
"""Test multi file snapshot""" """Test multi file snapshot"""
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def setup(self, tmp_dir: Path): def setup(self, tmp_dir: Path):
self.tmp_dir = tmp_dir self.tmp_dir = tmp_dir
@ -88,16 +92,20 @@ class TestDflySnapshot(SnapshotTestBase):
# save + flush + load # save + flush + load
await async_client.execute_command("SAVE DF") await async_client.execute_command("SAVE DF")
assert await async_client.flushall() assert await async_client.flushall()
await async_client.execute_command("DEBUG LOAD " + super().get_main_file("test-dfs-summary.dfs")) await async_client.execute_command(
"DEBUG LOAD " + super().get_main_file("test-dfs-summary.dfs")
)
assert await seeder.compare(start_capture) assert await seeder.compare(start_capture)
# We spawn instances manually, so reduce memory usage of default to minimum # We spawn instances manually, so reduce memory usage of default to minimum
@dfly_args({"proactor_threads": "1"}) @dfly_args({"proactor_threads": "1"})
class TestDflyAutoLoadSnapshot(SnapshotTestBase): class TestDflyAutoLoadSnapshot(SnapshotTestBase):
"""Test automatic loading of dump files on startup with timestamp""" """Test automatic loading of dump files on startup with timestamp"""
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def setup(self, tmp_dir: Path): def setup(self, tmp_dir: Path):
self.tmp_dir = tmp_dir self.tmp_dir = tmp_dir
@ -115,8 +123,8 @@ class TestDflyAutoLoadSnapshot(SnapshotTestBase):
@pytest.mark.parametrize("save_type, dbfilename", cases) @pytest.mark.parametrize("save_type, dbfilename", cases)
async def test_snapshot(self, df_local_factory, save_type, dbfilename): async def test_snapshot(self, df_local_factory, save_type, dbfilename):
df_args = {"dbfilename": dbfilename, **BASIC_ARGS, "port": 1111} df_args = {"dbfilename": dbfilename, **BASIC_ARGS, "port": 1111}
if save_type == 'rdb': if save_type == "rdb":
df_args['nodf_snapshot_format'] = "" df_args["nodf_snapshot_format"] = ""
df_server = df_local_factory.create(**df_args) df_server = df_local_factory.create(**df_args)
df_server.start() df_server.start()
@ -135,6 +143,7 @@ class TestDflyAutoLoadSnapshot(SnapshotTestBase):
@dfly_args({**BASIC_ARGS, "dbfilename": "test-periodic", "save_schedule": "*:*"}) @dfly_args({**BASIC_ARGS, "dbfilename": "test-periodic", "save_schedule": "*:*"})
class TestPeriodicSnapshot(SnapshotTestBase): class TestPeriodicSnapshot(SnapshotTestBase):
"""Test periodic snapshotting""" """Test periodic snapshotting"""
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def setup(self, tmp_dir: Path): def setup(self, tmp_dir: Path):
super().setup(tmp_dir) super().setup(tmp_dir)
@ -142,7 +151,8 @@ class TestPeriodicSnapshot(SnapshotTestBase):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_snapshot(self, df_seeder_factory, df_server): async def test_snapshot(self, df_seeder_factory, df_server):
seeder = df_seeder_factory.create( seeder = df_seeder_factory.create(
port=df_server.port, keys=10, multi_transaction_probability=0) port=df_server.port, keys=10, multi_transaction_probability=0
)
await seeder.run(target_deviation=0.5) await seeder.run(target_deviation=0.5)
time.sleep(60) time.sleep(60)
@ -154,14 +164,14 @@ class TestPeriodicSnapshot(SnapshotTestBase):
class TestPathEscapes(SnapshotTestBase): class TestPathEscapes(SnapshotTestBase):
"""Test that we don't allow path escapes. We just check that df_server.start() """Test that we don't allow path escapes. We just check that df_server.start()
fails because we don't have a much better way to test that.""" fails because we don't have a much better way to test that."""
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def setup(self, tmp_dir: Path): def setup(self, tmp_dir: Path):
super().setup(tmp_dir) super().setup(tmp_dir)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_snapshot(self, df_local_factory): async def test_snapshot(self, df_local_factory):
df_server = df_local_factory.create( df_server = df_local_factory.create(dbfilename="../../../../etc/passwd")
dbfilename="../../../../etc/passwd")
try: try:
df_server.start() df_server.start()
assert False, "Server should not start correctly" assert False, "Server should not start correctly"
@ -172,6 +182,7 @@ class TestPathEscapes(SnapshotTestBase):
@dfly_args({**BASIC_ARGS, "dbfilename": "test-shutdown"}) @dfly_args({**BASIC_ARGS, "dbfilename": "test-shutdown"})
class TestDflySnapshotOnShutdown(SnapshotTestBase): class TestDflySnapshotOnShutdown(SnapshotTestBase):
"""Test multi file snapshot""" """Test multi file snapshot"""
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def setup(self, tmp_dir: Path): def setup(self, tmp_dir: Path):
self.tmp_dir = tmp_dir self.tmp_dir = tmp_dir
@ -192,15 +203,17 @@ class TestDflySnapshotOnShutdown(SnapshotTestBase):
assert await seeder.compare(start_capture) assert await seeder.compare(start_capture)
@dfly_args({**BASIC_ARGS, "dbfilename": "test-info-persistence"}) @dfly_args({**BASIC_ARGS, "dbfilename": "test-info-persistence"})
class TestDflyInfoPersistenceLoadingField(SnapshotTestBase): class TestDflyInfoPersistenceLoadingField(SnapshotTestBase):
"""Test is_loading field on INFO PERSISTENCE during snapshot loading""" """Test is_loading field on INFO PERSISTENCE during snapshot loading"""
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def setup(self, tmp_dir: Path): def setup(self, tmp_dir: Path):
self.tmp_dir = tmp_dir self.tmp_dir = tmp_dir
def extract_is_loading_field(self, res): def extract_is_loading_field(self, res):
matcher = b'loading:' matcher = b"loading:"
start = res.find(matcher) start = res.find(matcher)
pos = start + len(matcher) pos = start + len(matcher)
return chr(res[pos]) return chr(res[pos])
@ -214,6 +227,6 @@ class TestDflyInfoPersistenceLoadingField(SnapshotTestBase):
# Wait for snapshot to finish loading and try INFO PERSISTENCE # Wait for snapshot to finish loading and try INFO PERSISTENCE
await wait_available_async(a_client) await wait_available_async(a_client)
res = await a_client.execute_command("INFO PERSISTENCE") res = await a_client.execute_command("INFO PERSISTENCE")
assert '0' == self.extract_is_loading_field(res) assert "0" == self.extract_is_loading_field(res)
await a_client.connection_pool.disconnect() await a_client.connection_pool.disconnect()

View file

@ -43,16 +43,16 @@ async def wait_available_async(client: aioredis.Redis):
its = 0 its = 0
while True: while True:
try: try:
await client.get('key') await client.get("key")
return return
except aioredis.ResponseError as e: except aioredis.ResponseError as e:
if ("MOVED" in str(e)): if "MOVED" in str(e):
# MOVED means we *can* serve traffic, but 'key' does not belong to an owned slot # MOVED means we *can* serve traffic, but 'key' does not belong to an owned slot
return return
assert "Can not execute during LOADING" in str(e) assert "Can not execute during LOADING" in str(e)
# Print W to indicate test is waiting for replica # Print W to indicate test is waiting for replica
print('W', end='', flush=True) print("W", end="", flush=True)
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
its += 1 its += 1
@ -141,7 +141,8 @@ class CommandGenerator:
def generate_val(self, t: ValueType): def generate_val(self, t: ValueType):
"""Generate filler value of configured size for type t""" """Generate filler value of configured size for type t"""
def rand_str(k=3, s=''):
def rand_str(k=3, s=""):
# Use small k value to reduce mem usage and increase number of ops # Use small k value to reduce mem usage and increase number of ops
return s.join(random.choices(string.ascii_letters, k=k)) return s.join(random.choices(string.ascii_letters, k=k))
@ -156,13 +157,15 @@ class CommandGenerator:
return tuple(rand_str() for _ in range(self.val_size // 4)) return tuple(rand_str() for _ in range(self.val_size // 4))
elif t == ValueType.HSET: elif t == ValueType.HSET:
# Random sequence of k-letter keys + int and two start values for HSET # Random sequence of k-letter keys + int and two start values for HSET
elements = ((rand_str(), random.randint(0, self.val_size)) elements = (
for _ in range(self.val_size//5)) (rand_str(), random.randint(0, self.val_size)) for _ in range(self.val_size // 5)
return ('v0', 0, 'v1', 0) + tuple(itertools.chain(*elements)) )
return ("v0", 0, "v1", 0) + tuple(itertools.chain(*elements))
elif t == ValueType.ZSET: elif t == ValueType.ZSET:
# Random sequnce of k-letter keys and int score for ZSET # Random sequnce of k-letter keys and int score for ZSET
elements = ((random.randint(0, self.val_size), rand_str()) elements = (
for _ in range(self.val_size//4)) (random.randint(0, self.val_size), rand_str()) for _ in range(self.val_size // 4)
)
return tuple(itertools.chain(*elements)) return tuple(itertools.chain(*elements))
elif t == ValueType.JSON: elif t == ValueType.JSON:
@ -170,8 +173,7 @@ class CommandGenerator:
# - arr (array of random strings) # - arr (array of random strings)
# - ints (array of objects {i:random integer}) # - ints (array of objects {i:random integer})
# - i (random integer) # - i (random integer)
ints = [{"i": random.randint(0, 100)} ints = [{"i": random.randint(0, 100)} for i in range(self.val_size // 6)]
for i in range(self.val_size//6)]
strs = [rand_str() for _ in range(self.val_size // 6)] strs = [rand_str() for _ in range(self.val_size // 6)]
return "$", json.dumps({"arr": strs, "ints": ints, "i": random.randint(0, 100)}) return "$", json.dumps({"arr": strs, "ints": ints, "i": random.randint(0, 100)})
else: else:
@ -187,8 +189,9 @@ class CommandGenerator:
return None, 0 return None, 0
return f"PEXPIRE k{key} {random.randint(0, 50)}", -1 return f"PEXPIRE k{key} {random.randint(0, 50)}", -1
else: else:
keys_gen = (self.randomize_key(pop=True) keys_gen = (
for _ in range(random.randint(1, self.max_multikey))) self.randomize_key(pop=True) for _ in range(random.randint(1, self.max_multikey))
)
keys = [f"k{k}" for k, _ in keys_gen if k is not None] keys = [f"k{k}" for k, _ in keys_gen if k is not None]
if len(keys) == 0: if len(keys) == 0:
@ -196,19 +199,19 @@ class CommandGenerator:
return "DEL " + " ".join(keys), -len(keys) return "DEL " + " ".join(keys), -len(keys)
UPDATE_ACTIONS = [ UPDATE_ACTIONS = [
('APPEND {k} {val}', ValueType.STRING), ("APPEND {k} {val}", ValueType.STRING),
('SETRANGE {k} 10 {val}', ValueType.STRING), ("SETRANGE {k} 10 {val}", ValueType.STRING),
('LPUSH {k} {val}', ValueType.LIST), ("LPUSH {k} {val}", ValueType.LIST),
('LPOP {k}', ValueType.LIST), ("LPOP {k}", ValueType.LIST),
('SADD {k} {val}', ValueType.SET), ("SADD {k} {val}", ValueType.SET),
('SPOP {k}', ValueType.SET), ("SPOP {k}", ValueType.SET),
('HSETNX {k} v0 {val}', ValueType.HSET), ("HSETNX {k} v0 {val}", ValueType.HSET),
('HINCRBY {k} v1 1', ValueType.HSET), ("HINCRBY {k} v1 1", ValueType.HSET),
('ZPOPMIN {k} 1', ValueType.ZSET), ("ZPOPMIN {k} 1", ValueType.ZSET),
('ZADD {k} 0 {val}', ValueType.ZSET), ("ZADD {k} 0 {val}", ValueType.ZSET),
('JSON.NUMINCRBY {k} $..i 1', ValueType.JSON), ("JSON.NUMINCRBY {k} $..i 1", ValueType.JSON),
('JSON.ARRPOP {k} $.arr', ValueType.JSON), ("JSON.ARRPOP {k} $.arr", ValueType.JSON),
('JSON.ARRAPPEND {k} $.arr "{val}"', ValueType.JSON) ('JSON.ARRAPPEND {k} $.arr "{val}"', ValueType.JSON),
] ]
def gen_update_cmd(self): def gen_update_cmd(self):
@ -217,16 +220,16 @@ class CommandGenerator:
""" """
cmd, t = random.choice(self.UPDATE_ACTIONS) cmd, t = random.choice(self.UPDATE_ACTIONS)
k, _ = self.randomize_key(t) k, _ = self.randomize_key(t)
val = ''.join(random.choices(string.ascii_letters, k=3)) val = "".join(random.choices(string.ascii_letters, k=3))
return cmd.format(k=f"k{k}", val=val) if k is not None else None, 0 return cmd.format(k=f"k{k}", val=val) if k is not None else None, 0
GROW_ACTINONS = { GROW_ACTINONS = {
ValueType.STRING: 'MSET', ValueType.STRING: "MSET",
ValueType.LIST: 'LPUSH', ValueType.LIST: "LPUSH",
ValueType.SET: 'SADD', ValueType.SET: "SADD",
ValueType.HSET: 'HMSET', ValueType.HSET: "HMSET",
ValueType.ZSET: 'ZADD', ValueType.ZSET: "ZADD",
ValueType.JSON: 'JSON.SET' ValueType.JSON: "JSON.SET",
} }
def gen_grow_cmd(self): def gen_grow_cmd(self):
@ -241,8 +244,7 @@ class CommandGenerator:
count = 1 count = 1
keys = (self.add_key(t) for _ in range(count)) keys = (self.add_key(t) for _ in range(count))
payload = itertools.chain( payload = itertools.chain(*((f"k{k}",) + self.generate_val(t) for k in keys))
*((f"k{k}",) + self.generate_val(t) for k in keys))
filtered_payload = filter(lambda p: p is not None, payload) filtered_payload = filter(lambda p: p is not None, payload)
return (self.GROW_ACTINONS[t],) + tuple(filtered_payload), count return (self.GROW_ACTINONS[t],) + tuple(filtered_payload), count
@ -269,8 +271,7 @@ class CommandGenerator:
return [ return [
max(self.base_diff_prob - self.diff_speed * dist, self.min_diff_prob), max(self.base_diff_prob - self.diff_speed * dist, self.min_diff_prob),
1.0, 1.0,
max(self.base_diff_prob + 2 * max(self.base_diff_prob + 2 * self.diff_speed * dist, self.min_diff_prob),
self.diff_speed * dist, self.min_diff_prob)
] ]
def generate(self): def generate(self):
@ -280,8 +281,7 @@ class CommandGenerator:
while len(cmds) < self.batch_size: while len(cmds) < self.batch_size:
# Re-calculating changes in small groups # Re-calculating changes in small groups
if len(changes) == 0: if len(changes) == 0:
changes = random.choices( changes = random.choices(list(SizeChange), weights=self.size_change_probs(), k=20)
list(SizeChange), weights=self.size_change_probs(), k=20)
cmd, delta = self.make(changes.pop()) cmd, delta = self.make(changes.pop())
if cmd is not None: if cmd is not None:
@ -311,7 +311,7 @@ class DataCapture:
printed = 0 printed = 0
diff = difflib.ndiff(self.entries, other.entries) diff = difflib.ndiff(self.entries, other.entries)
for line in diff: for line in diff:
if line.startswith(' '): if line.startswith(" "):
continue continue
eprint(line) eprint(line)
if printed >= 20: if printed >= 20:
@ -344,10 +344,20 @@ class DflySeeder:
assert await seeder.compare(capture, port=1112) assert await seeder.compare(capture, port=1112)
""" """
def __init__(self, port=6379, keys=1000, val_size=50, batch_size=100, max_multikey=5, dbcount=1, multi_transaction_probability=0.3, log_file=None, unsupported_types=[], stop_on_failure=True): def __init__(
self.gen = CommandGenerator( self,
keys, val_size, batch_size, max_multikey, unsupported_types port=6379,
) keys=1000,
val_size=50,
batch_size=100,
max_multikey=5,
dbcount=1,
multi_transaction_probability=0.3,
log_file=None,
unsupported_types=[],
stop_on_failure=True,
):
self.gen = CommandGenerator(keys, val_size, batch_size, max_multikey, unsupported_types)
self.port = port self.port = port
self.dbcount = dbcount self.dbcount = dbcount
self.multi_transaction_probability = multi_transaction_probability self.multi_transaction_probability = multi_transaction_probability
@ -356,7 +366,7 @@ class DflySeeder:
self.log_file = log_file self.log_file = log_file
if self.log_file is not None: if self.log_file is not None:
open(self.log_file, 'w').close() open(self.log_file, "w").close()
async def run(self, target_ops=None, target_deviation=None): async def run(self, target_ops=None, target_deviation=None):
""" """
@ -366,11 +376,11 @@ class DflySeeder:
print(f"Running ops:{target_ops} deviation:{target_deviation}") print(f"Running ops:{target_ops} deviation:{target_deviation}")
self.stop_flag = False self.stop_flag = False
queues = [asyncio.Queue(maxsize=3) for _ in range(self.dbcount)] queues = [asyncio.Queue(maxsize=3) for _ in range(self.dbcount)]
producer = asyncio.create_task(self._generator_task( producer = asyncio.create_task(
queues, target_ops=target_ops, target_deviation=target_deviation)) self._generator_task(queues, target_ops=target_ops, target_deviation=target_deviation)
)
consumers = [ consumers = [
asyncio.create_task(self._executor_task(i, queue)) asyncio.create_task(self._executor_task(i, queue)) for i, queue in enumerate(queues)
for i, queue in enumerate(queues)
] ]
time_start = time.time() time_start = time.time()
@ -398,9 +408,9 @@ class DflySeeder:
port = self.port port = self.port
keys = sorted(list(self.gen.keys_and_types())) keys = sorted(list(self.gen.keys_and_types()))
captures = await asyncio.gather(*( captures = await asyncio.gather(
self._capture_db(port=port, target_db=db, keys=keys) for db in range(self.dbcount) *(self._capture_db(port=port, target_db=db, keys=keys) for db in range(self.dbcount))
)) )
return captures return captures
async def compare(self, initial_captures, port=6379): async def compare(self, initial_captures, port=6379):
@ -408,7 +418,9 @@ class DflySeeder:
print(f"comparing capture to {port}") print(f"comparing capture to {port}")
target_captures = await self.capture(port=port) target_captures = await self.capture(port=port)
for db, target_capture, initial_capture in zip(range(self.dbcount), target_captures, initial_captures): for db, target_capture, initial_capture in zip(
range(self.dbcount), target_captures, initial_captures
):
print(f"comparing capture to {port}, db: {db}") print(f"comparing capture to {port}, db: {db}")
if not initial_capture.compare(target_capture): if not initial_capture.compare(target_capture):
eprint(f">>> Inconsistent data on port {port}, db {db}") eprint(f">>> Inconsistent data on port {port}, db {db}")
@ -433,7 +445,7 @@ class DflySeeder:
file = None file = None
if self.log_file: if self.log_file:
file = open(self.log_file, 'a') file = open(self.log_file, "a")
def should_run(): def should_run():
if self.stop_flag: if self.stop_flag:
@ -455,7 +467,7 @@ class DflySeeder:
blob, deviation = self.gen.generate() blob, deviation = self.gen.generate()
is_multi_transaction = random.random() < self.multi_transaction_probability is_multi_transaction = random.random() < self.multi_transaction_probability
tx_data = (blob, is_multi_transaction) tx_data = (blob, is_multi_transaction)
cpu_time += (time.time() - start_time) cpu_time += time.time() - start_time
await asyncio.gather(*(q.put(tx_data) for q in queues)) await asyncio.gather(*(q.put(tx_data) for q in queues))
submitted += len(blob) submitted += len(blob)
@ -463,14 +475,12 @@ class DflySeeder:
if file is not None: if file is not None:
pattern = "MULTI\n{}\nEXEC\n" if is_multi_transaction else "{}\n" pattern = "MULTI\n{}\nEXEC\n" if is_multi_transaction else "{}\n"
file.write(pattern.format('\n'.join(stringify_cmd(cmd) file.write(pattern.format("\n".join(stringify_cmd(cmd) for cmd in blob)))
for cmd in blob)))
print('.', end='', flush=True) print(".", end="", flush=True)
await asyncio.sleep(0.0) await asyncio.sleep(0.0)
print("\ncpu time", cpu_time, "batches", print("\ncpu time", cpu_time, "batches", batches, "commands", submitted)
batches, "commands", submitted)
await asyncio.gather(*(q.put(None) for q in queues)) await asyncio.gather(*(q.put(None) for q in queues))
for q in queues: for q in queues:
@ -512,20 +522,19 @@ class DflySeeder:
ValueType.LIST: lambda pipe, k: pipe.lrange(k, 0, -1), ValueType.LIST: lambda pipe, k: pipe.lrange(k, 0, -1),
ValueType.SET: lambda pipe, k: pipe.smembers(k), ValueType.SET: lambda pipe, k: pipe.smembers(k),
ValueType.HSET: lambda pipe, k: pipe.hgetall(k), ValueType.HSET: lambda pipe, k: pipe.hgetall(k),
ValueType.ZSET: lambda pipe, k: pipe.zrange( ValueType.ZSET: lambda pipe, k: pipe.zrange(k, start=0, end=-1, withscores=True),
k, start=0, end=-1, withscores=True), ValueType.JSON: lambda pipe, k: pipe.execute_command("JSON.GET", k, "$"),
ValueType.JSON: lambda pipe, k: pipe.execute_command(
"JSON.GET", k, "$")
} }
CAPTURE_EXTRACTORS = { CAPTURE_EXTRACTORS = {
ValueType.STRING: lambda res, tostr: (tostr(res),), ValueType.STRING: lambda res, tostr: (tostr(res),),
ValueType.LIST: lambda res, tostr: (tostr(s) for s in res), ValueType.LIST: lambda res, tostr: (tostr(s) for s in res),
ValueType.SET: lambda res, tostr: sorted(tostr(s) for s in res), ValueType.SET: lambda res, tostr: sorted(tostr(s) for s in res),
ValueType.HSET: lambda res, tostr: sorted(tostr(k)+"="+tostr(v) for k, v in res.items()), ValueType.HSET: lambda res, tostr: sorted(
ValueType.ZSET: lambda res, tostr: ( tostr(k) + "=" + tostr(v) for k, v in res.items()
tostr(s)+"-"+str(f) for (s, f) in res), ),
ValueType.JSON: lambda res, tostr: (tostr(res),) ValueType.ZSET: lambda res, tostr: (tostr(s) + "-" + str(f) for (s, f) in res),
ValueType.JSON: lambda res, tostr: (tostr(res),),
} }
async def _capture_entries(self, client, keys): async def _capture_entries(self, client, keys):
@ -540,8 +549,7 @@ class DflySeeder:
results = await pipe.execute() results = await pipe.execute()
for (k, t), res in zip(group, results): for (k, t), res in zip(group, results):
out = f"{t.name} k{k}: " + \ out = f"{t.name} k{k}: " + " ".join(self.CAPTURE_EXTRACTORS[t](res, tostr))
' '.join(self.CAPTURE_EXTRACTORS[t](res, tostr))
entries.append(out) entries.append(out)
return entries return entries
@ -563,11 +571,13 @@ async def disconnect_clients(*clients):
await asyncio.gather(*(c.connection_pool.disconnect() for c in clients)) await asyncio.gather(*(c.connection_pool.disconnect() for c in clients))
def gen_certificate(ca_key_path, ca_certificate_path, certificate_request_path, private_key_path, certificate_path): def gen_certificate(
ca_key_path, ca_certificate_path, certificate_request_path, private_key_path, certificate_path
):
# Generate Dragonfly's private key and certificate signing request (CSR) # Generate Dragonfly's private key and certificate signing request (CSR)
step1 = rf'openssl req -newkey rsa:4096 -nodes -keyout {private_key_path} -out {certificate_request_path} -subj "/C=GR/ST=SKG/L=Thessaloniki/O=KK/OU=Comp/CN=Gr/emailAddress=does_not_exist@gmail.com"' step1 = rf'openssl req -newkey rsa:4096 -nodes -keyout {private_key_path} -out {certificate_request_path} -subj "/C=GR/ST=SKG/L=Thessaloniki/O=KK/OU=Comp/CN=Gr/emailAddress=does_not_exist@gmail.com"'
subprocess.run(step1, shell=True) subprocess.run(step1, shell=True)
# Use CA's private key to sign dragonfly's CSR and get back the signed certificate # Use CA's private key to sign dragonfly's CSR and get back the signed certificate
step2 = fr'openssl x509 -req -in {certificate_request_path} -days 1 -CA {ca_certificate_path} -CAkey {ca_key_path} -CAcreateserial -out {certificate_path}' step2 = rf"openssl x509 -req -in {certificate_request_path} -days 1 -CA {ca_certificate_path} -CAkey {ca_key_path} -CAcreateserial -out {certificate_path}"
subprocess.run(step2, shell=True) subprocess.run(step2, shell=True)

View file

@ -13,12 +13,14 @@ from loguru import logger as log
import sys import sys
import random import random
connection_pool = aioredis.ConnectionPool(host="localhost", port=6379, connection_pool = aioredis.ConnectionPool(
db=1, decode_responses=True, max_connections=16) host="localhost", port=6379, db=1, decode_responses=True, max_connections=16
)
key_index = 1 key_index = 1
async def post_to_redis(sem, db_name, index): async def post_to_redis(sem, db_name, index):
global key_index global key_index
async with sem: async with sem:
@ -51,8 +53,8 @@ async def do_concurrent(db_name):
res = await asyncio.gather(*tasks) res = await asyncio.gather(*tasks)
if __name__ == '__main__': if __name__ == "__main__":
log.remove() log.remove()
log.add(sys.stdout, enqueue=True, level='INFO') log.add(sys.stdout, enqueue=True, level="INFO")
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.run_until_complete(do_concurrent("my_db")) loop.run_until_complete(do_concurrent("my_db"))

View file

@ -12,40 +12,40 @@ def fill_set(args, redis: rclient.Redis):
for j in range(args.num): for j in range(args.num):
token = uuid.uuid1().hex token = uuid.uuid1().hex
# print(token) # print(token)
key = f'USER_OTP:{token}' key = f"USER_OTP:{token}"
arr = [] arr = []
for i in range(30): for i in range(30):
otp = ''.join(random.choices( otp = "".join(random.choices(string.ascii_uppercase + string.digits, k=12))
string.ascii_uppercase + string.digits, k=12))
arr.append(otp) arr.append(otp)
redis.execute_command('sadd', key, *arr) redis.execute_command("sadd", key, *arr)
def fill_hset(args, redis): def fill_hset(args, redis):
for j in range(args.num): for j in range(args.num):
token = uuid.uuid1().hex token = uuid.uuid1().hex
key = f'USER_INFO:{token}' key = f"USER_INFO:{token}"
phone = f'555-999-{j}' phone = f"555-999-{j}"
user_id = 'user' * 5 + f'-{j}' user_id = "user" * 5 + f"-{j}"
redis.hset(key, 'phone', phone) redis.hset(key, "phone", phone)
redis.hset(key, 'user_id', user_id) redis.hset(key, "user_id", user_id)
redis.hset(key, 'login_time', time.time()) redis.hset(key, "login_time", time.time())
def main(): def main():
parser = argparse.ArgumentParser(description='fill hset entities') parser = argparse.ArgumentParser(description="fill hset entities")
parser.add_argument("-p", type=int, help="redis port", dest="port", default=6380)
parser.add_argument("-n", type=int, help="number of keys", dest="num", default=10000)
parser.add_argument( parser.add_argument(
'-p', type=int, help='redis port', dest='port', default=6380) "--type", type=str, choices=["hset", "set"], help="set type", default="hset"
parser.add_argument( )
'-n', type=int, help='number of keys', dest='num', default=10000)
parser.add_argument(
'--type', type=str, choices=['hset', 'set'], help='set type', default='hset')
args = parser.parse_args() args = parser.parse_args()
redis = rclient.Redis(host='localhost', port=args.port, db=0) redis = rclient.Redis(host="localhost", port=args.port, db=0)
if args.type == 'hset': if args.type == "hset":
fill_hset(args, redis) fill_hset(args, redis)
elif args.type == 'set': elif args.type == "set":
fill_set(args, redis) fill_set(args, redis)
if __name__ == "__main__": if __name__ == "__main__":
main() main()