Skip to content

Commit a50773e

Browse files
committed
merge main
2 parents 19886e2 + 83b786e commit a50773e

File tree

172 files changed

+8541
-6591
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

172 files changed

+8541
-6591
lines changed

.github/workflows/deploy-docs.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,14 @@ jobs:
3232
uses: actions/cache@v3
3333
with:
3434
path: ~/.cache/pip
35-
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
35+
key: ${{ runner.os }}-pip-${{ hashFiles('**/pyproject.toml') }}
3636
restore-keys: |
3737
${{ runner.os }}-pip-
3838
3939
- name: Install dependencies
4040
run: |
4141
python -m pip install --upgrade pip
4242
pip install --upgrade jupyter-book==1.0.4.post1
43-
# Install additional dependencies if you have a requirements.txt
44-
# pip install -r requirements.txt
4543
4644
- name: Build the book
4745
run: |

.github/workflows/format-check.yml

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,14 @@ jobs:
1818
- name: Install Python dependencies
1919
run: |
2020
python3 -m pip install --upgrade pip
21-
pip install ruff==0.14.1 black==25.1.0 clang-format==19.1.7 autoflake==2.3.1
21+
pip install ruff==0.14.1 clang-format==19.1.7
2222
23-
- name: Check autoflake formatting
23+
- name: Check Python formatting and linting with ruff
2424
run: |
25-
autoflake --check -r areal/
26-
autoflake --check -r examples/
27-
autoflake --check -r docs/
28-
29-
- name: Check Python formatting with ruff
30-
run: |
31-
ruff check --select I areal/
32-
ruff check --select I examples/
33-
ruff check --select I docs/
34-
35-
- name: Check Python formatting with black
36-
run: black --check .
25+
ruff check areal/
26+
ruff check examples/
27+
ruff format --check areal/
28+
ruff format --check examples/
3729
3830
- name: Check C++ formatting
3931
run: |

AGENTS.md

Lines changed: 222 additions & 202 deletions
Large diffs are not rendered by default.

README.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,8 @@ our project just as you enjoy real-world milk tea (cheers).
3535
## 📰 News
3636

3737
**\[2025/08/30\]** Introducing ASearcher, a state-of-the-art search agent built with
38-
AReaL's end-to-end asynchronous RL training. Check out the
39-
[paper](https://arxiv.org/pdf/2508.07976) and the
40-
[open-source repository](https://github.com/inclusionAI/ASearcher)!
38+
AReaL's end-to-end asynchronous RL training. Check out the [paper](assets/paper.pdf) and
39+
the [open-source repository](https://github.com/inclusionAI/ASearcher)!
4140

4241
**\[2025/07/31\] (AReaL-lite)** We introduce AReaL-lite, a **lightweight** version of
4342
AReaL designed specifically for AI researchers and rapid prototyping. AReaL-lite
@@ -56,7 +55,7 @@ asynchronous RL training, which achieves **2.77× speedup while delivering compa
5655
superior training performance** compared to synchronous systems. Furthermore,
5756
asynchronous RL significantly simplifies multi-turn agentic RL training setup! Check out
5857
[our v0.3 overview blog](/blog/AReaL_v0_3.md) and the
59-
[research paper](https://arxiv.org/pdf/2505.24298).
58+
[research paper](assets/paper.pdf).
6059

6160
**\[2025/03/31\] (v0.2, boba)** Introducing our milestone release—boba! Please call it
6261
A-ReaL-boba! This release features significantly faster training with SGLang support and

areal/api/cli_args.py

Lines changed: 140 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,30 @@
44
from dataclasses import MISSING as dataclass_missing
55
from dataclasses import asdict, dataclass, field, fields
66
from pathlib import Path
7-
from typing import Any
7+
from typing import Any, TypeVar
88

99
import uvloop
1010
import yaml
1111
from hydra import compose as hydra_compose
1212
from hydra import initialize as hydra_init
1313
from hydra.core.global_hydra import GlobalHydra
1414
from omegaconf import MISSING, DictConfig, OmegaConf
15+
from transformers import PreTrainedTokenizerFast
1516

1617
from areal.platforms import current_platform
1718
from areal.utils import logging, name_resolve, pkg_version
19+
from areal.utils.constants import (
20+
PROX_LOGP_METHOD_RECOMPUTE,
21+
PROX_LOGP_METHODS_ALL,
22+
)
1823
from areal.utils.pkg_version import is_version_less
1924

2025
uvloop.install()
2126

2227
logger = logging.getLogger("CLI args")
2328

29+
ConfigT = TypeVar("ConfigT")
30+
2431

2532
@dataclass
2633
class NormConfig:
@@ -157,12 +164,25 @@ class GenerationHyperparameters:
157164
)
158165
},
159166
)
167+
lora_name: str = field(
168+
default="",
169+
metadata={"help": "Lora name to be used for this generation."},
170+
)
160171

161172
def new(self, **kwargs):
162173
args = asdict(self)
163174
args.update(kwargs)
164175
return GenerationHyperparameters(**args)
165176

177+
def new_with_stop_and_pad_token_ids(self, tokenizer: PreTrainedTokenizerFast):
178+
"""Create a new generation hyperparameters with stop and pad token ids added."""
179+
new_stop_token_ids = self.stop_token_ids.copy()
180+
if tokenizer.pad_token_id not in new_stop_token_ids:
181+
new_stop_token_ids.append(tokenizer.pad_token_id)
182+
if tokenizer.eos_token_id not in new_stop_token_ids:
183+
new_stop_token_ids.append(tokenizer.eos_token_id)
184+
return self.new(stop_token_ids=new_stop_token_ids)
185+
166186
def to_openai_args_dict(
167187
self, exclude_args: list[str] | None = None
168188
) -> dict[str, Any]:
@@ -402,7 +422,6 @@ class SchedulingSpec:
402422
default_factory=dict,
403423
metadata={"help": "Environment variables for the container"},
404424
)
405-
# cmd
406425
cmd: str | None = field(
407426
default=None,
408427
metadata={
@@ -488,13 +507,32 @@ class TrainEngineConfig:
488507
default="lora",
489508
metadata={"help": "peft method type. Only LoRA is supported for now."},
490509
)
491-
scheduling_spec: SchedulingSpec = field(
492-
default_factory=lambda: SchedulingSpec(
493-
cmd="python -m areal.scheduler.rpc.rpc_server"
510+
scheduling_spec: tuple[SchedulingSpec, ...] = field(
511+
default_factory=lambda: (
512+
SchedulingSpec(cmd="python -m areal.scheduler.rpc.rpc_server"),
494513
),
495-
metadata={"help": "train engine schedule specs"},
514+
metadata={
515+
"help": "Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: "
516+
"if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; "
517+
"if 2 specs provided, first one is for worker, second one is for engine. "
518+
"Currently only used by the TrainController."
519+
},
520+
)
521+
scheduling_strategy: SchedulingStrategy = field(
522+
default_factory=SchedulingStrategy,
523+
metadata={
524+
"help": "The scheduling strategy of this TrainEngine, either separation or colocation. "
525+
"Currently only used by the TrainController."
526+
},
496527
)
497-
scheduling_strategy: SchedulingStrategy = field(default_factory=SchedulingStrategy)
528+
529+
def __post_init__(self):
530+
"""Validate scheduling_spec length."""
531+
if len(self.scheduling_spec) not in (1, 2):
532+
raise ValueError(
533+
f"scheduling_spec must contain 1 or 2 SchedulingSpec, "
534+
f"got {len(self.scheduling_spec)}"
535+
)
498536

499537

500538
@dataclass
@@ -605,6 +643,18 @@ class PPOActorConfig(TrainEngineConfig):
605643
"choices": ["token", "sequence"],
606644
},
607645
)
646+
# Proximal Log-Probability Computation Method
647+
prox_logp_method: str = field(
648+
default=PROX_LOGP_METHOD_RECOMPUTE,
649+
metadata={
650+
"help": "Method for computing proximal policy log-probabilities in decoupled PPO. "
651+
"Only effective when use_decoupled_loss=True. Options: "
652+
"'recompute' (default): Standard decoupled PPO, recompute proximal policy via forward pass. "
653+
"'loglinear': Use log-linear interpolation to approximate proximal policy (skip forward pass). "
654+
"'metrics': Like 'recompute', but also compute approximation metrics for evaluation.",
655+
"choices": PROX_LOGP_METHODS_ALL,
656+
},
657+
)
608658
# Advanced Options
609659
dynamic_sampling: bool = field(
610660
default=False,
@@ -702,6 +752,8 @@ class vLLMConfig:
702752
)
703753
enable_sleep_mode: bool = False
704754
uvicorn_log_level: str = "warning"
755+
enable_lora: bool = False
756+
lora_modules: str = ""
705757

706758
@staticmethod
707759
def build_args(
@@ -726,6 +778,18 @@ def build_args(
726778
args["port"] = port
727779
if host is not None:
728780
args["host"] = host
781+
# handle lora modules separately
782+
lm = args.get("lora_modules")
783+
if lm:
784+
if isinstance(lm, str):
785+
lm = [lm]
786+
if isinstance(lm, (list, tuple)):
787+
try:
788+
args["lora_modules"] = [
789+
json.dumps(json.loads(s), separators=(",", ":")) for s in lm
790+
]
791+
except json.JSONDecodeError as e:
792+
raise ValueError(f"Invalid JSON string in lora_modules: {e}") from e
729793
return args
730794

731795
@staticmethod
@@ -977,13 +1041,36 @@ class InferenceEngineConfig:
9771041
"help": "The grace period after calling /pause_generation. Wait until all requests have been dropped."
9781042
},
9791043
)
980-
scheduling_spec: SchedulingSpec = field(
981-
default_factory=lambda: SchedulingSpec(
982-
cmd="python -m areal.scheduler.rpc.rpc_server"
1044+
scheduling_spec: tuple[SchedulingSpec, ...] = field(
1045+
default_factory=lambda: (
1046+
SchedulingSpec(cmd="python -m areal.scheduler.rpc.rpc_server"),
9831047
),
984-
metadata={"help": "inference engine schedule specs"},
1048+
metadata={
1049+
"help": "inference engine schedule specs. Can accept 1 or 2 SchedulingSpec: "
1050+
"if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; "
1051+
"if 2 specs provided, first one is for worker, second one is for engine. "
1052+
"Currently only used by the RolloutController."
1053+
},
1054+
)
1055+
scheduling_strategy: SchedulingStrategy = field(
1056+
default_factory=SchedulingStrategy,
1057+
metadata={
1058+
"help": "The scheduling strategy of this TrainEngine, either separation or colocation. "
1059+
"Currently only used by the RolloutController."
1060+
},
1061+
)
1062+
use_lora: bool = field(
1063+
default=False,
1064+
metadata={"help": "Whether to use LoRA. Should be same as actors LORA option."},
9851065
)
986-
scheduling_strategy: SchedulingStrategy = field(default_factory=SchedulingStrategy)
1066+
1067+
def __post_init__(self):
1068+
"""Validate scheduling_spec length."""
1069+
if len(self.scheduling_spec) not in (1, 2):
1070+
raise ValueError(
1071+
f"scheduling_spec must contain 1 or 2 SchedulingSpec, "
1072+
f"got {len(self.scheduling_spec)}"
1073+
)
9871074

9881075

9891076
@dataclass
@@ -1148,6 +1235,15 @@ class PerfTracerConfig:
11481235
)
11491236
},
11501237
)
1238+
profile_steps: list[int] | None = field(
1239+
default=None,
1240+
metadata={
1241+
"help": (
1242+
"List of step numbers at which to capture detailed profiling traces. "
1243+
"If None, no detailed profiling traces are captured."
1244+
)
1245+
},
1246+
)
11511247
session_tracer: SessionTracerConfig | None = field(
11521248
default=None,
11531249
metadata={"help": "Session tracing configuration."},
@@ -1223,7 +1319,7 @@ class SchedulerConfig:
12231319

12241320

12251321
@dataclass
1226-
class DatasetConfig:
1322+
class _DatasetConfig:
12271323
"""Configuration for dataset loading and preprocessing."""
12281324

12291325
path: str = field(
@@ -1262,6 +1358,27 @@ class DatasetConfig:
12621358
)
12631359

12641360

1361+
@dataclass
1362+
class TrainDatasetConfig(_DatasetConfig):
1363+
"""Configuration for training dataset loading and preprocessing."""
1364+
1365+
1366+
@dataclass
1367+
class ValidDatasetConfig(_DatasetConfig):
1368+
"""Configuration for validation dataset loading and preprocessing.
1369+
1370+
It has different default values with `TrainDatasetConfig`.
1371+
`shuffle` and `drop_last` default to False.
1372+
"""
1373+
1374+
shuffle: bool = field(
1375+
default=False, metadata={"help": "Whether to shuffle the dataset"}
1376+
)
1377+
drop_last: bool = field(
1378+
default=False, metadata={"help": "Drop the last incomplete batch"}
1379+
)
1380+
1381+
12651382
@dataclass
12661383
class SlurmLauncherConfig:
12671384
"""Configuration for launching the training jobs with Slurm."""
@@ -1359,6 +1476,13 @@ class BaseExperimentConfig:
13591476
metadata={"help": "Pattern-based GPU parallel strategy allocation mode. "},
13601477
)
13611478
seed: int = field(default=1, metadata={"help": "Random seed for reproducibility."})
1479+
enable_offload: bool = field(
1480+
default=False,
1481+
metadata={
1482+
"help": "Whether to enable training offload using torch_memory_saver. "
1483+
"This requires setting up the environment for TMS (e.g., via LD_PRELOAD)."
1484+
},
1485+
)
13621486
total_train_epochs: int = field(
13631487
default=1, metadata={"help": "Total number of epochs to train the model."}
13641488
)
@@ -1381,8 +1505,8 @@ class BaseExperimentConfig:
13811505
metadata={"help": "Path to the tokenizer."},
13821506
)
13831507

1384-
train_dataset: DatasetConfig = field(default_factory=DatasetConfig)
1385-
valid_dataset: DatasetConfig | None = field(default=None)
1508+
train_dataset: TrainDatasetConfig = field(default_factory=TrainDatasetConfig)
1509+
valid_dataset: ValidDatasetConfig | None = field(default=None)
13861510

13871511
saver: SaverConfig = field(default_factory=SaverConfig)
13881512
evaluator: EvaluatorConfig = field(default_factory=EvaluatorConfig)
@@ -1466,7 +1590,7 @@ def to_structured_cfg(cfg, config_cls):
14661590
return cfg
14671591

14681592

1469-
def load_expr_config(argv: list[str], config_cls):
1593+
def load_expr_config(argv: list[str], config_cls: type[ConfigT]) -> tuple[ConfigT, str]:
14701594
cfg, config_file = parse_cli_args(argv)
14711595
cfg = to_structured_cfg(cfg, config_cls=config_cls)
14721596
cfg = OmegaConf.to_object(cfg)

0 commit comments

Comments
 (0)