mirror of
https://github.com/dragonflydb/dragonfly.git
synced 2024-12-14 11:58:02 +00:00
feat(search): sized vectors (#1788)
* feat(search): Sized vectors --------- Signed-off-by: Vladislav Oleshko <vlad@dragonflydb.io>
This commit is contained in:
parent
36be222091
commit
aa4cadfa12
18 changed files with 275 additions and 122 deletions
|
@ -5,7 +5,7 @@ cur_gen_dir(gen_dir)
|
|||
|
||||
find_package(ICU REQUIRED COMPONENTS uc i18n)
|
||||
|
||||
add_library(query_parser ast_expr.cc query_driver.cc search.cc indices.cc vector.cc compressed_sorted_set.cc
|
||||
add_library(query_parser ast_expr.cc query_driver.cc search.cc indices.cc vector_utils.cc compressed_sorted_set.cc
|
||||
${gen_dir}/parser.cc ${gen_dir}/lexer.cc)
|
||||
|
||||
target_link_libraries(query_parser ICU::uc ICU::i18n)
|
||||
|
|
|
@ -56,9 +56,11 @@ AstTagsNode::AstTagsNode(AstExpr&& l, std::string tag) {
|
|||
tags.push_back(move(tag));
|
||||
}
|
||||
|
||||
AstKnnNode::AstKnnNode(AstNode&& filter, size_t limit, std::string field, FtVector vec)
|
||||
: filter{make_unique<AstNode>(move(filter))}, limit{limit}, field{field.substr(1)}, vector{move(
|
||||
vec)} {
|
||||
AstKnnNode::AstKnnNode(AstNode&& filter, size_t limit, std::string_view field, OwnedFtVector vec)
|
||||
: filter{make_unique<AstNode>(std::move(filter))},
|
||||
limit{limit},
|
||||
field{field.substr(1)},
|
||||
vec{std::move(vec)} {
|
||||
}
|
||||
|
||||
} // namespace dfly::search
|
||||
|
|
|
@ -74,12 +74,12 @@ struct AstTagsNode {
|
|||
|
||||
// Applies nearest neighbor search to the final result set
|
||||
struct AstKnnNode {
|
||||
AstKnnNode(AstNode&& sub, size_t limit, std::string field, FtVector vec);
|
||||
AstKnnNode(AstNode&& sub, size_t limit, std::string_view field, OwnedFtVector vec);
|
||||
|
||||
std::unique_ptr<AstNode> filter;
|
||||
size_t limit;
|
||||
std::string field;
|
||||
FtVector vector;
|
||||
OwnedFtVector vec;
|
||||
};
|
||||
|
||||
using NodeVariants =
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
#include <absl/container/flat_hash_map.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
|
@ -14,7 +15,9 @@ namespace dfly::search {
|
|||
|
||||
using DocId = uint32_t;
|
||||
|
||||
using FtVector = std::vector<float>;
|
||||
enum class VectorSimilarity { L2, COSINE };
|
||||
|
||||
using OwnedFtVector = std::pair<std::unique_ptr<float[]>, size_t /* dimension (size) */>;
|
||||
|
||||
// Query params represent named parameters for queries supplied via PARAMS.
|
||||
struct QueryParams {
|
||||
|
@ -38,9 +41,11 @@ struct QueryParams {
|
|||
|
||||
// Interface for accessing document values with different data structures underneath.
|
||||
struct DocumentAccessor {
|
||||
using VectorInfo = search::OwnedFtVector;
|
||||
|
||||
virtual ~DocumentAccessor() = default;
|
||||
virtual std::string_view GetString(std::string_view active_field) const = 0;
|
||||
virtual FtVector GetVector(std::string_view active_field) const = 0;
|
||||
virtual VectorInfo GetVector(std::string_view active_field) const = 0;
|
||||
};
|
||||
|
||||
// Base class for type-specific indices.
|
||||
|
|
|
@ -59,6 +59,11 @@ class CompressedSortedSet {
|
|||
size_t Size() const;
|
||||
size_t ByteSize() const;
|
||||
|
||||
// To use transparently in templates together with stl containers
|
||||
size_t size() const {
|
||||
return Size();
|
||||
}
|
||||
|
||||
private:
|
||||
struct EntryLocation {
|
||||
IntType value; // Value or 0
|
||||
|
|
|
@ -151,17 +151,31 @@ absl::flat_hash_set<std::string> TagIndex::Tokenize(std::string_view value) cons
|
|||
return NormalizeTags(value);
|
||||
}
|
||||
|
||||
VectorIndex::VectorIndex(size_t dim, VectorSimilarity sim) : dim_{dim}, sim_{sim}, entries_{} {
|
||||
}
|
||||
|
||||
void VectorIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
|
||||
entries_[id] = doc->GetVector(field);
|
||||
DCHECK_LE(id * dim_, entries_.size());
|
||||
if (id * dim_ == entries_.size())
|
||||
entries_.resize((id + 1) * dim_);
|
||||
|
||||
// TODO: Let get vector write to buf itself
|
||||
auto [ptr, size] = doc->GetVector(field);
|
||||
|
||||
if (size == dim_)
|
||||
memcpy(&entries_[id * dim_], ptr.get(), dim_ * sizeof(float));
|
||||
}
|
||||
|
||||
void VectorIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) {
|
||||
entries_.erase(id);
|
||||
// noop
|
||||
}
|
||||
|
||||
FtVector VectorIndex::Get(DocId doc) const {
|
||||
auto it = entries_.find(doc);
|
||||
return it != entries_.end() ? it->second : FtVector{};
|
||||
const float* VectorIndex::Get(DocId doc) const {
|
||||
return &entries_[doc * dim_];
|
||||
}
|
||||
|
||||
std::pair<size_t /*dim*/, VectorSimilarity> VectorIndex::Info() const {
|
||||
return {dim_, sim_};
|
||||
}
|
||||
|
||||
} // namespace dfly::search
|
||||
|
|
|
@ -57,13 +57,18 @@ struct TagIndex : public BaseStringIndex {
|
|||
// Index for vector fields.
|
||||
// Only supports lookup by id.
|
||||
struct VectorIndex : public BaseIndex {
|
||||
VectorIndex(size_t dim, VectorSimilarity sim);
|
||||
|
||||
void Add(DocId id, DocumentAccessor* doc, std::string_view field) override;
|
||||
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override;
|
||||
|
||||
FtVector Get(DocId doc) const;
|
||||
const float* Get(DocId doc) const;
|
||||
std::pair<size_t /*dim*/, VectorSimilarity> Info() const;
|
||||
|
||||
private:
|
||||
absl::flat_hash_map<DocId, FtVector> entries_;
|
||||
size_t dim_;
|
||||
VectorSimilarity sim_;
|
||||
std::vector<float> entries_;
|
||||
};
|
||||
|
||||
} // namespace dfly::search
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
// Added to cc file
|
||||
%code {
|
||||
#include "core/search/query_driver.h"
|
||||
#include "core/search/vector.h"
|
||||
#include "core/search/vector_utils.h"
|
||||
|
||||
// Have to disable because GCC doesn't understand `symbol_type`'s union
|
||||
// implementation
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include <absl/strings/str_join.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <type_traits>
|
||||
#include <variant>
|
||||
|
||||
#include "base/logging.h"
|
||||
|
@ -18,7 +19,7 @@
|
|||
#include "core/search/compressed_sorted_set.h"
|
||||
#include "core/search/indices.h"
|
||||
#include "core/search/query_driver.h"
|
||||
#include "core/search/vector.h"
|
||||
#include "core/search/vector_utils.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
@ -35,11 +36,18 @@ AstExpr ParseQuery(std::string_view query, const QueryParams* params) {
|
|||
return driver.Take();
|
||||
}
|
||||
|
||||
// GCC 12 yields a wrong warning in a deeply inlined call in UnifyResults, only ignoring the whole
|
||||
// scope solves it
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
|
||||
|
||||
// Represents an either owned or non-owned result set that can be accessed transparently.
|
||||
struct IndexResult {
|
||||
using DocVec = vector<DocId>;
|
||||
using BorrowedView = variant<const DocVec*, const CompressedSortedSet*>;
|
||||
|
||||
IndexResult() : value_{DocVec{}} {};
|
||||
IndexResult() : value_{DocVec{}} {
|
||||
}
|
||||
|
||||
IndexResult(const CompressedSortedSet* css) : value_{css} {
|
||||
if (css == nullptr)
|
||||
|
@ -49,10 +57,11 @@ struct IndexResult {
|
|||
IndexResult(DocVec&& dv) : value_{move(dv)} {
|
||||
}
|
||||
|
||||
IndexResult(const DocVec* dv) : value_{dv} {
|
||||
}
|
||||
|
||||
size_t Size() const {
|
||||
if (holds_alternative<DocVec>(value_))
|
||||
return get<DocVec>(value_).size();
|
||||
return get<const CompressedSortedSet*>(value_)->Size();
|
||||
return visit([](auto* set) { return set->size(); }, Borrowed());
|
||||
}
|
||||
|
||||
bool IsOwned() const {
|
||||
|
@ -64,28 +73,31 @@ struct IndexResult {
|
|||
swap(get<DocVec>(value_), entries); // swap to keep backing array
|
||||
entries.clear();
|
||||
} else {
|
||||
value_ = move(entries);
|
||||
value_ = std::move(entries);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
variant<const DocVec*, const CompressedSortedSet*> Borrowed() {
|
||||
if (holds_alternative<DocVec>(value_))
|
||||
return &get<DocVec>(value_);
|
||||
return get<const CompressedSortedSet*>(value_);
|
||||
BorrowedView Borrowed() const {
|
||||
auto cb = [](const auto& v) -> BorrowedView {
|
||||
if constexpr (is_pointer_v<remove_reference_t<decltype(v)>>)
|
||||
return v;
|
||||
else
|
||||
return &v;
|
||||
};
|
||||
return visit(cb, value_);
|
||||
}
|
||||
|
||||
// Move out of owned or copy borrowed
|
||||
DocVec Take() {
|
||||
if (holds_alternative<DocVec>(value_))
|
||||
if (IsOwned())
|
||||
return move(get<DocVec>(value_));
|
||||
|
||||
const CompressedSortedSet* css = get<const CompressedSortedSet*>(value_);
|
||||
return DocVec(css->begin(), css->end());
|
||||
return visit([](auto* set) { return DocVec(set->begin(), set->end()); }, Borrowed());
|
||||
}
|
||||
|
||||
private:
|
||||
variant<DocVec /*owned*/, const CompressedSortedSet* /* borrowed */> value_;
|
||||
variant<DocVec /*owned*/, const CompressedSortedSet*, const DocVec*> value_;
|
||||
};
|
||||
|
||||
struct ProfileBuilder {
|
||||
|
@ -194,7 +206,7 @@ struct BasicSearch {
|
|||
sort(sub_results.begin(), sub_results.end(),
|
||||
[](const auto& l, const auto& r) { return l.Size() < r.Size(); });
|
||||
|
||||
IndexResult out{move(sub_results[0])};
|
||||
IndexResult out{std::move(sub_results[0])};
|
||||
for (auto& matched : absl::MakeSpan(sub_results).subspan(1))
|
||||
Merge(move(matched), &out, op);
|
||||
return out;
|
||||
|
@ -206,7 +218,7 @@ struct BasicSearch {
|
|||
|
||||
IndexResult Search(const AstStarNode& node, string_view active_field) {
|
||||
DCHECK(active_field.empty());
|
||||
return vector<DocId>{indices_->GetAllDocs()}; // TODO FIX;
|
||||
return {&indices_->GetAllDocs()};
|
||||
}
|
||||
|
||||
// "term": access field's text index or unify results from all text indices if no field is set
|
||||
|
@ -268,19 +280,23 @@ struct BasicSearch {
|
|||
auto sub_results = SearchGeneric(*knn.filter, active_field);
|
||||
|
||||
auto* vec_index = GetIndex<VectorIndex>(knn.field);
|
||||
if (auto [dim, _] = vec_index->Info(); dim != knn.vec.second)
|
||||
return IndexResult{};
|
||||
|
||||
distances_.reserve(sub_results.Size());
|
||||
auto cb = [&](auto* set) {
|
||||
auto [dim, sim] = vec_index->Info();
|
||||
for (DocId matched_doc : *set) {
|
||||
float dist = VectorDistance(knn.vector, vec_index->Get(matched_doc));
|
||||
float dist = VectorDistance(knn.vec.first.get(), vec_index->Get(matched_doc), dim, sim);
|
||||
distances_.emplace_back(dist, matched_doc);
|
||||
}
|
||||
};
|
||||
visit(cb, sub_results.Borrowed());
|
||||
|
||||
sort(distances_.begin(), distances_.end());
|
||||
size_t prefix_size = min(knn.limit, distances_.size());
|
||||
partial_sort(distances_.begin(), distances_.begin() + prefix_size, distances_.end());
|
||||
|
||||
vector<DocId> out(min(knn.limit, distances_.size()));
|
||||
vector<DocId> out(prefix_size);
|
||||
for (size_t i = 0; i < out.size(); i++)
|
||||
out[i] = distances_[i].second;
|
||||
|
||||
|
@ -331,6 +347,8 @@ struct BasicSearch {
|
|||
vector<pair<float, DocId>> distances_;
|
||||
};
|
||||
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
} // namespace
|
||||
|
||||
FieldIndices::FieldIndices(Schema schema) : schema_{move(schema)}, all_ids_{}, indices_{} {
|
||||
|
@ -346,7 +364,7 @@ FieldIndices::FieldIndices(Schema schema) : schema_{move(schema)}, all_ids_{}, i
|
|||
indices_[field_ident] = make_unique<NumericIndex>();
|
||||
break;
|
||||
case SchemaField::VECTOR:
|
||||
indices_[field_ident] = make_unique<VectorIndex>();
|
||||
indices_[field_ident] = make_unique<VectorIndex>(field_info.knn_dim, field_info.knn_sim);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,6 +25,9 @@ struct SchemaField {
|
|||
|
||||
FieldType type;
|
||||
std::string short_name; // equal to ident if none provided
|
||||
|
||||
size_t knn_dim = 0u; // dimension of knn vectors
|
||||
VectorSimilarity knn_sim = VectorSimilarity::L2; // similarity type
|
||||
};
|
||||
|
||||
// Describes the fields of an index
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "base/logging.h"
|
||||
#include "core/search/base.h"
|
||||
#include "core/search/query_driver.h"
|
||||
#include "core/search/vector_utils.h"
|
||||
|
||||
namespace dfly {
|
||||
namespace search {
|
||||
|
@ -40,15 +41,8 @@ struct MockedDocument : public DocumentAccessor {
|
|||
return it != fields_.end() ? string_view{it->second} : "";
|
||||
}
|
||||
|
||||
FtVector GetVector(string_view field) const override {
|
||||
string_view str_value = fields_.at(field);
|
||||
FtVector out;
|
||||
for (string_view coord : absl::StrSplit(str_value, ',')) {
|
||||
float v;
|
||||
CHECK(absl::SimpleAtof(coord, &v));
|
||||
out.push_back(v);
|
||||
}
|
||||
return out;
|
||||
VectorInfo GetVector(string_view field) const override {
|
||||
return BytesToFtVector(GetString(field));
|
||||
}
|
||||
|
||||
string DebugFormat() {
|
||||
|
@ -331,17 +325,18 @@ TEST_F(SearchParserTest, IntegerTerms) {
|
|||
EXPECT_TRUE(Check()) << GetError();
|
||||
}
|
||||
|
||||
std::string FtVectorToBytes(FtVector vec) {
|
||||
std::string ToBytes(absl::Span<const float> vec) {
|
||||
return string{reinterpret_cast<const char*>(vec.data()), sizeof(float) * vec.size()};
|
||||
}
|
||||
|
||||
TEST_F(SearchParserTest, SimpleKnn) {
|
||||
auto schema = MakeSimpleSchema({{"even", SchemaField::TAG}, {"pos", SchemaField::VECTOR}});
|
||||
schema.fields["pos"].knn_dim = 1;
|
||||
FieldIndices indices{schema};
|
||||
|
||||
// Place points on a straight line
|
||||
for (size_t i = 0; i < 100; i++) {
|
||||
Map values{{{"even", i % 2 == 0 ? "YES" : "NO"}, {"pos", to_string(float(i))}}};
|
||||
Map values{{{"even", i % 2 == 0 ? "YES" : "NO"}, {"pos", ToBytes({float(i)})}}};
|
||||
MockedDocument doc{values};
|
||||
indices.Add(i, &doc);
|
||||
}
|
||||
|
@ -351,35 +346,35 @@ TEST_F(SearchParserTest, SimpleKnn) {
|
|||
|
||||
// Five closest to 50
|
||||
{
|
||||
params["vec"] = FtVectorToBytes(FtVector{50.0});
|
||||
params["vec"] = ToBytes({50.0});
|
||||
algo.Init("*=>[KNN 5 @pos $vec]", ¶ms);
|
||||
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(48, 49, 50, 51, 52));
|
||||
}
|
||||
|
||||
// Five closest to 0
|
||||
{
|
||||
params["vec"] = FtVectorToBytes(FtVector{0.0});
|
||||
params["vec"] = ToBytes({0.0});
|
||||
algo.Init("*=>[KNN 5 @pos $vec]", ¶ms);
|
||||
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0, 1, 2, 3, 4));
|
||||
}
|
||||
|
||||
// Five closest to 20, all even
|
||||
{
|
||||
params["vec"] = FtVectorToBytes(FtVector{20.0});
|
||||
params["vec"] = ToBytes({20.0});
|
||||
algo.Init("@even:{yes} =>[KNN 5 @pos $vec]", ¶ms);
|
||||
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(16, 18, 20, 22, 24));
|
||||
}
|
||||
|
||||
// Three closest to 31, all odd
|
||||
{
|
||||
params["vec"] = FtVectorToBytes(FtVector{31.0});
|
||||
params["vec"] = ToBytes({31.0});
|
||||
algo.Init("@even:{no} =>[KNN 3 @pos $vec]", ¶ms);
|
||||
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(29, 31, 33));
|
||||
}
|
||||
|
||||
// Two closest to 70.5
|
||||
{
|
||||
params["vec"] = FtVectorToBytes(FtVector{70.5});
|
||||
params["vec"] = ToBytes({70.5});
|
||||
algo.Init("* =>[KNN 2 @pos $vec]", ¶ms);
|
||||
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(70, 71));
|
||||
}
|
||||
|
@ -393,11 +388,11 @@ TEST_F(SearchParserTest, Simple2dKnn) {
|
|||
const pair<float, float> kTestCoords[] = {{0, 0}, {1, 0}, {1, 1}, {0, 1}, {0.5, 0.5}};
|
||||
|
||||
auto schema = MakeSimpleSchema({{"pos", SchemaField::VECTOR}});
|
||||
schema.fields["pos"].knn_dim = 2;
|
||||
FieldIndices indices{schema};
|
||||
|
||||
for (size_t i = 0; i < ABSL_ARRAYSIZE(kTestCoords); i++) {
|
||||
auto [x, y] = kTestCoords[i];
|
||||
string coords = absl::StrCat(x, ",", y);
|
||||
string coords = ToBytes({kTestCoords[i].first, kTestCoords[i].second});
|
||||
MockedDocument doc{Map{{"pos", coords}}};
|
||||
indices.Add(i, &doc);
|
||||
}
|
||||
|
@ -407,47 +402,83 @@ TEST_F(SearchParserTest, Simple2dKnn) {
|
|||
|
||||
// Single center
|
||||
{
|
||||
params["vec"] = FtVectorToBytes(FtVector{0.5, 0.5});
|
||||
params["vec"] = ToBytes({0.5, 0.5});
|
||||
algo.Init("* =>[KNN 1 @pos $vec]", ¶ms);
|
||||
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(4));
|
||||
}
|
||||
|
||||
// Lower left
|
||||
{
|
||||
params["vec"] = FtVectorToBytes(FtVector{0, 0});
|
||||
params["vec"] = ToBytes({0, 0});
|
||||
algo.Init("* =>[KNN 4 @pos $vec]", ¶ms);
|
||||
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0, 1, 3, 4));
|
||||
}
|
||||
|
||||
// Upper right
|
||||
{
|
||||
params["vec"] = FtVectorToBytes(FtVector{1, 1});
|
||||
params["vec"] = ToBytes({1, 1});
|
||||
algo.Init("* =>[KNN 4 @pos $vec]", ¶ms);
|
||||
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(1, 2, 3, 4));
|
||||
}
|
||||
|
||||
// Request more than there is
|
||||
{
|
||||
params["vec"] = FtVectorToBytes(FtVector{0, 0});
|
||||
params["vec"] = ToBytes({0, 0});
|
||||
algo.Init("* => [KNN 10 @pos $vec]", ¶ms);
|
||||
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0, 1, 2, 3, 4));
|
||||
}
|
||||
|
||||
// Test correct order: (0.7, 0.15)
|
||||
{
|
||||
params["vec"] = FtVectorToBytes(FtVector{0.7, 0.15});
|
||||
params["vec"] = ToBytes({0.7, 0.15});
|
||||
algo.Init("* => [KNN 10 @pos $vec]", ¶ms);
|
||||
EXPECT_THAT(algo.Search(&indices).ids, testing::ElementsAre(1, 4, 0, 2, 3));
|
||||
}
|
||||
|
||||
// Test correct order: (0.8, 0.9)
|
||||
{
|
||||
params["vec"] = FtVectorToBytes(FtVector{0.8, 0.9});
|
||||
params["vec"] = ToBytes({0.8, 0.9});
|
||||
algo.Init("* => [KNN 10 @pos $vec]", ¶ms);
|
||||
EXPECT_THAT(algo.Search(&indices).ids, testing::ElementsAre(2, 4, 3, 1, 0));
|
||||
}
|
||||
}
|
||||
|
||||
static void BM_VectorSearch(benchmark::State& state) {
|
||||
unsigned ndims = state.range(0);
|
||||
unsigned nvecs = state.range(1);
|
||||
|
||||
auto schema = MakeSimpleSchema({{"pos", SchemaField::VECTOR}});
|
||||
schema.fields["pos"].knn_dim = ndims;
|
||||
FieldIndices indices{schema};
|
||||
|
||||
auto random_vec = [ndims]() {
|
||||
vector<float> coords;
|
||||
for (size_t j = 0; j < ndims; j++)
|
||||
coords.push_back(static_cast<float>(rand()) / static_cast<float>(RAND_MAX));
|
||||
return coords;
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < nvecs; i++) {
|
||||
auto rv = random_vec();
|
||||
MockedDocument doc{Map{{"pos", ToBytes(rv)}}};
|
||||
indices.Add(i, &doc);
|
||||
}
|
||||
|
||||
SearchAlgorithm algo{};
|
||||
QueryParams params;
|
||||
|
||||
auto rv = random_vec();
|
||||
params["vec"] = ToBytes(rv);
|
||||
algo.Init("* =>[KNN 1 @pos $vec]", ¶ms);
|
||||
|
||||
while (state.KeepRunningBatch(10)) {
|
||||
for (size_t i = 0; i < 10; i++)
|
||||
benchmark::DoNotOptimize(algo.Search(&indices));
|
||||
}
|
||||
}
|
||||
|
||||
BENCHMARK(BM_VectorSearch)->Args({120, 10'000});
|
||||
|
||||
} // namespace search
|
||||
|
||||
} // namespace dfly
|
||||
|
|
|
@ -1,38 +0,0 @@
|
|||
// Copyright 2023, DragonflyDB authors. All rights reserved.
|
||||
// See LICENSE for licensing terms.
|
||||
//
|
||||
|
||||
#include "core/search/vector.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
|
||||
#include "base/logging.h"
|
||||
|
||||
namespace dfly::search {
|
||||
|
||||
using namespace std;
|
||||
|
||||
FtVector BytesToFtVector(string_view value) {
|
||||
DCHECK_EQ(value.size() % sizeof(float), 0u);
|
||||
FtVector out(value.size() / sizeof(float));
|
||||
|
||||
// Create copy for aligned access
|
||||
unique_ptr<float[]> float_ptr = make_unique<float[]>(out.size());
|
||||
memcpy(float_ptr.get(), value.data(), value.size());
|
||||
|
||||
for (size_t i = 0; i < out.size(); i++)
|
||||
out[i] = float_ptr[i];
|
||||
return out;
|
||||
}
|
||||
|
||||
// Euclidean vector distance: sqrt( sum: (u[i] - v[i])^2 )
|
||||
__attribute__((optimize("fast-math"))) float VectorDistance(const FtVector& u, const FtVector& v) {
|
||||
DCHECK_EQ(u.size(), v.size());
|
||||
float sum = 0;
|
||||
for (size_t i = 0; i < u.size(); i++)
|
||||
sum += (u[i] - v[i]) * (u[i] - v[i]);
|
||||
return sqrt(sum);
|
||||
}
|
||||
|
||||
} // namespace dfly::search
|
65
src/core/search/vector_utils.cc
Normal file
65
src/core/search/vector_utils.cc
Normal file
|
@ -0,0 +1,65 @@
|
|||
// Copyright 2023, DragonflyDB authors. All rights reserved.
|
||||
// See LICENSE for licensing terms.
|
||||
//
|
||||
|
||||
#include "core/search/vector_utils.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
|
||||
#include "base/logging.h"
|
||||
|
||||
namespace dfly::search {
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace {
|
||||
|
||||
// Euclidean vector distance: sqrt( sum: (u[i] - v[i])^2 )
|
||||
__attribute__((optimize("fast-math"))) float L2Distance(const float* u, const float* v,
|
||||
size_t dims) {
|
||||
float sum = 0;
|
||||
for (size_t i = 0; i < dims; i++)
|
||||
sum += (u[i] - v[i]) * (u[i] - v[i]);
|
||||
return sqrt(sum);
|
||||
}
|
||||
|
||||
__attribute__((optimize("fast-math"))) float CosineDistance(const float* u, const float* v,
|
||||
size_t dims) {
|
||||
float sum_uv = 0, sum_uu = 0, sum_vv = 0;
|
||||
for (size_t i = 0; i < dims; i++) {
|
||||
sum_uv += u[i] * v[i];
|
||||
sum_uu += u[i] * u[i];
|
||||
sum_vv += v[i] * v[i];
|
||||
}
|
||||
|
||||
if (float denom = sum_uu * sum_vv; denom != 0.0f)
|
||||
return sum_uv / sqrt(denom);
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
OwnedFtVector BytesToFtVector(string_view value) {
|
||||
DCHECK_EQ(value.size() % sizeof(float), 0u) << value.size();
|
||||
|
||||
// Value cannot be casted directly as it might be not aligned as a float (4 bytes).
|
||||
// Misaligned memory access is UB.
|
||||
size_t size = value.size() / sizeof(float);
|
||||
auto out = make_unique<float[]>(size);
|
||||
memcpy(out.get(), value.data(), size * sizeof(float));
|
||||
|
||||
return {std::move(out), size};
|
||||
}
|
||||
|
||||
float VectorDistance(const float* u, const float* v, size_t dims, VectorSimilarity sim) {
|
||||
switch (sim) {
|
||||
case VectorSimilarity::L2:
|
||||
return L2Distance(u, v, dims);
|
||||
case VectorSimilarity::COSINE:
|
||||
return CosineDistance(u, v, dims);
|
||||
};
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
} // namespace dfly::search
|
|
@ -8,8 +8,8 @@
|
|||
|
||||
namespace dfly::search {
|
||||
|
||||
FtVector BytesToFtVector(std::string_view value);
|
||||
OwnedFtVector BytesToFtVector(std::string_view value);
|
||||
|
||||
float VectorDistance(const FtVector& v1, const FtVector& v2);
|
||||
float VectorDistance(const float* u, const float* v, size_t dims, VectorSimilarity sim);
|
||||
|
||||
} // namespace dfly::search
|
|
@ -157,6 +157,10 @@ struct CmdArgParser {
|
|||
return cur_i_ < args_.size() && !error_;
|
||||
}
|
||||
|
||||
bool HasError() {
|
||||
return error_.has_value();
|
||||
}
|
||||
|
||||
// Get optional error if occured
|
||||
std::optional<ErrorInfo> Error() {
|
||||
return std::exchange(error_, {});
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
|
||||
#include "core/json_object.h"
|
||||
#include "core/search/search.h"
|
||||
#include "core/search/vector.h"
|
||||
#include "core/search/vector_utils.h"
|
||||
#include "core/string_map.h"
|
||||
#include "server/container_utils.h"
|
||||
|
||||
|
@ -32,10 +32,11 @@ string_view SdsToSafeSv(sds str) {
|
|||
}
|
||||
|
||||
string PrintField(search::SchemaField::FieldType type, string_view value) {
|
||||
if (type == search::SchemaField::VECTOR)
|
||||
return absl::StrCat("[", absl::StrJoin(search::BytesToFtVector(value), ","), "]");
|
||||
else
|
||||
return string{value};
|
||||
if (type == search::SchemaField::VECTOR) {
|
||||
auto [ptr, size] = search::BytesToFtVector(value);
|
||||
return absl::StrCat("[", absl::StrJoin(absl::Span<const float>{ptr.get(), size}, ","), "]");
|
||||
}
|
||||
return string{value};
|
||||
}
|
||||
|
||||
string ExtractValue(const search::Schema& schema, string_view key, string_view value) {
|
||||
|
@ -63,7 +64,7 @@ string_view ListPackAccessor::GetString(string_view active_field) const {
|
|||
return container_utils::LpFind(lp_, active_field, intbuf_[0].data()).value_or(""sv);
|
||||
}
|
||||
|
||||
search::FtVector ListPackAccessor::GetVector(string_view active_field) const {
|
||||
BaseAccessor::VectorInfo ListPackAccessor::GetVector(string_view active_field) const {
|
||||
return search::BytesToFtVector(GetString(active_field));
|
||||
}
|
||||
|
||||
|
@ -89,7 +90,7 @@ string_view StringMapAccessor::GetString(string_view active_field) const {
|
|||
return SdsToSafeSv(hset_->Find(active_field));
|
||||
}
|
||||
|
||||
search::FtVector StringMapAccessor::GetVector(string_view active_field) const {
|
||||
BaseAccessor::VectorInfo StringMapAccessor::GetVector(string_view active_field) const {
|
||||
return search::BytesToFtVector(GetString(active_field));
|
||||
}
|
||||
|
||||
|
@ -113,16 +114,20 @@ string_view JsonAccessor::GetString(string_view active_field) const {
|
|||
return buf_;
|
||||
}
|
||||
|
||||
search::FtVector JsonAccessor::GetVector(string_view active_field) const {
|
||||
BaseAccessor::VectorInfo JsonAccessor::GetVector(string_view active_field) const {
|
||||
auto res = GetPath(active_field)->evaluate(json_);
|
||||
DCHECK(res.is_array());
|
||||
if (res.empty())
|
||||
return {};
|
||||
return {nullptr, 0};
|
||||
|
||||
search::FtVector out;
|
||||
size_t size = res[0].size();
|
||||
auto ptr = make_unique<float[]>(size);
|
||||
|
||||
size_t i = 0;
|
||||
for (auto v : res[0].array_range())
|
||||
out.push_back(v.as<float>());
|
||||
return out;
|
||||
ptr[i++] = v.as<float>();
|
||||
|
||||
return {std::move(ptr), size};
|
||||
}
|
||||
|
||||
JsonAccessor::JsonPathContainer* JsonAccessor::GetPath(std::string_view field) const {
|
||||
|
|
|
@ -40,7 +40,7 @@ struct ListPackAccessor : public BaseAccessor {
|
|||
}
|
||||
|
||||
std::string_view GetString(std::string_view field) const override;
|
||||
search::FtVector GetVector(std::string_view field) const override;
|
||||
VectorInfo GetVector(std::string_view field) const override;
|
||||
SearchDocData Serialize(const search::Schema& schema) const override;
|
||||
|
||||
private:
|
||||
|
@ -54,7 +54,7 @@ struct StringMapAccessor : public BaseAccessor {
|
|||
}
|
||||
|
||||
std::string_view GetString(std::string_view field) const override;
|
||||
search::FtVector GetVector(std::string_view field) const override;
|
||||
VectorInfo GetVector(std::string_view field) const override;
|
||||
SearchDocData Serialize(const search::Schema& schema) const override;
|
||||
|
||||
private:
|
||||
|
@ -69,7 +69,7 @@ struct JsonAccessor : public BaseAccessor {
|
|||
}
|
||||
|
||||
std::string_view GetString(std::string_view field) const override;
|
||||
search::FtVector GetVector(std::string_view field) const override;
|
||||
VectorInfo GetVector(std::string_view field) const override;
|
||||
SearchDocData Serialize(const search::Schema& schema) const override;
|
||||
|
||||
// The JsonAccessor works with structured types and not plain strings, so an overload is needed
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include "base/logging.h"
|
||||
#include "core/json_object.h"
|
||||
#include "core/search/search.h"
|
||||
#include "core/search/vector.h"
|
||||
#include "core/search/vector_utils.h"
|
||||
#include "facade/cmd_arg_parser.h"
|
||||
#include "facade/error.h"
|
||||
#include "facade/reply_builder.h"
|
||||
|
@ -46,6 +46,30 @@ bool IsValidJsonPath(string_view path) {
|
|||
return !ec;
|
||||
}
|
||||
|
||||
pair<size_t, search::VectorSimilarity> ParseVectorFieldInfo(CmdArgParser* parser,
|
||||
ConnectionContext* cntx) {
|
||||
size_t dim = 0;
|
||||
search::VectorSimilarity sim = search::VectorSimilarity::L2;
|
||||
|
||||
size_t num_args = parser->Next().Int<size_t>();
|
||||
for (size_t i = 0; i * 2 < num_args; i++) {
|
||||
parser->ToUpper();
|
||||
if (parser->Check("DIM").ExpectTail(1)) {
|
||||
dim = parser->Next().Int<size_t>();
|
||||
continue;
|
||||
}
|
||||
if (parser->Check("DISTANCE_METRIC").ExpectTail(1)) {
|
||||
sim = parser->Next()
|
||||
.Case("L2", search::VectorSimilarity::L2)
|
||||
.Case("COSINE", search::VectorSimilarity::COSINE);
|
||||
continue;
|
||||
}
|
||||
parser->Skip(2);
|
||||
}
|
||||
|
||||
return {dim, sim};
|
||||
}
|
||||
|
||||
optional<search::Schema> ParseSchemaOrReply(DocIndex::DataType type, CmdArgParser parser,
|
||||
ConnectionContext* cntx) {
|
||||
search::Schema schema;
|
||||
|
@ -74,15 +98,24 @@ optional<search::Schema> ParseSchemaOrReply(DocIndex::DataType type, CmdArgParse
|
|||
return nullopt;
|
||||
}
|
||||
|
||||
// Skip {algorithm} {dim} flags
|
||||
if (*type == search::SchemaField::VECTOR)
|
||||
parser.Skip(2);
|
||||
// Vector fields include: {algorithm} num_args args...
|
||||
size_t knn_dim = 0;
|
||||
search::VectorSimilarity knn_sim = search::VectorSimilarity::L2;
|
||||
if (*type == search::SchemaField::VECTOR) {
|
||||
parser.Skip(1); // algorithm
|
||||
std::tie(knn_dim, knn_sim) = ParseVectorFieldInfo(&parser, cntx);
|
||||
|
||||
if (!parser.HasError() && knn_dim == 0) {
|
||||
(*cntx)->SendError("Vector dimension cannot be zero");
|
||||
return nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
// Skip all trailing ignored parameters
|
||||
while (kIgnoredOptions.count(parser.Peek()) > 0)
|
||||
parser.Skip(2);
|
||||
|
||||
schema.fields[field] = {*type, string{field_alias}};
|
||||
schema.fields[field] = {*type, string{field_alias}, knn_dim, knn_sim};
|
||||
}
|
||||
|
||||
// Build field name mapping table
|
||||
|
@ -208,8 +241,9 @@ void ReplyKnn(size_t knn_limit, const SearchParams& params, absl::Span<SearchRes
|
|||
}
|
||||
}
|
||||
|
||||
partial_sort(docs.begin(),
|
||||
docs.begin() + min(params.limit_offset + params.limit_total, knn_limit), docs.end(),
|
||||
size_t prefix = min(params.limit_offset + params.limit_total, knn_limit);
|
||||
|
||||
partial_sort(docs.begin(), docs.begin() + min(docs.size(), prefix), docs.end(),
|
||||
[](const auto* l, const auto* r) { return l->knn_distance < r->knn_distance; });
|
||||
docs.resize(min(docs.size(), knn_limit));
|
||||
|
||||
|
|
Loading…
Reference in a new issue