Skip to content

Commit 8d5e51e

Browse files
gusdlf93rwightman
authored andcommitted
Pass the Test
1 parent 9e40d5a commit 8d5e51e

File tree

1 file changed

+64
-38
lines changed

1 file changed

+64
-38
lines changed

timm/models/csatv2.py

Lines changed: 64 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,16 @@ def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = T
6666
factory_kwargs = dict(device=device, dtype=dtype)
6767
super(_DCT1D, self).__init__()
6868
kernel = {'2': _dct_kernel_type_2, '3': _dct_kernel_type_3}
69-
self.weights = nn.Parameter(kernel[f'{kernel_type}'](kernel_size, orthonormal, **factory_kwargs).T, False)
69+
dct_weights = kernel[f'{kernel_type}'](kernel_size, orthonormal, **factory_kwargs).T
70+
self.register_buffer('weights', dct_weights)
71+
7072
self.register_parameter('bias', None)
7173

7274
def forward(self, x: torch.Tensor) -> torch.Tensor:
7375
return nn.functional.linear(x, self.weights, self.bias)
7476

7577

78+
7679
class _DCT2D(nn.Module):
7780
def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True,
7881
device=None, dtype=None) -> None:
@@ -96,6 +99,7 @@ def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = T
9699
self.Y_Conv = nn.Conv2d(kernel_size ** 2, 24, kernel_size=1, padding=0)
97100
self.Cb_Conv = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0)
98101
self.Cr_Conv = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0)
102+
99103
self.mean = torch.tensor(mean, requires_grad=False)
100104
self.var = torch.tensor(var, requires_grad=False)
101105
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406], requires_grad=False)
@@ -113,9 +117,14 @@ def rgb2ycbcr(self, x):
113117
return x
114118

115119
def frequncy_normalize(self, x):
116-
x[:, 0, ].sub_(self.mean.to(x.device)[0]).div_((self.var.to(x.device)[0] ** 0.5 + 1e-8))
117-
x[:, 1, ].sub_(self.mean.to(x.device)[1]).div_((self.var.to(x.device)[1] ** 0.5 + 1e-8))
118-
x[:, 2, ].sub_(self.mean.to(x.device)[2]).div_((self.var.to(x.device)[2] ** 0.5 + 1e-8))
120+
121+
mean_tensor = self.mean.to(x.device)
122+
var_tensor = self.var.to(x.device)
123+
124+
std = var_tensor ** 0.5 + 1e-8
125+
126+
x = (x - mean_tensor) / std
127+
119128
return x
120129

121130
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -181,11 +190,13 @@ def __init__(self, dim, drop_path=0.):
181190
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
182191
self.attention = Spatial_Attention()
183192

184-
def forward(self, x):
193+
def forward(self, x: torch.Tensor) -> torch.Tensor:
185194
input = x
186195
x = self.dwconv(x)
187196
x = x.permute(0, 2, 3, 1)
197+
188198
x = self.norm(x)
199+
189200
x = self.pwconv1(x)
190201
x = self.act(x)
191202
x = self.grn(x)
@@ -194,11 +205,9 @@ def forward(self, x):
194205

195206
# Spatial Attention logic
196207
attention = self.attention(x)
197-
198-
# [Fix] create nn.UpsamplingBilinear2d class -> use F.interpolate function
199-
attention = F.interpolate(attention, size=x.shape[2:], mode='bilinear', align_corners=False)
200-
201-
x = x * attention
208+
# x = x * nn.UpsamplingBilinear2d(size=x.shape[2:])(attention)
209+
up_attn = F.interpolate(attention, size=x.shape[2:], mode='bilinear', align_corners=True)
210+
x = x * up_attn
202211

203212
x = input + self.drop_path(x)
204213
return x
@@ -223,11 +232,6 @@ def forward(self, x):
223232

224233

225234
class TransformerBlock(nn.Module):
226-
"""
227-
Refactored TransformerBlock without einops and PreNorm class wrapper.
228-
Manual reshaping is performed in forward().
229-
"""
230-
231235
def __init__(self, inp, oup, heads=8, dim_head=32, img_size=None, downsample=False, dropout=0.):
232236
super().__init__()
233237
hidden_dim = int(inp * 4)
@@ -240,15 +244,16 @@ def __init__(self, inp, oup, heads=8, dim_head=32, img_size=None, downsample=Fal
240244
self.pool2 = nn.MaxPool2d(3, 2, 1)
241245
self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)
242246
else:
243-
# [Fix] Prevent JIT compile error
247+
# [change] for use TorchScript all variable need to declare for Identity
244248
self.pool1 = nn.Identity()
245249
self.pool2 = nn.Identity()
246250
self.proj = nn.Identity()
247251

248252
# Attention block components
249-
# Note: In old code, PreNorm wrapped Attention. Here we split them.
250253
self.attn_norm = nn.LayerNorm(inp)
251-
self.attn = Attention(inp, oup, heads, dim_head, dropout)
254+
# self.attn = Attention(inp, oup, heads, dim_head, dropout)
255+
self.attn = Attention(inp, oup, heads, dim_head, dropout, img_size=img_size)
256+
252257

253258
# FeedForward block components
254259
self.ff_norm = nn.LayerNorm(oup)
@@ -316,7 +321,7 @@ class Attention(nn.Module):
316321
Refactored Attention without einops.rearrange.
317322
"""
318323

319-
def __init__(self, inp, oup, heads=8, dim_head=32, dropout=0.):
324+
def __init__(self, inp, oup, heads=8, dim_head=32, dropout=0., img_size=None):
320325
super().__init__()
321326
inner_dim = dim_head * heads
322327
project_out = not (heads == 1 and dim_head == inp)
@@ -332,7 +337,8 @@ def __init__(self, inp, oup, heads=8, dim_head=32, dropout=0.):
332337
nn.Linear(inner_dim, oup),
333338
nn.Dropout(dropout)
334339
) if project_out else nn.Identity()
335-
self.pos_embed = PosCNN(in_chans=inp)
340+
# self.pos_embed = PosCNN(in_chans=inp)
341+
self.pos_embed = PosCNN(in_chans=inp, img_size=img_size)
336342

337343
def forward(self, x):
338344
# x shape: (B, N, C)
@@ -387,6 +393,16 @@ def __init__(
387393
self.img_size = img_size
388394

389395
dims = [32, 72, 168, 386]
396+
self.num_features = dims[-1]
397+
self.head_hidden_size = self.num_features
398+
399+
self.feature_info = [
400+
dict(num_chs=32, reduction=8, module='dct'),
401+
dict(num_chs=dims[1], reduction=16, module='stages1'),
402+
dict(num_chs=dims[2], reduction=32, module='stages2'),
403+
dict(num_chs=dims[3], reduction=64, module='stages3'),
404+
dict(num_chs=dims[3], reduction=64, module='stages4'),
405+
]
390406
channel_order = "channels_first"
391407
depths = [2, 2, 6, 4]
392408
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
@@ -479,17 +495,7 @@ def forward(self, x):
479495
return x
480496

481497

482-
# --- Components like LayerNorm, GRN, DropPath, FeedForward, PosCNN, trunc_normal_ ---
483-
# (이 부분은 einops와 무관하므로 위 코드와 동일하게 유지합니다. 여기서는 공간 절약을 위해 생략)
484-
# 기존 코드의 LayerNorm, GRN, DropPath, FeedForward, PosCNN, trunc_normal_ 함수를 그대로 사용하세요.
485-
486498
class LayerNorm(nn.Module):
487-
""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
488-
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
489-
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
490-
with shape (batch_size, channels, height, width).
491-
"""
492-
493499
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
494500
super().__init__()
495501
self.weight = nn.Parameter(torch.ones(normalized_shape))
@@ -500,12 +506,10 @@ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
500506
raise NotImplementedError
501507
self.normalized_shape = (normalized_shape,)
502508

503-
def forward(self, x):
509+
def forward(self, x: torch.Tensor) -> torch.Tensor:
504510
if self.data_format == "channels_last":
505511
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
506-
else:
507-
# [Fix] change elif -> else
508-
# JIT should know Tensor return path of all cases.
512+
else: #elif self.data_format == "channels_first":
509513
u = x.mean(1, keepdim=True)
510514
s = (x - u).pow(2).mean(1, keepdim=True)
511515
x = (x - u) / torch.sqrt(s + self.eps)
@@ -561,20 +565,41 @@ def forward(self, x):
561565

562566

563567
class PosCNN(nn.Module):
564-
def __init__(self, in_chans):
568+
def __init__(self, in_chans, img_size=None):
565569
super(PosCNN, self).__init__()
566570
self.proj = nn.Conv2d(in_chans, in_chans, kernel_size=3, stride=1, padding=1, bias=True, groups=in_chans)
571+
self.img_size = img_size
572+
573+
# ignore JIT + safety variable type change
574+
@torch.jit.ignore
575+
def _get_dynamic_size(self, N):
576+
# 1. Eager Mode (normal execution)
577+
if isinstance(N, int):
578+
s = int(N ** 0.5)
579+
return s, s
580+
581+
# 2. FX Tracing & Runtime
582+
# if n is Proxy or Tensor, change to float Tensor
583+
N_float = N * torch.tensor(1.0) # float promotion (Proxy 호환)
584+
s = (N_float ** 0.5).to(torch.int)
585+
return s, s
567586

568587
def forward(self, x):
569588
B, N, C = x.shape
570589
feat_token = x
571-
H, W = int(N ** 0.5), int(N ** 0.5)
590+
591+
# JIT mode vs others
592+
if torch.jit.is_scripting():
593+
H = int(N ** 0.5)
594+
W = H
595+
else:
596+
H, W = self._get_dynamic_size(N)
597+
572598
cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W)
573599
x = self.proj(cnn_feat) + cnn_feat
574600
x = x.flatten(2).transpose(1, 2)
575601
return x
576602

577-
578603
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
579604
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
580605

@@ -605,10 +630,11 @@ def norm_cdf(x):
605630
'std': (0.229, 0.224, 0.225),
606631
'interpolation': 'bilinear',
607632
'crop_pct': 1.0,
633+
'classifier': 'head',
634+
'first_conv': [],
608635
},
609636
})
610637

611-
612638
def _create_csatv2(variant: str, pretrained: bool = False, **kwargs) -> CSATv2:
613639
return build_model_with_cfg(
614640
CSATv2,

0 commit comments

Comments
 (0)