Skip to content

Commit c70d414

Browse files
committed
.
1 parent 83b786e commit c70d414

File tree

3 files changed

+391
-0
lines changed

3 files changed

+391
-0
lines changed

docs/cli_reference.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,7 @@ Controls text generation behavior for rollout.
465465
| `stop_token_ids` | list of integer | **Required** | Stop generation when encountering these token IDs. |
466466
| `stop` | list of string \| None | `None` | One or multiple stop words. Generation will stop if one of these words is sampled. |
467467
| `frequency_penalty` | float | `0.0` | Penalizes tokens based on their frequency in generation so far. Must be between -2 and 2 where negative numbers encourage repetition. |
468+
| `lora_name` | string | `""` | Lora name to be used for this generation. |
468469

469470
(section-inference-engine)=
470471

@@ -489,6 +490,7 @@ Configuration for inference servers, including offpolicyness control.
489490
| `pause_grace_period` | float | `0.0` | The grace period after calling /pause_generation. Wait until all requests have been dropped. |
490491
| `scheduling_spec` | `tuple` | **Required** | inference engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the RolloutController. |
491492
| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the RolloutController. |
493+
| `use_lora` | boolean | `False` | Whether to use LoRA. Should be same as actors LORA option. |
492494

493495
(section-sg-lang)=
494496

@@ -585,6 +587,8 @@ https://docs.vllm.ai/en/stable/api/index.html for detailed documentation.
585587
| `worker_extension_cls` | string | `"areal.thirdparty.vllm.vllm_worker_extension.VLLMWorkerExtension"` | - |
586588
| `enable_sleep_mode` | boolean | `False` | - |
587589
| `uvicorn_log_level` | string | `"warning"` | - |
590+
| `enable_lora` | boolean | `False` | - |
591+
| `lora_modules` | string | `""` | - |
588592

589593
(section-train-dataset)=
590594

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
import os
2+
import sys
3+
4+
from areal.api.alloc_mode import AllocationMode
5+
from areal.api.cli_args import GRPOConfig, SGLangConfig, load_expr_config
6+
from areal.api.io_struct import FinetuneSpec, StepInfo, WeightUpdateMeta
7+
from areal.controller.rollout_controller import RolloutController
8+
from areal.controller.train_controller import TrainController
9+
from areal.dataset import get_custom_dataset
10+
from areal.engine.ppo.actor import FSDPPPOActor
11+
from areal.engine.sglang_remote import RemoteSGLangEngine
12+
from areal.scheduler.local import LocalScheduler
13+
from areal.utils import stats_tracker
14+
from areal.utils.data import (
15+
cycle_dataloader,
16+
)
17+
from areal.utils.dataloader import create_dataloader
18+
from areal.utils.device import log_gpu_stats
19+
from areal.utils.evaluator import Evaluator
20+
from areal.utils.hf_utils import load_hf_tokenizer
21+
from areal.utils.recover import RecoverHandler
22+
from areal.utils.saver import Saver
23+
from areal.utils.stats_logger import StatsLogger
24+
25+
26+
def main(args):
27+
config, _ = load_expr_config(args, GRPOConfig)
28+
config: GRPOConfig
29+
30+
tokenizer = load_hf_tokenizer(config.tokenizer_path)
31+
32+
# Create dataset and dataloaders
33+
train_dataset = get_custom_dataset(
34+
split="train", dataset_config=config.train_dataset, tokenizer=tokenizer
35+
)
36+
37+
train_dataloader = create_dataloader(
38+
train_dataset,
39+
rank=0,
40+
world_size=1,
41+
dataset_config=config.train_dataset,
42+
)
43+
44+
ft_spec = FinetuneSpec(
45+
total_train_epochs=config.total_train_epochs,
46+
dataset_size=len(train_dataloader) * config.train_dataset.batch_size,
47+
train_batch_size=config.train_dataset.batch_size,
48+
)
49+
50+
# Initialize scheduler
51+
scheduler = LocalScheduler(exp_config=config)
52+
53+
# Initialize train controller
54+
allocation_mode = AllocationMode.from_str(config.allocation_mode)
55+
actor = TrainController(FSDPPPOActor, config=config.actor, scheduler=scheduler)
56+
actor.initialize(
57+
role="actor", alloc_mode=allocation_mode, ft_spec=ft_spec, addr=None
58+
)
59+
60+
# Initialize inference engine
61+
rollout = RolloutController(
62+
RemoteSGLangEngine, config=config.rollout, scheduler=scheduler
63+
)
64+
rollout.initialize(
65+
role="rollout",
66+
alloc_mode=allocation_mode,
67+
engine_args=SGLangConfig.build_args(
68+
sglang_config=config.sglang,
69+
tp_size=allocation_mode.gen.tp_size,
70+
base_gpu_id=0,
71+
),
72+
)
73+
74+
weight_update_meta = WeightUpdateMeta.from_disk(
75+
experiment_name=config.experiment_name,
76+
trial_name=config.trial_name,
77+
file_root=config.cluster.fileroot,
78+
)
79+
actor.connect_engine(rollout, weight_update_meta)
80+
81+
ref = None
82+
if config.actor.kl_ctl > 0 and config.ref is not None:
83+
ref = TrainController(FSDPPPOActor, config=config.ref, scheduler=scheduler)
84+
ref.initialize(
85+
role="ref", alloc_mode=allocation_mode, ft_spec=ft_spec, addr=None
86+
)
87+
88+
# Run training.
89+
saver = Saver(config.saver, ft_spec)
90+
stats_logger = StatsLogger(config, ft_spec)
91+
evaluator = Evaluator(config.evaluator, ft_spec)
92+
93+
recover_handler = RecoverHandler(config.recover, ft_spec)
94+
95+
try:
96+
recover_info = recover_handler.load(
97+
actor,
98+
saver,
99+
evaluator,
100+
stats_logger,
101+
train_dataloader,
102+
inference_engine=rollout,
103+
weight_update_meta=weight_update_meta,
104+
)
105+
start_step = (
106+
recover_info.last_step_info.next().global_step
107+
if recover_info is not None
108+
else 0
109+
)
110+
111+
total_epochs = config.total_train_epochs
112+
steps_per_epoch = len(train_dataloader)
113+
max_steps = total_epochs * steps_per_epoch
114+
115+
data_generator = cycle_dataloader(train_dataloader)
116+
for global_step in range(start_step, max_steps):
117+
epoch = global_step // steps_per_epoch
118+
step = global_step % steps_per_epoch
119+
step_info = StepInfo(
120+
global_step=global_step,
121+
epoch=epoch,
122+
epoch_step=step,
123+
steps_per_epoch=steps_per_epoch,
124+
)
125+
126+
with stats_tracker.record_timing("rollout"):
127+
if config.async_training:
128+
batch = actor.prepare_batch(
129+
train_dataloader,
130+
workflow_path="areal.workflow.rlvr.RLVRWorkflow",
131+
workflow_kwargs=dict(
132+
reward_fn="areal.reward.gsm8k.gsm8k_reward_fn",
133+
gconfig=config.gconfig,
134+
tokenizer=config.tokenizer_path,
135+
enable_thinking=False,
136+
dump_dir=os.path.join(
137+
StatsLogger.get_log_path(config.stats_logger),
138+
"generated",
139+
),
140+
),
141+
)
142+
else:
143+
batch = actor.rollout_batch(
144+
next(data_generator),
145+
workflow_path="areal.workflow.rlvr.RLVRWorkflow",
146+
workflow_kwargs=dict(
147+
reward_fn="areal.reward.gsm8k.gsm8k_reward_fn",
148+
gconfig=config.gconfig,
149+
tokenizer=config.tokenizer_path,
150+
enable_thinking=False,
151+
dump_dir=os.path.join(
152+
StatsLogger.get_log_path(config.stats_logger),
153+
"generated",
154+
),
155+
),
156+
)
157+
158+
if config.actor.recompute_logprob or config.actor.use_decoupled_loss:
159+
with stats_tracker.record_timing("recompute_logp"):
160+
logp = actor.compute_logp(batch)
161+
batch["prox_logp"] = logp
162+
log_gpu_stats("recompute logp")
163+
164+
if ref is not None:
165+
with stats_tracker.record_timing("ref_logp"):
166+
batch["ref_logp"] = ref.compute_logp(batch)
167+
log_gpu_stats("ref logp")
168+
169+
with stats_tracker.record_timing("compute_advantage"):
170+
batch = actor.compute_advantages(batch)
171+
log_gpu_stats("compute advantages")
172+
173+
with stats_tracker.record_timing("train_step"):
174+
actor.ppo_update(batch)
175+
actor.step_lr_scheduler()
176+
log_gpu_stats("ppo update")
177+
178+
# pause inference for updating weights, save, and evaluation
179+
rollout.pause()
180+
181+
with stats_tracker.record_timing("update_weights"):
182+
actor.update_weights(weight_update_meta)
183+
184+
actor.set_version(global_step + 1)
185+
rollout.set_version(global_step + 1)
186+
187+
with stats_tracker.record_timing("save"):
188+
saver.save(actor, epoch, step, global_step, tokenizer=tokenizer)
189+
190+
with stats_tracker.record_timing("checkpoint_for_recover"):
191+
recover_handler.dump(
192+
actor,
193+
step_info,
194+
saver,
195+
evaluator,
196+
stats_logger,
197+
train_dataloader,
198+
tokenizer=tokenizer,
199+
)
200+
201+
# Upload statistics to the logger (e.g., wandb)
202+
stats_logger.commit(epoch, step, global_step, actor.export_stats())
203+
204+
# Resume rollout
205+
rollout.resume()
206+
207+
finally:
208+
stats_logger.close()
209+
rollout.destroy()
210+
if ref is not None:
211+
ref.destroy()
212+
actor.destroy()
213+
214+
215+
if __name__ == "__main__":
216+
main(sys.argv[1:])

0 commit comments

Comments
 (0)