Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/sphinx_doc/source/tutorial/example_reasoning_basic.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@ Download the GSM8K dataset to the local directory `$DATASET_PATH/gsm8k`:

```bash
# Using Modelscope
modelscope download --dataset modelscope/gsm8k --local_dir $DATASET_PATH/gsm8k
modelscope download --dataset AI-ModelScope/gsm8k --local_dir $DATASET_PATH/gsm8k

# Using Huggingface
huggingface-cli download openai/gsm8k --repo-type dataset --local-dir $DATASET_PATH/gsm8k
```

More details on dataset downloading are referred to [ModelScope](https://modelscope.cn/docs/datasets/download) or [Huggingface](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli#download-a-dataset-or-a-space).
The dataset downloaded from ModelScope may lack the `dtype` field and cause error when loading the dataset. To solve this issue, please delete the `dataset_infos.json` file and run the experiment again.

## Step 2: Set up Configuration and Run Experiment

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@ huggingface-cli download Qwen/Qwen2.5-1.5B-Instruct --local-dir $MODEL_PATH/Qwen

```bash
# 使用 Modelscope
modelscope download --dataset modelscope/gsm8k --local_dir $DATASET_PATH/gsm8k
modelscope download --dataset AI-ModelScope/gsm8k --local_dir $DATASET_PATH/gsm8k

# 使用 Huggingface
huggingface-cli download openai/gsm8k --repo-type dataset --local-dir $DATASET_PATH/gsm8k
```

更多关于数据集下载的细节请参考 [ModelScope](https://modelscope.cn/docs/datasets/download) 或 [Huggingface](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli#download-a-dataset-or-a-space)。
从 ModelScope 下载的数据集可能缺少 `dtype` 字段,导致加载数据集时出错。要解决这个问题,请删除 `dataset_infos.json` 文件并重新运行实验。

## 第 2 步:配置实验并运行

Expand Down
40 changes: 40 additions & 0 deletions examples/grpo_frozen_lake/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Frozen Lake

This example shows the usage of GRPO on the [Frozen Lake](https://gymnasium.farama.org/environments/toy_text/frozen_lake/) task.

## Data and Environment Preparation

After setting up the basic environment following the [installation section of Quickstart](../../docs/sphinx_doc/source/tutorial/example_reasoning_basic.md#step-0-environment-preparation), you need to install the additional dependencies by running the following command:

```bash
pip install gymnasium[toy_text]
```

Then, we prepare the dataset by running the following command:

```bash
cd examples/grpo_frozen_lake
python get_frozen_lake_data.py
```

This command will save the dataset to the local directory `{DATA_ROOT_DIR}/frozenlake`, and print the path of the dataset. Afterwards, make sure to set the environment variable `TRINITY_TASKSET_PATH` to the path of the dataset.
```bash
export TRINITY_TASKSET_PATH={DATA_ROOT_DIR}/frozenlake
```


## Workflow Configuration and Training

We use a concatenated multi-turn workflow `FrozenLakeWorkflow` to solve the Frozen Lake task. For each rollout, the multi-turn interaction in between the agent and feedback from the environment are stored in a single `Experience` object.
The specific configuration is located in [`frozen_lake.yaml`](frozen_lake.yaml).

To run this example, you can use the following command:

```bash
trinity run --config examples/grpo_frozen_lake/frozen_lake.yaml
```

## Results
We show the result with a Qwen2.5-3B-Instruct model in the following. The figures demonstrate the reward increases over training steps.

![reward](frozen_lake_reward.png)
85 changes: 85 additions & 0 deletions examples/grpo_frozen_lake/frozen_lake.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
project: "FrozenLake"
name: "trinity-frozen-lake"
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
algorithm:
algorithm_type: grpo
repeat_times: 8
optimizer:
lr: 1e-6
policy_loss_fn_args:
loss_agg_mode: "seq-mean-token-sum"
clip_range_low: 0.2
clip_range_high: 0.28
kl_loss_fn_args:
kl_coef: 0.0
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-3B-Instruct}
enable_prompt_truncation: false
max_response_tokens: 10240
max_model_len: 14436
temperature: 0.7
cluster:
node_num: 1
gpu_per_node: 8
buffer:
total_epochs: 1
batch_size: 64
explorer_input:
taskset:
name: frozenlake
storage_type: file
path: ${oc.env:TRINITY_TASKSET_PATH}
split: train
workflow_args:
env_max_steps: 8
agent_max_steps: 10
is_slippery: false
eval_tasksets:
- name: frozenlake
storage_type: file
path: ${oc.env:TRINITY_TASKSET_PATH}
split: test
workflow_args:
env_max_steps: 8
agent_max_steps: 10
is_slippery: false
rollout_args:
n: 4
top_p: 0.8
top_k: 20
default_workflow_type: 'frozen_lake_workflow'
explorer:
eval_on_startup: true
eval_interval: 10
runner_per_model: 8
rollout_model:
engine_num: 6
tensor_parallel_size: 1
enable_chunked_prefill: true
enforce_eager: false
dtype: bfloat16
seed: 42
gpu_memory_utilization: 0.85
trainer:
trainer_type: 'verl'
save_interval: 1000
use_dynamic_bsz: true
max_token_len_per_gpu: 16384
ulysses_sequence_parallel_size: 1
trainer_config:
actor_rollout_ref:
hybrid_engine: true
model:
use_remove_padding: true
enable_gradient_checkpointing: true
actor:
fsdp_config:
param_offload: true
optimizer_offload: true
ref:
fsdp_config:
param_offload: true
synchronizer:
sync_method: nccl
sync_interval: 1
sync_timeout: 1200
Binary file added examples/grpo_frozen_lake/frozen_lake_reward.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
92 changes: 92 additions & 0 deletions examples/grpo_frozen_lake/get_frozen_lake_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""
Modified from https://github.com/rllm-org/rllm/blob/main/examples/frozenlake/prepare_frozenlake_data.py
"""
import os

import numpy as np
import pandas as pd

from trinity.common.constants import TASKSET_PATH_ENV_VAR

path_from_env = os.environ.get(TASKSET_PATH_ENV_VAR)
if path_from_env is not None:
DATA_ROOT_DIR = os.path.dirname(path_from_env)
else:
DATA_ROOT_DIR = os.path.join(os.path.dirname(__file__), "data")


def save_dataset_to_local(name: str, data: list[dict], split: str = "default") -> str:
"""Save dataset directly to local DATA_PATH.

Args:
name: Name of the dataset
data: List of dictionaries containing the dataset examples
split: Split name (e.g., 'train', 'test', 'default')

Returns:
str: Path to the saved parquet file
"""
dataset_dir = os.path.join(DATA_ROOT_DIR, name)
os.makedirs(dataset_dir, exist_ok=True)

# Convert to DataFrame and save
data_df = pd.DataFrame(data)
dataset_path = os.path.join(dataset_dir, f"{split}.parquet")
data_df.to_parquet(dataset_path)

print(
f"Saved dataset '{name}' split '{split}' with {len(data)} examples at {dataset_path}. Make sure to set the environment variable {TASKSET_PATH_ENV_VAR} to {DATA_ROOT_DIR}/{name}."
)

return dataset_path


def prepare_frozenlake_data(train_size=10000, test_size=100, map_max_size=6):
"""
Prepare and save FrozenLake datasets for training and testing.

Args:
train_size (int): Number of training examples to generate
test_size (int): Number of test examples to generate

Returns:
tuple: (train_data, test_data) - Lists of data dictionaries
"""
# Set random seed for reproducibility
np.random.seed(42)

# Generate random parameters for train and test sets
train_seeds = np.random.randint(0, 100000, size=train_size)
test_seeds = np.random.randint(0, 100000, size=test_size)
train_sizes = np.random.randint(2, map_max_size, size=train_size)
test_sizes = np.random.randint(2, map_max_size, size=test_size)
train_ps = np.random.uniform(0.6, 0.85, size=train_size)
test_ps = np.random.uniform(0.6, 0.85, size=test_size)

def frozenlake_process_fn(seed, size, p, idx):
"""Process function to create FrozenLake task instances."""
return {"seed": seed, "size": size, "p": p, "index": idx, "uid": f"{seed}_{size}_{p}"}

# Create train and test data
train_data = [
frozenlake_process_fn(seed, train_sizes[idx], train_ps[idx], idx)
for idx, seed in enumerate(train_seeds)
]
test_data = [
frozenlake_process_fn(seed, test_sizes[idx], test_ps[idx], idx)
for idx, seed in enumerate(test_seeds)
]

# Save datasets directly to local DATA_PATH
save_dataset_to_local("frozenlake", train_data, "train")
save_dataset_to_local("frozenlake", test_data, "test")

return train_data, test_data


if __name__ == "__main__":
train_data, test_data = prepare_frozenlake_data()
print(f"Train dataset: {len(train_data)} examples")
print(f"Test dataset: {len(test_data)} examples")
print("Sample train example:", train_data[0])
print("Sample test example:", test_data[0])
68 changes: 62 additions & 6 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def setUp(self):
self.config.model.max_model_len = self.max_model_len
self.config.model.max_prompt_tokens = self.max_prompt_tokens
self.config.model.max_response_tokens = self.max_response_tokens
self.config.model.enable_prompt_truncation = True
self.config.explorer.rollout_model.enable_openai_api = True
self.config.check_and_update()

Expand All @@ -246,14 +247,21 @@ async def test_model_len(self):
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather like today?"},
]

# For vllm engine, max_prompt_tokens and max_response_tokens work
response = self.model_wrapper.chat(messages)
self.assertEqual(len(response), 1)
self.assertEqual(len(response[0].tokens), self.max_model_len)
self.assertEqual(len(response[0].tokens), self.config.model.max_model_len)
exps = self.model_wrapper.extract_experience_from_history()
self.assertEqual(len(exps), 1)
self.assertEqual(len(exps[0].tokens), self.max_model_len)
# check prompt length, response length, max_model_len
self.assertEqual(exps[0].prompt_length, self.config.model.max_prompt_tokens)
self.assertEqual(
len(exps[0].tokens) - exps[0].prompt_length, self.config.model.max_response_tokens
)
self.assertLessEqual(len(response[0].tokens), self.config.model.max_model_len)

# max_prompt_tokens and max_response_tokens do not work with openai api
# For openai api, max_prompt_tokens and max_response_tokens do not work
openai_client = self.model_wrapper.get_openai_client()
model_id = openai_client.models.list().data[0].id
with self.assertRaises(BadRequestError):
Expand All @@ -267,9 +275,57 @@ async def test_model_len(self):
exps = self.model_wrapper.extract_experience_from_history()
self.assertEqual(len(exps), 1)
# only generate max_response_tokens tokens
self.assertEqual(
len(exps[0].tokens),
response.usage.prompt_tokens + self.config.model.max_response_tokens,
self.assertLessEqual(
len(exps[0].tokens) - response.usage.prompt_tokens,
self.config.model.max_response_tokens,
)


class TestModelLenWithoutPromptTruncation(RayUnittestBaseAysnc):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
self.config.model.model_path = get_model_path()
self.config.model.max_model_len = 20
self.config.model.max_prompt_tokens = 1
self.config.model.max_response_tokens = None
self.config.model.enable_prompt_truncation = False
self.config.explorer.rollout_model.enable_openai_api = True
self.config.check_and_update()

self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True)

async def test_model_len(self):
await self.model_wrapper.prepare()
messages = [
{"role": "user", "content": "How are you?"},
]

# For vllm engine, max_prompt_tokens and max_response_tokens work
response = self.model_wrapper.chat(messages)
self.assertEqual(len(response), 1)
self.assertLessEqual(
len(response[0].tokens) - response[0].prompt_length,
self.config.model.max_response_tokens,
)
exps = self.model_wrapper.extract_experience_from_history()
self.assertEqual(len(exps), 1)
self.assertLessEqual(
len(exps[0].tokens) - exps[0].prompt_length,
self.config.model.max_response_tokens,
)

# For openai api
openai_client = self.model_wrapper.get_openai_client()
model_id = openai_client.models.list().data[0].id
response = openai_client.chat.completions.create(model=model_id, messages=messages, n=1)
self.assertEqual(len(response.choices), 1)
exps = self.model_wrapper.extract_experience_from_history()
self.assertEqual(len(exps), 1)
self.assertLessEqual(
len(exps[0].tokens) - response.usage.prompt_tokens,
self.config.model.max_response_tokens,
)


Expand Down
1 change: 1 addition & 0 deletions trinity/algorithm/advantage_fn/grpo_advantage.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def calculate_group_advantage(

metrics["reward_mean"] = group_reward_mean.item()
metrics["reward_std"] = group_reward_std.item()
metrics["advantage_std"] = exp.advantages.std().item()

return exps, metrics

Expand Down
Loading