1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-15 17:51:06 +00:00

feat: search json support + client tests (#1210)

* feat: search json support + client tests

Signed-off-by: Vladislav Oleshko <vlad@dragonflydb.io>

* fix: small fixes

Signed-off-by: Vladislav Oleshko <vlad@dragonflydb.io>

* fix: small fixes

Signed-off-by: Vladislav Oleshko <vlad@dragonflydb.io>

---------

Signed-off-by: Vladislav Oleshko <vlad@dragonflydb.io>
This commit is contained in:
Vladislav 2023-05-16 11:11:28 +03:00 committed by GitHub
parent 790e357aaf
commit 7f547151bf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 213 additions and 41 deletions

View file

@ -13,30 +13,30 @@
namespace dfly::search {
// Interface for accessing hashset values with different data structures underneath.
struct HSetAccessor {
// Interface for accessing document values with different data structures underneath.
struct DocumentAccessor {
// Callback that's supplied with field values.
using FieldConsumer = std::function<bool(std::string_view)>;
virtual bool Check(FieldConsumer f, std::string_view active_field) const = 0;
};
// Wrapper around hashset accessor and optional active field.
// Wrapper around document accessor and optional active field.
struct SearchInput {
SearchInput(const HSetAccessor* hset, std::string_view active_field = {})
: hset_{hset}, active_field_{active_field} {
SearchInput(const DocumentAccessor* doc, std::string_view active_field = {})
: doc_{doc}, active_field_{active_field} {
}
SearchInput(const SearchInput& base, std::string_view active_field)
: hset_{base.hset_}, active_field_{active_field} {
: doc_{base.doc_}, active_field_{active_field} {
}
bool Check(HSetAccessor::FieldConsumer f) {
return hset_->Check(move(f), active_field_);
bool Check(DocumentAccessor::FieldConsumer f) {
return doc_->Check(move(f), active_field_);
}
private:
const HSetAccessor* hset_;
const DocumentAccessor* doc_;
std::string_view active_field_;
};

View file

@ -52,15 +52,15 @@ class SearchParserTest : public ::testing::Test {
QueryDriver query_driver_;
};
class MockedHSetAccessor : public HSetAccessor {
class MockedDocument : public DocumentAccessor {
public:
using Map = std::unordered_map<std::string, std::string>;
MockedHSetAccessor() = default;
MockedHSetAccessor(std::string test_field) : hset_{{"field", test_field}} {
MockedDocument() = default;
MockedDocument(std::string test_field) : hset_{{"field", test_field}} {
}
bool Check(HSetAccessor::FieldConsumer f, string_view active_field) const override {
bool Check(DocumentAccessor::FieldConsumer f, string_view active_field) const override {
if (!active_field.empty()) {
auto it = hset_.find(string{active_field});
return f(it != hset_.end() ? it->second : "");
@ -108,7 +108,7 @@ class MockedHSetAccessor : public HSetAccessor {
#define CHECK_ALL(...) \
{ \
for (auto str : {__VA_ARGS__}) { \
MockedHSetAccessor hset{str}; \
MockedDocument hset{str}; \
EXPECT_TRUE(Check(SearchInput{&hset})) << str << " failed on " << DebugExpr(); \
} \
}
@ -116,7 +116,7 @@ class MockedHSetAccessor : public HSetAccessor {
#define CHECK_NONE(...) \
{ \
for (auto str : {__VA_ARGS__}) { \
MockedHSetAccessor hset{str}; \
MockedDocument hset{str}; \
EXPECT_FALSE(Check(SearchInput{&hset})) << str << " failed on " << DebugExpr(); \
} \
}
@ -238,7 +238,7 @@ TEST_F(SearchParserTest, CheckParenthesisPriority) {
TEST_F(SearchParserTest, MatchField) {
ParseExpr("@f1:foo @f2:bar @f3:baz");
MockedHSetAccessor hset{};
MockedDocument hset{};
SearchInput input{&hset};
hset.Set({{"f1", "foo"}, {"f2", "bar"}, {"f3", "baz"}});
@ -260,7 +260,7 @@ TEST_F(SearchParserTest, MatchField) {
TEST_F(SearchParserTest, MatchRange) {
ParseExpr("@f1:[1 10] @f2:[50 100]");
MockedHSetAccessor hset{};
MockedDocument hset{};
SearchInput input{&hset};
hset.Set({{"f1", "5"}, {"f2", "50"}});
@ -282,7 +282,7 @@ TEST_F(SearchParserTest, MatchRange) {
TEST_F(SearchParserTest, CheckExprInField) {
ParseExpr("@f1:(a|b) @f2:(c d) @f3:-e");
MockedHSetAccessor hset{};
MockedDocument hset{};
SearchInput input{&hset};
hset.Set({{"f1", "a"}, {"f2", "c and d"}, {"f3", "right"}});

View file

@ -4,11 +4,14 @@
#include "server/search_family.h"
#include <jsoncons/json.hpp>
#include <variant>
#include <vector>
#include "base/logging.h"
#include "core/json_object.h"
#include "core/search/search.h"
#include "facade/error.h"
#include "facade/reply_builder.h"
#include "server/command_registry.h"
#include "server/conn_context.h"
@ -36,16 +39,18 @@ using DocumentData = absl::flat_hash_map<std::string, std::string>;
using SerializedDocument = pair<std::string /*key*/, DocumentData>;
using Query = search::AstExpr;
struct BaseAccessor : public search::HSetAccessor {
using FieldConsumer = search::HSetAccessor::FieldConsumer;
// Base class for document accessors
struct BaseAccessor : public search::DocumentAccessor {
using FieldConsumer = search::DocumentAccessor::FieldConsumer;
virtual DocumentData Serialize() const = 0;
};
// Accessor for hashes stored with listpack
struct ListPackAccessor : public BaseAccessor {
using LpPtr = uint8_t*;
ListPackAccessor(LpPtr ptr) : lp_{ptr} {
explicit ListPackAccessor(LpPtr ptr) : lp_{ptr} {
}
bool Check(FieldConsumer f, string_view active_field) const override {
@ -79,7 +84,7 @@ struct ListPackAccessor : public BaseAccessor {
while (fptr) {
string_view k = container_utils::LpGetView(fptr, intbuf[0].data());
fptr = lpNext(lp_, fptr); // skip key
fptr = lpNext(lp_, fptr);
string_view v = container_utils::LpGetView(fptr, intbuf[1].data());
fptr = lpNext(lp_, fptr);
@ -93,8 +98,9 @@ struct ListPackAccessor : public BaseAccessor {
LpPtr lp_;
};
// Accessor for hashes stored with StringMap
struct StringMapAccessor : public BaseAccessor {
StringMapAccessor(StringMap* hset) : hset_{hset} {
explicit StringMapAccessor(StringMap* hset) : hset_{hset} {
}
bool Check(FieldConsumer f, string_view active_field) const override {
@ -121,7 +127,42 @@ struct StringMapAccessor : public BaseAccessor {
StringMap* hset_;
};
// Accessor for json values
struct JsonAccessor : public BaseAccessor {
explicit JsonAccessor(JsonType* json) : json_{json} {
}
bool Check(FieldConsumer f, string_view active_field) const override {
if (!active_field.empty()) {
return f(json_->get_value_or<string>(active_field, string{}));
}
for (const auto& member : json_->object_range()) {
if (f(member.value().as_string()))
return true;
}
return false;
}
DocumentData Serialize() const override {
DocumentData out{};
for (const auto& member : json_->object_range()) {
out[member.key()] = member.value().as_string();
}
return out;
}
private:
JsonType* json_;
};
unique_ptr<BaseAccessor> GetAccessor(const OpArgs& op_args, const PrimeValue& pv) {
DCHECK(pv.ObjType() == OBJ_HASH || pv.ObjType() == OBJ_JSON);
if (pv.ObjType() == OBJ_JSON) {
DCHECK(pv.GetJson());
return make_unique<JsonAccessor>(pv.GetJson());
}
if (pv.Encoding() == kEncodingListPack) {
auto ptr = reinterpret_cast<ListPackAccessor::LpPtr>(pv.RObjPtr());
return make_unique<ListPackAccessor>(ptr);
@ -133,7 +174,7 @@ unique_ptr<BaseAccessor> GetAccessor(const OpArgs& op_args, const PrimeValue& pv
// Perform brute force search for all hashes in shard with specific prefix
// that match the query
void OpSearch(const OpArgs& op_args, string_view prefix, const Query& query,
void OpSearch(const OpArgs& op_args, const SearchFamily::IndexData& index, const Query& query,
vector<SerializedDocument>* shard_out) {
auto& db_slice = op_args.shard->db_slice();
DCHECK(db_slice.IsDbValid(op_args.db_cntx.db_index));
@ -143,12 +184,12 @@ void OpSearch(const OpArgs& op_args, string_view prefix, const Query& query,
auto cb = [&](PrimeTable::iterator it) {
// Check entry is hash
const PrimeValue& pv = it->second;
if (pv.ObjType() != OBJ_HASH)
if (pv.ObjType() != index.GetObjCode())
return;
// Check key starts with prefix
string_view key = it->first.GetSlice(&scratch);
if (key.rfind(prefix, 0) != 0)
if (key.rfind(index.prefix, 0) != 0)
return;
// Check entry matches filter
@ -167,35 +208,61 @@ void OpSearch(const OpArgs& op_args, string_view prefix, const Query& query,
void SearchFamily::FtCreate(CmdArgList args, ConnectionContext* cntx) {
string_view idx = ArgS(args, 0);
string prefix;
if (args.size() > 1 && ArgS(args, 1) == "ON") {
if (ArgS(args, 2) != "HASH" || ArgS(args, 3) != "PREFIX" || ArgS(args, 4) != "1") {
(*cntx)->SendError("Only simplest config supported");
return;
IndexData index{};
for (size_t i = 1; i < args.size(); i++) {
ToUpper(&args[i]);
// [ON HASH | JSON]
if (ArgS(args, i) == "ON") {
if (++i >= args.size())
return (*cntx)->SendError(kSyntaxErr);
ToUpper(&args[i]);
string_view type = ArgS(args, i);
if (type == "HASH")
index.type = IndexData::HASH;
else if (type == "JSON")
index.type = IndexData::JSON;
else
return (*cntx)->SendError("Invalid rule type: " + string{type});
continue;
}
// [PREFIX count prefix [prefix ...]]
if (ArgS(args, i) == "PREFIX") {
if (i + 2 >= args.size())
return (*cntx)->SendError(kSyntaxErr);
if (ArgS(args, ++i) != "1")
return (*cntx)->SendError("Multiple prefixes are not supported");
index.prefix = ArgS(args, ++i);
continue;
}
prefix = ArgS(args, 5);
}
{
lock_guard lk{indices_mu_};
indices_[idx] = prefix;
indices_[idx] = move(index);
}
(*cntx)->SendOk();
}
void SearchFamily::FtSearch(CmdArgList args, ConnectionContext* cntx) {
string_view index = ArgS(args, 0);
string_view index_name = ArgS(args, 0);
string_view query_str = ArgS(args, 1);
string prefix;
IndexData index;
{
lock_guard lk{indices_mu_};
auto it = indices_.find(index);
auto it = indices_.find(index_name);
if (it == indices_.end()) {
(*cntx)->SendError(string{index} + ": no such index");
(*cntx)->SendError(string{index_name} + ": no such index");
return;
}
prefix = it->second;
index = it->second;
}
Query query = search::ParseQuery(query_str);
@ -206,7 +273,7 @@ void SearchFamily::FtSearch(CmdArgList args, ConnectionContext* cntx) {
vector<vector<SerializedDocument>> docs(shard_set->size());
cntx->transaction->ScheduleSingleHop([&](Transaction* t, EngineShard* shard) {
OpSearch(t->GetOpArgs(shard), prefix, query, &docs[shard->shard_id()]);
OpSearch(t->GetOpArgs(shard), index, query, &docs[shard->shard_id()]);
return OpStatus::OK;
});
@ -228,6 +295,10 @@ void SearchFamily::FtSearch(CmdArgList args, ConnectionContext* cntx) {
}
}
uint8_t SearchFamily::IndexData::GetObjCode() const {
return type == JSON ? OBJ_JSON : OBJ_HASH;
}
#define HFUNC(x) SetHandler(&SearchFamily::x)
void SearchFamily::Register(CommandRegistry* registry) {
@ -238,6 +309,6 @@ void SearchFamily::Register(CommandRegistry* registry) {
}
Mutex SearchFamily::indices_mu_{};
absl::flat_hash_map<std::string, std::string> SearchFamily::indices_{};
absl::flat_hash_map<std::string, SearchFamily::IndexData> SearchFamily::indices_{};
} // namespace dfly

View file

@ -22,9 +22,19 @@ class SearchFamily {
public:
static void Register(CommandRegistry* registry);
struct IndexData {
enum DataType { HASH, JSON };
// Get numeric OBJ_ code
uint8_t GetObjCode() const;
std::string prefix{};
DataType type{HASH};
};
private:
static Mutex indices_mu_;
static absl::flat_hash_map<std::string, std::string> indices_;
static absl::flat_hash_map<std::string, IndexData> indices_;
};
} // namespace dfly

View file

@ -64,4 +64,33 @@ TEST_F(SearchFamilyTest, NoPrefix) {
EXPECT_THAT(Run({"ft.search", "i1", "one | three"}), ArrLen(1 + 2 * 2));
}
TEST_F(SearchFamilyTest, Json) {
EXPECT_EQ(Run({"ft.create", "i1", "on", "json"}), "OK");
Run({"json.set", "k1", ".", R"({"a": "small test", "b": "some details"})"});
Run({"json.set", "k2", ".", R"({"a": "another test", "b": "more details"})"});
Run({"json.set", "k3", ".", R"({"a": "last test", "b": "secret details"})"});
VLOG(0) << Run({"json.get", "k2", "$"});
{
auto resp = Run({"ft.search", "i1", "more"});
EXPECT_THAT(resp, ArrLen(1 + 2));
auto doc = resp.GetVec();
EXPECT_THAT(doc[0], IntArg(1));
EXPECT_EQ(doc[1], "k2");
EXPECT_THAT(doc[2], ArrLen(4));
}
EXPECT_THAT(Run({"ft.search", "i1", "some|more"}), ArrLen(1 + 2 * 2));
EXPECT_THAT(Run({"ft.search", "i1", "some|more|secret"}), ArrLen(1 + 3 * 2));
EXPECT_THAT(Run({"ft.search", "i1", "@a:last @b:details"}), ArrLen(1 + 1 * 2));
EXPECT_THAT(Run({"ft.search", "i1", "@a:(another|small)"}), ArrLen(1 + 2 * 2));
EXPECT_THAT(Run({"ft.search", "i1", "@a:(another|small|secret)"}), ArrLen(1 + 2 * 2));
EXPECT_THAT(Run({"ft.search", "i1", "none"}), kNoResults);
EXPECT_THAT(Run({"ft.search", "i1", "@a:small @b:secret"}), kNoResults);
}
} // namespace dfly

View file

@ -0,0 +1,62 @@
"""
Test compatibility with the redis-py client search module.
Search correctness should be ensured with unit tests.
"""
import pytest
from redis import asyncio as aioredis
from .utility import *
from redis.commands.search.query import Query
from redis.commands.search.field import TextField
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
TEST_DATA = [
{"title": "First article", "content": "Long description"},
{"title": "Second article", "content": "Small text"},
{"title": "Third piece", "content": "Brief description"},
{"title": "Last piece", "content": "Interesting text"},
]
TEST_DATA_SCHEMA = [TextField("title"), TextField("content")]
async def index_test_data(async_client: aioredis.Redis, itype: IndexType, prefix=""):
for i, e in enumerate(TEST_DATA):
if itype == IndexType.HASH:
await async_client.hset(prefix+str(i), mapping=e)
else:
await async_client.json().set(prefix+str(i), "$", e)
def contains_test_data(res, td_indices):
if res.total != len(td_indices):
return False
docset = set()
for doc in res.docs:
docset.add(f"{doc.title}//{doc.content}")
for td_entry in (TEST_DATA[tdi] for tdi in td_indices):
if not f"{td_entry['title']}//{td_entry['content']}" in docset:
return False
return True
@pytest.mark.parametrize("index_type", [IndexType.HASH, IndexType.JSON])
async def test_basic(async_client, index_type):
i1 = async_client.ft("i1")
await i1.create_index(TEST_DATA_SCHEMA, definition=IndexDefinition(index_type=index_type))
await index_test_data(async_client, index_type)
res = await i1.search("article")
assert contains_test_data(res, [0, 1])
res = await i1.search("text")
assert contains_test_data(res, [1, 3])
res = await i1.search("brief piece")
assert contains_test_data(res, [2])
res = await i1.search("@title:(article|last) @content:text")
assert contains_test_data(res, [1, 3])