1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-14 11:58:02 +00:00

fix(search_family): Remove the output of extra fields in the FT.AGGREGATE command (#4231)

* fix(search_family): Remove the output of extra fields in the FT.AGGREGATE command

fixes dragonflydb#4230

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>

* refactor: address comments

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>

---------

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>
This commit is contained in:
Stepan Bagritsevich 2024-12-11 15:21:20 +04:00 committed by GitHub
parent 1e3d9de0f6
commit 76f79f0e0b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 124 additions and 63 deletions

View file

@ -11,10 +11,10 @@ namespace dfly::aggregate {
namespace {
struct GroupStep {
PipelineResult operator()(std::vector<DocValues> values) {
PipelineResult operator()(PipelineResult result) {
// Separate items into groups
absl::flat_hash_map<absl::FixedArray<Value>, std::vector<DocValues>> groups;
for (auto& value : values) {
for (auto& value : result.values) {
groups[Extract(value)].push_back(std::move(value));
}
@ -28,7 +28,18 @@ struct GroupStep {
}
out.push_back(std::move(doc));
}
return out;
absl::flat_hash_set<std::string> fields_to_print;
fields_to_print.reserve(fields_.size() + reducers_.size());
for (auto& field : fields_) {
fields_to_print.insert(std::move(field));
}
for (auto& reducer : reducers_) {
fields_to_print.insert(std::move(reducer.result_field));
}
return {std::move(out), std::move(fields_to_print)};
}
absl::FixedArray<Value> Extract(const DocValues& dv) {
@ -104,34 +115,42 @@ PipelineStep MakeGroupStep(absl::Span<const std::string_view> fields,
}
PipelineStep MakeSortStep(std::string_view field, bool descending) {
return [field = std::string(field), descending](std::vector<DocValues> values) -> PipelineResult {
return [field = std::string(field), descending](PipelineResult result) -> PipelineResult {
auto& values = result.values;
std::sort(values.begin(), values.end(), [field](const DocValues& l, const DocValues& r) {
auto it1 = l.find(field);
auto it2 = r.find(field);
return it1 == l.end() || (it2 != r.end() && it1->second < it2->second);
});
if (descending)
if (descending) {
std::reverse(values.begin(), values.end());
return values;
}
result.fields_to_print.insert(field);
return result;
};
}
PipelineStep MakeLimitStep(size_t offset, size_t num) {
return [offset, num](std::vector<DocValues> values) -> PipelineResult {
return [offset, num](PipelineResult result) {
auto& values = result.values;
values.erase(values.begin(), values.begin() + std::min(offset, values.size()));
values.resize(std::min(num, values.size()));
return values;
return result;
};
}
PipelineResult Process(std::vector<DocValues> values, absl::Span<const PipelineStep> steps) {
PipelineResult Process(std::vector<DocValues> values,
absl::Span<const std::string_view> fields_to_print,
absl::Span<const PipelineStep> steps) {
PipelineResult result{std::move(values), {fields_to_print.begin(), fields_to_print.end()}};
for (auto& step : steps) {
auto result = step(std::move(values));
if (!result.has_value())
return result;
values = std::move(result.value());
PipelineResult step_result = step(std::move(result));
result = std::move(step_result);
}
return values;
return result;
}
} // namespace dfly::aggregate

View file

@ -5,6 +5,7 @@
#pragma once
#include <absl/container/flat_hash_map.h>
#include <absl/container/flat_hash_set.h>
#include <absl/types/span.h>
#include <string>
@ -19,10 +20,16 @@ namespace dfly::aggregate {
using Value = ::dfly::search::SortableValue;
using DocValues = absl::flat_hash_map<std::string, Value>; // documents sent through the pipeline
// TODO: Replace DocValues with compact linear search map instead of hash map
struct PipelineResult {
// Values to be passed to the next step
// TODO: Replace DocValues with compact linear search map instead of hash map
std::vector<DocValues> values;
using PipelineResult = io::Result<std::vector<DocValues>, facade::ErrorReply>;
using PipelineStep = std::function<PipelineResult(std::vector<DocValues>)>; // Group, Sort, etc.
// Fields from values to be printed
absl::flat_hash_set<std::string> fields_to_print;
};
using PipelineStep = std::function<PipelineResult(PipelineResult)>; // Group, Sort, etc.
// Iterator over Span<DocValues> that yields doc[field] or monostate if not present.
// Extra clumsy for STL compatibility!
@ -82,6 +89,8 @@ PipelineStep MakeSortStep(std::string_view field, bool descending = false);
PipelineStep MakeLimitStep(size_t offset, size_t num);
// Process values with given steps
PipelineResult Process(std::vector<DocValues> values, absl::Span<const PipelineStep> steps);
PipelineResult Process(std::vector<DocValues> values,
absl::Span<const std::string_view> fields_to_print,
absl::Span<const PipelineStep> steps);
} // namespace dfly::aggregate

View file

@ -18,12 +18,11 @@ TEST(AggregatorTest, Sort) {
};
PipelineStep steps[] = {MakeSortStep("a", false)};
auto result = Process(values, steps);
auto result = Process(values, {"a"}, steps);
EXPECT_TRUE(result);
EXPECT_EQ(result->at(0)["a"], Value(0.5));
EXPECT_EQ(result->at(1)["a"], Value(1.0));
EXPECT_EQ(result->at(2)["a"], Value(1.5));
EXPECT_EQ(result.values[0]["a"], Value(0.5));
EXPECT_EQ(result.values[1]["a"], Value(1.0));
EXPECT_EQ(result.values[2]["a"], Value(1.5));
}
TEST(AggregatorTest, Limit) {
@ -35,12 +34,11 @@ TEST(AggregatorTest, Limit) {
};
PipelineStep steps[] = {MakeLimitStep(1, 2)};
auto result = Process(values, steps);
auto result = Process(values, {"i"}, steps);
EXPECT_TRUE(result);
EXPECT_EQ(result->size(), 2);
EXPECT_EQ(result->at(0)["i"], Value(2.0));
EXPECT_EQ(result->at(1)["i"], Value(3.0));
EXPECT_EQ(result.values.size(), 2);
EXPECT_EQ(result.values[0]["i"], Value(2.0));
EXPECT_EQ(result.values[1]["i"], Value(3.0));
}
TEST(AggregatorTest, SimpleGroup) {
@ -54,12 +52,11 @@ TEST(AggregatorTest, SimpleGroup) {
std::string_view fields[] = {"tag"};
PipelineStep steps[] = {MakeGroupStep(fields, {})};
auto result = Process(values, steps);
EXPECT_TRUE(result);
EXPECT_EQ(result->size(), 2);
auto result = Process(values, {"i", "tag"}, steps);
EXPECT_EQ(result.values.size(), 2);
EXPECT_EQ(result->at(0).size(), 1);
std::set<Value> groups{result->at(0)["tag"], result->at(1)["tag"]};
EXPECT_EQ(result.values[0].size(), 1);
std::set<Value> groups{result.values[0]["tag"], result.values[1]["tag"]};
std::set<Value> expected{"even", "odd"};
EXPECT_EQ(groups, expected);
}
@ -83,25 +80,24 @@ TEST(AggregatorTest, GroupWithReduce) {
Reducer{"null-field", "distinct-null", FindReducerFunc(ReducerFunc::COUNT_DISTINCT)}};
PipelineStep steps[] = {MakeGroupStep(fields, std::move(reducers))};
auto result = Process(values, steps);
EXPECT_TRUE(result);
EXPECT_EQ(result->size(), 2);
auto result = Process(values, {"i", "half-i", "tag"}, steps);
EXPECT_EQ(result.values.size(), 2);
// Reorder even first
if (result->at(0).at("tag") == Value("odd"))
std::swap(result->at(0), result->at(1));
if (result.values[0].at("tag") == Value("odd"))
std::swap(result.values[0], result.values[1]);
// Even
EXPECT_EQ(result->at(0).at("count"), Value{(double)5});
EXPECT_EQ(result->at(0).at("sum-i"), Value{(double)2 + 4 + 6 + 8});
EXPECT_EQ(result->at(0).at("distinct-hi"), Value{(double)3});
EXPECT_EQ(result->at(0).at("distinct-null"), Value{(double)1});
EXPECT_EQ(result.values[0].at("count"), Value{(double)5});
EXPECT_EQ(result.values[0].at("sum-i"), Value{(double)2 + 4 + 6 + 8});
EXPECT_EQ(result.values[0].at("distinct-hi"), Value{(double)3});
EXPECT_EQ(result.values[0].at("distinct-null"), Value{(double)1});
// Odd
EXPECT_EQ(result->at(1).at("count"), Value{(double)5});
EXPECT_EQ(result->at(1).at("sum-i"), Value{(double)1 + 3 + 5 + 7 + 9});
EXPECT_EQ(result->at(1).at("distinct-hi"), Value{(double)3});
EXPECT_EQ(result->at(1).at("distinct-null"), Value{(double)1});
EXPECT_EQ(result.values[1].at("count"), Value{(double)5});
EXPECT_EQ(result.values[1].at("sum-i"), Value{(double)1 + 3 + 5 + 7 + 9});
EXPECT_EQ(result.values[1].at("distinct-hi"), Value{(double)3});
EXPECT_EQ(result.values[1].at("distinct-null"), Value{(double)1});
}
} // namespace dfly::aggregate

View file

@ -981,22 +981,34 @@ void SearchFamily::FtAggregate(CmdArgList args, const CommandContext& cmd_cntx)
make_move_iterator(sub_results.end()));
}
auto agg_results = aggregate::Process(std::move(values), params->steps);
if (!agg_results.has_value())
return builder->SendError(agg_results.error());
std::vector<std::string_view> load_fields;
if (params->load_fields) {
load_fields.reserve(params->load_fields->size());
for (const auto& field : params->load_fields.value()) {
load_fields.push_back(field.GetShortName());
}
}
auto agg_results = aggregate::Process(std::move(values), load_fields, params->steps);
size_t result_size = agg_results->size();
auto* rb = static_cast<RedisReplyBuilder*>(cmd_cntx.rb);
auto sortable_value_sender = SortableValueSender(rb);
const size_t result_size = agg_results.values.size();
rb->StartArray(result_size + 1);
rb->SendLong(result_size);
for (const auto& result : agg_results.value()) {
rb->StartArray(result.size() * 2);
for (const auto& [k, v] : result) {
rb->SendBulkString(k);
std::visit(sortable_value_sender, v);
const size_t field_count = agg_results.fields_to_print.size();
for (const auto& value : agg_results.values) {
rb->StartArray(field_count * 2);
for (const auto& field : agg_results.fields_to_print) {
rb->SendBulkString(field);
if (auto it = value.find(field); it != value.end()) {
std::visit(sortable_value_sender, it->second);
} else {
rb->SendNull();
}
}
}
}

View file

@ -962,15 +962,12 @@ TEST_F(SearchFamilyTest, AggregateGroupBy) {
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("foo_total", "20", "word", "item2"),
IsMap("foo_total", "50", "word", "item1")));
/*
Temporary not supported
resp = Run({"ft.aggregate", "i1", "*", "LOAD", "2", "foo", "text", "GROUPBY", "2", "@word",
"@text", "REDUCE", "SUM", "1", "@foo", "AS", "foo_total"}); EXPECT_THAT(resp,
IsUnordArrayWithSize(IsMap("foo_total", "20", "word", ArgType(RespExpr::NIL), "text", "\"second
key\""), IsMap("foo_total", "40", "word", ArgType(RespExpr::NIL), "text", "\"third key\""),
IsMap({"foo_total", "10", "word", ArgType(RespExpr::NIL), "text", "\"first key"})));
*/
"@text", "REDUCE", "SUM", "1", "@foo", "AS", "foo_total"});
EXPECT_THAT(resp, IsUnordArrayWithSize(
IsMap("foo_total", "40", "word", "item1", "text", "\"third key\""),
IsMap("foo_total", "20", "word", "item2", "text", "\"second key\""),
IsMap("foo_total", "10", "word", "item1", "text", "\"first key\"")));
}
TEST_F(SearchFamilyTest, JsonAggregateGroupBy) {
@ -1632,4 +1629,32 @@ TEST_F(SearchFamilyTest, SearchLoadReturnHash) {
EXPECT_THAT(resp, IsMapWithSize("h2", IsMap("a", "two"), "h1", IsMap("a", "one")));
}
// Test that FT.AGGREGATE prints only needed fields
TEST_F(SearchFamilyTest, AggregateResultFields) {
Run({"JSON.SET", "j1", ".", R"({"a":"1","b":"2","c":"3"})"});
Run({"JSON.SET", "j2", ".", R"({"a":"4","b":"5","c":"6"})"});
Run({"JSON.SET", "j3", ".", R"({"a":"7","b":"8","c":"9"})"});
auto resp = Run({"FT.CREATE", "index", "ON", "JSON", "SCHEMA", "$.a", "AS", "a", "TEXT",
"SORTABLE", "$.b", "AS", "b", "TEXT", "$.c", "AS", "c", "TEXT"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.AGGREGATE", "index", "*"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap(), IsMap(), IsMap()));
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "a"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("a", "1"), IsMap("a", "4"), IsMap("a", "7")));
resp = Run({"FT.AGGREGATE", "index", "*", "LOAD", "1", "@b", "SORTBY", "1", "a"});
EXPECT_THAT(resp,
IsUnordArrayWithSize(IsMap("b", "\"2\"", "a", "1"), IsMap("b", "\"5\"", "a", "4"),
IsMap("b", "\"8\"", "a", "7")));
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "a", "GROUPBY", "2", "@b", "@a",
"REDUCE", "COUNT", "0", "AS", "count"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("b", "\"8\"", "a", "7", "count", "1"),
IsMap("b", "\"2\"", "a", "1", "count", "1"),
IsMap("b", "\"5\"", "a", "4", "count", "1")));
}
} // namespace dfly