1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-14 11:58:02 +00:00

feat(search): sized vectors (#1788)

* feat(search): Sized vectors

---------

Signed-off-by: Vladislav Oleshko <vlad@dragonflydb.io>
This commit is contained in:
Vladislav 2023-09-06 11:00:03 +03:00 committed by GitHub
parent 36be222091
commit aa4cadfa12
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 275 additions and 122 deletions

View file

@ -5,7 +5,7 @@ cur_gen_dir(gen_dir)
find_package(ICU REQUIRED COMPONENTS uc i18n)
add_library(query_parser ast_expr.cc query_driver.cc search.cc indices.cc vector.cc compressed_sorted_set.cc
add_library(query_parser ast_expr.cc query_driver.cc search.cc indices.cc vector_utils.cc compressed_sorted_set.cc
${gen_dir}/parser.cc ${gen_dir}/lexer.cc)
target_link_libraries(query_parser ICU::uc ICU::i18n)

View file

@ -56,9 +56,11 @@ AstTagsNode::AstTagsNode(AstExpr&& l, std::string tag) {
tags.push_back(move(tag));
}
AstKnnNode::AstKnnNode(AstNode&& filter, size_t limit, std::string field, FtVector vec)
: filter{make_unique<AstNode>(move(filter))}, limit{limit}, field{field.substr(1)}, vector{move(
vec)} {
AstKnnNode::AstKnnNode(AstNode&& filter, size_t limit, std::string_view field, OwnedFtVector vec)
: filter{make_unique<AstNode>(std::move(filter))},
limit{limit},
field{field.substr(1)},
vec{std::move(vec)} {
}
} // namespace dfly::search

View file

@ -74,12 +74,12 @@ struct AstTagsNode {
// Applies nearest neighbor search to the final result set
struct AstKnnNode {
AstKnnNode(AstNode&& sub, size_t limit, std::string field, FtVector vec);
AstKnnNode(AstNode&& sub, size_t limit, std::string_view field, OwnedFtVector vec);
std::unique_ptr<AstNode> filter;
size_t limit;
std::string field;
FtVector vector;
OwnedFtVector vec;
};
using NodeVariants =

View file

@ -3,6 +3,7 @@
#include <absl/container/flat_hash_map.h>
#include <cstdint>
#include <memory>
#include <string>
#include <string_view>
#include <vector>
@ -14,7 +15,9 @@ namespace dfly::search {
using DocId = uint32_t;
using FtVector = std::vector<float>;
enum class VectorSimilarity { L2, COSINE };
using OwnedFtVector = std::pair<std::unique_ptr<float[]>, size_t /* dimension (size) */>;
// Query params represent named parameters for queries supplied via PARAMS.
struct QueryParams {
@ -38,9 +41,11 @@ struct QueryParams {
// Interface for accessing document values with different data structures underneath.
struct DocumentAccessor {
using VectorInfo = search::OwnedFtVector;
virtual ~DocumentAccessor() = default;
virtual std::string_view GetString(std::string_view active_field) const = 0;
virtual FtVector GetVector(std::string_view active_field) const = 0;
virtual VectorInfo GetVector(std::string_view active_field) const = 0;
};
// Base class for type-specific indices.

View file

@ -59,6 +59,11 @@ class CompressedSortedSet {
size_t Size() const;
size_t ByteSize() const;
// To use transparently in templates together with stl containers
size_t size() const {
return Size();
}
private:
struct EntryLocation {
IntType value; // Value or 0

View file

@ -151,17 +151,31 @@ absl::flat_hash_set<std::string> TagIndex::Tokenize(std::string_view value) cons
return NormalizeTags(value);
}
VectorIndex::VectorIndex(size_t dim, VectorSimilarity sim) : dim_{dim}, sim_{sim}, entries_{} {
}
void VectorIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
entries_[id] = doc->GetVector(field);
DCHECK_LE(id * dim_, entries_.size());
if (id * dim_ == entries_.size())
entries_.resize((id + 1) * dim_);
// TODO: Let get vector write to buf itself
auto [ptr, size] = doc->GetVector(field);
if (size == dim_)
memcpy(&entries_[id * dim_], ptr.get(), dim_ * sizeof(float));
}
void VectorIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) {
entries_.erase(id);
// noop
}
FtVector VectorIndex::Get(DocId doc) const {
auto it = entries_.find(doc);
return it != entries_.end() ? it->second : FtVector{};
const float* VectorIndex::Get(DocId doc) const {
return &entries_[doc * dim_];
}
std::pair<size_t /*dim*/, VectorSimilarity> VectorIndex::Info() const {
return {dim_, sim_};
}
} // namespace dfly::search

View file

@ -57,13 +57,18 @@ struct TagIndex : public BaseStringIndex {
// Index for vector fields.
// Only supports lookup by id.
struct VectorIndex : public BaseIndex {
VectorIndex(size_t dim, VectorSimilarity sim);
void Add(DocId id, DocumentAccessor* doc, std::string_view field) override;
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override;
FtVector Get(DocId doc) const;
const float* Get(DocId doc) const;
std::pair<size_t /*dim*/, VectorSimilarity> Info() const;
private:
absl::flat_hash_map<DocId, FtVector> entries_;
size_t dim_;
VectorSimilarity sim_;
std::vector<float> entries_;
};
} // namespace dfly::search

View file

@ -25,7 +25,7 @@
// Added to cc file
%code {
#include "core/search/query_driver.h"
#include "core/search/vector.h"
#include "core/search/vector_utils.h"
// Have to disable because GCC doesn't understand `symbol_type`'s union
// implementation

View file

@ -10,6 +10,7 @@
#include <absl/strings/str_join.h>
#include <chrono>
#include <type_traits>
#include <variant>
#include "base/logging.h"
@ -18,7 +19,7 @@
#include "core/search/compressed_sorted_set.h"
#include "core/search/indices.h"
#include "core/search/query_driver.h"
#include "core/search/vector.h"
#include "core/search/vector_utils.h"
using namespace std;
@ -35,11 +36,18 @@ AstExpr ParseQuery(std::string_view query, const QueryParams* params) {
return driver.Take();
}
// GCC 12 yields a wrong warning in a deeply inlined call in UnifyResults, only ignoring the whole
// scope solves it
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
// Represents an either owned or non-owned result set that can be accessed transparently.
struct IndexResult {
using DocVec = vector<DocId>;
using BorrowedView = variant<const DocVec*, const CompressedSortedSet*>;
IndexResult() : value_{DocVec{}} {};
IndexResult() : value_{DocVec{}} {
}
IndexResult(const CompressedSortedSet* css) : value_{css} {
if (css == nullptr)
@ -49,10 +57,11 @@ struct IndexResult {
IndexResult(DocVec&& dv) : value_{move(dv)} {
}
IndexResult(const DocVec* dv) : value_{dv} {
}
size_t Size() const {
if (holds_alternative<DocVec>(value_))
return get<DocVec>(value_).size();
return get<const CompressedSortedSet*>(value_)->Size();
return visit([](auto* set) { return set->size(); }, Borrowed());
}
bool IsOwned() const {
@ -64,28 +73,31 @@ struct IndexResult {
swap(get<DocVec>(value_), entries); // swap to keep backing array
entries.clear();
} else {
value_ = move(entries);
value_ = std::move(entries);
}
return *this;
}
variant<const DocVec*, const CompressedSortedSet*> Borrowed() {
if (holds_alternative<DocVec>(value_))
return &get<DocVec>(value_);
return get<const CompressedSortedSet*>(value_);
BorrowedView Borrowed() const {
auto cb = [](const auto& v) -> BorrowedView {
if constexpr (is_pointer_v<remove_reference_t<decltype(v)>>)
return v;
else
return &v;
};
return visit(cb, value_);
}
// Move out of owned or copy borrowed
DocVec Take() {
if (holds_alternative<DocVec>(value_))
if (IsOwned())
return move(get<DocVec>(value_));
const CompressedSortedSet* css = get<const CompressedSortedSet*>(value_);
return DocVec(css->begin(), css->end());
return visit([](auto* set) { return DocVec(set->begin(), set->end()); }, Borrowed());
}
private:
variant<DocVec /*owned*/, const CompressedSortedSet* /* borrowed */> value_;
variant<DocVec /*owned*/, const CompressedSortedSet*, const DocVec*> value_;
};
struct ProfileBuilder {
@ -194,7 +206,7 @@ struct BasicSearch {
sort(sub_results.begin(), sub_results.end(),
[](const auto& l, const auto& r) { return l.Size() < r.Size(); });
IndexResult out{move(sub_results[0])};
IndexResult out{std::move(sub_results[0])};
for (auto& matched : absl::MakeSpan(sub_results).subspan(1))
Merge(move(matched), &out, op);
return out;
@ -206,7 +218,7 @@ struct BasicSearch {
IndexResult Search(const AstStarNode& node, string_view active_field) {
DCHECK(active_field.empty());
return vector<DocId>{indices_->GetAllDocs()}; // TODO FIX;
return {&indices_->GetAllDocs()};
}
// "term": access field's text index or unify results from all text indices if no field is set
@ -268,19 +280,23 @@ struct BasicSearch {
auto sub_results = SearchGeneric(*knn.filter, active_field);
auto* vec_index = GetIndex<VectorIndex>(knn.field);
if (auto [dim, _] = vec_index->Info(); dim != knn.vec.second)
return IndexResult{};
distances_.reserve(sub_results.Size());
auto cb = [&](auto* set) {
auto [dim, sim] = vec_index->Info();
for (DocId matched_doc : *set) {
float dist = VectorDistance(knn.vector, vec_index->Get(matched_doc));
float dist = VectorDistance(knn.vec.first.get(), vec_index->Get(matched_doc), dim, sim);
distances_.emplace_back(dist, matched_doc);
}
};
visit(cb, sub_results.Borrowed());
sort(distances_.begin(), distances_.end());
size_t prefix_size = min(knn.limit, distances_.size());
partial_sort(distances_.begin(), distances_.begin() + prefix_size, distances_.end());
vector<DocId> out(min(knn.limit, distances_.size()));
vector<DocId> out(prefix_size);
for (size_t i = 0; i < out.size(); i++)
out[i] = distances_[i].second;
@ -331,6 +347,8 @@ struct BasicSearch {
vector<pair<float, DocId>> distances_;
};
#pragma GCC diagnostic pop
} // namespace
FieldIndices::FieldIndices(Schema schema) : schema_{move(schema)}, all_ids_{}, indices_{} {
@ -346,7 +364,7 @@ FieldIndices::FieldIndices(Schema schema) : schema_{move(schema)}, all_ids_{}, i
indices_[field_ident] = make_unique<NumericIndex>();
break;
case SchemaField::VECTOR:
indices_[field_ident] = make_unique<VectorIndex>();
indices_[field_ident] = make_unique<VectorIndex>(field_info.knn_dim, field_info.knn_sim);
break;
}
}

View file

@ -25,6 +25,9 @@ struct SchemaField {
FieldType type;
std::string short_name; // equal to ident if none provided
size_t knn_dim = 0u; // dimension of knn vectors
VectorSimilarity knn_sim = VectorSimilarity::L2; // similarity type
};
// Describes the fields of an index

View file

@ -19,6 +19,7 @@
#include "base/logging.h"
#include "core/search/base.h"
#include "core/search/query_driver.h"
#include "core/search/vector_utils.h"
namespace dfly {
namespace search {
@ -40,15 +41,8 @@ struct MockedDocument : public DocumentAccessor {
return it != fields_.end() ? string_view{it->second} : "";
}
FtVector GetVector(string_view field) const override {
string_view str_value = fields_.at(field);
FtVector out;
for (string_view coord : absl::StrSplit(str_value, ',')) {
float v;
CHECK(absl::SimpleAtof(coord, &v));
out.push_back(v);
}
return out;
VectorInfo GetVector(string_view field) const override {
return BytesToFtVector(GetString(field));
}
string DebugFormat() {
@ -331,17 +325,18 @@ TEST_F(SearchParserTest, IntegerTerms) {
EXPECT_TRUE(Check()) << GetError();
}
std::string FtVectorToBytes(FtVector vec) {
std::string ToBytes(absl::Span<const float> vec) {
return string{reinterpret_cast<const char*>(vec.data()), sizeof(float) * vec.size()};
}
TEST_F(SearchParserTest, SimpleKnn) {
auto schema = MakeSimpleSchema({{"even", SchemaField::TAG}, {"pos", SchemaField::VECTOR}});
schema.fields["pos"].knn_dim = 1;
FieldIndices indices{schema};
// Place points on a straight line
for (size_t i = 0; i < 100; i++) {
Map values{{{"even", i % 2 == 0 ? "YES" : "NO"}, {"pos", to_string(float(i))}}};
Map values{{{"even", i % 2 == 0 ? "YES" : "NO"}, {"pos", ToBytes({float(i)})}}};
MockedDocument doc{values};
indices.Add(i, &doc);
}
@ -351,35 +346,35 @@ TEST_F(SearchParserTest, SimpleKnn) {
// Five closest to 50
{
params["vec"] = FtVectorToBytes(FtVector{50.0});
params["vec"] = ToBytes({50.0});
algo.Init("*=>[KNN 5 @pos $vec]", &params);
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(48, 49, 50, 51, 52));
}
// Five closest to 0
{
params["vec"] = FtVectorToBytes(FtVector{0.0});
params["vec"] = ToBytes({0.0});
algo.Init("*=>[KNN 5 @pos $vec]", &params);
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0, 1, 2, 3, 4));
}
// Five closest to 20, all even
{
params["vec"] = FtVectorToBytes(FtVector{20.0});
params["vec"] = ToBytes({20.0});
algo.Init("@even:{yes} =>[KNN 5 @pos $vec]", &params);
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(16, 18, 20, 22, 24));
}
// Three closest to 31, all odd
{
params["vec"] = FtVectorToBytes(FtVector{31.0});
params["vec"] = ToBytes({31.0});
algo.Init("@even:{no} =>[KNN 3 @pos $vec]", &params);
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(29, 31, 33));
}
// Two closest to 70.5
{
params["vec"] = FtVectorToBytes(FtVector{70.5});
params["vec"] = ToBytes({70.5});
algo.Init("* =>[KNN 2 @pos $vec]", &params);
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(70, 71));
}
@ -393,11 +388,11 @@ TEST_F(SearchParserTest, Simple2dKnn) {
const pair<float, float> kTestCoords[] = {{0, 0}, {1, 0}, {1, 1}, {0, 1}, {0.5, 0.5}};
auto schema = MakeSimpleSchema({{"pos", SchemaField::VECTOR}});
schema.fields["pos"].knn_dim = 2;
FieldIndices indices{schema};
for (size_t i = 0; i < ABSL_ARRAYSIZE(kTestCoords); i++) {
auto [x, y] = kTestCoords[i];
string coords = absl::StrCat(x, ",", y);
string coords = ToBytes({kTestCoords[i].first, kTestCoords[i].second});
MockedDocument doc{Map{{"pos", coords}}};
indices.Add(i, &doc);
}
@ -407,47 +402,83 @@ TEST_F(SearchParserTest, Simple2dKnn) {
// Single center
{
params["vec"] = FtVectorToBytes(FtVector{0.5, 0.5});
params["vec"] = ToBytes({0.5, 0.5});
algo.Init("* =>[KNN 1 @pos $vec]", &params);
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(4));
}
// Lower left
{
params["vec"] = FtVectorToBytes(FtVector{0, 0});
params["vec"] = ToBytes({0, 0});
algo.Init("* =>[KNN 4 @pos $vec]", &params);
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0, 1, 3, 4));
}
// Upper right
{
params["vec"] = FtVectorToBytes(FtVector{1, 1});
params["vec"] = ToBytes({1, 1});
algo.Init("* =>[KNN 4 @pos $vec]", &params);
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(1, 2, 3, 4));
}
// Request more than there is
{
params["vec"] = FtVectorToBytes(FtVector{0, 0});
params["vec"] = ToBytes({0, 0});
algo.Init("* => [KNN 10 @pos $vec]", &params);
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0, 1, 2, 3, 4));
}
// Test correct order: (0.7, 0.15)
{
params["vec"] = FtVectorToBytes(FtVector{0.7, 0.15});
params["vec"] = ToBytes({0.7, 0.15});
algo.Init("* => [KNN 10 @pos $vec]", &params);
EXPECT_THAT(algo.Search(&indices).ids, testing::ElementsAre(1, 4, 0, 2, 3));
}
// Test correct order: (0.8, 0.9)
{
params["vec"] = FtVectorToBytes(FtVector{0.8, 0.9});
params["vec"] = ToBytes({0.8, 0.9});
algo.Init("* => [KNN 10 @pos $vec]", &params);
EXPECT_THAT(algo.Search(&indices).ids, testing::ElementsAre(2, 4, 3, 1, 0));
}
}
static void BM_VectorSearch(benchmark::State& state) {
unsigned ndims = state.range(0);
unsigned nvecs = state.range(1);
auto schema = MakeSimpleSchema({{"pos", SchemaField::VECTOR}});
schema.fields["pos"].knn_dim = ndims;
FieldIndices indices{schema};
auto random_vec = [ndims]() {
vector<float> coords;
for (size_t j = 0; j < ndims; j++)
coords.push_back(static_cast<float>(rand()) / static_cast<float>(RAND_MAX));
return coords;
};
for (size_t i = 0; i < nvecs; i++) {
auto rv = random_vec();
MockedDocument doc{Map{{"pos", ToBytes(rv)}}};
indices.Add(i, &doc);
}
SearchAlgorithm algo{};
QueryParams params;
auto rv = random_vec();
params["vec"] = ToBytes(rv);
algo.Init("* =>[KNN 1 @pos $vec]", &params);
while (state.KeepRunningBatch(10)) {
for (size_t i = 0; i < 10; i++)
benchmark::DoNotOptimize(algo.Search(&indices));
}
}
BENCHMARK(BM_VectorSearch)->Args({120, 10'000});
} // namespace search
} // namespace dfly

View file

@ -1,38 +0,0 @@
// Copyright 2023, DragonflyDB authors. All rights reserved.
// See LICENSE for licensing terms.
//
#include "core/search/vector.h"
#include <cmath>
#include <memory>
#include "base/logging.h"
namespace dfly::search {
using namespace std;
FtVector BytesToFtVector(string_view value) {
DCHECK_EQ(value.size() % sizeof(float), 0u);
FtVector out(value.size() / sizeof(float));
// Create copy for aligned access
unique_ptr<float[]> float_ptr = make_unique<float[]>(out.size());
memcpy(float_ptr.get(), value.data(), value.size());
for (size_t i = 0; i < out.size(); i++)
out[i] = float_ptr[i];
return out;
}
// Euclidean vector distance: sqrt( sum: (u[i] - v[i])^2 )
__attribute__((optimize("fast-math"))) float VectorDistance(const FtVector& u, const FtVector& v) {
DCHECK_EQ(u.size(), v.size());
float sum = 0;
for (size_t i = 0; i < u.size(); i++)
sum += (u[i] - v[i]) * (u[i] - v[i]);
return sqrt(sum);
}
} // namespace dfly::search

View file

@ -0,0 +1,65 @@
// Copyright 2023, DragonflyDB authors. All rights reserved.
// See LICENSE for licensing terms.
//
#include "core/search/vector_utils.h"
#include <cmath>
#include <memory>
#include "base/logging.h"
namespace dfly::search {
using namespace std;
namespace {
// Euclidean vector distance: sqrt( sum: (u[i] - v[i])^2 )
__attribute__((optimize("fast-math"))) float L2Distance(const float* u, const float* v,
size_t dims) {
float sum = 0;
for (size_t i = 0; i < dims; i++)
sum += (u[i] - v[i]) * (u[i] - v[i]);
return sqrt(sum);
}
__attribute__((optimize("fast-math"))) float CosineDistance(const float* u, const float* v,
size_t dims) {
float sum_uv = 0, sum_uu = 0, sum_vv = 0;
for (size_t i = 0; i < dims; i++) {
sum_uv += u[i] * v[i];
sum_uu += u[i] * u[i];
sum_vv += v[i] * v[i];
}
if (float denom = sum_uu * sum_vv; denom != 0.0f)
return sum_uv / sqrt(denom);
return 0.0f;
}
} // namespace
OwnedFtVector BytesToFtVector(string_view value) {
DCHECK_EQ(value.size() % sizeof(float), 0u) << value.size();
// Value cannot be casted directly as it might be not aligned as a float (4 bytes).
// Misaligned memory access is UB.
size_t size = value.size() / sizeof(float);
auto out = make_unique<float[]>(size);
memcpy(out.get(), value.data(), size * sizeof(float));
return {std::move(out), size};
}
float VectorDistance(const float* u, const float* v, size_t dims, VectorSimilarity sim) {
switch (sim) {
case VectorSimilarity::L2:
return L2Distance(u, v, dims);
case VectorSimilarity::COSINE:
return CosineDistance(u, v, dims);
};
return 0.0f;
}
} // namespace dfly::search

View file

@ -8,8 +8,8 @@
namespace dfly::search {
FtVector BytesToFtVector(std::string_view value);
OwnedFtVector BytesToFtVector(std::string_view value);
float VectorDistance(const FtVector& v1, const FtVector& v2);
float VectorDistance(const float* u, const float* v, size_t dims, VectorSimilarity sim);
} // namespace dfly::search

View file

@ -157,6 +157,10 @@ struct CmdArgParser {
return cur_i_ < args_.size() && !error_;
}
bool HasError() {
return error_.has_value();
}
// Get optional error if occured
std::optional<ErrorInfo> Error() {
return std::exchange(error_, {});

View file

@ -12,7 +12,7 @@
#include "core/json_object.h"
#include "core/search/search.h"
#include "core/search/vector.h"
#include "core/search/vector_utils.h"
#include "core/string_map.h"
#include "server/container_utils.h"
@ -32,10 +32,11 @@ string_view SdsToSafeSv(sds str) {
}
string PrintField(search::SchemaField::FieldType type, string_view value) {
if (type == search::SchemaField::VECTOR)
return absl::StrCat("[", absl::StrJoin(search::BytesToFtVector(value), ","), "]");
else
return string{value};
if (type == search::SchemaField::VECTOR) {
auto [ptr, size] = search::BytesToFtVector(value);
return absl::StrCat("[", absl::StrJoin(absl::Span<const float>{ptr.get(), size}, ","), "]");
}
return string{value};
}
string ExtractValue(const search::Schema& schema, string_view key, string_view value) {
@ -63,7 +64,7 @@ string_view ListPackAccessor::GetString(string_view active_field) const {
return container_utils::LpFind(lp_, active_field, intbuf_[0].data()).value_or(""sv);
}
search::FtVector ListPackAccessor::GetVector(string_view active_field) const {
BaseAccessor::VectorInfo ListPackAccessor::GetVector(string_view active_field) const {
return search::BytesToFtVector(GetString(active_field));
}
@ -89,7 +90,7 @@ string_view StringMapAccessor::GetString(string_view active_field) const {
return SdsToSafeSv(hset_->Find(active_field));
}
search::FtVector StringMapAccessor::GetVector(string_view active_field) const {
BaseAccessor::VectorInfo StringMapAccessor::GetVector(string_view active_field) const {
return search::BytesToFtVector(GetString(active_field));
}
@ -113,16 +114,20 @@ string_view JsonAccessor::GetString(string_view active_field) const {
return buf_;
}
search::FtVector JsonAccessor::GetVector(string_view active_field) const {
BaseAccessor::VectorInfo JsonAccessor::GetVector(string_view active_field) const {
auto res = GetPath(active_field)->evaluate(json_);
DCHECK(res.is_array());
if (res.empty())
return {};
return {nullptr, 0};
search::FtVector out;
size_t size = res[0].size();
auto ptr = make_unique<float[]>(size);
size_t i = 0;
for (auto v : res[0].array_range())
out.push_back(v.as<float>());
return out;
ptr[i++] = v.as<float>();
return {std::move(ptr), size};
}
JsonAccessor::JsonPathContainer* JsonAccessor::GetPath(std::string_view field) const {

View file

@ -40,7 +40,7 @@ struct ListPackAccessor : public BaseAccessor {
}
std::string_view GetString(std::string_view field) const override;
search::FtVector GetVector(std::string_view field) const override;
VectorInfo GetVector(std::string_view field) const override;
SearchDocData Serialize(const search::Schema& schema) const override;
private:
@ -54,7 +54,7 @@ struct StringMapAccessor : public BaseAccessor {
}
std::string_view GetString(std::string_view field) const override;
search::FtVector GetVector(std::string_view field) const override;
VectorInfo GetVector(std::string_view field) const override;
SearchDocData Serialize(const search::Schema& schema) const override;
private:
@ -69,7 +69,7 @@ struct JsonAccessor : public BaseAccessor {
}
std::string_view GetString(std::string_view field) const override;
search::FtVector GetVector(std::string_view field) const override;
VectorInfo GetVector(std::string_view field) const override;
SearchDocData Serialize(const search::Schema& schema) const override;
// The JsonAccessor works with structured types and not plain strings, so an overload is needed

View file

@ -18,7 +18,7 @@
#include "base/logging.h"
#include "core/json_object.h"
#include "core/search/search.h"
#include "core/search/vector.h"
#include "core/search/vector_utils.h"
#include "facade/cmd_arg_parser.h"
#include "facade/error.h"
#include "facade/reply_builder.h"
@ -46,6 +46,30 @@ bool IsValidJsonPath(string_view path) {
return !ec;
}
pair<size_t, search::VectorSimilarity> ParseVectorFieldInfo(CmdArgParser* parser,
ConnectionContext* cntx) {
size_t dim = 0;
search::VectorSimilarity sim = search::VectorSimilarity::L2;
size_t num_args = parser->Next().Int<size_t>();
for (size_t i = 0; i * 2 < num_args; i++) {
parser->ToUpper();
if (parser->Check("DIM").ExpectTail(1)) {
dim = parser->Next().Int<size_t>();
continue;
}
if (parser->Check("DISTANCE_METRIC").ExpectTail(1)) {
sim = parser->Next()
.Case("L2", search::VectorSimilarity::L2)
.Case("COSINE", search::VectorSimilarity::COSINE);
continue;
}
parser->Skip(2);
}
return {dim, sim};
}
optional<search::Schema> ParseSchemaOrReply(DocIndex::DataType type, CmdArgParser parser,
ConnectionContext* cntx) {
search::Schema schema;
@ -74,15 +98,24 @@ optional<search::Schema> ParseSchemaOrReply(DocIndex::DataType type, CmdArgParse
return nullopt;
}
// Skip {algorithm} {dim} flags
if (*type == search::SchemaField::VECTOR)
parser.Skip(2);
// Vector fields include: {algorithm} num_args args...
size_t knn_dim = 0;
search::VectorSimilarity knn_sim = search::VectorSimilarity::L2;
if (*type == search::SchemaField::VECTOR) {
parser.Skip(1); // algorithm
std::tie(knn_dim, knn_sim) = ParseVectorFieldInfo(&parser, cntx);
if (!parser.HasError() && knn_dim == 0) {
(*cntx)->SendError("Vector dimension cannot be zero");
return nullopt;
}
}
// Skip all trailing ignored parameters
while (kIgnoredOptions.count(parser.Peek()) > 0)
parser.Skip(2);
schema.fields[field] = {*type, string{field_alias}};
schema.fields[field] = {*type, string{field_alias}, knn_dim, knn_sim};
}
// Build field name mapping table
@ -208,8 +241,9 @@ void ReplyKnn(size_t knn_limit, const SearchParams& params, absl::Span<SearchRes
}
}
partial_sort(docs.begin(),
docs.begin() + min(params.limit_offset + params.limit_total, knn_limit), docs.end(),
size_t prefix = min(params.limit_offset + params.limit_total, knn_limit);
partial_sort(docs.begin(), docs.begin() + min(docs.size(), prefix), docs.end(),
[](const auto* l, const auto* r) { return l->knn_distance < r->knn_distance; });
docs.resize(min(docs.size(), knn_limit));