@@ -102,21 +102,21 @@ def _dct_kernel_type_2(
102102 dtype = None ,
103103) -> torch .Tensor :
104104 """Generate Type-II DCT kernel matrix."""
105- factory_kwargs = dict (device = device , dtype = dtype )
106- x = torch .eye (kernel_size , ** factory_kwargs )
105+ dd = dict (device = device , dtype = dtype )
106+ x = torch .eye (kernel_size , ** dd )
107107 v = x .clone ().contiguous ().view (- 1 , kernel_size )
108108 v = torch .cat ([v , v .flip ([1 ])], dim = - 1 )
109109 v = torch .fft .fft (v , dim = - 1 )[:, :kernel_size ]
110110 try :
111- k = torch .tensor (- 1j , ** factory_kwargs ) * torch .pi * torch .arange (kernel_size , ** factory_kwargs )[None , :]
111+ k = torch .tensor (- 1j , ** dd ) * torch .pi * torch .arange (kernel_size , ** dd )[None , :]
112112 except AttributeError :
113- k = torch .tensor (- 1j , ** factory_kwargs ) * math .pi * torch .arange (kernel_size , ** factory_kwargs )[None , :]
113+ k = torch .tensor (- 1j , ** dd ) * math .pi * torch .arange (kernel_size , ** dd )[None , :]
114114 k = torch .exp (k / (kernel_size * 2 ))
115115 v = v * k
116116 v = v .real
117117 if orthonormal :
118- v [:, 0 ] = v [:, 0 ] * torch .sqrt (torch .tensor (1 / (kernel_size * 4 ), ** factory_kwargs ))
119- v [:, 1 :] = v [:, 1 :] * torch .sqrt (torch .tensor (1 / (kernel_size * 2 ), ** factory_kwargs ))
118+ v [:, 0 ] = v [:, 0 ] * torch .sqrt (torch .tensor (1 / (kernel_size * 4 ), ** dd ))
119+ v [:, 1 :] = v [:, 1 :] * torch .sqrt (torch .tensor (1 / (kernel_size * 2 ), ** dd ))
120120 v = v .contiguous ().view (* x .shape )
121121 return v
122122
@@ -142,10 +142,10 @@ def __init__(
142142 device = None ,
143143 dtype = None ,
144144 ) -> None :
145- factory_kwargs = dict (device = device , dtype = dtype )
145+ dd = dict (device = device , dtype = dtype )
146146 super ().__init__ ()
147147 kernel = {'2' : _dct_kernel_type_2 , '3' : _dct_kernel_type_3 }
148- dct_weights = kernel [f'{ kernel_type } ' ](kernel_size , orthonormal , ** factory_kwargs ).T
148+ dct_weights = kernel [f'{ kernel_type } ' ](kernel_size , orthonormal , ** dd ).T
149149 self .register_buffer ('weights' , dct_weights )
150150 self .register_parameter ('bias' , None )
151151
@@ -164,9 +164,9 @@ def __init__(
164164 device = None ,
165165 dtype = None ,
166166 ) -> None :
167- factory_kwargs = dict (device = device , dtype = dtype )
167+ dd = dict (device = device , dtype = dtype )
168168 super ().__init__ ()
169- self .transform = Dct1d (kernel_size , kernel_type , orthonormal , ** factory_kwargs )
169+ self .transform = Dct1d (kernel_size , kernel_type , orthonormal , ** dd )
170170
171171 def forward (self , x : torch .Tensor ) -> torch .Tensor :
172172 return self .transform (self .transform (x ).transpose (- 1 , - 2 )).transpose (- 1 , - 2 )
@@ -183,20 +183,20 @@ def __init__(
183183 device = None ,
184184 dtype = None ,
185185 ) -> None :
186- factory_kwargs = dict (device = device , dtype = dtype )
186+ dd = dict (device = device , dtype = dtype )
187187 super ().__init__ ()
188188 self .k = kernel_size
189189 self .unfold = nn .Unfold (kernel_size = (kernel_size , kernel_size ), stride = (kernel_size , kernel_size ))
190- self .transform = Dct2d (kernel_size , kernel_type , orthonormal , ** factory_kwargs )
190+ self .transform = Dct2d (kernel_size , kernel_type , orthonormal , ** dd )
191191 self .permutation = _zigzag_permutation (kernel_size , kernel_size )
192- self .conv_y = nn .Conv2d (kernel_size ** 2 , 24 , kernel_size = 1 , padding = 0 )
193- self .conv_cb = nn .Conv2d (kernel_size ** 2 , 4 , kernel_size = 1 , padding = 0 )
194- self .conv_cr = nn .Conv2d (kernel_size ** 2 , 4 , kernel_size = 1 , padding = 0 )
192+ self .conv_y = nn .Conv2d (kernel_size ** 2 , 24 , kernel_size = 1 , padding = 0 , ** dd )
193+ self .conv_cb = nn .Conv2d (kernel_size ** 2 , 4 , kernel_size = 1 , padding = 0 , ** dd )
194+ self .conv_cr = nn .Conv2d (kernel_size ** 2 , 4 , kernel_size = 1 , padding = 0 , ** dd )
195195
196- self .register_buffer ('mean' , torch .tensor (_DCT_MEAN ), persistent = False )
197- self .register_buffer ('var' , torch .tensor (_DCT_VAR ), persistent = False )
198- self .register_buffer ('imagenet_mean' , torch .tensor ([0.485 , 0.456 , 0.406 ]), persistent = False )
199- self .register_buffer ('imagenet_std' , torch .tensor ([0.229 , 0.224 , 0.225 ]), persistent = False )
196+ self .register_buffer ('mean' , torch .tensor (_DCT_MEAN , device = device ), persistent = False )
197+ self .register_buffer ('var' , torch .tensor (_DCT_VAR , device = device ), persistent = False )
198+ self .register_buffer ('imagenet_mean' , torch .tensor ([0.485 , 0.456 , 0.406 ], device = device ), persistent = False )
199+ self .register_buffer ('imagenet_std' , torch .tensor ([0.229 , 0.224 , 0.225 ], device = device ), persistent = False )
200200
201201 def _denormalize (self , x : torch .Tensor ) -> torch .Tensor :
202202 """Convert from ImageNet normalized to [0, 255] range."""
@@ -245,11 +245,11 @@ def __init__(
245245 device = None ,
246246 dtype = None ,
247247 ) -> None :
248- factory_kwargs = dict (device = device , dtype = dtype )
248+ dd = dict (device = device , dtype = dtype )
249249 super ().__init__ ()
250250 self .k = kernel_size
251251 self .unfold = nn .Unfold (kernel_size = (kernel_size , kernel_size ), stride = (kernel_size , kernel_size ))
252- self .transform = Dct2d (kernel_size , kernel_type , orthonormal , ** factory_kwargs )
252+ self .transform = Dct2d (kernel_size , kernel_type , orthonormal , ** dd )
253253 self .permutation = _zigzag_permutation (kernel_size , kernel_size )
254254
255255 def forward (self , x : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
@@ -276,16 +276,19 @@ def __init__(
276276 self ,
277277 dim : int ,
278278 drop_path : float = 0. ,
279+ device = None ,
280+ dtype = None ,
279281 ) -> None :
282+ dd = dict (device = device , dtype = dtype )
280283 super ().__init__ ()
281- self .dwconv = nn .Conv2d (dim , dim , kernel_size = 7 , padding = 3 , groups = dim )
282- self .norm = nn .LayerNorm (dim , eps = 1e-6 )
283- self .pwconv1 = nn .Linear (dim , 4 * dim )
284+ self .dwconv = nn .Conv2d (dim , dim , kernel_size = 7 , padding = 3 , groups = dim , ** dd )
285+ self .norm = nn .LayerNorm (dim , eps = 1e-6 , ** dd )
286+ self .pwconv1 = nn .Linear (dim , 4 * dim , ** dd )
284287 self .act = nn .GELU ()
285- self .grn = GlobalResponseNorm (4 * dim , channels_last = True )
286- self .pwconv2 = nn .Linear (4 * dim , dim )
288+ self .grn = GlobalResponseNorm (4 * dim , channels_last = True , ** dd )
289+ self .pwconv2 = nn .Linear (4 * dim , dim , ** dd )
287290 self .drop_path = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
288- self .attention = SpatialAttention ()
291+ self .attention = SpatialAttention (** dd )
289292
290293 def forward (self , x : torch .Tensor ) -> torch .Tensor :
291294 shortcut = x
@@ -312,16 +315,21 @@ class SpatialTransformerBlock(nn.Module):
312315 positions. Used inside SpatialAttention where input is 1 channel at 7x7 resolution.
313316 """
314317
315- def __init__ (self ) -> None :
318+ def __init__ (
319+ self ,
320+ device = None ,
321+ dtype = None ,
322+ ) -> None :
323+ dd = dict (device = device , dtype = dtype )
316324 super ().__init__ ()
317325 # Single-head attention with 1-dim q/k/v (no output projection needed)
318- self .pos_embed = PosConv (in_chans = 1 )
319- self .norm1 = nn .LayerNorm (1 )
320- self .qkv = nn .Linear (1 , 3 , bias = False )
326+ self .pos_embed = PosConv (in_chans = 1 , ** dd )
327+ self .norm1 = nn .LayerNorm (1 , ** dd )
328+ self .qkv = nn .Linear (1 , 3 , bias = False , ** dd )
321329
322330 # Feedforward: 1 -> 4 -> 1
323- self .norm2 = nn .LayerNorm (1 )
324- self .mlp = Mlp (1 , 4 , 1 , act_layer = nn .GELU )
331+ self .norm2 = nn .LayerNorm (1 , ** dd )
332+ self .mlp = Mlp (1 , 4 , 1 , act_layer = nn .GELU , ** dd )
325333
326334 def forward (self , x : torch .Tensor ) -> torch .Tensor :
327335 B , C , H , W = x .shape
@@ -354,11 +362,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
354362class SpatialAttention (nn .Module ):
355363 """Spatial attention module using channel statistics and transformer."""
356364
357- def __init__ (self ) -> None :
365+ def __init__ (
366+ self ,
367+ device = None ,
368+ dtype = None ,
369+ ) -> None :
370+ dd = dict (device = device , dtype = dtype )
358371 super ().__init__ ()
359372 self .avgpool = nn .AdaptiveAvgPool2d ((7 , 7 ))
360- self .conv = nn .Conv2d (2 , 1 , kernel_size = 7 , padding = 3 )
361- self .attention = SpatialTransformerBlock ()
373+ self .conv = nn .Conv2d (2 , 1 , kernel_size = 7 , padding = 3 , ** dd )
374+ self .attention = SpatialTransformerBlock (** dd )
362375
363376 def forward (self , x : torch .Tensor ) -> torch .Tensor :
364377 x_avg = x .mean (dim = 1 , keepdim = True )
@@ -382,33 +395,37 @@ def __init__(
382395 downsample : bool = False ,
383396 attn_drop : float = 0. ,
384397 proj_drop : float = 0. ,
398+ device = None ,
399+ dtype = None ,
385400 ) -> None :
401+ dd = dict (device = device , dtype = dtype )
386402 super ().__init__ ()
387403 hidden_dim = int (inp * 4 )
388404 self .downsample = downsample
389405
390406 if self .downsample :
391407 self .pool1 = nn .MaxPool2d (3 , 2 , 1 )
392408 self .pool2 = nn .MaxPool2d (3 , 2 , 1 )
393- self .proj = nn .Conv2d (inp , oup , 1 , 1 , 0 , bias = False )
409+ self .proj = nn .Conv2d (inp , oup , 1 , 1 , 0 , bias = False , ** dd )
394410 else :
395411 self .pool1 = nn .Identity ()
396412 self .pool2 = nn .Identity ()
397413 self .proj = nn .Identity ()
398414
399- self .pos_embed = PosConv (in_chans = inp )
400- self .norm1 = nn .LayerNorm (inp )
415+ self .pos_embed = PosConv (in_chans = inp , ** dd )
416+ self .norm1 = nn .LayerNorm (inp , ** dd )
401417 self .attn = Attention (
402418 dim = inp ,
403419 num_heads = num_heads ,
404420 attn_head_dim = attn_head_dim ,
405421 dim_out = oup ,
406422 attn_drop = attn_drop ,
407423 proj_drop = proj_drop ,
424+ ** dd ,
408425 )
409426
410- self .norm2 = nn .LayerNorm (oup )
411- self .mlp = Mlp (oup , hidden_dim , oup , act_layer = nn .GELU , drop = proj_drop )
427+ self .norm2 = nn .LayerNorm (oup , ** dd )
428+ self .mlp = Mlp (oup , hidden_dim , oup , act_layer = nn .GELU , drop = proj_drop , ** dd )
412429
413430 def forward (self , x : torch .Tensor ) -> torch .Tensor :
414431 if self .downsample :
@@ -448,9 +465,12 @@ class PosConv(nn.Module):
448465 def __init__ (
449466 self ,
450467 in_chans : int ,
468+ device = None ,
469+ dtype = None ,
451470 ) -> None :
471+ dd = dict (device = device , dtype = dtype )
452472 super ().__init__ ()
453- self .proj = nn .Conv2d (in_chans , in_chans , kernel_size = 3 , stride = 1 , padding = 1 , bias = True , groups = in_chans )
473+ self .proj = nn .Conv2d (in_chans , in_chans , kernel_size = 3 , stride = 1 , padding = 1 , bias = True , groups = in_chans , ** dd )
454474
455475 def forward (self , x : torch .Tensor , size : Tuple [int , int ]) -> torch .Tensor :
456476 B , N , C = x .shape
@@ -473,8 +493,11 @@ def __init__(
473493 in_chans : int = 3 ,
474494 drop_path_rate : float = 0.0 ,
475495 global_pool : str = 'avg' ,
496+ device = None ,
497+ dtype = None ,
476498 ** kwargs ,
477499 ) -> None :
500+ dd = dict (device = device , dtype = dtype )
478501 super ().__init__ ()
479502 self .num_classes = num_classes
480503 self .global_pool = global_pool
@@ -495,44 +518,44 @@ def __init__(
495518 depths = [2 , 2 , 6 , 4 ]
496519 dp_rates = [x .item () for x in torch .linspace (0 , drop_path_rate , sum (depths ))]
497520
498- self .stem_dct = LearnableDct2d (8 )
521+ self .stem_dct = LearnableDct2d (8 , ** dd )
499522
500523 self .stages = nn .Sequential (
501524 nn .Sequential (
502- Block (dim = dims [0 ], drop_path = dp_rates [0 ]),
503- Block (dim = dims [0 ], drop_path = dp_rates [1 ]),
504- LayerNorm2d (dims [0 ], eps = 1e-6 ),
525+ Block (dim = dims [0 ], drop_path = dp_rates [0 ], ** dd ),
526+ Block (dim = dims [0 ], drop_path = dp_rates [1 ], ** dd ),
527+ LayerNorm2d (dims [0 ], eps = 1e-6 , ** dd ),
505528 ),
506529 nn .Sequential (
507- nn .Conv2d (dims [0 ], dims [1 ], kernel_size = 2 , stride = 2 ),
508- Block (dim = dims [1 ], drop_path = dp_rates [2 ]),
509- Block (dim = dims [1 ], drop_path = dp_rates [3 ]),
510- LayerNorm2d (dims [1 ], eps = 1e-6 ),
530+ nn .Conv2d (dims [0 ], dims [1 ], kernel_size = 2 , stride = 2 , ** dd ),
531+ Block (dim = dims [1 ], drop_path = dp_rates [2 ], ** dd ),
532+ Block (dim = dims [1 ], drop_path = dp_rates [3 ], ** dd ),
533+ LayerNorm2d (dims [1 ], eps = 1e-6 , ** dd ),
511534 ),
512535 nn .Sequential (
513- nn .Conv2d (dims [1 ], dims [2 ], kernel_size = 2 , stride = 2 ),
514- Block (dim = dims [2 ], drop_path = dp_rates [4 ]),
515- Block (dim = dims [2 ], drop_path = dp_rates [5 ]),
516- Block (dim = dims [2 ], drop_path = dp_rates [6 ]),
517- Block (dim = dims [2 ], drop_path = dp_rates [7 ]),
518- Block (dim = dims [2 ], drop_path = dp_rates [8 ]),
519- Block (dim = dims [2 ], drop_path = dp_rates [9 ]),
520- TransformerBlock (inp = dims [2 ], oup = dims [2 ]),
521- TransformerBlock (inp = dims [2 ], oup = dims [2 ]),
522- LayerNorm2d (dims [2 ], eps = 1e-6 ),
536+ nn .Conv2d (dims [1 ], dims [2 ], kernel_size = 2 , stride = 2 , ** dd ),
537+ Block (dim = dims [2 ], drop_path = dp_rates [4 ], ** dd ),
538+ Block (dim = dims [2 ], drop_path = dp_rates [5 ], ** dd ),
539+ Block (dim = dims [2 ], drop_path = dp_rates [6 ], ** dd ),
540+ Block (dim = dims [2 ], drop_path = dp_rates [7 ], ** dd ),
541+ Block (dim = dims [2 ], drop_path = dp_rates [8 ], ** dd ),
542+ Block (dim = dims [2 ], drop_path = dp_rates [9 ], ** dd ),
543+ TransformerBlock (inp = dims [2 ], oup = dims [2 ], ** dd ),
544+ TransformerBlock (inp = dims [2 ], oup = dims [2 ], ** dd ),
545+ LayerNorm2d (dims [2 ], eps = 1e-6 , ** dd ),
523546 ),
524547 nn .Sequential (
525- nn .Conv2d (dims [2 ], dims [3 ], kernel_size = 2 , stride = 2 ),
526- Block (dim = dims [3 ], drop_path = dp_rates [10 ]),
527- Block (dim = dims [3 ], drop_path = dp_rates [11 ]),
528- Block (dim = dims [3 ], drop_path = dp_rates [12 ]),
529- Block (dim = dims [3 ], drop_path = dp_rates [13 ]),
530- TransformerBlock (inp = dims [3 ], oup = dims [3 ]),
531- TransformerBlock (inp = dims [3 ], oup = dims [3 ]),
548+ nn .Conv2d (dims [2 ], dims [3 ], kernel_size = 2 , stride = 2 , ** dd ),
549+ Block (dim = dims [3 ], drop_path = dp_rates [10 ], ** dd ),
550+ Block (dim = dims [3 ], drop_path = dp_rates [11 ], ** dd ),
551+ Block (dim = dims [3 ], drop_path = dp_rates [12 ], ** dd ),
552+ Block (dim = dims [3 ], drop_path = dp_rates [13 ], ** dd ),
553+ TransformerBlock (inp = dims [3 ], oup = dims [3 ], ** dd ),
554+ TransformerBlock (inp = dims [3 ], oup = dims [3 ], ** dd ),
532555 ),
533556 )
534557
535- self .head = NormMlpClassifierHead (dims [- 1 ], num_classes , pool_type = global_pool )
558+ self .head = NormMlpClassifierHead (dims [- 1 ], num_classes , pool_type = global_pool , ** dd )
536559
537560 self .apply (self ._init_weights )
538561
0 commit comments