1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-14 11:58:02 +00:00

feat: Support tags in search (#1341)

Signed-off-by: Vladislav Oleshko <vlad@dragonflydb.io>
This commit is contained in:
Vladislav 2023-06-05 00:26:21 +03:00 committed by GitHub
parent ab3a67ced3
commit 9ab70e4f15
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 186 additions and 28 deletions

View file

@ -9,6 +9,8 @@
#include <algorithm>
#include <regex>
#include "base/logging.h"
using namespace std;
namespace dfly::search {
@ -42,4 +44,16 @@ AstFieldNode::AstFieldNode(string field, AstNode&& node)
: field{field.substr(1)}, node{make_unique<AstNode>(move(node))} {
}
AstTagsNode::AstTagsNode(std::string tag) {
tags = {move(tag)};
}
AstTagsNode::AstTagsNode(AstExpr&& l, std::string tag) {
DCHECK(holds_alternative<AstTagsNode>(l));
auto& tags_node = get<AstTagsNode>(l);
tags = move(tags_node.tags);
tags.push_back(move(tag));
}
} // namespace dfly::search

View file

@ -59,8 +59,16 @@ struct AstFieldNode {
std::unique_ptr<AstNode> node;
};
// Stores a list of tags for a tag query
struct AstTagsNode {
AstTagsNode(std::string tag);
AstTagsNode(AstNode&& l, std::string tag);
std::vector<std::string> tags;
};
using NodeVariants = std::variant<std::monostate, AstTermNode, AstRangeNode, AstNegateNode,
AstLogicalNode, AstFieldNode>;
AstLogicalNode, AstFieldNode, AstTagsNode>;
struct AstNode : public NodeVariants {
using variant::variant;
};

View file

@ -7,6 +7,7 @@
#include <absl/container/flat_hash_set.h>
#include <absl/strings/ascii.h>
#include <absl/strings/numbers.h>
#include <absl/strings/str_split.h>
#include <algorithm>
#include <regex>
@ -54,6 +55,11 @@ vector<DocId> NumericIndex::Range(int64_t l, int64_t r) const {
return out;
}
vector<DocId> BaseStringIndex::Matching(string_view str) const {
auto it = entries_.find(absl::StripAsciiWhitespace(str));
return (it != entries_.end()) ? it->second : vector<DocId>{};
}
void TextIndex::Add(DocId doc, string_view value) {
for (const auto& word : GetWords(value)) {
auto& list = entries_[word];
@ -61,9 +67,12 @@ void TextIndex::Add(DocId doc, string_view value) {
}
}
vector<DocId> TextIndex::Matching(string_view word_sv) const {
auto it = entries_.find(word_sv);
return (it != entries_.end()) ? it->second : vector<DocId>{};
void TagIndex::Add(DocId doc, string_view value) {
auto tags = absl::StrSplit(value, ',');
for (string_view tag : tags) {
auto& list = entries_[absl::StripAsciiWhitespace(tag)];
list.insert(upper_bound(list.begin(), list.end(), doc), doc);
}
}
} // namespace dfly::search

View file

@ -23,15 +23,24 @@ struct NumericIndex : public BaseIndex {
std::set<std::pair<int64_t, DocId>> entries_;
};
// Index for text fields.
// Hashmap based lookup per word.
struct TextIndex : public BaseIndex {
void Add(DocId doc, std::string_view value) override;
// Base index for string based indices.
struct BaseStringIndex : public BaseIndex {
std::vector<DocId> Matching(std::string_view str) const;
std::vector<DocId> Matching(std::string_view word) const;
private:
protected:
absl::flat_hash_map<std::string, std::vector<DocId>> entries_;
};
// Index for text fields.
// Hashmap based lookup per word.
struct TextIndex : public BaseStringIndex {
void Add(DocId doc, std::string_view value) override;
};
// Index for text fields.
// Hashmap based lookup per word.
struct TagIndex : public BaseStringIndex {
void Add(DocId doc, std::string_view value) override;
};
} // namespace dfly::search

View file

@ -58,6 +58,8 @@ term_char [_]|\w
"=>" return Parser::make_ARROW (loc());
"[" return Parser::make_LBRACKET (loc());
"]" return Parser::make_RBRACKET (loc());
"{" return Parser::make_LCURLBR (loc());
"}" return Parser::make_RCURLBR (loc());
"|" return Parser::make_OR_OP (loc());
-?[0-9]+ return make_INT64(matched_view(), loc());

View file

@ -52,6 +52,8 @@ using namespace std;
COLON ":"
LBRACKET "["
RBRACKET "]"
LCURLBR "{"
RCURLBR "}"
OR_OP "|"
;
@ -67,7 +69,7 @@ using namespace std;
%precedence LPAREN RPAREN
%token <int64_t> INT64 "int64"
%nterm <AstExpr> final_query filter search_expr field_cond field_cond_expr
%nterm <AstExpr> final_query filter search_expr field_cond field_cond_expr tag_list
%printer { yyo << $$; } <*>;
@ -92,6 +94,7 @@ field_cond:
| NOT_OP field_cond { $$ = AstNegateNode(move($2)); }
| LPAREN field_cond_expr RPAREN { $$ = move($2); }
| LBRACKET INT64 INT64 RBRACKET { $$ = AstRangeNode(move($2), move($3)); }
| LCURLBR tag_list RCURLBR { $$ = move($2); }
field_cond_expr:
LPAREN field_cond_expr RPAREN { $$ = move($2); }
@ -99,6 +102,11 @@ field_cond_expr:
| field_cond_expr OR_OP field_cond_expr { $$ = AstLogicalNode(move($1), move($3), AstLogicalNode::OR); }
| NOT_OP field_cond_expr { $$ = AstNegateNode(move($2)); };
| TERM { $$ = AstTermNode(move($1)); }
tag_list:
TERM { $$ = AstTagsNode(move($1)); }
| tag_list OR_OP TERM { $$ = AstTagsNode(move($1), move($3)); }
%%
void

View file

@ -113,9 +113,20 @@ struct BasicSearch {
vector<DocId> Search(const AstFieldNode& node, string_view active_field) {
DCHECK(active_field.empty());
DCHECK(node.node);
return SearchGeneric(*node.node, node.field);
}
vector<DocId> Search(const AstTagsNode& node, string_view active_field) {
auto* tag_index = GetIndex<TagIndex>(active_field);
vector<DocId> out, tmp;
for (const auto& tag : node.tags)
UnifyResults(tag_index->Matching(tag), &out, &tmp);
return out;
}
vector<DocId> SearchGeneric(const AstNode& node, string_view active_field) {
auto cb = [this, active_field](const auto& inner) { return Search(inner, active_field); };
auto result = visit(cb, static_cast<const NodeVariants&>(node));
@ -135,6 +146,9 @@ struct BasicSearch {
FieldIndices::FieldIndices(Schema schema) : schema_{move(schema)}, all_ids_{}, indices_{} {
for (auto& [field, type] : schema_.fields) {
switch (type) {
case Schema::TAG:
indices_[field] = make_unique<TagIndex>();
break;
case Schema::TEXT:
indices_[field] = make_unique<TextIndex>();
break;

View file

@ -26,7 +26,7 @@ struct DocumentAccessor {
};
struct Schema {
enum FieldType { TEXT, NUMERIC };
enum FieldType { TAG, TEXT, NUMERIC };
absl::flat_hash_map<std::string, FieldType> fields;
};

View file

@ -85,6 +85,13 @@ TEST_F(SearchParserTest, Scanner) {
NEXT_TOK(TOK_COLON);
NEXT_EQ(TOK_TERM, string, "hello");
SetInput("@field:{ tag }");
NEXT_EQ(TOK_FIELD, string, "@field");
NEXT_TOK(TOK_COLON);
NEXT_TOK(TOK_LCURLBR);
NEXT_EQ(TOK_TERM, string, "tag");
NEXT_TOK(TOK_RCURLBR);
SetInput("почтальон Печкин");
NEXT_EQ(TOK_TERM, string, "почтальон");
NEXT_EQ(TOK_TERM, string, "Печкин");
@ -96,6 +103,8 @@ TEST_F(SearchParserTest, Scanner) {
TEST_F(SearchParserTest, Parse) {
EXPECT_EQ(0, Parse(" foo bar (baz) "));
EXPECT_EQ(0, Parse(" -(foo) @foo:bar @ss:[1 2]"));
EXPECT_EQ(0, Parse("@foo:{ tag1 | tag2 }"));
EXPECT_EQ(1, Parse(" -(foo "));
EXPECT_EQ(1, Parse(" foo:bar "));
EXPECT_EQ(1, Parse(" @foo:@bar "));

View file

@ -4,6 +4,7 @@
#include "core/search/search.h"
#include <absl/cleanup/cleanup.h>
#include <absl/container/flat_hash_map.h>
#include <algorithm>
@ -79,6 +80,8 @@ class SearchParserTest : public ::testing::Test {
}
bool Check() {
absl::Cleanup cl{[this] { entries_.clear(); }};
FieldIndices index{schema_};
shuffle(entries_.begin(), entries_.end(), default_random_engine{});
@ -86,7 +89,11 @@ class SearchParserTest : public ::testing::Test {
index.Add(i, &entries_[i].first);
SearchAlgorithm search_algo{};
search_algo.Init(query_);
if (!search_algo.Init(query_)) {
error_ = "Failed to parse query";
return false;
}
auto matched = search_algo.Search(&index);
if (!is_sorted(matched.begin(), matched.end()))
@ -97,12 +104,10 @@ class SearchParserTest : public ::testing::Test {
if (doc_matched != entries_[i].second) {
error_ = "doc: \"" + entries_[i].first.DebugFormat() + "\"" + " was expected" +
(entries_[i].second ? "" : " not") + " to match" + " query: \"" + query_ + "\"";
entries_.clear();
return false;
}
}
entries_.clear();
return true;
}
@ -268,6 +273,22 @@ TEST_F(SearchParserTest, CheckExprInField) {
}
}
TEST_F(SearchParserTest, CheckTag) {
PrepareSchema({{{"f1", Schema::TAG}, {"f2", Schema::TAG}}});
PrepareQuery("@f1:{red | blue} @f2:{circle | square}");
ExpectAll(Map{{"f1", "red"}, {"f2", "square"}}, Map{{"f1", "blue"}, {"f2", "square"}},
Map{{"f1", "red"}, {"f2", "circle"}}, Map{{"f1", "red"}, {"f2", "circle, square"}},
Map{{"f1", "red"}, {"f2", "triangle, circle"}},
Map{{"f1", "red, green"}, {"f2", "square"}},
Map{{"f1", "green, blue"}, {"f2", "circle"}});
ExpectNone(Map{{"f1", "green"}, {"f2", "square"}}, Map{{"f1", "green"}, {"f2", "circle"}},
Map{{"f1", "red"}, {"f2", "triangle"}}, Map{{"f1", "blue"}, {"f2", "line, triangle"}});
EXPECT_TRUE(Check()) << GetError();
}
} // namespace search
} // namespace dfly

View file

@ -31,7 +31,9 @@ using namespace facade;
namespace {
unordered_map<string_view, search::Schema::FieldType> kSchemaTypes = {
{"TEXT"sv, search::Schema::TEXT}, {"NUMERIC"sv, search::Schema::NUMERIC}};
{"TAG"sv, search::Schema::TAG},
{"TEXT"sv, search::Schema::TEXT},
{"NUMERIC"sv, search::Schema::NUMERIC}};
optional<search::Schema> ParseSchemaOrReply(CmdArgList args, ConnectionContext* cntx) {
search::Schema schema;
@ -50,6 +52,12 @@ optional<search::Schema> ParseSchemaOrReply(CmdArgList args, ConnectionContext*
return nullopt;
}
// Skip optional WEIGHT or SEPARATOR flags
if (i + 2 < args.size() &&
(ArgS(args, i + 1) == "WEIGHT" || ArgS(args, i + 1) == "SEPARATOR")) {
i += 2;
}
schema.fields[field] = it->second;
// Skip optional WEIGHT flag

View file

@ -118,6 +118,29 @@ TEST_F(SearchFamilyTest, Json) {
EXPECT_THAT(Run({"ft.search", "i1", "@a:small @b:secret"}), kNoResults);
}
TEST_F(SearchFamilyTest, Tags) {
Run({"hset", "d:1", "color", "red, green"});
Run({"hset", "d:2", "color", "green, blue"});
Run({"hset", "d:3", "color", "blue, red"});
Run({"hset", "d:4", "color", "red"});
Run({"hset", "d:5", "color", "green"});
Run({"hset", "d:6", "color", "blue"});
EXPECT_EQ(Run({"ft.create", "i1", "on", "hash", "schema", "color", "tag"}), "OK");
// Tags don't participate in full text search
EXPECT_THAT(Run({"ft.search", "i1", "red"}), kNoResults);
EXPECT_THAT(Run({"ft.search", "i1", "@color:{ red }"}), AreDocIds("d:1", "d:3", "d:4"));
EXPECT_THAT(Run({"ft.search", "i1", "@color:{green}"}), AreDocIds("d:1", "d:2", "d:5"));
EXPECT_THAT(Run({"ft.search", "i1", "@color:{blue}"}), AreDocIds("d:2", "d:3", "d:6"));
EXPECT_THAT(Run({"ft.search", "i1", "@color:{red | green}"}),
AreDocIds("d:1", "d:2", "d:3", "d:4", "d:5"));
EXPECT_THAT(Run({"ft.search", "i1", "@color:{blue | green}"}),
AreDocIds("d:1", "d:2", "d:3", "d:5", "d:6"));
}
TEST_F(SearchFamilyTest, Numbers) {
for (unsigned i = 0; i <= 10; i++) {
for (unsigned j = 0; j <= 10; j++) {

View file

@ -6,18 +6,25 @@ import pytest
from redis import asyncio as aioredis
from .utility import *
from redis.commands.search.query import Query
from redis.commands.search.field import TextField
from redis.commands.search.field import TextField, NumericField, TagField
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
TEST_DATA = [
{"title": "First article", "content": "Long description"},
{"title": "Second article", "content": "Small text"},
{"title": "Third piece", "content": "Brief description"},
{"title": "Last piece", "content": "Interesting text"},
{"title": "First article", "content": "Long description",
"views": 100, "topic": "world, science"},
{"title": "Second article", "content": "Small text",
"views": 200, "topic": "national, policits"},
{"title": "Third piece", "content": "Brief description",
"views": 300, "topic": "health, lifestyle"},
{"title": "Last piece", "content": "Interesting text",
"views": 400, "topic": "world, business"},
]
TEST_DATA_SCHEMA = [TextField("title"), TextField("content")]
TEST_DATA_SCHEMA = [TextField("title"), TextField(
"content"), NumericField("views"), TagField("topic")]
async def index_test_data(async_client: aioredis.Redis, itype: IndexType, prefix=""):
@ -27,17 +34,24 @@ async def index_test_data(async_client: aioredis.Redis, itype: IndexType, prefix
else:
await async_client.json().set(prefix+str(i), "$", e)
def doc_to_str(doc):
if not type(doc) is dict:
doc = doc.__dict__
doc = dict(doc) # copy to remove fields
doc.pop('id', None)
doc.pop('payload', None)
return '//'.join(sorted(doc))
def contains_test_data(res, td_indices):
if res.total != len(td_indices):
return False
docset = set()
for doc in res.docs:
docset.add(f"{doc.title}//{doc.content}")
docset = {doc_to_str(doc) for doc in res.docs}
for td_entry in (TEST_DATA[tdi] for tdi in td_indices):
if not f"{td_entry['title']}//{td_entry['content']}" in docset:
if not doc_to_str(td_entry) in docset:
return False
return True
@ -47,6 +61,7 @@ def contains_test_data(res, td_indices):
async def test_basic(async_client, index_type):
i1 = async_client.ft("i-"+str(index_type))
await index_test_data(async_client, index_type)
await i1.create_index(TEST_DATA_SCHEMA, definition=IndexDefinition(index_type=index_type))
await i1.create_index(TEST_DATA_SCHEMA, definition=IndexDefinition(index_type=index_type))
@ -61,3 +76,21 @@ async def test_basic(async_client, index_type):
res = await i1.search("@title:(article|last) @content:text")
assert contains_test_data(res, [1, 3])
res = await i1.search("@views:[200 300]")
assert contains_test_data(res, [1, 2])
res = await i1.search("@views:[0 150] | @views:[350 500]")
assert contains_test_data(res, [0, 3])
res = await i1.search("@topic:{world}")
assert contains_test_data(res, [0, 3])
res = await i1.search("@topic:{business}")
assert contains_test_data(res, [3])
res = await i1.search("@topic:{world | national}")
assert contains_test_data(res, [0, 1, 3])
res = await i1.search("@topic:{science | health}")
assert contains_test_data(res, [0, 2])