11import argparse
2+ import importlib
23import os
34import subprocess
5+ import sys
46import time
57
68import torch
79import torch .distributed as dist
810import yaml
911
1012from trinity .algorithm .algorithm import ALGORITHM_TYPE
11- from trinity .common .constants import MODEL_PATH_ENV_VAR
13+ from trinity .common .constants import MODEL_PATH_ENV_VAR , SyncStyle
1214from trinity .utils .dlc_utils import get_dlc_env_vars
1315
1416
1517def set_engine_num (config , args ):
1618 config ["cluster" ]["node_num" ] = args .node_num
1719 config ["cluster" ]["gpu_per_node" ] = args .gpu_per_node
18- batch_size = config ["buffer" ]["batch_size" ]
20+ batch_size = config ["buffer" ]["batch_size" ] * config [ "algorithm" ][ "repeat_times" ]
1921 if config ["mode" ] == "train" :
2022 return
2123
@@ -61,6 +63,83 @@ def update_opt_explorer_num(trainer_gpu_num, opt_explorer_num, opt_ratio_diff):
6163 config ["explorer" ]["rollout_model" ]["engine_num" ] = opt_explorer_num
6264
6365
66+ def check_taskset_path (dataset_name : str , taskset_path : str ) -> str :
67+ """Ensures the taskset path exists for the given dataset; generates it if necessary.
68+
69+ This function checks whether `taskset_path` exists. If not,
70+ it uses a corresponding data generation script (e.g., gen_countdown_data.py) to create
71+ the dataset at the default or provided location. The generator scripts are expected
72+ to be located in the 'scripts/' subdirectory relative to this file.
73+
74+ Args:
75+ dataset_name: Name of the dataset (e.g., "countdown", "guru").
76+ Must be one of the supported datasets defined in `dataset_script_map`.
77+ taskset_path: Path to the dataset.
78+
79+ Returns:
80+ str: The resolved path to the dataset.
81+
82+ Raises:
83+ ValueError: If the `dataset_name` is not supported.
84+ FileNotFoundError: If the corresponding generator script does not exist.
85+ ImportError: If the generator module fails to load.
86+ AttributeError: If the loaded module does not define 'DEFAULT_DATA_PATH'.
87+ subprocess.CalledProcessError: If the generation script fails (due to check=True).
88+
89+ Side Effects:
90+ - May create directories and files on disk via the external generation script.
91+ - Executes a subprocess to run the dataset generation script.
92+
93+ Examples:
94+ For dataset_name='guru_math' and taskset_path=None, this function will runs the
95+ following command and generate the guru_math dataset to default location
96+ (DEFAULT_DATA_PATH in scripts/gen_guru_math_data.py):
97+
98+ ```bash
99+ python scripts/gen_guru_math_data.py --local_dir DEFAULT_DATA_PATH
100+ ```
101+ """
102+ if taskset_path :
103+ if os .path .exists (taskset_path ):
104+ return taskset_path
105+ if dataset_name == "gsm8k" and taskset_path == "openai/gsm8k" :
106+ return taskset_path
107+
108+ dataset_script_map = {
109+ "countdown" : "gen_countdown_data.py" ,
110+ "guru_math" : "gen_guru_math_data.py" ,
111+ }
112+ if dataset_name not in dataset_script_map :
113+ raise ValueError (
114+ f"Unsupported dataset: { dataset_name } . Please specify a valid taskset path."
115+ )
116+
117+ base_dir = os .path .dirname (__file__ )
118+ script_filename = dataset_script_map [dataset_name ]
119+ script_module_name = script_filename [:- 3 ] # remove .py
120+
121+ script_file_path = os .path .join (base_dir , "scripts" , script_filename )
122+ if not os .path .exists (script_file_path ):
123+ raise FileNotFoundError (f"Generator script not found: { script_file_path } " )
124+
125+ spec = importlib .util .spec_from_file_location (script_module_name , script_file_path )
126+ if spec is None or spec .loader is None :
127+ raise ImportError (f"Could not load spec for module: { script_module_name } " )
128+ module = importlib .util .module_from_spec (spec )
129+ spec .loader .exec_module (module )
130+
131+ if taskset_path is None :
132+ if not hasattr (module , "DEFAULT_DATA_PATH" ):
133+ raise AttributeError (f"{ script_filename } is missing 'DEFAULT_DATA_PATH'" )
134+ taskset_path = module .DEFAULT_DATA_PATH
135+ taskset_path = os .path .realpath (taskset_path )
136+
137+ gen_script_path = os .path .join (base_dir , "scripts" , script_filename )
138+ subprocess .run ([sys .executable , gen_script_path , "--local_dir" , taskset_path ], check = True )
139+
140+ return taskset_path
141+
142+
64143def prepare_configs (args , rank , current_time ):
65144 base_path = os .path .dirname (os .path .abspath (__file__ ))
66145
@@ -89,18 +168,19 @@ def prepare_configs(args, rank, current_time):
89168 )
90169 if args .critic_lr :
91170 config ["trainer" ]["trainer_config" ]["critic" ]["optim" ]["lr" ] = args .critic_lr
92- config ["buffer" ]["explorer_input" ]["taskset" ][ "path" ] = (
93- args . taskset_path
94- or os . environ . get ( "TASKSET_PATH" )
95- or config [ "buffer" ][ "explorer_input" ][ "taskset" ][ " path" ]
171+ taskset_config = config ["buffer" ]["explorer_input" ]["taskset" ]
172+ taskset_config [ "path" ] = check_taskset_path (
173+ args . dataset ,
174+ args . taskset_path or os . environ . get ( "TASKSET_PATH" ) or taskset_config [ " path" ],
96175 )
97- assert (
98- config ["buffer" ]["explorer_input" ]["taskset" ]["path" ] is not None
99- ), "Please specify taskset path."
100176 if args .lr :
101177 config ["algorithm" ]["optimizer" ]["lr" ] = args .lr
102178 if args .sync_interval :
103179 config ["synchronizer" ]["sync_interval" ] = args .sync_interval
180+ if args .sync_offset :
181+ config ["synchronizer" ]["sync_offset" ] = args .sync_offset
182+ if args .sync_style :
183+ config ["synchronizer" ]["sync_style" ] = args .sync_style
104184
105185 with open (config_path , "w" ) as f :
106186 yaml .dump (config , f , allow_unicode = True , sort_keys = False )
@@ -131,7 +211,7 @@ def main(args):
131211 rank , current_time = 0 , time .time ()
132212 config_path = prepare_configs (args , rank , current_time )
133213 cmd_list = [
134- "python" ,
214+ sys . executable ,
135215 "-m" ,
136216 "trinity.cli.launcher" ,
137217 "run" ,
@@ -142,12 +222,21 @@ def main(args):
142222 dist .barrier ()
143223 dist .destroy_process_group ()
144224 cmd_list .append ("--dlc" )
225+
226+ # load plugins
227+ base_path = os .path .dirname (os .path .abspath (__file__ ))
228+ plugin_dir = os .path .join (base_path , "plugins" , args .dataset )
229+ if os .path .exists (plugin_dir ):
230+ cmd_list .append ("--plugin-dir" )
231+ cmd_list .append (plugin_dir )
232+
233+ # run command
145234 subprocess .run (cmd_list , check = True )
146235
147236
148237if __name__ == "__main__" :
149238 parser = argparse .ArgumentParser ()
150- parser .add_argument ("dataset" , type = str , choices = ["gsm8k" , "countdown" , "openr1 " ])
239+ parser .add_argument ("dataset" , type = str . lower , choices = ["gsm8k" , "countdown" , "guru_math " ])
151240 parser .add_argument (
152241 "--dlc" , action = "store_true" , help = "Specify when running in Aliyun PAI DLC."
153242 )
@@ -191,5 +280,12 @@ def main(args):
191280 parser .add_argument (
192281 "--sync_interval" , type = int , default = None , help = "Specify the sync interval."
193282 )
283+ parser .add_argument ("--sync_offset" , type = int , default = None , help = "Specify the sync offset." )
284+ parser .add_argument (
285+ "--sync_style" ,
286+ type = str ,
287+ default = None ,
288+ choices = [sync_style .value for sync_style in SyncStyle ],
289+ )
194290 args = parser .parse_args ()
195291 main (args )
0 commit comments