11import torch
22
3- from catalyst .dl import Callback , CallbackOrder , State
3+ from catalyst .dl import Callback , CallbackNode , CallbackOrder , State
44from .utils import encode_mask_with_color , label_instances
55
66
@@ -11,7 +11,7 @@ def __init__(
1111 input_key : str = "logits" ,
1212 output_key : str = "mask" ,
1313 ):
14- super ().__init__ (CallbackOrder .Internal )
14+ super ().__init__ (order = CallbackOrder .Internal , node = CallbackNode . All )
1515 self .threshold = threshold
1616 self .input_key = input_key
1717 self .output_key = output_key
@@ -21,7 +21,7 @@ def on_batch_end(self, state: State):
2121
2222 output = torch .sigmoid (output ).detach ().cpu ().numpy ()
2323 state .batch_out [self .output_key ] = encode_mask_with_color (
24- output , self .threshold
24+ output , threshold = self .threshold
2525 )
2626
2727
@@ -35,7 +35,7 @@ def __init__(
3535 out_key_semantic : str = None ,
3636 out_key_border : str = None ,
3737 ):
38- super ().__init__ (CallbackOrder .Internal )
38+ super ().__init__ (CallbackOrder .Internal , node = CallbackNode . All )
3939 self .watershed_threshold = watershed_threshold
4040 self .mask_threshold = mask_threshold
4141 self .input_key = input_key
@@ -44,22 +44,22 @@ def __init__(
4444 self .out_key_border = out_key_border
4545
4646 def on_batch_end (self , state : State ):
47- output : torch . Tensor = torch . sigmoid ( state .output [self .input_key ])
47+ output = state .batch_out [self .input_key ]
4848
49+ output = torch .sigmoid (output ).detach ().cpu ()
4950 semantic , border = output .chunk (2 , - 3 )
5051
5152 if self .out_key_semantic is not None :
52- state .output [self .out_key_semantic ] = encode_mask_with_color (
53- semantic .data . cpu (). numpy (), threshold = self .mask_threshold
53+ state .batch_out [self .out_key_semantic ] = encode_mask_with_color (
54+ semantic .numpy (), threshold = self .mask_threshold
5455 )
5556
5657 if self .out_key_border is not None :
57- state .output [self .out_key_border ] = (
58- border .data .cpu ().squeeze (- 3 ).numpy () >
59- self .watershed_threshold
58+ state .batch_out [self .out_key_border ] = (
59+ border .squeeze (- 3 ).numpy () > self .watershed_threshold
6060 )
6161
62- state .output [self .output_key ] = label_instances (
62+ state .batch_out [self .output_key ] = label_instances (
6363 semantic ,
6464 border ,
6565 watershed_threshold = self .watershed_threshold ,
0 commit comments