Skip to content

Commit 5a27f91

Browse files
authored
Merge pull request #1 from soukaryag/master
Interactive feature addition, big fixes, code abstraction and seperation
2 parents be2f81c + 268a8fa commit 5a27f91

File tree

8 files changed

+255
-87
lines changed

8 files changed

+255
-87
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
__pycache__/
2+
venv/
3+
/.vscode
4+
results.p

Procfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
web: sh setup.sh && streamlit run app.py

app.py

Lines changed: 88 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import streamlit as st
2+
import numpy as np
23

34
from argparse import Namespace
45

@@ -10,21 +11,15 @@
1011
import random
1112
import re
1213

13-
logger = logging.getLogger(__name__)
14+
from models.html_helper import HtmlHelper
15+
from models.args import Args
16+
from models.cache import Cache
1417

15-
INITIAL_INSTRUCTIONS_HTML = """<p style="font-size:1.em; font-weight: 300">👋 Welcome to the TextAttack demo app! Please select a model and an attack recipe from the dropdown.</p> <hr style="margin: 1.em 0;">"""
18+
logger = logging.getLogger(__name__)
1619

17-
from config import NUM_SAMPLES_TO_ATTACK, MODELS, ATTACK_RECIPES, HIDDEN_ATTACK_RECIPES, PRECOMPUTED_RESULTS_DICT_NAME
20+
from config import NUM_SAMPLES_TO_ATTACK, MODELS, ATTACK_RECIPES, HIDDEN_ATTACK_RECIPES, PRECOMPUTED_RESULTS_DICT_NAME, HISTORY
1821

19-
def load_precomputed_results():
20-
try:
21-
precomputed_results = pickle.load(open(PRECOMPUTED_RESULTS_DICT_NAME, "rb" ))
22-
except FileNotFoundError:
23-
precomputed_results = {}
24-
print(f'Found {len(precomputed_results)} keys in pre-computed results.')
25-
return precomputed_results
26-
27-
def load_attack(model_name, attack_recipe_name):
22+
def load_attack(model_name, attack_recipe_name, num_examples):
2823
# Load model.
2924
model_class_name = MODELS[model_name][0]
3025
logger.info(f"Loading transformers.AutoModelForSequenceClassification from '{model_class_name}'.")
@@ -34,54 +29,24 @@ def load_attack(model_name, attack_recipe_name):
3429
except OSError:
3530
logger.warn('Couldn\'t find tokenizer; defaulting to "bert-base-uncased".')
3631
tokenizer = textattack.models.tokenizers.AutoTokenizer("bert-base-uncased")
37-
setattr(model, "tokenizer", tokenizer)
32+
model = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
33+
3834
# Load attack.
3935
logger.info(f"Loading attack from recipe {attack_recipe_name}.")
40-
attack = eval(f"{ATTACK_RECIPES[attack_recipe_name]}(model)")
36+
attack = eval(f"{ATTACK_RECIPES[attack_recipe_name]}.build(model)")
37+
4138
# Load dataset.
4239
_, dataset_args = MODELS[model_name]
43-
dataset = textattack.datasets.HuggingFaceNlpDataset(
40+
dataset = textattack.datasets.HuggingFaceDataset(
4441
*dataset_args, shuffle=True
4542
)
43+
dataset.examples = dataset.examples[:num_examples]
4644
return model, attack, dataset
4745

48-
def improve_result_html(result_html):
49-
result_html = result_html.replace("color = bold", 'style="font-weight: bold;"')
50-
result_html = result_html.replace("color = underline", 'style="text-decoration: underline;"')
51-
result_html = result_html.replace('<font style="font-weight: bold;"', '<span style=""') # no bolding for now
52-
result_html = result_html.replace('<font style="text-decoration: underline;"', '<span style="text-decoration: underline;"')
53-
result_html = re.sub(r"<font\scolor\s=\s(\w.*?)>", r'<span style="background-color: \1; padding: 1.2px; font-weight: 600;">', result_html)
54-
# replace font colors with transparent highlight versions
55-
result_html = result_html.replace(': red', ': rgba(255, 0, 0, .7)') \
56-
.replace(': green', ': rgb(0, 255, 0, .7)') \
57-
.replace(': blue', ': rgb(0, 0, 255, .7)') \
58-
.replace(': gray', ': rgb(220, 220, 220, .7)')
59-
result_html = result_html.replace("</font>", "</span>")
60-
return result_html
61-
62-
def get_attack_result_status(attack_result):
63-
status_html = attack_result.goal_function_result_str(color_method='html')
64-
return improve_result_html(status_html)
65-
66-
def get_attack_result_html(idx, attack_result):
67-
result_status = get_attack_result_status(attack_result)
68-
result_html_lines = attack_result.str_lines(color_method='html')
69-
result_html_lines = [improve_result_html(line) for line in result_html_lines]
70-
rows = [
71-
['', result_status],
72-
['Input', result_html_lines[1]]
73-
]
74-
75-
if len(result_html_lines) > 2:
76-
rows.append(['Output', result_html_lines[2]])
77-
78-
table_html = '\n'.join((f'<b>{header}:</b> {body} <br>' if header else f'{body} <br>') for header,body in rows)
79-
return f'<h3>Result {idx+1}</h3> {table_html} <br>'
80-
8146
@st.cache
8247
def get_attack_recipe_prototype(attack_recipe_name):
8348
""" a sort of hacky way to print an attack recipe without loading a big model"""
84-
recipe = eval(textattack.commands.attack.attack_args.ATTACK_RECIPE_NAMES[attack_recipe_name])
49+
recipe = eval(textattack.commands.attack.attack_args.ATTACK_RECIPE_NAMES[attack_recipe_name]).build
8550
dummy_tokenizer = Namespace(**{ 'encode': None})
8651
dummy_model = Namespace(**{ 'tokenizer': dummy_tokenizer })
8752
recipe = recipe(dummy_model)
@@ -91,70 +56,81 @@ def get_attack_recipe_prototype(attack_recipe_name):
9156
del dummy_tokenizer
9257
return recipe_str
9358

59+
def display_history(fake_latency=False):
60+
history = PRECOMPUTE_CACHE.get(HISTORY)
61+
for idx, result in enumerate(history):
62+
if fake_latency: random_latency()
63+
st.markdown(HtmlHelper.get_attack_result_html(idx, result), unsafe_allow_html=True)
64+
65+
def random_latency():
66+
# Artificially inject a tiny bit of latency to provide
67+
# a feel of the attack _running_.
68+
time.sleep(random.triangular(0., 2., .8))
69+
9470
@st.cache(suppress_st_warning=True,allow_output_mutation=True)
95-
def get_and_print_attack_results(model_name, attack_recipe_name):
71+
def get_and_print_attack_results(model_name, attack_recipe_name, num_examples):
9672
with st.spinner(f'Loading `{model_name}` model and `{attack_recipe_name}` attack...'):
97-
model, attack, dataset = load_attack(model_name, attack_recipe_name)
73+
model, attack, dataset = load_attack(model_name, attack_recipe_name, num_examples)
9874
dataset_name = dataset._name
75+
9976
# Run attack.
10077
from collections import deque
101-
worklist = deque(range(0, NUM_SAMPLES_TO_ATTACK))
78+
worklist = deque(range(0, num_examples))
10279
results = []
103-
with st.spinner(f'Running attack on {NUM_SAMPLES_TO_ATTACK} samples from nlp dataset "{dataset_name}"...'):
80+
with st.spinner(f'Running attack on {num_examples} samples from nlp dataset "{dataset_name}"...'):
10481
for idx, result in enumerate(attack.attack_dataset(dataset, indices=worklist)):
105-
st.markdown(get_attack_result_html(idx, result), unsafe_allow_html=True)
82+
st.markdown(HtmlHelper.get_attack_result_html(idx, result), unsafe_allow_html=True)
10683
results.append(result)
107-
108-
# Update precomputed results
109-
PRECOMPUTED_RESULTS = load_precomputed_results()
110-
PRECOMPUTED_RESULTS[(model_name, attack_recipe_name)] = results
111-
pickle.dump(PRECOMPUTED_RESULTS, open(PRECOMPUTED_RESULTS_DICT_NAME, 'wb'))
112-
# Return results
113-
return { 'results': results, 'already_printed': True }
11484

115-
def random_latency():
116-
# Artificially inject a tiny bit of latency to provide
117-
# a feel of the attack _running_.
118-
time.sleep(random.triangular(0., 2., .8))
85+
# Update precomputed results
86+
PRECOMPUTE_CACHE.add((model_name, attack_recipe_name), results)
11987

120-
def run_attack(model_name, attack_recipe_name):
121-
if (model_name, attack_recipe_name) in PRECOMPUTED_RESULTS:
122-
results = PRECOMPUTED_RESULTS[(model_name, attack_recipe_name)]
123-
for idx, result in enumerate(results):
124-
random_latency()
125-
st.markdown(get_attack_result_html(idx, result), unsafe_allow_html=True)
88+
def run_attack_interactive(text, model_name, attack_recipe_name):
89+
if PRECOMPUTE_CACHE.exists((text, model_name, attack_recipe_name)) and PRECOMPUTE_CACHE.exists(HISTORY):
90+
PRECOMPUTE_CACHE.to_top((text, model_name, attack_recipe_name))
91+
display_history(fake_latency=True)
12692
else:
127-
# Precompute results
128-
results_dict = get_and_print_attack_results(model_name, attack_recipe_name)
129-
results = results_dict['results']
130-
# Print attack results, as long as this wasn't the first time they were computed.
131-
if not results_dict['already_printed']:
132-
for idx, result in enumerate(results):
133-
random_latency()
134-
st.markdown(get_attack_result_html(idx, result), unsafe_allow_html=True)
135-
results_dict['already_printed'] = False
136-
# print summary
93+
attack = textattack.commands.attack.attack_args_helpers.parse_attack_from_args(Args(model_name, attack_recipe_name))
94+
attacked_text = textattack.shared.attacked_text.AttackedText(text)
95+
initial_result = attack.goal_function.get_output(attacked_text)
96+
result = next(attack.attack_dataset([(text, initial_result)]))
97+
98+
# Update precomputed results
99+
PRECOMPUTE_CACHE.add((text, model_name, attack_recipe_name), result)
100+
display_history()
101+
102+
def run_attack(model_name, attack_recipe_name, num_examples):
103+
if PRECOMPUTE_CACHE.exists((model_name, attack_recipe_name)):
104+
PRECOMPUTE_CACHE.to_top((model_name, attack_recipe_name))
105+
display_history(fake_latency=True)
106+
else:
107+
get_and_print_attack_results(model_name, attack_recipe_name, num_examples)
137108

138-
139109
def process_attack_recipe_doc(attack_recipe_text):
140110
attack_recipe_text = attack_recipe_text.strip()
141111
attack_recipe_text = "\n".join(map(lambda line: line.strip(), attack_recipe_text.split("\n")))
142112
return attack_recipe_text
143113

144114
def main():
145-
# Print instructions.
146-
st.markdown(INITIAL_INSTRUCTIONS_HTML, unsafe_allow_html=True)
115+
st.beta_set_page_config(page_title='TextAttack Web Demo', page_icon='https://cdn.shopify.com/s/files/1/1061/1924/products/Octopus_Iphone_Emoji_JPG_large.png', initial_sidebar_state ='auto')
116+
st.markdown(HtmlHelper.INITIAL_INSTRUCTIONS_HTML, unsafe_allow_html=True)
117+
147118
# Print TextAttack info to sidebar.
148119
st.sidebar.markdown('<h1 style="text-align:center; font-size: 1.5em;">TextAttack 🐙</h1>', unsafe_allow_html=True)
149120
st.sidebar.markdown('<p style="font-size:1.em; text-align:center;"><a href="https://github.com/QData/TextAttack">https://github.com/QData/TextAttack</a></p>', unsafe_allow_html=True)
150121
st.sidebar.markdown('<hr>', unsafe_allow_html=True)
122+
151123
# Select model.
152124
all_model_names = list(re.sub(r'-mr$', '-rotten_tomatoes', m) for m in MODELS.keys())
153125
model_names = list(sorted(set(map(lambda x: x.replace(x[x.rfind('-'):],''), all_model_names))))
154126
model_default = 'bert-base-uncased'
155127
model_default_index = model_names.index(model_default)
128+
interactive = st.sidebar.checkbox('Interactive')
156129
model_name = st.sidebar.selectbox('Model from transformers:', model_names, index=model_default_index)
130+
157131
# Select dataset. (TODO make this less hacky.)
132+
if interactive:
133+
interactive_text = st.sidebar.text_input('Custom Input Data')
158134
matching_model_keys = list(m for m in all_model_names if m.startswith(model_name))
159135
dataset_names = list(sorted(map(lambda x: x.replace(x[:x.rfind('-')+1],''), matching_model_keys)))
160136
dataset_default_index = 0
@@ -166,28 +142,53 @@ def main():
166142
continue
167143
dataset_name = st.sidebar.selectbox('Dataset from nlp:', dataset_names, index=dataset_default_index)
168144
full_model_name = '-'.join((model_name, dataset_name)).replace('-rotten_tomatoes', '-mr')
145+
169146
# Select attack recipe.
170147
recipe_names = list(sorted(ATTACK_RECIPES.keys()))
171148
for hidden_attack in HIDDEN_ATTACK_RECIPES: recipe_names.remove(hidden_attack)
172149
recipe_default = 'textfooler'
173150
recipe_default_index = recipe_names.index(recipe_default)
174151
attack_recipe_name = st.sidebar.selectbox('Attack recipe', recipe_names, index=recipe_default_index)
152+
153+
# Select number of examples to be displayed
154+
if not interactive:
155+
num_examples = st.sidebar.slider('Number of Examples', 1, 100, value=10, step=1)
156+
175157
# Run attack on button press.
176158
if st.sidebar.button('Run attack'):
177159
# Run full attack.
178-
run_attack(full_model_name, attack_recipe_name)
160+
if interactive: run_attack_interactive(interactive_text, full_model_name, attack_recipe_name)
161+
else: run_attack(full_model_name, attack_recipe_name, num_examples)
162+
else:
163+
# Print History of Usage
164+
timeline_history = PRECOMPUTE_CACHE.get(HISTORY)
165+
for idx, entry in enumerate(timeline_history):
166+
st.markdown(HtmlHelper.get_attack_result_html(idx, entry), unsafe_allow_html=True)
167+
168+
# Display clear history button
169+
if PRECOMPUTE_CACHE.exists(HISTORY):
170+
clear_history = st.button("Clear History")
171+
if clear_history:
172+
PRECOMPUTE_CACHE.purge(key=HISTORY)
173+
179174
# TODO print attack metrics somewhere?
180175
# Add model info to sidebar.
181176
hf_model_name = MODELS[full_model_name][0]
182177
model_link = f"https://huggingface.co/{hf_model_name}"
183178
st.markdown(f"### Model Hub Link \n [[{hf_model_name}]({model_link})]", unsafe_allow_html=True)
179+
184180
# Add attack info to sidebar (TODO would main page be better?).
185181
attack_recipe_doc = process_attack_recipe_doc(eval(f"{ATTACK_RECIPES[attack_recipe_name]}.__doc__"))
186182
st.sidebar.markdown(f'<hr style="margin: 1.0em 0;"> <h3>Attack Recipe:</h3>\n<b>Name:</b> {attack_recipe_name} <br> <br> {attack_recipe_doc}', unsafe_allow_html=True)
183+
187184
# Print attack recipe composition
188185
attack_recipe_prototype = get_attack_recipe_prototype(attack_recipe_name)
189186
st.markdown(f'### Attack Recipe Prototype \n```\n{attack_recipe_prototype}\n```')
187+
188+
purge_cache = st.button("Purge Local Cache")
189+
if purge_cache:
190+
PRECOMPUTE_CACHE.purge()
190191

191-
if __name__ == "__main__": # @TODO split model & dataset into 2 dropdowns
192-
PRECOMPUTED_RESULTS = load_precomputed_results()
192+
if __name__ == "__main__":
193+
PRECOMPUTE_CACHE = Cache(log=False)
193194
main()

config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
HIDDEN_ATTACK_RECIPES = ['alzantot', 'seq2sick', 'hotflip']
77

88
PRECOMPUTED_RESULTS_DICT_NAME = 'results.p'
9+
HISTORY = 'timeline_history'

models/args.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
class Args():
2+
def __init__(self, model, recipe, model_batch_size=32, query_budget=200, model_cache_size=2**18, constraint_cache_size=2**18):
3+
self.model = model
4+
self.recipe = recipe
5+
self.model_batch_size = model_batch_size
6+
self.model_cache_size = model_cache_size
7+
self.query_budget = query_budget
8+
self.constraint_cache_size = constraint_cache_size
9+
10+
def __getattr__(self, item):
11+
return False

models/cache.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import pickle
2+
3+
from config import PRECOMPUTED_RESULTS_DICT_NAME, HISTORY
4+
5+
class Cache():
6+
def __init__(self, log=False):
7+
self.log = log
8+
self.cache = self.load_precomputed_results()
9+
10+
def load_precomputed_results(self):
11+
try:
12+
precomputed_results = pickle.load(open(PRECOMPUTED_RESULTS_DICT_NAME, "rb"))
13+
except FileNotFoundError:
14+
precomputed_results = {}
15+
if self.log: print(f'Found {len(precomputed_results)} keys in pre-computed results.')
16+
return precomputed_results
17+
18+
def add(self, key, data):
19+
self.cache = self.load_precomputed_results()
20+
self.cache[key] = data
21+
22+
# update history
23+
if isinstance(data, list):
24+
self.cache[HISTORY] = data + self.cache.get(HISTORY, [])
25+
else:
26+
self.cache[HISTORY] = [data] + self.cache.get(HISTORY, [])
27+
28+
pickle.dump(self.cache, open(PRECOMPUTED_RESULTS_DICT_NAME, 'wb'))
29+
if self.log: print(f'Successfully added {key} to the cache')
30+
31+
def to_top(self, key):
32+
self.cache = self.load_precomputed_results()
33+
data, history = self.cache.get(key, None), self.cache.get(HISTORY, None)
34+
if not data or not history:
35+
return []
36+
37+
if isinstance(data, list):
38+
for d in data:
39+
history.pop(history.index(d))
40+
history.insert(0, d)
41+
else:
42+
history.pop(history.index(data))
43+
history.insert(0, data)
44+
45+
def exists(self, key):
46+
self.cache = self.load_precomputed_results()
47+
return key in self.cache
48+
49+
def purge(self, key=None):
50+
self.cache = self.load_precomputed_results()
51+
if not key:
52+
self.cache.clear()
53+
elif key in self.cache:
54+
del self.cache[key]
55+
if self.log: print(f'Cache successfully purged')
56+
pickle.dump(self.cache, open(PRECOMPUTED_RESULTS_DICT_NAME, 'wb'))
57+
58+
def get(self, key):
59+
self.cache = self.load_precomputed_results()
60+
return self.cache.get(key, [])

0 commit comments

Comments
 (0)