Skip to content

Commit 5c0ffc6

Browse files
authored
chore: format areal/api with ruff (#491)
* format areal/api with ruff * format * .
1 parent fdd7cdb commit 5c0ffc6

File tree

10 files changed

+125
-133
lines changed

10 files changed

+125
-133
lines changed

areal/api/alloc_mode.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import enum
1010
import math
1111
from dataclasses import dataclass, field
12-
from typing import Optional
1312

1413
from lark import Lark, Transformer
1514

@@ -216,7 +215,7 @@ def parallelism_eq(this, other):
216215
class MegatronParallelStrategy(ParallelStrategy):
217216
"""Megatron parallel strategy with additional sequence parallelism and virtual pipeline parallelism."""
218217

219-
virtual_pipeline_parallel_size: Optional[int] = field(
218+
virtual_pipeline_parallel_size: int | None = field(
220219
default=None,
221220
metadata={
222221
"help": "Virtual pipeline parallelism size for megatron modules "
@@ -234,10 +233,7 @@ class MegatronParallelStrategy(ParallelStrategy):
234233
def parallelism_eq(this, other):
235234
"""Compare Megatron parallelism configurations (excluding sequence parallelism)."""
236235
return ParallelStrategy.parallelism_eq(this, other) and (
237-
(
238-
this.virtual_pipeline_parallel_size
239-
== other.virtual_pipeline_parallel_size
240-
)
236+
this.virtual_pipeline_parallel_size == other.virtual_pipeline_parallel_size
241237
)
242238

243239

@@ -274,9 +270,9 @@ class AllocationMode:
274270

275271
type_: AllocationType
276272
gen: ParallelStrategy = field(default_factory=ParallelStrategy)
277-
train: Optional[ParallelStrategy] = None
278-
gen_backend: Optional[str] = None
279-
train_backend: Optional[str] = None
273+
train: ParallelStrategy | None = None
274+
gen_backend: str | None = None
275+
train_backend: str | None = None
280276

281277
@property
282278
def gen_instance_size(self) -> int:
@@ -407,7 +403,7 @@ class TrainingParallelism:
407403
and comprehensive validation rules.
408404
"""
409405

410-
backend: Optional[str] = None
406+
backend: str | None = None
411407
strategy: ParallelStrategy = field(default_factory=lambda: ParallelStrategy())
412408

413409
def __post_init__(self):

areal/api/cli_args.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,19 @@
33
import os
44
from dataclasses import asdict, dataclass, field
55
from pathlib import Path
6-
from typing import Dict, List
76

87
import uvloop
98
import yaml
10-
11-
from areal.utils.pkg_version import is_version_less
12-
13-
uvloop.install()
149
from hydra import compose as hydra_compose
1510
from hydra import initialize as hydra_init
1611
from hydra.core.global_hydra import GlobalHydra
1712
from omegaconf import MISSING, DictConfig, OmegaConf
1813

1914
from areal.platforms import current_platform
2015
from areal.utils import name_resolve, pkg_version
16+
from areal.utils.pkg_version import is_version_less
17+
18+
uvloop.install()
2119

2220

2321
@dataclass
@@ -129,11 +127,11 @@ class GenerationHyperparameters:
129127
default=1.0,
130128
metadata={"help": "Sampling temperature. Higher values increase diversity."},
131129
)
132-
stop_token_ids: List[int] = field(
130+
stop_token_ids: list[int] = field(
133131
default_factory=list,
134132
metadata={"help": "Stop generation when encountering these token IDs."},
135133
)
136-
stop: List[str] | None = field(
134+
stop: list[str] | None = field(
137135
default=None,
138136
metadata={
139137
"help": "One or multiple stop words. Generation will stop if one of these words is sampled."
@@ -232,7 +230,7 @@ class OptimizerConfig:
232230
class FSDPWrapPolicy:
233231
"""Policy configuration for FSDP model layer wrapping. None defaults to wrapping transformer decoder layers defined by transformers."""
234232

235-
transformer_layer_cls_to_wrap: List[str] | None = field(
233+
transformer_layer_cls_to_wrap: list[str] | None = field(
236234
default=None,
237235
metadata={"help": "A list of transformer layer names for FSDP to wrap."},
238236
)
@@ -310,7 +308,7 @@ class MegatronEngineConfig:
310308
recompute_method: str | None = "uniform"
311309
recompute_num_layers: int | None = 1
312310
distribute_saved_activations: bool | None = None
313-
recompute_modules: List[str] | None = None
311+
recompute_modules: list[str] | None = None
314312

315313

316314
@dataclass
@@ -378,7 +376,7 @@ class TrainEngineConfig:
378376
)
379377
lora_rank: int = field(default=32, metadata={"help": "lora rank"})
380378
lora_alpha: int = field(default=16, metadata={"help": "lora alpha"})
381-
target_modules: List[str] = field(
379+
target_modules: list[str] = field(
382380
default_factory=list,
383381
metadata={"help": "lora target_modules."},
384382
)
@@ -500,7 +498,7 @@ class PPOActorConfig(TrainEngineConfig):
500498
default=False,
501499
metadata={"help": "Log statistics for agent trajectories"},
502500
)
503-
log_agent_stats_keys: List[str] = field(
501+
log_agent_stats_keys: list[str] = field(
504502
default_factory=lambda: [],
505503
metadata={"help": "Keys for logging agent trajectory statistics"},
506504
)
@@ -574,7 +572,7 @@ def build_args(
574572
port,
575573
dist_init_addr: str | None = None,
576574
):
577-
args: Dict = conf_as_dict(vllm_config)
575+
args: dict = conf_as_dict(vllm_config)
578576
args = dict(
579577
host=host,
580578
port=port,
@@ -608,11 +606,11 @@ def build_cmd(
608606
if v is None or v is False or v == "":
609607
continue
610608
if v is True:
611-
flags.append(f"--{k.replace('_','-')}")
609+
flags.append(f"--{k.replace('_', '-')}")
612610
elif isinstance(v, list):
613-
flags.append(f"--{k.replace('_','-')} {' '.join(map(str, v))}")
611+
flags.append(f"--{k.replace('_', '-')} {' '.join(map(str, v))}")
614612
else:
615-
flags.append(f"--{k.replace('_','-')} {v}")
613+
flags.append(f"--{k.replace('_', '-')} {v}")
616614
return f"python3 -m areal.thirdparty.vllm.areal_vllm_server {' '.join(flags)}"
617615

618616

@@ -638,7 +636,7 @@ class SGLangConfig:
638636
enable_torch_compile: bool = False
639637
torch_compile_max_bs: int = 32
640638
cuda_graph_max_bs: int | None = None
641-
cuda_graph_bs: List[int] | None = None
639+
cuda_graph_bs: list[int] | None = None
642640
torchao_config: str = ""
643641
enable_nan_detection: bool = False
644642
enable_p2p_check: bool = False
@@ -667,8 +665,8 @@ class SGLangConfig:
667665
# lora
668666
enable_lora: bool | None = None
669667
max_lora_rank: int | None = None
670-
lora_target_modules: List[str] | None = None
671-
lora_paths: List[str] | None = None
668+
lora_target_modules: list[str] | None = None
669+
lora_paths: list[str] | None = None
672670
max_loaded_loras: int = 1
673671
max_loras_per_batch: int = 1
674672
lora_backend: str = "triton"
@@ -719,11 +717,11 @@ def build_cmd(
719717
if v is None or v is False or v == "":
720718
continue
721719
if v is True:
722-
flags.append(f"--{k.replace('_','-')}")
720+
flags.append(f"--{k.replace('_', '-')}")
723721
elif isinstance(v, list):
724-
flags.append(f"--{k.replace('_','-')} {' '.join(map(str, v))}")
722+
flags.append(f"--{k.replace('_', '-')} {' '.join(map(str, v))}")
725723
else:
726-
flags.append(f"--{k.replace('_','-')} {v}")
724+
flags.append(f"--{k.replace('_', '-')} {v}")
727725
return f"python3 -m sglang.launch_server {' '.join(flags)}"
728726

729727
@staticmethod
@@ -738,11 +736,12 @@ def build_args(
738736
node_rank: int = 0,
739737
):
740738
# Map "all-linear" to "all"
741-
args: Dict = conf_as_dict(sglang_config)
739+
args: dict = conf_as_dict(sglang_config)
742740
if sglang_config.enable_multithread_load or sglang_config.enable_fast_load:
743-
assert pkg_version.is_version_equal(
744-
"sglang", "0.5.2"
745-
), f"Customized model loading requires exact SGLang version 0.5.2"
741+
if not pkg_version.is_version_equal("sglang", "0.5.2"):
742+
raise RuntimeError(
743+
"Customized model loading requires exact SGLang version 0.5.2"
744+
)
746745
model_loader_extra_config = dict(
747746
enable_multithread_load=sglang_config.enable_multithread_load,
748747
enable_fast_load=sglang_config.enable_fast_load,
@@ -791,7 +790,8 @@ class InferenceEngineConfig:
791790
max_concurrent_rollouts: None | int = field(
792791
default=None,
793792
metadata={
794-
"help": "Maximum number of concurrent rollouts to the inference engine. Defaults to consumer_batch_size."
793+
"help": "Maximum number of concurrent rollouts to "
794+
"the inference engine. Defaults to consumer_batch_size."
795795
},
796796
)
797797
queue_size: None | int = field(
@@ -915,8 +915,8 @@ class WandBConfig:
915915
job_type: str | None = None
916916
group: str | None = None
917917
notes: str | None = None
918-
tags: List[str] | None = None
919-
config: Dict | None = None
918+
tags: list[str] | None = None
919+
config: dict | None = None
920920
id_suffix: str | None = "train"
921921

922922

@@ -926,7 +926,7 @@ class SwanlabConfig:
926926

927927
project: str | None = None
928928
name: str | None = None
929-
config: Dict | None = None
929+
config: dict | None = None
930930
logdir: str | None = None
931931
mode: str | None = "disabled"
932932
api_key: str | None = os.getenv("SWANLAB_API_KEY", None)
@@ -1023,7 +1023,7 @@ class SchedulerConfig:
10231023
endpoint: str = field(default="http://localhost:8081")
10241024
deploy_mode: str = field(default="separation")
10251025
functioncall_service_domain: str = field(default="http://localhost:8080")
1026-
reward_functioncall_config: Dict = field(default_factory=dict)
1026+
reward_functioncall_config: dict = field(default_factory=dict)
10271027
reward_model_path: str = field(default="")
10281028
reward_model_service_url: str = field(default="http://localhost:30000/classify")
10291029

@@ -1076,7 +1076,7 @@ class SlurmLauncherConfig:
10761076
default="--mpi=pmi2 -K --chdir $PWD",
10771077
metadata={"help": "Additional arguments to pass to the srun command."},
10781078
)
1079-
additional_bash_cmds: List[str] | None = field(
1079+
additional_bash_cmds: list[str] | None = field(
10801080
default=None,
10811081
metadata={
10821082
"help": "Additional bash commands to setup the container before running "
@@ -1244,7 +1244,7 @@ class PPOConfig(GRPOConfig):
12441244
critic: PPOCriticConfig = field(default_factory=PPOCriticConfig)
12451245

12461246

1247-
def parse_cli_args(argv: List[str]):
1247+
def parse_cli_args(argv: list[str]):
12481248
parser = argparse.ArgumentParser()
12491249
parser.add_argument(
12501250
"--config", help="Path to the main configuration file", required=True
@@ -1277,7 +1277,7 @@ def to_structured_cfg(cfg, config_cls):
12771277
return cfg
12781278

12791279

1280-
def load_expr_config(argv: List[str], config_cls):
1280+
def load_expr_config(argv: list[str], config_cls):
12811281
cfg, config_file = parse_cli_args(argv)
12821282
cfg = to_structured_cfg(cfg, config_cls=config_cls)
12831283
cfg = OmegaConf.to_object(cfg)
@@ -1305,7 +1305,7 @@ def save_config(cfg, log_dir):
13051305
os.makedirs(log_dir, exist_ok=True)
13061306
config_save_path = os.path.join(log_dir, "config.yaml")
13071307
with open(config_save_path, "w") as f:
1308-
config_dict: Dict = asdict(cfg)
1308+
config_dict: dict = asdict(cfg)
13091309
yaml.dump(
13101310
config_dict,
13111311
f,

0 commit comments

Comments
 (0)