|
| 1 | +import os |
| 2 | +import sys |
| 3 | + |
| 4 | +from areal.api.alloc_mode import AllocationMode |
| 5 | +from areal.api.cli_args import GRPOConfig, SGLangConfig, load_expr_config |
| 6 | +from areal.api.io_struct import FinetuneSpec, StepInfo, WeightUpdateMeta |
| 7 | +from areal.controller.rollout_controller import RolloutController |
| 8 | +from areal.controller.train_controller import TrainController |
| 9 | +from areal.dataset import get_custom_dataset |
| 10 | +from areal.engine.ppo.actor import FSDPPPOActor |
| 11 | +from areal.engine.sglang_remote import RemoteSGLangEngine |
| 12 | +from areal.scheduler.local import LocalScheduler |
| 13 | +from areal.utils import stats_tracker |
| 14 | +from areal.utils.data import ( |
| 15 | + cycle_dataloader, |
| 16 | +) |
| 17 | +from areal.utils.dataloader import create_dataloader |
| 18 | +from areal.utils.device import log_gpu_stats |
| 19 | +from areal.utils.evaluator import Evaluator |
| 20 | +from areal.utils.hf_utils import load_hf_tokenizer |
| 21 | +from areal.utils.recover import RecoverHandler |
| 22 | +from areal.utils.saver import Saver |
| 23 | +from areal.utils.stats_logger import StatsLogger |
| 24 | + |
| 25 | + |
| 26 | +def main(args): |
| 27 | + config, _ = load_expr_config(args, GRPOConfig) |
| 28 | + config: GRPOConfig |
| 29 | + |
| 30 | + tokenizer = load_hf_tokenizer(config.tokenizer_path) |
| 31 | + |
| 32 | + # Create dataset and dataloaders |
| 33 | + train_dataset = get_custom_dataset( |
| 34 | + split="train", dataset_config=config.train_dataset, tokenizer=tokenizer |
| 35 | + ) |
| 36 | + |
| 37 | + train_dataloader = create_dataloader( |
| 38 | + train_dataset, |
| 39 | + rank=0, |
| 40 | + world_size=1, |
| 41 | + dataset_config=config.train_dataset, |
| 42 | + ) |
| 43 | + |
| 44 | + ft_spec = FinetuneSpec( |
| 45 | + total_train_epochs=config.total_train_epochs, |
| 46 | + dataset_size=len(train_dataloader) * config.train_dataset.batch_size, |
| 47 | + train_batch_size=config.train_dataset.batch_size, |
| 48 | + ) |
| 49 | + |
| 50 | + # Initialize scheduler |
| 51 | + scheduler = LocalScheduler(exp_config=config) |
| 52 | + |
| 53 | + # Initialize train controller |
| 54 | + allocation_mode = AllocationMode.from_str(config.allocation_mode) |
| 55 | + actor = TrainController(FSDPPPOActor, config=config.actor, scheduler=scheduler) |
| 56 | + actor.initialize( |
| 57 | + role="actor", alloc_mode=allocation_mode, ft_spec=ft_spec, addr=None |
| 58 | + ) |
| 59 | + |
| 60 | + # Initialize inference engine |
| 61 | + rollout = RolloutController( |
| 62 | + RemoteSGLangEngine, config=config.rollout, scheduler=scheduler |
| 63 | + ) |
| 64 | + rollout.initialize( |
| 65 | + role="rollout", |
| 66 | + alloc_mode=allocation_mode, |
| 67 | + engine_args=SGLangConfig.build_args( |
| 68 | + sglang_config=config.sglang, |
| 69 | + tp_size=allocation_mode.gen.tp_size, |
| 70 | + base_gpu_id=0, |
| 71 | + ), |
| 72 | + ) |
| 73 | + |
| 74 | + weight_update_meta = WeightUpdateMeta.from_disk( |
| 75 | + experiment_name=config.experiment_name, |
| 76 | + trial_name=config.trial_name, |
| 77 | + file_root=config.cluster.fileroot, |
| 78 | + ) |
| 79 | + actor.connect_engine(rollout, weight_update_meta) |
| 80 | + |
| 81 | + ref = None |
| 82 | + if config.actor.kl_ctl > 0 and config.ref is not None: |
| 83 | + ref = TrainController(FSDPPPOActor, config=config.ref, scheduler=scheduler) |
| 84 | + ref.initialize( |
| 85 | + role="ref", alloc_mode=allocation_mode, ft_spec=ft_spec, addr=None |
| 86 | + ) |
| 87 | + |
| 88 | + # Run training. |
| 89 | + saver = Saver(config.saver, ft_spec) |
| 90 | + stats_logger = StatsLogger(config, ft_spec) |
| 91 | + evaluator = Evaluator(config.evaluator, ft_spec) |
| 92 | + |
| 93 | + recover_handler = RecoverHandler(config.recover, ft_spec) |
| 94 | + |
| 95 | + try: |
| 96 | + recover_info = recover_handler.load( |
| 97 | + actor, |
| 98 | + saver, |
| 99 | + evaluator, |
| 100 | + stats_logger, |
| 101 | + train_dataloader, |
| 102 | + inference_engine=rollout, |
| 103 | + weight_update_meta=weight_update_meta, |
| 104 | + ) |
| 105 | + start_step = ( |
| 106 | + recover_info.last_step_info.next().global_step |
| 107 | + if recover_info is not None |
| 108 | + else 0 |
| 109 | + ) |
| 110 | + |
| 111 | + total_epochs = config.total_train_epochs |
| 112 | + steps_per_epoch = len(train_dataloader) |
| 113 | + max_steps = total_epochs * steps_per_epoch |
| 114 | + |
| 115 | + data_generator = cycle_dataloader(train_dataloader) |
| 116 | + for global_step in range(start_step, max_steps): |
| 117 | + epoch = global_step // steps_per_epoch |
| 118 | + step = global_step % steps_per_epoch |
| 119 | + step_info = StepInfo( |
| 120 | + global_step=global_step, |
| 121 | + epoch=epoch, |
| 122 | + epoch_step=step, |
| 123 | + steps_per_epoch=steps_per_epoch, |
| 124 | + ) |
| 125 | + |
| 126 | + with stats_tracker.record_timing("rollout"): |
| 127 | + if config.async_training: |
| 128 | + batch = actor.prepare_batch( |
| 129 | + train_dataloader, |
| 130 | + workflow_path="areal.workflow.rlvr.RLVRWorkflow", |
| 131 | + workflow_kwargs=dict( |
| 132 | + reward_fn="areal.reward.gsm8k.gsm8k_reward_fn", |
| 133 | + gconfig=config.gconfig, |
| 134 | + tokenizer=config.tokenizer_path, |
| 135 | + enable_thinking=False, |
| 136 | + dump_dir=os.path.join( |
| 137 | + StatsLogger.get_log_path(config.stats_logger), |
| 138 | + "generated", |
| 139 | + ), |
| 140 | + ), |
| 141 | + ) |
| 142 | + else: |
| 143 | + batch = actor.rollout_batch( |
| 144 | + next(data_generator), |
| 145 | + workflow_path="areal.workflow.rlvr.RLVRWorkflow", |
| 146 | + workflow_kwargs=dict( |
| 147 | + reward_fn="areal.reward.gsm8k.gsm8k_reward_fn", |
| 148 | + gconfig=config.gconfig, |
| 149 | + tokenizer=config.tokenizer_path, |
| 150 | + enable_thinking=False, |
| 151 | + dump_dir=os.path.join( |
| 152 | + StatsLogger.get_log_path(config.stats_logger), |
| 153 | + "generated", |
| 154 | + ), |
| 155 | + ), |
| 156 | + ) |
| 157 | + |
| 158 | + if config.actor.recompute_logprob or config.actor.use_decoupled_loss: |
| 159 | + with stats_tracker.record_timing("recompute_logp"): |
| 160 | + logp = actor.compute_logp(batch) |
| 161 | + batch["prox_logp"] = logp |
| 162 | + log_gpu_stats("recompute logp") |
| 163 | + |
| 164 | + if ref is not None: |
| 165 | + with stats_tracker.record_timing("ref_logp"): |
| 166 | + batch["ref_logp"] = ref.compute_logp(batch) |
| 167 | + log_gpu_stats("ref logp") |
| 168 | + |
| 169 | + with stats_tracker.record_timing("compute_advantage"): |
| 170 | + batch = actor.compute_advantages(batch) |
| 171 | + log_gpu_stats("compute advantages") |
| 172 | + |
| 173 | + with stats_tracker.record_timing("train_step"): |
| 174 | + actor.ppo_update(batch) |
| 175 | + actor.step_lr_scheduler() |
| 176 | + log_gpu_stats("ppo update") |
| 177 | + |
| 178 | + # pause inference for updating weights, save, and evaluation |
| 179 | + rollout.pause() |
| 180 | + |
| 181 | + with stats_tracker.record_timing("update_weights"): |
| 182 | + actor.update_weights(weight_update_meta) |
| 183 | + |
| 184 | + actor.set_version(global_step + 1) |
| 185 | + rollout.set_version(global_step + 1) |
| 186 | + |
| 187 | + with stats_tracker.record_timing("save"): |
| 188 | + saver.save(actor, epoch, step, global_step, tokenizer=tokenizer) |
| 189 | + |
| 190 | + with stats_tracker.record_timing("checkpoint_for_recover"): |
| 191 | + recover_handler.dump( |
| 192 | + actor, |
| 193 | + step_info, |
| 194 | + saver, |
| 195 | + evaluator, |
| 196 | + stats_logger, |
| 197 | + train_dataloader, |
| 198 | + tokenizer=tokenizer, |
| 199 | + ) |
| 200 | + |
| 201 | + # Upload statistics to the logger (e.g., wandb) |
| 202 | + stats_logger.commit(epoch, step, global_step, actor.export_stats()) |
| 203 | + |
| 204 | + # Resume rollout |
| 205 | + rollout.resume() |
| 206 | + |
| 207 | + finally: |
| 208 | + stats_logger.close() |
| 209 | + rollout.destroy() |
| 210 | + if ref is not None: |
| 211 | + ref.destroy() |
| 212 | + actor.destroy() |
| 213 | + |
| 214 | + |
| 215 | +if __name__ == "__main__": |
| 216 | + main(sys.argv[1:]) |
0 commit comments