Skip to content

Commit aecf41f

Browse files
committed
add terminal agent with openai-agents
Signed-off-by: CormickKneey <[email protected]>
1 parent 16bbfb3 commit aecf41f

19 files changed

+2345
-11
lines changed

areal/dataset/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from transformers.processing_utils import ProcessorMixin
1111
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
1212

13-
VALID_DATASETS = ["gsm8k", "clevr_count_70k", "geometry3k", "hh-rlhf", "torl_data"]
13+
VALID_DATASETS = ["gsm8k", "clevr_count_70k", "geometry3k", "hh-rlhf", "torl_data", "terminal_bench"]
1414

1515
logger = logging.getLogger("Dataset")
1616

@@ -105,6 +105,16 @@ def _get_custom_dataset(
105105
max_length=max_length,
106106
**kwargs,
107107
)
108+
elif "terminal_bench" in path and type == "rl":
109+
from .terminal_bench import get_terminal_bench_rl_dataset
110+
111+
return get_terminal_bench_rl_dataset(
112+
path=path,
113+
split=split,
114+
tokenizer=tokenizer,
115+
max_length=max_length,
116+
**kwargs,
117+
)
108118
else:
109119
raise ValueError(
110120
f"Dataset {path} with split {split} and training type {type} is not supported. "

areal/dataset/terminal_bench.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from typing import TYPE_CHECKING
2+
3+
from datasets import load_dataset
4+
5+
if TYPE_CHECKING:
6+
from transformers import PreTrainedTokenizerFast
7+
8+
9+
def get_terminal_bench_rl_dataset(
10+
path: str,
11+
split: str,
12+
tokenizer: "PreTrainedTokenizerFast",
13+
max_length: int | None = None,
14+
):
15+
"""Load terminal-bench dataset for RL training.
16+
17+
The dataset should be in parquet format with the following columns:
18+
- prompt: The formatted prompt for the task
19+
- task_name: Name of the task
20+
- instruction: Raw instruction text
21+
- extra_info: JSON string containing task metadata
22+
"""
23+
# Load from parquet file
24+
dataset = load_dataset("parquet", data_files={split: path}, split=split)
25+
26+
# The dataset already has the right format from the converter:
27+
# - prompt: contains the formatted conversation
28+
# - task_name, instruction, extra_info: metadata fields
29+
30+
# For RL training, we need to extract messages from the prompt or extra_info
31+
def process(sample):
32+
# The prompt is already formatted, but we need to extract the instruction
33+
# to create a messages structure for the workflow
34+
instruction = sample.get("instruction", "")
35+
task_name = sample.get("task_name", "")
36+
37+
# Return data in the format expected by the workflow
38+
return {
39+
"instruction": instruction,
40+
"task_name": task_name,
41+
"extra_info": sample.get("extra_info", ""),
42+
"data_source": sample.get("data_source", "terminal_bench"),
43+
}
44+
45+
dataset = dataset.map(process)
46+
47+
# Filter out sequences longer than max_length if specified
48+
if max_length is not None:
49+
50+
def filter_length(samples):
51+
# Tokenize instructions in batches for efficiency
52+
instructions = samples["instruction"]
53+
tokens_list = tokenizer(instructions, add_special_tokens=False)["input_ids"]
54+
return [len(tokens) <= max_length for tokens in tokens_list]
55+
56+
dataset = dataset.filter(filter_length, batched=True)
57+
58+
return dataset

areal/experimental/openai/client.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ async def create(
165165
input_ids=prompt_token_ids,
166166
gconfig=gconfig,
167167
rid=str(uuid.uuid4()),
168-
metadata=metadata if metadata is not NOT_GIVEN else {},
168+
metadata=metadata if not is_omitted(metadata) else {},
169169
tokenizer=self.tokenizer,
170170
)
171171

@@ -276,6 +276,27 @@ async def create(
276276
if input is NOT_GIVEN or input is None:
277277
raise ValueError("input is required for Responses.create")
278278

279+
def _convert_tool_output_format(item: dict) -> dict:
280+
"""Convert custom tool output format to standard chat template format.
281+
282+
Converts from: {'call_id': ..., 'output': ..., 'type': 'function_call_output'}
283+
To: {'role': 'tool', 'content': ..., 'tool_call_id': ...}
284+
"""
285+
if (
286+
item
287+
and item.get("output")
288+
and item.get("type") == "function_call_output"
289+
):
290+
converted = {
291+
"role": "tool",
292+
"content": item["output"],
293+
}
294+
# Add tool_call_id if present
295+
if "call_id" in item:
296+
converted["tool_call_id"] = item["call_id"]
297+
return converted
298+
return item
299+
279300
def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
280301
messages_list = []
281302
if "content" in item:
@@ -286,13 +307,17 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
286307
elif isinstance(item["content"], Iterable):
287308
for content in item["content"]:
288309
if isinstance(content, dict):
289-
messages_list.append(deepcopy(content))
310+
# Convert tool output format if needed
311+
converted = _convert_tool_output_format(content)
312+
messages_list.append(deepcopy(converted))
290313
else:
291314
raise ValueError("Unsupported content format")
292315
else:
293316
raise ValueError("Unsupported input item format")
294317
else:
295-
messages_list.append(deepcopy(item))
318+
# Convert tool output format if needed
319+
converted = _convert_tool_output_format(item)
320+
messages_list.append(deepcopy(converted))
296321
return messages_list
297322

298323
if isinstance(input, str):
@@ -335,7 +360,7 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
335360
temp = 1.0 if temperature is NOT_GIVEN else (temperature or 0.0)
336361
top_p_val = 1.0 if top_p is NOT_GIVEN else (top_p or 1.0)
337362
max_new_tokens = 512
338-
if max_output_tokens is not NOT_GIVEN and max_output_tokens is not None:
363+
if not is_omitted(max_output_tokens):
339364
max_new_tokens = max_output_tokens
340365

341366
stop = kwargs.get("stop", None)
@@ -359,7 +384,7 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
359384
input_ids=prompt_token_ids,
360385
gconfig=gconfig,
361386
rid=str(uuid.uuid4()),
362-
metadata=metadata if metadata is not NOT_GIVEN else {},
387+
metadata=metadata if not is_omitted(metadata) else {},
363388
tokenizer=self.tokenizer,
364389
)
365390

@@ -420,14 +445,14 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
420445
created_at=current_time,
421446
error=None,
422447
incomplete_details=None,
423-
instructions=None if instructions is NOT_GIVEN else instructions,
424-
metadata=None if metadata is NOT_GIVEN else metadata,
448+
instructions=None if is_omitted(instructions) else instructions,
449+
metadata=None if is_omitted(metadata) else metadata,
425450
model="None",
426451
object="response",
427452
output=resp_output,
428453
parallel_tool_calls=False,
429454
temperature=temp,
430-
tool_choice=tool_choice if tool_choice is not NOT_GIVEN else "none",
455+
tool_choice=tool_choice if not is_omitted(tool_choice) else "none",
431456
tools=tools,
432457
top_p=top_p_val,
433458
background=None,
@@ -751,3 +776,13 @@ def export_responses(
751776
"export_responses is deprecated. Please use export_interactions instead."
752777
)
753778
return self.export_interactions(style)
779+
780+
781+
def is_omitted(value) -> bool:
782+
"""Check if a value is NOT_GIVEN or Omit type."""
783+
if value is NOT_GIVEN or value is None:
784+
return True
785+
# Check by class name to handle both NotGiven and Omit
786+
if hasattr(value, "__class__"):
787+
return value.__class__.__name__ in ("NotGiven", "Omit")
788+
return False
59.3 KB
Loading

examples/__init__.py

Whitespace-only changes.
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
import asyncio
2+
import logging
3+
import os
4+
5+
from agents import Agent as OpenAIAgent
6+
from agents import ModelSettings, OpenAIProvider, RunConfig, SQLiteSession
7+
from agents import Runner as OpenAIRunner
8+
from terminal.env import TerminalEnv
9+
from terminal.prompt import SYSTEM_PROMPT
10+
from transformers import PreTrainedTokenizerFast
11+
12+
from areal.api.cli_args import GenerationHyperparameters
13+
from areal.api.workflow_api import RolloutWorkflow
14+
from areal.experimental.openai import ArealOpenAI
15+
from areal.utils import stats_tracker
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
class TerminalAgent:
21+
def __init__(
22+
self,
23+
tokenizer: PreTrainedTokenizerFast,
24+
max_tokens_per_turn: int = 1024,
25+
max_turns: int = 8,
26+
max_total_tokens: int = 32768,
27+
dump_dir: str | None = None,
28+
rollout_stat_scope: str = "rollout",
29+
):
30+
self.tokenizer = tokenizer
31+
self.max_tokens_per_turn = max_tokens_per_turn
32+
self.max_turns = max_turns
33+
self.max_total_tokens = max_total_tokens
34+
self.dump_dir = dump_dir
35+
self.rollout_stat_scope = rollout_stat_scope
36+
37+
async def run_agent(self, data, client: ArealOpenAI):
38+
"""Run the agent workflow for terminal task execution."""
39+
run_config = RunConfig(
40+
model_provider=OpenAIProvider(
41+
openai_client=client,
42+
use_responses=True,
43+
),
44+
tracing_disabled=True,
45+
model_settings=ModelSettings(
46+
temperature=1.0,
47+
extra_args={"max_completion_tokens": self.max_tokens_per_turn},
48+
tool_choice="auto",
49+
store=True,
50+
),
51+
)
52+
53+
async with TerminalEnv(
54+
task_name=data["task_name"],
55+
dump_dir=self.dump_dir,
56+
rollout_stat_scope=self.rollout_stat_scope,
57+
) as env:
58+
# Create agent workflow with terminal tools
59+
agent = OpenAIAgent(
60+
name="Terminal Task Agent",
61+
instructions=SYSTEM_PROMPT,
62+
tools=env.get_tools(),
63+
)
64+
session = SQLiteSession("terminal")
65+
content = data["instruction"]
66+
67+
max_attempts = self.max_turns
68+
reward = 0
69+
tracker = stats_tracker.get(self.rollout_stat_scope)
70+
71+
with tracker.record_timing("run_agent_total"):
72+
error_count = 0.0
73+
attempts_used = 0.0
74+
for attempt in range(max_attempts):
75+
attempts_used = float(attempt + 1)
76+
try:
77+
with tracker.record_timing("openai_runner_run"):
78+
result = await OpenAIRunner.run(
79+
agent,
80+
input=content,
81+
session=session,
82+
run_config=run_config,
83+
max_turns=30,
84+
)
85+
except Exception as e:
86+
logger.error(f"Error running agent: {e}")
87+
error_count += 1.0
88+
break
89+
90+
with tracker.record_timing("env_validate_reward"):
91+
reward = env.reward()
92+
if reward is None:
93+
reward = 0.0
94+
tracker.scalar(
95+
attempt_reward=reward,
96+
attempt_index=float(attempt),
97+
input_chars=float(len(content) if content else 0.0),
98+
output_chars=float(
99+
len(getattr(result, "final_output", "") or "")
100+
),
101+
)
102+
103+
if isinstance(reward, float) and reward >= 0.99:
104+
tracker.scalar(success=1.0)
105+
break
106+
107+
if attempt < max_attempts - 1:
108+
content = f"""The previous attempt didn't complete the task successfully.
109+
Please try a different approach.
110+
Original task: {data["instruction"]}
111+
112+
Previous attempt result: {result.final_output}
113+
114+
Please analyze what went wrong and try again with a corrected approach."""
115+
else:
116+
content = f"""This is your final attempt. Please be extremely careful.
117+
Original task: {data["instruction"]}
118+
119+
Previous attempts: {result.final_output}
120+
121+
Please provide a final, carefully executed solution."""
122+
tracker.scalar(success=0.0)
123+
124+
tracker.scalar(
125+
final_reward=reward, attempts_used=attempts_used, errors=error_count
126+
)
127+
128+
client.set_final_reward(reward)
129+
130+
return reward
131+
132+
133+
class TerminalAgentWorkflow(RolloutWorkflow):
134+
def __init__(
135+
self,
136+
gconfig: GenerationHyperparameters,
137+
tokenizer: PreTrainedTokenizerFast,
138+
dump_dir: str | None = None,
139+
rollout_stat_scope: str = "rollout",
140+
n_trajs: int = 1,
141+
max_tokens: int = 32768,
142+
max_turns: int = 8,
143+
):
144+
self.gconfig = gconfig
145+
self.gconfig.n_samples = 1
146+
self.tokenizer = tokenizer
147+
self.dump_dir = dump_dir
148+
self.max_tokens = max_tokens
149+
self.rollout_stat_scope = rollout_stat_scope
150+
if self.dump_dir is not None and not os.path.exists(self.dump_dir):
151+
os.makedirs(self.dump_dir, exist_ok=True)
152+
153+
# Search hyper-parameters
154+
self.n_trajs = n_trajs
155+
self.agent = TerminalAgent(
156+
tokenizer=self.tokenizer,
157+
max_tokens_per_turn=self.gconfig.max_new_tokens,
158+
max_turns=max_turns,
159+
max_total_tokens=max_tokens,
160+
dump_dir=self.dump_dir,
161+
rollout_stat_scope=self.rollout_stat_scope,
162+
)
163+
164+
async def arun_episode(self, engine, data):
165+
clients = [
166+
ArealOpenAI(
167+
engine=engine, tokenizer=self.tokenizer, tool_call_parser="qwen25"
168+
)
169+
for _ in range(self.n_trajs)
170+
]
171+
172+
# Collect trajectories
173+
rewards = await asyncio.gather(
174+
*[
175+
self.agent.run_agent(
176+
data=data,
177+
client=clients[i],
178+
)
179+
for i in range(self.n_trajs)
180+
]
181+
)
182+
for reward in rewards:
183+
stats_tracker.get(self.rollout_stat_scope).scalar(reward=reward)
184+
185+
interactions_with_reward = {}
186+
for client in clients:
187+
client.apply_reward_discount(turn_discount=0.9)
188+
interactions = client.export_interactions(style="individual")
189+
interactions_with_reward.update(interactions)
190+
return interactions_with_reward

examples/openai-agents/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ cluster:
1919
type: nfs
2020
nfs_record_root: /tmp/areal/name_resolve
2121

22-
allocation_mode: sglang.d4p1t1+d4p1t1
22+
allocation_mode: sglang.d4p1t1+d1p1t1c4
2323

2424
rollout:
2525
experiment_name: ${experiment_name}

0 commit comments

Comments
 (0)