1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-14 11:58:02 +00:00

fix(search_family): Fix crash in FT.PROFILE command for invalid queries (#4043)

* refactor(search_family): Remove unnecessary std::move in FT.SEARCH

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>

* fix(search_family): Fix crash in FT.PROFILE command for invalid queries

fixes dragonflydb#3983

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>

* refactor(search_family_test): address comments

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>

---------

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>
This commit is contained in:
Stepan Bagritsevich 2024-11-04 18:18:12 +01:00 committed by GitHub
parent 9c2fc3fe63
commit 7ac853567b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 159 additions and 50 deletions

View file

@ -226,49 +226,49 @@ search::QueryParams ParseQueryParams(CmdArgParser* parser) {
return params; return params;
} }
optional<SearchParams> ParseSearchParamsOrReply(CmdArgParser parser, SinkReplyBuilder* builder) { optional<SearchParams> ParseSearchParamsOrReply(CmdArgParser* parser, SinkReplyBuilder* builder) {
SearchParams params; SearchParams params;
while (parser.HasNext()) { while (parser->HasNext()) {
// [LIMIT offset total] // [LIMIT offset total]
if (parser.Check("LIMIT")) { if (parser->Check("LIMIT")) {
params.limit_offset = parser.Next<size_t>(); params.limit_offset = parser->Next<size_t>();
params.limit_total = parser.Next<size_t>(); params.limit_total = parser->Next<size_t>();
} else if (parser.Check("LOAD")) { } else if (parser->Check("LOAD")) {
if (params.return_fields) { if (params.return_fields) {
builder->SendError("LOAD cannot be applied after RETURN"); builder->SendError("LOAD cannot be applied after RETURN");
return std::nullopt; return std::nullopt;
} }
ParseLoadFields(&parser, &params.load_fields); ParseLoadFields(parser, &params.load_fields);
} else if (parser.Check("RETURN")) { } else if (parser->Check("RETURN")) {
if (params.load_fields) { if (params.load_fields) {
builder->SendError("RETURN cannot be applied after LOAD"); builder->SendError("RETURN cannot be applied after LOAD");
return std::nullopt; return std::nullopt;
} }
// RETURN {num} [{ident} AS {name}...] // RETURN {num} [{ident} AS {name}...]
size_t num_fields = parser.Next<size_t>(); size_t num_fields = parser->Next<size_t>();
params.return_fields.emplace(); params.return_fields.emplace();
while (params.return_fields->size() < num_fields) { while (params.return_fields->size() < num_fields) {
string_view ident = parser.Next(); string_view ident = parser->Next();
string_view alias = parser.Check("AS") ? parser.Next() : ident; string_view alias = parser->Check("AS") ? parser->Next() : ident;
params.return_fields->emplace_back(ident, alias); params.return_fields->emplace_back(ident, alias);
} }
} else if (parser.Check("NOCONTENT")) { // NOCONTENT } else if (parser->Check("NOCONTENT")) { // NOCONTENT
params.load_fields.emplace(); params.load_fields.emplace();
params.return_fields.emplace(); params.return_fields.emplace();
} else if (parser.Check("PARAMS")) { // [PARAMS num(ignored) name(ignored) knn_vector] } else if (parser->Check("PARAMS")) { // [PARAMS num(ignored) name(ignored) knn_vector]
params.query_params = ParseQueryParams(&parser); params.query_params = ParseQueryParams(parser);
} else if (parser.Check("SORTBY")) { } else if (parser->Check("SORTBY")) {
params.sort_option = search::SortOption{string{parser.Next()}, bool(parser.Check("DESC"))}; params.sort_option = search::SortOption{string{parser->Next()}, bool(parser->Check("DESC"))};
} else { } else {
// Unsupported parameters are ignored for now // Unsupported parameters are ignored for now
parser.Skip(1); parser->Skip(1);
} }
} }
if (auto err = parser.Error(); err) { if (auto err = parser->Error(); err) {
builder->SendError(err->MakeReply()); builder->SendError(err->MakeReply());
return nullopt; return nullopt;
} }
@ -716,10 +716,11 @@ void SearchFamily::FtList(CmdArgList args, Transaction* tx, SinkReplyBuilder* bu
} }
void SearchFamily::FtSearch(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) { void SearchFamily::FtSearch(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
string_view index_name = ArgS(args, 0); CmdArgParser parser{args};
string_view query_str = ArgS(args, 1); string_view index_name = parser.Next();
string_view query_str = parser.Next();
auto params = ParseSearchParamsOrReply(args.subspan(2), builder); auto params = ParseSearchParamsOrReply(&parser, builder);
if (!params.has_value()) if (!params.has_value())
return; return;
@ -749,77 +750,129 @@ void SearchFamily::FtSearch(CmdArgList args, Transaction* tx, SinkReplyBuilder*
} }
if (auto agg = search_algo.HasAggregation(); agg) if (auto agg = search_algo.HasAggregation(); agg)
ReplySorted(std::move(*agg), *params, absl::MakeSpan(docs), builder); ReplySorted(*agg, *params, absl::MakeSpan(docs), builder);
else else
ReplyWithResults(*params, absl::MakeSpan(docs), builder); ReplyWithResults(*params, absl::MakeSpan(docs), builder);
} }
void SearchFamily::FtProfile(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) { void SearchFamily::FtProfile(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
string_view index_name = ArgS(args, 0); CmdArgParser parser{args};
string_view query_str = ArgS(args, 3);
optional<SearchParams> params = ParseSearchParamsOrReply(args.subspan(4), builder); string_view index_name = parser.Next();
if (!parser.Check("SEARCH") && !parser.Check("AGGREGATE")) {
return builder->SendError("no `SEARCH` or `AGGREGATE` provided");
}
parser.Check("LIMITED"); // TODO: Implement limited profiling
parser.ExpectTag("QUERY");
string_view query_str = parser.Next();
optional<SearchParams> params = ParseSearchParamsOrReply(&parser, builder);
if (!params.has_value()) if (!params.has_value())
return; return;
search::SearchAlgorithm search_algo; search::SearchAlgorithm search_algo;
search::SortOption* sort_opt = params->sort_option.has_value() ? &*params->sort_option : nullptr; search::SortOption* sort_opt = params->sort_option.has_value() ? &*params->sort_option : nullptr;
if (!search_algo.Init(query_str, &params->query_params, sort_opt)) if (!search_algo.Init(query_str, &params->query_params, sort_opt))
return builder->SendError("Query syntax error"); return builder->SendError("query syntax error");
search_algo.EnableProfiling(); search_algo.EnableProfiling();
absl::Time start = absl::Now(); absl::Time start = absl::Now();
atomic_uint total_docs = 0; const size_t shards_count = shard_set->size();
atomic_uint total_serialized = 0;
vector<pair<search::AlgorithmProfile, absl::Duration>> results(shard_set->size()); // Because our coordinator thread may not have a shard, we can't check ahead if the index exists.
std::atomic<bool> index_not_found{false};
std::vector<SearchResult> search_results(shards_count);
std::vector<absl::Duration> profile_results(shards_count);
tx->ScheduleSingleHop([&](Transaction* t, EngineShard* es) { tx->ScheduleSingleHop([&](Transaction* t, EngineShard* es) {
auto* index = es->search_indices()->GetIndex(index_name); auto* index = es->search_indices()->GetIndex(index_name);
if (!index) if (!index) {
index_not_found.store(true, memory_order_relaxed);
return OpStatus::OK; return OpStatus::OK;
}
const ShardId shard_id = es->shard_id();
auto shard_start = absl::Now(); auto shard_start = absl::Now();
auto res = index->Search(t->GetOpArgs(es), *params, &search_algo); search_results[shard_id] = index->Search(t->GetOpArgs(es), *params, &search_algo);
profile_results[shard_id] = {absl::Now() - shard_start};
total_docs.fetch_add(res.total_hits);
total_serialized.fetch_add(res.docs.size());
DCHECK(res.profile);
results[es->shard_id()] = {std::move(*res.profile), absl::Now() - shard_start};
return OpStatus::OK; return OpStatus::OK;
}); });
if (index_not_found.load())
return builder->SendError(std::string{index_name} + ": no such index");
auto took = absl::Now() - start; auto took = absl::Now() - start;
bool result_is_empty = false;
size_t total_docs = 0;
size_t total_serialized = 0;
for (const auto& result : search_results) {
if (!result.error) {
total_docs += result.total_hits;
total_serialized += result.docs.size();
} else {
result_is_empty = true;
}
}
auto* rb = static_cast<RedisReplyBuilder*>(builder); auto* rb = static_cast<RedisReplyBuilder*>(builder);
rb->StartArray(results.size() + 1); // First element -> Result of the search command
// Second element -> Profile information
rb->StartArray(2);
// Result of the search command
if (!result_is_empty) {
auto agg = search_algo.HasAggregation();
if (agg) {
ReplySorted(*agg, *params, absl::MakeSpan(search_results), builder);
} else {
ReplyWithResults(*params, absl::MakeSpan(search_results), builder);
}
} else {
rb->StartArray(1);
rb->SendLong(0);
}
// Profile information
rb->StartArray(shards_count + 1);
// General stats // General stats
rb->StartCollection(3, RedisReplyBuilder::MAP); rb->StartCollection(3, RedisReplyBuilder::MAP);
rb->SendBulkString("took"); rb->SendBulkString("took");
rb->SendLong(absl::ToInt64Microseconds(took)); rb->SendLong(absl::ToInt64Microseconds(took));
rb->SendBulkString("hits"); rb->SendBulkString("hits");
rb->SendLong(total_docs); rb->SendLong(static_cast<long>(total_docs));
rb->SendBulkString("serialized"); rb->SendBulkString("serialized");
rb->SendLong(total_serialized); rb->SendLong(static_cast<long>(total_serialized));
// Per-shard stats // Per-shard stats
for (const auto& [profile, shard_took] : results) { for (size_t shard_id = 0; shard_id < shards_count; shard_id++) {
rb->StartCollection(2, RedisReplyBuilder::MAP); rb->StartCollection(2, RedisReplyBuilder::MAP);
rb->SendBulkString("took"); rb->SendBulkString("took");
rb->SendLong(absl::ToInt64Microseconds(shard_took)); rb->SendLong(absl::ToInt64Microseconds(profile_results[shard_id]));
rb->SendBulkString("tree"); rb->SendBulkString("tree");
for (size_t i = 0; i < profile.events.size(); i++) { const auto& search_result = search_results[shard_id];
const auto& event = profile.events[i]; if (search_result.error || !search_result.profile || search_result.profile->events.empty()) {
rb->SendEmptyArray();
continue;
}
const auto& events = search_result.profile->events;
for (size_t i = 0; i < events.size(); i++) {
const auto& event = events[i];
size_t children = 0; size_t children = 0;
for (size_t j = i + 1; j < profile.events.size(); j++) { for (size_t j = i + 1; j < events.size(); j++) {
if (profile.events[j].depth == event.depth) if (events[j].depth == event.depth)
break; break;
if (profile.events[j].depth == event.depth + 1) if (events[j].depth == event.depth + 1)
children++; children++;
} }

View file

@ -22,6 +22,27 @@ class SearchFamilyTest : public BaseFamilyTest {
const auto kNoResults = IntArg(0); // tests auto destruct single element arrays const auto kNoResults = IntArg(0); // tests auto destruct single element arrays
/* Asserts that response is array of two arrays. Used to test FT.PROFILE response */
::testing::AssertionResult AssertArrayOfTwoArrays(const RespExpr& resp) {
if (resp.GetVec().size() != 2) {
return ::testing::AssertionFailure()
<< "Expected response array length to be 2, but was " << resp.GetVec().size();
}
const auto& vec = resp.GetVec();
if (vec[0].type != RespExpr::ARRAY) {
return ::testing::AssertionFailure()
<< "Expected resp[0] to be an array, but was " << vec[0].type;
}
if (vec[1].type != RespExpr::ARRAY) {
return ::testing::AssertionFailure()
<< "Expected resp[1] to be an array, but was " << vec[1].type;
}
return ::testing::AssertionSuccess();
}
#define ASSERT_ARRAY_OF_TWO_ARRAYS(resp) ASSERT_PRED1(AssertArrayOfTwoArrays, resp)
MATCHER_P2(DocIds, total, arg_ids, "") { MATCHER_P2(DocIds, total, arg_ids, "") {
if (arg_ids.empty()) { if (arg_ids.empty()) {
if (auto res = arg.GetInt(); !res || *res != 0) { if (auto res = arg.GetInt(); !res || *res != 0) {
@ -790,20 +811,55 @@ TEST_F(SearchFamilyTest, FtProfile) {
Run({"ft.create", "i1", "schema", "name", "text"}); Run({"ft.create", "i1", "schema", "name", "text"});
auto resp = Run({"ft.profile", "i1", "search", "query", "(a | b) c d"}); auto resp = Run({"ft.profile", "i1", "search", "query", "(a | b) c d"});
ASSERT_ARRAY_OF_TWO_ARRAYS(resp);
const auto& top_level = resp.GetVec(); const auto& top_level = resp.GetVec();
EXPECT_EQ(top_level.size(), shard_set->size() + 1); EXPECT_THAT(top_level[0], IsMapWithSize());
EXPECT_THAT(top_level[0].GetVec(), ElementsAre("took", _, "hits", _, "serialized", _)); const auto& profile_result = top_level[1].GetVec();
EXPECT_EQ(profile_result.size(), shard_set->size() + 1);
EXPECT_THAT(profile_result[0].GetVec(), ElementsAre("took", _, "hits", _, "serialized", _));
for (size_t sid = 0; sid < shard_set->size(); sid++) { for (size_t sid = 0; sid < shard_set->size(); sid++) {
const auto& shard_resp = top_level[sid + 1].GetVec(); const auto& shard_resp = profile_result[sid + 1].GetVec();
EXPECT_THAT(shard_resp, ElementsAre("took", _, "tree", _)); EXPECT_THAT(shard_resp, ElementsAre("took", _, "tree", _));
const auto& tree = shard_resp[3].GetVec(); const auto& tree = shard_resp[3].GetVec();
EXPECT_THAT(tree[0].GetString(), HasSubstr("Logical{n=3,o=and}"sv)); EXPECT_THAT(tree[0].GetString(), HasSubstr("Logical{n=3,o=and}"sv));
EXPECT_EQ(tree[1].GetVec().size(), 3); EXPECT_EQ(tree[1].GetVec().size(), 3);
} }
// Test LIMITED throws no errors
resp = Run({"ft.profile", "i1", "search", "limited", "query", "(a | b) c d"});
ASSERT_ARRAY_OF_TWO_ARRAYS(resp);
}
TEST_F(SearchFamilyTest, FtProfileInvalidQuery) {
Run({"json.set", "j1", ".", R"({"id":"1"})"});
Run({"ft.create", "i1", "on", "json", "schema", "$.id", "as", "id", "tag"});
auto resp = Run({"ft.profile", "i1", "search", "query", "@id:[1 1]"});
ASSERT_ARRAY_OF_TWO_ARRAYS(resp);
EXPECT_THAT(resp.GetVec()[0], IsMapWithSize());
resp = Run({"ft.profile", "i1", "search", "query", "@{invalid13289}"});
EXPECT_THAT(resp, ErrArg("query syntax error"));
}
TEST_F(SearchFamilyTest, FtProfileErrorReply) {
Run({"ft.create", "i1", "schema", "name", "text"});
;
auto resp = Run({"ft.profile", "i1", "not_search", "query", "(a | b) c d"});
EXPECT_THAT(resp, ErrArg("no `SEARCH` or `AGGREGATE` provided"));
resp = Run({"ft.profile", "i1", "search", "not_query", "(a | b) c d"});
EXPECT_THAT(resp, ErrArg("syntax error"));
resp = Run({"ft.profile", "non_existent_key", "search", "query", "(a | b) c d"});
EXPECT_THAT(resp, ErrArg("non_existent_key: no such index"));
} }
TEST_F(SearchFamilyTest, SimpleExpiry) { TEST_F(SearchFamilyTest, SimpleExpiry) {