Skip to content
This repository was archived by the owner on Oct 25, 2021. It is now read-only.

Commit e656429

Browse files
committed
(catalyst 20.03): update code
1 parent a26cf6b commit e656429

File tree

7 files changed

+30
-38
lines changed

7 files changed

+30
-38
lines changed

bin/tests/check_instance.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ print(iou_soft)
4949
print(iou_hard)
5050
5151
assert aggregated_loss < 0.9
52-
assert iou_soft > 0.05
52+
assert iou_soft > 0.04
5353
assert iou_hard > 0.1
5454
"""
5555

configs/templates/binary.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,6 @@ stages:
118118
callback: RawMaskPostprocessingCallback
119119
output_key: binary_mask
120120

121-
raw_processor:
122-
callback: RawMaskPostprocessingCallback
123-
124121
iou_soft:
125122
callback: IouCallback
126123
input_key: mask

configs/templates/instance.yml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ args:
1010
stages:
1111

1212
state_params:
13-
main_metric: &reduce_metric iou_hard
13+
main_metric: &reduced_metric iou_hard
1414
minimize_metric: False
1515

1616
data_params:
@@ -56,7 +56,7 @@ stages:
5656
contrast_limit: 0.5
5757
- transform: A.RandomGamma
5858
- transform: A.CLAHE
59-
- transform: A.JpegCompression
59+
- transform: A.ImageCompression
6060
quality_lower: 50
6161
- &post_transforms
6262
transform: A.Compose
@@ -108,10 +108,11 @@ stages:
108108
multiplier: 1.0
109109

110110
loss_aggregator:
111-
callback: CriterionAggregatorCallback
111+
callback: MetricAggregationCallback
112112
prefix: &aggregated_loss loss
113-
loss_aggregate_fn: "mean" # or "sum"
114-
multiplier: 1.0 # scale factor for the aggregated loss
113+
metrics: [loss_bce, loss_dice, loss_iou]
114+
mode: "mean"
115+
multiplier: 1.0
115116

116117
raw_processor:
117118
callback: RawMaskPostprocessingCallback
@@ -140,7 +141,7 @@ stages:
140141
loss_key: *aggregated_loss
141142
scheduler:
142143
callback: SchedulerCallback
143-
reduce_metric: *reduce_metric
144+
reduced_metric: *reduced_metric
144145
saver:
145146
callback: CheckpointCallback
146147

configs/templates/semantic.yml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,6 @@ stages:
118118
callback: RawMaskPostprocessingCallback
119119
output_key: semantic_mask
120120

121-
raw_processor:
122-
callback: RawMaskPostprocessingCallback
123-
output_key: semantic_mask
124-
125121
iou_soft:
126122
callback: IouCallback
127123
input_key: mask

scripts/process_instance_masks.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@
1414

1515
def build_args(parser):
1616
parser.add_argument(
17-
"--in-dir",
18-
type=Path,
19-
required=True,
20-
help="Raw masks folder path"
17+
"--in-dir", type=Path, required=True, help="Raw masks folder path"
2118
)
2219
parser.add_argument(
2320
"--out-dir",
@@ -151,10 +148,10 @@ def preprocess(self, sample: Path):
151148
else:
152149
sz = 3
153150

154-
uniq = np.unique(labels[
155-
max(0, y0 - sz):min(labels.shape[0], y0 + sz + 1),
156-
max(0, x0 - sz):min(labels.shape[1], x0 + sz + 1),
157-
])
151+
uniq = np.unique(
152+
labels[max(0, y0 - sz):min(labels.shape[0], y0 + sz + 1),
153+
max(0, x0 - sz):min(labels.shape[1], x0 + sz + 1)]
154+
)
158155
if len(uniq[uniq > 0]) > 1:
159156
borders[y0, x0] = 255
160157
mask_without_borders[y0, x0] = 0

src/callbacks/io.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import imageio
44
import numpy as np
55

6-
from catalyst.dl import Callback, CallbackOrder, State, utils
6+
from catalyst.dl import Callback, CallbackNode, CallbackOrder, State, utils
77
from .utils import crop_by_masks, mask_to_overlay_image
88

99

@@ -17,7 +17,7 @@ def __init__(
1717
input_key: str = "image",
1818
outpath_key: str = "name",
1919
):
20-
super().__init__(CallbackOrder.Logging)
20+
super().__init__(order=CallbackOrder.Logging, node=CallbackNode.Master)
2121
self.output_dir = Path(output_dir)
2222
self.relative = relative
2323
self.filename_suffix = filename_suffix
@@ -102,10 +102,11 @@ def __init__(
102102
self.output_key = output_key
103103

104104
def on_batch_end(self, state: State):
105-
names = state.input[self.outpath_key]
106-
images = utils.tensor_to_ndimage(state.input[self.input_key].cpu())
107-
masks = state.output[self.output_key]
105+
names = state.batch_in[self.outpath_key]
106+
images = state.batch_in[self.input_key]
107+
masks = state.batch_out[self.output_key]
108108

109+
images = utils.tensor_to_ndimage(images.detach().cpu())
109110
for name, image, masks_ in zip(names, images, masks):
110111
instances = crop_by_masks(image, masks_)
111112

src/callbacks/processing.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22

3-
from catalyst.dl import Callback, CallbackOrder, State
3+
from catalyst.dl import Callback, CallbackNode, CallbackOrder, State
44
from .utils import encode_mask_with_color, label_instances
55

66

@@ -11,7 +11,7 @@ def __init__(
1111
input_key: str = "logits",
1212
output_key: str = "mask",
1313
):
14-
super().__init__(CallbackOrder.Internal)
14+
super().__init__(order=CallbackOrder.Internal, node=CallbackNode.All)
1515
self.threshold = threshold
1616
self.input_key = input_key
1717
self.output_key = output_key
@@ -21,7 +21,7 @@ def on_batch_end(self, state: State):
2121

2222
output = torch.sigmoid(output).detach().cpu().numpy()
2323
state.batch_out[self.output_key] = encode_mask_with_color(
24-
output, self.threshold
24+
output, threshold=self.threshold
2525
)
2626

2727

@@ -35,7 +35,7 @@ def __init__(
3535
out_key_semantic: str = None,
3636
out_key_border: str = None,
3737
):
38-
super().__init__(CallbackOrder.Internal)
38+
super().__init__(CallbackOrder.Internal, node=CallbackNode.All)
3939
self.watershed_threshold = watershed_threshold
4040
self.mask_threshold = mask_threshold
4141
self.input_key = input_key
@@ -44,22 +44,22 @@ def __init__(
4444
self.out_key_border = out_key_border
4545

4646
def on_batch_end(self, state: State):
47-
output: torch.Tensor = torch.sigmoid(state.output[self.input_key])
47+
output = state.batch_out[self.input_key]
4848

49+
output = torch.sigmoid(output).detach().cpu()
4950
semantic, border = output.chunk(2, -3)
5051

5152
if self.out_key_semantic is not None:
52-
state.output[self.out_key_semantic] = encode_mask_with_color(
53-
semantic.data.cpu().numpy(), threshold=self.mask_threshold
53+
state.batch_out[self.out_key_semantic] = encode_mask_with_color(
54+
semantic.numpy(), threshold=self.mask_threshold
5455
)
5556

5657
if self.out_key_border is not None:
57-
state.output[self.out_key_border] = (
58-
border.data.cpu().squeeze(-3).numpy() >
59-
self.watershed_threshold
58+
state.batch_out[self.out_key_border] = (
59+
border.squeeze(-3).numpy() > self.watershed_threshold
6060
)
6161

62-
state.output[self.output_key] = label_instances(
62+
state.batch_out[self.output_key] = label_instances(
6363
semantic,
6464
border,
6565
watershed_threshold=self.watershed_threshold,

0 commit comments

Comments
 (0)