@@ -78,13 +78,6 @@ def train(model, training_generator, optimizer, criterion, epoch, writer, log_fi
7878 lr = optimizer .state_dict ()["param_groups" ][0 ]["lr" ]
7979
8080 if (iter % print_every == 0 ) and (iter > 0 ):
81- intermediate_report = classification_report (
82- y_true , y_pred , output_dict = True )
83-
84- f1_by_class = 'F1 Scores by class: '
85- for class_name in class_names :
86- f1_by_class += f"{ class_name } : { np .round (intermediate_report [class_name ]['f1-score' ], 4 )} |"
87-
8881 print ("[Training - Epoch: {}], LR: {} , Iteration: {}/{} , Loss: {}, Accuracy: {}" .format (
8982 epoch + 1 ,
9083 lr ,
@@ -93,7 +86,16 @@ def train(model, training_generator, optimizer, criterion, epoch, writer, log_fi
9386 losses .avg ,
9487 accuracies .avg
9588 ))
96- print (f1_by_class )
89+
90+ if bool (args .log_f1 ):
91+ intermediate_report = classification_report (
92+ y_true , y_pred , output_dict = True )
93+
94+ f1_by_class = 'F1 Scores by class: '
95+ for class_name in class_names :
96+ f1_by_class += f"{ class_name } : { np .round (intermediate_report [class_name ]['f1-score' ], 4 )} |"
97+
98+ print (f1_by_class )
9799
98100 f1_train = f1_score (y_true , y_pred , average = 'weighted' )
99101
@@ -403,6 +405,7 @@ def run(args, both_cases=False):
403405 parser .add_argument ('--workers' , type = int , default = 1 )
404406 parser .add_argument ('--log_path' , type = str , default = './logs/' )
405407 parser .add_argument ('--log_every' , type = int , default = 100 )
408+ parser .add_argument ('--log_f1' , type = int , default = 1 , choices = [0 , 1 ])
406409 parser .add_argument ('--flush_history' , type = int ,
407410 default = 1 , choices = [0 , 1 ])
408411 parser .add_argument ('--output' , type = str , default = './models/' )
0 commit comments