@@ -151,6 +151,82 @@ def __init__(self, config: DictConfig, role: str, **kwargs):
151151 )
152152 self ._ref_is_offload_param = self .config .ref .megatron .get ("param_offload" , False )
153153
154+ def _init_hf_config_and_tf_config (
155+ self ,
156+ model_path ,
157+ tokenizer_or_path ,
158+ dtype ,
159+ override_model_config ,
160+ override_transformer_config ,
161+ trust_remote_code = False ,
162+ use_mbridge = False ,
163+ ):
164+ from transformers import AutoConfig
165+ from verl .models .mcore import hf_to_mcore_config
166+ from verl .utils import hf_processor , hf_tokenizer
167+ from verl .utils .fs import copy_to_local
168+ from verl .utils .model import update_model_config
169+
170+ # Step 1: initialize the tokenizer
171+ self .local_path = copy_to_local (model_path )
172+ if tokenizer_or_path is None :
173+ self .tokenizer = hf_tokenizer (self .local_path , trust_remote_code = trust_remote_code )
174+ self .processor = hf_processor (self .local_path , trust_remote_code = trust_remote_code )
175+ elif isinstance (tokenizer_or_path , str ):
176+ self .tokenizer = hf_tokenizer (
177+ copy_to_local (tokenizer_or_path ), trust_remote_code = trust_remote_code
178+ )
179+ self .processor = hf_processor (
180+ copy_to_local (tokenizer_or_path ), trust_remote_code = trust_remote_code
181+ )
182+ else :
183+ self .tokenizer = tokenizer_or_path
184+ self .processor = tokenizer_or_path
185+
186+ if self .config .model .get ("custom_chat_template" , None ) is not None :
187+ if self .processor is not None :
188+ self .processor .chat_template = self .config .model .custom_chat_template
189+ else :
190+ self .tokenizer .chat_template = self .config .model .custom_chat_template
191+
192+ # Step 2: get the hf
193+ hf_config = AutoConfig .from_pretrained (self .local_path , trust_remote_code = trust_remote_code )
194+
195+ # Step 3: override the hf config
196+ override_config_kwargs = {
197+ "bos_token_id" : self .tokenizer .bos_token_id ,
198+ "eos_token_id" : self .tokenizer .eos_token_id ,
199+ "pad_token_id" : self .tokenizer .pad_token_id ,
200+ }
201+ override_config_kwargs .update (override_model_config .get ("model_config" , {}))
202+
203+ # patch for rope
204+ if self .config .model .rope_scaling is not None :
205+ hf_config .rope_scaling = OmegaConf .to_container (self .config .model .rope_scaling )
206+ if self .config .model .rope_theta is not None :
207+ hf_config .rope_theta = self .config .model .rope_theta
208+
209+ self .share_embeddings_and_output_weights = getattr (hf_config , "tie_word_embeddings" , False )
210+ update_model_config (hf_config , override_config_kwargs = override_config_kwargs )
211+ self .architectures = getattr (hf_config , "architectures" , None )
212+ if self .rank == 0 :
213+ print (f"Model config after override: { hf_config } " )
214+ tf_config = hf_to_mcore_config (hf_config , dtype , ** override_transformer_config )
215+
216+ if use_mbridge :
217+ from verl .models .mcore .mbridge import AutoBridge
218+
219+ bridge = AutoBridge .from_config (hf_config )
220+ bridge .set_extra_args (** override_transformer_config )
221+ tf_config = bridge .config
222+ self .bridge = bridge
223+ else :
224+ self .bridge = None
225+
226+ print (f"TF config: { tf_config } " )
227+ self .hf_config = hf_config
228+ self .tf_config = tf_config
229+
154230 def _build_model_optimizer (
155231 self , model_path , optim_config , override_model_config , override_transformer_config
156232 ):
0 commit comments