1111Backwards compatibility is not guaranteed at this time.
1212"""
1313
14+ from typing import Annotated
15+
16+ import annotated_types
1417from pydantic import PositiveInt , validate_call
1518
1619from graphrag .callbacks .noop_workflow_callbacks import NoopWorkflowCallbacks
17- from graphrag .config .defaults import language_model_defaults
20+ from graphrag .config .defaults import graphrag_config_defaults , language_model_defaults
1821from graphrag .config .models .graph_rag_config import GraphRagConfig
1922from graphrag .language_model .manager import ModelManager
20- from graphrag .logger .print_progress import PrintProgressLogger
23+ from graphrag .logger .base import ProgressLogger
2124from graphrag .prompt_tune .defaults import MAX_TOKEN_COUNT , PROMPT_TUNING_MODEL_ID
2225from graphrag .prompt_tune .generator .community_report_rating import (
2326 generate_community_report_rating ,
4144)
4245from graphrag .prompt_tune .generator .language import detect_language
4346from graphrag .prompt_tune .generator .persona import generate_persona
44- from graphrag .prompt_tune .loader .input import MIN_CHUNK_SIZE , load_docs_in_chunks
47+ from graphrag .prompt_tune .loader .input import load_docs_in_chunks
4548from graphrag .prompt_tune .types import DocSelectionType
4649
4750
48- @validate_call
51+ @validate_call ( config = { "arbitrary_types_allowed" : True })
4952async def generate_indexing_prompts (
5053 config : GraphRagConfig ,
54+ logger : ProgressLogger ,
5155 root : str ,
52- chunk_size : PositiveInt = MIN_CHUNK_SIZE ,
56+ chunk_size : PositiveInt = graphrag_config_defaults .chunks .size ,
57+ overlap : Annotated [
58+ int , annotated_types .Gt (- 1 )
59+ ] = graphrag_config_defaults .chunks .overlap ,
5360 limit : PositiveInt = 15 ,
5461 selection_method : DocSelectionType = DocSelectionType .RANDOM ,
5562 domain : str | None = None ,
@@ -65,6 +72,8 @@ async def generate_indexing_prompts(
6572 Parameters
6673 ----------
6774 - config: The GraphRag configuration.
75+ - logger: The logger to use for progress updates.
76+ - root: The root directory.
6877 - output_path: The path to store the prompts.
6978 - chunk_size: The chunk token size to use for input text units.
7079 - limit: The limit of chunks to load.
@@ -81,22 +90,23 @@ async def generate_indexing_prompts(
8190 -------
8291 tuple[str, str, str]: entity extraction prompt, entity summarization prompt, community summarization prompt
8392 """
84- logger = PrintProgressLogger ("" )
85-
8693 # Retrieve documents
94+ logger .info ("Chunking documents..." )
8795 doc_list = await load_docs_in_chunks (
8896 root = root ,
8997 config = config ,
9098 limit = limit ,
9199 select_method = selection_method ,
92100 logger = logger ,
93101 chunk_size = chunk_size ,
102+ overlap = overlap ,
94103 n_subset_max = n_subset_max ,
95104 k = k ,
96105 )
97106
98107 # Create LLM from config
99108 # TODO: Expose a way to specify Prompt Tuning model ID through config
109+ logger .info ("Retrieving language model configuration..." )
100110 default_llm_settings = config .get_language_model_config (PROMPT_TUNING_MODEL_ID )
101111
102112 # if max_retries is not set, inject a dynamically assigned value based on the number of expected LLM calls
@@ -105,7 +115,10 @@ async def generate_indexing_prompts(
105115 default_llm_settings .max_retries = min (
106116 len (doc_list ), language_model_defaults .max_retries
107117 )
118+ msg = f"max_retries not set, using default value: { default_llm_settings .max_retries } "
119+ logger .warning (msg )
108120
121+ logger .info ("Creating language model..." )
109122 llm = ModelManager ().register_chat (
110123 name = "prompt_tuning" ,
111124 model_type = default_llm_settings .type ,
@@ -117,7 +130,6 @@ async def generate_indexing_prompts(
117130 if not domain :
118131 logger .info ("Generating domain..." )
119132 domain = await generate_domain (llm , doc_list )
120- logger .info (f"Generated domain: { domain } " ) # noqa
121133
122134 if not language :
123135 logger .info ("Detecting language..." )
@@ -186,6 +198,10 @@ async def generate_indexing_prompts(
186198 language = language ,
187199 )
188200
201+ logger .info (f"\n Generated domain: { domain } " ) # noqa: G004
202+ logger .info (f"\n Detected language: { language } " ) # noqa: G004
203+ logger .info (f"\n Generated persona: { persona } " ) # noqa: G004
204+
189205 return (
190206 extract_graph_prompt ,
191207 entity_summarization_prompt ,
0 commit comments