diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index e9cf4fcded..1b0919074b 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -82,7 +82,7 @@ checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} # TRI - `explore`: Only launches the explorer. - `bench`: Used for benchmarking. - `checkpoint_root_dir`: Root directory where all checkpoints and logs will be saved. Checkpoints for this experiment will be stored in `///`. -- `continue_from_checkpoint`: If set to `true`, the experiment will continue from the latest checkpoint in the checkpoint path (if any); otherwise, it will rename the current experiment to `_` and start a new experiment. +- `continue_from_checkpoint`: If set to `true`, the experiment will continue from the latest checkpoint in the checkpoint path (if any); otherwise, it will rename the current experiment to `_` and start a new experiment. Due to our decoupled design, during recovery from a checkpoint, we can only guarantee that the Trainer's model parameters and its optional auxiliary buffers (`auxiliary_buffers`) are restored to their latest checkpointed states, while the Explorer and Experience Buffer cannot be guaranteed to be restored to the same point in time. - `ray_namespace`: Namespace for the modules launched in the current experiment. If not specified, it will be set to `/`. --- diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md index 7227423385..9f3017aba1 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md @@ -82,7 +82,7 @@ checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} # TRI - `explore`: 仅启动 explorer。 - `bench`: 用于 benchmark 测试。 - `checkpoint_root_dir`: 所有检查点和日志的根目录。该实验的检查点将存储在 `///` 路径下。 -- `continue_from_checkpoint`: 若设置为 `true`,实验将从检查点路径中的最新检查点继续;否则,会将当前实验重命名为 `_` 并启动新实验。 +- `continue_from_checkpoint`: 若设置为 `true`,实验将从检查点路径中的最新检查点继续;否则,会将当前实验重命名为 `_` 并启动新实验。由于我们的分离式设计,从检查点恢复的时候,我们只能保证Trainer的模型参数以及其使用的可选缓冲区(`auxiliary_buffers`)可以恢复到最新检查点的状态,而Explorer和Experience Buffer不能保证恢复到同一时点。 - `ray_namespace`: 当前实验中启动模块的命名空间。若未指定,则默认为 `/`。 --- diff --git a/tests/template/config.yaml b/tests/template/config.yaml index 0f085a25ea..425b327a98 100644 --- a/tests/template/config.yaml +++ b/tests/template/config.yaml @@ -20,8 +20,8 @@ model: max_response_tokens: 2048 max_model_len: 4096 cluster: # 2 for explorer, 2 for trainer - node_num: 2 - gpu_per_node: 2 + node_num: ${oc.env:NODE_NUM,2} + gpu_per_node: ${oc.env:GPU_PER_NODE,2} buffer: total_epochs: 1 batch_size: 4 diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index a873401599..70431df088 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -1,5 +1,6 @@ """Tests for trainer.""" +import json import multiprocessing import os import shutil @@ -809,7 +810,7 @@ def test_trainer(self): self.config.algorithm.policy_loss_fn = "mix" self.config.buffer.batch_size = 4 self.config.buffer.train_batch_size = 32 - self.config.buffer.total_epochs = 1 + self.config.buffer.total_steps = 2 self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") self.config.synchronizer.sync_interval = 1 self.config.trainer.save_interval = 1 @@ -823,6 +824,31 @@ def test_trainer(self): self.config.buffer.trainer_input.experience_buffer.max_read_timeout = 20 self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2 both(self.config) + ray.shutdown(_exiting_interpreter=True) + + # check trainer resume metadata + trainer_meta_file = os.path.join(self.config.checkpoint_job_dir, "trainer_meta.json") + with open(trainer_meta_file) as f: + trainer_meta = json.load(f) + self.assertEqual(trainer_meta["latest_iteration"], 2) + self.assertEqual( + trainer_meta["sample_strategy_state"]["expert_buffer"]["current_index"], 32 + ) + + self.config.buffer.total_steps = None + self.config.buffer.total_epochs = 1 + self.config.check_and_update() + ray.init(ignore_reinit_error=True, namespace=self.config.ray_namespace) + both(self.config) + + # check trainer resume metadata + with open(trainer_meta_file) as f: + trainer_meta = json.load(f) + self.assertEqual(trainer_meta["latest_iteration"], 4) + self.assertEqual( + trainer_meta["sample_strategy_state"]["expert_buffer"]["current_index"], 64 + ) + parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) # test rollout metrics diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py index 306e1b0836..d157489b84 100644 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -111,3 +111,15 @@ def default_args(cls) -> Dict: "expert_data_ratio": 0.5, "sft_dataset_name": "sft_dataset", } + + def state_dict(self) -> dict: + return { + "usal_buffer": self.usual_exp_buffer.state_dict(), + "expert_buffer": self.expert_exp_buffer.state_dict(), + } + + def load_state_dict(self, state_dict: dict) -> None: + if state_dict.get("usal_buffer", None): + self.usual_exp_buffer.load_state_dict(state_dict["usal_buffer"]) + if state_dict.get("expert_buffer", None): + self.expert_exp_buffer.load_state_dict(state_dict["expert_buffer"]) diff --git a/trinity/algorithm/sample_strategy/sample_strategy.py b/trinity/algorithm/sample_strategy/sample_strategy.py index db3dc4f012..27b021146f 100644 --- a/trinity/algorithm/sample_strategy/sample_strategy.py +++ b/trinity/algorithm/sample_strategy/sample_strategy.py @@ -43,6 +43,14 @@ async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: def default_args(cls) -> dict: """Get the default arguments of the sample strategy.""" + @abstractmethod + def state_dict(self) -> dict: + """Get the state dict of the sample strategy.""" + + @abstractmethod + def load_state_dict(self, state_dict: dict) -> None: + """Load the state dict of the sample strategy.""" + @SAMPLE_STRATEGY.register_module("default") class DefaultSampleStrategy(SampleStrategy): @@ -64,6 +72,13 @@ async def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]: def default_args(cls) -> dict: return {} + def state_dict(self) -> dict: + return self.exp_buffer.state_dict() + + def load_state_dict(self, state_dict: dict) -> None: + if state_dict: + self.exp_buffer.load_state_dict(state_dict) + @Deprecated @SAMPLE_STRATEGY.register_module("warmup") diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py index 31ab44d43a..e8debdada3 100644 --- a/trinity/buffer/reader/queue_reader.py +++ b/trinity/buffer/reader/queue_reader.py @@ -43,9 +43,9 @@ async def read_async(self, batch_size: Optional[int] = None) -> List: return exps def state_dict(self) -> Dict: - # SQL Not supporting state dict yet + # Queue Not supporting state dict yet return {"current_index": 0} def load_state_dict(self, state_dict): - # SQL Not supporting state dict yet + # Queue Not supporting state dict yet return None diff --git a/trinity/manager/state_manager.py b/trinity/manager/state_manager.py index eee97e49a5..5cd93c94f5 100644 --- a/trinity/manager/state_manager.py +++ b/trinity/manager/state_manager.py @@ -101,14 +101,14 @@ def load_explorer_server_url(self) -> Optional[str]: def save_trainer( self, - current_exp_index: int, current_step: int, + sample_strategy_state: dict, ) -> None: with open(self.trainer_state_path, "w", encoding="utf-8") as f: json.dump( { - "latest_exp_index": current_exp_index, "latest_iteration": current_step, + "sample_strategy_state": sample_strategy_state, }, f, indent=2, diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index bcc65ccebd..8b93c32dbf 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -14,6 +14,7 @@ import ray from trinity.algorithm import SAMPLE_STRATEGY +from trinity.algorithm.sample_strategy.sample_strategy import SampleStrategy from trinity.common.config import Config from trinity.common.constants import RunningStatus, SyncMethod, SyncStyle from trinity.common.experience import Experiences @@ -38,9 +39,6 @@ def __init__(self, config: Config) -> None: path=config.checkpoint_job_dir, trainer_name=config.trainer.name, config=config ) trainer_state = self.state.load_trainer() - config.buffer.trainer_input.experience_buffer.index = trainer_state.get( - "latest_exp_index", 0 - ) self.last_trainer_sync_step = 0 self.monitor = MONITOR.get(config.monitor.monitor_type)( project=config.project, @@ -50,10 +48,17 @@ def __init__(self, config: Config) -> None: config=config, ) self._sample_exps_to_log = [] - self.sample_strategy = SAMPLE_STRATEGY.get(config.algorithm.sample_strategy)( + self.sample_strategy: SampleStrategy = SAMPLE_STRATEGY.get( + config.algorithm.sample_strategy + )( buffer_config=config.buffer, **config.algorithm.sample_strategy_args, ) + if "latest_exp_index" in trainer_state: + sample_strategy_state = {"current_index": trainer_state["latest_exp_index"]} + else: + sample_strategy_state = trainer_state.get("sample_strategy_state", {}) + self.sample_strategy.load_state_dict(sample_strategy_state) self.save_interval = config.trainer.save_interval self.last_sync_step = None self.last_sync_time = None @@ -190,8 +195,8 @@ def save_checkpoint(self, block_until_saved: bool = False, save_as_hf: bool = Fa self.logger.info(f"Saving checkpoint at step {self.train_step_num}...") self.engine.save_checkpoint(block_until_saved=block_until_saved, save_as_hf=save_as_hf) self.state.save_trainer( - current_exp_index=self.engine.train_step_num * self.config.buffer.train_batch_size, current_step=self.train_step_num, + sample_strategy_state=self.sample_strategy.state_dict(), ) return metrics