11import streamlit as st
2+ import numpy as np
23
34from argparse import Namespace
45
1011import random
1112import re
1213
13- logger = logging .getLogger (__name__ )
14+ from models .html_helper import HtmlHelper
15+ from models .args import Args
16+ from models .cache import Cache
1417
15- INITIAL_INSTRUCTIONS_HTML = """<p style="font-size:1.em; font-weight: 300">👋 Welcome to the TextAttack demo app! Please select a model and an attack recipe from the dropdown.</p> <hr style="margin: 1.em 0;">"""
18+ logger = logging . getLogger ( __name__ )
1619
17- from config import NUM_SAMPLES_TO_ATTACK , MODELS , ATTACK_RECIPES , HIDDEN_ATTACK_RECIPES , PRECOMPUTED_RESULTS_DICT_NAME
20+ from config import NUM_SAMPLES_TO_ATTACK , MODELS , ATTACK_RECIPES , HIDDEN_ATTACK_RECIPES , PRECOMPUTED_RESULTS_DICT_NAME , HISTORY
1821
19- def load_precomputed_results ():
20- try :
21- precomputed_results = pickle .load (open (PRECOMPUTED_RESULTS_DICT_NAME , "rb" ))
22- except FileNotFoundError :
23- precomputed_results = {}
24- print (f'Found { len (precomputed_results )} keys in pre-computed results.' )
25- return precomputed_results
26-
27- def load_attack (model_name , attack_recipe_name ):
22+ def load_attack (model_name , attack_recipe_name , num_examples ):
2823 # Load model.
2924 model_class_name = MODELS [model_name ][0 ]
3025 logger .info (f"Loading transformers.AutoModelForSequenceClassification from '{ model_class_name } '." )
@@ -34,54 +29,24 @@ def load_attack(model_name, attack_recipe_name):
3429 except OSError :
3530 logger .warn ('Couldn\' t find tokenizer; defaulting to "bert-base-uncased".' )
3631 tokenizer = textattack .models .tokenizers .AutoTokenizer ("bert-base-uncased" )
37- setattr (model , "tokenizer" , tokenizer )
32+ model = textattack .models .wrappers .HuggingFaceModelWrapper (model , tokenizer )
33+
3834 # Load attack.
3935 logger .info (f"Loading attack from recipe { attack_recipe_name } ." )
40- attack = eval (f"{ ATTACK_RECIPES [attack_recipe_name ]} (model)" )
36+ attack = eval (f"{ ATTACK_RECIPES [attack_recipe_name ]} .build(model)" )
37+
4138 # Load dataset.
4239 _ , dataset_args = MODELS [model_name ]
43- dataset = textattack .datasets .HuggingFaceNlpDataset (
40+ dataset = textattack .datasets .HuggingFaceDataset (
4441 * dataset_args , shuffle = True
4542 )
43+ dataset .examples = dataset .examples [:num_examples ]
4644 return model , attack , dataset
4745
48- def improve_result_html (result_html ):
49- result_html = result_html .replace ("color = bold" , 'style="font-weight: bold;"' )
50- result_html = result_html .replace ("color = underline" , 'style="text-decoration: underline;"' )
51- result_html = result_html .replace ('<font style="font-weight: bold;"' , '<span style=""' ) # no bolding for now
52- result_html = result_html .replace ('<font style="text-decoration: underline;"' , '<span style="text-decoration: underline;"' )
53- result_html = re .sub (r"<font\scolor\s=\s(\w.*?)>" , r'<span style="background-color: \1; padding: 1.2px; font-weight: 600;">' , result_html )
54- # replace font colors with transparent highlight versions
55- result_html = result_html .replace (': red' , ': rgba(255, 0, 0, .7)' ) \
56- .replace (': green' , ': rgb(0, 255, 0, .7)' ) \
57- .replace (': blue' , ': rgb(0, 0, 255, .7)' ) \
58- .replace (': gray' , ': rgb(220, 220, 220, .7)' )
59- result_html = result_html .replace ("</font>" , "</span>" )
60- return result_html
61-
62- def get_attack_result_status (attack_result ):
63- status_html = attack_result .goal_function_result_str (color_method = 'html' )
64- return improve_result_html (status_html )
65-
66- def get_attack_result_html (idx , attack_result ):
67- result_status = get_attack_result_status (attack_result )
68- result_html_lines = attack_result .str_lines (color_method = 'html' )
69- result_html_lines = [improve_result_html (line ) for line in result_html_lines ]
70- rows = [
71- ['' , result_status ],
72- ['Input' , result_html_lines [1 ]]
73- ]
74-
75- if len (result_html_lines ) > 2 :
76- rows .append (['Output' , result_html_lines [2 ]])
77-
78- table_html = '\n ' .join ((f'<b>{ header } :</b> { body } <br>' if header else f'{ body } <br>' ) for header ,body in rows )
79- return f'<h3>Result { idx + 1 } </h3> { table_html } <br>'
80-
8146@st .cache
8247def get_attack_recipe_prototype (attack_recipe_name ):
8348 """ a sort of hacky way to print an attack recipe without loading a big model"""
84- recipe = eval (textattack .commands .attack .attack_args .ATTACK_RECIPE_NAMES [attack_recipe_name ])
49+ recipe = eval (textattack .commands .attack .attack_args .ATTACK_RECIPE_NAMES [attack_recipe_name ]). build
8550 dummy_tokenizer = Namespace (** { 'encode' : None })
8651 dummy_model = Namespace (** { 'tokenizer' : dummy_tokenizer })
8752 recipe = recipe (dummy_model )
@@ -91,70 +56,81 @@ def get_attack_recipe_prototype(attack_recipe_name):
9156 del dummy_tokenizer
9257 return recipe_str
9358
59+ def display_history (fake_latency = False ):
60+ history = PRECOMPUTE_CACHE .get (HISTORY )
61+ for idx , result in enumerate (history ):
62+ if fake_latency : random_latency ()
63+ st .markdown (HtmlHelper .get_attack_result_html (idx , result ), unsafe_allow_html = True )
64+
65+ def random_latency ():
66+ # Artificially inject a tiny bit of latency to provide
67+ # a feel of the attack _running_.
68+ time .sleep (random .triangular (0. , 2. , .8 ))
69+
9470@st .cache (suppress_st_warning = True ,allow_output_mutation = True )
95- def get_and_print_attack_results (model_name , attack_recipe_name ):
71+ def get_and_print_attack_results (model_name , attack_recipe_name , num_examples ):
9672 with st .spinner (f'Loading `{ model_name } ` model and `{ attack_recipe_name } ` attack...' ):
97- model , attack , dataset = load_attack (model_name , attack_recipe_name )
73+ model , attack , dataset = load_attack (model_name , attack_recipe_name , num_examples )
9874 dataset_name = dataset ._name
75+
9976 # Run attack.
10077 from collections import deque
101- worklist = deque (range (0 , NUM_SAMPLES_TO_ATTACK ))
78+ worklist = deque (range (0 , num_examples ))
10279 results = []
103- with st .spinner (f'Running attack on { NUM_SAMPLES_TO_ATTACK } samples from nlp dataset "{ dataset_name } "...' ):
80+ with st .spinner (f'Running attack on { num_examples } samples from nlp dataset "{ dataset_name } "...' ):
10481 for idx , result in enumerate (attack .attack_dataset (dataset , indices = worklist )):
105- st .markdown (get_attack_result_html (idx , result ), unsafe_allow_html = True )
82+ st .markdown (HtmlHelper . get_attack_result_html (idx , result ), unsafe_allow_html = True )
10683 results .append (result )
107-
108- # Update precomputed results
109- PRECOMPUTED_RESULTS = load_precomputed_results ()
110- PRECOMPUTED_RESULTS [(model_name , attack_recipe_name )] = results
111- pickle .dump (PRECOMPUTED_RESULTS , open (PRECOMPUTED_RESULTS_DICT_NAME , 'wb' ))
112- # Return results
113- return { 'results' : results , 'already_printed' : True }
11484
115- def random_latency ():
116- # Artificially inject a tiny bit of latency to provide
117- # a feel of the attack _running_.
118- time .sleep (random .triangular (0. , 2. , .8 ))
85+ # Update precomputed results
86+ PRECOMPUTE_CACHE .add ((model_name , attack_recipe_name ), results )
11987
120- def run_attack (model_name , attack_recipe_name ):
121- if (model_name , attack_recipe_name ) in PRECOMPUTED_RESULTS :
122- results = PRECOMPUTED_RESULTS [(model_name , attack_recipe_name )]
123- for idx , result in enumerate (results ):
124- random_latency ()
125- st .markdown (get_attack_result_html (idx , result ), unsafe_allow_html = True )
88+ def run_attack_interactive (text , model_name , attack_recipe_name ):
89+ if PRECOMPUTE_CACHE .exists ((text , model_name , attack_recipe_name )) and PRECOMPUTE_CACHE .exists (HISTORY ):
90+ PRECOMPUTE_CACHE .to_top ((text , model_name , attack_recipe_name ))
91+ display_history (fake_latency = True )
12692 else :
127- # Precompute results
128- results_dict = get_and_print_attack_results (model_name , attack_recipe_name )
129- results = results_dict ['results' ]
130- # Print attack results, as long as this wasn't the first time they were computed.
131- if not results_dict ['already_printed' ]:
132- for idx , result in enumerate (results ):
133- random_latency ()
134- st .markdown (get_attack_result_html (idx , result ), unsafe_allow_html = True )
135- results_dict ['already_printed' ] = False
136- # print summary
93+ attack = textattack .commands .attack .attack_args_helpers .parse_attack_from_args (Args (model_name , attack_recipe_name ))
94+ attacked_text = textattack .shared .attacked_text .AttackedText (text )
95+ initial_result = attack .goal_function .get_output (attacked_text )
96+ result = next (attack .attack_dataset ([(text , initial_result )]))
97+
98+ # Update precomputed results
99+ PRECOMPUTE_CACHE .add ((text , model_name , attack_recipe_name ), result )
100+ display_history ()
101+
102+ def run_attack (model_name , attack_recipe_name , num_examples ):
103+ if PRECOMPUTE_CACHE .exists ((model_name , attack_recipe_name )):
104+ PRECOMPUTE_CACHE .to_top ((model_name , attack_recipe_name ))
105+ display_history (fake_latency = True )
106+ else :
107+ get_and_print_attack_results (model_name , attack_recipe_name , num_examples )
137108
138-
139109def process_attack_recipe_doc (attack_recipe_text ):
140110 attack_recipe_text = attack_recipe_text .strip ()
141111 attack_recipe_text = "\n " .join (map (lambda line : line .strip (), attack_recipe_text .split ("\n " )))
142112 return attack_recipe_text
143113
144114def main ():
145- # Print instructions.
146- st .markdown (INITIAL_INSTRUCTIONS_HTML , unsafe_allow_html = True )
115+ st .beta_set_page_config (page_title = 'TextAttack Web Demo' , page_icon = 'https://cdn.shopify.com/s/files/1/1061/1924/products/Octopus_Iphone_Emoji_JPG_large.png' , initial_sidebar_state = 'auto' )
116+ st .markdown (HtmlHelper .INITIAL_INSTRUCTIONS_HTML , unsafe_allow_html = True )
117+
147118 # Print TextAttack info to sidebar.
148119 st .sidebar .markdown ('<h1 style="text-align:center; font-size: 1.5em;">TextAttack 🐙</h1>' , unsafe_allow_html = True )
149120 st .sidebar .markdown ('<p style="font-size:1.em; text-align:center;"><a href="https://github.com/QData/TextAttack">https://github.com/QData/TextAttack</a></p>' , unsafe_allow_html = True )
150121 st .sidebar .markdown ('<hr>' , unsafe_allow_html = True )
122+
151123 # Select model.
152124 all_model_names = list (re .sub (r'-mr$' , '-rotten_tomatoes' , m ) for m in MODELS .keys ())
153125 model_names = list (sorted (set (map (lambda x : x .replace (x [x .rfind ('-' ):],'' ), all_model_names ))))
154126 model_default = 'bert-base-uncased'
155127 model_default_index = model_names .index (model_default )
128+ interactive = st .sidebar .checkbox ('Interactive' )
156129 model_name = st .sidebar .selectbox ('Model from transformers:' , model_names , index = model_default_index )
130+
157131 # Select dataset. (TODO make this less hacky.)
132+ if interactive :
133+ interactive_text = st .sidebar .text_input ('Custom Input Data' )
158134 matching_model_keys = list (m for m in all_model_names if m .startswith (model_name ))
159135 dataset_names = list (sorted (map (lambda x : x .replace (x [:x .rfind ('-' )+ 1 ],'' ), matching_model_keys )))
160136 dataset_default_index = 0
@@ -166,28 +142,53 @@ def main():
166142 continue
167143 dataset_name = st .sidebar .selectbox ('Dataset from nlp:' , dataset_names , index = dataset_default_index )
168144 full_model_name = '-' .join ((model_name , dataset_name )).replace ('-rotten_tomatoes' , '-mr' )
145+
169146 # Select attack recipe.
170147 recipe_names = list (sorted (ATTACK_RECIPES .keys ()))
171148 for hidden_attack in HIDDEN_ATTACK_RECIPES : recipe_names .remove (hidden_attack )
172149 recipe_default = 'textfooler'
173150 recipe_default_index = recipe_names .index (recipe_default )
174151 attack_recipe_name = st .sidebar .selectbox ('Attack recipe' , recipe_names , index = recipe_default_index )
152+
153+ # Select number of examples to be displayed
154+ if not interactive :
155+ num_examples = st .sidebar .slider ('Number of Examples' , 1 , 100 , value = 10 , step = 1 )
156+
175157 # Run attack on button press.
176158 if st .sidebar .button ('Run attack' ):
177159 # Run full attack.
178- run_attack (full_model_name , attack_recipe_name )
160+ if interactive : run_attack_interactive (interactive_text , full_model_name , attack_recipe_name )
161+ else : run_attack (full_model_name , attack_recipe_name , num_examples )
162+ else :
163+ # Print History of Usage
164+ timeline_history = PRECOMPUTE_CACHE .get (HISTORY )
165+ for idx , entry in enumerate (timeline_history ):
166+ st .markdown (HtmlHelper .get_attack_result_html (idx , entry ), unsafe_allow_html = True )
167+
168+ # Display clear history button
169+ if PRECOMPUTE_CACHE .exists (HISTORY ):
170+ clear_history = st .button ("Clear History" )
171+ if clear_history :
172+ PRECOMPUTE_CACHE .purge (key = HISTORY )
173+
179174 # TODO print attack metrics somewhere?
180175 # Add model info to sidebar.
181176 hf_model_name = MODELS [full_model_name ][0 ]
182177 model_link = f"https://huggingface.co/{ hf_model_name } "
183178 st .markdown (f"### Model Hub Link \n [[{ hf_model_name } ]({ model_link } )]" , unsafe_allow_html = True )
179+
184180 # Add attack info to sidebar (TODO would main page be better?).
185181 attack_recipe_doc = process_attack_recipe_doc (eval (f"{ ATTACK_RECIPES [attack_recipe_name ]} .__doc__" ))
186182 st .sidebar .markdown (f'<hr style="margin: 1.0em 0;"> <h3>Attack Recipe:</h3>\n <b>Name:</b> { attack_recipe_name } <br> <br> { attack_recipe_doc } ' , unsafe_allow_html = True )
183+
187184 # Print attack recipe composition
188185 attack_recipe_prototype = get_attack_recipe_prototype (attack_recipe_name )
189186 st .markdown (f'### Attack Recipe Prototype \n ```\n { attack_recipe_prototype } \n ```' )
187+
188+ purge_cache = st .button ("Purge Local Cache" )
189+ if purge_cache :
190+ PRECOMPUTE_CACHE .purge ()
190191
191- if __name__ == "__main__" : # @TODO split model & dataset into 2 dropdowns
192- PRECOMPUTED_RESULTS = load_precomputed_results ( )
192+ if __name__ == "__main__" :
193+ PRECOMPUTE_CACHE = Cache ( log = False )
193194 main ()
0 commit comments