Skip to content

Commit 3044e4a

Browse files
authored
INTPYTHON-717 Test for hybrid search retriever with weighted rrf (#190)
1 parent 57c748c commit 3044e4a

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

libs/langchain-mongodb/tests/integration_tests/test_retrievers.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,57 @@ def test_hybrid_retriever_nested(
224224
assert "New Orleans" in results[0].page_content
225225

226226

227+
def test_hybrid_search_weighted_rrf(
228+
indexed_vectorstore: PatchedMongoDBAtlasVectorSearch,
229+
):
230+
vec_only_retriever = MongoDBAtlasHybridSearchRetriever(
231+
vectorstore=indexed_vectorstore,
232+
search_index_name=SEARCH_INDEX_NAME,
233+
k=3,
234+
vector_weight=1.0,
235+
fulltext_weight=0.0,
236+
)
237+
238+
text_only_retriever = MongoDBAtlasHybridSearchRetriever(
239+
vectorstore=indexed_vectorstore,
240+
search_index_name=SEARCH_INDEX_NAME,
241+
k=3,
242+
vector_weight=0.0,
243+
fulltext_weight=1.0,
244+
)
245+
246+
balanced_retriever = MongoDBAtlasHybridSearchRetriever(
247+
vectorstore=indexed_vectorstore,
248+
search_index_name=SEARCH_INDEX_NAME,
249+
k=3,
250+
vector_weight=1.0,
251+
fulltext_weight=1.0,
252+
)
253+
254+
query = "Sandwiches"
255+
256+
text_only_results = text_only_retriever.invoke(query)
257+
assert len(text_only_results) == 3 # but only one with non-zero text score
258+
single_text_score = text_only_results[0].metadata["fulltext_score"]
259+
assert single_text_score > 0
260+
assert all(
261+
result.metadata["fulltext_score"] == 0 for result in text_only_results[1:]
262+
)
263+
assert all(result.metadata["vector_score"] == 0 for result in text_only_results)
264+
total_score = sum(res.metadata["score"] for res in text_only_results)
265+
assert abs(total_score - single_text_score) < 0.001
266+
267+
vec_only_results = vec_only_retriever.invoke(query)
268+
assert len(vec_only_results) == 3
269+
assert all(result.metadata["vector_score"] > 0 for result in vec_only_results)
270+
assert all(result.metadata["fulltext_score"] == 0 for result in vec_only_results)
271+
total_vec_score = sum(res.metadata["score"] for res in vec_only_results)
272+
273+
balanced_results = balanced_retriever.invoke(query)
274+
total_score = sum(res.metadata["score"] for res in balanced_results)
275+
assert abs(total_score - (total_vec_score + single_text_score)) < 0.001
276+
277+
227278
def test_fulltext_retriever(
228279
indexed_vectorstore: PatchedMongoDBAtlasVectorSearch,
229280
) -> None:

0 commit comments

Comments
 (0)