Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source/tutorial/develop_workflow.md
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ Here, `<config_file_path>` is the path to a YAML configuration file, which shoul
Once started, the model will keep running and wait for debug instructions; it will not exit automatically. You can then run the following command in another terminal to debug your workflow:

```bash
trinity debug --config <config_file_path> --module workflow --output_file <output_file_path> --plugin_dir <plugin_dir>
trinity debug --config <config_file_path> --module workflow --output-file <output_file_path> --plugin-dir <plugin_dir>
```

- `<config_file_path>`: Path to the YAML configuration file, usually the same as used for starting the inference model.
Expand Down
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source_zh/tutorial/develop_workflow.md
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ trinity debug --config <config_file_path> --module inference_model
模型启动后会持续运行并等待调试指令,不会自动退出。此时,你可在另一个终端执行如下命令进行 Workflow 调试:

```bash
trinity debug --config <config_file_path> --module workflow --output_file <output_file_path> --plugin_dir <plugin_dir>
trinity debug --config <config_file_path> --module workflow --output-file <output_file_path> --plugin-dir <plugin_dir>
```

- `config_file_path`:YAML 配置文件路径,通常与启动推理模型时使用的配置文件相同。
Expand Down
180 changes: 149 additions & 31 deletions tests/explorer/scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

import ray
import torch
from parameterized import parameterized

from tests.tools import get_template_config
from trinity.common.config import ExperienceBufferConfig
from trinity.common.constants import StorageType
from trinity.common.constants import StorageType, SyncStyle
from trinity.common.experience import EID, Experience
from trinity.common.models.model import InferenceModel
from trinity.common.workflows import Task
Expand Down Expand Up @@ -46,17 +47,23 @@ def run(self) -> List[Experience]:
elif self.error_type == "auxiliary_models":
assert self.auxiliary_models is not None and len(self.auxiliary_models) == 2

return [
Experience(
tokens=torch.zeros(5),
prompt_length=2,
prompt_text=self.error_type or "success",
eid=EID(run=i + self.run_id_base, step=step),
info={"repeat_times": self.repeat_times},
)
for step in range(self.step_num)
for i in range(self.repeat_times)
]
exps = []
for i in range(self.repeat_times):
run_level_metrics = {"run_metrics": float(i + self.run_id_base)}
run_level_exps = []
for step in range(self.step_num):
run_level_exps.append(
Experience(
tokens=torch.zeros(5),
prompt_length=2,
prompt_text=self.error_type or "success",
eid=EID(run=i + self.run_id_base, step=step),
info={"repeat_times": self.repeat_times},
)
)
run_level_exps[-1].metrics = run_level_metrics
exps.extend(run_level_exps)
return exps


@WORKFLOWS.register_module("dummy_nonrepeat_workflow")
Expand All @@ -67,22 +74,29 @@ def __init__(self, *, task, model, auxiliary_models):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.reset_flag = False
self.step_num = task.workflow_args.get("step_num", 1)
self.metrics = task.workflow_args.get("metrics", [0])

def reset(self, task: Task):
self.task = task
self.reset_flag = True
self.step_num = task.workflow_args.get("step_num", 1)
self.metrics = task.workflow_args.get("metrics", [0])

def run(self) -> List[Experience]:
return [
exps = [
Experience(
eid=EID(run=self.run_id_base, step=step),
tokens=torch.zeros(5),
prompt_length=2,
prompt_text="success",
info={"reset_flag": self.reset_flag},
metrics={
"run_metrics": self.metrics[step % len(self.metrics)],
},
)
for step in range(self.step_num)
]
return exps


@WORKFLOWS.register_module("dummy_async_workflow")
Expand All @@ -99,16 +113,22 @@ def set_repeat_times(self, repeat_times, run_id_base):
self.run_id_base = run_id_base

async def run_async(self):
return [
Experience(
eid=EID(run=i + self.run_id_base, step=step),
tokens=torch.zeros(5),
prompt_length=2,
prompt_text="success",
)
for step in range(self.step_num)
for i in range(self.repeat_times)
]
exps = []
for i in range(self.repeat_times):
run_level_metrics = {"run_metrics": float(i + self.run_id_base)}
run_level_exps = []
for step in range(self.step_num):
run_level_exps.append(
Experience(
eid=EID(run=i + self.run_id_base, step=step),
tokens=torch.zeros(5),
prompt_length=2,
prompt_text="success",
)
)
run_level_exps[-1].metrics = run_level_metrics
exps.extend(run_level_exps)
return exps

def run(self):
raise RuntimeError("This method should not be called")
Expand Down Expand Up @@ -490,7 +510,7 @@ async def test_split_tasks(self):
tasks = generate_tasks(4, repeat_times=8) # ceil(8 / 2) == 4
scheduler.schedule(tasks, batch_id=1)
statuses, exps = await scheduler.get_results(batch_id=1)
self.assertEqual(len(statuses), 4 * 4)
self.assertEqual(len(statuses), 4)
self.assertEqual(len(exps), 4 * 8)
exp_list.extend(exps)
_, exps = await scheduler.get_results(batch_id=1, min_num=1, timeout=1)
Expand All @@ -499,7 +519,7 @@ async def test_split_tasks(self):
tasks = generate_tasks(4, repeat_times=5) # ceil(5 / 2) == 3
scheduler.schedule(tasks, batch_id=2)
statuses, exps = await scheduler.get_results(batch_id=2)
self.assertEqual(len(statuses), 4 * 3)
self.assertEqual(len(statuses), 4)
self.assertEqual(len(exps), 4 * 5)
exp_list.extend(exps)
_, exps = await scheduler.get_results(batch_id=2, min_num=1, timeout=1)
Expand All @@ -508,7 +528,7 @@ async def test_split_tasks(self):
tasks = generate_tasks(3, repeat_times=1) # ceil(1 / 2) == 1
scheduler.schedule(tasks, batch_id=3)
statuses, exps = await scheduler.get_results(batch_id=3)
self.assertEqual(len(statuses), 3 * 1)
self.assertEqual(len(statuses), 3)
self.assertEqual(len(exps), 3 * 1)
exp_list.extend(exps)
_, exps = await scheduler.get_results(batch_id=3, min_num=1, timeout=1)
Expand All @@ -535,7 +555,7 @@ async def test_multi_step_execution(self):
for i in range(1, n_steps + 1):
scheduler.schedule(tasks, batch_id=i)
statuses, exps = await scheduler.get_results(batch_id=i)
self.assertEqual(len(statuses), 2 * 4)
self.assertEqual(len(statuses), 2)
self.assertEqual(len(exps), 2 * 4)

await scheduler.stop()
Expand All @@ -553,7 +573,7 @@ async def test_non_repeatable_workflow(self):
for i in range(1, batch_num + 1):
scheduler.schedule(tasks, batch_id=i)
statuses, exps = await scheduler.get_results(batch_id=i)
self.assertEqual(len(statuses), task_num * repeat_times / 2)
self.assertEqual(len(statuses), task_num)
self.assertEqual(len(exps), task_num * repeat_times)
exp_list.extend(exps)

Expand Down Expand Up @@ -594,7 +614,7 @@ async def test_async_workflow(self):
for i in range(1, batch_num + 1):
scheduler.schedule(tasks, batch_id=i)
statuses, exps = await scheduler.get_results(batch_id=i)
self.assertEqual(len(statuses), task_num * repeat_times / 2)
self.assertEqual(len(statuses), task_num)
self.assertEqual(len(exps), task_num * repeat_times * step_num)
exp_list.extend(exps)

Expand Down Expand Up @@ -624,7 +644,7 @@ async def test_stepwise_experience_eid(self):
for i in range(1, batch_num + 1):
scheduler.schedule(tasks, batch_id=i)
statuses, exps = await scheduler.get_results(batch_id=i)
self.assertEqual(len(statuses), task_num * repeat_times / 2)
self.assertEqual(len(statuses), task_num)
self.assertEqual(len(exps), task_num * repeat_times * step_num)
exp_list.extend(exps)

Expand All @@ -644,7 +664,7 @@ async def test_stepwise_experience_eid(self):
for i in range(1, batch_num + 1):
scheduler.schedule(tasks, batch_id=i)
statuses, exps = await scheduler.get_results(batch_id=i)
self.assertEqual(len(statuses), task_num * repeat_times / 2)
self.assertEqual(len(statuses), task_num)
self.assertEqual(len(exps), task_num * repeat_times * step_num)
exp_list.extend(exps)

Expand All @@ -656,6 +676,104 @@ async def test_stepwise_experience_eid(self):
unique_ids = [exp.eid.uid for exp in exp_list]
self.assertEqual(len(unique_ids), len(set(unique_ids)))

@parameterized.expand(
[
(2,),
(None,),
]
)
async def test_metric_calculation_with_repeatable_workflow(self, max_repeat_times_per_runner):
self.config.explorer.max_repeat_times_per_runner = max_repeat_times_per_runner
self.config.check_and_update()
scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()])
await scheduler.start()
tasks = []
tasks.extend(generate_tasks(total_num=1, step_num=1, repeat_times=4, repeatable=True))
tasks.extend(generate_tasks(total_num=1, step_num=4, repeat_times=8, repeatable=True))
scheduler.schedule(tasks, batch_id=0)
statuses, exps = await scheduler.get_results(batch_id=0)
self.assertEqual(len(statuses), 2)
self.assertEqual(len(exps), 1 * 4 * 1 + 1 * 8 * 4)
self.assertAlmostEqual(statuses[0].metrics[0]["run_metrics"], 1.5) # (0+1+2+3)/4
self.assertAlmostEqual(statuses[1].metrics[0]["run_metrics"], 3.5) # (0+1+2+3+4+5+6+7)/8

@parameterized.expand(
[
(2,),
(None,),
]
)
async def test_metric_calculation_with_non_repeatable_workflow(
self, max_repeat_times_per_runner
):
self.config.explorer.max_repeat_times_per_runner = max_repeat_times_per_runner
self.config.check_and_update()
scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()])
await scheduler.start()
tasks = []
tasks.extend(generate_tasks(total_num=1, step_num=3, repeat_times=4, repeatable=False))
tasks[-1].workflow_args["metrics"] = [1.0, 2.0, 3.0]
tasks.extend(generate_tasks(total_num=1, step_num=8, repeat_times=5, repeatable=False))
tasks[-1].workflow_args["metrics"] = [2 * i for i in range(8)]
scheduler.schedule(tasks, batch_id=0)
statuses, exps = await scheduler.get_results(batch_id=0)
self.assertEqual(len(statuses), 2)
self.assertEqual(len(exps), 1 * 4 * 3 + 1 * 5 * 8)
self.assertAlmostEqual(statuses[0].metrics[0]["run_metrics"], 2.0) # (1+2+3)/3
self.assertAlmostEqual(statuses[1].metrics[0]["run_metrics"], 7.0) # (0+2+4+6+8+10+12+14)/8

async def test_over_rollout_min_wait(self):
self.config.explorer.over_rollout.over_rollout_rate = 0.5
self.config.explorer.over_rollout.wait_time_after_min_threshold = 3
self.config.explorer.max_repeat_times_per_runner = None
self.config.buffer.batch_size = 4
self.config.synchronizer.sync_style = SyncStyle.DYNAMIC_BY_EXPLORER
self.config.check_and_update()
scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()])
await scheduler.start()
tasks = []
tasks.extend(generate_tasks(0, timeout_num=2, repeat_times=1, timeout_seconds=1))
tasks.extend(generate_tasks(0, timeout_num=1, repeat_times=1, timeout_seconds=3))
tasks.extend(generate_tasks(0, timeout_num=1, repeat_times=1, timeout_seconds=6))
scheduler.schedule(tasks, batch_id=0)
statuses, exps = await scheduler.get_results(batch_id=0, min_num=2)
self.assertEqual(len(statuses), 3)
self.assertEqual(len(exps), 3 * 1)

async def test_dynamic_timeout(self):
self.config.explorer.dynamic_timeout.enable = True
self.config.explorer.dynamic_timeout.dynamic_timeout_ratio = 3.0
self.config.buffer.batch_size = 4
self.config.explorer.max_timeout = 20
self.config.explorer.max_retry_times = 0 # no retry here
scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()])
await scheduler.start()
tasks = []
# generate 4 tasks that will run 1 second
tasks.extend(generate_tasks(0, timeout_num=4, repeat_times=1, timeout_seconds=1))
scheduler.schedule(tasks, batch_id=0) # first step will not use dynamic timeout
statuses, exps = await scheduler.get_results(batch_id=0)
self.assertEqual(len(statuses), 4)
# dynamic timeout will be set to 3.0 * 1.0 = 3.0 seconds for next step
tasks = []
tasks.extend(generate_tasks(0, timeout_num=4, repeat_times=1, timeout_seconds=4))
st = time.time()
scheduler.schedule(tasks, batch_id=1)
statuses, exps = await scheduler.get_results(batch_id=1)
et = time.time()
self.assertTrue(
et - st < 4
) # should wait about 1 * 3.0 seconds, here we set 4 seconds timeout
self.assertEqual(len(exps), 0)
self.assertEqual(len(statuses), 4)
# tasks take 2 seconds, which is within the dynamic timeout 3.0 * 1.0 = 3.0 seconds
tasks = []
tasks.extend(generate_tasks(0, timeout_num=4, repeat_times=1, timeout_seconds=2))
scheduler.schedule(tasks, batch_id=2)
statuses, exps = await scheduler.get_results(batch_id=2)
self.assertEqual(len(statuses), 4)
self.assertEqual(len(exps), 4)

def tearDown(self):
try:
ray.shutdown()
Expand Down
45 changes: 45 additions & 0 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,3 +987,48 @@ def test_trainer(self):

def tearDown(self):
shutil.rmtree(self.config.checkpoint_job_dir)


class TestOverRollout(BaseTrainerCase):
def test_trainer(self):
self.config.algorithm.repeat_times = 4
self.config.buffer.batch_size = 4
self.config.buffer.total_steps = 2
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
self.config.name = f"explore-over-rollout-{datetime.now().strftime('%Y%m%d%H%M%S')}"
self.config.explorer.over_rollout.over_rollout_rate = 0.5 # set over rollout rate to 50%, which means only wait for 2 (4 * 50%) tasks in each steps
self.config.explorer.over_rollout.wait_time_after_min_threshold = 1
self.config.algorithm.algorithm_type = "grpo"
self.config.algorithm.advantage_fn = "grpo"
self.config.algorithm.advantage_fn_args = {
"epsilon": 1e-6,
}
self.config.synchronizer.sync_style = SyncStyle.DYNAMIC_BY_EXPLORER
self.config.synchronizer.sync_interval = 1
self.config.check_and_update()
both(self.config)
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
rollout_metrics = parser.metric_list("rollout")
self.assertTrue(len(rollout_metrics) > 0)
eval_metrics = parser.metric_list("eval")
self.assertTrue(len(eval_metrics) == 0)
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2)
self.assertTrue(parser.metric_exist("experience_pipeline/experience_count"))
experience_counts = parser.metric_values("experience_pipeline/experience_count")
self.assertTrue(len(experience_counts) == 2)
for count in experience_counts:
self.assertTrue(
count > 2 * 4
) # at least process 2 tasks in each step, repeat_times is 4
pg_loss = parser.metric_values("actor/pg_loss")
self.assertEqual(len(pg_loss), 1) # trainer only has 1 step
exp_save_path = self.config.buffer.trainer_input.experience_buffer.path
with open(exp_save_path, "r", encoding="utf-8") as f:
lines = f.readlines()
self.assertTrue(
len(lines) > 2 * 4 * 2
) # total_steps * repeat_times * batch_size * min_waited_tasks

def tearDown(self):
# remove dir only when the test passed
shutil.rmtree(self.config.checkpoint_job_dir)
Loading