Skip to content

Commit 5931976

Browse files
committed
formatting with black and prettier
1 parent ba37aa4 commit 5931976

File tree

8 files changed

+486
-432
lines changed

8 files changed

+486
-432
lines changed

clr_parameters_finder.py

Lines changed: 61 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
'''
1+
"""
22
This script allows to find the optimal parameters for a learning rate scheduling:
33
44
- min_lr
@@ -20,7 +20,7 @@
2020
2121
reference: https://towardsdatascience.com/adaptive-and-cyclical-learning-rates-using-pytorch-2bf904d18dee
2222
23-
'''
23+
"""
2424

2525
import math
2626
import os
@@ -51,17 +51,21 @@ def run(args):
5151

5252
batch_size = args.batch_size
5353

54-
training_params = {"batch_size": batch_size,
55-
"shuffle": True,
56-
"num_workers": args.workers}
54+
training_params = {
55+
"batch_size": batch_size,
56+
"shuffle": True,
57+
"num_workers": args.workers,
58+
}
5759

5860
texts, labels, number_of_classes, sample_weights = load_data(args)
59-
train_texts, _, train_labels, _, _, _ = train_test_split(texts,
60-
labels,
61-
sample_weights,
62-
test_size=args.validation_split,
63-
random_state=42,
64-
stratify=labels)
61+
train_texts, _, train_labels, _, _, _ = train_test_split(
62+
texts,
63+
labels,
64+
sample_weights,
65+
test_size=args.validation_split,
66+
random_state=42,
67+
stratify=labels,
68+
)
6569

6670
training_set = MyDataset(train_texts, train_labels, args)
6771
training_generator = DataLoader(training_set, **training_params)
@@ -74,31 +78,31 @@ def run(args):
7478

7579
criterion = nn.CrossEntropyLoss()
7680

77-
if args.optimizer == 'sgd':
78-
optimizer = torch.optim.SGD(
79-
model.parameters(), lr=args.start_lr, momentum=0.9
80-
)
81-
elif args.optimizer == 'adam':
82-
optimizer = torch.optim.Adam(
83-
model.parameters(), lr=args.start_lr
84-
)
81+
if args.optimizer == "sgd":
82+
optimizer = torch.optim.SGD(model.parameters(), lr=args.start_lr, momentum=0.9)
83+
elif args.optimizer == "adam":
84+
optimizer = torch.optim.Adam(model.parameters(), lr=args.start_lr)
8585

8686
start_lr = args.start_lr
8787
end_lr = args.end_lr
8888
lr_find_epochs = args.epochs
8989
smoothing = args.smoothing
9090

91-
def lr_lambda(x): return math.exp(
92-
x * math.log(end_lr / start_lr) / (lr_find_epochs * len(training_generator)))
91+
def lr_lambda(x):
92+
return math.exp(
93+
x * math.log(end_lr / start_lr) / (lr_find_epochs * len(training_generator))
94+
)
95+
9396
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
9497

9598
losses = []
9699
learning_rates = []
97100

98101
for epoch in range(lr_find_epochs):
99-
print(f'[epoch {epoch + 1} / {lr_find_epochs}]')
100-
progress_bar = tqdm(enumerate(training_generator),
101-
total=len(training_generator))
102+
print(f"[epoch {epoch + 1} / {lr_find_epochs}]")
103+
progress_bar = tqdm(
104+
enumerate(training_generator), total=len(training_generator)
105+
)
102106
for iter, batch in progress_bar:
103107
features, labels = batch
104108
if torch.cuda.is_available():
@@ -124,41 +128,42 @@ def lr_lambda(x): return math.exp(
124128
losses.append(loss)
125129

126130
plt.semilogx(learning_rates, losses)
127-
plt.savefig('./plots/losses_vs_lr.png')
131+
plt.savefig("./plots/losses_vs_lr.png")
128132

129133

130134
if __name__ == "__main__":
131-
parser = argparse.ArgumentParser(
132-
'Character Based CNN for text classification')
133-
parser.add_argument('--data_path', type=str,
134-
default='./data/train.csv')
135-
parser.add_argument('--validation_split', type=float, default=0.2)
136-
parser.add_argument('--label_column', type=str, default='Sentiment')
137-
parser.add_argument('--text_column', type=str, default='SentimentText')
138-
parser.add_argument('--max_rows', type=int, default=None)
139-
parser.add_argument('--chunksize', type=int, default=50000)
140-
parser.add_argument('--encoding', type=str, default='utf-8')
141-
parser.add_argument('--sep', type=str, default=',')
142-
parser.add_argument('--steps', nargs='+', default=['lower'])
143-
parser.add_argument('--group_labels', type=str,
144-
default=None, choices=[None, 'binarize'])
145-
parser.add_argument('--ratio', type=float, default=1)
146-
147-
parser.add_argument('--alphabet', type=str,
148-
default='abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:\'"\\/|_@#$%^&*~`+-=<>()[]{}')
149-
parser.add_argument('--number_of_characters', type=int, default=69)
150-
parser.add_argument('--extra_characters', type=str, default='')
151-
parser.add_argument('--max_length', type=int, default=150)
152-
parser.add_argument('--batch_size', type=int, default=128)
153-
parser.add_argument('--optimizer', type=str,
154-
choices=['adam', 'sgd'], default='sgd')
155-
parser.add_argument('--learning_rate', type=float, default=0.01)
156-
parser.add_argument('--workers', type=int, default=1)
157-
158-
parser.add_argument('--start_lr', type=float, default=1e-5)
159-
parser.add_argument('--end_lr', type=float, default=1e-2)
160-
parser.add_argument('--smoothing', type=float, default=0.05)
161-
parser.add_argument('--epochs', type=int, default=1)
135+
parser = argparse.ArgumentParser("Character Based CNN for text classification")
136+
parser.add_argument("--data_path", type=str, default="./data/train.csv")
137+
parser.add_argument("--validation_split", type=float, default=0.2)
138+
parser.add_argument("--label_column", type=str, default="Sentiment")
139+
parser.add_argument("--text_column", type=str, default="SentimentText")
140+
parser.add_argument("--max_rows", type=int, default=None)
141+
parser.add_argument("--chunksize", type=int, default=50000)
142+
parser.add_argument("--encoding", type=str, default="utf-8")
143+
parser.add_argument("--sep", type=str, default=",")
144+
parser.add_argument("--steps", nargs="+", default=["lower"])
145+
parser.add_argument(
146+
"--group_labels", type=str, default=None, choices=[None, "binarize"]
147+
)
148+
parser.add_argument("--ratio", type=float, default=1)
149+
150+
parser.add_argument(
151+
"--alphabet",
152+
type=str,
153+
default="abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"\\/|_@#$%^&*~`+-=<>()[]{}",
154+
)
155+
parser.add_argument("--number_of_characters", type=int, default=69)
156+
parser.add_argument("--extra_characters", type=str, default="")
157+
parser.add_argument("--max_length", type=int, default=150)
158+
parser.add_argument("--batch_size", type=int, default=128)
159+
parser.add_argument("--optimizer", type=str, choices=["adam", "sgd"], default="sgd")
160+
parser.add_argument("--learning_rate", type=float, default=0.01)
161+
parser.add_argument("--workers", type=int, default=1)
162+
163+
parser.add_argument("--start_lr", type=float, default=1e-5)
164+
parser.add_argument("--end_lr", type=float, default=1e-2)
165+
parser.add_argument("--smoothing", type=float, default=0.05)
166+
parser.add_argument("--epochs", type=int, default=1)
162167

163168
args = parser.parse_args()
164169
run(args)

config.json

Lines changed: 46 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,49 @@
11
{
2-
"alphabet": {
3-
"en": {
4-
"lower": {
5-
"alphabet": "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}",
6-
"number_of_characters": 69
7-
},
8-
"both": {
9-
"alphabet": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}",
10-
"number_of_characters": 95
11-
}
12-
}
13-
},
2+
"alphabet": {
3+
"en": {
4+
"lower": {
5+
"alphabet": "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}",
6+
"number_of_characters": 69
7+
},
8+
"both": {
9+
"alphabet": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}",
10+
"number_of_characters": 95
11+
}
12+
}
13+
},
1414

15-
"model_parameters": {
16-
"small": {
17-
"conv": [
18-
[
19-
256,
20-
7,
21-
3
22-
],
23-
[
24-
256,
25-
7,
26-
3
27-
],
28-
[
29-
256,
30-
3,
31-
-1
32-
],
33-
[
34-
256,
35-
3,
36-
-1
37-
],
38-
[
39-
256,
40-
3,
41-
-1
42-
],
43-
[
44-
256,
45-
3,
46-
3
47-
]
48-
],
49-
"fc": [
50-
1024,
51-
1024
52-
]
53-
}
54-
},
55-
"data": {
56-
"text_column": "SentimentText",
57-
"label_column": "Sentiment",
58-
"max_length": 150,
59-
"num_of_classes": 2,
60-
"encoding": null,
61-
"chunksize": 50000,
62-
"max_rows": 100000,
63-
"preprocessing_steps": ["lower", "remove_hashtags", "remove_urls", "remove_user_mentions"]
64-
},
65-
"training": {
66-
"batch_size": 128,
67-
"learning_rate": 0.01,
68-
"epochs": 10,
69-
"optimizer": "sgd"
15+
"model_parameters": {
16+
"small": {
17+
"conv": [
18+
[256, 7, 3],
19+
[256, 7, 3],
20+
[256, 3, -1],
21+
[256, 3, -1],
22+
[256, 3, -1],
23+
[256, 3, 3]
24+
],
25+
"fc": [1024, 1024]
7026
}
71-
}
27+
},
28+
"data": {
29+
"text_column": "SentimentText",
30+
"label_column": "Sentiment",
31+
"max_length": 150,
32+
"num_of_classes": 2,
33+
"encoding": null,
34+
"chunksize": 50000,
35+
"max_rows": 100000,
36+
"preprocessing_steps": [
37+
"lower",
38+
"remove_hashtags",
39+
"remove_urls",
40+
"remove_user_mentions"
41+
]
42+
},
43+
"training": {
44+
"batch_size": 128,
45+
"learning_rate": 0.01,
46+
"epochs": 10,
47+
"optimizer": "sgd"
48+
}
49+
}

predict.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,19 @@
66

77
use_cuda = torch.cuda.is_available()
88

9+
910
def predict(args):
1011
model = CharacterLevelCNN(args, args.number_of_classes)
1112
state = torch.load(args.model)
1213
model.load_state_dict(state)
1314
model.eval()
14-
15+
1516
processed_input = utils.preprocess_input(args)
1617
processed_input = torch.tensor(processed_input)
1718
processed_input = processed_input.unsqueeze(0)
1819
if use_cuda:
19-
processed_input = processed_input.to('cuda')
20-
model = model.to('cuda')
20+
processed_input = processed_input.to("cuda")
21+
model = model.to("cuda")
2122
prediction = model(processed_input)
2223
probabilities = F.softmax(prediction, dim=1)
2324
probabilities = probabilities.detach().cpu().numpy()
@@ -26,22 +27,25 @@ def predict(args):
2627

2728
if __name__ == "__main__":
2829
parser = argparse.ArgumentParser(
29-
'Testing a pretrained Character Based CNN for text classification')
30-
parser.add_argument('--model', type=str, help='path for pre-trained model')
31-
parser.add_argument('--text', type=str,
32-
default='I love pizza!', help='text string')
33-
parser.add_argument('--steps', nargs="+", default=['lower'])
30+
"Testing a pretrained Character Based CNN for text classification"
31+
)
32+
parser.add_argument("--model", type=str, help="path for pre-trained model")
33+
parser.add_argument("--text", type=str, default="I love pizza!", help="text string")
34+
parser.add_argument("--steps", nargs="+", default=["lower"])
3435

3536
# arguments needed for the predicition
36-
parser.add_argument('--alphabet', type=str,
37-
default="abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+ =<>()[]{}")
38-
parser.add_argument('--number_of_characters', type=int, default=69)
39-
parser.add_argument('--extra_characters', type=str, default="éàèùâêîôûçëïü")
40-
parser.add_argument('--max_length', type=int, default=300)
41-
parser.add_argument('--number_of_classes', type=int, default=2)
37+
parser.add_argument(
38+
"--alphabet",
39+
type=str,
40+
default="abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+ =<>()[]{}",
41+
)
42+
parser.add_argument("--number_of_characters", type=int, default=69)
43+
parser.add_argument("--extra_characters", type=str, default="éàèùâêîôûçëïü")
44+
parser.add_argument("--max_length", type=int, default=300)
45+
parser.add_argument("--number_of_classes", type=int, default=2)
4246

4347
args = parser.parse_args()
4448
prediction = predict(args)
45-
46-
print('input : {}'.format(args.text))
47-
print('prediction : {}'.format(prediction))
49+
50+
print("input : {}".format(args.text))
51+
print("prediction : {}".format(prediction))

0 commit comments

Comments
 (0)