1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-14 11:58:02 +00:00

fix(server): fix JSON.MGET implementation (#849) (#876)

fix(server): fix json.mget implementation (#849)

Signed-off-by: iko1 <me@remotecpp.dev>
This commit is contained in:
iko1 2023-03-03 00:16:35 +02:00 committed by GitHub
parent bcc3d3ec4f
commit 04f4362c72
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 94 additions and 40 deletions

View file

@ -884,33 +884,50 @@ OpResult<vector<OptLong>> OpArrIndex(const OpArgs& op_args, string_view key,
}
// Returns string vector that represents the query result of each supplied key.
OpResult<vector<OptString>> OpMGet(const OpArgs& op_args, const vector<string_view>& keys,
JsonExpression expression) {
vector<OptString> vec;
for (auto& it : keys) {
// OpResult<JsonType> result = GetJson(op_args, it);
OpResult<JsonType*> result = GetJson(op_args, it);
if (!result) {
vec.emplace_back();
vector<OptString> OpMGet(JsonExpression expression, const Transaction* t, EngineShard* shard) {
auto args = t->GetShardArgs(shard->shard_id());
DCHECK(!args.empty());
vector<OptString> response(args.size());
auto& db_slice = shard->db_slice();
for (size_t i = 0; i < args.size(); ++i) {
OpResult<PrimeIterator> it_res = db_slice.Find(t->GetDbContext(), args[i], OBJ_JSON);
if (!it_res.ok())
continue;
auto& dest = response[i].emplace();
JsonType* json_val = it_res.value()->second.GetJson();
DCHECK(json_val) << "should have a valid JSON object for key " << args[i];
vector<JsonType> query_result;
auto cb = [&query_result](const string_view& path, const JsonType& val) {
query_result.push_back(val);
};
const JsonType& json_entry = *(json_val);
expression.evaluate(json_entry, cb);
if (query_result.empty()) {
continue;
}
auto cb = [&vec](const string_view& path, const JsonType& val) {
string str;
error_code ec;
val.dump(str, {}, ec);
if (ec) {
VLOG(1) << "Failed to dump JSON to string with the error: " << ec.message();
return;
}
JsonType arr(json_array_arg);
arr.reserve(query_result.size());
for (auto& s : query_result) {
arr.push_back(s);
}
vec.push_back(move(str));
};
const JsonType& json_entry = *(result.value());
expression.evaluate(json_entry, cb);
string str;
error_code ec;
arr.dump(str, {}, ec);
if (ec) {
VLOG(1) << "Failed to dump JSON array to string with the error: " << ec.message();
}
dest = move(str);
}
return vec;
return response;
}
// Returns numeric vector that represents the number of fields of JSON value at each path.
@ -1091,6 +1108,8 @@ void JsonFamily::Debug(CmdArgList args, ConnectionContext* cntx) {
}
void JsonFamily::MGet(CmdArgList args, ConnectionContext* cntx) {
DCHECK_GE(args.size(), 2U);
error_code ec;
string_view path = ArgS(args, args.size() - 1);
JsonExpression expression = jsonpath::make_expression<JsonType>(path, ec);
@ -1101,29 +1120,46 @@ void JsonFamily::MGet(CmdArgList args, ConnectionContext* cntx) {
return;
}
vector<string_view> vec;
for (auto i = 1U; i < args.size() - 1; i++) {
vec.emplace_back(ArgS(args, i));
}
Transaction* transaction = cntx->transaction;
unsigned shard_count = shard_set->size();
std::vector<vector<OptString>> mget_resp(shard_count);
auto cb = [&](Transaction* t, EngineShard* shard) {
return OpMGet(t->GetOpArgs(shard), vec, move(expression));
ShardId sid = shard->shard_id();
mget_resp[sid] = OpMGet(jsonpath::make_expression<JsonType>(path, ec), t, shard);
return OpStatus::OK;
};
Transaction* trans = cntx->transaction;
OpResult<vector<OptString>> result = trans->ScheduleSingleHopT(move(cb));
OpStatus result = transaction->ScheduleSingleHop(std::move(cb));
CHECK_EQ(OpStatus::OK, result);
if (result) {
(*cntx)->StartArray(result->size());
for (auto& it : *result) {
if (!it) {
(*cntx)->SendNull();
} else {
(*cntx)->SendSimpleString(*it);
}
std::vector<OptString> results(args.size() - 2);
for (ShardId sid = 0; sid < shard_count; ++sid) {
if (!transaction->IsActive(sid))
continue;
vector<OptString>& res = mget_resp[sid];
ArgSlice slice = transaction->GetShardArgs(sid);
DCHECK(!slice.empty());
DCHECK_EQ(slice.size(), res.size());
for (size_t j = 0; j < slice.size(); ++j) {
if (!res[j])
continue;
uint32_t indx = transaction->ReverseArgIndex(sid, j);
results[indx] = move(res[j]);
}
}
(*cntx)->StartArray(results.size());
for (auto& it : results) {
if (!it) {
(*cntx)->SendNull();
} else {
(*cntx)->SendBulkString(*it);
}
} else {
(*cntx)->SendError(result.status());
}
}
@ -1645,7 +1681,8 @@ void JsonFamily::Get(CmdArgList args, ConnectionContext* cntx) {
void JsonFamily::Register(CommandRegistry* registry) {
*registry << CI{"JSON.GET", CO::READONLY | CO::FAST, -2, 1, 1, 1}.HFUNC(Get);
*registry << CI{"JSON.MGET", CO::READONLY | CO::FAST, -3, 1, 1, 1}.HFUNC(MGet);
*registry << CI{"JSON.MGET", CO::READONLY | CO::FAST | CO::REVERSE_MAPPING, -3, 1, -2, 1}.HFUNC(
MGet);
*registry << CI{"JSON.TYPE", CO::READONLY | CO::FAST, 3, 1, 1, 1}.HFUNC(Type);
*registry << CI{"JSON.STRLEN", CO::READONLY | CO::FAST, 3, 1, 1, 1}.HFUNC(StrLen);
*registry << CI{"JSON.OBJLEN", CO::READONLY | CO::FAST, 3, 1, 1, 1}.HFUNC(ObjLen);

View file

@ -864,6 +864,12 @@ TEST_F(JsonFamilyTest, MGet) {
)",
R"(
{"address":{"street":"Oranienburger Str. 27","city":"Berlin","country":"Germany","zipcode":"10117"}}
)",
R"(
{"a":1, "b": 2, "nested": {"a": 3}, "c": null}
)",
R"(
{"a":4, "b": 5, "nested": {"a": 6}, "c": null}
)"};
auto resp = Run({"JSON.SET", "json1", ".", json[0]});
@ -874,7 +880,18 @@ TEST_F(JsonFamilyTest, MGet) {
resp = Run({"JSON.MGET", "json1", "json2", "json3", "$.address.country"});
ASSERT_EQ(RespExpr::ARRAY, resp.type);
EXPECT_THAT(resp.GetVec(), ElementsAre(R"("Israel")", R"("Germany")", ArgType(RespExpr::NIL)));
EXPECT_THAT(resp.GetVec(),
ElementsAre(R"(["Israel"])", R"(["Germany"])", ArgType(RespExpr::NIL)));
resp = Run({"JSON.SET", "json3", ".", json[2]});
ASSERT_THAT(resp, "OK");
resp = Run({"JSON.SET", "json4", ".", json[3]});
ASSERT_THAT(resp, "OK");
resp = Run({"JSON.MGET", "json3", "json4", "$..a"});
ASSERT_EQ(RespExpr::ARRAY, resp.type);
EXPECT_THAT(resp.GetVec(), ElementsAre(R"([1,3])", R"([4,6])"));
}
TEST_F(JsonFamilyTest, DebugFields) {