|
12 | 12 |
|
13 | 13 | from graphrag.cache.pipeline_cache import PipelineCache |
14 | 14 | from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks |
15 | | -from graphrag.config.embeddings import create_collection_name |
| 15 | +from graphrag.config.embeddings import create_index_name |
| 16 | +from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig |
16 | 17 | from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingStrategy |
17 | 18 | from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument |
18 | 19 | from graphrag.vector_stores.factory import VectorStoreFactory |
@@ -49,9 +50,9 @@ async def embed_text( |
49 | 50 | vector_store_config = strategy.get("vector_store") |
50 | 51 |
|
51 | 52 | if vector_store_config: |
52 | | - collection_name = _get_collection_name(vector_store_config, embedding_name) |
| 53 | + index_name = _get_index_name(vector_store_config, embedding_name) |
53 | 54 | vector_store: BaseVectorStore = _create_vector_store( |
54 | | - vector_store_config, collection_name |
| 55 | + vector_store_config, index_name, embedding_name |
55 | 56 | ) |
56 | 57 | vector_store_workflow_config = vector_store_config.get( |
57 | 58 | embedding_name, vector_store_config |
@@ -183,27 +184,46 @@ async def _text_embed_with_vector_store( |
183 | 184 |
|
184 | 185 |
|
185 | 186 | def _create_vector_store( |
186 | | - vector_store_config: dict, collection_name: str |
| 187 | + vector_store_config: dict, index_name: str, embedding_name: str | None = None |
187 | 188 | ) -> BaseVectorStore: |
188 | 189 | vector_store_type: str = str(vector_store_config.get("type")) |
189 | | - if collection_name: |
190 | | - vector_store_config.update({"collection_name": collection_name}) |
| 190 | + |
| 191 | + embeddings_schema: dict[str, VectorStoreSchemaConfig] = vector_store_config.get( |
| 192 | + "embeddings_schema", {} |
| 193 | + ) |
| 194 | + single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig() |
| 195 | + |
| 196 | + if ( |
| 197 | + embeddings_schema is not None |
| 198 | + and embedding_name is not None |
| 199 | + and embedding_name in embeddings_schema |
| 200 | + ): |
| 201 | + raw_config = embeddings_schema[embedding_name] |
| 202 | + if isinstance(raw_config, dict): |
| 203 | + single_embedding_config = VectorStoreSchemaConfig(**raw_config) |
| 204 | + else: |
| 205 | + single_embedding_config = raw_config |
| 206 | + |
| 207 | + if single_embedding_config.index_name is None: |
| 208 | + single_embedding_config.index_name = index_name |
191 | 209 |
|
192 | 210 | vector_store = VectorStoreFactory().create_vector_store( |
193 | | - vector_store_type, kwargs=vector_store_config |
| 211 | + vector_store_schema_config=single_embedding_config, |
| 212 | + vector_store_type=vector_store_type, |
| 213 | + kwargs=vector_store_config, |
194 | 214 | ) |
195 | 215 |
|
196 | 216 | vector_store.connect(**vector_store_config) |
197 | 217 | return vector_store |
198 | 218 |
|
199 | 219 |
|
200 | | -def _get_collection_name(vector_store_config: dict, embedding_name: str) -> str: |
| 220 | +def _get_index_name(vector_store_config: dict, embedding_name: str) -> str: |
201 | 221 | container_name = vector_store_config.get("container_name", "default") |
202 | | - collection_name = create_collection_name(container_name, embedding_name) |
| 222 | + index_name = create_index_name(container_name, embedding_name) |
203 | 223 |
|
204 | | - msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {collection_name}" |
| 224 | + msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {index_name}" |
205 | 225 | logger.info(msg) |
206 | | - return collection_name |
| 226 | + return index_name |
207 | 227 |
|
208 | 228 |
|
209 | 229 | def load_strategy(strategy: TextEmbedStrategyType) -> TextEmbeddingStrategy: |
|
0 commit comments