Skip to content

Commit 83b786e

Browse files
authored
feat: implement proximal log-probability approximation for decoupled PPO (#600)
* feat: implement proximal log-probability approximation for decoupled PPO Implement proximal log-probability approximation to eliminate expensive forward passes in decoupled (off-policy) PPO training. * docs: remove rollout from user-facing documentation * feat: always log compute_logp metrics and add importance_weight tracking
1 parent 1f73719 commit 83b786e

File tree

10 files changed

+1899
-63
lines changed

10 files changed

+1899
-63
lines changed

areal/api/cli_args.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616

1717
from areal.platforms import current_platform
1818
from areal.utils import logging, name_resolve, pkg_version
19+
from areal.utils.constants import (
20+
PROX_LOGP_METHOD_RECOMPUTE,
21+
PROX_LOGP_METHODS_ALL,
22+
)
1923
from areal.utils.pkg_version import is_version_less
2024

2125
uvloop.install()
@@ -639,6 +643,18 @@ class PPOActorConfig(TrainEngineConfig):
639643
"choices": ["token", "sequence"],
640644
},
641645
)
646+
# Proximal Log-Probability Computation Method
647+
prox_logp_method: str = field(
648+
default=PROX_LOGP_METHOD_RECOMPUTE,
649+
metadata={
650+
"help": "Method for computing proximal policy log-probabilities in decoupled PPO. "
651+
"Only effective when use_decoupled_loss=True. Options: "
652+
"'recompute' (default): Standard decoupled PPO, recompute proximal policy via forward pass. "
653+
"'loglinear': Use log-linear interpolation to approximate proximal policy (skip forward pass). "
654+
"'metrics': Like 'recompute', but also compute approximation metrics for evaluation.",
655+
"choices": PROX_LOGP_METHODS_ALL,
656+
},
657+
)
642658
# Advanced Options
643659
dynamic_sampling: bool = field(
644660
default=False,

areal/engine/ppo/actor.py

Lines changed: 405 additions & 9 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)