44from dataclasses import MISSING as dataclass_missing
55from dataclasses import asdict , dataclass , field , fields
66from pathlib import Path
7- from typing import Any
7+ from typing import Any , TypeVar
88
99import uvloop
1010import yaml
1111from hydra import compose as hydra_compose
1212from hydra import initialize as hydra_init
1313from hydra .core .global_hydra import GlobalHydra
1414from omegaconf import MISSING , DictConfig , OmegaConf
15+ from transformers import PreTrainedTokenizerFast
1516
1617from areal .platforms import current_platform
1718from areal .utils import logging , name_resolve , pkg_version
19+ from areal .utils .constants import (
20+ PROX_LOGP_METHOD_RECOMPUTE ,
21+ PROX_LOGP_METHODS_ALL ,
22+ )
1823from areal .utils .pkg_version import is_version_less
1924
2025uvloop .install ()
2126
2227logger = logging .getLogger ("CLI args" )
2328
29+ ConfigT = TypeVar ("ConfigT" )
30+
2431
2532@dataclass
2633class NormConfig :
@@ -157,12 +164,25 @@ class GenerationHyperparameters:
157164 )
158165 },
159166 )
167+ lora_name : str = field (
168+ default = "" ,
169+ metadata = {"help" : "Lora name to be used for this generation." },
170+ )
160171
161172 def new (self , ** kwargs ):
162173 args = asdict (self )
163174 args .update (kwargs )
164175 return GenerationHyperparameters (** args )
165176
177+ def new_with_stop_and_pad_token_ids (self , tokenizer : PreTrainedTokenizerFast ):
178+ """Create a new generation hyperparameters with stop and pad token ids added."""
179+ new_stop_token_ids = self .stop_token_ids .copy ()
180+ if tokenizer .pad_token_id not in new_stop_token_ids :
181+ new_stop_token_ids .append (tokenizer .pad_token_id )
182+ if tokenizer .eos_token_id not in new_stop_token_ids :
183+ new_stop_token_ids .append (tokenizer .eos_token_id )
184+ return self .new (stop_token_ids = new_stop_token_ids )
185+
166186 def to_openai_args_dict (
167187 self , exclude_args : list [str ] | None = None
168188 ) -> dict [str , Any ]:
@@ -402,7 +422,6 @@ class SchedulingSpec:
402422 default_factory = dict ,
403423 metadata = {"help" : "Environment variables for the container" },
404424 )
405- # cmd
406425 cmd : str | None = field (
407426 default = None ,
408427 metadata = {
@@ -488,13 +507,32 @@ class TrainEngineConfig:
488507 default = "lora" ,
489508 metadata = {"help" : "peft method type. Only LoRA is supported for now." },
490509 )
491- scheduling_spec : SchedulingSpec = field (
492- default_factory = lambda : SchedulingSpec (
493- cmd = "python -m areal.scheduler.rpc.rpc_server"
510+ scheduling_spec : tuple [ SchedulingSpec , ...] = field (
511+ default_factory = lambda : (
512+ SchedulingSpec ( cmd = "python -m areal.scheduler.rpc.rpc_server" ),
494513 ),
495- metadata = {"help" : "train engine schedule specs" },
514+ metadata = {
515+ "help" : "Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: "
516+ "if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; "
517+ "if 2 specs provided, first one is for worker, second one is for engine. "
518+ "Currently only used by the TrainController."
519+ },
520+ )
521+ scheduling_strategy : SchedulingStrategy = field (
522+ default_factory = SchedulingStrategy ,
523+ metadata = {
524+ "help" : "The scheduling strategy of this TrainEngine, either separation or colocation. "
525+ "Currently only used by the TrainController."
526+ },
496527 )
497- scheduling_strategy : SchedulingStrategy = field (default_factory = SchedulingStrategy )
528+
529+ def __post_init__ (self ):
530+ """Validate scheduling_spec length."""
531+ if len (self .scheduling_spec ) not in (1 , 2 ):
532+ raise ValueError (
533+ f"scheduling_spec must contain 1 or 2 SchedulingSpec, "
534+ f"got { len (self .scheduling_spec )} "
535+ )
498536
499537
500538@dataclass
@@ -605,6 +643,18 @@ class PPOActorConfig(TrainEngineConfig):
605643 "choices" : ["token" , "sequence" ],
606644 },
607645 )
646+ # Proximal Log-Probability Computation Method
647+ prox_logp_method : str = field (
648+ default = PROX_LOGP_METHOD_RECOMPUTE ,
649+ metadata = {
650+ "help" : "Method for computing proximal policy log-probabilities in decoupled PPO. "
651+ "Only effective when use_decoupled_loss=True. Options: "
652+ "'recompute' (default): Standard decoupled PPO, recompute proximal policy via forward pass. "
653+ "'loglinear': Use log-linear interpolation to approximate proximal policy (skip forward pass). "
654+ "'metrics': Like 'recompute', but also compute approximation metrics for evaluation." ,
655+ "choices" : PROX_LOGP_METHODS_ALL ,
656+ },
657+ )
608658 # Advanced Options
609659 dynamic_sampling : bool = field (
610660 default = False ,
@@ -702,6 +752,8 @@ class vLLMConfig:
702752 )
703753 enable_sleep_mode : bool = False
704754 uvicorn_log_level : str = "warning"
755+ enable_lora : bool = False
756+ lora_modules : str = ""
705757
706758 @staticmethod
707759 def build_args (
@@ -726,6 +778,18 @@ def build_args(
726778 args ["port" ] = port
727779 if host is not None :
728780 args ["host" ] = host
781+ # handle lora modules separately
782+ lm = args .get ("lora_modules" )
783+ if lm :
784+ if isinstance (lm , str ):
785+ lm = [lm ]
786+ if isinstance (lm , (list , tuple )):
787+ try :
788+ args ["lora_modules" ] = [
789+ json .dumps (json .loads (s ), separators = ("," , ":" )) for s in lm
790+ ]
791+ except json .JSONDecodeError as e :
792+ raise ValueError (f"Invalid JSON string in lora_modules: { e } " ) from e
729793 return args
730794
731795 @staticmethod
@@ -977,13 +1041,36 @@ class InferenceEngineConfig:
9771041 "help" : "The grace period after calling /pause_generation. Wait until all requests have been dropped."
9781042 },
9791043 )
980- scheduling_spec : SchedulingSpec = field (
981- default_factory = lambda : SchedulingSpec (
982- cmd = "python -m areal.scheduler.rpc.rpc_server"
1044+ scheduling_spec : tuple [ SchedulingSpec , ...] = field (
1045+ default_factory = lambda : (
1046+ SchedulingSpec ( cmd = "python -m areal.scheduler.rpc.rpc_server" ),
9831047 ),
984- metadata = {"help" : "inference engine schedule specs" },
1048+ metadata = {
1049+ "help" : "inference engine schedule specs. Can accept 1 or 2 SchedulingSpec: "
1050+ "if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; "
1051+ "if 2 specs provided, first one is for worker, second one is for engine. "
1052+ "Currently only used by the RolloutController."
1053+ },
1054+ )
1055+ scheduling_strategy : SchedulingStrategy = field (
1056+ default_factory = SchedulingStrategy ,
1057+ metadata = {
1058+ "help" : "The scheduling strategy of this TrainEngine, either separation or colocation. "
1059+ "Currently only used by the RolloutController."
1060+ },
1061+ )
1062+ use_lora : bool = field (
1063+ default = False ,
1064+ metadata = {"help" : "Whether to use LoRA. Should be same as actors LORA option." },
9851065 )
986- scheduling_strategy : SchedulingStrategy = field (default_factory = SchedulingStrategy )
1066+
1067+ def __post_init__ (self ):
1068+ """Validate scheduling_spec length."""
1069+ if len (self .scheduling_spec ) not in (1 , 2 ):
1070+ raise ValueError (
1071+ f"scheduling_spec must contain 1 or 2 SchedulingSpec, "
1072+ f"got { len (self .scheduling_spec )} "
1073+ )
9871074
9881075
9891076@dataclass
@@ -1148,6 +1235,15 @@ class PerfTracerConfig:
11481235 )
11491236 },
11501237 )
1238+ profile_steps : list [int ] | None = field (
1239+ default = None ,
1240+ metadata = {
1241+ "help" : (
1242+ "List of step numbers at which to capture detailed profiling traces. "
1243+ "If None, no detailed profiling traces are captured."
1244+ )
1245+ },
1246+ )
11511247 session_tracer : SessionTracerConfig | None = field (
11521248 default = None ,
11531249 metadata = {"help" : "Session tracing configuration." },
@@ -1223,7 +1319,7 @@ class SchedulerConfig:
12231319
12241320
12251321@dataclass
1226- class DatasetConfig :
1322+ class _DatasetConfig :
12271323 """Configuration for dataset loading and preprocessing."""
12281324
12291325 path : str = field (
@@ -1262,6 +1358,27 @@ class DatasetConfig:
12621358 )
12631359
12641360
1361+ @dataclass
1362+ class TrainDatasetConfig (_DatasetConfig ):
1363+ """Configuration for training dataset loading and preprocessing."""
1364+
1365+
1366+ @dataclass
1367+ class ValidDatasetConfig (_DatasetConfig ):
1368+ """Configuration for validation dataset loading and preprocessing.
1369+
1370+ It has different default values with `TrainDatasetConfig`.
1371+ `shuffle` and `drop_last` default to False.
1372+ """
1373+
1374+ shuffle : bool = field (
1375+ default = False , metadata = {"help" : "Whether to shuffle the dataset" }
1376+ )
1377+ drop_last : bool = field (
1378+ default = False , metadata = {"help" : "Drop the last incomplete batch" }
1379+ )
1380+
1381+
12651382@dataclass
12661383class SlurmLauncherConfig :
12671384 """Configuration for launching the training jobs with Slurm."""
@@ -1359,6 +1476,13 @@ class BaseExperimentConfig:
13591476 metadata = {"help" : "Pattern-based GPU parallel strategy allocation mode. " },
13601477 )
13611478 seed : int = field (default = 1 , metadata = {"help" : "Random seed for reproducibility." })
1479+ enable_offload : bool = field (
1480+ default = False ,
1481+ metadata = {
1482+ "help" : "Whether to enable training offload using torch_memory_saver. "
1483+ "This requires setting up the environment for TMS (e.g., via LD_PRELOAD)."
1484+ },
1485+ )
13621486 total_train_epochs : int = field (
13631487 default = 1 , metadata = {"help" : "Total number of epochs to train the model." }
13641488 )
@@ -1381,8 +1505,8 @@ class BaseExperimentConfig:
13811505 metadata = {"help" : "Path to the tokenizer." },
13821506 )
13831507
1384- train_dataset : DatasetConfig = field (default_factory = DatasetConfig )
1385- valid_dataset : DatasetConfig | None = field (default = None )
1508+ train_dataset : TrainDatasetConfig = field (default_factory = TrainDatasetConfig )
1509+ valid_dataset : ValidDatasetConfig | None = field (default = None )
13861510
13871511 saver : SaverConfig = field (default_factory = SaverConfig )
13881512 evaluator : EvaluatorConfig = field (default_factory = EvaluatorConfig )
@@ -1466,7 +1590,7 @@ def to_structured_cfg(cfg, config_cls):
14661590 return cfg
14671591
14681592
1469- def load_expr_config (argv : list [str ], config_cls ) :
1593+ def load_expr_config (argv : list [str ], config_cls : type [ ConfigT ]) -> tuple [ ConfigT , str ] :
14701594 cfg , config_file = parse_cli_args (argv )
14711595 cfg = to_structured_cfg (cfg , config_cls = config_cls )
14721596 cfg = OmegaConf .to_object (cfg )
0 commit comments