Skip to content

Commit f1de327

Browse files
gusdlf93rwightman
authored andcommitted
Add files via upload
1 parent 5dc2e6d commit f1de327

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

timm/models/csatv2.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,12 @@ def forward(self, x):
194194

195195
# Spatial Attention logic
196196
attention = self.attention(x)
197-
# Upsampling to match x spatial size
198-
x = x * nn.UpsamplingBilinear2d(x.shape[2:])(attention)
197+
198+
# [Fix] nn.UpsamplingBilinear2d 클래스 생성 -> F.interpolate 함수 사용
199+
# align_corners=False가 최신 기본값에 가깝습니다. (성능 차이는 미미함)
200+
attention = F.interpolate(attention, size=x.shape[2:], mode='bilinear', align_corners=False)
201+
202+
x = x * attention
199203

200204
x = input + self.drop_path(x)
201205
return x
@@ -236,6 +240,11 @@ def __init__(self, inp, oup, heads=8, dim_head=32, img_size=None, downsample=Fal
236240
self.pool1 = nn.MaxPool2d(3, 2, 1)
237241
self.pool2 = nn.MaxPool2d(3, 2, 1)
238242
self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)
243+
else:
244+
# [Fix] JIT 컴파일 에러 방지: 사용하지 않더라도 속성을 정의해야 함
245+
self.pool1 = nn.Identity()
246+
self.pool2 = nn.Identity()
247+
self.proj = nn.Identity()
239248

240249
# Attention block components
241250
# Note: In old code, PreNorm wrapped Attention. Here we split them.
@@ -476,6 +485,12 @@ def forward(self, x):
476485
# 기존 코드의 LayerNorm, GRN, DropPath, FeedForward, PosCNN, trunc_normal_ 함수를 그대로 사용하세요.
477486

478487
class LayerNorm(nn.Module):
488+
""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
489+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
490+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
491+
with shape (batch_size, channels, height, width).
492+
"""
493+
479494
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
480495
super().__init__()
481496
self.weight = nn.Parameter(torch.ones(normalized_shape))
@@ -489,7 +504,9 @@ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
489504
def forward(self, x):
490505
if self.data_format == "channels_last":
491506
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
492-
elif self.data_format == "channels_first":
507+
else:
508+
# [Fix] elif -> else로 변경
509+
# JIT이 "모든 경로에서 Tensor가 반환됨"을 알 수 있게 함
493510
u = x.mean(1, keepdim=True)
494511
s = (x - u).pow(2).mean(1, keepdim=True)
495512
x = (x - u) / torch.sqrt(s + self.eps)

0 commit comments

Comments
 (0)