Skip to content

Commit 4706c8e

Browse files
committed
Consistent use of device/dtype factory kwargs through csatv2 and submodules
1 parent 1c6b0f5 commit 4706c8e

File tree

1 file changed

+91
-68
lines changed

1 file changed

+91
-68
lines changed

timm/models/csatv2.py

Lines changed: 91 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -102,21 +102,21 @@ def _dct_kernel_type_2(
102102
dtype=None,
103103
) -> torch.Tensor:
104104
"""Generate Type-II DCT kernel matrix."""
105-
factory_kwargs = dict(device=device, dtype=dtype)
106-
x = torch.eye(kernel_size, **factory_kwargs)
105+
dd = dict(device=device, dtype=dtype)
106+
x = torch.eye(kernel_size, **dd)
107107
v = x.clone().contiguous().view(-1, kernel_size)
108108
v = torch.cat([v, v.flip([1])], dim=-1)
109109
v = torch.fft.fft(v, dim=-1)[:, :kernel_size]
110110
try:
111-
k = torch.tensor(-1j, **factory_kwargs) * torch.pi * torch.arange(kernel_size, **factory_kwargs)[None, :]
111+
k = torch.tensor(-1j, **dd) * torch.pi * torch.arange(kernel_size, **dd)[None, :]
112112
except AttributeError:
113-
k = torch.tensor(-1j, **factory_kwargs) * math.pi * torch.arange(kernel_size, **factory_kwargs)[None, :]
113+
k = torch.tensor(-1j, **dd) * math.pi * torch.arange(kernel_size, **dd)[None, :]
114114
k = torch.exp(k / (kernel_size * 2))
115115
v = v * k
116116
v = v.real
117117
if orthonormal:
118-
v[:, 0] = v[:, 0] * torch.sqrt(torch.tensor(1 / (kernel_size * 4), **factory_kwargs))
119-
v[:, 1:] = v[:, 1:] * torch.sqrt(torch.tensor(1 / (kernel_size * 2), **factory_kwargs))
118+
v[:, 0] = v[:, 0] * torch.sqrt(torch.tensor(1 / (kernel_size * 4), **dd))
119+
v[:, 1:] = v[:, 1:] * torch.sqrt(torch.tensor(1 / (kernel_size * 2), **dd))
120120
v = v.contiguous().view(*x.shape)
121121
return v
122122

@@ -142,10 +142,10 @@ def __init__(
142142
device=None,
143143
dtype=None,
144144
) -> None:
145-
factory_kwargs = dict(device=device, dtype=dtype)
145+
dd = dict(device=device, dtype=dtype)
146146
super().__init__()
147147
kernel = {'2': _dct_kernel_type_2, '3': _dct_kernel_type_3}
148-
dct_weights = kernel[f'{kernel_type}'](kernel_size, orthonormal, **factory_kwargs).T
148+
dct_weights = kernel[f'{kernel_type}'](kernel_size, orthonormal, **dd).T
149149
self.register_buffer('weights', dct_weights)
150150
self.register_parameter('bias', None)
151151

@@ -164,9 +164,9 @@ def __init__(
164164
device=None,
165165
dtype=None,
166166
) -> None:
167-
factory_kwargs = dict(device=device, dtype=dtype)
167+
dd = dict(device=device, dtype=dtype)
168168
super().__init__()
169-
self.transform = Dct1d(kernel_size, kernel_type, orthonormal, **factory_kwargs)
169+
self.transform = Dct1d(kernel_size, kernel_type, orthonormal, **dd)
170170

171171
def forward(self, x: torch.Tensor) -> torch.Tensor:
172172
return self.transform(self.transform(x).transpose(-1, -2)).transpose(-1, -2)
@@ -183,20 +183,20 @@ def __init__(
183183
device=None,
184184
dtype=None,
185185
) -> None:
186-
factory_kwargs = dict(device=device, dtype=dtype)
186+
dd = dict(device=device, dtype=dtype)
187187
super().__init__()
188188
self.k = kernel_size
189189
self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size))
190-
self.transform = Dct2d(kernel_size, kernel_type, orthonormal, **factory_kwargs)
190+
self.transform = Dct2d(kernel_size, kernel_type, orthonormal, **dd)
191191
self.permutation = _zigzag_permutation(kernel_size, kernel_size)
192-
self.conv_y = nn.Conv2d(kernel_size ** 2, 24, kernel_size=1, padding=0)
193-
self.conv_cb = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0)
194-
self.conv_cr = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0)
192+
self.conv_y = nn.Conv2d(kernel_size ** 2, 24, kernel_size=1, padding=0, **dd)
193+
self.conv_cb = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0, **dd)
194+
self.conv_cr = nn.Conv2d(kernel_size ** 2, 4, kernel_size=1, padding=0, **dd)
195195

196-
self.register_buffer('mean', torch.tensor(_DCT_MEAN), persistent=False)
197-
self.register_buffer('var', torch.tensor(_DCT_VAR), persistent=False)
198-
self.register_buffer('imagenet_mean', torch.tensor([0.485, 0.456, 0.406]), persistent=False)
199-
self.register_buffer('imagenet_std', torch.tensor([0.229, 0.224, 0.225]), persistent=False)
196+
self.register_buffer('mean', torch.tensor(_DCT_MEAN, device=device), persistent=False)
197+
self.register_buffer('var', torch.tensor(_DCT_VAR, device=device), persistent=False)
198+
self.register_buffer('imagenet_mean', torch.tensor([0.485, 0.456, 0.406], device=device), persistent=False)
199+
self.register_buffer('imagenet_std', torch.tensor([0.229, 0.224, 0.225], device=device), persistent=False)
200200

201201
def _denormalize(self, x: torch.Tensor) -> torch.Tensor:
202202
"""Convert from ImageNet normalized to [0, 255] range."""
@@ -245,11 +245,11 @@ def __init__(
245245
device=None,
246246
dtype=None,
247247
) -> None:
248-
factory_kwargs = dict(device=device, dtype=dtype)
248+
dd = dict(device=device, dtype=dtype)
249249
super().__init__()
250250
self.k = kernel_size
251251
self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size))
252-
self.transform = Dct2d(kernel_size, kernel_type, orthonormal, **factory_kwargs)
252+
self.transform = Dct2d(kernel_size, kernel_type, orthonormal, **dd)
253253
self.permutation = _zigzag_permutation(kernel_size, kernel_size)
254254

255255
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -276,16 +276,19 @@ def __init__(
276276
self,
277277
dim: int,
278278
drop_path: float = 0.,
279+
device=None,
280+
dtype=None,
279281
) -> None:
282+
dd = dict(device=device, dtype=dtype)
280283
super().__init__()
281-
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)
282-
self.norm = nn.LayerNorm(dim, eps=1e-6)
283-
self.pwconv1 = nn.Linear(dim, 4 * dim)
284+
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, **dd)
285+
self.norm = nn.LayerNorm(dim, eps=1e-6, **dd)
286+
self.pwconv1 = nn.Linear(dim, 4 * dim, **dd)
284287
self.act = nn.GELU()
285-
self.grn = GlobalResponseNorm(4 * dim, channels_last=True)
286-
self.pwconv2 = nn.Linear(4 * dim, dim)
288+
self.grn = GlobalResponseNorm(4 * dim, channels_last=True, **dd)
289+
self.pwconv2 = nn.Linear(4 * dim, dim, **dd)
287290
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
288-
self.attention = SpatialAttention()
291+
self.attention = SpatialAttention(**dd)
289292

290293
def forward(self, x: torch.Tensor) -> torch.Tensor:
291294
shortcut = x
@@ -312,16 +315,21 @@ class SpatialTransformerBlock(nn.Module):
312315
positions. Used inside SpatialAttention where input is 1 channel at 7x7 resolution.
313316
"""
314317

315-
def __init__(self) -> None:
318+
def __init__(
319+
self,
320+
device=None,
321+
dtype=None,
322+
) -> None:
323+
dd = dict(device=device, dtype=dtype)
316324
super().__init__()
317325
# Single-head attention with 1-dim q/k/v (no output projection needed)
318-
self.pos_embed = PosConv(in_chans=1)
319-
self.norm1 = nn.LayerNorm(1)
320-
self.qkv = nn.Linear(1, 3, bias=False)
326+
self.pos_embed = PosConv(in_chans=1, **dd)
327+
self.norm1 = nn.LayerNorm(1, **dd)
328+
self.qkv = nn.Linear(1, 3, bias=False, **dd)
321329

322330
# Feedforward: 1 -> 4 -> 1
323-
self.norm2 = nn.LayerNorm(1)
324-
self.mlp = Mlp(1, 4, 1, act_layer=nn.GELU)
331+
self.norm2 = nn.LayerNorm(1, **dd)
332+
self.mlp = Mlp(1, 4, 1, act_layer=nn.GELU, **dd)
325333

326334
def forward(self, x: torch.Tensor) -> torch.Tensor:
327335
B, C, H, W = x.shape
@@ -354,11 +362,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
354362
class SpatialAttention(nn.Module):
355363
"""Spatial attention module using channel statistics and transformer."""
356364

357-
def __init__(self) -> None:
365+
def __init__(
366+
self,
367+
device=None,
368+
dtype=None,
369+
) -> None:
370+
dd = dict(device=device, dtype=dtype)
358371
super().__init__()
359372
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
360-
self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3)
361-
self.attention = SpatialTransformerBlock()
373+
self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, **dd)
374+
self.attention = SpatialTransformerBlock(**dd)
362375

363376
def forward(self, x: torch.Tensor) -> torch.Tensor:
364377
x_avg = x.mean(dim=1, keepdim=True)
@@ -382,33 +395,37 @@ def __init__(
382395
downsample: bool = False,
383396
attn_drop: float = 0.,
384397
proj_drop: float = 0.,
398+
device=None,
399+
dtype=None,
385400
) -> None:
401+
dd = dict(device=device, dtype=dtype)
386402
super().__init__()
387403
hidden_dim = int(inp * 4)
388404
self.downsample = downsample
389405

390406
if self.downsample:
391407
self.pool1 = nn.MaxPool2d(3, 2, 1)
392408
self.pool2 = nn.MaxPool2d(3, 2, 1)
393-
self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)
409+
self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False, **dd)
394410
else:
395411
self.pool1 = nn.Identity()
396412
self.pool2 = nn.Identity()
397413
self.proj = nn.Identity()
398414

399-
self.pos_embed = PosConv(in_chans=inp)
400-
self.norm1 = nn.LayerNorm(inp)
415+
self.pos_embed = PosConv(in_chans=inp, **dd)
416+
self.norm1 = nn.LayerNorm(inp, **dd)
401417
self.attn = Attention(
402418
dim=inp,
403419
num_heads=num_heads,
404420
attn_head_dim=attn_head_dim,
405421
dim_out=oup,
406422
attn_drop=attn_drop,
407423
proj_drop=proj_drop,
424+
**dd,
408425
)
409426

410-
self.norm2 = nn.LayerNorm(oup)
411-
self.mlp = Mlp(oup, hidden_dim, oup, act_layer=nn.GELU, drop=proj_drop)
427+
self.norm2 = nn.LayerNorm(oup, **dd)
428+
self.mlp = Mlp(oup, hidden_dim, oup, act_layer=nn.GELU, drop=proj_drop, **dd)
412429

413430
def forward(self, x: torch.Tensor) -> torch.Tensor:
414431
if self.downsample:
@@ -448,9 +465,12 @@ class PosConv(nn.Module):
448465
def __init__(
449466
self,
450467
in_chans: int,
468+
device=None,
469+
dtype=None,
451470
) -> None:
471+
dd = dict(device=device, dtype=dtype)
452472
super().__init__()
453-
self.proj = nn.Conv2d(in_chans, in_chans, kernel_size=3, stride=1, padding=1, bias=True, groups=in_chans)
473+
self.proj = nn.Conv2d(in_chans, in_chans, kernel_size=3, stride=1, padding=1, bias=True, groups=in_chans, **dd)
454474

455475
def forward(self, x: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
456476
B, N, C = x.shape
@@ -473,8 +493,11 @@ def __init__(
473493
in_chans: int = 3,
474494
drop_path_rate: float = 0.0,
475495
global_pool: str = 'avg',
496+
device=None,
497+
dtype=None,
476498
**kwargs,
477499
) -> None:
500+
dd = dict(device=device, dtype=dtype)
478501
super().__init__()
479502
self.num_classes = num_classes
480503
self.global_pool = global_pool
@@ -495,44 +518,44 @@ def __init__(
495518
depths = [2, 2, 6, 4]
496519
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
497520

498-
self.stem_dct = LearnableDct2d(8)
521+
self.stem_dct = LearnableDct2d(8, **dd)
499522

500523
self.stages = nn.Sequential(
501524
nn.Sequential(
502-
Block(dim=dims[0], drop_path=dp_rates[0]),
503-
Block(dim=dims[0], drop_path=dp_rates[1]),
504-
LayerNorm2d(dims[0], eps=1e-6),
525+
Block(dim=dims[0], drop_path=dp_rates[0], **dd),
526+
Block(dim=dims[0], drop_path=dp_rates[1], **dd),
527+
LayerNorm2d(dims[0], eps=1e-6, **dd),
505528
),
506529
nn.Sequential(
507-
nn.Conv2d(dims[0], dims[1], kernel_size=2, stride=2),
508-
Block(dim=dims[1], drop_path=dp_rates[2]),
509-
Block(dim=dims[1], drop_path=dp_rates[3]),
510-
LayerNorm2d(dims[1], eps=1e-6),
530+
nn.Conv2d(dims[0], dims[1], kernel_size=2, stride=2, **dd),
531+
Block(dim=dims[1], drop_path=dp_rates[2], **dd),
532+
Block(dim=dims[1], drop_path=dp_rates[3], **dd),
533+
LayerNorm2d(dims[1], eps=1e-6, **dd),
511534
),
512535
nn.Sequential(
513-
nn.Conv2d(dims[1], dims[2], kernel_size=2, stride=2),
514-
Block(dim=dims[2], drop_path=dp_rates[4]),
515-
Block(dim=dims[2], drop_path=dp_rates[5]),
516-
Block(dim=dims[2], drop_path=dp_rates[6]),
517-
Block(dim=dims[2], drop_path=dp_rates[7]),
518-
Block(dim=dims[2], drop_path=dp_rates[8]),
519-
Block(dim=dims[2], drop_path=dp_rates[9]),
520-
TransformerBlock(inp=dims[2], oup=dims[2]),
521-
TransformerBlock(inp=dims[2], oup=dims[2]),
522-
LayerNorm2d(dims[2], eps=1e-6),
536+
nn.Conv2d(dims[1], dims[2], kernel_size=2, stride=2, **dd),
537+
Block(dim=dims[2], drop_path=dp_rates[4], **dd),
538+
Block(dim=dims[2], drop_path=dp_rates[5], **dd),
539+
Block(dim=dims[2], drop_path=dp_rates[6], **dd),
540+
Block(dim=dims[2], drop_path=dp_rates[7], **dd),
541+
Block(dim=dims[2], drop_path=dp_rates[8], **dd),
542+
Block(dim=dims[2], drop_path=dp_rates[9], **dd),
543+
TransformerBlock(inp=dims[2], oup=dims[2], **dd),
544+
TransformerBlock(inp=dims[2], oup=dims[2], **dd),
545+
LayerNorm2d(dims[2], eps=1e-6, **dd),
523546
),
524547
nn.Sequential(
525-
nn.Conv2d(dims[2], dims[3], kernel_size=2, stride=2),
526-
Block(dim=dims[3], drop_path=dp_rates[10]),
527-
Block(dim=dims[3], drop_path=dp_rates[11]),
528-
Block(dim=dims[3], drop_path=dp_rates[12]),
529-
Block(dim=dims[3], drop_path=dp_rates[13]),
530-
TransformerBlock(inp=dims[3], oup=dims[3]),
531-
TransformerBlock(inp=dims[3], oup=dims[3]),
548+
nn.Conv2d(dims[2], dims[3], kernel_size=2, stride=2, **dd),
549+
Block(dim=dims[3], drop_path=dp_rates[10], **dd),
550+
Block(dim=dims[3], drop_path=dp_rates[11], **dd),
551+
Block(dim=dims[3], drop_path=dp_rates[12], **dd),
552+
Block(dim=dims[3], drop_path=dp_rates[13], **dd),
553+
TransformerBlock(inp=dims[3], oup=dims[3], **dd),
554+
TransformerBlock(inp=dims[3], oup=dims[3], **dd),
532555
),
533556
)
534557

535-
self.head = NormMlpClassifierHead(dims[-1], num_classes, pool_type=global_pool)
558+
self.head = NormMlpClassifierHead(dims[-1], num_classes, pool_type=global_pool, **dd)
536559

537560
self.apply(self._init_weights)
538561

0 commit comments

Comments
 (0)