1- '''
1+ """
22This script allows to find the optimal parameters for a learning rate scheduling:
33
44- min_lr
2020
2121reference: https://towardsdatascience.com/adaptive-and-cyclical-learning-rates-using-pytorch-2bf904d18dee
2222
23- '''
23+ """
2424
2525import math
2626import os
@@ -51,17 +51,21 @@ def run(args):
5151
5252 batch_size = args .batch_size
5353
54- training_params = {"batch_size" : batch_size ,
55- "shuffle" : True ,
56- "num_workers" : args .workers }
54+ training_params = {
55+ "batch_size" : batch_size ,
56+ "shuffle" : True ,
57+ "num_workers" : args .workers ,
58+ }
5759
5860 texts , labels , number_of_classes , sample_weights = load_data (args )
59- train_texts , _ , train_labels , _ , _ , _ = train_test_split (texts ,
60- labels ,
61- sample_weights ,
62- test_size = args .validation_split ,
63- random_state = 42 ,
64- stratify = labels )
61+ train_texts , _ , train_labels , _ , _ , _ = train_test_split (
62+ texts ,
63+ labels ,
64+ sample_weights ,
65+ test_size = args .validation_split ,
66+ random_state = 42 ,
67+ stratify = labels ,
68+ )
6569
6670 training_set = MyDataset (train_texts , train_labels , args )
6771 training_generator = DataLoader (training_set , ** training_params )
@@ -74,31 +78,31 @@ def run(args):
7478
7579 criterion = nn .CrossEntropyLoss ()
7680
77- if args .optimizer == 'sgd' :
78- optimizer = torch .optim .SGD (
79- model .parameters (), lr = args .start_lr , momentum = 0.9
80- )
81- elif args .optimizer == 'adam' :
82- optimizer = torch .optim .Adam (
83- model .parameters (), lr = args .start_lr
84- )
81+ if args .optimizer == "sgd" :
82+ optimizer = torch .optim .SGD (model .parameters (), lr = args .start_lr , momentum = 0.9 )
83+ elif args .optimizer == "adam" :
84+ optimizer = torch .optim .Adam (model .parameters (), lr = args .start_lr )
8585
8686 start_lr = args .start_lr
8787 end_lr = args .end_lr
8888 lr_find_epochs = args .epochs
8989 smoothing = args .smoothing
9090
91- def lr_lambda (x ): return math .exp (
92- x * math .log (end_lr / start_lr ) / (lr_find_epochs * len (training_generator )))
91+ def lr_lambda (x ):
92+ return math .exp (
93+ x * math .log (end_lr / start_lr ) / (lr_find_epochs * len (training_generator ))
94+ )
95+
9396 scheduler = torch .optim .lr_scheduler .LambdaLR (optimizer , lr_lambda )
9497
9598 losses = []
9699 learning_rates = []
97100
98101 for epoch in range (lr_find_epochs ):
99- print (f'[epoch { epoch + 1 } / { lr_find_epochs } ]' )
100- progress_bar = tqdm (enumerate (training_generator ),
101- total = len (training_generator ))
102+ print (f"[epoch { epoch + 1 } / { lr_find_epochs } ]" )
103+ progress_bar = tqdm (
104+ enumerate (training_generator ), total = len (training_generator )
105+ )
102106 for iter , batch in progress_bar :
103107 features , labels = batch
104108 if torch .cuda .is_available ():
@@ -124,41 +128,42 @@ def lr_lambda(x): return math.exp(
124128 losses .append (loss )
125129
126130 plt .semilogx (learning_rates , losses )
127- plt .savefig (' ./plots/losses_vs_lr.png' )
131+ plt .savefig (" ./plots/losses_vs_lr.png" )
128132
129133
130134if __name__ == "__main__" :
131- parser = argparse .ArgumentParser (
132- 'Character Based CNN for text classification' )
133- parser .add_argument ('--data_path' , type = str ,
134- default = './data/train.csv' )
135- parser .add_argument ('--validation_split' , type = float , default = 0.2 )
136- parser .add_argument ('--label_column' , type = str , default = 'Sentiment' )
137- parser .add_argument ('--text_column' , type = str , default = 'SentimentText' )
138- parser .add_argument ('--max_rows' , type = int , default = None )
139- parser .add_argument ('--chunksize' , type = int , default = 50000 )
140- parser .add_argument ('--encoding' , type = str , default = 'utf-8' )
141- parser .add_argument ('--sep' , type = str , default = ',' )
142- parser .add_argument ('--steps' , nargs = '+' , default = ['lower' ])
143- parser .add_argument ('--group_labels' , type = str ,
144- default = None , choices = [None , 'binarize' ])
145- parser .add_argument ('--ratio' , type = float , default = 1 )
146-
147- parser .add_argument ('--alphabet' , type = str ,
148- default = 'abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:\' "\\ /|_@#$%^&*~`+-=<>()[]{}' )
149- parser .add_argument ('--number_of_characters' , type = int , default = 69 )
150- parser .add_argument ('--extra_characters' , type = str , default = '' )
151- parser .add_argument ('--max_length' , type = int , default = 150 )
152- parser .add_argument ('--batch_size' , type = int , default = 128 )
153- parser .add_argument ('--optimizer' , type = str ,
154- choices = ['adam' , 'sgd' ], default = 'sgd' )
155- parser .add_argument ('--learning_rate' , type = float , default = 0.01 )
156- parser .add_argument ('--workers' , type = int , default = 1 )
157-
158- parser .add_argument ('--start_lr' , type = float , default = 1e-5 )
159- parser .add_argument ('--end_lr' , type = float , default = 1e-2 )
160- parser .add_argument ('--smoothing' , type = float , default = 0.05 )
161- parser .add_argument ('--epochs' , type = int , default = 1 )
135+ parser = argparse .ArgumentParser ("Character Based CNN for text classification" )
136+ parser .add_argument ("--data_path" , type = str , default = "./data/train.csv" )
137+ parser .add_argument ("--validation_split" , type = float , default = 0.2 )
138+ parser .add_argument ("--label_column" , type = str , default = "Sentiment" )
139+ parser .add_argument ("--text_column" , type = str , default = "SentimentText" )
140+ parser .add_argument ("--max_rows" , type = int , default = None )
141+ parser .add_argument ("--chunksize" , type = int , default = 50000 )
142+ parser .add_argument ("--encoding" , type = str , default = "utf-8" )
143+ parser .add_argument ("--sep" , type = str , default = "," )
144+ parser .add_argument ("--steps" , nargs = "+" , default = ["lower" ])
145+ parser .add_argument (
146+ "--group_labels" , type = str , default = None , choices = [None , "binarize" ]
147+ )
148+ parser .add_argument ("--ratio" , type = float , default = 1 )
149+
150+ parser .add_argument (
151+ "--alphabet" ,
152+ type = str ,
153+ default = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\" \\ /|_@#$%^&*~`+-=<>()[]{}" ,
154+ )
155+ parser .add_argument ("--number_of_characters" , type = int , default = 69 )
156+ parser .add_argument ("--extra_characters" , type = str , default = "" )
157+ parser .add_argument ("--max_length" , type = int , default = 150 )
158+ parser .add_argument ("--batch_size" , type = int , default = 128 )
159+ parser .add_argument ("--optimizer" , type = str , choices = ["adam" , "sgd" ], default = "sgd" )
160+ parser .add_argument ("--learning_rate" , type = float , default = 0.01 )
161+ parser .add_argument ("--workers" , type = int , default = 1 )
162+
163+ parser .add_argument ("--start_lr" , type = float , default = 1e-5 )
164+ parser .add_argument ("--end_lr" , type = float , default = 1e-2 )
165+ parser .add_argument ("--smoothing" , type = float , default = 0.05 )
166+ parser .add_argument ("--epochs" , type = int , default = 1 )
162167
163168 args = parser .parse_args ()
164169 run (args )
0 commit comments