mirror of
https://github.com/dragonflydb/dragonfly.git
synced 2024-12-14 11:58:02 +00:00
feat: Support tags in search (#1341)
Signed-off-by: Vladislav Oleshko <vlad@dragonflydb.io>
This commit is contained in:
parent
ab3a67ced3
commit
9ab70e4f15
13 changed files with 186 additions and 28 deletions
|
@ -9,6 +9,8 @@
|
|||
#include <algorithm>
|
||||
#include <regex>
|
||||
|
||||
#include "base/logging.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace dfly::search {
|
||||
|
@ -42,4 +44,16 @@ AstFieldNode::AstFieldNode(string field, AstNode&& node)
|
|||
: field{field.substr(1)}, node{make_unique<AstNode>(move(node))} {
|
||||
}
|
||||
|
||||
AstTagsNode::AstTagsNode(std::string tag) {
|
||||
tags = {move(tag)};
|
||||
}
|
||||
|
||||
AstTagsNode::AstTagsNode(AstExpr&& l, std::string tag) {
|
||||
DCHECK(holds_alternative<AstTagsNode>(l));
|
||||
auto& tags_node = get<AstTagsNode>(l);
|
||||
|
||||
tags = move(tags_node.tags);
|
||||
tags.push_back(move(tag));
|
||||
}
|
||||
|
||||
} // namespace dfly::search
|
||||
|
|
|
@ -59,8 +59,16 @@ struct AstFieldNode {
|
|||
std::unique_ptr<AstNode> node;
|
||||
};
|
||||
|
||||
// Stores a list of tags for a tag query
|
||||
struct AstTagsNode {
|
||||
AstTagsNode(std::string tag);
|
||||
AstTagsNode(AstNode&& l, std::string tag);
|
||||
|
||||
std::vector<std::string> tags;
|
||||
};
|
||||
|
||||
using NodeVariants = std::variant<std::monostate, AstTermNode, AstRangeNode, AstNegateNode,
|
||||
AstLogicalNode, AstFieldNode>;
|
||||
AstLogicalNode, AstFieldNode, AstTagsNode>;
|
||||
struct AstNode : public NodeVariants {
|
||||
using variant::variant;
|
||||
};
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
#include <absl/container/flat_hash_set.h>
|
||||
#include <absl/strings/ascii.h>
|
||||
#include <absl/strings/numbers.h>
|
||||
#include <absl/strings/str_split.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <regex>
|
||||
|
@ -54,6 +55,11 @@ vector<DocId> NumericIndex::Range(int64_t l, int64_t r) const {
|
|||
return out;
|
||||
}
|
||||
|
||||
vector<DocId> BaseStringIndex::Matching(string_view str) const {
|
||||
auto it = entries_.find(absl::StripAsciiWhitespace(str));
|
||||
return (it != entries_.end()) ? it->second : vector<DocId>{};
|
||||
}
|
||||
|
||||
void TextIndex::Add(DocId doc, string_view value) {
|
||||
for (const auto& word : GetWords(value)) {
|
||||
auto& list = entries_[word];
|
||||
|
@ -61,9 +67,12 @@ void TextIndex::Add(DocId doc, string_view value) {
|
|||
}
|
||||
}
|
||||
|
||||
vector<DocId> TextIndex::Matching(string_view word_sv) const {
|
||||
auto it = entries_.find(word_sv);
|
||||
return (it != entries_.end()) ? it->second : vector<DocId>{};
|
||||
void TagIndex::Add(DocId doc, string_view value) {
|
||||
auto tags = absl::StrSplit(value, ',');
|
||||
for (string_view tag : tags) {
|
||||
auto& list = entries_[absl::StripAsciiWhitespace(tag)];
|
||||
list.insert(upper_bound(list.begin(), list.end(), doc), doc);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace dfly::search
|
||||
|
|
|
@ -23,15 +23,24 @@ struct NumericIndex : public BaseIndex {
|
|||
std::set<std::pair<int64_t, DocId>> entries_;
|
||||
};
|
||||
|
||||
// Index for text fields.
|
||||
// Hashmap based lookup per word.
|
||||
struct TextIndex : public BaseIndex {
|
||||
void Add(DocId doc, std::string_view value) override;
|
||||
// Base index for string based indices.
|
||||
struct BaseStringIndex : public BaseIndex {
|
||||
std::vector<DocId> Matching(std::string_view str) const;
|
||||
|
||||
std::vector<DocId> Matching(std::string_view word) const;
|
||||
|
||||
private:
|
||||
protected:
|
||||
absl::flat_hash_map<std::string, std::vector<DocId>> entries_;
|
||||
};
|
||||
|
||||
// Index for text fields.
|
||||
// Hashmap based lookup per word.
|
||||
struct TextIndex : public BaseStringIndex {
|
||||
void Add(DocId doc, std::string_view value) override;
|
||||
};
|
||||
|
||||
// Index for text fields.
|
||||
// Hashmap based lookup per word.
|
||||
struct TagIndex : public BaseStringIndex {
|
||||
void Add(DocId doc, std::string_view value) override;
|
||||
};
|
||||
|
||||
} // namespace dfly::search
|
||||
|
|
|
@ -58,6 +58,8 @@ term_char [_]|\w
|
|||
"=>" return Parser::make_ARROW (loc());
|
||||
"[" return Parser::make_LBRACKET (loc());
|
||||
"]" return Parser::make_RBRACKET (loc());
|
||||
"{" return Parser::make_LCURLBR (loc());
|
||||
"}" return Parser::make_RCURLBR (loc());
|
||||
"|" return Parser::make_OR_OP (loc());
|
||||
|
||||
-?[0-9]+ return make_INT64(matched_view(), loc());
|
||||
|
|
|
@ -52,6 +52,8 @@ using namespace std;
|
|||
COLON ":"
|
||||
LBRACKET "["
|
||||
RBRACKET "]"
|
||||
LCURLBR "{"
|
||||
RCURLBR "}"
|
||||
OR_OP "|"
|
||||
;
|
||||
|
||||
|
@ -67,7 +69,7 @@ using namespace std;
|
|||
%precedence LPAREN RPAREN
|
||||
|
||||
%token <int64_t> INT64 "int64"
|
||||
%nterm <AstExpr> final_query filter search_expr field_cond field_cond_expr
|
||||
%nterm <AstExpr> final_query filter search_expr field_cond field_cond_expr tag_list
|
||||
|
||||
%printer { yyo << $$; } <*>;
|
||||
|
||||
|
@ -92,6 +94,7 @@ field_cond:
|
|||
| NOT_OP field_cond { $$ = AstNegateNode(move($2)); }
|
||||
| LPAREN field_cond_expr RPAREN { $$ = move($2); }
|
||||
| LBRACKET INT64 INT64 RBRACKET { $$ = AstRangeNode(move($2), move($3)); }
|
||||
| LCURLBR tag_list RCURLBR { $$ = move($2); }
|
||||
|
||||
field_cond_expr:
|
||||
LPAREN field_cond_expr RPAREN { $$ = move($2); }
|
||||
|
@ -99,6 +102,11 @@ field_cond_expr:
|
|||
| field_cond_expr OR_OP field_cond_expr { $$ = AstLogicalNode(move($1), move($3), AstLogicalNode::OR); }
|
||||
| NOT_OP field_cond_expr { $$ = AstNegateNode(move($2)); };
|
||||
| TERM { $$ = AstTermNode(move($1)); }
|
||||
|
||||
tag_list:
|
||||
TERM { $$ = AstTagsNode(move($1)); }
|
||||
| tag_list OR_OP TERM { $$ = AstTagsNode(move($1), move($3)); }
|
||||
|
||||
%%
|
||||
|
||||
void
|
||||
|
|
|
@ -113,9 +113,20 @@ struct BasicSearch {
|
|||
|
||||
vector<DocId> Search(const AstFieldNode& node, string_view active_field) {
|
||||
DCHECK(active_field.empty());
|
||||
DCHECK(node.node);
|
||||
return SearchGeneric(*node.node, node.field);
|
||||
}
|
||||
|
||||
vector<DocId> Search(const AstTagsNode& node, string_view active_field) {
|
||||
auto* tag_index = GetIndex<TagIndex>(active_field);
|
||||
|
||||
vector<DocId> out, tmp;
|
||||
for (const auto& tag : node.tags)
|
||||
UnifyResults(tag_index->Matching(tag), &out, &tmp);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
vector<DocId> SearchGeneric(const AstNode& node, string_view active_field) {
|
||||
auto cb = [this, active_field](const auto& inner) { return Search(inner, active_field); };
|
||||
auto result = visit(cb, static_cast<const NodeVariants&>(node));
|
||||
|
@ -135,6 +146,9 @@ struct BasicSearch {
|
|||
FieldIndices::FieldIndices(Schema schema) : schema_{move(schema)}, all_ids_{}, indices_{} {
|
||||
for (auto& [field, type] : schema_.fields) {
|
||||
switch (type) {
|
||||
case Schema::TAG:
|
||||
indices_[field] = make_unique<TagIndex>();
|
||||
break;
|
||||
case Schema::TEXT:
|
||||
indices_[field] = make_unique<TextIndex>();
|
||||
break;
|
||||
|
|
|
@ -26,7 +26,7 @@ struct DocumentAccessor {
|
|||
};
|
||||
|
||||
struct Schema {
|
||||
enum FieldType { TEXT, NUMERIC };
|
||||
enum FieldType { TAG, TEXT, NUMERIC };
|
||||
|
||||
absl::flat_hash_map<std::string, FieldType> fields;
|
||||
};
|
||||
|
|
|
@ -85,6 +85,13 @@ TEST_F(SearchParserTest, Scanner) {
|
|||
NEXT_TOK(TOK_COLON);
|
||||
NEXT_EQ(TOK_TERM, string, "hello");
|
||||
|
||||
SetInput("@field:{ tag }");
|
||||
NEXT_EQ(TOK_FIELD, string, "@field");
|
||||
NEXT_TOK(TOK_COLON);
|
||||
NEXT_TOK(TOK_LCURLBR);
|
||||
NEXT_EQ(TOK_TERM, string, "tag");
|
||||
NEXT_TOK(TOK_RCURLBR);
|
||||
|
||||
SetInput("почтальон Печкин");
|
||||
NEXT_EQ(TOK_TERM, string, "почтальон");
|
||||
NEXT_EQ(TOK_TERM, string, "Печкин");
|
||||
|
@ -96,6 +103,8 @@ TEST_F(SearchParserTest, Scanner) {
|
|||
TEST_F(SearchParserTest, Parse) {
|
||||
EXPECT_EQ(0, Parse(" foo bar (baz) "));
|
||||
EXPECT_EQ(0, Parse(" -(foo) @foo:bar @ss:[1 2]"));
|
||||
EXPECT_EQ(0, Parse("@foo:{ tag1 | tag2 }"));
|
||||
|
||||
EXPECT_EQ(1, Parse(" -(foo "));
|
||||
EXPECT_EQ(1, Parse(" foo:bar "));
|
||||
EXPECT_EQ(1, Parse(" @foo:@bar "));
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
#include "core/search/search.h"
|
||||
|
||||
#include <absl/cleanup/cleanup.h>
|
||||
#include <absl/container/flat_hash_map.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
@ -79,6 +80,8 @@ class SearchParserTest : public ::testing::Test {
|
|||
}
|
||||
|
||||
bool Check() {
|
||||
absl::Cleanup cl{[this] { entries_.clear(); }};
|
||||
|
||||
FieldIndices index{schema_};
|
||||
|
||||
shuffle(entries_.begin(), entries_.end(), default_random_engine{});
|
||||
|
@ -86,7 +89,11 @@ class SearchParserTest : public ::testing::Test {
|
|||
index.Add(i, &entries_[i].first);
|
||||
|
||||
SearchAlgorithm search_algo{};
|
||||
search_algo.Init(query_);
|
||||
if (!search_algo.Init(query_)) {
|
||||
error_ = "Failed to parse query";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto matched = search_algo.Search(&index);
|
||||
|
||||
if (!is_sorted(matched.begin(), matched.end()))
|
||||
|
@ -97,12 +104,10 @@ class SearchParserTest : public ::testing::Test {
|
|||
if (doc_matched != entries_[i].second) {
|
||||
error_ = "doc: \"" + entries_[i].first.DebugFormat() + "\"" + " was expected" +
|
||||
(entries_[i].second ? "" : " not") + " to match" + " query: \"" + query_ + "\"";
|
||||
entries_.clear();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
entries_.clear();
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -268,6 +273,22 @@ TEST_F(SearchParserTest, CheckExprInField) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(SearchParserTest, CheckTag) {
|
||||
PrepareSchema({{{"f1", Schema::TAG}, {"f2", Schema::TAG}}});
|
||||
|
||||
PrepareQuery("@f1:{red | blue} @f2:{circle | square}");
|
||||
|
||||
ExpectAll(Map{{"f1", "red"}, {"f2", "square"}}, Map{{"f1", "blue"}, {"f2", "square"}},
|
||||
Map{{"f1", "red"}, {"f2", "circle"}}, Map{{"f1", "red"}, {"f2", "circle, square"}},
|
||||
Map{{"f1", "red"}, {"f2", "triangle, circle"}},
|
||||
Map{{"f1", "red, green"}, {"f2", "square"}},
|
||||
Map{{"f1", "green, blue"}, {"f2", "circle"}});
|
||||
ExpectNone(Map{{"f1", "green"}, {"f2", "square"}}, Map{{"f1", "green"}, {"f2", "circle"}},
|
||||
Map{{"f1", "red"}, {"f2", "triangle"}}, Map{{"f1", "blue"}, {"f2", "line, triangle"}});
|
||||
|
||||
EXPECT_TRUE(Check()) << GetError();
|
||||
}
|
||||
|
||||
} // namespace search
|
||||
|
||||
} // namespace dfly
|
||||
|
|
|
@ -31,7 +31,9 @@ using namespace facade;
|
|||
namespace {
|
||||
|
||||
unordered_map<string_view, search::Schema::FieldType> kSchemaTypes = {
|
||||
{"TEXT"sv, search::Schema::TEXT}, {"NUMERIC"sv, search::Schema::NUMERIC}};
|
||||
{"TAG"sv, search::Schema::TAG},
|
||||
{"TEXT"sv, search::Schema::TEXT},
|
||||
{"NUMERIC"sv, search::Schema::NUMERIC}};
|
||||
|
||||
optional<search::Schema> ParseSchemaOrReply(CmdArgList args, ConnectionContext* cntx) {
|
||||
search::Schema schema;
|
||||
|
@ -50,6 +52,12 @@ optional<search::Schema> ParseSchemaOrReply(CmdArgList args, ConnectionContext*
|
|||
return nullopt;
|
||||
}
|
||||
|
||||
// Skip optional WEIGHT or SEPARATOR flags
|
||||
if (i + 2 < args.size() &&
|
||||
(ArgS(args, i + 1) == "WEIGHT" || ArgS(args, i + 1) == "SEPARATOR")) {
|
||||
i += 2;
|
||||
}
|
||||
|
||||
schema.fields[field] = it->second;
|
||||
|
||||
// Skip optional WEIGHT flag
|
||||
|
|
|
@ -118,6 +118,29 @@ TEST_F(SearchFamilyTest, Json) {
|
|||
EXPECT_THAT(Run({"ft.search", "i1", "@a:small @b:secret"}), kNoResults);
|
||||
}
|
||||
|
||||
TEST_F(SearchFamilyTest, Tags) {
|
||||
Run({"hset", "d:1", "color", "red, green"});
|
||||
Run({"hset", "d:2", "color", "green, blue"});
|
||||
Run({"hset", "d:3", "color", "blue, red"});
|
||||
Run({"hset", "d:4", "color", "red"});
|
||||
Run({"hset", "d:5", "color", "green"});
|
||||
Run({"hset", "d:6", "color", "blue"});
|
||||
|
||||
EXPECT_EQ(Run({"ft.create", "i1", "on", "hash", "schema", "color", "tag"}), "OK");
|
||||
|
||||
// Tags don't participate in full text search
|
||||
EXPECT_THAT(Run({"ft.search", "i1", "red"}), kNoResults);
|
||||
|
||||
EXPECT_THAT(Run({"ft.search", "i1", "@color:{ red }"}), AreDocIds("d:1", "d:3", "d:4"));
|
||||
EXPECT_THAT(Run({"ft.search", "i1", "@color:{green}"}), AreDocIds("d:1", "d:2", "d:5"));
|
||||
EXPECT_THAT(Run({"ft.search", "i1", "@color:{blue}"}), AreDocIds("d:2", "d:3", "d:6"));
|
||||
|
||||
EXPECT_THAT(Run({"ft.search", "i1", "@color:{red | green}"}),
|
||||
AreDocIds("d:1", "d:2", "d:3", "d:4", "d:5"));
|
||||
EXPECT_THAT(Run({"ft.search", "i1", "@color:{blue | green}"}),
|
||||
AreDocIds("d:1", "d:2", "d:3", "d:5", "d:6"));
|
||||
}
|
||||
|
||||
TEST_F(SearchFamilyTest, Numbers) {
|
||||
for (unsigned i = 0; i <= 10; i++) {
|
||||
for (unsigned j = 0; j <= 10; j++) {
|
||||
|
|
|
@ -6,18 +6,25 @@ import pytest
|
|||
from redis import asyncio as aioredis
|
||||
from .utility import *
|
||||
|
||||
from redis.commands.search.query import Query
|
||||
from redis.commands.search.field import TextField
|
||||
from redis.commands.search.field import TextField, NumericField, TagField
|
||||
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
|
||||
|
||||
TEST_DATA = [
|
||||
{"title": "First article", "content": "Long description"},
|
||||
{"title": "Second article", "content": "Small text"},
|
||||
{"title": "Third piece", "content": "Brief description"},
|
||||
{"title": "Last piece", "content": "Interesting text"},
|
||||
{"title": "First article", "content": "Long description",
|
||||
"views": 100, "topic": "world, science"},
|
||||
|
||||
{"title": "Second article", "content": "Small text",
|
||||
"views": 200, "topic": "national, policits"},
|
||||
|
||||
{"title": "Third piece", "content": "Brief description",
|
||||
"views": 300, "topic": "health, lifestyle"},
|
||||
|
||||
{"title": "Last piece", "content": "Interesting text",
|
||||
"views": 400, "topic": "world, business"},
|
||||
]
|
||||
|
||||
TEST_DATA_SCHEMA = [TextField("title"), TextField("content")]
|
||||
TEST_DATA_SCHEMA = [TextField("title"), TextField(
|
||||
"content"), NumericField("views"), TagField("topic")]
|
||||
|
||||
|
||||
async def index_test_data(async_client: aioredis.Redis, itype: IndexType, prefix=""):
|
||||
|
@ -27,17 +34,24 @@ async def index_test_data(async_client: aioredis.Redis, itype: IndexType, prefix
|
|||
else:
|
||||
await async_client.json().set(prefix+str(i), "$", e)
|
||||
|
||||
def doc_to_str(doc):
|
||||
if not type(doc) is dict:
|
||||
doc = doc.__dict__
|
||||
|
||||
doc = dict(doc) # copy to remove fields
|
||||
doc.pop('id', None)
|
||||
doc.pop('payload', None)
|
||||
|
||||
return '//'.join(sorted(doc))
|
||||
|
||||
def contains_test_data(res, td_indices):
|
||||
if res.total != len(td_indices):
|
||||
return False
|
||||
|
||||
docset = set()
|
||||
for doc in res.docs:
|
||||
docset.add(f"{doc.title}//{doc.content}")
|
||||
docset = {doc_to_str(doc) for doc in res.docs}
|
||||
|
||||
for td_entry in (TEST_DATA[tdi] for tdi in td_indices):
|
||||
if not f"{td_entry['title']}//{td_entry['content']}" in docset:
|
||||
if not doc_to_str(td_entry) in docset:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
@ -47,6 +61,7 @@ def contains_test_data(res, td_indices):
|
|||
async def test_basic(async_client, index_type):
|
||||
i1 = async_client.ft("i-"+str(index_type))
|
||||
await index_test_data(async_client, index_type)
|
||||
await i1.create_index(TEST_DATA_SCHEMA, definition=IndexDefinition(index_type=index_type))
|
||||
|
||||
await i1.create_index(TEST_DATA_SCHEMA, definition=IndexDefinition(index_type=index_type))
|
||||
|
||||
|
@ -61,3 +76,21 @@ async def test_basic(async_client, index_type):
|
|||
|
||||
res = await i1.search("@title:(article|last) @content:text")
|
||||
assert contains_test_data(res, [1, 3])
|
||||
|
||||
res = await i1.search("@views:[200 300]")
|
||||
assert contains_test_data(res, [1, 2])
|
||||
|
||||
res = await i1.search("@views:[0 150] | @views:[350 500]")
|
||||
assert contains_test_data(res, [0, 3])
|
||||
|
||||
res = await i1.search("@topic:{world}")
|
||||
assert contains_test_data(res, [0, 3])
|
||||
|
||||
res = await i1.search("@topic:{business}")
|
||||
assert contains_test_data(res, [3])
|
||||
|
||||
res = await i1.search("@topic:{world | national}")
|
||||
assert contains_test_data(res, [0, 1, 3])
|
||||
|
||||
res = await i1.search("@topic:{science | health}")
|
||||
assert contains_test_data(res, [0, 2])
|
||||
|
|
Loading…
Reference in a new issue