33import os
44from dataclasses import asdict , dataclass , field
55from pathlib import Path
6- from typing import Dict , List
76
87import uvloop
98import yaml
10-
11- from areal .utils .pkg_version import is_version_less
12-
13- uvloop .install ()
149from hydra import compose as hydra_compose
1510from hydra import initialize as hydra_init
1611from hydra .core .global_hydra import GlobalHydra
1712from omegaconf import MISSING , DictConfig , OmegaConf
1813
1914from areal .platforms import current_platform
2015from areal .utils import name_resolve , pkg_version
16+ from areal .utils .pkg_version import is_version_less
17+
18+ uvloop .install ()
2119
2220
2321@dataclass
@@ -129,11 +127,11 @@ class GenerationHyperparameters:
129127 default = 1.0 ,
130128 metadata = {"help" : "Sampling temperature. Higher values increase diversity." },
131129 )
132- stop_token_ids : List [int ] = field (
130+ stop_token_ids : list [int ] = field (
133131 default_factory = list ,
134132 metadata = {"help" : "Stop generation when encountering these token IDs." },
135133 )
136- stop : List [str ] | None = field (
134+ stop : list [str ] | None = field (
137135 default = None ,
138136 metadata = {
139137 "help" : "One or multiple stop words. Generation will stop if one of these words is sampled."
@@ -232,7 +230,7 @@ class OptimizerConfig:
232230class FSDPWrapPolicy :
233231 """Policy configuration for FSDP model layer wrapping. None defaults to wrapping transformer decoder layers defined by transformers."""
234232
235- transformer_layer_cls_to_wrap : List [str ] | None = field (
233+ transformer_layer_cls_to_wrap : list [str ] | None = field (
236234 default = None ,
237235 metadata = {"help" : "A list of transformer layer names for FSDP to wrap." },
238236 )
@@ -310,7 +308,7 @@ class MegatronEngineConfig:
310308 recompute_method : str | None = "uniform"
311309 recompute_num_layers : int | None = 1
312310 distribute_saved_activations : bool | None = None
313- recompute_modules : List [str ] | None = None
311+ recompute_modules : list [str ] | None = None
314312
315313
316314@dataclass
@@ -378,7 +376,7 @@ class TrainEngineConfig:
378376 )
379377 lora_rank : int = field (default = 32 , metadata = {"help" : "lora rank" })
380378 lora_alpha : int = field (default = 16 , metadata = {"help" : "lora alpha" })
381- target_modules : List [str ] = field (
379+ target_modules : list [str ] = field (
382380 default_factory = list ,
383381 metadata = {"help" : "lora target_modules." },
384382 )
@@ -500,7 +498,7 @@ class PPOActorConfig(TrainEngineConfig):
500498 default = False ,
501499 metadata = {"help" : "Log statistics for agent trajectories" },
502500 )
503- log_agent_stats_keys : List [str ] = field (
501+ log_agent_stats_keys : list [str ] = field (
504502 default_factory = lambda : [],
505503 metadata = {"help" : "Keys for logging agent trajectory statistics" },
506504 )
@@ -574,7 +572,7 @@ def build_args(
574572 port ,
575573 dist_init_addr : str | None = None ,
576574 ):
577- args : Dict = conf_as_dict (vllm_config )
575+ args : dict = conf_as_dict (vllm_config )
578576 args = dict (
579577 host = host ,
580578 port = port ,
@@ -608,11 +606,11 @@ def build_cmd(
608606 if v is None or v is False or v == "" :
609607 continue
610608 if v is True :
611- flags .append (f"--{ k .replace ('_' ,'-' )} " )
609+ flags .append (f"--{ k .replace ('_' , '-' )} " )
612610 elif isinstance (v , list ):
613- flags .append (f"--{ k .replace ('_' ,'-' )} { ' ' .join (map (str , v ))} " )
611+ flags .append (f"--{ k .replace ('_' , '-' )} { ' ' .join (map (str , v ))} " )
614612 else :
615- flags .append (f"--{ k .replace ('_' ,'-' )} { v } " )
613+ flags .append (f"--{ k .replace ('_' , '-' )} { v } " )
616614 return f"python3 -m areal.thirdparty.vllm.areal_vllm_server { ' ' .join (flags )} "
617615
618616
@@ -638,7 +636,7 @@ class SGLangConfig:
638636 enable_torch_compile : bool = False
639637 torch_compile_max_bs : int = 32
640638 cuda_graph_max_bs : int | None = None
641- cuda_graph_bs : List [int ] | None = None
639+ cuda_graph_bs : list [int ] | None = None
642640 torchao_config : str = ""
643641 enable_nan_detection : bool = False
644642 enable_p2p_check : bool = False
@@ -667,8 +665,8 @@ class SGLangConfig:
667665 # lora
668666 enable_lora : bool | None = None
669667 max_lora_rank : int | None = None
670- lora_target_modules : List [str ] | None = None
671- lora_paths : List [str ] | None = None
668+ lora_target_modules : list [str ] | None = None
669+ lora_paths : list [str ] | None = None
672670 max_loaded_loras : int = 1
673671 max_loras_per_batch : int = 1
674672 lora_backend : str = "triton"
@@ -719,11 +717,11 @@ def build_cmd(
719717 if v is None or v is False or v == "" :
720718 continue
721719 if v is True :
722- flags .append (f"--{ k .replace ('_' ,'-' )} " )
720+ flags .append (f"--{ k .replace ('_' , '-' )} " )
723721 elif isinstance (v , list ):
724- flags .append (f"--{ k .replace ('_' ,'-' )} { ' ' .join (map (str , v ))} " )
722+ flags .append (f"--{ k .replace ('_' , '-' )} { ' ' .join (map (str , v ))} " )
725723 else :
726- flags .append (f"--{ k .replace ('_' ,'-' )} { v } " )
724+ flags .append (f"--{ k .replace ('_' , '-' )} { v } " )
727725 return f"python3 -m sglang.launch_server { ' ' .join (flags )} "
728726
729727 @staticmethod
@@ -738,11 +736,12 @@ def build_args(
738736 node_rank : int = 0 ,
739737 ):
740738 # Map "all-linear" to "all"
741- args : Dict = conf_as_dict (sglang_config )
739+ args : dict = conf_as_dict (sglang_config )
742740 if sglang_config .enable_multithread_load or sglang_config .enable_fast_load :
743- assert pkg_version .is_version_equal (
744- "sglang" , "0.5.2"
745- ), f"Customized model loading requires exact SGLang version 0.5.2"
741+ if not pkg_version .is_version_equal ("sglang" , "0.5.2" ):
742+ raise RuntimeError (
743+ "Customized model loading requires exact SGLang version 0.5.2"
744+ )
746745 model_loader_extra_config = dict (
747746 enable_multithread_load = sglang_config .enable_multithread_load ,
748747 enable_fast_load = sglang_config .enable_fast_load ,
@@ -791,7 +790,8 @@ class InferenceEngineConfig:
791790 max_concurrent_rollouts : None | int = field (
792791 default = None ,
793792 metadata = {
794- "help" : "Maximum number of concurrent rollouts to the inference engine. Defaults to consumer_batch_size."
793+ "help" : "Maximum number of concurrent rollouts to "
794+ "the inference engine. Defaults to consumer_batch_size."
795795 },
796796 )
797797 queue_size : None | int = field (
@@ -915,8 +915,8 @@ class WandBConfig:
915915 job_type : str | None = None
916916 group : str | None = None
917917 notes : str | None = None
918- tags : List [str ] | None = None
919- config : Dict | None = None
918+ tags : list [str ] | None = None
919+ config : dict | None = None
920920 id_suffix : str | None = "train"
921921
922922
@@ -926,7 +926,7 @@ class SwanlabConfig:
926926
927927 project : str | None = None
928928 name : str | None = None
929- config : Dict | None = None
929+ config : dict | None = None
930930 logdir : str | None = None
931931 mode : str | None = "disabled"
932932 api_key : str | None = os .getenv ("SWANLAB_API_KEY" , None )
@@ -1023,7 +1023,7 @@ class SchedulerConfig:
10231023 endpoint : str = field (default = "http://localhost:8081" )
10241024 deploy_mode : str = field (default = "separation" )
10251025 functioncall_service_domain : str = field (default = "http://localhost:8080" )
1026- reward_functioncall_config : Dict = field (default_factory = dict )
1026+ reward_functioncall_config : dict = field (default_factory = dict )
10271027 reward_model_path : str = field (default = "" )
10281028 reward_model_service_url : str = field (default = "http://localhost:30000/classify" )
10291029
@@ -1076,7 +1076,7 @@ class SlurmLauncherConfig:
10761076 default = "--mpi=pmi2 -K --chdir $PWD" ,
10771077 metadata = {"help" : "Additional arguments to pass to the srun command." },
10781078 )
1079- additional_bash_cmds : List [str ] | None = field (
1079+ additional_bash_cmds : list [str ] | None = field (
10801080 default = None ,
10811081 metadata = {
10821082 "help" : "Additional bash commands to setup the container before running "
@@ -1244,7 +1244,7 @@ class PPOConfig(GRPOConfig):
12441244 critic : PPOCriticConfig = field (default_factory = PPOCriticConfig )
12451245
12461246
1247- def parse_cli_args (argv : List [str ]):
1247+ def parse_cli_args (argv : list [str ]):
12481248 parser = argparse .ArgumentParser ()
12491249 parser .add_argument (
12501250 "--config" , help = "Path to the main configuration file" , required = True
@@ -1277,7 +1277,7 @@ def to_structured_cfg(cfg, config_cls):
12771277 return cfg
12781278
12791279
1280- def load_expr_config (argv : List [str ], config_cls ):
1280+ def load_expr_config (argv : list [str ], config_cls ):
12811281 cfg , config_file = parse_cli_args (argv )
12821282 cfg = to_structured_cfg (cfg , config_cls = config_cls )
12831283 cfg = OmegaConf .to_object (cfg )
@@ -1305,7 +1305,7 @@ def save_config(cfg, log_dir):
13051305 os .makedirs (log_dir , exist_ok = True )
13061306 config_save_path = os .path .join (log_dir , "config.yaml" )
13071307 with open (config_save_path , "w" ) as f :
1308- config_dict : Dict = asdict (cfg )
1308+ config_dict : dict = asdict (cfg )
13091309 yaml .dump (
13101310 config_dict ,
13111311 f ,
0 commit comments