Skip to content

Commit 43d616f

Browse files
authored
VisualPuzzles (EvolvingLMMs-Lab#637)
1 parent bbff985 commit 43d616f

File tree

3 files changed

+162
-0
lines changed

3 files changed

+162
-0
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
dataset_path: neulab/VisualPuzzles
2+
dataset_kwargs:
3+
token: True
4+
task: "VisualPuzzles_cot"
5+
test_split: train
6+
output_type: generate_until
7+
doc_to_visual: !function utils.VisualPuzzles_doc_to_visual
8+
doc_to_text: !function utils.VisualPuzzles_doc_to_text
9+
doc_to_target: "answer"
10+
generation_kwargs:
11+
max_new_tokens: 4096
12+
temperature: 0
13+
top_p: 1.0
14+
num_beams: 1
15+
do_sample: false
16+
metric_list:
17+
- metric: exact_match
18+
aggregation: mean
19+
higher_is_better: true
20+
ignore_case: true
21+
ignore_punctuation: true
22+
process_results: !function utils.VisualPuzzles_process_result
23+
metadata:
24+
- version: 0.0
25+
26+
lmms_eval_specific_kwargs:
27+
default:
28+
prompt: "COT_PROMPT"
29+
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
dataset_path: neulab/VisualPuzzles
2+
dataset_kwargs:
3+
token: True
4+
task: "VisualPuzzles_direct"
5+
test_split: train
6+
output_type: generate_until
7+
doc_to_visual: !function utils.VisualPuzzles_doc_to_visual
8+
doc_to_text: !function utils.VisualPuzzles_doc_to_text
9+
doc_to_target: "answer"
10+
generation_kwargs:
11+
max_new_tokens: 4096
12+
temperature: 0
13+
top_p: 1.0
14+
num_beams: 1
15+
do_sample: false
16+
metric_list:
17+
- metric: exact_match
18+
aggregation: mean
19+
higher_is_better: true
20+
ignore_case: true
21+
ignore_punctuation: true
22+
process_results: !function utils.VisualPuzzles_process_result
23+
metadata:
24+
- version: 0.0
25+
26+
lmms_eval_specific_kwargs:
27+
default:
28+
prompt: "MULTI_CHOICE_DIRECT_PROMPT"
29+
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from PIL import Image
2+
import random
3+
import numpy as np
4+
import re
5+
import json
6+
import os
7+
import random
8+
9+
MULTI_CHOICE_DIRECT_PROMPT = "Answer the question with the option's letter from the given choices directly."
10+
COT_PROMPT = "Solve the multiple-choice question and then answer with the option letter from the given choices. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of options. Think step by step before answering."
11+
PROMPTS = {'MULTI_CHOICE_DIRECT_PROMPT': MULTI_CHOICE_DIRECT_PROMPT, 'COT_PROMPT': COT_PROMPT}
12+
13+
def VisualPuzzles_doc_to_visual(doc):
14+
return [doc['image']]
15+
16+
def VisualPuzzles_doc_to_text(doc, lmms_eval_specific_kwargs):
17+
question = 'Question: ' + doc["question"].strip()
18+
options = doc['options']
19+
if options != None: question += '\nOptions:\n(A) ' + options[0] + '\n(B) ' + options[1] + '\n(C) ' + options[2] + '\n(D) ' + options[3]
20+
else: question += '\nOptions: Choose from (A) (B) (C) (D) in the image.'
21+
question += '\n' + PROMPTS[lmms_eval_specific_kwargs['prompt']]
22+
return question
23+
24+
def parse_response(response, all_choices, index2ans):
25+
"""
26+
Return the last letter appearing after 'ANSWER:' in the input text.
27+
If there's no match, return None.
28+
"""
29+
pattern = r'Answer:\s*\(([A-Za-z])\)' # Answer: (A)
30+
matches = re.findall(pattern, response)
31+
if matches:
32+
for match in matches[::-1]:
33+
if match in all_choices or match.upper() in all_choices: return match
34+
pattern = r'(?<!Final )Answer:\s*([A-Za-z])' # Answer: A
35+
matches = re.findall(pattern, response)
36+
if matches:
37+
for match in matches[::-1]:
38+
if match in all_choices or match.upper() in all_choices: return match
39+
pattern = r'Answer:\s*([A-Za-z])' # Answer: A
40+
matches = re.findall(pattern, response)
41+
if matches:
42+
for match in matches[::-1]:
43+
if match in all_choices or match.upper() in all_choices: return match
44+
pattern = r'\s*\(([A-Za-z])\)' # e.g., (A) (B) (C) (D)
45+
matches = re.findall(pattern, response)
46+
if matches:
47+
for match in matches[::-1]:
48+
if match in all_choices or match.upper() in all_choices: return match
49+
response = ' ' + response.strip()
50+
pattern = r'\s*([A-Za-z])\)' # e.g., A) B) C) D)
51+
matches = re.findall(pattern, response)
52+
if matches:
53+
for match in matches[::-1]:
54+
if match in all_choices or match.upper() in all_choices: return match
55+
pattern = r'\s*\{([A-Za-z])\}' # e.g., {A} {B} {C} {D}
56+
matches = re.findall(pattern, response)
57+
if matches:
58+
for match in matches[::-1]:
59+
if match in all_choices or match.upper() in all_choices: return match
60+
pattern = r'\s*\$([A-Za-z])\$' # e.g., $A$, $B$, $C$, $D$
61+
matches = re.findall(pattern, response)
62+
if matches:
63+
for match in matches[::-1]:
64+
if match in all_choices or match.upper() in all_choices: return match
65+
pattern = r" ([A-Da-d])\." # e.g., A. B. C. D.
66+
matches = re.findall(pattern, response)
67+
if matches:
68+
for match in matches[::-1]:
69+
if match in all_choices or match.upper() in all_choices:
70+
return match
71+
pattern = r" ([A-Da-d])" # e.g., A B C D
72+
matches = re.findall(pattern, response)
73+
if matches and len(response) <= 5:
74+
for match in matches[::-1]:
75+
if match in all_choices or match.upper() in all_choices:
76+
return match
77+
if index2ans != None:
78+
for index in all_choices:
79+
ans = index2ans[index]
80+
if f'answer: {ans}' in response.lower(): return index
81+
if f'answer:{ans}' in response.lower(): return index
82+
last_found = None
83+
last_index = -1
84+
for index in all_choices:
85+
ans = index2ans[index]
86+
idx = response.rfind(ans)
87+
if idx > last_index:
88+
last_found = index
89+
last_index = idx
90+
if last_found: return last_found
91+
return random.choice(all_choices)
92+
93+
def VisualPuzzles_process_result(doc, results):
94+
print(f"results: {results}")
95+
pred = results[0].strip()
96+
all_choices = ['A', 'B', 'C', 'D']
97+
if doc['options'] == None: index2ans = None
98+
else: index2ans = {all_choices[i]: doc['options'][i] for i in range(4)}
99+
pred = parse_response(pred, all_choices, index2ans)
100+
target = doc['answer']
101+
if pred.lower() == target.lower(): return {"exact_match": 1.0}
102+
return {"exact_match": 0.0}
103+
104+

0 commit comments

Comments
 (0)