Skip to content

Commit e39d869

Browse files
nievespg1Gabriel NievesAlonsoGuevara
authored
Added support for verbose logging and csv-metadata to the prompt tune… (#1789)
* Added support for verbose logging and csv-metadata to the prompt tune client. * Updated community report summarization file name and prompt template * updated semversioner * ran ruff linter * Ran poe format * Fix Ruff complains * Fix a new ruff complain :P * Pyright * Fix tests --------- Co-authored-by: Gabriel Nieves <[email protected]> Co-authored-by: Alonso Guevara <[email protected]>
1 parent 66c2cfb commit e39d869

File tree

12 files changed

+309
-212
lines changed

12 files changed

+309
-212
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "minor",
3+
"description": "Updated the prompt tunning client to support csv-metadata injection and updated ourput file types to match the new naming convention."
4+
}

graphrag/api/prompt_tune.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,16 @@
1111
Backwards compatibility is not guaranteed at this time.
1212
"""
1313

14+
from typing import Annotated
15+
16+
import annotated_types
1417
from pydantic import PositiveInt, validate_call
1518

1619
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
17-
from graphrag.config.defaults import language_model_defaults
20+
from graphrag.config.defaults import graphrag_config_defaults, language_model_defaults
1821
from graphrag.config.models.graph_rag_config import GraphRagConfig
1922
from graphrag.language_model.manager import ModelManager
20-
from graphrag.logger.print_progress import PrintProgressLogger
23+
from graphrag.logger.base import ProgressLogger
2124
from graphrag.prompt_tune.defaults import MAX_TOKEN_COUNT, PROMPT_TUNING_MODEL_ID
2225
from graphrag.prompt_tune.generator.community_report_rating import (
2326
generate_community_report_rating,
@@ -41,15 +44,19 @@
4144
)
4245
from graphrag.prompt_tune.generator.language import detect_language
4346
from graphrag.prompt_tune.generator.persona import generate_persona
44-
from graphrag.prompt_tune.loader.input import MIN_CHUNK_SIZE, load_docs_in_chunks
47+
from graphrag.prompt_tune.loader.input import load_docs_in_chunks
4548
from graphrag.prompt_tune.types import DocSelectionType
4649

4750

48-
@validate_call
51+
@validate_call(config={"arbitrary_types_allowed": True})
4952
async def generate_indexing_prompts(
5053
config: GraphRagConfig,
54+
logger: ProgressLogger,
5155
root: str,
52-
chunk_size: PositiveInt = MIN_CHUNK_SIZE,
56+
chunk_size: PositiveInt = graphrag_config_defaults.chunks.size,
57+
overlap: Annotated[
58+
int, annotated_types.Gt(-1)
59+
] = graphrag_config_defaults.chunks.overlap,
5360
limit: PositiveInt = 15,
5461
selection_method: DocSelectionType = DocSelectionType.RANDOM,
5562
domain: str | None = None,
@@ -65,6 +72,8 @@ async def generate_indexing_prompts(
6572
Parameters
6673
----------
6774
- config: The GraphRag configuration.
75+
- logger: The logger to use for progress updates.
76+
- root: The root directory.
6877
- output_path: The path to store the prompts.
6978
- chunk_size: The chunk token size to use for input text units.
7079
- limit: The limit of chunks to load.
@@ -81,22 +90,23 @@ async def generate_indexing_prompts(
8190
-------
8291
tuple[str, str, str]: entity extraction prompt, entity summarization prompt, community summarization prompt
8392
"""
84-
logger = PrintProgressLogger("")
85-
8693
# Retrieve documents
94+
logger.info("Chunking documents...")
8795
doc_list = await load_docs_in_chunks(
8896
root=root,
8997
config=config,
9098
limit=limit,
9199
select_method=selection_method,
92100
logger=logger,
93101
chunk_size=chunk_size,
102+
overlap=overlap,
94103
n_subset_max=n_subset_max,
95104
k=k,
96105
)
97106

98107
# Create LLM from config
99108
# TODO: Expose a way to specify Prompt Tuning model ID through config
109+
logger.info("Retrieving language model configuration...")
100110
default_llm_settings = config.get_language_model_config(PROMPT_TUNING_MODEL_ID)
101111

102112
# if max_retries is not set, inject a dynamically assigned value based on the number of expected LLM calls
@@ -105,7 +115,10 @@ async def generate_indexing_prompts(
105115
default_llm_settings.max_retries = min(
106116
len(doc_list), language_model_defaults.max_retries
107117
)
118+
msg = f"max_retries not set, using default value: {default_llm_settings.max_retries}"
119+
logger.warning(msg)
108120

121+
logger.info("Creating language model...")
109122
llm = ModelManager().register_chat(
110123
name="prompt_tuning",
111124
model_type=default_llm_settings.type,
@@ -117,7 +130,6 @@ async def generate_indexing_prompts(
117130
if not domain:
118131
logger.info("Generating domain...")
119132
domain = await generate_domain(llm, doc_list)
120-
logger.info(f"Generated domain: {domain}") # noqa
121133

122134
if not language:
123135
logger.info("Detecting language...")
@@ -186,6 +198,10 @@ async def generate_indexing_prompts(
186198
language=language,
187199
)
188200

201+
logger.info(f"\nGenerated domain: {domain}") # noqa: G004
202+
logger.info(f"\nDetected language: {language}") # noqa: G004
203+
logger.info(f"\nGenerated persona: {persona}") # noqa: G004
204+
189205
return (
190206
extract_graph_prompt,
191207
entity_summarization_prompt,

graphrag/cli/main.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,10 @@
1111

1212
import typer
1313

14+
from graphrag.config.defaults import graphrag_config_defaults
1415
from graphrag.config.enums import IndexingMethod, SearchMethod
1516
from graphrag.logger.types import LoggerType
16-
from graphrag.prompt_tune.defaults import (
17-
MAX_TOKEN_COUNT,
18-
MIN_CHUNK_SIZE,
19-
N_SUBSET_MAX,
20-
K,
21-
)
17+
from graphrag.prompt_tune.defaults import LIMIT, MAX_TOKEN_COUNT, N_SUBSET_MAX, K
2218
from graphrag.prompt_tune.types import DocSelectionType
2319

2420
INVALID_METHOD_ERROR = "Invalid method"
@@ -274,6 +270,12 @@ def _prompt_tune_cli(
274270
),
275271
),
276272
] = None,
273+
verbose: Annotated[
274+
bool, typer.Option(help="Run the prompt tuning pipeline with verbose logging")
275+
] = False,
276+
logger: Annotated[
277+
LoggerType, typer.Option(help="The progress logger to use.")
278+
] = LoggerType.RICH,
277279
domain: Annotated[
278280
str | None,
279281
typer.Option(
@@ -300,7 +302,7 @@ def _prompt_tune_cli(
300302
typer.Option(
301303
help="The number of documents to load when --selection-method={random,top}."
302304
),
303-
] = 15,
305+
] = LIMIT,
304306
max_tokens: Annotated[
305307
int, typer.Option(help="The max token count for prompt generation.")
306308
] = MAX_TOKEN_COUNT,
@@ -311,8 +313,17 @@ def _prompt_tune_cli(
311313
),
312314
] = 2,
313315
chunk_size: Annotated[
314-
int, typer.Option(help="The max token count for prompt generation.")
315-
] = MIN_CHUNK_SIZE,
316+
int,
317+
typer.Option(
318+
help="The size of each example text chunk. Overrides chunks.size in the configuration file."
319+
),
320+
] = graphrag_config_defaults.chunks.size,
321+
overlap: Annotated[
322+
int,
323+
typer.Option(
324+
help="The overlap size for chunking documents. Overrides chunks.overlap in the configuration file"
325+
),
326+
] = graphrag_config_defaults.chunks.overlap,
316327
language: Annotated[
317328
str | None,
318329
typer.Option(
@@ -343,10 +354,13 @@ def _prompt_tune_cli(
343354
root=root,
344355
config=config,
345356
domain=domain,
357+
verbose=verbose,
358+
logger=logger,
346359
selection_method=selection_method,
347360
limit=limit,
348361
max_tokens=max_tokens,
349362
chunk_size=chunk_size,
363+
overlap=overlap,
350364
language=language,
351365
discover_entity_types=discover_entity_types,
352366
output=output,

graphrag/cli/prompt_tune.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
from pathlib import Path
77

88
import graphrag.api as api
9+
from graphrag.cli.index import _logger
910
from graphrag.config.load_config import load_config
10-
from graphrag.logger.print_progress import PrintProgressLogger
11+
from graphrag.config.logging import enable_logging_with_config
12+
from graphrag.logger.factory import LoggerFactory, LoggerType
1113
from graphrag.prompt_tune.generator.community_report_summarization import (
1214
COMMUNITY_SUMMARIZATION_FILENAME,
1315
)
@@ -17,16 +19,20 @@
1719
from graphrag.prompt_tune.generator.extract_graph_prompt import (
1820
EXTRACT_GRAPH_FILENAME,
1921
)
22+
from graphrag.utils.cli import redact
2023

2124

2225
async def prompt_tune(
2326
root: Path,
2427
config: Path | None,
2528
domain: str | None,
29+
verbose: bool,
30+
logger: LoggerType,
2631
selection_method: api.DocSelectionType,
2732
limit: int,
2833
max_tokens: int,
2934
chunk_size: int,
35+
overlap: int,
3036
language: str | None,
3137
discover_entity_types: bool,
3238
output: Path,
@@ -41,6 +47,8 @@ async def prompt_tune(
4147
- config: The configuration file.
4248
- root: The root directory.
4349
- domain: The domain to map the input documents to.
50+
- verbose: Whether to enable verbose logging.
51+
- logger: The logger to use.
4452
- selection_method: The chunk selection method.
4553
- limit: The limit of chunks to load.
4654
- max_tokens: The maximum number of tokens to use on entity extraction prompts.
@@ -52,14 +60,36 @@ async def prompt_tune(
5260
- k: The number of documents to select when using auto selection method.
5361
- min_examples_required: The minimum number of examples required for entity extraction prompts.
5462
"""
55-
logger = PrintProgressLogger("")
5663
root_path = Path(root).resolve()
5764
graph_config = load_config(root_path, config)
5865

66+
# override chunking config in the configuration
67+
if chunk_size != graph_config.chunks.size:
68+
graph_config.chunks.size = chunk_size
69+
70+
if overlap != graph_config.chunks.overlap:
71+
graph_config.chunks.overlap = overlap
72+
73+
progress_logger = LoggerFactory().create_logger(logger)
74+
info, error, success = _logger(progress_logger)
75+
76+
enabled_logging, log_path = enable_logging_with_config(
77+
graph_config, verbose, filename="prompt-tune.log"
78+
)
79+
if enabled_logging:
80+
info(f"Logging enabled at {log_path}", verbose)
81+
else:
82+
info(
83+
f"Logging not enabled for config {redact(graph_config.model_dump())}",
84+
verbose,
85+
)
86+
5987
prompts = await api.generate_indexing_prompts(
6088
config=graph_config,
6189
root=str(root_path),
90+
logger=progress_logger,
6291
chunk_size=chunk_size,
92+
overlap=overlap,
6393
limit=limit,
6494
selection_method=selection_method,
6595
domain=domain,
@@ -73,7 +103,7 @@ async def prompt_tune(
73103

74104
output_path = output.resolve()
75105
if output_path:
76-
logger.info(f"Writing prompts to {output_path}") # noqa: G004
106+
info(f"Writing prompts to {output_path}")
77107
output_path.mkdir(parents=True, exist_ok=True)
78108
extract_graph_prompt_path = output_path / EXTRACT_GRAPH_FILENAME
79109
entity_summarization_prompt_path = output_path / ENTITY_SUMMARIZATION_FILENAME
@@ -87,3 +117,6 @@ async def prompt_tune(
87117
file.write(prompts[1].encode(encoding="utf-8", errors="strict"))
88118
with community_summarization_prompt_path.open("wb") as file:
89119
file.write(prompts[2].encode(encoding="utf-8", errors="strict"))
120+
success(f"Prompts written to {output_path}")
121+
else:
122+
error("No output path provided. Skipping writing prompts.")

graphrag/config/logging.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def enable_logging(log_filepath: str | Path, verbose: bool = False) -> None:
3434

3535

3636
def enable_logging_with_config(
37-
config: GraphRagConfig, verbose: bool = False
37+
config: GraphRagConfig, verbose: bool = False, filename: str = "indexing-engine.log"
3838
) -> tuple[bool, str]:
3939
"""Enable logging to a file based on the config.
4040
@@ -55,7 +55,7 @@ def enable_logging_with_config(
5555
(True, str) if logging was enabled.
5656
"""
5757
if config.reporting.type == ReportingType.file:
58-
log_path = Path(config.reporting.base_dir) / "indexing-engine.log"
58+
log_path = Path(config.reporting.base_dir) / filename
5959
enable_logging(log_path, verbose)
6060
return (True, str(log_path))
6161
return (False, "")

graphrag/language_model/protocol/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ async def achat_stream(
120120
-------
121121
A generator that yields strings representing the response.
122122
"""
123+
yield "" # Yield an empty string so that the function is recognized as a generator
123124
...
124125

125126
def chat(

graphrag/prompt_tune/defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"""
1313

1414
K = 15
15+
LIMIT = 15
1516
MAX_TOKEN_COUNT = 2000
1617
MIN_CHUNK_SIZE = 200
1718
N_SUBSET_MAX = 300

graphrag/prompt_tune/generator/community_report_summarization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
COMMUNITY_REPORT_SUMMARIZATION_PROMPT,
1010
)
1111

12-
COMMUNITY_SUMMARIZATION_FILENAME = "community_report.txt"
12+
COMMUNITY_SUMMARIZATION_FILENAME = "community_report_graph.txt"
1313

1414

1515
def create_community_summarization_prompt(

0 commit comments

Comments
 (0)