Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions be/src/olap/rowset/segment_v2/ann_index/faiss_ann_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,28 @@ class ScopedThreadName {
std::string _previous_name;
};

class IDSelectorRoaring : public faiss::IDSelector {
public:
explicit IDSelectorRoaring(const roaring::Roaring* roaring) : _roaring(roaring) {
DCHECK(_roaring != nullptr);
}

bool is_member(faiss::idx_t id) const final {
if (id < 0) {
return false;
}
return _roaring->contains(cast_set<vectorized::UInt32>(id));
}

private:
const roaring::Roaring* _roaring;
};

} // namespace
std::unique_ptr<faiss::IDSelector> FaissVectorIndex::roaring_to_faiss_selector(
const roaring::Roaring& roaring) {
std::vector<faiss::idx_t> ids;
ids.resize(roaring.cardinality());

size_t i = 0;
for (roaring::Roaring::const_iterator it = roaring.begin(); it != roaring.end(); ++it, ++i) {
ids[i] = cast_set<faiss::idx_t>(*it);
}
// construct derived and wrap into base unique_ptr explicitly
return std::unique_ptr<faiss::IDSelector>(new faiss::IDSelectorBatch(ids.size(), ids.data()));
// Wrap the roaring bitmap directly to avoid copying ids into an intermediate buffer.
return std::unique_ptr<faiss::IDSelector>(new IDSelectorRoaring(&roaring));
}

void FaissVectorIndex::update_roaring(const faiss::idx_t* labels, const size_t n,
Expand Down
50 changes: 50 additions & 0 deletions be/test/olap/vector_search/faiss_vector_index_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <algorithm>
#include <cstddef>
#include <limits>
#include <memory>
#include <random>
#include <string>
Expand Down Expand Up @@ -916,6 +917,55 @@ TEST_F(VectorSearchTest, TestIdSelectorWithEmptyRoaring) {
}
}

TEST_F(VectorSearchTest, TestIdSelectorRoaringBasicMembership) {
auto roaring = std::make_unique<roaring::Roaring>();
for (uint32_t i : {0u, 2u, 4u, 1000u}) {
roaring->add(i);
}
auto sel = FaissVectorIndex::roaring_to_faiss_selector(*roaring);
for (uint32_t i : {0u, 2u, 4u, 1000u}) {
ASSERT_TRUE(sel->is_member(static_cast<faiss::idx_t>(i)))
<< "Expected id " << i << " present";
}
for (uint32_t i : {1u, 3u, 5u, 999u, 1001u}) {
ASSERT_FALSE(sel->is_member(static_cast<faiss::idx_t>(i)))
<< "Unexpected id " << i << " present";
}
ASSERT_FALSE(sel->is_member(-1)) << "Negative ids should never match";
}

TEST_F(VectorSearchTest, TestIdSelectorRoaringReflectsBitmapUpdates) {
auto roaring = std::make_unique<roaring::Roaring>();
roaring->add(10);
auto sel = FaissVectorIndex::roaring_to_faiss_selector(*roaring);

ASSERT_TRUE(sel->is_member(10));
ASSERT_FALSE(sel->is_member(20));

roaring->add(20);
roaring->remove(10);

ASSERT_FALSE(sel->is_member(10)) << "Selector should track removals";
ASSERT_TRUE(sel->is_member(20)) << "Selector should track additions";
}

TEST_F(VectorSearchTest, TestIdSelectorRoaringHandlesUInt32Max) {
auto roaring = std::make_unique<roaring::Roaring>();
constexpr uint32_t kMax = std::numeric_limits<uint32_t>::max();
roaring->add(kMax);
auto sel = FaissVectorIndex::roaring_to_faiss_selector(*roaring);

ASSERT_TRUE(sel->is_member(static_cast<faiss::idx_t>(kMax)))
<< "Expected uint32_t max to be present";
bool exception_thrown = false;
try {
sel->is_member(static_cast<faiss::idx_t>(kMax) + 1);
} catch (const std::exception& e) {
exception_thrown = true;
}
ASSERT_TRUE(exception_thrown) << "Expected exception for value beyond uint32_t max";
}

// New tests: radius == 0 or < 0
TEST_F(VectorSearchTest, L2RangeSearchZeroAndNegativeRadius) {
const int dim = 32;
Expand Down
Loading