11import os
22import sys
33
4- import torch
54import torch .distributed as dist
65from datasets import load_dataset
76
87from areal .api .cli_args import GRPOConfig , load_expr_config
9- from areal .api .io_struct import AllocationMode , FinetuneSpec , StepInfo
8+ from areal .api .io_struct import AllocationMode , FinetuneSpec , StepInfo , WeightUpdateMeta
109from areal .engine .ppo .actor import FSDPPPOActor
1110from areal .engine .sglang_remote import RemoteSGLangEngine
1211from areal .engine .vllm_remote import RemotevLLMEngine
1312from areal .platforms import current_platform
14- from areal .utils import logging , seeding , stats_tracker
13+ from areal .utils import logging , perf_tracer , seeding , stats_tracker
1514from areal .utils .data import (
1615 cycle_dataloader ,
1716)
1817from areal .utils .dataloader import create_dataloader
1918from areal .utils .device import log_gpu_stats
2019from areal .utils .evaluator import Evaluator
2120from areal .utils .hf_utils import load_hf_tokenizer
22- from areal .utils .model import get_model_update_meta
21+ from areal .utils .perf_tracer import Category
2322from areal .utils .recover import RecoverHandler
2423from areal .utils .saver import Saver
2524from areal .utils .stats_logger import StatsLogger
2625from areal .workflow .rlvr import RLVRWorkflow
2726
2827logger = logging .getLogger ("boba_grpo" )
2928
30- REWARD_TIMEOUT_SECONDS = 30
31-
3229
3330def get_input_ids_fn (data , tokenizer , enable_thinking ):
3431 user_token = "<|User|>"
3532 assistant_token = "<|Assistant|>"
3633 think_token = "<think>"
37- if user_token in data :
38- data = data .replace ("<|User|>" , "" )
39- if assistant_token in data :
40- data = data .replace ("<|Assistant|>" , "" )
41- if think_token in data :
42- enable_thinking = True
43- data = data .replace ("<think>" , "" )
34+ has_think_token = think_token in data
35+ data = (
36+ data .replace (user_token , "" )
37+ .replace (assistant_token , "" )
38+ .replace (think_token , "" )
39+ )
4440 input_ids = tokenizer .apply_chat_template (
4541 [{"role" : "user" , "content" : data }],
4642 tokenize = True ,
4743 add_generation_prompt = True ,
48- enable_thinking = enable_thinking ,
44+ enable_thinking = enable_thinking or has_think_token ,
4945 )
5046 return input_ids
5147
@@ -91,6 +87,10 @@ def main(args):
9187 actor = FSDPPPOActor (config = config .actor )
9288 actor .create_process_group (parallel_strategy = parallel_strategy )
9389
90+ # Configure performance tracer
91+ if config .perf_tracer is not None :
92+ perf_tracer .configure (config .perf_tracer , rank = rank )
93+
9494 world_size = actor .data_parallel_world_size
9595 if config .train_dataset .batch_size < world_size :
9696 raise ValueError (
@@ -107,12 +107,7 @@ def main(args):
107107 dataset_config = config .train_dataset ,
108108 )
109109
110- device = torch .device (int (os .environ ["LOCAL_RANK" ]))
111110 train_dataset_len = len (train_dataloader )
112- dataset_len_tensor = torch .tensor (
113- [train_dataset_len ], dtype = torch .long , device = device
114- )
115- train_dataset_len = int (dataset_len_tensor .item ())
116111 ft_spec = FinetuneSpec (
117112 total_train_epochs = config .total_train_epochs ,
118113 dataset_size = train_dataset_len * config .train_dataset .batch_size ,
@@ -126,7 +121,7 @@ def main(args):
126121 rollout = RemoteSGLangEngine (config .rollout )
127122 rollout .initialize (train_data_parallel_size = parallel_strategy .dp_size )
128123
129- weight_update_meta = get_model_update_meta ( config )
124+ weight_update_meta = WeightUpdateMeta . from_fsdp_xccl ( allocation_mode )
130125
131126 # Initialize train engine
132127 actor .initialize (None , ft_spec )
@@ -146,6 +141,7 @@ def main(args):
146141 reward_fn = boba_reward_fn ,
147142 gconfig = config .gconfig ,
148143 tokenizer = tokenizer ,
144+ enable_thinking = True ,
149145 dump_dir = os .path .join (
150146 StatsLogger .get_log_path (config .stats_logger ), "generated"
151147 ),
@@ -193,7 +189,14 @@ def main(args):
193189 steps_per_epoch = steps_per_epoch ,
194190 )
195191
196- with stats_tracker .record_timing ("rollout" ):
192+ with (
193+ stats_tracker .record_timing ("rollout" ),
194+ perf_tracer .trace_scope (
195+ "train.rollout" ,
196+ category = Category .COMPUTE ,
197+ args = {"global_step" : global_step },
198+ ),
199+ ):
197200 if config .async_training :
198201 batch = actor .prepare_batch (
199202 train_dataloader ,
@@ -210,23 +213,49 @@ def main(args):
210213 )
211214
212215 if config .actor .recompute_logprob or config .actor .use_decoupled_loss :
213- with stats_tracker .record_timing ("recompute_logp" ):
216+ with (
217+ stats_tracker .record_timing ("recompute_logp" ),
218+ perf_tracer .trace_scope (
219+ "train.recompute_logp" ,
220+ category = Category .COMPUTE ,
221+ args = {"global_step" : global_step },
222+ ),
223+ ):
214224 logp = actor .compute_logp (batch )
215225 batch ["prox_logp" ] = logp
216226 log_gpu_stats ("recompute logp" )
217227
218228 if ref is not None :
219- with stats_tracker .record_timing ("ref_logp" ):
229+ with (
230+ stats_tracker .record_timing ("ref_logp" ),
231+ perf_tracer .trace_scope (
232+ "train.ref_logp" ,
233+ category = Category .COMPUTE ,
234+ args = {"global_step" : global_step },
235+ ),
236+ ):
220237 batch ["ref_logp" ] = ref .compute_logp (batch )
221238 log_gpu_stats ("ref logp" )
222239
223- with stats_tracker .record_timing ("compute_advantage" ):
240+ with (
241+ stats_tracker .record_timing ("compute_advantage" ),
242+ perf_tracer .trace_scope (
243+ "train.compute_advantage" ,
244+ category = Category .COMPUTE ,
245+ args = {"global_step" : global_step },
246+ ),
247+ ):
224248 actor .compute_advantages (batch )
225249 log_gpu_stats ("compute advantages" )
226250
227251 with (
228252 stats_tracker .record_timing ("train_step" ),
229253 stats_tracker .scope ("grpo_actor" ),
254+ perf_tracer .trace_scope (
255+ "train.ppo_update" ,
256+ category = Category .COMPUTE ,
257+ args = {"global_step" : global_step },
258+ ),
230259 ):
231260 stats = actor .ppo_update (batch )
232261 actor .step_lr_scheduler ()
@@ -235,18 +264,37 @@ def main(args):
235264 # pause inference for updating weights, save, and evaluation
236265 rollout .pause ()
237266
238- with stats_tracker .record_timing ("update_weights" ):
267+ with (
268+ stats_tracker .record_timing ("update_weights" ),
269+ perf_tracer .trace_scope (
270+ "train.update_weights" ,
271+ category = Category .COMM ,
272+ args = {"global_step" : global_step },
273+ ),
274+ ):
239275 actor .update_weights (weight_update_meta )
240276
241277 actor .set_version (global_step + 1 )
242278 rollout .set_version (global_step + 1 )
243279
244- rollout .resume ()
245-
246- with stats_tracker .record_timing ("save" ):
280+ with (
281+ stats_tracker .record_timing ("save" ),
282+ perf_tracer .trace_scope (
283+ "train.save" ,
284+ category = Category .IO ,
285+ args = {"global_step" : global_step },
286+ ),
287+ ):
247288 saver .save (actor , epoch , step , global_step , tokenizer = tokenizer )
248289
249- with stats_tracker .record_timing ("checkpoint_for_recover" ):
290+ with (
291+ stats_tracker .record_timing ("checkpoint_for_recover" ),
292+ perf_tracer .trace_scope (
293+ "train.checkpoint" ,
294+ category = Category .IO ,
295+ args = {"global_step" : global_step },
296+ ),
297+ ):
250298 recover_handler .dump (
251299 actor ,
252300 step_info ,
@@ -261,22 +309,30 @@ def main(args):
261309 current_platform .synchronize ()
262310
263311 # Upload statistics to the logger (e.g., wandb)
264- stats [0 ].update (
265- stats_tracker .export_all (reduce_group = actor .data_parallel_group )
266- )
267- stats_logger .commit (epoch , step , global_step , stats )
312+ with perf_tracer .trace_scope (
313+ "train.log_stats" ,
314+ category = Category .INSTR ,
315+ args = {"global_step" : global_step },
316+ ):
317+ stats [0 ].update (
318+ stats_tracker .export_all (reduce_group = actor .data_parallel_group )
319+ )
320+ stats_logger .commit (epoch , step , global_step , stats )
268321
269322 dist .barrier (device_ids = [actor .device .index ])
270323 current_platform .synchronize ()
271324
272325 # Resume rollout
273326 rollout .resume ()
274327
328+ perf_tracer .save (step = global_step )
329+
275330 stats_logger .close ()
276331 rollout .destroy ()
277332 if ref is not None :
278333 ref .destroy ()
279334 actor .destroy ()
335+ perf_tracer .save (force = True )
280336
281337
282338if __name__ == "__main__" :
0 commit comments