Skip to content

Commit 5f444ae

Browse files
jac0626silas.jiang
andauthored
[enhance]: Add some missing features to AsyncMilvusClient and refactor common code (#3108)
## Summary Adds some missing features to `AsyncMilvusClient` and extracts common code into `BaseMilvusClient` base class. ## New Features - `get_server_type()`, `flush_all()`, `get_flush_all_state()`, `get_compaction_plans()`, `update_replicate_configuration()` - Enhanced `get_load_state()` to return progress when loading - Enhanced `describe_collection()` to convert struct_array_fields - Added `ranker` parameter to `search()` and `Function` ranker support to `hybrid_search()` - Added `is_l0` parameter to `compact()` - Added `create_struct_field_schema()` class method ## Refactoring - Created `BaseMilvusClient` with shared class/instance methods - Both clients now inherit from base class Signed-off-by: silas.jiang <[email protected]> Co-authored-by: silas.jiang <[email protected]>
1 parent 7130008 commit 5f444ae

File tree

5 files changed

+507
-132
lines changed

5 files changed

+507
-132
lines changed

pymilvus/milvus_client/async_milvus_client.py

Lines changed: 111 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,29 @@
44
from pymilvus.client.constants import DEFAULT_CONSISTENCY_LEVEL
55
from pymilvus.client.types import (
66
ExceptionsMessage,
7+
LoadState,
78
OmitZeroDict,
89
ResourceGroupConfig,
910
RoleInfo,
1011
UserInfo,
1112
)
13+
from pymilvus.client.utils import convert_struct_fields_to_user_format
1214
from pymilvus.exceptions import (
1315
DataTypeNotMatchException,
1416
ParamError,
1517
PrimaryKeyException,
1618
)
17-
from pymilvus.orm import utility
18-
from pymilvus.orm.collection import CollectionSchema
19+
from pymilvus.orm.collection import CollectionSchema, Function, FunctionScore
1920
from pymilvus.orm.connections import connections
20-
from pymilvus.orm.schema import FieldSchema
2121
from pymilvus.orm.types import DataType
2222

2323
from ._utils import create_connection
24+
from .base import BaseMilvusClient
2425
from .check import validate_param
2526
from .index import IndexParam, IndexParams
2627

2728

28-
class AsyncMilvusClient:
29+
class AsyncMilvusClient(BaseMilvusClient):
2930
"""AsyncMilvusClient is an EXPERIMENTAL class
3031
which only provides part of MilvusClient's methods"""
3132

@@ -49,7 +50,7 @@ def __init__(
4950
timeout=timeout,
5051
**kwargs,
5152
)
52-
self.is_self_hosted = bool(utility.get_server_type(using=self._using) == "milvus")
53+
self.is_self_hosted = bool(self.get_server_type() == "milvus")
5354

5455
async def create_collection(
5556
self,
@@ -361,7 +362,7 @@ async def hybrid_search(
361362
self,
362363
collection_name: str,
363364
reqs: List[AnnSearchRequest],
364-
ranker: BaseRanker,
365+
ranker: Union[BaseRanker, Function],
365366
limit: int = 10,
366367
output_fields: Optional[List[str]] = None,
367368
timeout: Optional[float] = None,
@@ -391,6 +392,7 @@ async def search(
391392
timeout: Optional[float] = None,
392393
partition_names: Optional[List[str]] = None,
393394
anns_field: Optional[str] = None,
395+
ranker: Optional[Union[Function, FunctionScore]] = None,
394396
**kwargs,
395397
) -> List[List[dict]]:
396398
conn = self._get_connection()
@@ -405,6 +407,7 @@ async def search(
405407
partition_names=partition_names,
406408
expr_params=kwargs.pop("filter_params", {}),
407409
timeout=timeout,
410+
ranker=ranker,
408411
**kwargs,
409412
)
410413

@@ -559,7 +562,14 @@ async def describe_collection(
559562
self, collection_name: str, timeout: Optional[float] = None, **kwargs
560563
) -> dict:
561564
conn = self._get_connection()
562-
return await conn.describe_collection(collection_name, timeout=timeout, **kwargs)
565+
result = await conn.describe_collection(collection_name, timeout=timeout, **kwargs)
566+
# Convert internal struct_array_fields to user-friendly format
567+
if isinstance(result, dict) and "struct_array_fields" in result:
568+
converted_fields = convert_struct_fields_to_user_format(result["struct_array_fields"])
569+
result["fields"].extend(converted_fields)
570+
# Remove internal struct_array_fields from user-facing response
571+
result.pop("struct_array_fields")
572+
return result
563573

564574
async def has_collection(
565575
self, collection_name: str, timeout: Optional[float] = None, **kwargs
@@ -601,10 +611,19 @@ async def get_load_state(
601611
**kwargs,
602612
):
603613
conn = self._get_connection()
604-
return await conn.get_load_state(
614+
state = await conn.get_load_state(
605615
collection_name, partition_names, timeout=timeout, **kwargs
606616
)
607617

618+
ret = {"state": state}
619+
if state == LoadState.Loading:
620+
progress = await conn.get_loading_progress(
621+
collection_name, partition_names, timeout=timeout
622+
)
623+
ret["progress"] = progress
624+
625+
return ret
626+
608627
async def refresh_load(
609628
self,
610629
collection_name: str,
@@ -683,56 +702,9 @@ async def add_collection_field(
683702
**kwargs,
684703
)
685704

686-
@classmethod
687-
def create_schema(cls, **kwargs):
688-
kwargs["check_fields"] = False # do not check fields for now
689-
return CollectionSchema([], **kwargs)
690-
691-
@classmethod
692-
def create_field_schema(
693-
cls, name: str, data_type: DataType, desc: str = "", **kwargs
694-
) -> FieldSchema:
695-
return FieldSchema(name, data_type, desc, **kwargs)
696-
697-
@classmethod
698-
def prepare_index_params(cls, field_name: str = "", **kwargs) -> IndexParams:
699-
index_params = IndexParams()
700-
if field_name:
701-
validate_param("field_name", field_name, str)
702-
index_params.add_index(field_name, **kwargs)
703-
return index_params
704-
705705
async def close(self):
706706
await connections.async_remove_connection(self._using)
707707

708-
def _get_connection(self):
709-
return connections._fetch_handler(self._using)
710-
711-
def _extract_primary_field(self, schema_dict: Dict) -> dict:
712-
fields = schema_dict.get("fields", [])
713-
if not fields:
714-
return {}
715-
716-
for field_dict in fields:
717-
if field_dict.get("is_primary", None) is not None:
718-
return field_dict
719-
720-
return {}
721-
722-
def _pack_pks_expr(self, schema_dict: Dict, pks: List) -> str:
723-
primary_field = self._extract_primary_field(schema_dict)
724-
pk_field_name = primary_field["name"]
725-
data_type = primary_field["type"]
726-
727-
# Varchar pks need double quotes around the values
728-
if data_type == DataType.VARCHAR:
729-
ids = ["'" + str(entry) + "'" for entry in pks]
730-
expr = f"""{pk_field_name} in [{','.join(ids)}]"""
731-
else:
732-
ids = [str(entry) for entry in pks]
733-
expr = f"{pk_field_name} in [{','.join(ids)}]"
734-
return expr
735-
736708
async def list_indexes(self, collection_name: str, field_name: Optional[str] = "", **kwargs):
737709
conn = self._get_connection()
738710
indexes = await conn.list_indexes(collection_name, **kwargs)
@@ -1080,16 +1052,40 @@ async def flush(self, collection_name: str, timeout: Optional[float] = None, **k
10801052
conn = self._get_connection()
10811053
await conn.flush([collection_name], timeout=timeout, **kwargs)
10821054

1055+
async def flush_all(self, timeout: Optional[float] = None, **kwargs) -> None:
1056+
"""Flush all collections.
1057+
1058+
Args:
1059+
timeout (Optional[float]): An optional duration of time in seconds to allow for the RPC.
1060+
**kwargs: Additional arguments.
1061+
"""
1062+
conn = self._get_connection()
1063+
await conn.flush_all(timeout=timeout, **kwargs)
1064+
1065+
async def get_flush_all_state(self, timeout: Optional[float] = None, **kwargs) -> bool:
1066+
"""Get the flush all state.
1067+
1068+
Args:
1069+
timeout (Optional[float]): An optional duration of time in seconds to allow for the RPC.
1070+
**kwargs: Additional arguments.
1071+
1072+
Returns:
1073+
bool: True if flush all operation is completed, False otherwise.
1074+
"""
1075+
conn = self._get_connection()
1076+
return await conn.get_flush_all_state(timeout=timeout, **kwargs)
1077+
10831078
async def compact(
10841079
self,
10851080
collection_name: str,
10861081
is_clustering: Optional[bool] = False,
1082+
is_l0: Optional[bool] = False,
10871083
timeout: Optional[float] = None,
10881084
**kwargs,
10891085
) -> int:
10901086
conn = self._get_connection()
10911087
return await conn.compact(
1092-
collection_name, is_clustering=is_clustering, timeout=timeout, **kwargs
1088+
collection_name, is_clustering=is_clustering, is_l0=is_l0, timeout=timeout, **kwargs
10931089
)
10941090

10951091
async def get_compaction_state(
@@ -1099,6 +1095,25 @@ async def get_compaction_state(
10991095
result = await conn.get_compaction_state(job_id, timeout=timeout, **kwargs)
11001096
return result.state_name
11011097

1098+
async def get_compaction_plans(
1099+
self,
1100+
job_id: int,
1101+
timeout: Optional[float] = None,
1102+
**kwargs,
1103+
):
1104+
"""Get compaction plans for a specific job.
1105+
1106+
Args:
1107+
job_id (int): The ID of the compaction job.
1108+
timeout (Optional[float]): An optional duration of time in seconds to allow for the RPC.
1109+
**kwargs: Additional arguments.
1110+
1111+
Returns:
1112+
CompactionPlans: The compaction plans for the specified job.
1113+
"""
1114+
conn = self._get_connection()
1115+
return await conn.get_compaction_plans(job_id, timeout=timeout, **kwargs)
1116+
11021117
async def run_analyzer(
11031118
self,
11041119
texts: Union[str, List[str]],
@@ -1123,3 +1138,43 @@ async def run_analyzer(
11231138
timeout=timeout,
11241139
**kwargs,
11251140
)
1141+
1142+
async def update_replicate_configuration(
1143+
self,
1144+
clusters: Optional[List[Dict]] = None,
1145+
cross_cluster_topology: Optional[List[Dict]] = None,
1146+
timeout: Optional[float] = None,
1147+
**kwargs,
1148+
):
1149+
"""
1150+
Update replication configuration across Milvus clusters.
1151+
1152+
Args:
1153+
clusters (List[Dict], optional): List of cluster configurations.
1154+
Each dict should contain:
1155+
- cluster_id (str): Unique identifier for the cluster
1156+
- connection_param (Dict): Connection parameters with 'uri' and 'token'
1157+
- pchannels (List[str], optional): Physical channels for the cluster
1158+
1159+
cross_cluster_topology (List[Dict], optional): List of replication relationships.
1160+
Each dict should contain:
1161+
- source_cluster_id (str): ID of the source cluster
1162+
- target_cluster_id (str): ID of the target cluster
1163+
1164+
timeout (float, optional): An optional duration of time in seconds to allow for the RPC
1165+
**kwargs: Additional arguments
1166+
1167+
Returns:
1168+
Status: The status of the operation
1169+
1170+
Raises:
1171+
ParamError: If neither clusters nor cross_cluster_topology is provided
1172+
MilvusException: If the operation fails
1173+
"""
1174+
conn = self._get_connection()
1175+
return await conn.update_replicate_configuration(
1176+
clusters=clusters,
1177+
cross_cluster_topology=cross_cluster_topology,
1178+
timeout=timeout,
1179+
**kwargs,
1180+
)

0 commit comments

Comments
 (0)