Skip to content

Commit a4ffc3d

Browse files
gaudybGaudy Blanco
andauthored
Remove embeddings optional new (#2128)
* remove optional embeddings * fix test * fix tests * fix pipeline * fix test * fix test * fix test * fix tests --------- Co-authored-by: Gaudy Blanco <[email protected]>
1 parent 4512ce0 commit a4ffc3d

File tree

13 files changed

+43
-125
lines changed

13 files changed

+43
-125
lines changed

docs/config/yaml.md

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -182,14 +182,9 @@ Where to put all vectors for the system. Configured for lancedb by default. This
182182

183183
The supported embeddings are:
184184

185-
- `text_unit.text`
186-
- `document.text`
187-
- `entity.title`
188-
- `entity.description`
189-
- `relationship.description`
190-
- `community.title`
191-
- `community.summary`
192-
- `community.full_content`
185+
- `text_unit_text`
186+
- `entity_description`
187+
- `community_full_content`
193188

194189
For example:
195190

@@ -199,12 +194,12 @@ vector_store:
199194
db_uri: output/lancedb
200195
index_prefix: "christmas-carol"
201196
embeddings_schema:
202-
text_unit.text:
197+
text_unit_text:
203198
index_name: "text-unit-embeddings"
204199
id_field: "id_custom"
205200
vector_field: "vector_custom"
206201
vector_size: 3072
207-
entity.description:
202+
entity_description:
208203
id_field: "id_custom"
209204
210205
```
@@ -224,14 +219,9 @@ By default, the GraphRAG indexer will only export embeddings required for our qu
224219

225220
Supported embeddings names are:
226221

227-
- `text_unit.text`
228-
- `document.text`
229-
- `entity.title`
230-
- `entity.description`
231-
- `relationship.description`
232-
- `community.title`
233-
- `community.summary`
234-
- `community.full_content`
222+
- `text_unit_text`
223+
- `entity_description`
224+
- `community_full_content`
235225

236226
#### Fields
237227

docs/examples_notebooks/api_overview.ipynb

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@
2828
"from pathlib import Path\n",
2929
"from pprint import pprint\n",
3030
"\n",
31-
"import graphrag.api as api\n",
3231
"import pandas as pd\n",
3332
"from graphrag.config.load_config import load_config\n",
34-
"from graphrag.index.typing.pipeline_run_result import PipelineRunResult"
33+
"from graphrag.index.typing.pipeline_run_result import PipelineRunResult\n",
34+
"\n",
35+
"import graphrag.api as api"
3536
]
3637
},
3738
{
@@ -170,7 +171,7 @@
170171
],
171172
"metadata": {
172173
"kernelspec": {
173-
"display_name": ".venv",
174+
"display_name": "graphrag-monorepo",
174175
"language": "python",
175176
"name": "python3"
176177
},
@@ -184,7 +185,7 @@
184185
"name": "python",
185186
"nbconvert_exporter": "python",
186187
"pygments_lexer": "ipython3",
187-
"version": "3.11.9"
188+
"version": "3.12.9"
188189
}
189190
},
190191
"nbformat": 4,

docs/examples_notebooks/index_migration_to_v1.ipynb

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,6 @@
229229
"tokenizer = get_tokenizer(model_config)\n",
230230
"\n",
231231
"await generate_text_embeddings(\n",
232-
" documents=None,\n",
233-
" relationships=None,\n",
234232
" text_units=final_text_units,\n",
235233
" entities=final_entities,\n",
236234
" community_reports=final_community_reports,\n",

docs/examples_notebooks/input_documents.ipynb

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,11 @@
3030
"from pathlib import Path\n",
3131
"from pprint import pprint\n",
3232
"\n",
33-
"import graphrag.api as api\n",
3433
"import pandas as pd\n",
3534
"from graphrag.config.load_config import load_config\n",
36-
"from graphrag.index.typing.pipeline_run_result import PipelineRunResult"
35+
"from graphrag.index.typing.pipeline_run_result import PipelineRunResult\n",
36+
"\n",
37+
"import graphrag.api as api"
3738
]
3839
},
3940
{
@@ -171,7 +172,7 @@
171172
],
172173
"metadata": {
173174
"kernelspec": {
174-
"display_name": "graphrag",
175+
"display_name": "graphrag-monorepo",
175176
"language": "python",
176177
"name": "python3"
177178
},
@@ -185,7 +186,7 @@
185186
"name": "python",
186187
"nbconvert_exporter": "python",
187188
"pygments_lexer": "ipython3",
188-
"version": "3.12.10"
189+
"version": "3.12.9"
189190
}
190191
},
191192
"nbformat": 4,

packages/graphrag/graphrag/config/embeddings.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,12 @@
33

44
"""A module containing embeddings values."""
55

6-
entity_title_embedding = "entity.title"
7-
entity_description_embedding = "entity.description"
8-
relationship_description_embedding = "relationship.description"
9-
document_text_embedding = "document.text"
10-
community_title_embedding = "community.title"
11-
community_summary_embedding = "community.summary"
12-
community_full_content_embedding = "community.full_content"
13-
text_unit_text_embedding = "text_unit.text"
6+
entity_description_embedding = "entity_description"
7+
community_full_content_embedding = "community_full_content"
8+
text_unit_text_embedding = "text_unit_text"
149

1510
all_embeddings: set[str] = {
16-
entity_title_embedding,
1711
entity_description_embedding,
18-
relationship_description_embedding,
19-
document_text_embedding,
20-
community_title_embedding,
21-
community_summary_embedding,
2212
community_full_content_embedding,
2313
text_unit_text_embedding,
2414
}
@@ -47,5 +37,5 @@ def create_index_name(
4737
raise KeyError(msg)
4838

4939
if index_prefix:
50-
return f"{index_prefix}-{embedding_name}".replace(".", "-")
51-
return embedding_name.replace(".", "-")
40+
return f"{index_prefix}-{embedding_name}"
41+
return embedding_name

packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py

Lines changed: 2 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,8 @@
1010
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
1111
from graphrag.config.embeddings import (
1212
community_full_content_embedding,
13-
community_summary_embedding,
14-
community_title_embedding,
1513
create_index_name,
16-
document_text_embedding,
1714
entity_description_embedding,
18-
entity_title_embedding,
19-
relationship_description_embedding,
2015
text_unit_text_embedding,
2116
)
2217
from graphrag.config.models.graph_rag_config import GraphRagConfig
@@ -47,29 +42,14 @@ async def run_workflow(
4742
logger.info("Workflow started: generate_text_embeddings")
4843
embedded_fields = config.embed_text.names
4944
logger.info("Embedding the following fields: %s", embedded_fields)
50-
documents = None
51-
relationships = None
5245
text_units = None
5346
entities = None
5447
community_reports = None
55-
if document_text_embedding in embedded_fields:
56-
documents = await load_table_from_storage("documents", context.output_storage)
57-
if relationship_description_embedding in embedded_fields:
58-
relationships = await load_table_from_storage(
59-
"relationships", context.output_storage
60-
)
6148
if text_unit_text_embedding in embedded_fields:
6249
text_units = await load_table_from_storage("text_units", context.output_storage)
63-
if (
64-
entity_title_embedding in embedded_fields
65-
or entity_description_embedding in embedded_fields
66-
):
50+
if entity_description_embedding in embedded_fields:
6751
entities = await load_table_from_storage("entities", context.output_storage)
68-
if (
69-
community_title_embedding in embedded_fields
70-
or community_summary_embedding in embedded_fields
71-
or community_full_content_embedding in embedded_fields
72-
):
52+
if community_full_content_embedding in embedded_fields:
7353
community_reports = await load_table_from_storage(
7454
"community_reports", context.output_storage
7555
)
@@ -87,8 +67,6 @@ async def run_workflow(
8767
tokenizer = get_tokenizer(model_config)
8868

8969
output = await generate_text_embeddings(
90-
documents=documents,
91-
relationships=relationships,
9270
text_units=text_units,
9371
entities=entities,
9472
community_reports=community_reports,
@@ -115,8 +93,6 @@ async def run_workflow(
11593

11694

11795
async def generate_text_embeddings(
118-
documents: pd.DataFrame | None,
119-
relationships: pd.DataFrame | None,
12096
text_units: pd.DataFrame | None,
12197
entities: pd.DataFrame | None,
12298
community_reports: pd.DataFrame | None,
@@ -131,26 +107,12 @@ async def generate_text_embeddings(
131107
) -> dict[str, pd.DataFrame]:
132108
"""All the steps to generate all embeddings."""
133109
embedding_param_map = {
134-
document_text_embedding: {
135-
"data": documents.loc[:, ["id", "text"]] if documents is not None else None,
136-
"embed_column": "text",
137-
},
138-
relationship_description_embedding: {
139-
"data": relationships.loc[:, ["id", "description"]]
140-
if relationships is not None
141-
else None,
142-
"embed_column": "description",
143-
},
144110
text_unit_text_embedding: {
145111
"data": text_units.loc[:, ["id", "text"]]
146112
if text_units is not None
147113
else None,
148114
"embed_column": "text",
149115
},
150-
entity_title_embedding: {
151-
"data": entities.loc[:, ["id", "title"]] if entities is not None else None,
152-
"embed_column": "title",
153-
},
154116
entity_description_embedding: {
155117
"data": entities.loc[:, ["id", "title", "description"]].assign(
156118
title_description=lambda df: df["title"] + ":" + df["description"]
@@ -159,18 +121,6 @@ async def generate_text_embeddings(
159121
else None,
160122
"embed_column": "title_description",
161123
},
162-
community_title_embedding: {
163-
"data": community_reports.loc[:, ["id", "title"]]
164-
if community_reports is not None
165-
else None,
166-
"embed_column": "title",
167-
},
168-
community_summary_embedding: {
169-
"data": community_reports.loc[:, ["id", "summary"]]
170-
if community_reports is not None
171-
else None,
172-
"embed_column": "summary",
173-
},
174124
community_full_content_embedding: {
175125
"data": community_reports.loc[:, ["id", "full_content"]]
176126
if community_reports is not None

packages/graphrag/graphrag/index/workflows/update_text_embeddings.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,6 @@ async def run_workflow(
2626
output_storage, _, _ = get_update_storages(
2727
config, context.state["update_timestamp"]
2828
)
29-
30-
final_documents_df = context.state["incremental_update_final_documents"]
31-
merged_relationships_df = context.state["incremental_update_merged_relationships"]
3229
merged_text_units = context.state["incremental_update_merged_text_units"]
3330
merged_entities_df = context.state["incremental_update_merged_entities"]
3431
merged_community_reports = context.state[
@@ -50,8 +47,6 @@ async def run_workflow(
5047
tokenizer = get_tokenizer(model_config)
5148

5249
result = await generate_text_embeddings(
53-
documents=final_documents_df,
54-
relationships=merged_relationships_df,
5550
text_units=merged_text_units,
5651
entities=merged_entities_df,
5752
community_reports=merged_community_reports,

tests/fixtures/min-csv/config.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,9 @@
8585
],
8686
"max_runtime": 150,
8787
"expected_artifacts": [
88-
"embeddings.text_unit.text.parquet",
89-
"embeddings.entity.description.parquet",
90-
"embeddings.community.full_content.parquet"
88+
"embeddings.text_unit_text.parquet",
89+
"embeddings.entity_description.parquet",
90+
"embeddings.community_full_content.parquet"
9191
]
9292
}
9393
},

tests/fixtures/text/config.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@
8484
],
8585
"max_runtime": 150,
8686
"expected_artifacts": [
87-
"embeddings.text_unit.text.parquet",
88-
"embeddings.entity.description.parquet",
89-
"embeddings.community.full_content.parquet"
87+
"embeddings.text_unit_text.parquet",
88+
"embeddings.entity_description.parquet",
89+
"embeddings.community_full_content.parquet"
9090
]
9191
}
9292
},

tests/unit/utils/test_embeddings.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77

88
def test_create_index_name():
9-
collection = create_index_name("default", "entity.title")
10-
assert collection == "default-entity-title"
9+
collection = create_index_name("default", "entity_description")
10+
assert collection == "default-entity_description"
1111

1212

1313
def test_create_index_name_invalid_embedding_throws():
@@ -16,5 +16,5 @@ def test_create_index_name_invalid_embedding_throws():
1616

1717

1818
def test_create_index_name_invalid_embedding_does_not_throw():
19-
collection = create_index_name("default", "invalid.name", validate=False)
20-
assert collection == "default-invalid-name"
19+
collection = create_index_name("default", "invalid_name", validate=False)
20+
assert collection == "default-invalid_name"

0 commit comments

Comments
 (0)