Skip to content

Commit 2b10f88

Browse files
author
silas.jiang
committed
fix: AsyncMilvusClient.search support EmbeddingList
Convert EmbeddingList to flat array in search() and hybrid_search() before calling Prepare. Enhance is_legal_search_data() to recognize EmbeddingList. Add unit tests for search and hybrid_search with EmbeddingList. Signed-off-by: silas.jiang <[email protected]>
1 parent da333ee commit 2b10f88

File tree

3 files changed

+184
-2
lines changed

3 files changed

+184
-2
lines changed

pymilvus/client/async_grpc_handler.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
is_legal_port,
3737
)
3838
from .constants import ITERATOR_SESSION_TS_FIELD
39+
from .embedding_list import EmbeddingList
3940
from .interceptor import _api_level_md
4041
from .prepare import Prepare
4142
from .search_result import SearchResult
@@ -844,6 +845,12 @@ async def search(
844845
guarantee_timestamp=kwargs.get("guarantee_timestamp"),
845846
timeout=timeout,
846847
)
848+
849+
# Convert EmbeddingList to flat array if present
850+
if isinstance(data, list) and len(data) > 0 and isinstance(data[0], EmbeddingList):
851+
data = [emb_list.to_flat_array() for emb_list in data]
852+
kwargs["is_embedding_list"] = True
853+
847854
request = Prepare.search_requests_with_expr(
848855
collection_name,
849856
data,
@@ -884,17 +891,24 @@ async def hybrid_search(
884891

885892
requests = []
886893
for req in reqs:
894+
data = req.data
895+
req_kwargs = dict(kwargs)
896+
# Convert EmbeddingList to flat array if present
897+
if isinstance(data, list) and len(data) > 0 and isinstance(data[0], EmbeddingList):
898+
data = [emb_list.to_flat_array() for emb_list in data]
899+
req_kwargs["is_embedding_list"] = True
900+
887901
search_request = Prepare.search_requests_with_expr(
888902
collection_name,
889-
req.data,
903+
data,
890904
req.anns_field,
891905
req.param,
892906
req.limit,
893907
req.expr,
894908
partition_names=partition_names,
895909
round_decimal=round_decimal,
896910
expr_params=req.expr_params,
897-
**kwargs,
911+
**req_kwargs,
898912
)
899913
requests.append(search_request)
900914

pymilvus/client/check.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,11 @@ def is_legal_search_data(data: Any) -> bool:
205205
if entity_helper.entity_is_sparse_matrix(data):
206206
return True
207207

208+
# Support EmbeddingList for array-of-vector searches
209+
# Check for EmbeddingList by type name to avoid circular dependency
210+
if isinstance(data, list) and len(data) > 0 and type(data[0]).__name__ == "EmbeddingList":
211+
return True
212+
208213
if not isinstance(data, (list, np.ndarray)):
209214
return False
210215

tests/test_async_grpc_handler.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from unittest.mock import AsyncMock, MagicMock, patch
22

3+
import numpy as np
34
import pytest
45
from pymilvus.client.async_grpc_handler import AsyncGrpcHandler
56
from pymilvus.exceptions import MilvusException
7+
from pymilvus.grpc_gen import schema_pb2
68

79

810
class TestAsyncGrpcHandler:
@@ -242,3 +244,164 @@ async def test_load_partitions_with_resource_groups(self) -> None:
242244
replica_number=2,
243245
resource_groups=["rg1", "rg2"]
244246
)
247+
248+
@pytest.mark.asyncio
249+
async def test_search_with_embedding_list(self) -> None:
250+
"""
251+
Test that search works with EmbeddingList input data.
252+
This test verifies the fix for issue where AsyncMilvusClient.search
253+
failed when using EmbeddingList for array-of-vector searches.
254+
"""
255+
# Setup mock channel and stub
256+
mock_channel = AsyncMock()
257+
mock_channel.channel_ready = AsyncMock()
258+
mock_channel._unary_unary_interceptors = []
259+
260+
handler = AsyncGrpcHandler(channel=mock_channel)
261+
handler._is_channel_ready = True
262+
263+
mock_stub = AsyncMock()
264+
handler._async_stub = mock_stub
265+
266+
# Mock Search response with proper SearchResultData structure
267+
mock_search_result_data = schema_pb2.SearchResultData(
268+
num_queries=2,
269+
top_k=0,
270+
scores=[],
271+
ids=schema_pb2.IDs(int_id=schema_pb2.LongArray(data=[])),
272+
topks=[],
273+
primary_field_name="id"
274+
)
275+
mock_search_response = MagicMock()
276+
mock_status = MagicMock()
277+
mock_status.code = 0
278+
mock_status.reason = ""
279+
mock_search_response.status = mock_status
280+
mock_search_response.results = mock_search_result_data
281+
mock_search_response.session_ts = 0
282+
mock_stub.Search = AsyncMock(return_value=mock_search_response)
283+
284+
# Create EmbeddingList data
285+
from pymilvus.client.embedding_list import EmbeddingList
286+
emb_list1 = EmbeddingList()
287+
emb_list1.add([0.1, 0.2, 0.3, 0.4, 0.5])
288+
emb_list2 = EmbeddingList()
289+
emb_list2.add([0.5, 0.4, 0.3, 0.2, 0.1])
290+
data = [emb_list1, emb_list2]
291+
292+
with patch('pymilvus.client.async_grpc_handler.Prepare') as mock_prepare, \
293+
patch('pymilvus.client.async_grpc_handler.check_pass_param'), \
294+
patch('pymilvus.client.async_grpc_handler.check_status'), \
295+
patch('pymilvus.client.async_grpc_handler._api_level_md', return_value={}):
296+
297+
# Mock search_requests_with_expr to return a request
298+
mock_request = MagicMock()
299+
mock_prepare.search_requests_with_expr.return_value = mock_request
300+
301+
await handler.search(
302+
collection_name="test_collection",
303+
data=data,
304+
anns_field="vector",
305+
param={"metric_type": "COSINE"},
306+
limit=10
307+
)
308+
309+
# Verify that Prepare.search_requests_with_expr was called
310+
mock_prepare.search_requests_with_expr.assert_called_once()
311+
call_args = mock_prepare.search_requests_with_expr.call_args
312+
313+
# Verify that is_embedding_list was passed as True in kwargs
314+
assert call_args.kwargs.get("is_embedding_list") is True
315+
316+
# Verify data was converted (not EmbeddingList objects anymore)
317+
passed_data = call_args[0][1] # data is the second positional argument
318+
assert isinstance(passed_data, list)
319+
assert not isinstance(passed_data[0], EmbeddingList)
320+
# The data should be converted to flat arrays
321+
assert isinstance(passed_data[0], (list, np.ndarray))
322+
323+
# Verify Search was called
324+
mock_stub.Search.assert_called_once()
325+
326+
@pytest.mark.asyncio
327+
async def test_hybrid_search_with_embedding_list(self) -> None:
328+
"""
329+
Test that hybrid_search works with EmbeddingList input data.
330+
"""
331+
# Setup mock channel and stub
332+
mock_channel = AsyncMock()
333+
mock_channel.channel_ready = AsyncMock()
334+
mock_channel._unary_unary_interceptors = []
335+
336+
handler = AsyncGrpcHandler(channel=mock_channel)
337+
handler._is_channel_ready = True
338+
339+
mock_stub = AsyncMock()
340+
handler._async_stub = mock_stub
341+
342+
# Mock HybridSearch response with proper SearchResultData structure
343+
mock_hybrid_result_data = schema_pb2.SearchResultData(
344+
num_queries=1,
345+
top_k=0,
346+
scores=[],
347+
ids=schema_pb2.IDs(int_id=schema_pb2.LongArray(data=[])),
348+
topks=[],
349+
primary_field_name="id"
350+
)
351+
mock_hybrid_response = MagicMock()
352+
mock_status = MagicMock()
353+
mock_status.code = 0
354+
mock_status.reason = ""
355+
mock_hybrid_response.status = mock_status
356+
mock_hybrid_response.results = mock_hybrid_result_data
357+
mock_stub.HybridSearch = AsyncMock(return_value=mock_hybrid_response)
358+
359+
# Create AnnSearchRequest with EmbeddingList
360+
from pymilvus.client.embedding_list import EmbeddingList
361+
from pymilvus.client.abstract import AnnSearchRequest
362+
import numpy as np
363+
364+
emb_list = EmbeddingList()
365+
emb_list.add([0.1, 0.2, 0.3])
366+
req = AnnSearchRequest(
367+
data=[emb_list],
368+
anns_field="vector",
369+
param={"metric_type": "COSINE"},
370+
limit=10
371+
)
372+
373+
with patch('pymilvus.client.async_grpc_handler.Prepare') as mock_prepare, \
374+
patch('pymilvus.client.async_grpc_handler.check_pass_param'), \
375+
patch('pymilvus.client.async_grpc_handler.check_status'), \
376+
patch('pymilvus.client.async_grpc_handler._api_level_md', return_value={}):
377+
378+
# Mock search_requests_with_expr and hybrid_search_request_with_ranker
379+
mock_search_request = MagicMock()
380+
mock_hybrid_request = MagicMock()
381+
mock_prepare.search_requests_with_expr.return_value = mock_search_request
382+
mock_prepare.hybrid_search_request_with_ranker.return_value = mock_hybrid_request
383+
384+
# Mock rerank (BaseRanker)
385+
mock_ranker = MagicMock()
386+
387+
await handler.hybrid_search(
388+
collection_name="test_collection",
389+
reqs=[req],
390+
rerank=mock_ranker,
391+
limit=10
392+
)
393+
394+
# Verify that search_requests_with_expr was called with converted data
395+
mock_prepare.search_requests_with_expr.assert_called_once()
396+
call_args = mock_prepare.search_requests_with_expr.call_args
397+
398+
# Verify is_embedding_list flag was set
399+
assert call_args.kwargs.get("is_embedding_list") is True
400+
401+
# Verify data was converted
402+
passed_data = call_args[0][1]
403+
assert isinstance(passed_data, list)
404+
assert not isinstance(passed_data[0], EmbeddingList)
405+
406+
# Verify HybridSearch was called
407+
mock_stub.HybridSearch.assert_called_once()

0 commit comments

Comments
 (0)