Skip to content

Commit a690509

Browse files
authored
Bug fix in benchmark ckpt loading and megatron hf save (#392)
1 parent d9d2135 commit a690509

File tree

9 files changed

+103
-64
lines changed

9 files changed

+103
-64
lines changed

examples/bots/workflow/bots_reward.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# Adapted from Reasoning360: https://github.com/LLM360/Reasoning360/blob/main/verl/utils/reward_score/naive_dapo.py
22

3+
import concurrent
34
import contextlib
45
import math
56
import re
7+
import resource
68
from math import isclose
79
from typing import Optional, Union
810

@@ -585,17 +587,25 @@ def should_allow_eval(expr: str):
585587

586588
# @timeout(timeout_seconds=10)
587589
def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str):
588-
are_equal = False
589-
try:
590+
def check_equal():
591+
memory_size = 1024**3
592+
resource.setrlimit(resource.RLIMIT_AS, (memory_size, memory_size))
593+
590594
expr = f"({ground_truth_normalized})-({given_normalized})"
591595
if should_allow_eval(expr):
592596
sympy_diff = _sympy_parse(expr)
593597
simplified = sympy.simplify(sympy_diff)
594598
if simplified == 0:
595-
are_equal = True
596-
except Exception:
597-
pass
598-
return are_equal
599+
return True
600+
return False
601+
602+
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
603+
future = executor.submit(check_equal)
604+
try:
605+
return future.result(timeout=10)
606+
except (concurrent.futures.TimeoutError, Exception):
607+
future.cancel()
608+
return False
599609

600610

601611
def split_tuple(expr: str):

tests/trainer/trainer_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ def test_trainer(self):
113113
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 8)
114114
actor_kl_metrics = parser.metric_list("actor/kl")
115115
self.assertTrue(len(actor_kl_metrics) > 0)
116+
actor_kl_loss = parser.metric_values("actor/kl_loss")
117+
self.assertEqual(actor_kl_loss[0], 0.0)
116118
critic_kl_metrics = parser.metric_list("critic/kl")
117119
self.assertTrue(len(critic_kl_metrics) > 0)
118120
response_metrics = parser.metric_list("response_length")
@@ -138,6 +140,9 @@ def test_trainer(self):
138140
self.config.mode = "bench"
139141
self.config.synchronizer.sync_method = SyncMethod.CHECKPOINT
140142
self.config.explorer.bench_on_latest_checkpoint = False
143+
self.config.buffer.explorer_input.taskset = None
144+
self.config.buffer.explorer_input.tasksets = []
145+
self.config.buffer.trainer_input.experience_buffer = None
141146
self.config.check_and_update()
142147
bench(self.config)
143148
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))

trinity/buffer/pipelines/task_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66

77
def check_and_run_task_pipeline(config: Config) -> Dict:
8+
if config.mode not in {"explore", "train", "both"}:
9+
return {}
810
if config.data_processor.task_pipeline is None:
911
return {}
1012

trinity/common/config.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -853,8 +853,8 @@ def _check_interval(self) -> None:
853853
)
854854

855855
def _check_explorer_input(self) -> None:
856-
if self.mode == "train":
857-
# no need to check explorer_input in train mode
856+
if self.mode in {"train", "serve"}:
857+
# no need to check explorer_input in serve mode
858858
return
859859

860860
explorer_input = self.buffer.explorer_input
@@ -864,12 +864,11 @@ def _check_explorer_input(self) -> None:
864864
raise ValueError("Do not support setting `taskset` and `tasksets` simultaneously!")
865865
explorer_input.tasksets = [explorer_input.taskset]
866866
explorer_input.taskset = None
867-
elif len(explorer_input.tasksets) == 0:
867+
elif self.mode != "bench" and len(explorer_input.tasksets) == 0:
868868
raise ValueError("At least one taskset should be provided in explorer_input!")
869-
tasksets = explorer_input.tasksets
870869

871-
for i, taskset in enumerate(tasksets):
872-
if self.mode != "train" and not taskset.path:
870+
for i, taskset in enumerate(explorer_input.tasksets):
871+
if not taskset.path:
873872
raise ValueError(
874873
"`buffer.explorer_input.taskset.path` is required, please set it to the path of the taskset."
875874
)
@@ -914,6 +913,10 @@ def _check_explorer_input(self) -> None:
914913
set_if_none(dataset.rollout_args, "max_tokens", self.model.max_response_tokens)
915914

916915
def _check_trainer_input(self) -> None:
916+
if self.mode == "bench":
917+
# no need to check trainer_input in bench mode
918+
return
919+
917920
trainer_input = self.buffer.trainer_input
918921
experience_buffer = trainer_input.experience_buffer
919922

@@ -973,7 +976,7 @@ def _default_storage_path(self, storage_type: StorageType, name: str) -> str:
973976
def _check_data_processor(self) -> None:
974977
# check input/output buffers in pipelines
975978
experience_pipeline = self.data_processor.experience_pipeline
976-
if experience_pipeline is not None:
979+
if experience_pipeline is not None and self.mode in {"explore", "both", "serve"}:
977980
if experience_pipeline.save_input and experience_pipeline.input_save_path is None:
978981
experience_pipeline.input_save_path = os.path.join(
979982
self.buffer.cache_dir, "explorer_output.jsonl" # type: ignore[arg-type]
@@ -983,10 +986,15 @@ def _check_data_processor(self) -> None:
983986
)
984987

985988
task_pipeline = self.data_processor.task_pipeline
986-
if task_pipeline is not None:
989+
if task_pipeline is not None and self.mode in {"explore", "train", "both"}:
987990
if task_pipeline.output is None:
988991
if self.mode != "train":
989-
task_pipeline.output = self.buffer.explorer_input.tasksets[0]
992+
if len(self.buffer.explorer_input.tasksets) > 0:
993+
task_pipeline.output = self.buffer.explorer_input.tasksets[0]
994+
else:
995+
raise ValueError(
996+
"At least one taskset should be provided in explorer_input!"
997+
)
990998
elif self.mode == "train" and self.algorithm.algorithm_type in {"dpo", "sft"}:
991999
task_pipeline.output = self.buffer.trainer_input.experience_buffer
9921000
else:

trinity/common/verl_config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class FSDPConfig:
8585
wrap_policy: WrapPolicy = field(default_factory=WrapPolicy)
8686
fsdp_size: int = -1
8787
forward_prefetch: bool = False
88+
model_dtype: Optional[str] = None
8889

8990

9091
@dataclass
@@ -163,8 +164,6 @@ class Actor:
163164
clip_ratio_high: Optional[float] = None
164165
entropy_coeff: float = 0.001
165166
use_kl_loss: bool = False
166-
kl_loss_coef: float = 0.0
167-
kl_loss_type: str = "low_var_kl"
168167

169168

170169
@dataclass

trinity/explorer/explorer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ def __init__(self, config: Config):
5252
self.models, self.auxiliary_models = create_inference_models(config)
5353
self.experience_pipeline = self._init_experience_pipeline()
5454
self.taskset = (
55-
TasksetScheduler(explorer_state, config) if self.config.mode != "serve" else None
55+
TasksetScheduler(explorer_state, config)
56+
if self.config.mode not in {"bench", "serve"}
57+
else None
5658
)
5759
self.scheduler = None
5860
self.monitor = MONITOR.get(self.config.monitor.monitor_type)(
@@ -151,7 +153,8 @@ async def prepare(self) -> None:
151153
"""Preparation before running."""
152154
try:
153155
# prepare experience pipeline
154-
await self.experience_pipeline.prepare.remote()
156+
if self.experience_pipeline:
157+
await self.experience_pipeline.prepare.remote()
155158
self.logger.info("Experience pipeline is ready.")
156159
# make sure all rollout models are ready
157160
run_api_ref = [model.run_api_server.remote() for model in self.models]
@@ -406,6 +409,8 @@ async def is_alive(self) -> bool:
406409

407410
def _init_experience_pipeline(self) -> ray.actor.ActorHandle:
408411
"""Init experience pipeline for the explorer."""
412+
if self.config.mode == "bench":
413+
return None
409414
node_id = ray.get_runtime_context().get_node_id()
410415
return (
411416
ray.remote(ExperiencePipeline)

trinity/manager/synchronizer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@ def __init__(self, config: Config, module_ref: ray.actor.ActorHandle):
4444
self._modules = {module_ref}
4545
self._modules_lock = asyncio.Lock()
4646
asyncio.create_task(self._check_modules())
47-
if self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT:
47+
if (
48+
self.config.mode != "bench"
49+
and self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT
50+
):
4851
asyncio.create_task(self._find_latest_state_dict())
4952

5053
async def add_module(self, module_ref: ray.actor.ActorHandle) -> None:

trinity/trainer/verl/fsdp_workers.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,10 @@ def __init__(self, config: DictConfig, role: str):
168168
self.config.actor.ppo_micro_batch_size
169169
)
170170

171-
if self.config.actor.ppo_micro_batch_size_per_gpu is not None:
171+
if (
172+
not self.config.actor.use_dynamic_bsz
173+
and self.config.actor.ppo_micro_batch_size_per_gpu is not None
174+
):
172175
assert (
173176
self.config.actor.ppo_mini_batch_size
174177
% self.config.actor.ppo_micro_batch_size_per_gpu
@@ -181,7 +184,11 @@ def __init__(self, config: DictConfig, role: str):
181184
), f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}"
182185

183186
# normalize ref config
184-
if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None:
187+
if (
188+
self._is_ref
189+
and not self.config.ref.log_prob_use_dynamic_bsz
190+
and self.config.ref.log_prob_micro_batch_size is not None
191+
):
185192
self.config.ref.log_prob_micro_batch_size //= (
186193
self.device_mesh.size() // self.ulysses_sequence_parallel_size
187194
)
@@ -246,7 +253,7 @@ def _build_model_optimizer( # noqa: C901
246253
else:
247254
self.tokenizer.chat_template = self.config.model.custom_chat_template
248255

249-
torch_dtype = fsdp_config.get("model_dtype", None)
256+
torch_dtype = fsdp_config.model_dtype
250257
if torch_dtype is None:
251258
torch_dtype = torch.float32 if self._is_actor else torch.bfloat16
252259
else:
@@ -326,9 +333,6 @@ def _build_model_optimizer( # noqa: C901
326333
fused_kernels_backend=fused_kernels_backend,
327334
)
328335

329-
# some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2
330-
actor_module.to(torch_dtype)
331-
332336
if enable_gradient_checkpointing:
333337
actor_module.gradient_checkpointing_enable(
334338
gradient_checkpointing_kwargs={"use_reentrant": False}
@@ -971,7 +975,7 @@ def __init__(self, config):
971975
self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size
972976
self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size
973977

974-
if self.config.ppo_micro_batch_size_per_gpu is not None:
978+
if not self.config.use_dynamic_bsz and self.config.ppo_micro_batch_size_per_gpu is not None:
975979
assert (
976980
self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0
977981
), f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}"
@@ -1020,7 +1024,7 @@ def _build_critic_model_optimizer(self, config): # noqa: C901
10201024
if self.rank == 0:
10211025
print(f"Critic overriding config {override_config_kwargs}")
10221026

1023-
torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32")
1027+
torch_dtype = self.config.model.fsdp_config.model_dtype or "fp32"
10241028
torch_dtype = PrecisionType.to_dtype(torch_dtype)
10251029

10261030
from transformers import AutoConfig
@@ -1060,9 +1064,6 @@ def _build_critic_model_optimizer(self, config): # noqa: C901
10601064
ulysses_sp_size=self.ulysses_sequence_parallel_size,
10611065
)
10621066

1063-
# some parameters may not in torch_dtype
1064-
critic_module.to(torch_dtype)
1065-
10661067
if config.model.get("enable_gradient_checkpointing", False):
10671068
critic_module.gradient_checkpointing_enable(
10681069
gradient_checkpointing_kwargs={"use_reentrant": False}

trinity/trainer/verl/megatron_checkpoint_manager.py

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -233,44 +233,50 @@ def save_checkpoint( # noqa: C901
233233
json.dump(transformer_config_dict, f, indent=2)
234234

235235
if self.should_save_hf_model or save_as_hf:
236-
# wait for everyone to dump to local
237-
state_dict = self.weight_saver(
238-
self.model,
239-
self.hf_config,
240-
dtype=self.param_dtype,
241-
is_value_model=self.is_value_model,
242-
tie_word_embeddings=self.share_embeddings_and_output_weights,
243-
)
236+
try:
237+
# wait for everyone to dump to local
238+
state_dict = self.weight_saver(
239+
self.model,
240+
self.hf_config,
241+
dtype=self.param_dtype,
242+
is_value_model=self.is_value_model,
243+
tie_word_embeddings=self.share_embeddings_and_output_weights,
244+
)
244245

245-
torch.distributed.barrier()
246-
if self.rank == 0:
247-
# TODO: async save or use mbridge to save hf model
248-
hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path)
249-
import warnings
246+
torch.distributed.barrier()
247+
if self.rank == 0:
248+
# TODO: async save or use mbridge to save hf model
249+
hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path)
250+
import warnings
250251

251-
from accelerate import init_empty_weights
252+
from accelerate import init_empty_weights
252253

253-
with init_empty_weights(), warnings.catch_warnings():
254-
warnings.simplefilter("ignore")
255-
if "mistral7b-rm" in self.config.model.path:
256-
from transformers import MistralForSequenceClassification
254+
with init_empty_weights(), warnings.catch_warnings():
255+
warnings.simplefilter("ignore")
256+
if "mistral7b-rm" in self.config.model.path:
257+
from transformers import MistralForSequenceClassification
257258

258-
model = MistralForSequenceClassification.from_pretrained(
259-
self.config.model.path
260-
) # use score head instead of lm_head
261-
state_dict["score.weight"] = state_dict["score.weight"]
262-
else:
263-
from transformers import AutoModelForCausalLM
259+
model = MistralForSequenceClassification.from_pretrained(
260+
self.config.model.path
261+
) # use score head instead of lm_head
262+
state_dict["score.weight"] = state_dict["score.weight"]
263+
else:
264+
from transformers import AutoModelForCausalLM
264265

265-
model = AutoModelForCausalLM.from_pretrained(
266-
self.config.model.path, torch_dtype="auto"
267-
)
268-
model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict)
269-
log_with_rank(
270-
f"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}",
271-
rank=self.rank,
272-
logger=logger,
273-
log_only_rank_0=True,
266+
model = AutoModelForCausalLM.from_pretrained(
267+
self.config.model.path, torch_dtype="auto"
268+
)
269+
model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict)
270+
log_with_rank(
271+
f"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}",
272+
rank=self.rank,
273+
logger=logger,
274+
log_only_rank_0=True,
275+
)
276+
except Exception:
277+
logger.error(
278+
f"Failed to save Huggingface model to {local_path}, you can try to set `use_mbridge=true` to save it.",
279+
exc_info=True,
274280
)
275281

276282
ray.get(

0 commit comments

Comments
 (0)