Skip to content

Commit 82cd3b7

Browse files
gaudybGaudy Blanco
andauthored
Custom vector store schema implementation (#2062)
* progress on vector customization * fix for lancedb vectors * cosmosdb implementation * uv run poe format * clean test for vector store * semversioner update * test_factory.py integration test fixes * fixes for cosmosdb test * integration test fix for lancedb * uv fix for format * test fixes * fixes for tests * fix cosmosdb bug * print statement * test * test * fix cosmosdb bug * test validation * validation cosmosdb * validate cosmosdb * fix cosmosdb * fix small feedback from PR --------- Co-authored-by: Gaudy Blanco <[email protected]>
1 parent 075cadd commit 82cd3b7

File tree

19 files changed

+778
-272
lines changed

19 files changed

+778
-272
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "minor",
3+
"description": "add customization to vector store"
4+
}

.vscode/launch.json

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,24 @@
66
"name": "Indexer",
77
"type": "debugpy",
88
"request": "launch",
9-
"module": "uv",
9+
"module": "graphrag",
1010
"args": [
11-
"poe", "index",
12-
"--root", "<path_to_ragtest_root_demo>"
11+
"index",
12+
"--root",
13+
"<path_to_index_folder>"
1314
],
15+
"console": "integratedTerminal"
1416
},
1517
{
1618
"name": "Query",
1719
"type": "debugpy",
1820
"request": "launch",
19-
"module": "uv",
21+
"module": "graphrag",
2022
"args": [
21-
"poe", "query",
22-
"--root", "<path_to_ragtest_root_demo>",
23-
"--method", "global",
23+
"query",
24+
"--root",
25+
"<path_to_index_folder>",
26+
"--method", "basic",
2427
"--query", "What are the top themes in this story",
2528
]
2629
},
@@ -34,6 +37,42 @@
3437
"--config",
3538
"<path_to_ragtest_root_demo>/settings.yaml",
3639
]
37-
}
40+
},
41+
{
42+
"name": "Debug Integration Pytest",
43+
"type": "debugpy",
44+
"request": "launch",
45+
"module": "pytest",
46+
"args": [
47+
"./tests/integration/vector_stores",
48+
"-k", "test_azure_ai_search"
49+
],
50+
"console": "integratedTerminal",
51+
"justMyCode": false
52+
},
53+
{
54+
"name": "Debug Verbs Pytest",
55+
"type": "debugpy",
56+
"request": "launch",
57+
"module": "pytest",
58+
"args": [
59+
"./tests/verbs",
60+
"-k", "test_generate_text_embeddings"
61+
],
62+
"console": "integratedTerminal",
63+
"justMyCode": false
64+
},
65+
{
66+
"name": "Debug Smoke Pytest",
67+
"type": "debugpy",
68+
"request": "launch",
69+
"module": "pytest",
70+
"args": [
71+
"./tests/smoke",
72+
"-k", "test_fixtures"
73+
],
74+
"console": "integratedTerminal",
75+
"justMyCode": false
76+
},
3877
]
3978
}

graphrag/config/defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ class VectorStoreDefaults:
394394
api_key: None = None
395395
audience: None = None
396396
database_name: None = None
397+
schema: None = None
397398

398399

399400
@dataclass

graphrag/config/embeddings.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@
2929
]
3030

3131

32-
def create_collection_name(
32+
def create_index_name(
3333
container_name: str, embedding_name: str, validate: bool = True
3434
) -> str:
3535
"""
36-
Create a collection name for the embedding store.
36+
Create a index name for the embedding store.
3737
3838
Within any given vector store, we can have multiple sets of embeddings organized into projects.
39-
The `container` param is used for this partitioning, and is added as a prefix to the collection name for differentiation.
39+
The `container` param is used for this partitioning, and is added as a prefix to the index name for differentiation.
4040
4141
The embedding name is fixed, with the available list defined in graphrag.index.config.embeddings
4242

graphrag/config/models/vector_store_config.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from pydantic import BaseModel, Field, model_validator
77

88
from graphrag.config.defaults import vector_store_defaults
9+
from graphrag.config.embeddings import all_embeddings
910
from graphrag.config.enums import VectorStoreType
11+
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
1012

1113

1214
class VectorStoreConfig(BaseModel):
@@ -85,9 +87,25 @@ def _validate_url(self) -> None:
8587
default=vector_store_defaults.overwrite,
8688
)
8789

90+
embeddings_schema: dict[str, VectorStoreSchemaConfig] = {}
91+
92+
def _validate_embeddings_schema(self) -> None:
93+
"""Validate the embeddings schema."""
94+
for name in self.embeddings_schema:
95+
if name not in all_embeddings:
96+
msg = f"vector_store.embeddings_schema contains an invalid embedding schema name: {name}. Please update your settings.yaml and select the correct embedding schema names."
97+
raise ValueError(msg)
98+
99+
if self.type == VectorStoreType.CosmosDB:
100+
for id_field in self.embeddings_schema:
101+
if id_field != "id":
102+
msg = "When using CosmosDB, the id_field in embeddings_schema must be 'id'. Please update your settings.yaml and set the id_field to 'id'."
103+
raise ValueError(msg)
104+
88105
@model_validator(mode="after")
89106
def _validate_model(self):
90107
"""Validate the model."""
91108
self._validate_db_uri()
92109
self._validate_url()
110+
self._validate_embeddings_schema()
93111
return self
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
"""Parameterization settings for the default configuration."""
5+
6+
import re
7+
8+
from pydantic import BaseModel, Field, model_validator
9+
10+
DEFAULT_VECTOR_SIZE: int = 1536
11+
12+
VALID_IDENTIFIER_REGEX = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
13+
14+
15+
def is_valid_field_name(field: str) -> bool:
16+
"""Check if a field name is valid for CosmosDB."""
17+
return bool(VALID_IDENTIFIER_REGEX.match(field))
18+
19+
20+
class VectorStoreSchemaConfig(BaseModel):
21+
"""The default configuration section for Vector Store Schema."""
22+
23+
id_field: str = Field(
24+
description="The ID field to use.",
25+
default="id",
26+
)
27+
28+
vector_field: str = Field(
29+
description="The vector field to use.",
30+
default="vector",
31+
)
32+
33+
text_field: str = Field(
34+
description="The text field to use.",
35+
default="text",
36+
)
37+
38+
attributes_field: str = Field(
39+
description="The attributes field to use.",
40+
default="attributes",
41+
)
42+
43+
vector_size: int = Field(
44+
description="The vector size to use.",
45+
default=DEFAULT_VECTOR_SIZE,
46+
)
47+
48+
index_name: str | None = Field(description="The index name to use.", default=None)
49+
50+
def _validate_schema(self) -> None:
51+
"""Validate the schema."""
52+
for field in [
53+
self.id_field,
54+
self.vector_field,
55+
self.text_field,
56+
self.attributes_field,
57+
]:
58+
if not is_valid_field_name(field):
59+
msg = f"Unsafe or invalid field name: {field}"
60+
raise ValueError(msg)
61+
62+
@model_validator(mode="after")
63+
def _validate_model(self):
64+
"""Validate the model."""
65+
self._validate_schema()
66+
return self

graphrag/index/operations/embed_text/embed_text.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212

1313
from graphrag.cache.pipeline_cache import PipelineCache
1414
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
15-
from graphrag.config.embeddings import create_collection_name
15+
from graphrag.config.embeddings import create_index_name
16+
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
1617
from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingStrategy
1718
from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument
1819
from graphrag.vector_stores.factory import VectorStoreFactory
@@ -49,9 +50,9 @@ async def embed_text(
4950
vector_store_config = strategy.get("vector_store")
5051

5152
if vector_store_config:
52-
collection_name = _get_collection_name(vector_store_config, embedding_name)
53+
index_name = _get_index_name(vector_store_config, embedding_name)
5354
vector_store: BaseVectorStore = _create_vector_store(
54-
vector_store_config, collection_name
55+
vector_store_config, index_name, embedding_name
5556
)
5657
vector_store_workflow_config = vector_store_config.get(
5758
embedding_name, vector_store_config
@@ -183,27 +184,46 @@ async def _text_embed_with_vector_store(
183184

184185

185186
def _create_vector_store(
186-
vector_store_config: dict, collection_name: str
187+
vector_store_config: dict, index_name: str, embedding_name: str | None = None
187188
) -> BaseVectorStore:
188189
vector_store_type: str = str(vector_store_config.get("type"))
189-
if collection_name:
190-
vector_store_config.update({"collection_name": collection_name})
190+
191+
embeddings_schema: dict[str, VectorStoreSchemaConfig] = vector_store_config.get(
192+
"embeddings_schema", {}
193+
)
194+
single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig()
195+
196+
if (
197+
embeddings_schema is not None
198+
and embedding_name is not None
199+
and embedding_name in embeddings_schema
200+
):
201+
raw_config = embeddings_schema[embedding_name]
202+
if isinstance(raw_config, dict):
203+
single_embedding_config = VectorStoreSchemaConfig(**raw_config)
204+
else:
205+
single_embedding_config = raw_config
206+
207+
if single_embedding_config.index_name is None:
208+
single_embedding_config.index_name = index_name
191209

192210
vector_store = VectorStoreFactory().create_vector_store(
193-
vector_store_type, kwargs=vector_store_config
211+
vector_store_schema_config=single_embedding_config,
212+
vector_store_type=vector_store_type,
213+
kwargs=vector_store_config,
194214
)
195215

196216
vector_store.connect(**vector_store_config)
197217
return vector_store
198218

199219

200-
def _get_collection_name(vector_store_config: dict, embedding_name: str) -> str:
220+
def _get_index_name(vector_store_config: dict, embedding_name: str) -> str:
201221
container_name = vector_store_config.get("container_name", "default")
202-
collection_name = create_collection_name(container_name, embedding_name)
222+
index_name = create_index_name(container_name, embedding_name)
203223

204-
msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {collection_name}"
224+
msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {index_name}"
205225
logger.info(msg)
206-
return collection_name
226+
return index_name
207227

208228

209229
def load_strategy(strategy: TextEmbedStrategyType) -> TextEmbeddingStrategy:

graphrag/utils/api.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88

99
from graphrag.cache.factory import CacheFactory
1010
from graphrag.cache.pipeline_cache import PipelineCache
11-
from graphrag.config.embeddings import create_collection_name
11+
from graphrag.config.embeddings import create_index_name
1212
from graphrag.config.models.cache_config import CacheConfig
1313
from graphrag.config.models.storage_config import StorageConfig
14+
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
1415
from graphrag.data_model.types import TextEmbedder
1516
from graphrag.storage.factory import StorageFactory
1617
from graphrag.storage.pipeline_storage import PipelineStorage
@@ -103,12 +104,33 @@ def get_embedding_store(
103104
index_names = []
104105
for index, store in config_args.items():
105106
vector_store_type = store["type"]
106-
collection_name = create_collection_name(
107+
index_name = create_index_name(
107108
store.get("container_name", "default"), embedding_name
108109
)
110+
111+
embeddings_schema: dict[str, VectorStoreSchemaConfig] = store.get(
112+
"embeddings_schema", {}
113+
)
114+
single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig()
115+
116+
if (
117+
embeddings_schema is not None
118+
and embedding_name is not None
119+
and embedding_name in embeddings_schema
120+
):
121+
raw_config = embeddings_schema[embedding_name]
122+
if isinstance(raw_config, dict):
123+
single_embedding_config = VectorStoreSchemaConfig(**raw_config)
124+
else:
125+
single_embedding_config = raw_config
126+
127+
if single_embedding_config.index_name is None:
128+
single_embedding_config.index_name = index_name
129+
109130
embedding_store = VectorStoreFactory().create_vector_store(
110131
vector_store_type=vector_store_type,
111-
kwargs={**store, "collection_name": collection_name},
132+
vector_store_schema_config=single_embedding_config,
133+
kwargs={**store},
112134
)
113135
embedding_store.connect(**store)
114136
# If there is only a single index, return the embedding store directly

0 commit comments

Comments
 (0)