1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-14 11:58:02 +00:00

fix: missing error reply to client after AddOrFind throw std::bad_alloc (#2411)

* Handle properly and reply on execution paths that throw std::bad_alloc within AddOrFind
This commit is contained in:
Kostas Kyrimis 2024-01-15 10:13:10 +02:00 committed by GitHub
parent 13718699d8
commit 39e7e5ad87
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 69 additions and 19 deletions

View file

@ -258,7 +258,12 @@ OpResult<int> PFMergeInternal(CmdArgList args, ConnectionContext* cntx) {
string_view key = ArgS(args, 0);
const OpArgs& op_args = t->GetOpArgs(shard);
auto& db_slice = op_args.shard->db_slice();
auto res = db_slice.AddOrFind(t->GetDbContext(), key);
DbSlice::AddOrFindResult res;
try {
res = db_slice.AddOrFind(t->GetDbContext(), key);
} catch (const bad_alloc& e) {
return OpStatus::OUT_OF_MEMORY;
}
res.it->second.SetString(hll);
return OpStatus::OK;
};

View file

@ -168,7 +168,12 @@ OpStatus IncrementValue(optional<string_view> prev_val, IncrByParam* param) {
OpStatus OpIncrBy(const OpArgs& op_args, string_view key, string_view field, IncrByParam* param) {
auto& db_slice = op_args.shard->db_slice();
auto add_res = db_slice.AddOrFind(op_args.db_cntx, key);
DbSlice::AddOrFindResult add_res;
try {
add_res = db_slice.AddOrFind(op_args.db_cntx, key);
} catch (const bad_alloc& e) {
return OpStatus::OUT_OF_MEMORY;
}
DbTableStats* stats = db_slice.MutableStats(op_args.db_cntx.db_index);

View file

@ -1076,7 +1076,12 @@ OpResult<bool> OpSet(const OpArgs& op_args, string_view key, string_view path,
}
}
SetJson(op_args, key, std::move(parsed_json.value()));
try {
SetJson(op_args, key, std::move(parsed_json.value()));
} catch (const bad_alloc& e) {
return OpStatus::OUT_OF_MEMORY;
}
return true;
}
@ -1154,7 +1159,9 @@ void JsonFamily::Set(CmdArgList args, ConnectionContext* cntx) {
};
Transaction* trans = cntx->transaction;
OpResult<bool> result = trans->ScheduleSingleHopT(std::move(cb));
auto* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder());
if (result) {
if (*result) {

View file

@ -88,24 +88,29 @@ OpResult<uint32_t> OpSetRange(const OpArgs& op_args, string_view key, size_t sta
}
}
auto res = db_slice.AddOrFind(op_args.db_cntx, key);
DbSlice::AddOrFindResult res;
string s;
try {
res = db_slice.AddOrFind(op_args.db_cntx, key);
if (res.is_new) {
s.resize(range_len);
} else {
if (res.it->second.ObjType() != OBJ_STRING)
return OpStatus::WRONG_TYPE;
string s;
s = GetString(op_args.shard, res.it->second);
if (s.size() < range_len)
if (res.is_new) {
s.resize(range_len);
} else {
if (res.it->second.ObjType() != OBJ_STRING)
return OpStatus::WRONG_TYPE;
s = GetString(op_args.shard, res.it->second);
if (s.size() < range_len)
s.resize(range_len);
}
memcpy(s.data() + start, value.data(), value.size());
res.it->second.SetString(s);
} catch (const std::bad_alloc& e) {
return OpStatus::OUT_OF_MEMORY;
}
memcpy(s.data() + start, value.data(), value.size());
res.it->second.SetString(s);
return res.it->second.Size();
}
@ -161,7 +166,12 @@ OpResult<uint32_t> ExtendOrSet(const OpArgs& op_args, string_view key, string_vi
bool prepend) {
auto* shard = op_args.shard;
auto& db_slice = shard->db_slice();
auto add_res = db_slice.AddOrFind(op_args.db_cntx, key);
DbSlice::AddOrFindResult add_res;
try {
add_res = db_slice.AddOrFind(op_args.db_cntx, key);
} catch (const std::bad_alloc& e) {
return OpStatus::OUT_OF_MEMORY;
}
if (add_res.is_new) {
add_res.it->second.SetString(val);
return val.size();
@ -224,7 +234,12 @@ OpResult<string> OpMutableGet(const OpArgs& op_args, string_view key, bool del_h
OpResult<double> OpIncrFloat(const OpArgs& op_args, string_view key, double val) {
auto& db_slice = op_args.shard->db_slice();
auto add_res = db_slice.AddOrFind(op_args.db_cntx, key);
DbSlice::AddOrFindResult add_res;
try {
add_res = db_slice.AddOrFind(op_args.db_cntx, key);
} catch (const std::bad_alloc& e) {
return OpStatus::OUT_OF_MEMORY;
}
char buf[128];
@ -1351,7 +1366,7 @@ void StringFamily::SetRange(CmdArgList args, ConnectionContext* cntx) {
Transaction* trans = cntx->transaction;
OpResult<uint32_t> result = trans->ScheduleSingleHopT(std::move(cb));
if (result.status() == OpStatus::WRONG_TYPE) {
if (!result.ok()) {
cntx->SendError(result.status());
} else {
cntx->SendLong(result.value());

View file

@ -119,3 +119,21 @@ async def test_restricted_commands(df_local_factory):
async with aioredis.Redis(port=server.admin_port) as admin_client:
await admin_client.get("foo")
await admin_client.set("foo", "bar")
@pytest.mark.asyncio
async def test_reply_guard_oom(df_local_factory, df_seeder_factory):
master = df_local_factory.create(
proactor_threads=1, cache_mode="true", maxmemory="256mb", enable_heartbeat_eviction="false"
)
df_local_factory.start_all([master])
c_master = master.client()
await c_master.execute_command("DEBUG POPULATE 6000 size 44000")
seeder = df_seeder_factory.create(
port=master.port, keys=5000, val_size=1000, stop_on_failure=False
)
await seeder.run(target_deviation=0.1)
info = await c_master.info("stats")
assert info["evicted_keys"] > 0, "Weak testcase: policy based eviction was not triggered."