Skip to content

Commit 34c3a53

Browse files
authored
Add benchmark scripts for Guru-Math (#417)
1 parent 93bf2ea commit 34c3a53

File tree

12 files changed

+866
-223
lines changed

12 files changed

+866
-223
lines changed

benchmark/README.md

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,24 @@ The chart below shows performance based on this [commit](https://github.com/mode
6969
![View Results](../docs/sphinx_doc/assets/gsm8k-bench.png)
7070

7171
### 2. Countdown
72-
First generate data, then run the benchmark:
72+
To reproduce this experiment:
7373
```bash
74-
# Step 1: Generate data
75-
python benchmark/scripts/gen-countdown-data.py --local_dir /your/data/path
76-
# Step 2: Run benchmark
77-
python bench.py countdown --model_path /path/to/Qwen/Qwen2.5-1.5B-Instruct --taskset_path /your/data/path
74+
python bench.py countdown --model_path /path/to/Qwen/Qwen2.5-1.5B-Instruct
7875
```
7976
#### Countdown Results
8077
The chart below shows performance based on this [commit](https://github.com/modelscope/Trinity-RFT/tree/068da409d215bb2450d93b6b7a56740d4751669d).
8178
![View Results](../docs/sphinx_doc/assets/countdown-bench.png)
8279

80+
### 3. Guru-Math
81+
To reproduce this experiment:
82+
```bash
83+
python bench.py guru_math --model_path /path/to/Qwen/Qwen2.5-7B
84+
```
85+
86+
#### Guru Results
87+
The chart below shows performance based on this [commit](https://github.com/modelscope/Trinity-RFT/tree/fbf6c967bcd637bfd9f81fb4d7dd4961d7d5a407).
88+
![View Results](../docs/sphinx_doc/assets/guru-bench.png)
89+
8390
*More benchmarks will be added soon!*
8491

8592
---

benchmark/bench.py

Lines changed: 107 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
import argparse
2+
import importlib
23
import os
34
import subprocess
5+
import sys
46
import time
57

68
import torch
79
import torch.distributed as dist
810
import yaml
911

1012
from trinity.algorithm.algorithm import ALGORITHM_TYPE
11-
from trinity.common.constants import MODEL_PATH_ENV_VAR
13+
from trinity.common.constants import MODEL_PATH_ENV_VAR, SyncStyle
1214
from trinity.utils.dlc_utils import get_dlc_env_vars
1315

1416

1517
def set_engine_num(config, args):
1618
config["cluster"]["node_num"] = args.node_num
1719
config["cluster"]["gpu_per_node"] = args.gpu_per_node
18-
batch_size = config["buffer"]["batch_size"]
20+
batch_size = config["buffer"]["batch_size"] * config["algorithm"]["repeat_times"]
1921
if config["mode"] == "train":
2022
return
2123

@@ -61,6 +63,83 @@ def update_opt_explorer_num(trainer_gpu_num, opt_explorer_num, opt_ratio_diff):
6163
config["explorer"]["rollout_model"]["engine_num"] = opt_explorer_num
6264

6365

66+
def check_taskset_path(dataset_name: str, taskset_path: str) -> str:
67+
"""Ensures the taskset path exists for the given dataset; generates it if necessary.
68+
69+
This function checks whether `taskset_path` exists. If not,
70+
it uses a corresponding data generation script (e.g., gen_countdown_data.py) to create
71+
the dataset at the default or provided location. The generator scripts are expected
72+
to be located in the 'scripts/' subdirectory relative to this file.
73+
74+
Args:
75+
dataset_name: Name of the dataset (e.g., "countdown", "guru").
76+
Must be one of the supported datasets defined in `dataset_script_map`.
77+
taskset_path: Path to the dataset.
78+
79+
Returns:
80+
str: The resolved path to the dataset.
81+
82+
Raises:
83+
ValueError: If the `dataset_name` is not supported.
84+
FileNotFoundError: If the corresponding generator script does not exist.
85+
ImportError: If the generator module fails to load.
86+
AttributeError: If the loaded module does not define 'DEFAULT_DATA_PATH'.
87+
subprocess.CalledProcessError: If the generation script fails (due to check=True).
88+
89+
Side Effects:
90+
- May create directories and files on disk via the external generation script.
91+
- Executes a subprocess to run the dataset generation script.
92+
93+
Examples:
94+
For dataset_name='guru_math' and taskset_path=None, this function will runs the
95+
following command and generate the guru_math dataset to default location
96+
(DEFAULT_DATA_PATH in scripts/gen_guru_math_data.py):
97+
98+
```bash
99+
python scripts/gen_guru_math_data.py --local_dir DEFAULT_DATA_PATH
100+
```
101+
"""
102+
if taskset_path:
103+
if os.path.exists(taskset_path):
104+
return taskset_path
105+
if dataset_name == "gsm8k" and taskset_path == "openai/gsm8k":
106+
return taskset_path
107+
108+
dataset_script_map = {
109+
"countdown": "gen_countdown_data.py",
110+
"guru_math": "gen_guru_math_data.py",
111+
}
112+
if dataset_name not in dataset_script_map:
113+
raise ValueError(
114+
f"Unsupported dataset: {dataset_name}. Please specify a valid taskset path."
115+
)
116+
117+
base_dir = os.path.dirname(__file__)
118+
script_filename = dataset_script_map[dataset_name]
119+
script_module_name = script_filename[:-3] # remove .py
120+
121+
script_file_path = os.path.join(base_dir, "scripts", script_filename)
122+
if not os.path.exists(script_file_path):
123+
raise FileNotFoundError(f"Generator script not found: {script_file_path}")
124+
125+
spec = importlib.util.spec_from_file_location(script_module_name, script_file_path)
126+
if spec is None or spec.loader is None:
127+
raise ImportError(f"Could not load spec for module: {script_module_name}")
128+
module = importlib.util.module_from_spec(spec)
129+
spec.loader.exec_module(module)
130+
131+
if taskset_path is None:
132+
if not hasattr(module, "DEFAULT_DATA_PATH"):
133+
raise AttributeError(f"{script_filename} is missing 'DEFAULT_DATA_PATH'")
134+
taskset_path = module.DEFAULT_DATA_PATH
135+
taskset_path = os.path.realpath(taskset_path)
136+
137+
gen_script_path = os.path.join(base_dir, "scripts", script_filename)
138+
subprocess.run([sys.executable, gen_script_path, "--local_dir", taskset_path], check=True)
139+
140+
return taskset_path
141+
142+
64143
def prepare_configs(args, rank, current_time):
65144
base_path = os.path.dirname(os.path.abspath(__file__))
66145

@@ -89,18 +168,19 @@ def prepare_configs(args, rank, current_time):
89168
)
90169
if args.critic_lr:
91170
config["trainer"]["trainer_config"]["critic"]["optim"]["lr"] = args.critic_lr
92-
config["buffer"]["explorer_input"]["taskset"]["path"] = (
93-
args.taskset_path
94-
or os.environ.get("TASKSET_PATH")
95-
or config["buffer"]["explorer_input"]["taskset"]["path"]
171+
taskset_config = config["buffer"]["explorer_input"]["taskset"]
172+
taskset_config["path"] = check_taskset_path(
173+
args.dataset,
174+
args.taskset_path or os.environ.get("TASKSET_PATH") or taskset_config["path"],
96175
)
97-
assert (
98-
config["buffer"]["explorer_input"]["taskset"]["path"] is not None
99-
), "Please specify taskset path."
100176
if args.lr:
101177
config["algorithm"]["optimizer"]["lr"] = args.lr
102178
if args.sync_interval:
103179
config["synchronizer"]["sync_interval"] = args.sync_interval
180+
if args.sync_offset:
181+
config["synchronizer"]["sync_offset"] = args.sync_offset
182+
if args.sync_style:
183+
config["synchronizer"]["sync_style"] = args.sync_style
104184

105185
with open(config_path, "w") as f:
106186
yaml.dump(config, f, allow_unicode=True, sort_keys=False)
@@ -131,7 +211,7 @@ def main(args):
131211
rank, current_time = 0, time.time()
132212
config_path = prepare_configs(args, rank, current_time)
133213
cmd_list = [
134-
"python",
214+
sys.executable,
135215
"-m",
136216
"trinity.cli.launcher",
137217
"run",
@@ -142,12 +222,21 @@ def main(args):
142222
dist.barrier()
143223
dist.destroy_process_group()
144224
cmd_list.append("--dlc")
225+
226+
# load plugins
227+
base_path = os.path.dirname(os.path.abspath(__file__))
228+
plugin_dir = os.path.join(base_path, "plugins", args.dataset)
229+
if os.path.exists(plugin_dir):
230+
cmd_list.append("--plugin-dir")
231+
cmd_list.append(plugin_dir)
232+
233+
# run command
145234
subprocess.run(cmd_list, check=True)
146235

147236

148237
if __name__ == "__main__":
149238
parser = argparse.ArgumentParser()
150-
parser.add_argument("dataset", type=str, choices=["gsm8k", "countdown", "openr1"])
239+
parser.add_argument("dataset", type=str.lower, choices=["gsm8k", "countdown", "guru_math"])
151240
parser.add_argument(
152241
"--dlc", action="store_true", help="Specify when running in Aliyun PAI DLC."
153242
)
@@ -191,5 +280,12 @@ def main(args):
191280
parser.add_argument(
192281
"--sync_interval", type=int, default=None, help="Specify the sync interval."
193282
)
283+
parser.add_argument("--sync_offset", type=int, default=None, help="Specify the sync offset.")
284+
parser.add_argument(
285+
"--sync_style",
286+
type=str,
287+
default=None,
288+
choices=[sync_style.value for sync_style in SyncStyle],
289+
)
194290
args = parser.parse_args()
195291
main(args)

benchmark/config/countdown-template.yaml

Lines changed: 3 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
mode: both
22
project: Trinity-RFT
3-
group: countdown-bench
4-
name: countdown-qwen2.5-1.5B
3+
group: ${oc.env:TRINITY_GROUP,countdown-bench}
4+
name: ${oc.env:TRINITY_NAME,countdown}
55
checkpoint_root_dir: placeholder
66
algorithm:
77
algorithm_type: ppo
@@ -72,102 +72,16 @@ trainer:
7272
total_steps: 1000
7373
enable_preview: true
7474
grad_clip: 1.0
75+
max_token_len_per_gpu: 6400
7576
trainer_config:
76-
actor_rollout_ref:
77-
hybrid_engine: true
78-
model:
79-
external_lib: null
80-
override_config: {}
81-
enable_gradient_checkpointing: true
82-
use_remove_padding: true
83-
actor:
84-
strategy: fsdp
85-
ppo_micro_batch_size_per_gpu: 4
86-
use_dynamic_bsz: true
87-
ppo_max_token_len_per_gpu: 6400
88-
ppo_epochs: 1
89-
shuffle: false
90-
ulysses_sequence_parallel_size: 1
91-
checkpoint:
92-
load_contents:
93-
- model
94-
- optimizer
95-
- extra
96-
save_contents:
97-
- model
98-
- optimizer
99-
- extra
100-
fsdp_config:
101-
wrap_policy:
102-
min_num_params: 0
103-
param_offload: false
104-
optimizer_offload: false
105-
fsdp_size: -1
106-
ref:
107-
fsdp_config:
108-
wrap_policy:
109-
min_num_params: 0
110-
param_offload: false
111-
optimizer_offload: false
112-
fsdp_size: -1
113-
log_prob_micro_batch_size_per_gpu: 8
114-
log_prob_use_dynamic_bsz: true
115-
log_prob_max_token_len_per_gpu: 6400
116-
ulysses_sequence_parallel_size: 1
117-
custom_reward_function:
118-
path: null
119-
name: compute_score
120-
algorithm:
121-
kl_penalty: low_var_kl
122-
kl_ctrl:
123-
type: fixed
124-
kl_coef: 0.001
125-
trainer:
126-
balance_batch: true
127-
resume_mode: auto
128-
resume_from_path: ''
129-
critic_warmup: 0
130-
default_hdfs_dir: null
131-
remove_previous_ckpt_in_save: false
132-
del_local_ckpt_after_load: false
133-
max_actor_ckpt_to_keep: null
134-
max_critic_ckpt_to_keep: null
13577
critic:
136-
strategy: fsdp
13778
optim:
13879
lr: 1e-5
13980
lr_warmup_steps_ratio: 0.0
14081
warmup_style: constant
141-
model:
142-
override_config: {}
143-
external_lib: null
144-
enable_gradient_checkpointing: true
145-
use_remove_padding: true
146-
fsdp_config:
147-
wrap_policy:
148-
min_num_params: 0
149-
param_offload: false
150-
optimizer_offload: false
151-
fsdp_size: -1
152-
ppo_micro_batch_size_per_gpu: 8
153-
forward_micro_batch_size_per_gpu: 8
154-
use_dynamic_bsz: true
15582
ppo_max_token_len_per_gpu: 12800
15683
forward_max_token_len_per_gpu: 12800
157-
ulysses_sequence_parallel_size: 1
158-
ppo_epochs: 1
159-
shuffle: false
160-
grad_clip: 1.0
16184
cliprange_value: 0.5
162-
checkpoint:
163-
load_contents:
164-
- model
165-
- optimizer
166-
- extra
167-
save_contents:
168-
- model
169-
- optimizer
170-
- extra
17185
monitor:
17286
monitor_type: wandb
17387
synchronizer:

benchmark/config/gsm8k-template.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
mode: both
22
project: Trinity-RFT
3-
group: gsm8k-bench
4-
name: gsm8k-qwen2.5-1.5B
3+
group: ${oc.env:TRINITY_GROUP,gsm8k-bench}
4+
name: ${oc.env:TRINITY_NAME,gsm8k}
55
checkpoint_root_dir: placeholder
66
algorithm:
77
algorithm_type: grpo

0 commit comments

Comments
 (0)