Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} # TRI
- `explore`: Only launches the explorer.
- `bench`: Used for benchmarking.
- `checkpoint_root_dir`: Root directory where all checkpoints and logs will be saved. Checkpoints for this experiment will be stored in `<checkpoint_root_dir>/<project>/<name>/`.
- `continue_from_checkpoint`: If set to `true`, the experiment will continue from the latest checkpoint in the checkpoint path (if any); otherwise, it will rename the current experiment to `<name>_<timestamp>` and start a new experiment.
- `continue_from_checkpoint`: If set to `true`, the experiment will continue from the latest checkpoint in the checkpoint path (if any); otherwise, it will rename the current experiment to `<name>_<timestamp>` and start a new experiment. Due to our decoupled design, during recovery from a checkpoint, we can only guarantee that the Trainer's model parameters and its optional auxiliary buffers (`auxiliary_buffers`) are restored to their latest checkpointed states, while the Explorer and Experience Buffer cannot be guaranteed to be restored to the same point in time.
- `ray_namespace`: Namespace for the modules launched in the current experiment. If not specified, it will be set to `<project>/<name>`.

---
Expand Down
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source_zh/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} # TRI
- `explore`: 仅启动 explorer。
- `bench`: 用于 benchmark 测试。
- `checkpoint_root_dir`: 所有检查点和日志的根目录。该实验的检查点将存储在 `<checkpoint_root_dir>/<project>/<name>/` 路径下。
- `continue_from_checkpoint`: 若设置为 `true`,实验将从检查点路径中的最新检查点继续;否则,会将当前实验重命名为 `<name>_<timestamp>` 并启动新实验。
- `continue_from_checkpoint`: 若设置为 `true`,实验将从检查点路径中的最新检查点继续;否则,会将当前实验重命名为 `<name>_<timestamp>` 并启动新实验。由于我们的分离式设计,从检查点恢复的时候,我们只能保证Trainer的模型参数以及其使用的可选缓冲区(`auxiliary_buffers`)可以恢复到最新检查点的状态,而Explorer和Experience Buffer不能保证恢复到同一时点。
- `ray_namespace`: 当前实验中启动模块的命名空间。若未指定,则默认为 `<project>/<name>`。

---
Expand Down
4 changes: 2 additions & 2 deletions tests/template/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ model:
max_response_tokens: 2048
max_model_len: 4096
cluster: # 2 for explorer, 2 for trainer
node_num: 2
gpu_per_node: 2
node_num: ${oc.env:NODE_NUM,2}
gpu_per_node: ${oc.env:GPU_PER_NODE,2}
buffer:
total_epochs: 1
batch_size: 4
Expand Down
28 changes: 27 additions & 1 deletion tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for trainer."""

import json
import multiprocessing
import os
import shutil
Expand Down Expand Up @@ -809,7 +810,7 @@ def test_trainer(self):
self.config.algorithm.policy_loss_fn = "mix"
self.config.buffer.batch_size = 4
self.config.buffer.train_batch_size = 32
self.config.buffer.total_epochs = 1
self.config.buffer.total_steps = 2
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
self.config.synchronizer.sync_interval = 1
self.config.trainer.save_interval = 1
Expand All @@ -823,6 +824,31 @@ def test_trainer(self):
self.config.buffer.trainer_input.experience_buffer.max_read_timeout = 20
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
both(self.config)
ray.shutdown(_exiting_interpreter=True)

# check trainer resume metadata
trainer_meta_file = os.path.join(self.config.checkpoint_job_dir, "trainer_meta.json")
with open(trainer_meta_file) as f:
trainer_meta = json.load(f)
self.assertEqual(trainer_meta["latest_iteration"], 2)
self.assertEqual(
trainer_meta["sample_strategy_state"]["expert_buffer"]["current_index"], 32
)

self.config.buffer.total_steps = None
self.config.buffer.total_epochs = 1
self.config.check_and_update()
ray.init(ignore_reinit_error=True, namespace=self.config.ray_namespace)
both(self.config)

# check trainer resume metadata
with open(trainer_meta_file) as f:
trainer_meta = json.load(f)
self.assertEqual(trainer_meta["latest_iteration"], 4)
self.assertEqual(
trainer_meta["sample_strategy_state"]["expert_buffer"]["current_index"], 64
)

parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))

# test rollout metrics
Expand Down
12 changes: 12 additions & 0 deletions trinity/algorithm/sample_strategy/mix_sample_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,15 @@ def default_args(cls) -> Dict:
"expert_data_ratio": 0.5,
"sft_dataset_name": "sft_dataset",
}

def state_dict(self) -> dict:
return {
"usal_buffer": self.usual_exp_buffer.state_dict(),
"expert_buffer": self.expert_exp_buffer.state_dict(),
}

def load_state_dict(self, state_dict: dict) -> None:
if state_dict.get("usal_buffer", None):
self.usual_exp_buffer.load_state_dict(state_dict["usal_buffer"])
if state_dict.get("expert_buffer", None):
self.expert_exp_buffer.load_state_dict(state_dict["expert_buffer"])
15 changes: 15 additions & 0 deletions trinity/algorithm/sample_strategy/sample_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ async def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
def default_args(cls) -> dict:
"""Get the default arguments of the sample strategy."""

@abstractmethod
def state_dict(self) -> dict:
"""Get the state dict of the sample strategy."""

@abstractmethod
def load_state_dict(self, state_dict: dict) -> None:
"""Load the state dict of the sample strategy."""


@SAMPLE_STRATEGY.register_module("default")
class DefaultSampleStrategy(SampleStrategy):
Expand All @@ -64,6 +72,13 @@ async def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]:
def default_args(cls) -> dict:
return {}

def state_dict(self) -> dict:
return self.exp_buffer.state_dict()

def load_state_dict(self, state_dict: dict) -> None:
if state_dict:
self.exp_buffer.load_state_dict(state_dict)


@Deprecated
@SAMPLE_STRATEGY.register_module("warmup")
Expand Down
4 changes: 2 additions & 2 deletions trinity/buffer/reader/queue_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ async def read_async(self, batch_size: Optional[int] = None) -> List:
return exps

def state_dict(self) -> Dict:
# SQL Not supporting state dict yet
# Queue Not supporting state dict yet
return {"current_index": 0}

def load_state_dict(self, state_dict):
# SQL Not supporting state dict yet
# Queue Not supporting state dict yet
return None
4 changes: 2 additions & 2 deletions trinity/manager/state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,14 @@ def load_explorer_server_url(self) -> Optional[str]:

def save_trainer(
self,
current_exp_index: int,
current_step: int,
sample_strategy_state: dict,
) -> None:
with open(self.trainer_state_path, "w", encoding="utf-8") as f:
json.dump(
{
"latest_exp_index": current_exp_index,
"latest_iteration": current_step,
"sample_strategy_state": sample_strategy_state,
},
f,
indent=2,
Expand Down
15 changes: 10 additions & 5 deletions trinity/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import ray

from trinity.algorithm import SAMPLE_STRATEGY
from trinity.algorithm.sample_strategy.sample_strategy import SampleStrategy
from trinity.common.config import Config
from trinity.common.constants import RunningStatus, SyncMethod, SyncStyle
from trinity.common.experience import Experiences
Expand All @@ -38,9 +39,6 @@ def __init__(self, config: Config) -> None:
path=config.checkpoint_job_dir, trainer_name=config.trainer.name, config=config
)
trainer_state = self.state.load_trainer()
config.buffer.trainer_input.experience_buffer.index = trainer_state.get(
"latest_exp_index", 0
)
self.last_trainer_sync_step = 0
self.monitor = MONITOR.get(config.monitor.monitor_type)(
project=config.project,
Expand All @@ -50,10 +48,17 @@ def __init__(self, config: Config) -> None:
config=config,
)
self._sample_exps_to_log = []
self.sample_strategy = SAMPLE_STRATEGY.get(config.algorithm.sample_strategy)(
self.sample_strategy: SampleStrategy = SAMPLE_STRATEGY.get(
config.algorithm.sample_strategy
)(
buffer_config=config.buffer,
**config.algorithm.sample_strategy_args,
)
if "latest_exp_index" in trainer_state:
sample_strategy_state = {"current_index": trainer_state["latest_exp_index"]}
else:
sample_strategy_state = trainer_state.get("sample_strategy_state", {})
self.sample_strategy.load_state_dict(sample_strategy_state)
self.save_interval = config.trainer.save_interval
self.last_sync_step = None
self.last_sync_time = None
Expand Down Expand Up @@ -190,8 +195,8 @@ def save_checkpoint(self, block_until_saved: bool = False, save_as_hf: bool = Fa
self.logger.info(f"Saving checkpoint at step {self.train_step_num}...")
self.engine.save_checkpoint(block_until_saved=block_until_saved, save_as_hf=save_as_hf)
self.state.save_trainer(
current_exp_index=self.engine.train_step_num * self.config.buffer.train_batch_size,
current_step=self.train_step_num,
sample_strategy_state=self.sample_strategy.state_dict(),
)
return metrics

Expand Down