diff --git a/requirements-dev.txt b/requirements-dev.txt index 3646e43..8c1ba6d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,4 @@ -weaviate-client>=4.16.7 +weaviate-client@git+https://github.com/weaviate/weaviate-python-client.git@rob/spfresh click==8.1.7 twine pytest diff --git a/setup.cfg b/setup.cfg index c9e6125..612aadb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,7 +37,7 @@ classifiers = include_package_data = True python_requires = >=3.9 install_requires = - weaviate-client>=4.16.7 + weaviate-client@git+https://github.com/weaviate/weaviate-python-client.git@rob/spfresh click==8.1.7 semver>=3.0.2 numpy>=1.24.0 diff --git a/weaviate_cli/commands/create.py b/weaviate_cli/commands/create.py index 27ba916..e8ab0ed 100644 --- a/weaviate_cli/commands/create.py +++ b/weaviate_cli/commands/create.py @@ -76,6 +76,7 @@ def create() -> None: "hnsw_acorn", "hnsw_multivector", "flat_bq", + "spfresh", ] ), help="Vector index type (default: 'hnsw').", @@ -157,6 +158,48 @@ def create() -> None: ), help="Replication deletion strategy (default: 'delete_on_conflict').", ) +@click.option( + "--spfresh_max_posting_size", + default=CreateCollectionDefaults.spfresh_max_posting_size, + type=int, + help="SPFresh max posting size (default: None).", +) +@click.option( + "--spfresh_min_posting_size", + default=CreateCollectionDefaults.spfresh_min_posting_size, + type=int, + help="SPFresh min posting size (default: None).", +) +@click.option( + "--spfresh_replicas", + default=CreateCollectionDefaults.spfresh_replicas, + type=int, + help="SPFresh replicas (default: None).", +) +@click.option( + "--spfresh_rng_factor", + default=CreateCollectionDefaults.spfresh_rng_factor, + type=int, + help="SPFresh RNG factor (default: None).", +) +@click.option( + "--spfresh_search_probe", + default=CreateCollectionDefaults.spfresh_search_probe, + type=int, + help="SPFresh search probe (default: None).", +) +@click.option( + "--spfresh_centroids_index_type", + default=CreateCollectionDefaults.spfresh_centroids_index_type, + type=click.Choice(["flat", "hnsw"]), + help="SPFresh centroids index type (default: None).", +) +@click.option( + "--spfresh_quantizer", + default=CreateCollectionDefaults.spfresh_quantizer, + type=click.Choice(["rq8", "rq1"]), + help="SPFresh quantizer type (default: None).", +) @click.pass_context def create_collection_cli( ctx: click.Context, @@ -176,6 +219,13 @@ def create_collection_cli( replication_deletion_strategy: str, named_vector: bool, named_vector_name: Optional[str], + spfresh_max_posting_size: Optional[int], + spfresh_min_posting_size: Optional[int], + spfresh_replicas: Optional[int], + spfresh_rng_factor: Optional[int], + spfresh_search_probe: Optional[int], + spfresh_centroids_index_type: Optional[str], + spfresh_quantizer: Optional[str], ) -> None: """Create a collection in Weaviate.""" @@ -201,6 +251,13 @@ def create_collection_cli( replication_deletion_strategy=replication_deletion_strategy, named_vector=named_vector, named_vector_name=named_vector_name, + spfresh_max_posting_size=spfresh_max_posting_size, + spfresh_min_posting_size=spfresh_min_posting_size, + spfresh_replicas=spfresh_replicas, + spfresh_rng_factor=spfresh_rng_factor, + spfresh_search_probe=spfresh_search_probe, + spfresh_centroids_index_type=spfresh_centroids_index_type, + spfresh_quantizer=spfresh_quantizer, ) except Exception as e: click.echo(f"Error: {e}") diff --git a/weaviate_cli/defaults.py b/weaviate_cli/defaults.py index 444f4ac..4bd56e1 100644 --- a/weaviate_cli/defaults.py +++ b/weaviate_cli/defaults.py @@ -69,6 +69,13 @@ class CreateCollectionDefaults: replication_deletion_strategy: str = "no_automated_resolution" named_vector: bool = False named_vector_name: Optional[str] = "default" + spfresh_max_posting_size: Optional[int] = None + spfresh_min_posting_size: Optional[int] = None + spfresh_replicas: Optional[int] = None + spfresh_rng_factor: Optional[int] = None + spfresh_search_probe: Optional[int] = None + spfresh_centroids_index_type: Optional[str] = None + spfresh_quantizer: Optional[str] = None @dataclass diff --git a/weaviate_cli/managers/collection_manager.py b/weaviate_cli/managers/collection_manager.py index b0aaddb..92c93fc 100644 --- a/weaviate_cli/managers/collection_manager.py +++ b/weaviate_cli/managers/collection_manager.py @@ -5,7 +5,8 @@ from weaviate.collections import Collection from weaviate.collections.classes.config import _CollectionConfigSimple from weaviate.collections.classes.tenants import TenantActivityStatus -from weaviate.classes.config import VectorFilterStrategy +from weaviate.collections.classes.config_vector_index import VectorCentroidsIndexType +from weaviate.collections.classes.config_vector_index import VectorFilterStrategy from weaviate_cli.defaults import ( CreateCollectionDefaults, UpdateCollectionDefaults, @@ -117,6 +118,53 @@ def get_collection( def get_all_collections(self) -> dict[str, _CollectionConfigSimple]: return self.client.collections.list_all() + def _build_spfresh_config( + self, + max_posting_size: Optional[int] = None, + min_posting_size: Optional[int] = None, + replicas: Optional[int] = None, + rng_factor: Optional[int] = None, + search_probe: Optional[int] = None, + centroids_index_type: Optional[str] = None, + quantizer: Optional[str] = None, + ): + """Build SPFresh configuration with provided parameters.""" + kwargs = {} + + if max_posting_size is not None: + kwargs["max_posting_size"] = max_posting_size + if min_posting_size is not None: + kwargs["min_posting_size"] = min_posting_size + if replicas is not None: + kwargs["replicas"] = replicas + if rng_factor is not None: + kwargs["rng_factor"] = rng_factor + if search_probe is not None: + kwargs["search_probe"] = search_probe + + # Handle centroids index type + if centroids_index_type is not None: + if centroids_index_type == "flat": + kwargs["centroids_index_type"] = VectorCentroidsIndexType.FLAT + elif centroids_index_type == "hnsw": + kwargs["centroids_index_type"] = VectorCentroidsIndexType.HNSW + + # Handle quantizer + quantizer_config = None + if quantizer is not None: + if quantizer == "rq8": + quantizer_config = wvc.Configure.VectorIndex.Quantizer.rq(bits=8) + elif quantizer == "rq1": + quantizer_config = wvc.Configure.VectorIndex.Quantizer.rq(bits=1) + else: + # Default quantizer if none specified + quantizer_config = wvc.Configure.VectorIndex.Quantizer.rq(bits=8) + + if quantizer_config is not None: + kwargs["quantizer"] = quantizer_config + + return wvc.Configure.VectorIndex.spfresh(**kwargs) + def create_collection( self, collection: str = CreateCollectionDefaults.collection, @@ -139,6 +187,21 @@ def create_collection( ] = CreateCollectionDefaults.replication_deletion_strategy, named_vector: bool = CreateCollectionDefaults.named_vector, named_vector_name: Optional[str] = CreateCollectionDefaults.named_vector_name, + spfresh_max_posting_size: Optional[ + int + ] = CreateCollectionDefaults.spfresh_max_posting_size, + spfresh_min_posting_size: Optional[ + int + ] = CreateCollectionDefaults.spfresh_min_posting_size, + spfresh_replicas: Optional[int] = CreateCollectionDefaults.spfresh_replicas, + spfresh_rng_factor: Optional[int] = CreateCollectionDefaults.spfresh_rng_factor, + spfresh_search_probe: Optional[ + int + ] = CreateCollectionDefaults.spfresh_search_probe, + spfresh_centroids_index_type: Optional[ + str + ] = CreateCollectionDefaults.spfresh_centroids_index_type, + spfresh_quantizer: Optional[str] = CreateCollectionDefaults.spfresh_quantizer, ) -> None: if self.client.collections.exists(collection): @@ -239,6 +302,15 @@ def create_collection( "flat_bq_cache": wvc.Configure.VectorIndex.flat( quantizer=wvc.Configure.VectorIndex.Quantizer.bq(cache=True) ), + "spfresh": self._build_spfresh_config( + max_posting_size=spfresh_max_posting_size, + min_posting_size=spfresh_min_posting_size, + replicas=spfresh_replicas, + rng_factor=spfresh_rng_factor, + search_probe=spfresh_search_probe, + centroids_index_type=spfresh_centroids_index_type, + quantizer=spfresh_quantizer, + ), } # Vectorizer configurations