Skip to content

Commit 53b11ed

Browse files
committed
add terminal agent with openai-agents
Signed-off-by: CormickKneey <[email protected]>
1 parent 97764f8 commit 53b11ed

20 files changed

+2777
-19
lines changed

areal/dataset/__init__.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,14 @@
1010
from transformers.processing_utils import ProcessorMixin
1111
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
1212

13-
VALID_DATASETS = ["gsm8k", "clevr_count_70k", "geometry3k", "hh-rlhf", "torl_data"]
13+
VALID_DATASETS = [
14+
"gsm8k",
15+
"clevr_count_70k",
16+
"geometry3k",
17+
"hh-rlhf",
18+
"torl_data",
19+
"terminal_bench",
20+
]
1421

1522
logger = logging.getLogger("Dataset")
1623

@@ -24,7 +31,6 @@ def _get_custom_dataset(
2431
processor: Optional["ProcessorMixin"] = None,
2532
**kwargs,
2633
) -> "Dataset":
27-
2834
if "gsm8k" in path and type == "sft":
2935
from .gsm8k import get_gsm8k_sft_dataset
3036

@@ -105,6 +111,16 @@ def _get_custom_dataset(
105111
max_length=max_length,
106112
**kwargs,
107113
)
114+
elif "terminal_bench" in path and type == "rl":
115+
from .terminal_bench import get_terminal_bench_rl_dataset
116+
117+
return get_terminal_bench_rl_dataset(
118+
path=path,
119+
split=split,
120+
tokenizer=tokenizer,
121+
max_length=max_length,
122+
**kwargs,
123+
)
108124
else:
109125
raise ValueError(
110126
f"Dataset {path} with split {split} and training type {type} is not supported. "

areal/dataset/terminal_bench.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from typing import TYPE_CHECKING
2+
3+
from datasets import load_dataset
4+
5+
if TYPE_CHECKING:
6+
from transformers import PreTrainedTokenizerFast
7+
8+
9+
def get_terminal_bench_rl_dataset(
10+
path: str,
11+
split: str,
12+
tokenizer: "PreTrainedTokenizerFast",
13+
max_length: int | None = None,
14+
):
15+
"""Load terminal-bench dataset for RL training.
16+
17+
The dataset should be in parquet format with the following columns:
18+
- prompt: The formatted prompt for the task
19+
- task_name: Name of the task
20+
- instruction: Raw instruction text
21+
- extra_info: JSON string containing task metadata
22+
"""
23+
# Load from parquet file
24+
dataset = load_dataset("parquet", data_files={split: path}, split=split)
25+
26+
# The dataset already has the right format from the converter:
27+
# - prompt: contains the formatted conversation
28+
# - task_name, instruction, extra_info: metadata fields
29+
30+
# For RL training, we need to extract messages from the prompt or extra_info
31+
def process(sample):
32+
# The prompt is already formatted, but we need to extract the instruction
33+
# to create a messages structure for the workflow
34+
instruction = sample.get("instruction", "")
35+
task_name = sample.get("task_name", "")
36+
dockerfile_contents = sample.get("dockerfile_contents", "")
37+
38+
# Return data in the format expected by the workflow
39+
return {
40+
"instruction": instruction,
41+
"task_name": task_name,
42+
"dockerfile_contents": dockerfile_contents,
43+
"extra_info": sample.get("extra_info", ""),
44+
"data_source": sample.get("data_source", "terminal_bench"),
45+
}
46+
47+
dataset = dataset.map(process)
48+
49+
# Filter out sequences longer than max_length if specified
50+
if max_length is not None:
51+
52+
def filter_length(samples):
53+
# Tokenize instructions in batches for efficiency
54+
instructions = samples["instruction"]
55+
tokens_list = tokenizer(instructions, add_special_tokens=False)["input_ids"]
56+
return [len(tokens) <= max_length for tokens in tokens_list]
57+
58+
dataset = dataset.filter(filter_length, batched=True)
59+
60+
return dataset

areal/experimental/openai/client.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from openai.types.chat import (
1616
ChatCompletion,
1717
ChatCompletionMessage,
18-
ChatCompletionToolMessageParam,
1918
ChatCompletionToolParam,
2019
)
2120
from openai.types.chat.chat_completion import Choice
@@ -277,22 +276,11 @@ async def create(
277276
if is_omitted(input):
278277
raise ValueError("input is required for Responses.create")
279278

280-
def _convert_tool_output_format(
281-
item: dict,
282-
) -> ChatCompletionToolMessageParam | dict:
279+
def _convert_tool_output_format(item: dict) -> dict:
283280
"""Convert custom tool output format to standard chat template format.
284281
285-
Converts openai.types.responses.response_input_item_param.FunctionCallOutput
286-
to openai.types.chat.ChatCompletionToolMessageParam.
287-
288-
Args:
289-
item: Input dict, could be FunctionCallOutput from openai-agents SDK
290-
with format: {'call_id': str, 'output': str, 'type': 'function_call_output'}
291-
292-
Returns:
293-
ChatCompletionToolMessageParam (TypedDict) with format:
294-
{'role': 'tool', 'content': str, 'tool_call_id': str}
295-
or the original dict if conversion is not needed.
282+
Converts from: {'call_id': ..., 'output': ..., 'type': 'function_call_output'}
283+
To: {'role': 'tool', 'content': ..., 'tool_call_id': ...}
296284
"""
297285
if (
298286
isinstance(item, dict)
59.3 KB
Loading

examples/__init__.py

Whitespace-only changes.
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
import asyncio
2+
import logging
3+
import os
4+
5+
from agents import Agent as OpenAIAgent
6+
from agents import ModelSettings, OpenAIProvider, RunConfig, SQLiteSession
7+
from agents import Runner as OpenAIRunner
8+
from terminal.env import TerminalEnv
9+
from terminal.judge_agent import JudgeAgent, judge_from_env
10+
from terminal.prompt import SYSTEM_PROMPT
11+
from transformers import PreTrainedTokenizerFast
12+
13+
from areal.api.cli_args import GenerationHyperparameters
14+
from areal.api.workflow_api import RolloutWorkflow
15+
from areal.experimental.openai import ArealOpenAI
16+
from areal.utils import stats_tracker
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
class TerminalAgent:
22+
def __init__(
23+
self,
24+
tokenizer: PreTrainedTokenizerFast,
25+
max_tokens_per_turn: int = 1024,
26+
max_turns: int = 8,
27+
max_total_tokens: int = 32768,
28+
dump_dir: str | None = None,
29+
rollout_stat_scope: str = "rollout",
30+
):
31+
self.tokenizer = tokenizer
32+
self.max_tokens_per_turn = max_tokens_per_turn
33+
self.max_turns = max_turns
34+
self.max_total_tokens = max_total_tokens
35+
self.dump_dir = dump_dir
36+
self.rollout_stat_scope = rollout_stat_scope
37+
38+
async def run_agent(self, data, client: ArealOpenAI, judge_agent: JudgeAgent):
39+
"""Run the agent workflow for terminal task execution."""
40+
run_config = RunConfig(
41+
model_provider=OpenAIProvider(
42+
openai_client=client,
43+
use_responses=True,
44+
),
45+
tracing_disabled=True,
46+
model_settings=ModelSettings(
47+
temperature=1.0,
48+
extra_args={"max_completion_tokens": self.max_tokens_per_turn},
49+
tool_choice="auto",
50+
store=True,
51+
),
52+
)
53+
54+
async with TerminalEnv(
55+
task_name=data["task_name"],
56+
dump_dir=self.dump_dir,
57+
rollout_stat_scope=self.rollout_stat_scope,
58+
) as env:
59+
# Create agent workflow with terminal tools
60+
agent = OpenAIAgent(
61+
name="Terminal Task Agent",
62+
instructions=SYSTEM_PROMPT,
63+
tools=env.get_tools(),
64+
)
65+
session = SQLiteSession("terminal")
66+
content = data["instruction"]
67+
68+
max_attempts = self.max_turns
69+
reward = 0
70+
judge_reward = 0
71+
tracker = stats_tracker.get(self.rollout_stat_scope)
72+
73+
with tracker.record_timing("run_agent_total"):
74+
error_count = 0.0
75+
attempts_used = 0.0
76+
for attempt in range(max_attempts):
77+
attempts_used = float(attempt + 1)
78+
try:
79+
with tracker.record_timing("openai_runner_run"):
80+
result = await OpenAIRunner.run(
81+
agent,
82+
input=content,
83+
session=session,
84+
run_config=run_config,
85+
max_turns=30,
86+
)
87+
except Exception as e:
88+
logger.error(f"Error running agent: {e}")
89+
error_count += 1.0
90+
break
91+
92+
with tracker.record_timing("env_validate_reward"):
93+
reward = env.reward()
94+
if judge_agent:
95+
with tracker.record_timing("judge_agent_reward"):
96+
judge_reward = await judge_agent.get_reward_from_judge(
97+
session=session,
98+
dockerfile_contents=data["dockerfile_contents"],
99+
)
100+
if judge_reward >= 0 and reward < 0.99:
101+
reward = reward * 0.65 + judge_reward * 0.35
102+
103+
tracker.scalar(
104+
reward=reward,
105+
judge_reward=judge_reward,
106+
attempt_index=float(attempt),
107+
input_chars=float(len(content) if content else 0.0),
108+
output_chars=float(
109+
len(getattr(result, "final_output", "") or "")
110+
),
111+
)
112+
113+
if isinstance(reward, float) and reward >= 0.99:
114+
tracker.scalar(success=1.0)
115+
break
116+
117+
if attempt < max_attempts - 1:
118+
content = f"""The previous attempt didn't complete the task successfully.
119+
Please try a different approach.
120+
Original task: {data["instruction"]}
121+
122+
Previous attempt result: {result.final_output}
123+
124+
Please analyze what went wrong and try again with a corrected approach."""
125+
else:
126+
content = f"""This is your final attempt. Please be extremely careful.
127+
Original task: {data["instruction"]}
128+
129+
Previous attempts: {result.final_output}
130+
131+
Please provide a final, carefully executed solution."""
132+
tracker.scalar(success=0.0)
133+
134+
tracker.scalar(
135+
final_reward=reward, attempts_used=attempts_used, errors=error_count
136+
)
137+
138+
client.set_final_reward(reward)
139+
140+
return reward
141+
142+
143+
class TerminalAgentWorkflow(RolloutWorkflow):
144+
def __init__(
145+
self,
146+
gconfig: GenerationHyperparameters,
147+
tokenizer: PreTrainedTokenizerFast,
148+
dump_dir: str | None = None,
149+
rollout_stat_scope: str = "rollout",
150+
n_trajs: int = 1,
151+
max_tokens: int = 32768,
152+
max_turns: int = 8,
153+
):
154+
self.gconfig = gconfig
155+
self.gconfig.n_samples = 1
156+
self.tokenizer = tokenizer
157+
self.dump_dir = dump_dir
158+
self.max_tokens = max_tokens
159+
self.rollout_stat_scope = rollout_stat_scope
160+
if self.dump_dir is not None and not os.path.exists(self.dump_dir):
161+
os.makedirs(self.dump_dir, exist_ok=True)
162+
163+
# Search hyper-parameters
164+
self.n_trajs = n_trajs
165+
self.agent = TerminalAgent(
166+
tokenizer=self.tokenizer,
167+
max_tokens_per_turn=self.gconfig.max_new_tokens,
168+
max_turns=max_turns,
169+
max_total_tokens=max_tokens,
170+
dump_dir=self.dump_dir,
171+
rollout_stat_scope=self.rollout_stat_scope,
172+
)
173+
self.judge_agent = judge_from_env()
174+
175+
async def arun_episode(self, engine, data):
176+
clients = [
177+
ArealOpenAI(
178+
engine=engine, tokenizer=self.tokenizer, tool_call_parser="qwen25"
179+
)
180+
for _ in range(self.n_trajs)
181+
]
182+
183+
# Collect trajectories
184+
rewards = await asyncio.gather(
185+
*[
186+
self.agent.run_agent(
187+
data=data,
188+
client=clients[i],
189+
judge_agent=self.judge_agent,
190+
)
191+
for i in range(self.n_trajs)
192+
]
193+
)
194+
for reward in rewards:
195+
stats_tracker.get(self.rollout_stat_scope).scalar(reward=reward)
196+
197+
interactions_with_reward = {}
198+
for client in clients:
199+
client.apply_reward_discount(turn_discount=0.9)
200+
interactions = client.export_interactions(style="individual")
201+
interactions_with_reward.update(interactions)
202+
return interactions_with_reward

examples/openai-agents/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ cluster:
1919
type: nfs
2020
nfs_record_root: /tmp/areal/name_resolve
2121

22-
allocation_mode: sglang.d4p1t1+d4p1t1
22+
allocation_mode: sglang.d4p1t1+d1p1t1c4
2323

2424
rollout:
2525
experiment_name: ${experiment_name}

0 commit comments

Comments
 (0)