Skip to content

Commit 75f7d2d

Browse files
authored
Fix dynamic timeout (#409)
1 parent 404bc13 commit 75f7d2d

File tree

3 files changed

+39
-6
lines changed

3 files changed

+39
-6
lines changed

tests/explorer/scheduler_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,18 @@ async def test_dynamic_timeout(self):
785785
scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()])
786786
await scheduler.start()
787787
tasks = []
788+
tasks.extend(generate_tasks(0, timeout_num=4, repeat_times=1, timeout_seconds=1))
789+
for task in tasks:
790+
task.is_eval = True
791+
scheduler.schedule(
792+
tasks, batch_id="0/eval"
793+
) # eval tasks will not count into dynamic timeout
794+
statuses, exps = await scheduler.get_results(batch_id="0/eval")
795+
self.assertEqual(len(statuses), 4)
796+
self.assertEqual(len(exps), 0)
797+
self.assertEqual(scheduler.total_running_time, 0)
798+
self.assertEqual(scheduler.total_completed_tasks, 0)
799+
tasks = []
788800
# generate 4 tasks that will run 1 second
789801
tasks.extend(generate_tasks(0, timeout_num=4, repeat_times=1, timeout_seconds=1))
790802
scheduler.schedule(tasks, batch_id=0) # first step will not use dynamic timeout

tests/trainer/trainer_test.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,10 +1031,18 @@ def test_trainer(self):
10311031
self.config.algorithm.repeat_times = 4
10321032
self.config.buffer.batch_size = 4
10331033
self.config.buffer.total_steps = 2
1034-
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
1034+
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config(
1035+
"countdown", "train"
1036+
)
1037+
self.config.buffer.explorer_input.eval_tasksets = [
1038+
get_unittest_dataset_config("countdown", "test")
1039+
]
1040+
self.config.buffer.eval_interval = 4 # only eval on start
10351041
self.config.name = f"explore-over-rollout-{datetime.now().strftime('%Y%m%d%H%M%S')}"
10361042
self.config.explorer.over_rollout.ratio = 0.5 # set over rollout rate to 50%, which means only wait for 2 (4 * 50%) tasks in each steps
10371043
self.config.explorer.over_rollout.wait_after_min = 0
1044+
self.config.explorer.dynamic_timeout.enable = True
1045+
self.config.explorer.dynamic_timeout.ratio = 2
10381046
self.config.algorithm.algorithm_type = "grpo"
10391047
self.config.algorithm.advantage_fn = "grpo"
10401048
self.config.algorithm.advantage_fn_args = {
@@ -1048,7 +1056,7 @@ def test_trainer(self):
10481056
rollout_metrics = parser.metric_list("rollout")
10491057
self.assertTrue(len(rollout_metrics) > 0)
10501058
eval_metrics = parser.metric_list("eval")
1051-
self.assertTrue(len(eval_metrics) == 0)
1059+
self.assertTrue(len(eval_metrics) > 0)
10521060
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2)
10531061
self.assertTrue(parser.metric_exist("experience_pipeline/experience_count"))
10541062
experience_counts = parser.metric_values("experience_pipeline/experience_count")

trinity/explorer/scheduler.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,22 @@ class TaskWrapper:
2626

2727
task: Task
2828
batch_id: Union[int, str]
29-
sub_task_num: int = 1
29+
sub_task_num: int = 1 # number of sub tasks splitted from this task
30+
# if max_repeat_times_per_runner is set, one task may be splitted into multiple sub tasks
3031
results: List[Tuple[Status, List[Experience]]] = field(default_factory=list)
3132

3233

3334
def calculate_task_level_metrics(metrics: List[Dict]) -> Dict[str, float]:
34-
"""Calculate task level metrics from experiences."""
35+
"""Calculate task level metrics (mean) from multiple runs of the same task.
36+
37+
Args:
38+
metrics (`List[Dict]`): A list of metric dictionaries from multiple runs of the same task.
39+
40+
Returns:
41+
`Dict[str, float]`: A dictionary of aggregated metrics, where each metric is averaged over all runs.
42+
43+
TODO: support more aggregation methods like max, min.
44+
"""
3545
if not metrics:
3646
return {}
3747
aggregated_metrics: Dict[str, List[float]] = defaultdict(list)
@@ -312,11 +322,13 @@ def task_done_callback(self, async_task: asyncio.Task):
312322
return
313323
else:
314324
status, exps, runner_id, run_time = async_task.result()
315-
self.total_running_time += run_time
316-
self.total_completed_tasks += 1
325+
if not task.task.is_eval: # only count running time for non-eval tasks
326+
self.total_running_time += run_time
327+
self.total_completed_tasks += 1
317328
task.results.append((status, exps))
318329
self.busy_runners.pop(runner_id)
319330
self.idle_runners.add(runner_id)
331+
# If all sub runs in a task are completed
320332
if len(task.results) == task.sub_task_num:
321333
task_experiences = []
322334
task_metrics = []
@@ -326,6 +338,7 @@ def task_done_callback(self, async_task: asyncio.Task):
326338
task_experiences.extend(exp)
327339
if not s.ok:
328340
all_success = False
341+
# calculate task level metrics
329342
task_status = Status(
330343
ok=all_success, metrics=[calculate_task_level_metrics(task_metrics)]
331344
)

0 commit comments

Comments
 (0)