Skip to content

Commit

Permalink
fix(search_family): Process wrong field types in indexes for the FT.S…
Browse files Browse the repository at this point in the history
…EARCH and FT.AGGREGATE commands (#4070)

* fix(search_family): Process wrong field types in indexes for the FT.SEARCH and FT.AGGREGATE commands

fixes  #3986
---------

Signed-off-by: Stepan Bagritsevich <[email protected]>
  • Loading branch information
BagritsevichStepan authored and romange committed Nov 10, 2024
1 parent 2d49a28 commit a506795
Show file tree
Hide file tree
Showing 16 changed files with 682 additions and 215 deletions.
9 changes: 9 additions & 0 deletions src/core/search/base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include "core/search/base.h"

#include <absl/strings/numbers.h>

namespace dfly::search {

std::string_view QueryParams::operator[](std::string_view name) const {
Expand Down Expand Up @@ -37,4 +39,11 @@ WrappedStrPtr::operator std::string_view() const {
return std::string_view{ptr.get(), std::strlen(ptr.get())};
}

std::optional<double> ParseNumericField(std::string_view value) {
double value_as_double;
if (absl::SimpleAtod(value, &value_as_double))
return value_as_double;
return std::nullopt;
}

} // namespace dfly::search
33 changes: 29 additions & 4 deletions src/core/search/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,18 @@ using SortableValue = std::variant<std::monostate, double, std::string>;
struct DocumentAccessor {
using VectorInfo = search::OwnedFtVector;
using StringList = absl::InlinedVector<std::string_view, 1>;
using NumsList = absl::InlinedVector<double, 1>;

virtual ~DocumentAccessor() = default;

virtual StringList GetStrings(std::string_view active_field) const = 0;
virtual VectorInfo GetVector(std::string_view active_field) const = 0;
/* Returns nullopt if the specified field is not a list of strings */
virtual std::optional<StringList> GetStrings(std::string_view active_field) const = 0;

/* Returns nullopt if the specified field is not a vector */
virtual std::optional<VectorInfo> GetVector(std::string_view active_field) const = 0;

/* Return nullopt if the specified field is not a list of doubles */
virtual std::optional<NumsList> GetNumbers(std::string_view active_field) const = 0;
};

// Base class for type-specific indices.
Expand All @@ -81,8 +88,10 @@ struct DocumentAccessor {
// query functions. All results for all index types should be sorted.
struct BaseIndex {
virtual ~BaseIndex() = default;
virtual void Add(DocId id, DocumentAccessor* doc, std::string_view field) = 0;
virtual void Remove(DocId id, DocumentAccessor* doc, std::string_view field) = 0;

// Returns true if the document was added / indexed
virtual bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) = 0;
virtual void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) = 0;
};

// Base class for type-specific sorting indices.
Expand All @@ -91,4 +100,20 @@ struct BaseSortIndex : BaseIndex {
virtual std::vector<ResultScore> Sort(std::vector<DocId>* ids, size_t limit, bool desc) const = 0;
};

/* Used for converting field values to double. Returns std::nullopt if the conversion fails */
std::optional<double> ParseNumericField(std::string_view value);

/* Temporary method to create an empty std::optional<InlinedVector> in DocumentAccessor::GetString
and DocumentAccessor::GetNumbers methods. The problem is that due to internal implementation
details of absl::InlineVector, we are getting a -Wmaybe-uninitialized compiler warning. To
suppress this false warning, we temporarily disable it around this block of code using GCC
diagnostic directives. */
template <typename InlinedVector> std::optional<InlinedVector> EmptyAccessResult() {
// GCC 13.1 throws spurious warnings around this code.
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
return InlinedVector{};
#pragma GCC diagnostic pop
}

} // namespace dfly::search
76 changes: 50 additions & 26 deletions src/core/search/indices.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,19 +71,22 @@ absl::flat_hash_set<string> NormalizeTags(string_view taglist, bool case_sensiti
NumericIndex::NumericIndex(PMR_NS::memory_resource* mr) : entries_{mr} {
}

void NumericIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
for (auto str : doc->GetStrings(field)) {
double num;
if (absl::SimpleAtod(str, &num))
entries_.emplace(num, id);
bool NumericIndex::Add(DocId id, const DocumentAccessor& doc, string_view field) {
auto numbers = doc.GetNumbers(field);
if (!numbers) {
return false;
}

for (auto num : numbers.value()) {
entries_.emplace(num, id);
}
return true;
}

void NumericIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) {
for (auto str : doc->GetStrings(field)) {
double num;
if (absl::SimpleAtod(str, &num))
entries_.erase({num, id});
void NumericIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) {
auto numbers = doc.GetNumbers(field).value();
for (auto num : numbers) {
entries_.erase({num, id});
}
}

Expand Down Expand Up @@ -139,19 +142,27 @@ typename BaseStringIndex<C>::Container* BaseStringIndex<C>::GetOrCreate(string_v
}

template <typename C>
void BaseStringIndex<C>::Add(DocId id, DocumentAccessor* doc, string_view field) {
bool BaseStringIndex<C>::Add(DocId id, const DocumentAccessor& doc, string_view field) {
auto strings_list = doc.GetStrings(field);
if (!strings_list) {
return false;
}

absl::flat_hash_set<std::string> tokens;
for (string_view str : doc->GetStrings(field))
for (string_view str : strings_list.value())
tokens.merge(Tokenize(str));

for (string_view token : tokens)
GetOrCreate(token)->Insert(id);
return true;
}

template <typename C>
void BaseStringIndex<C>::Remove(DocId id, DocumentAccessor* doc, string_view field) {
void BaseStringIndex<C>::Remove(DocId id, const DocumentAccessor& doc, string_view field) {
auto strings_list = doc.GetStrings(field).value();

absl::flat_hash_set<std::string> tokens;
for (string_view str : doc->GetStrings(field))
for (string_view str : strings_list)
tokens.merge(Tokenize(str));

for (const auto& token : tokens) {
Expand Down Expand Up @@ -192,26 +203,39 @@ std::pair<size_t /*dim*/, VectorSimilarity> BaseVectorIndex::Info() const {
return {dim_, sim_};
}

bool BaseVectorIndex::Add(DocId id, const DocumentAccessor& doc, std::string_view field) {
auto vector = doc.GetVector(field);
if (!vector)
return false;

auto& [ptr, size] = vector.value();
if (ptr && size != dim_) {
return false;
}

AddVector(id, ptr);
return true;
}

FlatVectorIndex::FlatVectorIndex(const SchemaField::VectorParams& params,
PMR_NS::memory_resource* mr)
: BaseVectorIndex{params.dim, params.sim}, entries_{mr} {
DCHECK(!params.use_hnsw);
entries_.reserve(params.capacity * params.dim);
}

void FlatVectorIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
void FlatVectorIndex::AddVector(DocId id, const VectorPtr& vector) {
DCHECK_LE(id * dim_, entries_.size());
if (id * dim_ == entries_.size())
entries_.resize((id + 1) * dim_);

// TODO: Let get vector write to buf itself
auto [ptr, size] = doc->GetVector(field);

if (size == dim_)
memcpy(&entries_[id * dim_], ptr.get(), dim_ * sizeof(float));
if (vector) {
memcpy(&entries_[id * dim_], vector.get(), dim_ * sizeof(float));
}
}

void FlatVectorIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) {
void FlatVectorIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) {
// noop
}

Expand All @@ -229,7 +253,7 @@ struct HnswlibAdapter {
100 /* seed*/} {
}

void Add(float* data, DocId id) {
void Add(const float* data, DocId id) {
if (world_.cur_element_count + 1 >= world_.max_elements_)
world_.resizeIndex(world_.cur_element_count * 2);
world_.addPoint(data, id);
Expand Down Expand Up @@ -298,10 +322,10 @@ HnswVectorIndex::HnswVectorIndex(const SchemaField::VectorParams& params, PMR_NS
HnswVectorIndex::~HnswVectorIndex() {
}

void HnswVectorIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
auto [ptr, size] = doc->GetVector(field);
if (size == dim_)
adapter_->Add(ptr.get(), id);
void HnswVectorIndex::AddVector(DocId id, const VectorPtr& vector) {
if (vector) {
adapter_->Add(vector.get(), id);
}
}

std::vector<std::pair<float, DocId>> HnswVectorIndex::Knn(float* target, size_t k,
Expand All @@ -314,7 +338,7 @@ std::vector<std::pair<float, DocId>> HnswVectorIndex::Knn(float* target, size_t
return adapter_->Knn(target, k, ef, allowed);
}

void HnswVectorIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) {
void HnswVectorIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) {
adapter_->Remove(id);
}

Expand Down
27 changes: 18 additions & 9 deletions src/core/search/indices.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ namespace dfly::search {
struct NumericIndex : public BaseIndex {
explicit NumericIndex(PMR_NS::memory_resource* mr);

void Add(DocId id, DocumentAccessor* doc, std::string_view field) override;
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override;
bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override;
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;

std::vector<DocId> Range(double l, double r) const;

Expand All @@ -44,16 +44,16 @@ template <typename C> struct BaseStringIndex : public BaseIndex {

BaseStringIndex(PMR_NS::memory_resource* mr, bool case_sensitive);

void Add(DocId id, DocumentAccessor* doc, std::string_view field) override;
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override;
bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override;
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;

// Used by Add & Remove to tokenize text value
virtual absl::flat_hash_set<std::string> Tokenize(std::string_view value) const = 0;

// Pointer is valid as long as index is not mutated. Nullptr if not found
const Container* Matching(std::string_view str) const;

// Iterate over all Machting on prefix.
// Iterate over all Matching on prefix.
void MatchingPrefix(std::string_view prefix, absl::FunctionRef<void(const Container*)> cb) const;

// Returns all the terms that appear as keys in the reverse index.
Expand Down Expand Up @@ -97,9 +97,14 @@ struct TagIndex : public BaseStringIndex<SortedVector> {
struct BaseVectorIndex : public BaseIndex {
std::pair<size_t /*dim*/, VectorSimilarity> Info() const;

bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override final;

protected:
BaseVectorIndex(size_t dim, VectorSimilarity sim);

using VectorPtr = decltype(std::declval<OwnedFtVector>().first);
virtual void AddVector(DocId id, const VectorPtr& vector) = 0;

size_t dim_;
VectorSimilarity sim_;
};
Expand All @@ -109,11 +114,13 @@ struct BaseVectorIndex : public BaseIndex {
struct FlatVectorIndex : public BaseVectorIndex {
FlatVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource* mr);

void Add(DocId id, DocumentAccessor* doc, std::string_view field) override;
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override;
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;

const float* Get(DocId doc) const;

protected:
void AddVector(DocId id, const VectorPtr& vector) override;

private:
PMR_NS::vector<float> entries_;
};
Expand All @@ -124,13 +131,15 @@ struct HnswVectorIndex : public BaseVectorIndex {
HnswVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource* mr);
~HnswVectorIndex();

void Add(DocId id, DocumentAccessor* doc, std::string_view field) override;
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override;
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;

std::vector<std::pair<float, DocId>> Knn(float* target, size_t k, std::optional<size_t> ef) const;
std::vector<std::pair<float, DocId>> Knn(float* target, size_t k, std::optional<size_t> ef,
const std::vector<DocId>& allowed) const;

protected:
void AddVector(DocId id, const VectorPtr& vector) override;

private:
std::unique_ptr<HnswlibAdapter> adapter_;
};
Expand Down
39 changes: 32 additions & 7 deletions src/core/search/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -571,23 +571,48 @@ void FieldIndices::CreateSortIndices(PMR_NS::memory_resource* mr) {
}
}

void FieldIndices::Add(DocId doc, DocumentAccessor* access) {
for (auto& [field, index] : indices_)
index->Add(doc, access, field);
for (auto& [field, sort_index] : sort_indices_)
sort_index->Add(doc, access, field);
bool FieldIndices::Add(DocId doc, const DocumentAccessor& access) {
bool was_added = true;

std::vector<std::pair<std::string_view, BaseIndex*>> successfully_added_indices;
successfully_added_indices.reserve(indices_.size() + sort_indices_.size());

auto try_add = [&](const auto& indices_container) {
for (auto& [field, index] : indices_container) {
if (index->Add(doc, access, field)) {
successfully_added_indices.emplace_back(field, index.get());
} else {
was_added = false;
break;
}
}
};

try_add(indices_);

if (was_added) {
try_add(sort_indices_);
}

if (!was_added) {
for (auto& [field, index] : successfully_added_indices) {
index->Remove(doc, access, field);
}
return false;
}

all_ids_.insert(upper_bound(all_ids_.begin(), all_ids_.end(), doc), doc);
return true;
}

void FieldIndices::Remove(DocId doc, DocumentAccessor* access) {
void FieldIndices::Remove(DocId doc, const DocumentAccessor& access) {
for (auto& [field, index] : indices_)
index->Remove(doc, access, field);
for (auto& [field, sort_index] : sort_indices_)
sort_index->Remove(doc, access, field);

auto it = lower_bound(all_ids_.begin(), all_ids_.end(), doc);
CHECK(it != all_ids_.end() && *it == doc);
DCHECK(it != all_ids_.end() && *it == doc);
all_ids_.erase(it);
}

Expand Down
5 changes: 3 additions & 2 deletions src/core/search/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ class FieldIndices {
// Create indices based on schema and options. Both must outlive the indices
FieldIndices(const Schema& schema, const IndicesOptions& options, PMR_NS::memory_resource* mr);

void Add(DocId doc, DocumentAccessor* access);
void Remove(DocId doc, DocumentAccessor* access);
// Returns true if document was added
bool Add(DocId doc, const DocumentAccessor& access);
void Remove(DocId doc, const DocumentAccessor& access);

BaseIndex* GetIndex(std::string_view field) const;
BaseSortIndex* GetSortIndex(std::string_view field) const;
Expand Down
Loading

0 comments on commit a506795

Please sign in to comment.