Skip to content

Commit babeec9

Browse files
authored
One file one task definition (#1059)
moves all the prompts from `default_prompts.py` to their respective task file
1 parent ad58fed commit babeec9

File tree

110 files changed

+2989
-4050
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

110 files changed

+2989
-4050
lines changed

examples/custom_tasks_tests.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@
2020
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
# SOFTWARE.
2222

23-
import lighteval.tasks.default_prompts as prompt
2423
from lighteval.metrics.metrics import Metrics
2524
from lighteval.tasks.lighteval_task import LightevalTaskConfig
25+
from lighteval.tasks.tasks.gpqa import gpqa_instruct_prompt
26+
from lighteval.tasks.tasks.gsm8k import gsm8k_prompt
2627

2728

2829
gsm8k_test = LightevalTaskConfig(
2930
name="gsm8k_test",
30-
prompt_function=prompt.gsm8k,
31+
prompt_function=gsm8k_prompt,
3132
hf_repo="gsm8k",
3233
hf_subset="main",
3334
hf_avail_splits=["train", "test"],
@@ -42,7 +43,7 @@
4243

4344
gpqa_diamond_test = LightevalTaskConfig(
4445
name="gpqa:diamond_test",
45-
prompt_function=prompt.gpqa_instruct,
46+
prompt_function=gpqa_instruct_prompt,
4647
hf_repo="Idavidrein/gpqa",
4748
hf_subset="gpqa_diamond",
4849
hf_avail_splits=["train"],

examples/nanotron/custom_evaluation_tasks.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,21 @@
2828
"""
2929

3030
import re
31+
from string import ascii_uppercase
3132
from typing import List, Tuple
3233

33-
import lighteval.tasks.default_prompts as prompt
3434
from lighteval.metrics.metrics import Metrics
3535
from lighteval.metrics.normalizations import LogProbCharNorm, helm_normalizer, math_normalizer
36-
from lighteval.tasks.default_prompts import LETTER_INDICES
3736
from lighteval.tasks.lighteval_task import LightevalTaskConfig
3837
from lighteval.tasks.requests import Doc
38+
from lighteval.tasks.tasks.arc import arc_prompt
39+
from lighteval.tasks.tasks.gsm8k import gsm8k_prompt
40+
from lighteval.tasks.tasks.math import math_prompt
41+
from lighteval.tasks.tasks.openbookqa import openbookqa_prompt
42+
from lighteval.tasks.tasks.piqa import piqa_prompt
43+
from lighteval.tasks.tasks.quac import quac_prompt
44+
from lighteval.tasks.tasks.triviaqa import triviaqa_prompt
45+
from lighteval.tasks.tasks.winogrande import winogrande_prompt
3946

4047

4148
_TASKS_STRINGS: List[Tuple[LightevalTaskConfig, str]] = []
@@ -48,7 +55,7 @@ def commonsense_qa_prompt(line, task_name: str = None):
4855
task_name=task_name,
4956
query=line["question"],
5057
choices=[f" {c}" for c in line["choices"]["text"]],
51-
gold_index=LETTER_INDICES.index(line["answerKey"].strip()),
58+
gold_index=ascii_uppercase.index(line["answerKey"].strip()),
5259
instruction="",
5360
)
5461

@@ -99,7 +106,7 @@ def preprocess(text):
99106
),
100107
LightevalTaskConfig(
101108
name="winogrande",
102-
prompt_function=prompt.winogrande,
109+
prompt_function=winogrande_prompt,
103110
hf_repo="winogrande",
104111
hf_subset="winogrande_xl",
105112
metrics=[
@@ -112,7 +119,7 @@ def preprocess(text):
112119
),
113120
LightevalTaskConfig(
114121
name="piqa",
115-
prompt_function=prompt.piqa_harness,
122+
prompt_function=piqa_prompt,
116123
hf_repo="piqa",
117124
hf_subset="plain_text",
118125
metrics=[
@@ -139,7 +146,7 @@ def preprocess(text):
139146
),
140147
LightevalTaskConfig(
141148
name="openbookqa",
142-
prompt_function=prompt.openbookqa,
149+
prompt_function=openbookqa_prompt,
143150
hf_repo="openbookqa",
144151
hf_subset="main",
145152
metrics=[
@@ -152,7 +159,7 @@ def preprocess(text):
152159
),
153160
LightevalTaskConfig(
154161
name="arc:easy",
155-
prompt_function=prompt.arc,
162+
prompt_function=arc_prompt,
156163
hf_repo="ai2_arc",
157164
hf_subset="ARC-Easy",
158165
evaluation_splits=["test"],
@@ -167,7 +174,7 @@ def preprocess(text):
167174
),
168175
LightevalTaskConfig(
169176
name="arc:challenge",
170-
prompt_function=prompt.arc,
177+
prompt_function=arc_prompt,
171178
hf_repo="ai2_arc",
172179
hf_subset="ARC-Challenge",
173180
evaluation_splits=["test"],
@@ -216,7 +223,7 @@ def natural_questions_prompt(line, task_name: str = None):
216223
WORLD_KNOWLEDGE_TASKS = [
217224
LightevalTaskConfig(
218225
name="trivia_qa",
219-
prompt_function=prompt.triviaqa,
226+
prompt_function=triviaqa_prompt,
220227
hf_repo="trivia_qa",
221228
hf_subset="rc.nocontext",
222229
metrics=[
@@ -266,7 +273,7 @@ def boolq_prompt(line, task_name: str = None):
266273
),
267274
LightevalTaskConfig(
268275
name="quac",
269-
prompt_function=prompt.quac,
276+
prompt_function=quac_prompt,
270277
hf_repo="lighteval/quac_helm",
271278
hf_subset="deault",
272279
metrics=[
@@ -290,7 +297,7 @@ class CustomMathEvaluationTask(LightevalTaskConfig):
290297
def __init__(
291298
self,
292299
name,
293-
prompt_function=prompt.math,
300+
prompt_function=math_prompt,
294301
hf_repo="DigitalLearningGmbH/MATH-lighteval",
295302
hf_subset=None,
296303
metrics=[
@@ -329,7 +336,7 @@ def __init__(
329336
]
330337
GSM8K = LightevalTaskConfig(
331338
name="gsm8k",
332-
prompt_function=prompt.gsm8k,
339+
prompt_function=gsm8k_prompt,
333340
hf_repo="gsm8k",
334341
hf_subset="main",
335342
hf_avail_splits=["train", "test"],
@@ -352,10 +359,10 @@ def mmlu_harness(line, task_name: str = None):
352359
topic = line["subject"]
353360
prompt = f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n"
354361
prompt += line["question"] + "\n"
355-
prompt += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["choices"])])
362+
prompt += "".join([f"{key}. {choice}\n" for key, choice in zip(ascii_uppercase, line["choices"])])
356363
prompt += "Answer:"
357364

358-
gold_ix = LETTER_INDICES.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"]
365+
gold_ix = ascii_uppercase.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"]
359366
"__few_shots" in line and line["__few_shots"] is True # We are adding few shots
360367

361368
return Doc(
@@ -590,7 +597,7 @@ def agi_eval_prompt(line, task_name: str = None):
590597
prompt += line["question"] + "\n" + "\n".join(cleaned_options) + "\n"
591598
prompt += "Answer: "
592599

593-
choices = LETTER_INDICES[: len(line["options"])]
600+
choices = ascii_uppercase[: len(line["options"])]
594601

595602
output = Doc(
596603
query=prompt,
@@ -599,7 +606,7 @@ def agi_eval_prompt(line, task_name: str = None):
599606

600607
if line["label"]:
601608
output.choices = choices
602-
output.gold_index = LETTER_INDICES.index(line["label"].strip())
609+
output.gold_index = ascii_uppercase.index(line["label"].strip())
603610
else:
604611
output.choices = [line["answer"]]
605612
output.gold_index = 0
@@ -616,7 +623,7 @@ def agi_eval_prompt_no_letters(line, task_name: str = None):
616623
output = Doc(
617624
query=line["question"],
618625
choices=cleaned_options,
619-
gold_index=LETTER_INDICES.index(line["label"].strip()),
626+
gold_index=ascii_uppercase.index(line["label"].strip()),
620627
instruction="",
621628
)
622629

0 commit comments

Comments
 (0)