@@ -39,18 +39,23 @@ def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm
39
39
self .norm = norm_layer (4 * dim )
40
40
41
41
def forward (self , x : Tensor ):
42
- B , H , W , C = x .shape
43
-
44
- x0 = x [:, 0 ::2 , 0 ::2 , :] # B H/2 W/2 C
45
- x1 = x [:, 1 ::2 , 0 ::2 , :] # B H/2 W/2 C
46
- x2 = x [:, 0 ::2 , 1 ::2 , :] # B H/2 W/2 C
47
- x3 = x [:, 1 ::2 , 1 ::2 , :] # B H/2 W/2 C
48
- x = torch .cat ([x0 , x1 , x2 , x3 ], - 1 ) # B H/2 W/2 4*C
49
- x = x .view (B , - 1 , 4 * C ) # B H/2*W/2 4*C
42
+ """
43
+ Args:
44
+ x (Tensor): input tensor with expected layout of [..., H, W, C]
45
+ Returns:
46
+ Tensor with layout of [..., H/2, W/2, 2*C]
47
+ """
48
+ H , W , _ = x .shape [- 3 :]
49
+ x = F .pad (x , (0 , 0 , 0 , W % 2 , 0 , H % 2 ))
50
+
51
+ x0 = x [..., 0 ::2 , 0 ::2 , :] # ... H/2 W/2 C
52
+ x1 = x [..., 1 ::2 , 0 ::2 , :] # ... H/2 W/2 C
53
+ x2 = x [..., 0 ::2 , 1 ::2 , :] # ... H/2 W/2 C
54
+ x3 = x [..., 1 ::2 , 1 ::2 , :] # ... H/2 W/2 C
55
+ x = torch .cat ([x0 , x1 , x2 , x3 ], - 1 ) # ... H/2 W/2 4*C
50
56
51
57
x = self .norm (x )
52
- x = self .reduction (x )
53
- x = x .view (B , H // 2 , W // 2 , 2 * C )
58
+ x = self .reduction (x ) # ... H/2 W/2 2*C
54
59
return x
55
60
56
61
@@ -59,9 +64,9 @@ def shifted_window_attention(
59
64
qkv_weight : Tensor ,
60
65
proj_weight : Tensor ,
61
66
relative_position_bias : Tensor ,
62
- window_size : int ,
67
+ window_size : List [ int ] ,
63
68
num_heads : int ,
64
- shift_size : int = 0 ,
69
+ shift_size : List [ int ] ,
65
70
attention_dropout : float = 0.0 ,
66
71
dropout : float = 0.0 ,
67
72
qkv_bias : Optional [Tensor ] = None ,
@@ -75,9 +80,9 @@ def shifted_window_attention(
75
80
qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value.
76
81
proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection.
77
82
relative_position_bias (Tensor): The learned relative position bias added to attention.
78
- window_size (int): Window size.
83
+ window_size (List[ int] ): Window size.
79
84
num_heads (int): Number of attention heads.
80
- shift_size (int): Shift size for shifted window attention. Default: 0 .
85
+ shift_size (List[ int] ): Shift size for shifted window attention.
81
86
attention_dropout (float): Dropout ratio of attention weight. Default: 0.0.
82
87
dropout (float): Dropout ratio of output. Default: 0.0.
83
88
qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
@@ -87,23 +92,25 @@ def shifted_window_attention(
87
92
"""
88
93
B , H , W , C = input .shape
89
94
# pad feature maps to multiples of window size
90
- pad_r = (window_size - W % window_size ) % window_size
91
- pad_b = (window_size - H % window_size ) % window_size
95
+ pad_r = (window_size [ 1 ] - W % window_size [ 1 ] ) % window_size [ 1 ]
96
+ pad_b = (window_size [ 0 ] - H % window_size [ 0 ] ) % window_size [ 0 ]
92
97
x = F .pad (input , (0 , 0 , 0 , pad_r , 0 , pad_b ))
93
98
_ , pad_H , pad_W , _ = x .shape
94
99
95
- # If window size is larger than feature size, there is no need to shift window.
96
- if window_size == min (pad_H , pad_W ):
97
- shift_size = 0
100
+ # If window size is larger than feature size, there is no need to shift window
101
+ if window_size [0 ] >= pad_H :
102
+ shift_size [0 ] = 0
103
+ if window_size [1 ] >= pad_W :
104
+ shift_size [1 ] = 0
98
105
99
106
# cyclic shift
100
- if shift_size > 0 :
101
- x = torch .roll (x , shifts = (- shift_size , - shift_size ), dims = (1 , 2 ))
107
+ if sum ( shift_size ) > 0 :
108
+ x = torch .roll (x , shifts = (- shift_size [ 0 ] , - shift_size [ 1 ] ), dims = (1 , 2 ))
102
109
103
110
# partition windows
104
- num_windows = (pad_H // window_size ) * (pad_W // window_size )
105
- x = x .view (B , pad_H // window_size , window_size , pad_W // window_size , window_size , C )
106
- x = x .permute (0 , 1 , 3 , 2 , 4 , 5 ).reshape (B * num_windows , window_size * window_size , C ) # B*nW, Ws*Ws, C
111
+ num_windows = (pad_H // window_size [ 0 ] ) * (pad_W // window_size [ 1 ] )
112
+ x = x .view (B , pad_H // window_size [ 0 ] , window_size [ 0 ] , pad_W // window_size [ 1 ] , window_size [ 1 ] , C )
113
+ x = x .permute (0 , 1 , 3 , 2 , 4 , 5 ).reshape (B * num_windows , window_size [ 0 ] * window_size [ 1 ] , C ) # B*nW, Ws*Ws, C
107
114
108
115
# multi-head attention
109
116
qkv = F .linear (x , qkv_weight , qkv_bias )
@@ -114,17 +121,18 @@ def shifted_window_attention(
114
121
# add relative position bias
115
122
attn = attn + relative_position_bias
116
123
117
- if shift_size > 0 :
124
+ if sum ( shift_size ) > 0 :
118
125
# generate attention mask
119
126
attn_mask = x .new_zeros ((pad_H , pad_W ))
120
- slices = ((0 , - window_size ), (- window_size , - shift_size ), (- shift_size , None ))
127
+ h_slices = ((0 , - window_size [0 ]), (- window_size [0 ], - shift_size [0 ]), (- shift_size [0 ], None ))
128
+ w_slices = ((0 , - window_size [1 ]), (- window_size [1 ], - shift_size [1 ]), (- shift_size [1 ], None ))
121
129
count = 0
122
- for h in slices :
123
- for w in slices :
130
+ for h in h_slices :
131
+ for w in w_slices :
124
132
attn_mask [h [0 ] : h [1 ], w [0 ] : w [1 ]] = count
125
133
count += 1
126
- attn_mask = attn_mask .view (pad_H // window_size , window_size , pad_W // window_size , window_size )
127
- attn_mask = attn_mask .permute (0 , 2 , 1 , 3 ).reshape (num_windows , window_size * window_size )
134
+ attn_mask = attn_mask .view (pad_H // window_size [ 0 ] , window_size [ 0 ] , pad_W // window_size [ 1 ] , window_size [ 1 ] )
135
+ attn_mask = attn_mask .permute (0 , 2 , 1 , 3 ).reshape (num_windows , window_size [ 0 ] * window_size [ 1 ] )
128
136
attn_mask = attn_mask .unsqueeze (1 ) - attn_mask .unsqueeze (2 )
129
137
attn_mask = attn_mask .masked_fill (attn_mask != 0 , float (- 100.0 )).masked_fill (attn_mask == 0 , float (0.0 ))
130
138
attn = attn .view (x .size (0 ) // num_windows , num_windows , num_heads , x .size (1 ), x .size (1 ))
@@ -139,12 +147,12 @@ def shifted_window_attention(
139
147
x = F .dropout (x , p = dropout )
140
148
141
149
# reverse windows
142
- x = x .view (B , pad_H // window_size , pad_W // window_size , window_size , window_size , C )
150
+ x = x .view (B , pad_H // window_size [ 0 ] , pad_W // window_size [ 1 ] , window_size [ 0 ] , window_size [ 1 ] , C )
143
151
x = x .permute (0 , 1 , 3 , 2 , 4 , 5 ).reshape (B , pad_H , pad_W , C )
144
152
145
153
# reverse cyclic shift
146
- if shift_size > 0 :
147
- x = torch .roll (x , shifts = (shift_size , shift_size ), dims = (1 , 2 ))
154
+ if sum ( shift_size ) > 0 :
155
+ x = torch .roll (x , shifts = (shift_size [ 0 ] , shift_size [ 1 ] ), dims = (1 , 2 ))
148
156
149
157
# unpad features
150
158
x = x [:, :H , :W , :].contiguous ()
@@ -162,15 +170,17 @@ class ShiftedWindowAttention(nn.Module):
162
170
def __init__ (
163
171
self ,
164
172
dim : int ,
165
- window_size : int ,
166
- shift_size : int ,
173
+ window_size : List [ int ] ,
174
+ shift_size : List [ int ] ,
167
175
num_heads : int ,
168
176
qkv_bias : bool = True ,
169
177
proj_bias : bool = True ,
170
178
attention_dropout : float = 0.0 ,
171
179
dropout : float = 0.0 ,
172
180
):
173
181
super ().__init__ ()
182
+ if len (window_size ) != 2 or len (shift_size ) != 2 :
183
+ raise ValueError ("window_size and shift_size must be of length 2" )
174
184
self .window_size = window_size
175
185
self .shift_size = shift_size
176
186
self .num_heads = num_heads
@@ -182,29 +192,35 @@ def __init__(
182
192
183
193
# define a parameter table of relative position bias
184
194
self .relative_position_bias_table = nn .Parameter (
185
- torch .zeros ((2 * window_size - 1 ) * (2 * window_size - 1 ), num_heads )
195
+ torch .zeros ((2 * window_size [ 0 ] - 1 ) * (2 * window_size [ 1 ] - 1 ), num_heads )
186
196
) # 2*Wh-1 * 2*Ww-1, nH
187
197
188
198
# get pair-wise relative position index for each token inside the window
189
- coords_h = torch .arange (self .window_size )
190
- coords_w = torch .arange (self .window_size )
199
+ coords_h = torch .arange (self .window_size [ 0 ] )
200
+ coords_w = torch .arange (self .window_size [ 1 ] )
191
201
coords = torch .stack (torch .meshgrid (coords_h , coords_w , indexing = "ij" )) # 2, Wh, Ww
192
202
coords_flatten = torch .flatten (coords , 1 ) # 2, Wh*Ww
193
203
relative_coords = coords_flatten [:, :, None ] - coords_flatten [:, None , :] # 2, Wh*Ww, Wh*Ww
194
204
relative_coords = relative_coords .permute (1 , 2 , 0 ).contiguous () # Wh*Ww, Wh*Ww, 2
195
- relative_coords [:, :, 0 ] += self .window_size - 1 # shift to start from 0
196
- relative_coords [:, :, 1 ] += self .window_size - 1
197
- relative_coords [:, :, 0 ] *= 2 * self .window_size - 1
205
+ relative_coords [:, :, 0 ] += self .window_size [ 0 ] - 1 # shift to start from 0
206
+ relative_coords [:, :, 1 ] += self .window_size [ 1 ] - 1
207
+ relative_coords [:, :, 0 ] *= 2 * self .window_size [ 1 ] - 1
198
208
relative_position_index = relative_coords .sum (- 1 ).view (- 1 ) # Wh*Ww*Wh*Ww
199
209
self .register_buffer ("relative_position_index" , relative_position_index )
200
210
201
211
nn .init .trunc_normal_ (self .relative_position_bias_table , std = 0.02 )
202
212
203
213
def forward (self , x : Tensor ):
214
+ """
215
+ Args:
216
+ x (Tensor): Tensor with layout of [B, H, W, C]
217
+ Returns:
218
+ Tensor with same layout as input, i.e. [B, H, W, C]
219
+ """
220
+
221
+ N = self .window_size [0 ] * self .window_size [1 ]
204
222
relative_position_bias = self .relative_position_bias_table [self .relative_position_index ] # type: ignore[index]
205
- relative_position_bias = relative_position_bias .view (
206
- self .window_size * self .window_size , self .window_size * self .window_size , - 1
207
- )
223
+ relative_position_bias = relative_position_bias .view (N , N , - 1 )
208
224
relative_position_bias = relative_position_bias .permute (2 , 0 , 1 ).contiguous ().unsqueeze (0 )
209
225
210
226
return shifted_window_attention (
@@ -228,31 +244,33 @@ class SwinTransformerBlock(nn.Module):
228
244
Args:
229
245
dim (int): Number of input channels.
230
246
num_heads (int): Number of attention heads.
231
- window_size (int): Window size. Default: 7 .
232
- shift_size (int): Shift size for shifted window attention. Default: 0 .
247
+ window_size (List[ int] ): Window size.
248
+ shift_size (List[ int] ): Shift size for shifted window attention.
233
249
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
234
250
dropout (float): Dropout rate. Default: 0.0.
235
251
attention_dropout (float): Attention dropout rate. Default: 0.0.
236
252
stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
237
253
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
254
+ attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttention
238
255
"""
239
256
240
257
def __init__ (
241
258
self ,
242
259
dim : int ,
243
260
num_heads : int ,
244
- window_size : int = 7 ,
245
- shift_size : int = 0 ,
261
+ window_size : List [ int ] ,
262
+ shift_size : List [ int ] ,
246
263
mlp_ratio : float = 4.0 ,
247
264
dropout : float = 0.0 ,
248
265
attention_dropout : float = 0.0 ,
249
266
stochastic_depth_prob : float = 0.0 ,
250
267
norm_layer : Callable [..., nn .Module ] = nn .LayerNorm ,
268
+ attn_layer : Callable [..., nn .Module ] = ShiftedWindowAttention ,
251
269
):
252
270
super ().__init__ ()
253
271
254
272
self .norm1 = norm_layer (dim )
255
- self .attn = ShiftedWindowAttention (
273
+ self .attn = attn_layer (
256
274
dim ,
257
275
window_size ,
258
276
shift_size ,
@@ -281,11 +299,11 @@ class SwinTransformer(nn.Module):
281
299
Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using
282
300
Shifted Windows" <https://arxiv.org/pdf/2103.14030>`_ paper.
283
301
Args:
284
- patch_size (int): Patch size.
302
+ patch_size (List[ int] ): Patch size.
285
303
embed_dim (int): Patch embedding dimension.
286
304
depths (List(int)): Depth of each Swin Transformer layer.
287
305
num_heads (List(int)): Number of attention heads in different layers.
288
- window_size (int): Window size. Default: 7 .
306
+ window_size (List[ int] ): Window size.
289
307
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
290
308
dropout (float): Dropout rate. Default: 0.0.
291
309
attention_dropout (float): Attention dropout rate. Default: 0.0.
@@ -297,11 +315,11 @@ class SwinTransformer(nn.Module):
297
315
298
316
def __init__ (
299
317
self ,
300
- patch_size : int ,
318
+ patch_size : List [ int ] ,
301
319
embed_dim : int ,
302
320
depths : List [int ],
303
321
num_heads : List [int ],
304
- window_size : int = 7 ,
322
+ window_size : List [ int ] ,
305
323
mlp_ratio : float = 4.0 ,
306
324
dropout : float = 0.0 ,
307
325
attention_dropout : float = 0.0 ,
@@ -324,7 +342,9 @@ def __init__(
324
342
# split image into non-overlapping patches
325
343
layers .append (
326
344
nn .Sequential (
327
- nn .Conv2d (3 , embed_dim , kernel_size = patch_size , stride = patch_size ),
345
+ nn .Conv2d (
346
+ 3 , embed_dim , kernel_size = (patch_size [0 ], patch_size [1 ]), stride = (patch_size [0 ], patch_size [1 ])
347
+ ),
328
348
Permute ([0 , 2 , 3 , 1 ]),
329
349
norm_layer (embed_dim ),
330
350
)
@@ -344,7 +364,7 @@ def __init__(
344
364
dim ,
345
365
num_heads [i_stage ],
346
366
window_size = window_size ,
347
- shift_size = 0 if i_layer % 2 == 0 else window_size // 2 ,
367
+ shift_size = [ 0 if i_layer % 2 == 0 else w // 2 for w in window_size ] ,
348
368
mlp_ratio = mlp_ratio ,
349
369
dropout = dropout ,
350
370
attention_dropout = attention_dropout ,
@@ -381,11 +401,11 @@ def forward(self, x):
381
401
382
402
383
403
def _swin_transformer (
384
- patch_size : int ,
404
+ patch_size : List [ int ] ,
385
405
embed_dim : int ,
386
406
depths : List [int ],
387
407
num_heads : List [int ],
388
- window_size : int ,
408
+ window_size : List [ int ] ,
389
409
stochastic_depth_prob : float ,
390
410
weights : Optional [WeightsEnum ],
391
411
progress : bool ,
@@ -508,11 +528,11 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, *
508
528
weights = Swin_T_Weights .verify (weights )
509
529
510
530
return _swin_transformer (
511
- patch_size = 4 ,
531
+ patch_size = [ 4 , 4 ] ,
512
532
embed_dim = 96 ,
513
533
depths = [2 , 2 , 6 , 2 ],
514
534
num_heads = [3 , 6 , 12 , 24 ],
515
- window_size = 7 ,
535
+ window_size = [ 7 , 7 ] ,
516
536
stochastic_depth_prob = 0.2 ,
517
537
weights = weights ,
518
538
progress = progress ,
@@ -544,11 +564,11 @@ def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, *
544
564
weights = Swin_S_Weights .verify (weights )
545
565
546
566
return _swin_transformer (
547
- patch_size = 4 ,
567
+ patch_size = [ 4 , 4 ] ,
548
568
embed_dim = 96 ,
549
569
depths = [2 , 2 , 18 , 2 ],
550
570
num_heads = [3 , 6 , 12 , 24 ],
551
- window_size = 7 ,
571
+ window_size = [ 7 , 7 ] ,
552
572
stochastic_depth_prob = 0.3 ,
553
573
weights = weights ,
554
574
progress = progress ,
@@ -580,11 +600,11 @@ def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, *
580
600
weights = Swin_B_Weights .verify (weights )
581
601
582
602
return _swin_transformer (
583
- patch_size = 4 ,
603
+ patch_size = [ 4 , 4 ] ,
584
604
embed_dim = 128 ,
585
605
depths = [2 , 2 , 18 , 2 ],
586
606
num_heads = [4 , 8 , 16 , 32 ],
587
- window_size = 7 ,
607
+ window_size = [ 7 , 7 ] ,
588
608
stochastic_depth_prob = 0.5 ,
589
609
weights = weights ,
590
610
progress = progress ,
0 commit comments