Skip to content

Commit ba37aa4

Browse files
committed
add log_f1 argument
1 parent b8080f6 commit ba37aa4

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

train.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,6 @@ def train(model, training_generator, optimizer, criterion, epoch, writer, log_fi
7878
lr = optimizer.state_dict()["param_groups"][0]["lr"]
7979

8080
if (iter % print_every == 0) and (iter > 0):
81-
intermediate_report = classification_report(
82-
y_true, y_pred, output_dict=True)
83-
84-
f1_by_class = 'F1 Scores by class: '
85-
for class_name in class_names:
86-
f1_by_class += f"{class_name} : {np.round(intermediate_report[class_name]['f1-score'], 4)} |"
87-
8881
print("[Training - Epoch: {}], LR: {} , Iteration: {}/{} , Loss: {}, Accuracy: {}".format(
8982
epoch + 1,
9083
lr,
@@ -93,7 +86,16 @@ def train(model, training_generator, optimizer, criterion, epoch, writer, log_fi
9386
losses.avg,
9487
accuracies.avg
9588
))
96-
print(f1_by_class)
89+
90+
if bool(args.log_f1):
91+
intermediate_report = classification_report(
92+
y_true, y_pred, output_dict=True)
93+
94+
f1_by_class = 'F1 Scores by class: '
95+
for class_name in class_names:
96+
f1_by_class += f"{class_name} : {np.round(intermediate_report[class_name]['f1-score'], 4)} |"
97+
98+
print(f1_by_class)
9799

98100
f1_train = f1_score(y_true, y_pred, average='weighted')
99101

@@ -403,6 +405,7 @@ def run(args, both_cases=False):
403405
parser.add_argument('--workers', type=int, default=1)
404406
parser.add_argument('--log_path', type=str, default='./logs/')
405407
parser.add_argument('--log_every', type=int, default=100)
408+
parser.add_argument('--log_f1', type=int, default=1, choices=[0, 1])
406409
parser.add_argument('--flush_history', type=int,
407410
default=1, choices=[0, 1])
408411
parser.add_argument('--output', type=str, default='./models/')

0 commit comments

Comments
 (0)