Skip to content

Commit 0c58b5e

Browse files
authored
Add rope_scaling and rope_theta to config (#390)
1 parent ffdf4ff commit 0c58b5e

File tree

7 files changed

+169
-2
lines changed

7 files changed

+169
-2
lines changed

tests/common/vllm_test.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import unittest
23

34
import ray
@@ -935,3 +936,59 @@ async def test_api_tool_calls(self):
935936
print_debug(
936937
"\n" + "=" * 28 + f" test_api_tool_calls PASSED in {total_time:.2f}s " + "=" * 28 + "\n"
937938
)
939+
940+
941+
class TestSuperLongGeneration(RayUnittestBaseAysnc):
942+
def setUp(self):
943+
self.config = get_template_config()
944+
self.config.mode = "explore"
945+
self.config.model.model_path = get_model_path()
946+
self.config.model.max_model_len = 81920
947+
self.config.model.max_prompt_tokens = 61440
948+
self.config.model.max_response_tokens = 20480
949+
self.config.model.rope_scaling = {
950+
"rope_type": "yarn",
951+
"factor": 2.0,
952+
"original_max_position_embeddings": 40960,
953+
}
954+
self.config.explorer.rollout_model.engine_type = "vllm"
955+
self.config.explorer.rollout_model.engine_num = 1
956+
self.config.explorer.rollout_model.tensor_parallel_size = 1
957+
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
958+
959+
self.config.check_and_update()
960+
self.engines, self.auxiliary_engines = create_inference_models(self.config)
961+
self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True)
962+
963+
async def test_generate(self):
964+
base_dir = os.path.dirname(__file__)
965+
target_dir = os.path.join(base_dir, "..", "..", "trinity", "trainer", "verl")
966+
with open(os.path.join(target_dir, "fsdp_workers.py")) as f:
967+
fsdp_code = f.read()
968+
with open(os.path.join(target_dir, "megatron_workers.py")) as f:
969+
megatron_code = f.read()
970+
target_dir = os.path.join(base_dir, "..", "..", "trinity", "common")
971+
with open(os.path.join(target_dir, "config.py")) as f:
972+
config_code = f.read()
973+
target_dir = os.path.join(base_dir, "..", "..", "trinity", "manager")
974+
with open(os.path.join(target_dir, "config_manager.py")) as f:
975+
config_manager_code = f.read()
976+
977+
messages = [
978+
{"role": "system", "content": "You are a helpful assistant."},
979+
{
980+
"role": "user",
981+
"content": """# Please add comments and documentation for these following code, """
982+
"""make sure the code is well-structured and easy to read, """
983+
"""and the complete code must be shown, do not omit any parts.\n"""
984+
f"""## fsdp_workers.py\n{fsdp_code}\n"""
985+
f"""## megatron_workers.py\n{megatron_code}\n"""
986+
f"""## config.py\n{config_code}\n"""
987+
f"""## config_manager.py\n{config_manager_code}\n""",
988+
},
989+
]
990+
response = self.model_wrapper.chat(messages, n=1, temperature=0.7, logprobs=True)[0]
991+
self.assertGreater(
992+
response.prompt_length, 40960
993+
) # If not long enough, please add more files to prompt
994+
self.assertGreater(response.logprobs.shape[0], 1000)

tests/trainer/trainer_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,12 @@ class TestTrainerCountdown(BaseTrainerCase):
7373
def test_trainer(self):
7474
"""Test the both and bench mode."""
7575
# test both mode
76+
self.config.model.rope_scaling = {
77+
"rope_type": "yarn",
78+
"factor": 2.0,
79+
"original_max_position_embeddings": 16384,
80+
}
81+
self.config.model.rope_theta = 10000
7682
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
7783
self.config.buffer.explorer_input.taskset.task_selector = TaskSelectorConfig(
7884
selector_type="shuffle", seed=42

trinity/common/config.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,10 @@ class ModelConfig:
442442
fully_sharded_loras: bool = False
443443
max_cpu_loras: Optional[int] = None
444444

445+
# rope config
446+
rope_scaling: Optional[dict] = None
447+
rope_theta: Optional[float] = None
448+
445449

446450
@dataclass
447451
class InferenceModelConfig:
@@ -503,6 +507,10 @@ class InferenceModelConfig:
503507
lora_modules: Optional[List[Dict]] = None
504508
lora_kwargs: Optional[dict] = field(default_factory=dict)
505509

510+
# ! DO NOT SET, rope config
511+
rope_scaling: Optional[dict] = None
512+
rope_theta: Optional[float] = None
513+
506514

507515
@dataclass
508516
class AlgorithmConfig:
@@ -1195,12 +1203,14 @@ def check_and_update(self) -> Config: # noqa: C901
11951203
"max_response_tokens",
11961204
"min_response_tokens",
11971205
]
1198-
for args in ["model_path"] + rollout_args + length_args:
1206+
rope_args = ["rope_scaling", "rope_theta"]
1207+
model_args = rollout_args + length_args + rope_args
1208+
for args in ["model_path"] + model_args:
11991209
setattr(self.explorer.rollout_model, args, getattr(self.model, args))
12001210
for aux_model in self.explorer.auxiliary_models:
12011211
if not aux_model.model_path:
12021212
raise ValueError("auxiliary model's model_path is required.")
1203-
for args in rollout_args + length_args:
1213+
for args in model_args:
12041214
set_if_none(aux_model, args, getattr(self.model, args))
12051215

12061216
# for lora configs

trinity/common/models/vllm_model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ def __init__(
7777
max_model_len = config.max_model_len
7878
self.enable_lora = config.enable_lora
7979
self.default_lora_path = config.lora_kwargs.pop("default_lora_path", None)
80+
rope_kwargs = {
81+
key: getattr(config, key)
82+
for key in ["rope_scaling", "rope_theta"]
83+
if getattr(config, key) is not None
84+
}
8085
engine_args = vllm.AsyncEngineArgs(
8186
model=config.model_path,
8287
enforce_eager=config.enforce_eager,
@@ -101,6 +106,7 @@ def __init__(
101106
disable_log_stats=True,
102107
enable_lora=config.enable_lora,
103108
logprobs_mode="processed_logprobs",
109+
**rope_kwargs,
104110
**config.lora_kwargs,
105111
)
106112
if get_vllm_version() > parse_version("0.10.0"):

trinity/common/verl_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ class ActorModel:
4040
lora_alpha: int = 32
4141
target_modules: Optional[str] = "all-linear"
4242

43+
# rope configs
44+
rope_scaling: Optional[dict] = None
45+
rope_theta: Optional[float] = None
46+
4347

4448
@dataclass
4549
class Optim:
@@ -412,6 +416,8 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
412416
# Actor / Rollout Config
413417
self.actor_rollout_ref.model.path = config.model.model_path
414418
self.actor_rollout_ref.model.custom_chat_template = config.model.custom_chat_template
419+
self.actor_rollout_ref.model.rope_scaling = config.model.rope_scaling
420+
self.actor_rollout_ref.model.rope_theta = config.model.rope_theta
415421
self.actor_rollout_ref.actor.optim.total_training_steps = self.trainer.total_training_steps
416422
self.actor_rollout_ref.actor.ppo_mini_batch_size = config.buffer.train_batch_size
417423
self.actor_rollout_ref.rollout.temperature = (

trinity/trainer/verl/fsdp_workers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,12 @@ def _build_model_optimizer( # noqa: C901
257257
local_path, trust_remote_code=trust_remote_code, attn_implementation="flash_attention_2"
258258
)
259259

260+
# patch for rope
261+
if self.config.model.rope_scaling is not None:
262+
actor_model_config.rope_scaling = OmegaConf.to_container(self.config.model.rope_scaling)
263+
if self.config.model.rope_theta is not None:
264+
actor_model_config.rope_theta = self.config.model.rope_theta
265+
260266
# patch for kimi-vl
261267
if getattr(actor_model_config, "model_type", None) == "kimi_vl":
262268
actor_model_config.text_config.topk_method = "greedy"

trinity/trainer/verl/megatron_workers.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,82 @@ def __init__(self, config: DictConfig, role: str, **kwargs):
151151
)
152152
self._ref_is_offload_param = self.config.ref.megatron.get("param_offload", False)
153153

154+
def _init_hf_config_and_tf_config(
155+
self,
156+
model_path,
157+
tokenizer_or_path,
158+
dtype,
159+
override_model_config,
160+
override_transformer_config,
161+
trust_remote_code=False,
162+
use_mbridge=False,
163+
):
164+
from transformers import AutoConfig
165+
from verl.models.mcore import hf_to_mcore_config
166+
from verl.utils import hf_processor, hf_tokenizer
167+
from verl.utils.fs import copy_to_local
168+
from verl.utils.model import update_model_config
169+
170+
# Step 1: initialize the tokenizer
171+
self.local_path = copy_to_local(model_path)
172+
if tokenizer_or_path is None:
173+
self.tokenizer = hf_tokenizer(self.local_path, trust_remote_code=trust_remote_code)
174+
self.processor = hf_processor(self.local_path, trust_remote_code=trust_remote_code)
175+
elif isinstance(tokenizer_or_path, str):
176+
self.tokenizer = hf_tokenizer(
177+
copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code
178+
)
179+
self.processor = hf_processor(
180+
copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code
181+
)
182+
else:
183+
self.tokenizer = tokenizer_or_path
184+
self.processor = tokenizer_or_path
185+
186+
if self.config.model.get("custom_chat_template", None) is not None:
187+
if self.processor is not None:
188+
self.processor.chat_template = self.config.model.custom_chat_template
189+
else:
190+
self.tokenizer.chat_template = self.config.model.custom_chat_template
191+
192+
# Step 2: get the hf
193+
hf_config = AutoConfig.from_pretrained(self.local_path, trust_remote_code=trust_remote_code)
194+
195+
# Step 3: override the hf config
196+
override_config_kwargs = {
197+
"bos_token_id": self.tokenizer.bos_token_id,
198+
"eos_token_id": self.tokenizer.eos_token_id,
199+
"pad_token_id": self.tokenizer.pad_token_id,
200+
}
201+
override_config_kwargs.update(override_model_config.get("model_config", {}))
202+
203+
# patch for rope
204+
if self.config.model.rope_scaling is not None:
205+
hf_config.rope_scaling = OmegaConf.to_container(self.config.model.rope_scaling)
206+
if self.config.model.rope_theta is not None:
207+
hf_config.rope_theta = self.config.model.rope_theta
208+
209+
self.share_embeddings_and_output_weights = getattr(hf_config, "tie_word_embeddings", False)
210+
update_model_config(hf_config, override_config_kwargs=override_config_kwargs)
211+
self.architectures = getattr(hf_config, "architectures", None)
212+
if self.rank == 0:
213+
print(f"Model config after override: {hf_config}")
214+
tf_config = hf_to_mcore_config(hf_config, dtype, **override_transformer_config)
215+
216+
if use_mbridge:
217+
from verl.models.mcore.mbridge import AutoBridge
218+
219+
bridge = AutoBridge.from_config(hf_config)
220+
bridge.set_extra_args(**override_transformer_config)
221+
tf_config = bridge.config
222+
self.bridge = bridge
223+
else:
224+
self.bridge = None
225+
226+
print(f"TF config: {tf_config}")
227+
self.hf_config = hf_config
228+
self.tf_config = tf_config
229+
154230
def _build_model_optimizer(
155231
self, model_path, optim_config, override_model_config, override_transformer_config
156232
):

0 commit comments

Comments
 (0)