88
99
1010def record_to_sample (record , field_spec : dict ):
11- """Convert a dataset record to a Sample based on field_spec."""
11+ """
12+ Used for multiple choice tasks because we often need to convert numeric
13+ labels to letters for the target.
14+ """
1215 input_text = record [field_spec ["input" ]]
1316
14- # Handle target - convert numeric labels to letters for multiple choice
15- target_letter = ascii_uppercase [record [field_spec ["target" ]]]
17+ target = record [field_spec ["target" ]]
18+ if isinstance (target , int ):
19+ target = ascii_uppercase [target ]
1620
17- # Get choices if specified
18- choices_list = None
19- if "choices" in field_spec :
20- choices_list = [record [choice_field ] for choice_field in field_spec ["choices" ]]
21+ choices_list = record [field_spec ["choices" ]]
22+
23+ metadata = field_spec .get ("metadata" , None )
24+
25+ if metadata :
26+ metadata = {name : record [name ] for name in metadata }
2127
2228 sample_kwargs = {
2329 "input" : input_text ,
24- "target" : target_letter ,
30+ "target" : target ,
31+ "choices" : choices_list ,
32+ "metadata" : metadata ,
2533 }
26- if choices_list :
27- sample_kwargs ["choices" ] = choices_list
2834
2935 return Sample (** sample_kwargs )
3036
@@ -35,28 +41,20 @@ def load_dataset(repo_id: str, revision: str = "main", task_config: dict = None)
3541 split = task_config .get ("splits" , "test" )
3642 field_spec = task_config ["field_spec" ]
3743
38- # Use custom function if choices are specified (for multiple choice with label conversion)
3944 if "choices" in field_spec :
40- dataset = hf_dataset (
41- path = repo_id ,
42- revision = revision ,
43- name = subset ,
44- split = split ,
45- sample_fields = lambda record : record_to_sample (record , field_spec ),
46- )
45+
46+ def sample_fields (record ):
47+ return record_to_sample (record , field_spec )
4748 else :
48- # For non-multiple-choice, use FieldSpec
49- dataset = hf_dataset (
50- path = repo_id ,
51- revision = revision ,
52- name = subset ,
53- split = split ,
54- sample_fields = FieldSpec (
55- input = field_spec ["input" ],
56- target = field_spec ["target" ],
57- metadata = field_spec .get ("metadata" , []),
58- ),
59- )
49+ sample_fields = FieldSpec (** field_spec )
50+
51+ dataset = hf_dataset (
52+ path = repo_id ,
53+ revision = revision ,
54+ name = subset ,
55+ split = split ,
56+ sample_fields = sample_fields ,
57+ )
6058
6159 return dataset
6260
0 commit comments