diff --git a/src/server/search/search_family.cc b/src/server/search/search_family.cc index 6ef550715..c710d03f5 100644 --- a/src/server/search/search_family.cc +++ b/src/server/search/search_family.cc @@ -226,49 +226,49 @@ search::QueryParams ParseQueryParams(CmdArgParser* parser) { return params; } -optional ParseSearchParamsOrReply(CmdArgParser parser, SinkReplyBuilder* builder) { +optional ParseSearchParamsOrReply(CmdArgParser* parser, SinkReplyBuilder* builder) { SearchParams params; - while (parser.HasNext()) { + while (parser->HasNext()) { // [LIMIT offset total] - if (parser.Check("LIMIT")) { - params.limit_offset = parser.Next(); - params.limit_total = parser.Next(); - } else if (parser.Check("LOAD")) { + if (parser->Check("LIMIT")) { + params.limit_offset = parser->Next(); + params.limit_total = parser->Next(); + } else if (parser->Check("LOAD")) { if (params.return_fields) { builder->SendError("LOAD cannot be applied after RETURN"); return std::nullopt; } - ParseLoadFields(&parser, ¶ms.load_fields); - } else if (parser.Check("RETURN")) { + ParseLoadFields(parser, ¶ms.load_fields); + } else if (parser->Check("RETURN")) { if (params.load_fields) { builder->SendError("RETURN cannot be applied after LOAD"); return std::nullopt; } // RETURN {num} [{ident} AS {name}...] - size_t num_fields = parser.Next(); + size_t num_fields = parser->Next(); params.return_fields.emplace(); while (params.return_fields->size() < num_fields) { - string_view ident = parser.Next(); - string_view alias = parser.Check("AS") ? parser.Next() : ident; + string_view ident = parser->Next(); + string_view alias = parser->Check("AS") ? parser->Next() : ident; params.return_fields->emplace_back(ident, alias); } - } else if (parser.Check("NOCONTENT")) { // NOCONTENT + } else if (parser->Check("NOCONTENT")) { // NOCONTENT params.load_fields.emplace(); params.return_fields.emplace(); - } else if (parser.Check("PARAMS")) { // [PARAMS num(ignored) name(ignored) knn_vector] - params.query_params = ParseQueryParams(&parser); - } else if (parser.Check("SORTBY")) { - params.sort_option = search::SortOption{string{parser.Next()}, bool(parser.Check("DESC"))}; + } else if (parser->Check("PARAMS")) { // [PARAMS num(ignored) name(ignored) knn_vector] + params.query_params = ParseQueryParams(parser); + } else if (parser->Check("SORTBY")) { + params.sort_option = search::SortOption{string{parser->Next()}, bool(parser->Check("DESC"))}; } else { // Unsupported parameters are ignored for now - parser.Skip(1); + parser->Skip(1); } } - if (auto err = parser.Error(); err) { + if (auto err = parser->Error(); err) { builder->SendError(err->MakeReply()); return nullopt; } @@ -716,10 +716,11 @@ void SearchFamily::FtList(CmdArgList args, Transaction* tx, SinkReplyBuilder* bu } void SearchFamily::FtSearch(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) { - string_view index_name = ArgS(args, 0); - string_view query_str = ArgS(args, 1); + CmdArgParser parser{args}; + string_view index_name = parser.Next(); + string_view query_str = parser.Next(); - auto params = ParseSearchParamsOrReply(args.subspan(2), builder); + auto params = ParseSearchParamsOrReply(&parser, builder); if (!params.has_value()) return; @@ -749,77 +750,129 @@ void SearchFamily::FtSearch(CmdArgList args, Transaction* tx, SinkReplyBuilder* } if (auto agg = search_algo.HasAggregation(); agg) - ReplySorted(std::move(*agg), *params, absl::MakeSpan(docs), builder); + ReplySorted(*agg, *params, absl::MakeSpan(docs), builder); else ReplyWithResults(*params, absl::MakeSpan(docs), builder); } void SearchFamily::FtProfile(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) { - string_view index_name = ArgS(args, 0); - string_view query_str = ArgS(args, 3); + CmdArgParser parser{args}; - optional params = ParseSearchParamsOrReply(args.subspan(4), builder); + string_view index_name = parser.Next(); + + if (!parser.Check("SEARCH") && !parser.Check("AGGREGATE")) { + return builder->SendError("no `SEARCH` or `AGGREGATE` provided"); + } + + parser.Check("LIMITED"); // TODO: Implement limited profiling + parser.ExpectTag("QUERY"); + + string_view query_str = parser.Next(); + + optional params = ParseSearchParamsOrReply(&parser, builder); if (!params.has_value()) return; search::SearchAlgorithm search_algo; search::SortOption* sort_opt = params->sort_option.has_value() ? &*params->sort_option : nullptr; if (!search_algo.Init(query_str, ¶ms->query_params, sort_opt)) - return builder->SendError("Query syntax error"); + return builder->SendError("query syntax error"); search_algo.EnableProfiling(); absl::Time start = absl::Now(); - atomic_uint total_docs = 0; - atomic_uint total_serialized = 0; + const size_t shards_count = shard_set->size(); - vector> results(shard_set->size()); + // Because our coordinator thread may not have a shard, we can't check ahead if the index exists. + std::atomic index_not_found{false}; + std::vector search_results(shards_count); + std::vector profile_results(shards_count); tx->ScheduleSingleHop([&](Transaction* t, EngineShard* es) { auto* index = es->search_indices()->GetIndex(index_name); - if (!index) + if (!index) { + index_not_found.store(true, memory_order_relaxed); return OpStatus::OK; + } + + const ShardId shard_id = es->shard_id(); auto shard_start = absl::Now(); - auto res = index->Search(t->GetOpArgs(es), *params, &search_algo); - - total_docs.fetch_add(res.total_hits); - total_serialized.fetch_add(res.docs.size()); - - DCHECK(res.profile); - results[es->shard_id()] = {std::move(*res.profile), absl::Now() - shard_start}; + search_results[shard_id] = index->Search(t->GetOpArgs(es), *params, &search_algo); + profile_results[shard_id] = {absl::Now() - shard_start}; return OpStatus::OK; }); + if (index_not_found.load()) + return builder->SendError(std::string{index_name} + ": no such index"); + auto took = absl::Now() - start; + + bool result_is_empty = false; + size_t total_docs = 0; + size_t total_serialized = 0; + for (const auto& result : search_results) { + if (!result.error) { + total_docs += result.total_hits; + total_serialized += result.docs.size(); + } else { + result_is_empty = true; + } + } + auto* rb = static_cast(builder); - rb->StartArray(results.size() + 1); + // First element -> Result of the search command + // Second element -> Profile information + rb->StartArray(2); + + // Result of the search command + if (!result_is_empty) { + auto agg = search_algo.HasAggregation(); + if (agg) { + ReplySorted(*agg, *params, absl::MakeSpan(search_results), builder); + } else { + ReplyWithResults(*params, absl::MakeSpan(search_results), builder); + } + } else { + rb->StartArray(1); + rb->SendLong(0); + } + + // Profile information + rb->StartArray(shards_count + 1); // General stats rb->StartCollection(3, RedisReplyBuilder::MAP); rb->SendBulkString("took"); rb->SendLong(absl::ToInt64Microseconds(took)); rb->SendBulkString("hits"); - rb->SendLong(total_docs); + rb->SendLong(static_cast(total_docs)); rb->SendBulkString("serialized"); - rb->SendLong(total_serialized); + rb->SendLong(static_cast(total_serialized)); // Per-shard stats - for (const auto& [profile, shard_took] : results) { + for (size_t shard_id = 0; shard_id < shards_count; shard_id++) { rb->StartCollection(2, RedisReplyBuilder::MAP); rb->SendBulkString("took"); - rb->SendLong(absl::ToInt64Microseconds(shard_took)); + rb->SendLong(absl::ToInt64Microseconds(profile_results[shard_id])); rb->SendBulkString("tree"); - for (size_t i = 0; i < profile.events.size(); i++) { - const auto& event = profile.events[i]; + const auto& search_result = search_results[shard_id]; + if (search_result.error || !search_result.profile || search_result.profile->events.empty()) { + rb->SendEmptyArray(); + continue; + } + + const auto& events = search_result.profile->events; + for (size_t i = 0; i < events.size(); i++) { + const auto& event = events[i]; size_t children = 0; - for (size_t j = i + 1; j < profile.events.size(); j++) { - if (profile.events[j].depth == event.depth) + for (size_t j = i + 1; j < events.size(); j++) { + if (events[j].depth == event.depth) break; - if (profile.events[j].depth == event.depth + 1) + if (events[j].depth == event.depth + 1) children++; } diff --git a/src/server/search/search_family_test.cc b/src/server/search/search_family_test.cc index dbe063ab2..7b1198035 100644 --- a/src/server/search/search_family_test.cc +++ b/src/server/search/search_family_test.cc @@ -22,6 +22,27 @@ class SearchFamilyTest : public BaseFamilyTest { const auto kNoResults = IntArg(0); // tests auto destruct single element arrays +/* Asserts that response is array of two arrays. Used to test FT.PROFILE response */ +::testing::AssertionResult AssertArrayOfTwoArrays(const RespExpr& resp) { + if (resp.GetVec().size() != 2) { + return ::testing::AssertionFailure() + << "Expected response array length to be 2, but was " << resp.GetVec().size(); + } + + const auto& vec = resp.GetVec(); + if (vec[0].type != RespExpr::ARRAY) { + return ::testing::AssertionFailure() + << "Expected resp[0] to be an array, but was " << vec[0].type; + } + if (vec[1].type != RespExpr::ARRAY) { + return ::testing::AssertionFailure() + << "Expected resp[1] to be an array, but was " << vec[1].type; + } + return ::testing::AssertionSuccess(); +} + +#define ASSERT_ARRAY_OF_TWO_ARRAYS(resp) ASSERT_PRED1(AssertArrayOfTwoArrays, resp) + MATCHER_P2(DocIds, total, arg_ids, "") { if (arg_ids.empty()) { if (auto res = arg.GetInt(); !res || *res != 0) { @@ -790,20 +811,55 @@ TEST_F(SearchFamilyTest, FtProfile) { Run({"ft.create", "i1", "schema", "name", "text"}); auto resp = Run({"ft.profile", "i1", "search", "query", "(a | b) c d"}); + ASSERT_ARRAY_OF_TWO_ARRAYS(resp); const auto& top_level = resp.GetVec(); - EXPECT_EQ(top_level.size(), shard_set->size() + 1); + EXPECT_THAT(top_level[0], IsMapWithSize()); - EXPECT_THAT(top_level[0].GetVec(), ElementsAre("took", _, "hits", _, "serialized", _)); + const auto& profile_result = top_level[1].GetVec(); + EXPECT_EQ(profile_result.size(), shard_set->size() + 1); + + EXPECT_THAT(profile_result[0].GetVec(), ElementsAre("took", _, "hits", _, "serialized", _)); for (size_t sid = 0; sid < shard_set->size(); sid++) { - const auto& shard_resp = top_level[sid + 1].GetVec(); + const auto& shard_resp = profile_result[sid + 1].GetVec(); EXPECT_THAT(shard_resp, ElementsAre("took", _, "tree", _)); const auto& tree = shard_resp[3].GetVec(); EXPECT_THAT(tree[0].GetString(), HasSubstr("Logical{n=3,o=and}"sv)); EXPECT_EQ(tree[1].GetVec().size(), 3); } + + // Test LIMITED throws no errors + resp = Run({"ft.profile", "i1", "search", "limited", "query", "(a | b) c d"}); + ASSERT_ARRAY_OF_TWO_ARRAYS(resp); +} + +TEST_F(SearchFamilyTest, FtProfileInvalidQuery) { + Run({"json.set", "j1", ".", R"({"id":"1"})"}); + Run({"ft.create", "i1", "on", "json", "schema", "$.id", "as", "id", "tag"}); + + auto resp = Run({"ft.profile", "i1", "search", "query", "@id:[1 1]"}); + ASSERT_ARRAY_OF_TWO_ARRAYS(resp); + + EXPECT_THAT(resp.GetVec()[0], IsMapWithSize()); + + resp = Run({"ft.profile", "i1", "search", "query", "@{invalid13289}"}); + EXPECT_THAT(resp, ErrArg("query syntax error")); +} + +TEST_F(SearchFamilyTest, FtProfileErrorReply) { + Run({"ft.create", "i1", "schema", "name", "text"}); + ; + + auto resp = Run({"ft.profile", "i1", "not_search", "query", "(a | b) c d"}); + EXPECT_THAT(resp, ErrArg("no `SEARCH` or `AGGREGATE` provided")); + + resp = Run({"ft.profile", "i1", "search", "not_query", "(a | b) c d"}); + EXPECT_THAT(resp, ErrArg("syntax error")); + + resp = Run({"ft.profile", "non_existent_key", "search", "query", "(a | b) c d"}); + EXPECT_THAT(resp, ErrArg("non_existent_key: no such index")); } TEST_F(SearchFamilyTest, SimpleExpiry) {