Skip to content

Commit 23e9d55

Browse files
authored
Minor updates for BOTS example (#385)
1 parent a52cc3a commit 23e9d55

File tree

2 files changed

+6
-14
lines changed

2 files changed

+6
-14
lines changed

examples/bots/workflow/bots_math_boxed_reward.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn
44
from trinity.utils.eval_utils import validate_think_pattern
55

6-
from .bots_reward import compute_score
7-
86

97
@REWARD_FUNCTIONS.register_module("bots_math_boxed_reward")
108
class BOTSMathBoxedRewardFn(RewardFn):
@@ -24,6 +22,8 @@ def __call__( # type: ignore
2422
format_score_coef: Optional[float] = 0.1,
2523
**kwargs,
2624
) -> dict[str, float]:
25+
from trinity.plugins.bots_reward import compute_score
26+
2727
accuracy_score = compute_score(response, truth)
2828

2929
format_score = 0.0

examples/bots/workflow/bots_math_boxed_workflow.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,23 @@
33
from trinity.common.workflows.customized_math_workflows import MathBoxedWorkflow, Task
44
from trinity.common.workflows.workflow import WORKFLOWS
55

6-
from .bots_math_boxed_reward import BOTSMathBoxedRewardFn
7-
86

97
@WORKFLOWS.register_module("bots_math_boxed_workflow")
108
class BOTSMathBoxedWorkflow(MathBoxedWorkflow):
119
"""A workflow for math tasks that give answers in boxed format for BOTS."""
1210

1311
def reset(self, task: Task):
1412
super().reset(task)
13+
from trinity.plugins.bots_math_boxed_reward import BOTSMathBoxedRewardFn
14+
1515
self.reward_fn = BOTSMathBoxedRewardFn(**self.reward_fn_args)
16+
self.task_desc = nested_query(self.format_args.prompt_key, self.raw_task)
17+
self.truth = nested_query(self.format_args.response_key, self.raw_task)
1618

1719
def format_messages(self):
1820
# the prompts are already in message format
1921
return self.task_desc
2022

21-
@property
22-
def task_desc(self) -> Union[str, None]: # type: ignore [override]
23-
prompt_key = self.format_args.prompt_key
24-
return nested_query(prompt_key, self.raw_task) # type: ignore
25-
26-
@property
27-
def truth(self) -> Union[str, None]: # type: ignore [override]
28-
response_key = self.format_args.response_key
29-
return nested_query(response_key, self.raw_task)
30-
3123

3224
def nested_query(query_key: str, query_obj: Union[dict, None]):
3325
# support nested query for a dict given query_keys split by '.'

0 commit comments

Comments
 (0)