1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-14 11:58:02 +00:00

fix(search_family): Remove the output of extra fields in the FT.AGGREGATE command (#4231)

* fix(search_family): Remove the output of extra fields in the FT.AGGREGATE command

fixes dragonflydb#4230

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>

* refactor: address comments

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>

---------

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>
This commit is contained in:
Stepan Bagritsevich 2024-12-11 15:21:20 +04:00 committed by GitHub
parent 1e3d9de0f6
commit 76f79f0e0b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 124 additions and 63 deletions

View file

@ -11,10 +11,10 @@ namespace dfly::aggregate {
namespace { namespace {
struct GroupStep { struct GroupStep {
PipelineResult operator()(std::vector<DocValues> values) { PipelineResult operator()(PipelineResult result) {
// Separate items into groups // Separate items into groups
absl::flat_hash_map<absl::FixedArray<Value>, std::vector<DocValues>> groups; absl::flat_hash_map<absl::FixedArray<Value>, std::vector<DocValues>> groups;
for (auto& value : values) { for (auto& value : result.values) {
groups[Extract(value)].push_back(std::move(value)); groups[Extract(value)].push_back(std::move(value));
} }
@ -28,7 +28,18 @@ struct GroupStep {
} }
out.push_back(std::move(doc)); out.push_back(std::move(doc));
} }
return out;
absl::flat_hash_set<std::string> fields_to_print;
fields_to_print.reserve(fields_.size() + reducers_.size());
for (auto& field : fields_) {
fields_to_print.insert(std::move(field));
}
for (auto& reducer : reducers_) {
fields_to_print.insert(std::move(reducer.result_field));
}
return {std::move(out), std::move(fields_to_print)};
} }
absl::FixedArray<Value> Extract(const DocValues& dv) { absl::FixedArray<Value> Extract(const DocValues& dv) {
@ -104,34 +115,42 @@ PipelineStep MakeGroupStep(absl::Span<const std::string_view> fields,
} }
PipelineStep MakeSortStep(std::string_view field, bool descending) { PipelineStep MakeSortStep(std::string_view field, bool descending) {
return [field = std::string(field), descending](std::vector<DocValues> values) -> PipelineResult { return [field = std::string(field), descending](PipelineResult result) -> PipelineResult {
auto& values = result.values;
std::sort(values.begin(), values.end(), [field](const DocValues& l, const DocValues& r) { std::sort(values.begin(), values.end(), [field](const DocValues& l, const DocValues& r) {
auto it1 = l.find(field); auto it1 = l.find(field);
auto it2 = r.find(field); auto it2 = r.find(field);
return it1 == l.end() || (it2 != r.end() && it1->second < it2->second); return it1 == l.end() || (it2 != r.end() && it1->second < it2->second);
}); });
if (descending)
if (descending) {
std::reverse(values.begin(), values.end()); std::reverse(values.begin(), values.end());
return values; }
result.fields_to_print.insert(field);
return result;
}; };
} }
PipelineStep MakeLimitStep(size_t offset, size_t num) { PipelineStep MakeLimitStep(size_t offset, size_t num) {
return [offset, num](std::vector<DocValues> values) -> PipelineResult { return [offset, num](PipelineResult result) {
auto& values = result.values;
values.erase(values.begin(), values.begin() + std::min(offset, values.size())); values.erase(values.begin(), values.begin() + std::min(offset, values.size()));
values.resize(std::min(num, values.size())); values.resize(std::min(num, values.size()));
return values; return result;
}; };
} }
PipelineResult Process(std::vector<DocValues> values, absl::Span<const PipelineStep> steps) { PipelineResult Process(std::vector<DocValues> values,
absl::Span<const std::string_view> fields_to_print,
absl::Span<const PipelineStep> steps) {
PipelineResult result{std::move(values), {fields_to_print.begin(), fields_to_print.end()}};
for (auto& step : steps) { for (auto& step : steps) {
auto result = step(std::move(values)); PipelineResult step_result = step(std::move(result));
if (!result.has_value()) result = std::move(step_result);
return result;
values = std::move(result.value());
} }
return values; return result;
} }
} // namespace dfly::aggregate } // namespace dfly::aggregate

View file

@ -5,6 +5,7 @@
#pragma once #pragma once
#include <absl/container/flat_hash_map.h> #include <absl/container/flat_hash_map.h>
#include <absl/container/flat_hash_set.h>
#include <absl/types/span.h> #include <absl/types/span.h>
#include <string> #include <string>
@ -19,10 +20,16 @@ namespace dfly::aggregate {
using Value = ::dfly::search::SortableValue; using Value = ::dfly::search::SortableValue;
using DocValues = absl::flat_hash_map<std::string, Value>; // documents sent through the pipeline using DocValues = absl::flat_hash_map<std::string, Value>; // documents sent through the pipeline
// TODO: Replace DocValues with compact linear search map instead of hash map struct PipelineResult {
// Values to be passed to the next step
// TODO: Replace DocValues with compact linear search map instead of hash map
std::vector<DocValues> values;
using PipelineResult = io::Result<std::vector<DocValues>, facade::ErrorReply>; // Fields from values to be printed
using PipelineStep = std::function<PipelineResult(std::vector<DocValues>)>; // Group, Sort, etc. absl::flat_hash_set<std::string> fields_to_print;
};
using PipelineStep = std::function<PipelineResult(PipelineResult)>; // Group, Sort, etc.
// Iterator over Span<DocValues> that yields doc[field] or monostate if not present. // Iterator over Span<DocValues> that yields doc[field] or monostate if not present.
// Extra clumsy for STL compatibility! // Extra clumsy for STL compatibility!
@ -82,6 +89,8 @@ PipelineStep MakeSortStep(std::string_view field, bool descending = false);
PipelineStep MakeLimitStep(size_t offset, size_t num); PipelineStep MakeLimitStep(size_t offset, size_t num);
// Process values with given steps // Process values with given steps
PipelineResult Process(std::vector<DocValues> values, absl::Span<const PipelineStep> steps); PipelineResult Process(std::vector<DocValues> values,
absl::Span<const std::string_view> fields_to_print,
absl::Span<const PipelineStep> steps);
} // namespace dfly::aggregate } // namespace dfly::aggregate

View file

@ -18,12 +18,11 @@ TEST(AggregatorTest, Sort) {
}; };
PipelineStep steps[] = {MakeSortStep("a", false)}; PipelineStep steps[] = {MakeSortStep("a", false)};
auto result = Process(values, steps); auto result = Process(values, {"a"}, steps);
EXPECT_TRUE(result); EXPECT_EQ(result.values[0]["a"], Value(0.5));
EXPECT_EQ(result->at(0)["a"], Value(0.5)); EXPECT_EQ(result.values[1]["a"], Value(1.0));
EXPECT_EQ(result->at(1)["a"], Value(1.0)); EXPECT_EQ(result.values[2]["a"], Value(1.5));
EXPECT_EQ(result->at(2)["a"], Value(1.5));
} }
TEST(AggregatorTest, Limit) { TEST(AggregatorTest, Limit) {
@ -35,12 +34,11 @@ TEST(AggregatorTest, Limit) {
}; };
PipelineStep steps[] = {MakeLimitStep(1, 2)}; PipelineStep steps[] = {MakeLimitStep(1, 2)};
auto result = Process(values, steps); auto result = Process(values, {"i"}, steps);
EXPECT_TRUE(result); EXPECT_EQ(result.values.size(), 2);
EXPECT_EQ(result->size(), 2); EXPECT_EQ(result.values[0]["i"], Value(2.0));
EXPECT_EQ(result->at(0)["i"], Value(2.0)); EXPECT_EQ(result.values[1]["i"], Value(3.0));
EXPECT_EQ(result->at(1)["i"], Value(3.0));
} }
TEST(AggregatorTest, SimpleGroup) { TEST(AggregatorTest, SimpleGroup) {
@ -54,12 +52,11 @@ TEST(AggregatorTest, SimpleGroup) {
std::string_view fields[] = {"tag"}; std::string_view fields[] = {"tag"};
PipelineStep steps[] = {MakeGroupStep(fields, {})}; PipelineStep steps[] = {MakeGroupStep(fields, {})};
auto result = Process(values, steps); auto result = Process(values, {"i", "tag"}, steps);
EXPECT_TRUE(result); EXPECT_EQ(result.values.size(), 2);
EXPECT_EQ(result->size(), 2);
EXPECT_EQ(result->at(0).size(), 1); EXPECT_EQ(result.values[0].size(), 1);
std::set<Value> groups{result->at(0)["tag"], result->at(1)["tag"]}; std::set<Value> groups{result.values[0]["tag"], result.values[1]["tag"]};
std::set<Value> expected{"even", "odd"}; std::set<Value> expected{"even", "odd"};
EXPECT_EQ(groups, expected); EXPECT_EQ(groups, expected);
} }
@ -83,25 +80,24 @@ TEST(AggregatorTest, GroupWithReduce) {
Reducer{"null-field", "distinct-null", FindReducerFunc(ReducerFunc::COUNT_DISTINCT)}}; Reducer{"null-field", "distinct-null", FindReducerFunc(ReducerFunc::COUNT_DISTINCT)}};
PipelineStep steps[] = {MakeGroupStep(fields, std::move(reducers))}; PipelineStep steps[] = {MakeGroupStep(fields, std::move(reducers))};
auto result = Process(values, steps); auto result = Process(values, {"i", "half-i", "tag"}, steps);
EXPECT_TRUE(result); EXPECT_EQ(result.values.size(), 2);
EXPECT_EQ(result->size(), 2);
// Reorder even first // Reorder even first
if (result->at(0).at("tag") == Value("odd")) if (result.values[0].at("tag") == Value("odd"))
std::swap(result->at(0), result->at(1)); std::swap(result.values[0], result.values[1]);
// Even // Even
EXPECT_EQ(result->at(0).at("count"), Value{(double)5}); EXPECT_EQ(result.values[0].at("count"), Value{(double)5});
EXPECT_EQ(result->at(0).at("sum-i"), Value{(double)2 + 4 + 6 + 8}); EXPECT_EQ(result.values[0].at("sum-i"), Value{(double)2 + 4 + 6 + 8});
EXPECT_EQ(result->at(0).at("distinct-hi"), Value{(double)3}); EXPECT_EQ(result.values[0].at("distinct-hi"), Value{(double)3});
EXPECT_EQ(result->at(0).at("distinct-null"), Value{(double)1}); EXPECT_EQ(result.values[0].at("distinct-null"), Value{(double)1});
// Odd // Odd
EXPECT_EQ(result->at(1).at("count"), Value{(double)5}); EXPECT_EQ(result.values[1].at("count"), Value{(double)5});
EXPECT_EQ(result->at(1).at("sum-i"), Value{(double)1 + 3 + 5 + 7 + 9}); EXPECT_EQ(result.values[1].at("sum-i"), Value{(double)1 + 3 + 5 + 7 + 9});
EXPECT_EQ(result->at(1).at("distinct-hi"), Value{(double)3}); EXPECT_EQ(result.values[1].at("distinct-hi"), Value{(double)3});
EXPECT_EQ(result->at(1).at("distinct-null"), Value{(double)1}); EXPECT_EQ(result.values[1].at("distinct-null"), Value{(double)1});
} }
} // namespace dfly::aggregate } // namespace dfly::aggregate

View file

@ -981,22 +981,34 @@ void SearchFamily::FtAggregate(CmdArgList args, const CommandContext& cmd_cntx)
make_move_iterator(sub_results.end())); make_move_iterator(sub_results.end()));
} }
auto agg_results = aggregate::Process(std::move(values), params->steps); std::vector<std::string_view> load_fields;
if (!agg_results.has_value()) if (params->load_fields) {
return builder->SendError(agg_results.error()); load_fields.reserve(params->load_fields->size());
for (const auto& field : params->load_fields.value()) {
load_fields.push_back(field.GetShortName());
}
}
auto agg_results = aggregate::Process(std::move(values), load_fields, params->steps);
size_t result_size = agg_results->size();
auto* rb = static_cast<RedisReplyBuilder*>(cmd_cntx.rb); auto* rb = static_cast<RedisReplyBuilder*>(cmd_cntx.rb);
auto sortable_value_sender = SortableValueSender(rb); auto sortable_value_sender = SortableValueSender(rb);
const size_t result_size = agg_results.values.size();
rb->StartArray(result_size + 1); rb->StartArray(result_size + 1);
rb->SendLong(result_size); rb->SendLong(result_size);
for (const auto& result : agg_results.value()) { const size_t field_count = agg_results.fields_to_print.size();
rb->StartArray(result.size() * 2); for (const auto& value : agg_results.values) {
for (const auto& [k, v] : result) { rb->StartArray(field_count * 2);
rb->SendBulkString(k); for (const auto& field : agg_results.fields_to_print) {
std::visit(sortable_value_sender, v); rb->SendBulkString(field);
if (auto it = value.find(field); it != value.end()) {
std::visit(sortable_value_sender, it->second);
} else {
rb->SendNull();
}
} }
} }
} }

View file

@ -962,15 +962,12 @@ TEST_F(SearchFamilyTest, AggregateGroupBy) {
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("foo_total", "20", "word", "item2"), EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("foo_total", "20", "word", "item2"),
IsMap("foo_total", "50", "word", "item1"))); IsMap("foo_total", "50", "word", "item1")));
/*
Temporary not supported
resp = Run({"ft.aggregate", "i1", "*", "LOAD", "2", "foo", "text", "GROUPBY", "2", "@word", resp = Run({"ft.aggregate", "i1", "*", "LOAD", "2", "foo", "text", "GROUPBY", "2", "@word",
"@text", "REDUCE", "SUM", "1", "@foo", "AS", "foo_total"}); EXPECT_THAT(resp, "@text", "REDUCE", "SUM", "1", "@foo", "AS", "foo_total"});
IsUnordArrayWithSize(IsMap("foo_total", "20", "word", ArgType(RespExpr::NIL), "text", "\"second EXPECT_THAT(resp, IsUnordArrayWithSize(
key\""), IsMap("foo_total", "40", "word", ArgType(RespExpr::NIL), "text", "\"third key\""), IsMap("foo_total", "40", "word", "item1", "text", "\"third key\""),
IsMap({"foo_total", "10", "word", ArgType(RespExpr::NIL), "text", "\"first key"}))); IsMap("foo_total", "20", "word", "item2", "text", "\"second key\""),
*/ IsMap("foo_total", "10", "word", "item1", "text", "\"first key\"")));
} }
TEST_F(SearchFamilyTest, JsonAggregateGroupBy) { TEST_F(SearchFamilyTest, JsonAggregateGroupBy) {
@ -1632,4 +1629,32 @@ TEST_F(SearchFamilyTest, SearchLoadReturnHash) {
EXPECT_THAT(resp, IsMapWithSize("h2", IsMap("a", "two"), "h1", IsMap("a", "one"))); EXPECT_THAT(resp, IsMapWithSize("h2", IsMap("a", "two"), "h1", IsMap("a", "one")));
} }
// Test that FT.AGGREGATE prints only needed fields
TEST_F(SearchFamilyTest, AggregateResultFields) {
Run({"JSON.SET", "j1", ".", R"({"a":"1","b":"2","c":"3"})"});
Run({"JSON.SET", "j2", ".", R"({"a":"4","b":"5","c":"6"})"});
Run({"JSON.SET", "j3", ".", R"({"a":"7","b":"8","c":"9"})"});
auto resp = Run({"FT.CREATE", "index", "ON", "JSON", "SCHEMA", "$.a", "AS", "a", "TEXT",
"SORTABLE", "$.b", "AS", "b", "TEXT", "$.c", "AS", "c", "TEXT"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.AGGREGATE", "index", "*"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap(), IsMap(), IsMap()));
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "a"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("a", "1"), IsMap("a", "4"), IsMap("a", "7")));
resp = Run({"FT.AGGREGATE", "index", "*", "LOAD", "1", "@b", "SORTBY", "1", "a"});
EXPECT_THAT(resp,
IsUnordArrayWithSize(IsMap("b", "\"2\"", "a", "1"), IsMap("b", "\"5\"", "a", "4"),
IsMap("b", "\"8\"", "a", "7")));
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "a", "GROUPBY", "2", "@b", "@a",
"REDUCE", "COUNT", "0", "AS", "count"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("b", "\"8\"", "a", "7", "count", "1"),
IsMap("b", "\"2\"", "a", "1", "count", "1"),
IsMap("b", "\"5\"", "a", "4", "count", "1")));
}
} // namespace dfly } // namespace dfly