Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
title: Use the Python API
- local: adding-a-custom-task
title: Add a custom task
- local: offline-evaluation
title: Offline evaluation
- local: adding-a-new-metric
title: Add a custom metric
- local: evaluating-a-custom-model
Expand Down
46 changes: 46 additions & 0 deletions docs/source/offline-evaluation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Offline evaluation using local data files

If you are prototyping a task based on files that are not yet hosted on the
Hub, you can take advantage of the `hf_data_files` argument to point lighteval
at local JSON/CSV resources. This makes it easy to evaluate datasets that live
in your repo or that are generated on the fly.

Internally, `hf_data_files` is passed directly to the `data_files` parameter of `datasets.load_dataset` ([docs]((https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset))).

See [adding a custom task](adding-a-custom-task) for more information on how to create a custom task.

```python
from pathlib import Path

from lighteval.metrics import Metrics
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.requests import Doc


def local_prompt(line: dict, task_name: str) -> Doc:
return Doc(
task_name=task_name,
query=line["question"],
choices=line["choices"],
gold_index=line["answer"]
)


local_data = Path(__file__).parent / "samples" / "faq.jsonl"

local_task = LightevalTaskConfig(
name="faq_eval",
prompt_function=local_prompt,
hf_repo="json", # Built-in streaming loader for json/jsonl files
hf_subset="default",
hf_data_files=str(local_data), # Can also be a dict mapping split names to paths
evaluation_splits=["train"],
metrics=[Metrics.ACCURACY],
)
```

Once the config is registered in `TASKS_TABLE`, running the task with
`--custom-tasks path/to/your_file.py` will automatically load the local data
files. You can also pass a dictionary to `hf_data_files` (e.g.
`{"train": "train.jsonl", "validation": "val.jsonl"}`) to expose multiple
splits.
105 changes: 105 additions & 0 deletions examples/custom_tasks_templates/custom_yourbench_task_from_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# MIT License

# Copyright (c) 2024 The HuggingFace Team

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import json
import logging
import tempfile
from functools import partial
from pathlib import Path

from custom_yourbench_task_mcq import yourbench_prompt
from datasets import Dataset, DatasetDict

from lighteval.metrics.metrics import Metrics
from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig


logger = logging.getLogger(__name__)

save_dir = str(tempfile.mkdtemp())

ds = DatasetDict(
{
"train": Dataset.from_dict(
{
"question": ["What is 2+2?", "Capital of France?"],
"choices": [["1", "2", "3", "4"], ["Paris", "Berlin", "Rome", "Madrid"]],
"gold": [[3], [0]],
}
)
}
)


CustomTaskConfig = partial(
LightevalTaskConfig,
prompt_function=yourbench_prompt,
hf_avail_splits=["train"],
evaluation_splits=["train"],
few_shots_split=None,
few_shots_select=None,
generation_size=16,
metrics=[Metrics.gpqa_instruct_metric],
version=0,
)

# Example 1: save to disk (huggingface format) ####

ds.save_to_disk(save_dir)

yourbench_mcq = CustomTaskConfig(
name="tiny_mcqa_dataset",
hf_repo="arrow",
hf_subset="default",
hf_data_files=f"{save_dir}/**/*.arrow",
)

task = LightevalTask(yourbench_mcq)
eval_docs = task.eval_docs()

print("\n>>READING TASK FROM ARROW<<")
for doc in eval_docs:
print(doc)


# Example 2: jsonlines format ####

jsonl_path = Path(save_dir) / "train.jsonl"
with open(jsonl_path, "w") as f:
for row in ds["train"]:
f.write(json.dumps(row) + "\n")

yourbench_mcq = CustomTaskConfig(
name="tiny_mcqa_dataset",
hf_repo="json",
hf_subset="default",
hf_data_files=str(jsonl_path),
)

task = LightevalTask(yourbench_mcq)
eval_docs = task.eval_docs()

print("\n>>READING TASK FROM JSONLINES<<")
for doc in eval_docs:
print(doc)

# TASKS_TABLE = [yourbench_mcq]
7 changes: 6 additions & 1 deletion src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import logging
import random
from dataclasses import asdict, dataclass, field
from typing import Callable
from typing import Callable, Mapping, Sequence

from datasets import DatasetDict, load_dataset
from huggingface_hub import TextGenerationInputGrammarType
Expand Down Expand Up @@ -59,6 +59,8 @@ class LightevalTaskConfig:
row to Doc objects for evaluation. Takes a dataset row dict and task
name as input.
hf_repo (str): HuggingFace Hub repository path containing the evaluation dataset.
hf_data_files (str | Sequence[str] | Mapping[str, str | Sequence[str]] | None):
Data files to load. Same as `data_files` argument of `datasets.load_dataset`.
hf_subset (str): Dataset subset/configuration name to use for this task.
metrics (ListLike[Metric | Metrics]): List of metrics or metric enums to compute for this task.

Expand Down Expand Up @@ -113,6 +115,7 @@ class LightevalTaskConfig:
hf_repo: str
hf_subset: str
metrics: ListLike[Metric | Metrics] # Accept both Metric objects and Metrics enums
hf_data_files: str | Sequence[str] | Mapping[str, str | Sequence[str]] | None = None

# Inspect AI compatible parameters
solver: None = None
Expand Down Expand Up @@ -219,6 +222,7 @@ def __init__(

# Dataset info
self.dataset_path = config.hf_repo
self.data_files = config.hf_data_files
self.dataset_config_name = config.hf_subset
self.dataset_revision = config.hf_revision
self.dataset_filter = config.hf_filter
Expand Down Expand Up @@ -454,6 +458,7 @@ def download_dataset_worker(
path=task.dataset_path,
name=task.dataset_config_name,
revision=task.dataset_revision,
data_files=task.data_files,
)

if task.dataset_filter is not None:
Expand Down
21 changes: 21 additions & 0 deletions tests/unit/tasks/test_lighteval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,24 @@ def test_dataset_filter():
filtered_docs = task.eval_docs()
assert len(filtered_docs) == 1
assert filtered_docs[0].query == "hi"


def test_hf_data_files(tmp_path):
# create a small jsonl dataset
data_file = tmp_path / "data.jsonl"
src_docs = [f"document {i}" for i in range(3)]
data_file.write_text("\n".join([f'{{"text": "{doc}"}}' for doc in src_docs]))

cfg = LightevalTaskConfig(
name="test_data_files",
prompt_function=dummy_prompt_function,
hf_repo="json",
hf_subset="default",
metrics=[],
evaluation_splits=["train"],
hf_data_files=str(data_file),
)
task = LightevalTask(cfg)

eval_docs = task.eval_docs()
assert [doc.query for doc in eval_docs] == src_docs
Loading