|
3 | 3 | from trinity.common.workflows.customized_math_workflows import MathBoxedWorkflow, Task |
4 | 4 | from trinity.common.workflows.workflow import WORKFLOWS |
5 | 5 |
|
6 | | -from .bots_math_boxed_reward import BOTSMathBoxedRewardFn |
7 | | - |
8 | 6 |
|
9 | 7 | @WORKFLOWS.register_module("bots_math_boxed_workflow") |
10 | 8 | class BOTSMathBoxedWorkflow(MathBoxedWorkflow): |
11 | 9 | """A workflow for math tasks that give answers in boxed format for BOTS.""" |
12 | 10 |
|
13 | 11 | def reset(self, task: Task): |
14 | 12 | super().reset(task) |
| 13 | + from trinity.plugins.bots_math_boxed_reward import BOTSMathBoxedRewardFn |
| 14 | + |
15 | 15 | self.reward_fn = BOTSMathBoxedRewardFn(**self.reward_fn_args) |
| 16 | + self.task_desc = nested_query(self.format_args.prompt_key, self.raw_task) |
| 17 | + self.truth = nested_query(self.format_args.response_key, self.raw_task) |
16 | 18 |
|
17 | 19 | def format_messages(self): |
18 | 20 | # the prompts are already in message format |
19 | 21 | return self.task_desc |
20 | 22 |
|
21 | | - @property |
22 | | - def task_desc(self) -> Union[str, None]: # type: ignore [override] |
23 | | - prompt_key = self.format_args.prompt_key |
24 | | - return nested_query(prompt_key, self.raw_task) # type: ignore |
25 | | - |
26 | | - @property |
27 | | - def truth(self) -> Union[str, None]: # type: ignore [override] |
28 | | - response_key = self.format_args.response_key |
29 | | - return nested_query(response_key, self.raw_task) |
30 | | - |
31 | 23 |
|
32 | 24 | def nested_query(query_key: str, query_obj: Union[dict, None]): |
33 | 25 | # support nested query for a dict given query_keys split by '.' |
|
0 commit comments