@@ -194,8 +194,12 @@ def forward(self, x):
194194
195195 # Spatial Attention logic
196196 attention = self .attention (x )
197- # Upsampling to match x spatial size
198- x = x * nn .UpsamplingBilinear2d (x .shape [2 :])(attention )
197+
198+ # [Fix] nn.UpsamplingBilinear2d 클래스 생성 -> F.interpolate 함수 사용
199+ # align_corners=False가 최신 기본값에 가깝습니다. (성능 차이는 미미함)
200+ attention = F .interpolate (attention , size = x .shape [2 :], mode = 'bilinear' , align_corners = False )
201+
202+ x = x * attention
199203
200204 x = input + self .drop_path (x )
201205 return x
@@ -236,6 +240,11 @@ def __init__(self, inp, oup, heads=8, dim_head=32, img_size=None, downsample=Fal
236240 self .pool1 = nn .MaxPool2d (3 , 2 , 1 )
237241 self .pool2 = nn .MaxPool2d (3 , 2 , 1 )
238242 self .proj = nn .Conv2d (inp , oup , 1 , 1 , 0 , bias = False )
243+ else :
244+ # [Fix] JIT 컴파일 에러 방지: 사용하지 않더라도 속성을 정의해야 함
245+ self .pool1 = nn .Identity ()
246+ self .pool2 = nn .Identity ()
247+ self .proj = nn .Identity ()
239248
240249 # Attention block components
241250 # Note: In old code, PreNorm wrapped Attention. Here we split them.
@@ -476,6 +485,12 @@ def forward(self, x):
476485# 기존 코드의 LayerNorm, GRN, DropPath, FeedForward, PosCNN, trunc_normal_ 함수를 그대로 사용하세요.
477486
478487class LayerNorm (nn .Module ):
488+ """ LayerNorm that supports two data formats: channels_last (default) or channels_first.
489+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
490+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
491+ with shape (batch_size, channels, height, width).
492+ """
493+
479494 def __init__ (self , normalized_shape , eps = 1e-6 , data_format = "channels_last" ):
480495 super ().__init__ ()
481496 self .weight = nn .Parameter (torch .ones (normalized_shape ))
@@ -489,7 +504,9 @@ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
489504 def forward (self , x ):
490505 if self .data_format == "channels_last" :
491506 return F .layer_norm (x , self .normalized_shape , self .weight , self .bias , self .eps )
492- elif self .data_format == "channels_first" :
507+ else :
508+ # [Fix] elif -> else로 변경
509+ # JIT이 "모든 경로에서 Tensor가 반환됨"을 알 수 있게 함
493510 u = x .mean (1 , keepdim = True )
494511 s = (x - u ).pow (2 ).mean (1 , keepdim = True )
495512 x = (x - u ) / torch .sqrt (s + self .eps )
0 commit comments