@@ -66,13 +66,16 @@ def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = T
6666 factory_kwargs = dict (device = device , dtype = dtype )
6767 super (_DCT1D , self ).__init__ ()
6868 kernel = {'2' : _dct_kernel_type_2 , '3' : _dct_kernel_type_3 }
69- self .weights = nn .Parameter (kernel [f'{ kernel_type } ' ](kernel_size , orthonormal , ** factory_kwargs ).T , False )
69+ dct_weights = kernel [f'{ kernel_type } ' ](kernel_size , orthonormal , ** factory_kwargs ).T
70+ self .register_buffer ('weights' , dct_weights )
71+
7072 self .register_parameter ('bias' , None )
7173
7274 def forward (self , x : torch .Tensor ) -> torch .Tensor :
7375 return nn .functional .linear (x , self .weights , self .bias )
7476
7577
78+
7679class _DCT2D (nn .Module ):
7780 def __init__ (self , kernel_size : int , kernel_type : int = 2 , orthonormal : bool = True ,
7881 device = None , dtype = None ) -> None :
@@ -96,6 +99,7 @@ def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = T
9699 self .Y_Conv = nn .Conv2d (kernel_size ** 2 , 24 , kernel_size = 1 , padding = 0 )
97100 self .Cb_Conv = nn .Conv2d (kernel_size ** 2 , 4 , kernel_size = 1 , padding = 0 )
98101 self .Cr_Conv = nn .Conv2d (kernel_size ** 2 , 4 , kernel_size = 1 , padding = 0 )
102+
99103 self .mean = torch .tensor (mean , requires_grad = False )
100104 self .var = torch .tensor (var , requires_grad = False )
101105 self .imagenet_mean = torch .tensor ([0.485 , 0.456 , 0.406 ], requires_grad = False )
@@ -113,9 +117,14 @@ def rgb2ycbcr(self, x):
113117 return x
114118
115119 def frequncy_normalize (self , x ):
116- x [:, 0 , ].sub_ (self .mean .to (x .device )[0 ]).div_ ((self .var .to (x .device )[0 ] ** 0.5 + 1e-8 ))
117- x [:, 1 , ].sub_ (self .mean .to (x .device )[1 ]).div_ ((self .var .to (x .device )[1 ] ** 0.5 + 1e-8 ))
118- x [:, 2 , ].sub_ (self .mean .to (x .device )[2 ]).div_ ((self .var .to (x .device )[2 ] ** 0.5 + 1e-8 ))
120+
121+ mean_tensor = self .mean .to (x .device )
122+ var_tensor = self .var .to (x .device )
123+
124+ std = var_tensor ** 0.5 + 1e-8
125+
126+ x = (x - mean_tensor ) / std
127+
119128 return x
120129
121130 def forward (self , x : torch .Tensor ) -> torch .Tensor :
@@ -181,11 +190,13 @@ def __init__(self, dim, drop_path=0.):
181190 self .drop_path = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
182191 self .attention = Spatial_Attention ()
183192
184- def forward (self , x ) :
193+ def forward (self , x : torch . Tensor ) -> torch . Tensor :
185194 input = x
186195 x = self .dwconv (x )
187196 x = x .permute (0 , 2 , 3 , 1 )
197+
188198 x = self .norm (x )
199+
189200 x = self .pwconv1 (x )
190201 x = self .act (x )
191202 x = self .grn (x )
@@ -194,11 +205,9 @@ def forward(self, x):
194205
195206 # Spatial Attention logic
196207 attention = self .attention (x )
197-
198- # [Fix] create nn.UpsamplingBilinear2d class -> use F.interpolate function
199- attention = F .interpolate (attention , size = x .shape [2 :], mode = 'bilinear' , align_corners = False )
200-
201- x = x * attention
208+ # x = x * nn.UpsamplingBilinear2d(size=x.shape[2:])(attention)
209+ up_attn = F .interpolate (attention , size = x .shape [2 :], mode = 'bilinear' , align_corners = True )
210+ x = x * up_attn
202211
203212 x = input + self .drop_path (x )
204213 return x
@@ -223,11 +232,6 @@ def forward(self, x):
223232
224233
225234class TransformerBlock (nn .Module ):
226- """
227- Refactored TransformerBlock without einops and PreNorm class wrapper.
228- Manual reshaping is performed in forward().
229- """
230-
231235 def __init__ (self , inp , oup , heads = 8 , dim_head = 32 , img_size = None , downsample = False , dropout = 0. ):
232236 super ().__init__ ()
233237 hidden_dim = int (inp * 4 )
@@ -240,15 +244,16 @@ def __init__(self, inp, oup, heads=8, dim_head=32, img_size=None, downsample=Fal
240244 self .pool2 = nn .MaxPool2d (3 , 2 , 1 )
241245 self .proj = nn .Conv2d (inp , oup , 1 , 1 , 0 , bias = False )
242246 else :
243- # [Fix] Prevent JIT compile error
247+ # [change] for use TorchScript all variable need to declare for Identity
244248 self .pool1 = nn .Identity ()
245249 self .pool2 = nn .Identity ()
246250 self .proj = nn .Identity ()
247251
248252 # Attention block components
249- # Note: In old code, PreNorm wrapped Attention. Here we split them.
250253 self .attn_norm = nn .LayerNorm (inp )
251- self .attn = Attention (inp , oup , heads , dim_head , dropout )
254+ # self.attn = Attention(inp, oup, heads, dim_head, dropout)
255+ self .attn = Attention (inp , oup , heads , dim_head , dropout , img_size = img_size )
256+
252257
253258 # FeedForward block components
254259 self .ff_norm = nn .LayerNorm (oup )
@@ -316,7 +321,7 @@ class Attention(nn.Module):
316321 Refactored Attention without einops.rearrange.
317322 """
318323
319- def __init__ (self , inp , oup , heads = 8 , dim_head = 32 , dropout = 0. ):
324+ def __init__ (self , inp , oup , heads = 8 , dim_head = 32 , dropout = 0. , img_size = None ):
320325 super ().__init__ ()
321326 inner_dim = dim_head * heads
322327 project_out = not (heads == 1 and dim_head == inp )
@@ -332,7 +337,8 @@ def __init__(self, inp, oup, heads=8, dim_head=32, dropout=0.):
332337 nn .Linear (inner_dim , oup ),
333338 nn .Dropout (dropout )
334339 ) if project_out else nn .Identity ()
335- self .pos_embed = PosCNN (in_chans = inp )
340+ # self.pos_embed = PosCNN(in_chans=inp)
341+ self .pos_embed = PosCNN (in_chans = inp , img_size = img_size )
336342
337343 def forward (self , x ):
338344 # x shape: (B, N, C)
@@ -387,6 +393,16 @@ def __init__(
387393 self .img_size = img_size
388394
389395 dims = [32 , 72 , 168 , 386 ]
396+ self .num_features = dims [- 1 ]
397+ self .head_hidden_size = self .num_features
398+
399+ self .feature_info = [
400+ dict (num_chs = 32 , reduction = 8 , module = 'dct' ),
401+ dict (num_chs = dims [1 ], reduction = 16 , module = 'stages1' ),
402+ dict (num_chs = dims [2 ], reduction = 32 , module = 'stages2' ),
403+ dict (num_chs = dims [3 ], reduction = 64 , module = 'stages3' ),
404+ dict (num_chs = dims [3 ], reduction = 64 , module = 'stages4' ),
405+ ]
390406 channel_order = "channels_first"
391407 depths = [2 , 2 , 6 , 4 ]
392408 dp_rates = [x .item () for x in torch .linspace (0 , drop_path_rate , sum (depths ))]
@@ -479,17 +495,7 @@ def forward(self, x):
479495 return x
480496
481497
482- # --- Components like LayerNorm, GRN, DropPath, FeedForward, PosCNN, trunc_normal_ ---
483- # (이 부분은 einops와 무관하므로 위 코드와 동일하게 유지합니다. 여기서는 공간 절약을 위해 생략)
484- # 기존 코드의 LayerNorm, GRN, DropPath, FeedForward, PosCNN, trunc_normal_ 함수를 그대로 사용하세요.
485-
486498class LayerNorm (nn .Module ):
487- """ LayerNorm that supports two data formats: channels_last (default) or channels_first.
488- The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
489- shape (batch_size, height, width, channels) while channels_first corresponds to inputs
490- with shape (batch_size, channels, height, width).
491- """
492-
493499 def __init__ (self , normalized_shape , eps = 1e-6 , data_format = "channels_last" ):
494500 super ().__init__ ()
495501 self .weight = nn .Parameter (torch .ones (normalized_shape ))
@@ -500,12 +506,10 @@ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
500506 raise NotImplementedError
501507 self .normalized_shape = (normalized_shape ,)
502508
503- def forward (self , x ) :
509+ def forward (self , x : torch . Tensor ) -> torch . Tensor :
504510 if self .data_format == "channels_last" :
505511 return F .layer_norm (x , self .normalized_shape , self .weight , self .bias , self .eps )
506- else :
507- # [Fix] change elif -> else
508- # JIT should know Tensor return path of all cases.
512+ else : #elif self.data_format == "channels_first":
509513 u = x .mean (1 , keepdim = True )
510514 s = (x - u ).pow (2 ).mean (1 , keepdim = True )
511515 x = (x - u ) / torch .sqrt (s + self .eps )
@@ -561,20 +565,41 @@ def forward(self, x):
561565
562566
563567class PosCNN (nn .Module ):
564- def __init__ (self , in_chans ):
568+ def __init__ (self , in_chans , img_size = None ):
565569 super (PosCNN , self ).__init__ ()
566570 self .proj = nn .Conv2d (in_chans , in_chans , kernel_size = 3 , stride = 1 , padding = 1 , bias = True , groups = in_chans )
571+ self .img_size = img_size
572+
573+ # ignore JIT + safety variable type change
574+ @torch .jit .ignore
575+ def _get_dynamic_size (self , N ):
576+ # 1. Eager Mode (normal execution)
577+ if isinstance (N , int ):
578+ s = int (N ** 0.5 )
579+ return s , s
580+
581+ # 2. FX Tracing & Runtime
582+ # if n is Proxy or Tensor, change to float Tensor
583+ N_float = N * torch .tensor (1.0 ) # float promotion (Proxy 호환)
584+ s = (N_float ** 0.5 ).to (torch .int )
585+ return s , s
567586
568587 def forward (self , x ):
569588 B , N , C = x .shape
570589 feat_token = x
571- H , W = int (N ** 0.5 ), int (N ** 0.5 )
590+
591+ # JIT mode vs others
592+ if torch .jit .is_scripting ():
593+ H = int (N ** 0.5 )
594+ W = H
595+ else :
596+ H , W = self ._get_dynamic_size (N )
597+
572598 cnn_feat = feat_token .transpose (1 , 2 ).view (B , C , H , W )
573599 x = self .proj (cnn_feat ) + cnn_feat
574600 x = x .flatten (2 ).transpose (1 , 2 )
575601 return x
576602
577-
578603def trunc_normal_ (tensor , mean = 0. , std = 1. , a = - 2. , b = 2. ):
579604 return _no_grad_trunc_normal_ (tensor , mean , std , a , b )
580605
@@ -605,10 +630,11 @@ def norm_cdf(x):
605630 'std' : (0.229 , 0.224 , 0.225 ),
606631 'interpolation' : 'bilinear' ,
607632 'crop_pct' : 1.0 ,
633+ 'classifier' : 'head' ,
634+ 'first_conv' : [],
608635 },
609636})
610637
611-
612638def _create_csatv2 (variant : str , pretrained : bool = False , ** kwargs ) -> CSATv2 :
613639 return build_model_with_cfg (
614640 CSATv2 ,
0 commit comments