mirror of
https://github.com/dragonflydb/dragonfly.git
synced 2024-12-15 17:51:06 +00:00
feat: search json support + client tests (#1210)
* feat: search json support + client tests Signed-off-by: Vladislav Oleshko <vlad@dragonflydb.io> * fix: small fixes Signed-off-by: Vladislav Oleshko <vlad@dragonflydb.io> * fix: small fixes Signed-off-by: Vladislav Oleshko <vlad@dragonflydb.io> --------- Signed-off-by: Vladislav Oleshko <vlad@dragonflydb.io>
This commit is contained in:
parent
790e357aaf
commit
7f547151bf
6 changed files with 213 additions and 41 deletions
|
@ -13,30 +13,30 @@
|
|||
|
||||
namespace dfly::search {
|
||||
|
||||
// Interface for accessing hashset values with different data structures underneath.
|
||||
struct HSetAccessor {
|
||||
// Interface for accessing document values with different data structures underneath.
|
||||
struct DocumentAccessor {
|
||||
// Callback that's supplied with field values.
|
||||
using FieldConsumer = std::function<bool(std::string_view)>;
|
||||
|
||||
virtual bool Check(FieldConsumer f, std::string_view active_field) const = 0;
|
||||
};
|
||||
|
||||
// Wrapper around hashset accessor and optional active field.
|
||||
// Wrapper around document accessor and optional active field.
|
||||
struct SearchInput {
|
||||
SearchInput(const HSetAccessor* hset, std::string_view active_field = {})
|
||||
: hset_{hset}, active_field_{active_field} {
|
||||
SearchInput(const DocumentAccessor* doc, std::string_view active_field = {})
|
||||
: doc_{doc}, active_field_{active_field} {
|
||||
}
|
||||
|
||||
SearchInput(const SearchInput& base, std::string_view active_field)
|
||||
: hset_{base.hset_}, active_field_{active_field} {
|
||||
: doc_{base.doc_}, active_field_{active_field} {
|
||||
}
|
||||
|
||||
bool Check(HSetAccessor::FieldConsumer f) {
|
||||
return hset_->Check(move(f), active_field_);
|
||||
bool Check(DocumentAccessor::FieldConsumer f) {
|
||||
return doc_->Check(move(f), active_field_);
|
||||
}
|
||||
|
||||
private:
|
||||
const HSetAccessor* hset_;
|
||||
const DocumentAccessor* doc_;
|
||||
std::string_view active_field_;
|
||||
};
|
||||
|
||||
|
|
|
@ -52,15 +52,15 @@ class SearchParserTest : public ::testing::Test {
|
|||
QueryDriver query_driver_;
|
||||
};
|
||||
|
||||
class MockedHSetAccessor : public HSetAccessor {
|
||||
class MockedDocument : public DocumentAccessor {
|
||||
public:
|
||||
using Map = std::unordered_map<std::string, std::string>;
|
||||
|
||||
MockedHSetAccessor() = default;
|
||||
MockedHSetAccessor(std::string test_field) : hset_{{"field", test_field}} {
|
||||
MockedDocument() = default;
|
||||
MockedDocument(std::string test_field) : hset_{{"field", test_field}} {
|
||||
}
|
||||
|
||||
bool Check(HSetAccessor::FieldConsumer f, string_view active_field) const override {
|
||||
bool Check(DocumentAccessor::FieldConsumer f, string_view active_field) const override {
|
||||
if (!active_field.empty()) {
|
||||
auto it = hset_.find(string{active_field});
|
||||
return f(it != hset_.end() ? it->second : "");
|
||||
|
@ -108,7 +108,7 @@ class MockedHSetAccessor : public HSetAccessor {
|
|||
#define CHECK_ALL(...) \
|
||||
{ \
|
||||
for (auto str : {__VA_ARGS__}) { \
|
||||
MockedHSetAccessor hset{str}; \
|
||||
MockedDocument hset{str}; \
|
||||
EXPECT_TRUE(Check(SearchInput{&hset})) << str << " failed on " << DebugExpr(); \
|
||||
} \
|
||||
}
|
||||
|
@ -116,7 +116,7 @@ class MockedHSetAccessor : public HSetAccessor {
|
|||
#define CHECK_NONE(...) \
|
||||
{ \
|
||||
for (auto str : {__VA_ARGS__}) { \
|
||||
MockedHSetAccessor hset{str}; \
|
||||
MockedDocument hset{str}; \
|
||||
EXPECT_FALSE(Check(SearchInput{&hset})) << str << " failed on " << DebugExpr(); \
|
||||
} \
|
||||
}
|
||||
|
@ -238,7 +238,7 @@ TEST_F(SearchParserTest, CheckParenthesisPriority) {
|
|||
TEST_F(SearchParserTest, MatchField) {
|
||||
ParseExpr("@f1:foo @f2:bar @f3:baz");
|
||||
|
||||
MockedHSetAccessor hset{};
|
||||
MockedDocument hset{};
|
||||
SearchInput input{&hset};
|
||||
|
||||
hset.Set({{"f1", "foo"}, {"f2", "bar"}, {"f3", "baz"}});
|
||||
|
@ -260,7 +260,7 @@ TEST_F(SearchParserTest, MatchField) {
|
|||
TEST_F(SearchParserTest, MatchRange) {
|
||||
ParseExpr("@f1:[1 10] @f2:[50 100]");
|
||||
|
||||
MockedHSetAccessor hset{};
|
||||
MockedDocument hset{};
|
||||
SearchInput input{&hset};
|
||||
|
||||
hset.Set({{"f1", "5"}, {"f2", "50"}});
|
||||
|
@ -282,7 +282,7 @@ TEST_F(SearchParserTest, MatchRange) {
|
|||
TEST_F(SearchParserTest, CheckExprInField) {
|
||||
ParseExpr("@f1:(a|b) @f2:(c d) @f3:-e");
|
||||
|
||||
MockedHSetAccessor hset{};
|
||||
MockedDocument hset{};
|
||||
SearchInput input{&hset};
|
||||
|
||||
hset.Set({{"f1", "a"}, {"f2", "c and d"}, {"f3", "right"}});
|
||||
|
|
|
@ -4,11 +4,14 @@
|
|||
|
||||
#include "server/search_family.h"
|
||||
|
||||
#include <jsoncons/json.hpp>
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
|
||||
#include "base/logging.h"
|
||||
#include "core/json_object.h"
|
||||
#include "core/search/search.h"
|
||||
#include "facade/error.h"
|
||||
#include "facade/reply_builder.h"
|
||||
#include "server/command_registry.h"
|
||||
#include "server/conn_context.h"
|
||||
|
@ -36,16 +39,18 @@ using DocumentData = absl::flat_hash_map<std::string, std::string>;
|
|||
using SerializedDocument = pair<std::string /*key*/, DocumentData>;
|
||||
using Query = search::AstExpr;
|
||||
|
||||
struct BaseAccessor : public search::HSetAccessor {
|
||||
using FieldConsumer = search::HSetAccessor::FieldConsumer;
|
||||
// Base class for document accessors
|
||||
struct BaseAccessor : public search::DocumentAccessor {
|
||||
using FieldConsumer = search::DocumentAccessor::FieldConsumer;
|
||||
|
||||
virtual DocumentData Serialize() const = 0;
|
||||
};
|
||||
|
||||
// Accessor for hashes stored with listpack
|
||||
struct ListPackAccessor : public BaseAccessor {
|
||||
using LpPtr = uint8_t*;
|
||||
|
||||
ListPackAccessor(LpPtr ptr) : lp_{ptr} {
|
||||
explicit ListPackAccessor(LpPtr ptr) : lp_{ptr} {
|
||||
}
|
||||
|
||||
bool Check(FieldConsumer f, string_view active_field) const override {
|
||||
|
@ -79,7 +84,7 @@ struct ListPackAccessor : public BaseAccessor {
|
|||
|
||||
while (fptr) {
|
||||
string_view k = container_utils::LpGetView(fptr, intbuf[0].data());
|
||||
fptr = lpNext(lp_, fptr); // skip key
|
||||
fptr = lpNext(lp_, fptr);
|
||||
string_view v = container_utils::LpGetView(fptr, intbuf[1].data());
|
||||
fptr = lpNext(lp_, fptr);
|
||||
|
||||
|
@ -93,8 +98,9 @@ struct ListPackAccessor : public BaseAccessor {
|
|||
LpPtr lp_;
|
||||
};
|
||||
|
||||
// Accessor for hashes stored with StringMap
|
||||
struct StringMapAccessor : public BaseAccessor {
|
||||
StringMapAccessor(StringMap* hset) : hset_{hset} {
|
||||
explicit StringMapAccessor(StringMap* hset) : hset_{hset} {
|
||||
}
|
||||
|
||||
bool Check(FieldConsumer f, string_view active_field) const override {
|
||||
|
@ -121,7 +127,42 @@ struct StringMapAccessor : public BaseAccessor {
|
|||
StringMap* hset_;
|
||||
};
|
||||
|
||||
// Accessor for json values
|
||||
struct JsonAccessor : public BaseAccessor {
|
||||
explicit JsonAccessor(JsonType* json) : json_{json} {
|
||||
}
|
||||
|
||||
bool Check(FieldConsumer f, string_view active_field) const override {
|
||||
if (!active_field.empty()) {
|
||||
return f(json_->get_value_or<string>(active_field, string{}));
|
||||
}
|
||||
for (const auto& member : json_->object_range()) {
|
||||
if (f(member.value().as_string()))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
DocumentData Serialize() const override {
|
||||
DocumentData out{};
|
||||
for (const auto& member : json_->object_range()) {
|
||||
out[member.key()] = member.value().as_string();
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
private:
|
||||
JsonType* json_;
|
||||
};
|
||||
|
||||
unique_ptr<BaseAccessor> GetAccessor(const OpArgs& op_args, const PrimeValue& pv) {
|
||||
DCHECK(pv.ObjType() == OBJ_HASH || pv.ObjType() == OBJ_JSON);
|
||||
|
||||
if (pv.ObjType() == OBJ_JSON) {
|
||||
DCHECK(pv.GetJson());
|
||||
return make_unique<JsonAccessor>(pv.GetJson());
|
||||
}
|
||||
|
||||
if (pv.Encoding() == kEncodingListPack) {
|
||||
auto ptr = reinterpret_cast<ListPackAccessor::LpPtr>(pv.RObjPtr());
|
||||
return make_unique<ListPackAccessor>(ptr);
|
||||
|
@ -133,7 +174,7 @@ unique_ptr<BaseAccessor> GetAccessor(const OpArgs& op_args, const PrimeValue& pv
|
|||
|
||||
// Perform brute force search for all hashes in shard with specific prefix
|
||||
// that match the query
|
||||
void OpSearch(const OpArgs& op_args, string_view prefix, const Query& query,
|
||||
void OpSearch(const OpArgs& op_args, const SearchFamily::IndexData& index, const Query& query,
|
||||
vector<SerializedDocument>* shard_out) {
|
||||
auto& db_slice = op_args.shard->db_slice();
|
||||
DCHECK(db_slice.IsDbValid(op_args.db_cntx.db_index));
|
||||
|
@ -143,12 +184,12 @@ void OpSearch(const OpArgs& op_args, string_view prefix, const Query& query,
|
|||
auto cb = [&](PrimeTable::iterator it) {
|
||||
// Check entry is hash
|
||||
const PrimeValue& pv = it->second;
|
||||
if (pv.ObjType() != OBJ_HASH)
|
||||
if (pv.ObjType() != index.GetObjCode())
|
||||
return;
|
||||
|
||||
// Check key starts with prefix
|
||||
string_view key = it->first.GetSlice(&scratch);
|
||||
if (key.rfind(prefix, 0) != 0)
|
||||
if (key.rfind(index.prefix, 0) != 0)
|
||||
return;
|
||||
|
||||
// Check entry matches filter
|
||||
|
@ -167,35 +208,61 @@ void OpSearch(const OpArgs& op_args, string_view prefix, const Query& query,
|
|||
|
||||
void SearchFamily::FtCreate(CmdArgList args, ConnectionContext* cntx) {
|
||||
string_view idx = ArgS(args, 0);
|
||||
string prefix;
|
||||
|
||||
if (args.size() > 1 && ArgS(args, 1) == "ON") {
|
||||
if (ArgS(args, 2) != "HASH" || ArgS(args, 3) != "PREFIX" || ArgS(args, 4) != "1") {
|
||||
(*cntx)->SendError("Only simplest config supported");
|
||||
return;
|
||||
IndexData index{};
|
||||
|
||||
for (size_t i = 1; i < args.size(); i++) {
|
||||
ToUpper(&args[i]);
|
||||
|
||||
// [ON HASH | JSON]
|
||||
if (ArgS(args, i) == "ON") {
|
||||
if (++i >= args.size())
|
||||
return (*cntx)->SendError(kSyntaxErr);
|
||||
|
||||
ToUpper(&args[i]);
|
||||
string_view type = ArgS(args, i);
|
||||
if (type == "HASH")
|
||||
index.type = IndexData::HASH;
|
||||
else if (type == "JSON")
|
||||
index.type = IndexData::JSON;
|
||||
else
|
||||
return (*cntx)->SendError("Invalid rule type: " + string{type});
|
||||
continue;
|
||||
}
|
||||
|
||||
// [PREFIX count prefix [prefix ...]]
|
||||
if (ArgS(args, i) == "PREFIX") {
|
||||
if (i + 2 >= args.size())
|
||||
return (*cntx)->SendError(kSyntaxErr);
|
||||
|
||||
if (ArgS(args, ++i) != "1")
|
||||
return (*cntx)->SendError("Multiple prefixes are not supported");
|
||||
|
||||
index.prefix = ArgS(args, ++i);
|
||||
continue;
|
||||
}
|
||||
prefix = ArgS(args, 5);
|
||||
}
|
||||
|
||||
{
|
||||
lock_guard lk{indices_mu_};
|
||||
indices_[idx] = prefix;
|
||||
indices_[idx] = move(index);
|
||||
}
|
||||
(*cntx)->SendOk();
|
||||
}
|
||||
|
||||
void SearchFamily::FtSearch(CmdArgList args, ConnectionContext* cntx) {
|
||||
string_view index = ArgS(args, 0);
|
||||
string_view index_name = ArgS(args, 0);
|
||||
string_view query_str = ArgS(args, 1);
|
||||
|
||||
string prefix;
|
||||
IndexData index;
|
||||
{
|
||||
lock_guard lk{indices_mu_};
|
||||
auto it = indices_.find(index);
|
||||
auto it = indices_.find(index_name);
|
||||
if (it == indices_.end()) {
|
||||
(*cntx)->SendError(string{index} + ": no such index");
|
||||
(*cntx)->SendError(string{index_name} + ": no such index");
|
||||
return;
|
||||
}
|
||||
prefix = it->second;
|
||||
index = it->second;
|
||||
}
|
||||
|
||||
Query query = search::ParseQuery(query_str);
|
||||
|
@ -206,7 +273,7 @@ void SearchFamily::FtSearch(CmdArgList args, ConnectionContext* cntx) {
|
|||
|
||||
vector<vector<SerializedDocument>> docs(shard_set->size());
|
||||
cntx->transaction->ScheduleSingleHop([&](Transaction* t, EngineShard* shard) {
|
||||
OpSearch(t->GetOpArgs(shard), prefix, query, &docs[shard->shard_id()]);
|
||||
OpSearch(t->GetOpArgs(shard), index, query, &docs[shard->shard_id()]);
|
||||
return OpStatus::OK;
|
||||
});
|
||||
|
||||
|
@ -228,6 +295,10 @@ void SearchFamily::FtSearch(CmdArgList args, ConnectionContext* cntx) {
|
|||
}
|
||||
}
|
||||
|
||||
uint8_t SearchFamily::IndexData::GetObjCode() const {
|
||||
return type == JSON ? OBJ_JSON : OBJ_HASH;
|
||||
}
|
||||
|
||||
#define HFUNC(x) SetHandler(&SearchFamily::x)
|
||||
|
||||
void SearchFamily::Register(CommandRegistry* registry) {
|
||||
|
@ -238,6 +309,6 @@ void SearchFamily::Register(CommandRegistry* registry) {
|
|||
}
|
||||
|
||||
Mutex SearchFamily::indices_mu_{};
|
||||
absl::flat_hash_map<std::string, std::string> SearchFamily::indices_{};
|
||||
absl::flat_hash_map<std::string, SearchFamily::IndexData> SearchFamily::indices_{};
|
||||
|
||||
} // namespace dfly
|
||||
|
|
|
@ -22,9 +22,19 @@ class SearchFamily {
|
|||
public:
|
||||
static void Register(CommandRegistry* registry);
|
||||
|
||||
struct IndexData {
|
||||
enum DataType { HASH, JSON };
|
||||
|
||||
// Get numeric OBJ_ code
|
||||
uint8_t GetObjCode() const;
|
||||
|
||||
std::string prefix{};
|
||||
DataType type{HASH};
|
||||
};
|
||||
|
||||
private:
|
||||
static Mutex indices_mu_;
|
||||
static absl::flat_hash_map<std::string, std::string> indices_;
|
||||
static absl::flat_hash_map<std::string, IndexData> indices_;
|
||||
};
|
||||
|
||||
} // namespace dfly
|
||||
|
|
|
@ -64,4 +64,33 @@ TEST_F(SearchFamilyTest, NoPrefix) {
|
|||
EXPECT_THAT(Run({"ft.search", "i1", "one | three"}), ArrLen(1 + 2 * 2));
|
||||
}
|
||||
|
||||
TEST_F(SearchFamilyTest, Json) {
|
||||
EXPECT_EQ(Run({"ft.create", "i1", "on", "json"}), "OK");
|
||||
Run({"json.set", "k1", ".", R"({"a": "small test", "b": "some details"})"});
|
||||
Run({"json.set", "k2", ".", R"({"a": "another test", "b": "more details"})"});
|
||||
Run({"json.set", "k3", ".", R"({"a": "last test", "b": "secret details"})"});
|
||||
|
||||
VLOG(0) << Run({"json.get", "k2", "$"});
|
||||
|
||||
{
|
||||
auto resp = Run({"ft.search", "i1", "more"});
|
||||
EXPECT_THAT(resp, ArrLen(1 + 2));
|
||||
|
||||
auto doc = resp.GetVec();
|
||||
EXPECT_THAT(doc[0], IntArg(1));
|
||||
EXPECT_EQ(doc[1], "k2");
|
||||
EXPECT_THAT(doc[2], ArrLen(4));
|
||||
}
|
||||
|
||||
EXPECT_THAT(Run({"ft.search", "i1", "some|more"}), ArrLen(1 + 2 * 2));
|
||||
EXPECT_THAT(Run({"ft.search", "i1", "some|more|secret"}), ArrLen(1 + 3 * 2));
|
||||
|
||||
EXPECT_THAT(Run({"ft.search", "i1", "@a:last @b:details"}), ArrLen(1 + 1 * 2));
|
||||
EXPECT_THAT(Run({"ft.search", "i1", "@a:(another|small)"}), ArrLen(1 + 2 * 2));
|
||||
EXPECT_THAT(Run({"ft.search", "i1", "@a:(another|small|secret)"}), ArrLen(1 + 2 * 2));
|
||||
|
||||
EXPECT_THAT(Run({"ft.search", "i1", "none"}), kNoResults);
|
||||
EXPECT_THAT(Run({"ft.search", "i1", "@a:small @b:secret"}), kNoResults);
|
||||
}
|
||||
|
||||
} // namespace dfly
|
||||
|
|
62
tests/dragonfly/search_test.py
Normal file
62
tests/dragonfly/search_test.py
Normal file
|
@ -0,0 +1,62 @@
|
|||
"""
|
||||
Test compatibility with the redis-py client search module.
|
||||
Search correctness should be ensured with unit tests.
|
||||
"""
|
||||
import pytest
|
||||
from redis import asyncio as aioredis
|
||||
from .utility import *
|
||||
|
||||
from redis.commands.search.query import Query
|
||||
from redis.commands.search.field import TextField
|
||||
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
|
||||
|
||||
TEST_DATA = [
|
||||
{"title": "First article", "content": "Long description"},
|
||||
{"title": "Second article", "content": "Small text"},
|
||||
{"title": "Third piece", "content": "Brief description"},
|
||||
{"title": "Last piece", "content": "Interesting text"},
|
||||
]
|
||||
|
||||
TEST_DATA_SCHEMA = [TextField("title"), TextField("content")]
|
||||
|
||||
|
||||
async def index_test_data(async_client: aioredis.Redis, itype: IndexType, prefix=""):
|
||||
for i, e in enumerate(TEST_DATA):
|
||||
if itype == IndexType.HASH:
|
||||
await async_client.hset(prefix+str(i), mapping=e)
|
||||
else:
|
||||
await async_client.json().set(prefix+str(i), "$", e)
|
||||
|
||||
|
||||
def contains_test_data(res, td_indices):
|
||||
if res.total != len(td_indices):
|
||||
return False
|
||||
|
||||
docset = set()
|
||||
for doc in res.docs:
|
||||
docset.add(f"{doc.title}//{doc.content}")
|
||||
|
||||
for td_entry in (TEST_DATA[tdi] for tdi in td_indices):
|
||||
if not f"{td_entry['title']}//{td_entry['content']}" in docset:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_type", [IndexType.HASH, IndexType.JSON])
|
||||
async def test_basic(async_client, index_type):
|
||||
i1 = async_client.ft("i1")
|
||||
await i1.create_index(TEST_DATA_SCHEMA, definition=IndexDefinition(index_type=index_type))
|
||||
await index_test_data(async_client, index_type)
|
||||
|
||||
res = await i1.search("article")
|
||||
assert contains_test_data(res, [0, 1])
|
||||
|
||||
res = await i1.search("text")
|
||||
assert contains_test_data(res, [1, 3])
|
||||
|
||||
res = await i1.search("brief piece")
|
||||
assert contains_test_data(res, [2])
|
||||
|
||||
res = await i1.search("@title:(article|last) @content:text")
|
||||
assert contains_test_data(res, [1, 3])
|
Loading…
Reference in a new issue