@@ -168,7 +168,10 @@ def __init__(self, config: DictConfig, role: str):
168168 self .config .actor .ppo_micro_batch_size
169169 )
170170
171- if self .config .actor .ppo_micro_batch_size_per_gpu is not None :
171+ if (
172+ not self .config .actor .use_dynamic_bsz
173+ and self .config .actor .ppo_micro_batch_size_per_gpu is not None
174+ ):
172175 assert (
173176 self .config .actor .ppo_mini_batch_size
174177 % self .config .actor .ppo_micro_batch_size_per_gpu
@@ -181,7 +184,11 @@ def __init__(self, config: DictConfig, role: str):
181184 ), f"normalized ppo_mini_batch_size { self .config .actor .ppo_mini_batch_size } should be larger than ppo_micro_batch_size_per_gpu { self .config .actor .ppo_micro_batch_size_per_gpu } "
182185
183186 # normalize ref config
184- if self ._is_ref and self .config .ref .log_prob_micro_batch_size is not None :
187+ if (
188+ self ._is_ref
189+ and not self .config .ref .log_prob_use_dynamic_bsz
190+ and self .config .ref .log_prob_micro_batch_size is not None
191+ ):
185192 self .config .ref .log_prob_micro_batch_size //= (
186193 self .device_mesh .size () // self .ulysses_sequence_parallel_size
187194 )
@@ -246,7 +253,7 @@ def _build_model_optimizer( # noqa: C901
246253 else :
247254 self .tokenizer .chat_template = self .config .model .custom_chat_template
248255
249- torch_dtype = fsdp_config .get ( " model_dtype" , None )
256+ torch_dtype = fsdp_config .model_dtype
250257 if torch_dtype is None :
251258 torch_dtype = torch .float32 if self ._is_actor else torch .bfloat16
252259 else :
@@ -326,9 +333,6 @@ def _build_model_optimizer( # noqa: C901
326333 fused_kernels_backend = fused_kernels_backend ,
327334 )
328335
329- # some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2
330- actor_module .to (torch_dtype )
331-
332336 if enable_gradient_checkpointing :
333337 actor_module .gradient_checkpointing_enable (
334338 gradient_checkpointing_kwargs = {"use_reentrant" : False }
@@ -971,7 +975,7 @@ def __init__(self, config):
971975 self .config .ppo_micro_batch_size_per_gpu = self .config .ppo_micro_batch_size
972976 self .config .forward_micro_batch_size_per_gpu = self .config .forward_micro_batch_size
973977
974- if self .config .ppo_micro_batch_size_per_gpu is not None :
978+ if not self . config . use_dynamic_bsz and self .config .ppo_micro_batch_size_per_gpu is not None :
975979 assert (
976980 self .config .ppo_mini_batch_size % self .config .ppo_micro_batch_size_per_gpu == 0
977981 ), f"normalized ppo_mini_batch_size { self .config .ppo_mini_batch_size } should be divisible by ppo_micro_batch_size_per_gpu { self .config .ppo_micro_batch_size_per_gpu } "
@@ -1020,7 +1024,7 @@ def _build_critic_model_optimizer(self, config): # noqa: C901
10201024 if self .rank == 0 :
10211025 print (f"Critic overriding config { override_config_kwargs } " )
10221026
1023- torch_dtype = self .config .model .fsdp_config .get ( " model_dtype" , "fp32" )
1027+ torch_dtype = self .config .model .fsdp_config .model_dtype or "fp32"
10241028 torch_dtype = PrecisionType .to_dtype (torch_dtype )
10251029
10261030 from transformers import AutoConfig
@@ -1060,9 +1064,6 @@ def _build_critic_model_optimizer(self, config): # noqa: C901
10601064 ulysses_sp_size = self .ulysses_sequence_parallel_size ,
10611065 )
10621066
1063- # some parameters may not in torch_dtype
1064- critic_module .to (torch_dtype )
1065-
10661067 if config .model .get ("enable_gradient_checkpointing" , False ):
10671068 critic_module .gradient_checkpointing_enable (
10681069 gradient_checkpointing_kwargs = {"use_reentrant" : False }
0 commit comments