Skip to content

Commit 97764f8

Browse files
authored
chore: refactor boba GRPO for tracing (#527)
* chore: add an example of tracing boba GRPO * feat: enable rollout tracing in boba_grpo configuration * refactor: enhance performance tracing in boba_grpo and remove unused tracer files * feat: add initial configuration for boba_grpo experiment
1 parent 4d9e863 commit 97764f8

File tree

2 files changed

+248
-33
lines changed

2 files changed

+248
-33
lines changed

examples/math/boba_grpo.py

Lines changed: 89 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,47 @@
11
import os
22
import sys
33

4-
import torch
54
import torch.distributed as dist
65
from datasets import load_dataset
76

87
from areal.api.cli_args import GRPOConfig, load_expr_config
9-
from areal.api.io_struct import AllocationMode, FinetuneSpec, StepInfo
8+
from areal.api.io_struct import AllocationMode, FinetuneSpec, StepInfo, WeightUpdateMeta
109
from areal.engine.ppo.actor import FSDPPPOActor
1110
from areal.engine.sglang_remote import RemoteSGLangEngine
1211
from areal.engine.vllm_remote import RemotevLLMEngine
1312
from areal.platforms import current_platform
14-
from areal.utils import logging, seeding, stats_tracker
13+
from areal.utils import logging, perf_tracer, seeding, stats_tracker
1514
from areal.utils.data import (
1615
cycle_dataloader,
1716
)
1817
from areal.utils.dataloader import create_dataloader
1918
from areal.utils.device import log_gpu_stats
2019
from areal.utils.evaluator import Evaluator
2120
from areal.utils.hf_utils import load_hf_tokenizer
22-
from areal.utils.model import get_model_update_meta
21+
from areal.utils.perf_tracer import Category
2322
from areal.utils.recover import RecoverHandler
2423
from areal.utils.saver import Saver
2524
from areal.utils.stats_logger import StatsLogger
2625
from areal.workflow.rlvr import RLVRWorkflow
2726

2827
logger = logging.getLogger("boba_grpo")
2928

30-
REWARD_TIMEOUT_SECONDS = 30
31-
3229

3330
def get_input_ids_fn(data, tokenizer, enable_thinking):
3431
user_token = "<|User|>"
3532
assistant_token = "<|Assistant|>"
3633
think_token = "<think>"
37-
if user_token in data:
38-
data = data.replace("<|User|>", "")
39-
if assistant_token in data:
40-
data = data.replace("<|Assistant|>", "")
41-
if think_token in data:
42-
enable_thinking = True
43-
data = data.replace("<think>", "")
34+
has_think_token = think_token in data
35+
data = (
36+
data.replace(user_token, "")
37+
.replace(assistant_token, "")
38+
.replace(think_token, "")
39+
)
4440
input_ids = tokenizer.apply_chat_template(
4541
[{"role": "user", "content": data}],
4642
tokenize=True,
4743
add_generation_prompt=True,
48-
enable_thinking=enable_thinking,
44+
enable_thinking=enable_thinking or has_think_token,
4945
)
5046
return input_ids
5147

@@ -91,6 +87,10 @@ def main(args):
9187
actor = FSDPPPOActor(config=config.actor)
9288
actor.create_process_group(parallel_strategy=parallel_strategy)
9389

90+
# Configure performance tracer
91+
if config.perf_tracer is not None:
92+
perf_tracer.configure(config.perf_tracer, rank=rank)
93+
9494
world_size = actor.data_parallel_world_size
9595
if config.train_dataset.batch_size < world_size:
9696
raise ValueError(
@@ -107,12 +107,7 @@ def main(args):
107107
dataset_config=config.train_dataset,
108108
)
109109

110-
device = torch.device(int(os.environ["LOCAL_RANK"]))
111110
train_dataset_len = len(train_dataloader)
112-
dataset_len_tensor = torch.tensor(
113-
[train_dataset_len], dtype=torch.long, device=device
114-
)
115-
train_dataset_len = int(dataset_len_tensor.item())
116111
ft_spec = FinetuneSpec(
117112
total_train_epochs=config.total_train_epochs,
118113
dataset_size=train_dataset_len * config.train_dataset.batch_size,
@@ -126,7 +121,7 @@ def main(args):
126121
rollout = RemoteSGLangEngine(config.rollout)
127122
rollout.initialize(train_data_parallel_size=parallel_strategy.dp_size)
128123

129-
weight_update_meta = get_model_update_meta(config)
124+
weight_update_meta = WeightUpdateMeta.from_fsdp_xccl(allocation_mode)
130125

131126
# Initialize train engine
132127
actor.initialize(None, ft_spec)
@@ -146,6 +141,7 @@ def main(args):
146141
reward_fn=boba_reward_fn,
147142
gconfig=config.gconfig,
148143
tokenizer=tokenizer,
144+
enable_thinking=True,
149145
dump_dir=os.path.join(
150146
StatsLogger.get_log_path(config.stats_logger), "generated"
151147
),
@@ -193,7 +189,14 @@ def main(args):
193189
steps_per_epoch=steps_per_epoch,
194190
)
195191

196-
with stats_tracker.record_timing("rollout"):
192+
with (
193+
stats_tracker.record_timing("rollout"),
194+
perf_tracer.trace_scope(
195+
"train.rollout",
196+
category=Category.COMPUTE,
197+
args={"global_step": global_step},
198+
),
199+
):
197200
if config.async_training:
198201
batch = actor.prepare_batch(
199202
train_dataloader,
@@ -210,23 +213,49 @@ def main(args):
210213
)
211214

212215
if config.actor.recompute_logprob or config.actor.use_decoupled_loss:
213-
with stats_tracker.record_timing("recompute_logp"):
216+
with (
217+
stats_tracker.record_timing("recompute_logp"),
218+
perf_tracer.trace_scope(
219+
"train.recompute_logp",
220+
category=Category.COMPUTE,
221+
args={"global_step": global_step},
222+
),
223+
):
214224
logp = actor.compute_logp(batch)
215225
batch["prox_logp"] = logp
216226
log_gpu_stats("recompute logp")
217227

218228
if ref is not None:
219-
with stats_tracker.record_timing("ref_logp"):
229+
with (
230+
stats_tracker.record_timing("ref_logp"),
231+
perf_tracer.trace_scope(
232+
"train.ref_logp",
233+
category=Category.COMPUTE,
234+
args={"global_step": global_step},
235+
),
236+
):
220237
batch["ref_logp"] = ref.compute_logp(batch)
221238
log_gpu_stats("ref logp")
222239

223-
with stats_tracker.record_timing("compute_advantage"):
240+
with (
241+
stats_tracker.record_timing("compute_advantage"),
242+
perf_tracer.trace_scope(
243+
"train.compute_advantage",
244+
category=Category.COMPUTE,
245+
args={"global_step": global_step},
246+
),
247+
):
224248
actor.compute_advantages(batch)
225249
log_gpu_stats("compute advantages")
226250

227251
with (
228252
stats_tracker.record_timing("train_step"),
229253
stats_tracker.scope("grpo_actor"),
254+
perf_tracer.trace_scope(
255+
"train.ppo_update",
256+
category=Category.COMPUTE,
257+
args={"global_step": global_step},
258+
),
230259
):
231260
stats = actor.ppo_update(batch)
232261
actor.step_lr_scheduler()
@@ -235,18 +264,37 @@ def main(args):
235264
# pause inference for updating weights, save, and evaluation
236265
rollout.pause()
237266

238-
with stats_tracker.record_timing("update_weights"):
267+
with (
268+
stats_tracker.record_timing("update_weights"),
269+
perf_tracer.trace_scope(
270+
"train.update_weights",
271+
category=Category.COMM,
272+
args={"global_step": global_step},
273+
),
274+
):
239275
actor.update_weights(weight_update_meta)
240276

241277
actor.set_version(global_step + 1)
242278
rollout.set_version(global_step + 1)
243279

244-
rollout.resume()
245-
246-
with stats_tracker.record_timing("save"):
280+
with (
281+
stats_tracker.record_timing("save"),
282+
perf_tracer.trace_scope(
283+
"train.save",
284+
category=Category.IO,
285+
args={"global_step": global_step},
286+
),
287+
):
247288
saver.save(actor, epoch, step, global_step, tokenizer=tokenizer)
248289

249-
with stats_tracker.record_timing("checkpoint_for_recover"):
290+
with (
291+
stats_tracker.record_timing("checkpoint_for_recover"),
292+
perf_tracer.trace_scope(
293+
"train.checkpoint",
294+
category=Category.IO,
295+
args={"global_step": global_step},
296+
),
297+
):
250298
recover_handler.dump(
251299
actor,
252300
step_info,
@@ -261,22 +309,30 @@ def main(args):
261309
current_platform.synchronize()
262310

263311
# Upload statistics to the logger (e.g., wandb)
264-
stats[0].update(
265-
stats_tracker.export_all(reduce_group=actor.data_parallel_group)
266-
)
267-
stats_logger.commit(epoch, step, global_step, stats)
312+
with perf_tracer.trace_scope(
313+
"train.log_stats",
314+
category=Category.INSTR,
315+
args={"global_step": global_step},
316+
):
317+
stats[0].update(
318+
stats_tracker.export_all(reduce_group=actor.data_parallel_group)
319+
)
320+
stats_logger.commit(epoch, step, global_step, stats)
268321

269322
dist.barrier(device_ids=[actor.device.index])
270323
current_platform.synchronize()
271324

272325
# Resume rollout
273326
rollout.resume()
274327

328+
perf_tracer.save(step=global_step)
329+
275330
stats_logger.close()
276331
rollout.destroy()
277332
if ref is not None:
278333
ref.destroy()
279334
actor.destroy()
335+
perf_tracer.save(force=True)
280336

281337

282338
if __name__ == "__main__":

0 commit comments

Comments
 (0)