Skip to content

Commit c83315d

Browse files
jac0626silas.jiang
andcommitted
fix: AsyncMilvusClient.search support EmbeddingList (milvus-io#3101)
Convert EmbeddingList to flat array in search() and hybrid_search() before calling Prepare(just like the sync client). Enhance is_legal_search_data() to recognize EmbeddingList. Add unit tests for search and hybrid_search with EmbeddingList. also see milvus-io#3059 Signed-off-by: silas.jiang <[email protected]> Co-authored-by: silas.jiang <[email protected]> (cherry picked from commit ddff574) Signed-off-by: silas.jiang <[email protected]>
1 parent 57db59f commit c83315d

File tree

3 files changed

+184
-2
lines changed

3 files changed

+184
-2
lines changed

pymilvus/client/async_grpc_handler.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
is_legal_port,
3636
)
3737
from .constants import ITERATOR_SESSION_TS_FIELD
38+
from .embedding_list import EmbeddingList
3839
from .interceptor import _api_level_md
3940
from .prepare import Prepare
4041
from .search_result import SearchResult
@@ -842,6 +843,12 @@ async def search(
842843
guarantee_timestamp=kwargs.get("guarantee_timestamp"),
843844
timeout=timeout,
844845
)
846+
847+
# Convert EmbeddingList to flat array if present
848+
if isinstance(data, list) and len(data) > 0 and isinstance(data[0], EmbeddingList):
849+
data = [emb_list.to_flat_array() for emb_list in data]
850+
kwargs["is_embedding_list"] = True
851+
845852
request = Prepare.search_requests_with_expr(
846853
collection_name,
847854
data,
@@ -883,17 +890,24 @@ async def hybrid_search(
883890

884891
requests = []
885892
for req in reqs:
893+
data = req.data
894+
req_kwargs = dict(kwargs)
895+
# Convert EmbeddingList to flat array if present
896+
if isinstance(data, list) and len(data) > 0 and isinstance(data[0], EmbeddingList):
897+
data = [emb_list.to_flat_array() for emb_list in data]
898+
req_kwargs["is_embedding_list"] = True
899+
886900
search_request = Prepare.search_requests_with_expr(
887901
collection_name,
888-
req.data,
902+
data,
889903
req.anns_field,
890904
req.param,
891905
req.limit,
892906
req.expr,
893907
partition_names=partition_names,
894908
round_decimal=round_decimal,
895909
expr_params=req.expr_params,
896-
**kwargs,
910+
**req_kwargs,
897911
)
898912
requests.append(search_request)
899913

pymilvus/client/check.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,11 @@ def is_legal_search_data(data: Any) -> bool:
205205
if entity_helper.entity_is_sparse_matrix(data):
206206
return True
207207

208+
# Support EmbeddingList for array-of-vector searches
209+
# Check for EmbeddingList by type name to avoid circular dependency
210+
if isinstance(data, list) and len(data) > 0 and type(data[0]).__name__ == "EmbeddingList":
211+
return True
212+
208213
if not isinstance(data, (list, np.ndarray)):
209214
return False
210215

tests/test_async_grpc_handler.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from unittest.mock import AsyncMock, MagicMock, patch
22

3+
import numpy as np
34
import pytest
45
from pymilvus.client.async_grpc_handler import AsyncGrpcHandler
56
from pymilvus.exceptions import MilvusException
7+
from pymilvus.grpc_gen import schema_pb2
68

79

810
class TestAsyncGrpcHandler:
@@ -312,3 +314,164 @@ async def test_create_index_with_nested_field(self) -> None:
312314

313315
# Verify wait_for_creating_index was called
314316
handler.wait_for_creating_index.assert_called_once()
317+
318+
@pytest.mark.asyncio
319+
async def test_search_with_embedding_list(self) -> None:
320+
"""
321+
Test that search works with EmbeddingList input data.
322+
This test verifies the fix for issue where AsyncMilvusClient.search
323+
failed when using EmbeddingList for array-of-vector searches.
324+
"""
325+
# Setup mock channel and stub
326+
mock_channel = AsyncMock()
327+
mock_channel.channel_ready = AsyncMock()
328+
mock_channel._unary_unary_interceptors = []
329+
330+
handler = AsyncGrpcHandler(channel=mock_channel)
331+
handler._is_channel_ready = True
332+
333+
mock_stub = AsyncMock()
334+
handler._async_stub = mock_stub
335+
336+
# Mock Search response with proper SearchResultData structure
337+
mock_search_result_data = schema_pb2.SearchResultData(
338+
num_queries=2,
339+
top_k=0,
340+
scores=[],
341+
ids=schema_pb2.IDs(int_id=schema_pb2.LongArray(data=[])),
342+
topks=[],
343+
primary_field_name="id"
344+
)
345+
mock_search_response = MagicMock()
346+
mock_status = MagicMock()
347+
mock_status.code = 0
348+
mock_status.reason = ""
349+
mock_search_response.status = mock_status
350+
mock_search_response.results = mock_search_result_data
351+
mock_search_response.session_ts = 0
352+
mock_stub.Search = AsyncMock(return_value=mock_search_response)
353+
354+
# Create EmbeddingList data
355+
from pymilvus.client.embedding_list import EmbeddingList
356+
emb_list1 = EmbeddingList()
357+
emb_list1.add([0.1, 0.2, 0.3, 0.4, 0.5])
358+
emb_list2 = EmbeddingList()
359+
emb_list2.add([0.5, 0.4, 0.3, 0.2, 0.1])
360+
data = [emb_list1, emb_list2]
361+
362+
with patch('pymilvus.client.async_grpc_handler.Prepare') as mock_prepare, \
363+
patch('pymilvus.client.async_grpc_handler.check_pass_param'), \
364+
patch('pymilvus.client.async_grpc_handler.check_status'), \
365+
patch('pymilvus.client.async_grpc_handler._api_level_md', return_value={}):
366+
367+
# Mock search_requests_with_expr to return a request
368+
mock_request = MagicMock()
369+
mock_prepare.search_requests_with_expr.return_value = mock_request
370+
371+
await handler.search(
372+
collection_name="test_collection",
373+
data=data,
374+
anns_field="vector",
375+
param={"metric_type": "COSINE"},
376+
limit=10
377+
)
378+
379+
# Verify that Prepare.search_requests_with_expr was called
380+
mock_prepare.search_requests_with_expr.assert_called_once()
381+
call_args = mock_prepare.search_requests_with_expr.call_args
382+
383+
# Verify that is_embedding_list was passed as True in kwargs
384+
assert call_args.kwargs.get("is_embedding_list") is True
385+
386+
# Verify data was converted (not EmbeddingList objects anymore)
387+
passed_data = call_args[0][1] # data is the second positional argument
388+
assert isinstance(passed_data, list)
389+
assert not isinstance(passed_data[0], EmbeddingList)
390+
# The data should be converted to flat arrays
391+
assert isinstance(passed_data[0], (list, np.ndarray))
392+
393+
# Verify Search was called
394+
mock_stub.Search.assert_called_once()
395+
396+
@pytest.mark.asyncio
397+
async def test_hybrid_search_with_embedding_list(self) -> None:
398+
"""
399+
Test that hybrid_search works with EmbeddingList input data.
400+
"""
401+
# Setup mock channel and stub
402+
mock_channel = AsyncMock()
403+
mock_channel.channel_ready = AsyncMock()
404+
mock_channel._unary_unary_interceptors = []
405+
406+
handler = AsyncGrpcHandler(channel=mock_channel)
407+
handler._is_channel_ready = True
408+
409+
mock_stub = AsyncMock()
410+
handler._async_stub = mock_stub
411+
412+
# Mock HybridSearch response with proper SearchResultData structure
413+
mock_hybrid_result_data = schema_pb2.SearchResultData(
414+
num_queries=1,
415+
top_k=0,
416+
scores=[],
417+
ids=schema_pb2.IDs(int_id=schema_pb2.LongArray(data=[])),
418+
topks=[],
419+
primary_field_name="id"
420+
)
421+
mock_hybrid_response = MagicMock()
422+
mock_status = MagicMock()
423+
mock_status.code = 0
424+
mock_status.reason = ""
425+
mock_hybrid_response.status = mock_status
426+
mock_hybrid_response.results = mock_hybrid_result_data
427+
mock_stub.HybridSearch = AsyncMock(return_value=mock_hybrid_response)
428+
429+
# Create AnnSearchRequest with EmbeddingList
430+
from pymilvus.client.embedding_list import EmbeddingList
431+
from pymilvus.client.abstract import AnnSearchRequest
432+
import numpy as np
433+
434+
emb_list = EmbeddingList()
435+
emb_list.add([0.1, 0.2, 0.3])
436+
req = AnnSearchRequest(
437+
data=[emb_list],
438+
anns_field="vector",
439+
param={"metric_type": "COSINE"},
440+
limit=10
441+
)
442+
443+
with patch('pymilvus.client.async_grpc_handler.Prepare') as mock_prepare, \
444+
patch('pymilvus.client.async_grpc_handler.check_pass_param'), \
445+
patch('pymilvus.client.async_grpc_handler.check_status'), \
446+
patch('pymilvus.client.async_grpc_handler._api_level_md', return_value={}):
447+
448+
# Mock search_requests_with_expr and hybrid_search_request_with_ranker
449+
mock_search_request = MagicMock()
450+
mock_hybrid_request = MagicMock()
451+
mock_prepare.search_requests_with_expr.return_value = mock_search_request
452+
mock_prepare.hybrid_search_request_with_ranker.return_value = mock_hybrid_request
453+
454+
# Mock rerank (BaseRanker)
455+
mock_ranker = MagicMock()
456+
457+
await handler.hybrid_search(
458+
collection_name="test_collection",
459+
reqs=[req],
460+
rerank=mock_ranker,
461+
limit=10
462+
)
463+
464+
# Verify that search_requests_with_expr was called with converted data
465+
mock_prepare.search_requests_with_expr.assert_called_once()
466+
call_args = mock_prepare.search_requests_with_expr.call_args
467+
468+
# Verify is_embedding_list flag was set
469+
assert call_args.kwargs.get("is_embedding_list") is True
470+
471+
# Verify data was converted
472+
passed_data = call_args[0][1]
473+
assert isinstance(passed_data, list)
474+
assert not isinstance(passed_data[0], EmbeddingList)
475+
476+
# Verify HybridSearch was called
477+
mock_stub.HybridSearch.assert_called_once()

0 commit comments

Comments
 (0)