diff --git a/be/src/olap/rowset/segment_v2/ann_index/faiss_ann_index.cpp b/be/src/olap/rowset/segment_v2/ann_index/faiss_ann_index.cpp index 140fc800482991..5c29384b876030 100644 --- a/be/src/olap/rowset/segment_v2/ann_index/faiss_ann_index.cpp +++ b/be/src/olap/rowset/segment_v2/ann_index/faiss_ann_index.cpp @@ -117,18 +117,28 @@ class ScopedThreadName { std::string _previous_name; }; +class IDSelectorRoaring : public faiss::IDSelector { +public: + explicit IDSelectorRoaring(const roaring::Roaring* roaring) : _roaring(roaring) { + DCHECK(_roaring != nullptr); + } + + bool is_member(faiss::idx_t id) const final { + if (id < 0) { + return false; + } + return _roaring->contains(cast_set(id)); + } + +private: + const roaring::Roaring* _roaring; +}; + } // namespace std::unique_ptr FaissVectorIndex::roaring_to_faiss_selector( const roaring::Roaring& roaring) { - std::vector ids; - ids.resize(roaring.cardinality()); - - size_t i = 0; - for (roaring::Roaring::const_iterator it = roaring.begin(); it != roaring.end(); ++it, ++i) { - ids[i] = cast_set(*it); - } - // construct derived and wrap into base unique_ptr explicitly - return std::unique_ptr(new faiss::IDSelectorBatch(ids.size(), ids.data())); + // Wrap the roaring bitmap directly to avoid copying ids into an intermediate buffer. + return std::unique_ptr(new IDSelectorRoaring(&roaring)); } void FaissVectorIndex::update_roaring(const faiss::idx_t* labels, const size_t n, diff --git a/be/test/olap/vector_search/faiss_vector_index_test.cpp b/be/test/olap/vector_search/faiss_vector_index_test.cpp index 60da7e939ee23a..82869837c59501 100644 --- a/be/test/olap/vector_search/faiss_vector_index_test.cpp +++ b/be/test/olap/vector_search/faiss_vector_index_test.cpp @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -916,6 +917,55 @@ TEST_F(VectorSearchTest, TestIdSelectorWithEmptyRoaring) { } } +TEST_F(VectorSearchTest, TestIdSelectorRoaringBasicMembership) { + auto roaring = std::make_unique(); + for (uint32_t i : {0u, 2u, 4u, 1000u}) { + roaring->add(i); + } + auto sel = FaissVectorIndex::roaring_to_faiss_selector(*roaring); + for (uint32_t i : {0u, 2u, 4u, 1000u}) { + ASSERT_TRUE(sel->is_member(static_cast(i))) + << "Expected id " << i << " present"; + } + for (uint32_t i : {1u, 3u, 5u, 999u, 1001u}) { + ASSERT_FALSE(sel->is_member(static_cast(i))) + << "Unexpected id " << i << " present"; + } + ASSERT_FALSE(sel->is_member(-1)) << "Negative ids should never match"; +} + +TEST_F(VectorSearchTest, TestIdSelectorRoaringReflectsBitmapUpdates) { + auto roaring = std::make_unique(); + roaring->add(10); + auto sel = FaissVectorIndex::roaring_to_faiss_selector(*roaring); + + ASSERT_TRUE(sel->is_member(10)); + ASSERT_FALSE(sel->is_member(20)); + + roaring->add(20); + roaring->remove(10); + + ASSERT_FALSE(sel->is_member(10)) << "Selector should track removals"; + ASSERT_TRUE(sel->is_member(20)) << "Selector should track additions"; +} + +TEST_F(VectorSearchTest, TestIdSelectorRoaringHandlesUInt32Max) { + auto roaring = std::make_unique(); + constexpr uint32_t kMax = std::numeric_limits::max(); + roaring->add(kMax); + auto sel = FaissVectorIndex::roaring_to_faiss_selector(*roaring); + + ASSERT_TRUE(sel->is_member(static_cast(kMax))) + << "Expected uint32_t max to be present"; + bool exception_thrown = false; + try { + sel->is_member(static_cast(kMax) + 1); + } catch (const std::exception& e) { + exception_thrown = true; + } + ASSERT_TRUE(exception_thrown) << "Expected exception for value beyond uint32_t max"; +} + // New tests: radius == 0 or < 0 TEST_F(VectorSearchTest, L2RangeSearchZeroAndNegativeRadius) { const int dim = 32;