diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index c2c4593d23..2a74f6f92a 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -174,7 +174,7 @@ model: - `max_response_tokens`: Maximum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`. - `max_prompt_tokens`: Maximum number of tokens allowed in prompts. Only for `chat` and `generate` methods in `InferenceModel`. - `min_response_tokens`: Minimum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`. Default is `1`. It must be less than `max_response_tokens`. -- `enable_prompt_truncation`: Whether to truncate the prompt. Default is `true`. If set to `true`, the prompt will be truncated to `max_prompt_tokens` tokens; if set to `false`, the prompt will not be truncated and there is a risk that the prompt length plus response length exceeds `max_model_len`. +- `enable_prompt_truncation`: Whether to truncate the prompt. Default is `true`. If set to `true`, the prompt will be truncated to `max_prompt_tokens` tokens; if set to `false`, the prompt will not be truncated and there is a risk that the prompt length plus response length exceeds `max_model_len`. This function does not work with openai api mode. ```{tip} If you are using the openai API provided by Explorer, only `max_model_len` will take effect, and the value of `max_response_tokens`, `max_prompt_tokens`, and `min_response_tokens` will be ignored. When `max_tokens` is not independently specified, each API call will generate up to `max_model_len - prompt_length` tokens. Therefore, please ensure that the prompt length is less than `max_model_len` when using the API. diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md index 10f836fa62..804c2527a7 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md @@ -174,7 +174,7 @@ model: - `max_prompt_tokens`: 输入 prompt 中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。 - `max_response_tokens`: 模型生成的回复中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。 - `min_response_tokens`: 模型生成的回复中允许的最小 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。 -- `enable_prompt_truncation`: 是否截断 prompt。默认为 `true`。若设置为 `true`,则 prompt 将被截断为 `max_prompt_tokens` 个 token;若设置为 `false`,则 prompt 不会被截断,存在 prompt 和 response 长度之和超过 `max_model_len` 的风险。 +- `enable_prompt_truncation`: 是否截断 prompt。默认为 `true`。若设置为 `true`,则 prompt 将被截断为 `max_prompt_tokens` 个 token;若设置为 `false`,则 prompt 不会被截断,存在 prompt 和 response 长度之和超过 `max_model_len` 的风险。在 OpenAI API 模式下不生效。 ```{tip} 如果使用的是 Explorer 提供的 openai API,则只有 `max_model_len` 会生效,而 `max_response_tokens`、`max_prompt_tokens` 和 `min_response_tokens` 的值将被忽略,在没有独立指定 `max_tokens` 时,每次 API 调用将生成最多 `max_model_len - prompt_length` 个 token,因此在使用时请确保 prompt 长度小于 `max_model_len`。 diff --git a/tests/common/experience_test.py b/tests/common/experience_test.py index 826c29d546..195aaa61ae 100644 --- a/tests/common/experience_test.py +++ b/tests/common/experience_test.py @@ -175,7 +175,7 @@ def test_assertions(self): # prompt_length must be > 0 with self.assertRaises(AssertionError): Experience(tokens=[1, 2, 3], prompt_length=0) - # tokens must be longer than prompt_length for single-turn + # tokens must be larger than prompt_length for single-turn with self.assertRaises(AssertionError): Experience(tokens=[1, 2], prompt_length=2) # DPO: tokens must match prompt_length diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 2f281f4b32..ef3a5e7f7a 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -224,6 +224,7 @@ async def test_generate( [ (20, 19, None), (20, None, 1), + (20, 5, 15), ], ) class TestModelLen(RayUnittestBaseAysnc): @@ -240,6 +241,7 @@ def setUp(self): self.engines, self.auxiliary_engines = create_inference_models(self.config) self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True) + self.tokenizer = AutoTokenizer.from_pretrained(self.config.model.model_path) async def test_model_len(self): await self.model_wrapper.prepare() @@ -248,18 +250,30 @@ async def test_model_len(self): {"role": "user", "content": "What's the weather like today?"}, ] + def _check_experience(exp): + # check prompt content and length + encoded_prompt = self.tokenizer.encode(exp.prompt_text, add_special_tokens=False) + self.assertEqual(len(encoded_prompt), exp.prompt_length) + self.assertLessEqual(exp.prompt_length, self.config.model.max_prompt_tokens) + # check response content and length + encoded_response = self.tokenizer.encode(exp.response_text, add_special_tokens=False) + self.assertEqual(len(encoded_response), len(exp.tokens) - exp.prompt_length) + self.assertLessEqual( + len(exp.tokens) - exp.prompt_length, self.config.model.max_response_tokens + ) + # check full sequence + self.assertLessEqual(len(exp.tokens), self.config.model.max_model_len) + # For vllm engine, max_prompt_tokens and max_response_tokens work response = self.model_wrapper.chat(messages) self.assertEqual(len(response), 1) - self.assertEqual(len(response[0].tokens), self.config.model.max_model_len) + if self.max_prompt_tokens == 5: + self.assertEqual(response[0].truncate_status, "prompt_truncated") + _check_experience(response[0]) + exps = self.model_wrapper.extract_experience_from_history() self.assertEqual(len(exps), 1) - # check prompt length, response length, max_model_len - self.assertEqual(exps[0].prompt_length, self.config.model.max_prompt_tokens) - self.assertEqual( - len(exps[0].tokens) - exps[0].prompt_length, self.config.model.max_response_tokens - ) - self.assertLessEqual(len(response[0].tokens), self.config.model.max_model_len) + _check_experience(exps[0]) # For openai api, max_prompt_tokens and max_response_tokens do not work openai_client = self.model_wrapper.get_openai_client() diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index a873401599..8a87efe759 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -1043,3 +1043,44 @@ def test_trainer(self): def tearDown(self): # remove dir only when the test passed shutil.rmtree(self.config.checkpoint_job_dir) + + +class TestTrainerPromptTruncation(BaseTrainerCase): + def test_trainer(self): + self.config.model.max_model_len = 20 + self.config.model.max_prompt_tokens = 5 + self.config.model.max_response_tokens = 15 + self.config.model.enable_prompt_truncation = True + self.config.algorithm.algorithm_type = "grpo" + self.config.algorithm.advantage_fn = "grpo" + self.config.algorithm.kl_loss_fn = "none" + self.config.algorithm.repeat_times = 2 + self.config.buffer.batch_size = 4 + self.config.buffer.total_steps = 2 + self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") + self.config.check_and_update() + both(self.config) + + parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) + rollout_metrics = parser.metric_list("rollout") + self.assertTrue(len(rollout_metrics) > 0) + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2) + actor_metrics = parser.metric_list("actor") + self.assertTrue(len(actor_metrics) > 0) + self.assertEqual(parser.metric_max_step(actor_metrics[0]), 2) + max_prompt_length = parser.metric_values("prompt_length/max") + self.assertEqual(max(max_prompt_length), 5) + min_prompt_length = parser.metric_values("prompt_length/min") + self.assertEqual(min(min_prompt_length), 5) + max_response_length = parser.metric_values("response_length/max") + self.assertEqual(max(max_response_length), 1) + min_response_length = parser.metric_values("response_length/min") + self.assertEqual(min(min_response_length), 1) + final_loss = parser.metric_values("actor/final_loss") + self.assertEqual(final_loss[0], 0.0) + grad_norm = parser.metric_values("actor/grad_norm") + self.assertEqual(grad_norm[0], 0.0) + + def tearDown(self): + # remove dir only when the test passed + shutil.rmtree(self.config.checkpoint_job_dir) diff --git a/trinity/common/config.py b/trinity/common/config.py index 8ea02095f7..f63f57c2f4 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -457,7 +457,8 @@ class ModelConfig: max_response_tokens: Optional[int] = None # the minimum number of tokens for the response min_response_tokens: int = 1 - # whether to truncate the prompt; if set to True, the prompt will be truncated to `max_prompt_tokens` tokens. + # whether to truncate the prompt; if set to True, the prompt will be truncated to `max_prompt_tokens` tokens; + # not applicable for OpenAI API enable_prompt_truncation: bool = True # lora config @@ -1192,7 +1193,7 @@ def _check_model(self) -> None: if model.enable_prompt_truncation is True: if model.max_prompt_tokens is None: raise ValueError( - "When `model.enable_prompt_truncation` is True, `model.max_prompt_tokens` must be set properly." + "When `model.enable_prompt_truncation` is True, `model.max_prompt_tokens` must be set properly. This function does not work with OpenAI API mode." ) logger.warning( f"`enable_prompt_truncation` is set to True; the prompt will be truncated to `max_prompt_tokens`={model.max_prompt_tokens} tokens if it is too long." diff --git a/trinity/common/experience.py b/trinity/common/experience.py index 4e3aa936be..2a734144d2 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -104,6 +104,9 @@ class Experience: token_level_reward: Optional[Tensor] = None # [resp_length] advantages: Optional[Tensor] = None # [resp_length] returns: Optional[Tensor] = None # [resp_length] + truncate_status: Optional[ + str + ] = None # The status of truncation, e.g., "prompt_truncated", "response_truncated"; Not working for openai api info: dict = field( default_factory=dict ) # Additional information about the experience, can also be used to store custom fields @@ -140,6 +143,7 @@ def __init__( # noqa: C901 token_level_reward=None, advantages=None, returns=None, + truncate_status=None, info=None, metrics=None, prompt_length=1, @@ -165,10 +169,13 @@ def __init__( # noqa: C901 assert ( prompt_length > 0 ), "Prompt length must be greater than 0 for single-turn experiences." - assert ( - len(tokens) > prompt_length - ), f"Token ids must be longer than the prompt length. Got len(tokens)={len(tokens)}, prompt_length={prompt_length}." - action_mask = torch.ones(len(tokens) - prompt_length, dtype=torch.bool) + if truncate_status != "prompt_truncated": + assert ( + len(tokens) > prompt_length + ), f"Token ids must be larger than the prompt length. Got len(tokens)={len(tokens)}, prompt_length={prompt_length}." + action_mask = torch.ones(len(tokens) - prompt_length, dtype=torch.bool) + else: + action_mask = torch.zeros(len(logprobs), dtype=torch.bool) elif experience_type == "dpo": prompt_length = len(tokens) if eid is None: @@ -196,6 +203,7 @@ def __init__( # noqa: C901 self.experience_type = experience_type self.info = info or {} self.metrics = metrics or {} + self.truncate_status = truncate_status self.prompt_length = prompt_length self.response_text = response_text self.prompt_text = prompt_text @@ -264,6 +272,8 @@ def to_dict(self) -> dict: res["rejected_messages"] = self.rejected_messages if self.reward is not None: res["reward"] = float(self.reward) + if self.truncate_status is not None: + res["truncate_status"] = self.truncate_status return res @classmethod diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index aefafb315b..ba6ccf0f5f 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -191,12 +191,36 @@ async def generate( """ if self.tokenizer is None: await self._initialize_tokenizer() + + # Tokenize once without truncation to check if truncation is needed token_ids = self.tokenizer( # type: ignore prompt, - truncation=self.config.enable_prompt_truncation, - max_length=self.config.max_prompt_tokens, + truncation=False, return_tensors="pt", - )["input_ids"][0].tolist() + )[ + "input_ids" + ][0].tolist() + + # Check if truncation is needed and apply it + if self.config.enable_prompt_truncation and self.config.max_prompt_tokens is not None: + if len(token_ids) > self.config.max_prompt_tokens: + self.logger.warning( + f"Prompt was truncated to {self.config.max_prompt_tokens} tokens" + ) + token_ids = token_ids[: self.config.max_prompt_tokens + 1] # leave one for response + return [ + Experience( + tokens=token_ids, + logprobs=torch.zeros(1, dtype=torch.float32), + prompt_length=len(token_ids) - 1, + prompt_text=self.tokenizer.decode(token_ids[:-1]), + response_text=self.tokenizer.decode(token_ids[-1]), + truncate_status="prompt_truncated", + reward=0.0, + ) + for i in range(kwargs.get("n", 1)) + ] + output = await self._generate_internal( prompt={"prompt_token_ids": token_ids}, lora_request=lora_request, **kwargs ) @@ -397,10 +421,10 @@ async def convert_messages_to_experience( # Truncate tokens if they exceed the length limit assert token_ids is not None - is_truncated = False # TODO: add to experience itself + truncate_status = None if self.config.max_model_len is not None and self.config.max_model_len > 0: if len(token_ids) > self.config.max_model_len - 1: - is_truncated = True + truncate_status = "response_truncated" self.logger.warning( f"Warning: {len(token_ids) = } exceeds the length limit {self.config.max_model_len-1 = }" ) @@ -417,7 +441,7 @@ async def convert_messages_to_experience( prompt_length=prompt_length, action_mask=action_mask[prompt_length:], # Exclude the prompt tokens messages=messages, - info={"is_truncated": is_truncated}, + truncate_status=truncate_status, ) async def shutdown(self): diff --git a/trinity/common/workflows/envs/frozen_lake/workflow.py b/trinity/common/workflows/envs/frozen_lake/workflow.py index 35fc8bce98..c7a13c17fd 100644 --- a/trinity/common/workflows/envs/frozen_lake/workflow.py +++ b/trinity/common/workflows/envs/frozen_lake/workflow.py @@ -280,6 +280,7 @@ async def run_async(self) -> List[Experience]: self.step_count = 0 self.action = None terminate_reason = None + truncate_status = None # Initialize messages messages = [] @@ -318,6 +319,7 @@ async def run_async(self) -> List[Experience]: self.done = False self.step_rewards.append(0) terminate_reason = "max_tokens_reached" + truncate_status = "response_truncated" break # Get action from the model @@ -360,6 +362,7 @@ async def run_async(self) -> List[Experience]: "env_done": 1 if self.done else 0, "test_score": final_reward, }, + truncate_status=truncate_status, ) return [experience] diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 0798c3e65a..90f52a2505 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -165,10 +165,12 @@ def set_repeat_times(self, repeat_times, run_id_base): self.repeat_times = repeat_times self.run_id_base = run_id_base - def process_messages_to_experience(self, messages, reward, info={}) -> Experience: + def process_messages_to_experience( + self, messages, reward, info={}, truncate_status=None + ) -> Experience: converted_experience = self.model.convert_messages_to_experience(messages) - if converted_experience.info.get("is_truncated", False): + if converted_experience.truncate_status == "response_truncated": reward = 0.0 tokens = converted_experience.tokens @@ -188,6 +190,7 @@ def process_messages_to_experience(self, messages, reward, info={}) -> Experienc prompt_length=converted_experience.prompt_length, prompt_text=converted_experience.prompt_text, response_text=converted_experience.response_text, + truncate_status=converted_experience.truncate_status or truncate_status, reward=reward, logprobs=log_probs, info=info, diff --git a/trinity/trainer/verl/utils.py b/trinity/trainer/verl/utils.py index 9a35eb8a29..ab50b7c877 100644 --- a/trinity/trainer/verl/utils.py +++ b/trinity/trainer/verl/utils.py @@ -154,7 +154,10 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> dict: if "advantages" in batch.batch: # adv advantages = batch.batch["advantages"] - valid_adv = torch.masked_select(advantages, response_mask) + if response_mask.numel() > 0: + valid_adv = torch.masked_select(advantages, response_mask) + else: + valid_adv = torch.zeros(1) metrics.update( { # adv @@ -166,7 +169,10 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> dict: if "returns" in batch.batch: # returns returns = batch.batch["returns"] - valid_returns = torch.masked_select(returns, response_mask) + if response_mask.numel() > 0: + valid_returns = torch.masked_select(returns, response_mask) + else: + valid_returns = torch.zeros(1) metrics.update( { "critic/returns/mean": torch.mean(valid_returns).detach().item(),