Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ model:
- `max_response_tokens`: Maximum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`.
- `max_prompt_tokens`: Maximum number of tokens allowed in prompts. Only for `chat` and `generate` methods in `InferenceModel`.
- `min_response_tokens`: Minimum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`. Default is `1`. It must be less than `max_response_tokens`.
- `enable_prompt_truncation`: Whether to truncate the prompt. Default is `true`. If set to `true`, the prompt will be truncated to `max_prompt_tokens` tokens; if set to `false`, the prompt will not be truncated and there is a risk that the prompt length plus response length exceeds `max_model_len`.
- `enable_prompt_truncation`: Whether to truncate the prompt. Default is `true`. If set to `true`, the prompt will be truncated to `max_prompt_tokens` tokens; if set to `false`, the prompt will not be truncated and there is a risk that the prompt length plus response length exceeds `max_model_len`. This function does not work with openai api mode.

```{tip}
If you are using the openai API provided by Explorer, only `max_model_len` will take effect, and the value of `max_response_tokens`, `max_prompt_tokens`, and `min_response_tokens` will be ignored. When `max_tokens` is not independently specified, each API call will generate up to `max_model_len - prompt_length` tokens. Therefore, please ensure that the prompt length is less than `max_model_len` when using the API.
Expand Down
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source_zh/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ model:
- `max_prompt_tokens`: 输入 prompt 中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。
- `max_response_tokens`: 模型生成的回复中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。
- `min_response_tokens`: 模型生成的回复中允许的最小 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。
- `enable_prompt_truncation`: 是否截断 prompt。默认为 `true`。若设置为 `true`,则 prompt 将被截断为 `max_prompt_tokens` 个 token;若设置为 `false`,则 prompt 不会被截断,存在 prompt 和 response 长度之和超过 `max_model_len` 的风险。
- `enable_prompt_truncation`: 是否截断 prompt。默认为 `true`。若设置为 `true`,则 prompt 将被截断为 `max_prompt_tokens` 个 token;若设置为 `false`,则 prompt 不会被截断,存在 prompt 和 response 长度之和超过 `max_model_len` 的风险。在 OpenAI API 模式下不生效。

```{tip}
如果使用的是 Explorer 提供的 openai API,则只有 `max_model_len` 会生效,而 `max_response_tokens`、`max_prompt_tokens` 和 `min_response_tokens` 的值将被忽略,在没有独立指定 `max_tokens` 时,每次 API 调用将生成最多 `max_model_len - prompt_length` 个 token,因此在使用时请确保 prompt 长度小于 `max_model_len`。
Expand Down
2 changes: 1 addition & 1 deletion tests/common/experience_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def test_assertions(self):
# prompt_length must be > 0
with self.assertRaises(AssertionError):
Experience(tokens=[1, 2, 3], prompt_length=0)
# tokens must be longer than prompt_length for single-turn
# tokens must be larger than prompt_length for single-turn
with self.assertRaises(AssertionError):
Experience(tokens=[1, 2], prompt_length=2)
# DPO: tokens must match prompt_length
Expand Down
28 changes: 21 additions & 7 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ async def test_generate(
[
(20, 19, None),
(20, None, 1),
(20, 5, 15),
],
)
class TestModelLen(RayUnittestBaseAysnc):
Expand All @@ -240,6 +241,7 @@ def setUp(self):

self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True)
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model.model_path)

async def test_model_len(self):
await self.model_wrapper.prepare()
Expand All @@ -248,18 +250,30 @@ async def test_model_len(self):
{"role": "user", "content": "What's the weather like today?"},
]

def _check_experience(exp):
# check prompt content and length
encoded_prompt = self.tokenizer.encode(exp.prompt_text, add_special_tokens=False)
self.assertEqual(len(encoded_prompt), exp.prompt_length)
self.assertLessEqual(exp.prompt_length, self.config.model.max_prompt_tokens)
# check response content and length
encoded_response = self.tokenizer.encode(exp.response_text, add_special_tokens=False)
self.assertEqual(len(encoded_response), len(exp.tokens) - exp.prompt_length)
self.assertLessEqual(
len(exp.tokens) - exp.prompt_length, self.config.model.max_response_tokens
)
# check full sequence
self.assertLessEqual(len(exp.tokens), self.config.model.max_model_len)

# For vllm engine, max_prompt_tokens and max_response_tokens work
response = self.model_wrapper.chat(messages)
self.assertEqual(len(response), 1)
self.assertEqual(len(response[0].tokens), self.config.model.max_model_len)
if self.max_prompt_tokens == 5:
self.assertEqual(response[0].truncate_status, "prompt_truncated")
_check_experience(response[0])

exps = self.model_wrapper.extract_experience_from_history()
self.assertEqual(len(exps), 1)
# check prompt length, response length, max_model_len
self.assertEqual(exps[0].prompt_length, self.config.model.max_prompt_tokens)
self.assertEqual(
len(exps[0].tokens) - exps[0].prompt_length, self.config.model.max_response_tokens
)
self.assertLessEqual(len(response[0].tokens), self.config.model.max_model_len)
_check_experience(exps[0])

# For openai api, max_prompt_tokens and max_response_tokens do not work
openai_client = self.model_wrapper.get_openai_client()
Expand Down
41 changes: 41 additions & 0 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,3 +1043,44 @@ def test_trainer(self):
def tearDown(self):
# remove dir only when the test passed
shutil.rmtree(self.config.checkpoint_job_dir)


class TestTrainerPromptTruncation(BaseTrainerCase):
def test_trainer(self):
self.config.model.max_model_len = 20
self.config.model.max_prompt_tokens = 5
self.config.model.max_response_tokens = 15
self.config.model.enable_prompt_truncation = True
self.config.algorithm.algorithm_type = "grpo"
self.config.algorithm.advantage_fn = "grpo"
self.config.algorithm.kl_loss_fn = "none"
self.config.algorithm.repeat_times = 2
self.config.buffer.batch_size = 4
self.config.buffer.total_steps = 2
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
self.config.check_and_update()
both(self.config)

parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
rollout_metrics = parser.metric_list("rollout")
self.assertTrue(len(rollout_metrics) > 0)
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2)
actor_metrics = parser.metric_list("actor")
self.assertTrue(len(actor_metrics) > 0)
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 2)
max_prompt_length = parser.metric_values("prompt_length/max")
self.assertEqual(max(max_prompt_length), 5)
min_prompt_length = parser.metric_values("prompt_length/min")
self.assertEqual(min(min_prompt_length), 5)
max_response_length = parser.metric_values("response_length/max")
self.assertEqual(max(max_response_length), 1)
min_response_length = parser.metric_values("response_length/min")
self.assertEqual(min(min_response_length), 1)
final_loss = parser.metric_values("actor/final_loss")
self.assertEqual(final_loss[0], 0.0)
grad_norm = parser.metric_values("actor/grad_norm")
self.assertEqual(grad_norm[0], 0.0)

def tearDown(self):
# remove dir only when the test passed
shutil.rmtree(self.config.checkpoint_job_dir)
5 changes: 3 additions & 2 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,8 @@ class ModelConfig:
max_response_tokens: Optional[int] = None
# the minimum number of tokens for the response
min_response_tokens: int = 1
# whether to truncate the prompt; if set to True, the prompt will be truncated to `max_prompt_tokens` tokens.
# whether to truncate the prompt; if set to True, the prompt will be truncated to `max_prompt_tokens` tokens;
# not applicable for OpenAI API
enable_prompt_truncation: bool = True

# lora config
Expand Down Expand Up @@ -1192,7 +1193,7 @@ def _check_model(self) -> None:
if model.enable_prompt_truncation is True:
if model.max_prompt_tokens is None:
raise ValueError(
"When `model.enable_prompt_truncation` is True, `model.max_prompt_tokens` must be set properly."
"When `model.enable_prompt_truncation` is True, `model.max_prompt_tokens` must be set properly. This function does not work with OpenAI API mode."
)
logger.warning(
f"`enable_prompt_truncation` is set to True; the prompt will be truncated to `max_prompt_tokens`={model.max_prompt_tokens} tokens if it is too long."
Expand Down
18 changes: 14 additions & 4 deletions trinity/common/experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ class Experience:
token_level_reward: Optional[Tensor] = None # [resp_length]
advantages: Optional[Tensor] = None # [resp_length]
returns: Optional[Tensor] = None # [resp_length]
truncate_status: Optional[
str
] = None # The status of truncation, e.g., "prompt_truncated", "response_truncated"; Not working for openai api
info: dict = field(
default_factory=dict
) # Additional information about the experience, can also be used to store custom fields
Expand Down Expand Up @@ -140,6 +143,7 @@ def __init__( # noqa: C901
token_level_reward=None,
advantages=None,
returns=None,
truncate_status=None,
info=None,
metrics=None,
prompt_length=1,
Expand All @@ -165,10 +169,13 @@ def __init__( # noqa: C901
assert (
prompt_length > 0
), "Prompt length must be greater than 0 for single-turn experiences."
assert (
len(tokens) > prompt_length
), f"Token ids must be longer than the prompt length. Got len(tokens)={len(tokens)}, prompt_length={prompt_length}."
action_mask = torch.ones(len(tokens) - prompt_length, dtype=torch.bool)
if truncate_status != "prompt_truncated":
assert (
len(tokens) > prompt_length
), f"Token ids must be larger than the prompt length. Got len(tokens)={len(tokens)}, prompt_length={prompt_length}."
action_mask = torch.ones(len(tokens) - prompt_length, dtype=torch.bool)
else:
action_mask = torch.zeros(len(logprobs), dtype=torch.bool)
elif experience_type == "dpo":
prompt_length = len(tokens)
if eid is None:
Expand Down Expand Up @@ -196,6 +203,7 @@ def __init__( # noqa: C901
self.experience_type = experience_type
self.info = info or {}
self.metrics = metrics or {}
self.truncate_status = truncate_status
self.prompt_length = prompt_length
self.response_text = response_text
self.prompt_text = prompt_text
Expand Down Expand Up @@ -264,6 +272,8 @@ def to_dict(self) -> dict:
res["rejected_messages"] = self.rejected_messages
if self.reward is not None:
res["reward"] = float(self.reward)
if self.truncate_status is not None:
res["truncate_status"] = self.truncate_status
return res

@classmethod
Expand Down
36 changes: 30 additions & 6 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,36 @@ async def generate(
"""
if self.tokenizer is None:
await self._initialize_tokenizer()

# Tokenize once without truncation to check if truncation is needed
token_ids = self.tokenizer( # type: ignore
prompt,
truncation=self.config.enable_prompt_truncation,
max_length=self.config.max_prompt_tokens,
truncation=False,
return_tensors="pt",
)["input_ids"][0].tolist()
)[
"input_ids"
][0].tolist()

# Check if truncation is needed and apply it
if self.config.enable_prompt_truncation and self.config.max_prompt_tokens is not None:
if len(token_ids) > self.config.max_prompt_tokens:
self.logger.warning(
f"Prompt was truncated to {self.config.max_prompt_tokens} tokens"
)
token_ids = token_ids[: self.config.max_prompt_tokens + 1] # leave one for response
return [
Experience(
tokens=token_ids,
logprobs=torch.zeros(1, dtype=torch.float32),
prompt_length=len(token_ids) - 1,
prompt_text=self.tokenizer.decode(token_ids[:-1]),
response_text=self.tokenizer.decode(token_ids[-1]),
truncate_status="prompt_truncated",
reward=0.0,
)
for i in range(kwargs.get("n", 1))
]

output = await self._generate_internal(
prompt={"prompt_token_ids": token_ids}, lora_request=lora_request, **kwargs
)
Expand Down Expand Up @@ -397,10 +421,10 @@ async def convert_messages_to_experience(

# Truncate tokens if they exceed the length limit
assert token_ids is not None
is_truncated = False # TODO: add to experience itself
truncate_status = None
if self.config.max_model_len is not None and self.config.max_model_len > 0:
if len(token_ids) > self.config.max_model_len - 1:
is_truncated = True
truncate_status = "response_truncated"
self.logger.warning(
f"Warning: {len(token_ids) = } exceeds the length limit {self.config.max_model_len-1 = }"
)
Expand All @@ -417,7 +441,7 @@ async def convert_messages_to_experience(
prompt_length=prompt_length,
action_mask=action_mask[prompt_length:], # Exclude the prompt tokens
messages=messages,
info={"is_truncated": is_truncated},
truncate_status=truncate_status,
)

async def shutdown(self):
Expand Down
3 changes: 3 additions & 0 deletions trinity/common/workflows/envs/frozen_lake/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ async def run_async(self) -> List[Experience]:
self.step_count = 0
self.action = None
terminate_reason = None
truncate_status = None

# Initialize messages
messages = []
Expand Down Expand Up @@ -318,6 +319,7 @@ async def run_async(self) -> List[Experience]:
self.done = False
self.step_rewards.append(0)
terminate_reason = "max_tokens_reached"
truncate_status = "response_truncated"
break

# Get action from the model
Expand Down Expand Up @@ -360,6 +362,7 @@ async def run_async(self) -> List[Experience]:
"env_done": 1 if self.done else 0,
"test_score": final_reward,
},
truncate_status=truncate_status,
)
return [experience]

Expand Down
7 changes: 5 additions & 2 deletions trinity/common/workflows/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,12 @@ def set_repeat_times(self, repeat_times, run_id_base):
self.repeat_times = repeat_times
self.run_id_base = run_id_base

def process_messages_to_experience(self, messages, reward, info={}) -> Experience:
def process_messages_to_experience(
self, messages, reward, info={}, truncate_status=None
) -> Experience:
converted_experience = self.model.convert_messages_to_experience(messages)

if converted_experience.info.get("is_truncated", False):
if converted_experience.truncate_status == "response_truncated":
reward = 0.0

tokens = converted_experience.tokens
Expand All @@ -188,6 +190,7 @@ def process_messages_to_experience(self, messages, reward, info={}) -> Experienc
prompt_length=converted_experience.prompt_length,
prompt_text=converted_experience.prompt_text,
response_text=converted_experience.response_text,
truncate_status=converted_experience.truncate_status or truncate_status,
reward=reward,
logprobs=log_probs,
info=info,
Expand Down
10 changes: 8 additions & 2 deletions trinity/trainer/verl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,10 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> dict:
if "advantages" in batch.batch:
# adv
advantages = batch.batch["advantages"]
valid_adv = torch.masked_select(advantages, response_mask)
if response_mask.numel() > 0:
valid_adv = torch.masked_select(advantages, response_mask)
else:
valid_adv = torch.zeros(1)
metrics.update(
{
# adv
Expand All @@ -166,7 +169,10 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> dict:
if "returns" in batch.batch:
# returns
returns = batch.batch["returns"]
valid_returns = torch.masked_select(returns, response_mask)
if response_mask.numel() > 0:
valid_returns = torch.masked_select(returns, response_mask)
else:
valid_returns = torch.zeros(1)
metrics.update(
{
"critic/returns/mean": torch.mean(valid_returns).detach().item(),
Expand Down