Skip to content

Commit 1a288d1

Browse files
authored
Refactor swin transfomer so later we can reuse component for 3d version (#6088) (#6100)
* Use List[int] instead of int for window_size and shift_size * Make PatchMerging and SwinTransformerBlock able to handle 2d and 3d cases * Separate patch embedding from SwinTransformer and enable to get model without head by specifying num_heads=None * Dont use if before padding so it is fx friendly * Put the handling on window_size edge cases on separate function and wrap with torch.fx.wrap so it is excluded from tracing * Update the weight url to the converted weight with new structure * Update the accuracy of swin_transformer * Change assert to Exception and nit * Make num_classes optional * Add typing output for _fix_window_and_shift_size function * init head to None to make it jit scriptable * Revert the change to make num_classes optional * Revert unneccesarry changes that might be risky * Remove self.head declaration
1 parent eced17c commit 1a288d1

File tree

1 file changed

+83
-63
lines changed

1 file changed

+83
-63
lines changed

torchvision/models/swin_transformer.py

Lines changed: 83 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,23 @@ def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm
3939
self.norm = norm_layer(4 * dim)
4040

4141
def forward(self, x: Tensor):
42-
B, H, W, C = x.shape
43-
44-
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
45-
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
46-
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
47-
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
48-
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
49-
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
42+
"""
43+
Args:
44+
x (Tensor): input tensor with expected layout of [..., H, W, C]
45+
Returns:
46+
Tensor with layout of [..., H/2, W/2, 2*C]
47+
"""
48+
H, W, _ = x.shape[-3:]
49+
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
50+
51+
x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C
52+
x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C
53+
x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C
54+
x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C
55+
x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C
5056

5157
x = self.norm(x)
52-
x = self.reduction(x)
53-
x = x.view(B, H // 2, W // 2, 2 * C)
58+
x = self.reduction(x) # ... H/2 W/2 2*C
5459
return x
5560

5661

@@ -59,9 +64,9 @@ def shifted_window_attention(
5964
qkv_weight: Tensor,
6065
proj_weight: Tensor,
6166
relative_position_bias: Tensor,
62-
window_size: int,
67+
window_size: List[int],
6368
num_heads: int,
64-
shift_size: int = 0,
69+
shift_size: List[int],
6570
attention_dropout: float = 0.0,
6671
dropout: float = 0.0,
6772
qkv_bias: Optional[Tensor] = None,
@@ -75,9 +80,9 @@ def shifted_window_attention(
7580
qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value.
7681
proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection.
7782
relative_position_bias (Tensor): The learned relative position bias added to attention.
78-
window_size (int): Window size.
83+
window_size (List[int]): Window size.
7984
num_heads (int): Number of attention heads.
80-
shift_size (int): Shift size for shifted window attention. Default: 0.
85+
shift_size (List[int]): Shift size for shifted window attention.
8186
attention_dropout (float): Dropout ratio of attention weight. Default: 0.0.
8287
dropout (float): Dropout ratio of output. Default: 0.0.
8388
qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
@@ -87,23 +92,25 @@ def shifted_window_attention(
8792
"""
8893
B, H, W, C = input.shape
8994
# pad feature maps to multiples of window size
90-
pad_r = (window_size - W % window_size) % window_size
91-
pad_b = (window_size - H % window_size) % window_size
95+
pad_r = (window_size[1] - W % window_size[1]) % window_size[1]
96+
pad_b = (window_size[0] - H % window_size[0]) % window_size[0]
9297
x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b))
9398
_, pad_H, pad_W, _ = x.shape
9499

95-
# If window size is larger than feature size, there is no need to shift window.
96-
if window_size == min(pad_H, pad_W):
97-
shift_size = 0
100+
# If window size is larger than feature size, there is no need to shift window
101+
if window_size[0] >= pad_H:
102+
shift_size[0] = 0
103+
if window_size[1] >= pad_W:
104+
shift_size[1] = 0
98105

99106
# cyclic shift
100-
if shift_size > 0:
101-
x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2))
107+
if sum(shift_size) > 0:
108+
x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
102109

103110
# partition windows
104-
num_windows = (pad_H // window_size) * (pad_W // window_size)
105-
x = x.view(B, pad_H // window_size, window_size, pad_W // window_size, window_size, C)
106-
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size * window_size, C) # B*nW, Ws*Ws, C
111+
num_windows = (pad_H // window_size[0]) * (pad_W // window_size[1])
112+
x = x.view(B, pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1], C)
113+
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C) # B*nW, Ws*Ws, C
107114

108115
# multi-head attention
109116
qkv = F.linear(x, qkv_weight, qkv_bias)
@@ -114,17 +121,18 @@ def shifted_window_attention(
114121
# add relative position bias
115122
attn = attn + relative_position_bias
116123

117-
if shift_size > 0:
124+
if sum(shift_size) > 0:
118125
# generate attention mask
119126
attn_mask = x.new_zeros((pad_H, pad_W))
120-
slices = ((0, -window_size), (-window_size, -shift_size), (-shift_size, None))
127+
h_slices = ((0, -window_size[0]), (-window_size[0], -shift_size[0]), (-shift_size[0], None))
128+
w_slices = ((0, -window_size[1]), (-window_size[1], -shift_size[1]), (-shift_size[1], None))
121129
count = 0
122-
for h in slices:
123-
for w in slices:
130+
for h in h_slices:
131+
for w in w_slices:
124132
attn_mask[h[0] : h[1], w[0] : w[1]] = count
125133
count += 1
126-
attn_mask = attn_mask.view(pad_H // window_size, window_size, pad_W // window_size, window_size)
127-
attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size * window_size)
134+
attn_mask = attn_mask.view(pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1])
135+
attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size[0] * window_size[1])
128136
attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2)
129137
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
130138
attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1))
@@ -139,12 +147,12 @@ def shifted_window_attention(
139147
x = F.dropout(x, p=dropout)
140148

141149
# reverse windows
142-
x = x.view(B, pad_H // window_size, pad_W // window_size, window_size, window_size, C)
150+
x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C)
143151
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C)
144152

145153
# reverse cyclic shift
146-
if shift_size > 0:
147-
x = torch.roll(x, shifts=(shift_size, shift_size), dims=(1, 2))
154+
if sum(shift_size) > 0:
155+
x = torch.roll(x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))
148156

149157
# unpad features
150158
x = x[:, :H, :W, :].contiguous()
@@ -162,15 +170,17 @@ class ShiftedWindowAttention(nn.Module):
162170
def __init__(
163171
self,
164172
dim: int,
165-
window_size: int,
166-
shift_size: int,
173+
window_size: List[int],
174+
shift_size: List[int],
167175
num_heads: int,
168176
qkv_bias: bool = True,
169177
proj_bias: bool = True,
170178
attention_dropout: float = 0.0,
171179
dropout: float = 0.0,
172180
):
173181
super().__init__()
182+
if len(window_size) != 2 or len(shift_size) != 2:
183+
raise ValueError("window_size and shift_size must be of length 2")
174184
self.window_size = window_size
175185
self.shift_size = shift_size
176186
self.num_heads = num_heads
@@ -182,29 +192,35 @@ def __init__(
182192

183193
# define a parameter table of relative position bias
184194
self.relative_position_bias_table = nn.Parameter(
185-
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)
195+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
186196
) # 2*Wh-1 * 2*Ww-1, nH
187197

188198
# get pair-wise relative position index for each token inside the window
189-
coords_h = torch.arange(self.window_size)
190-
coords_w = torch.arange(self.window_size)
199+
coords_h = torch.arange(self.window_size[0])
200+
coords_w = torch.arange(self.window_size[1])
191201
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) # 2, Wh, Ww
192202
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
193203
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
194204
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
195-
relative_coords[:, :, 0] += self.window_size - 1 # shift to start from 0
196-
relative_coords[:, :, 1] += self.window_size - 1
197-
relative_coords[:, :, 0] *= 2 * self.window_size - 1
205+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
206+
relative_coords[:, :, 1] += self.window_size[1] - 1
207+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
198208
relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww
199209
self.register_buffer("relative_position_index", relative_position_index)
200210

201211
nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
202212

203213
def forward(self, x: Tensor):
214+
"""
215+
Args:
216+
x (Tensor): Tensor with layout of [B, H, W, C]
217+
Returns:
218+
Tensor with same layout as input, i.e. [B, H, W, C]
219+
"""
220+
221+
N = self.window_size[0] * self.window_size[1]
204222
relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index]
205-
relative_position_bias = relative_position_bias.view(
206-
self.window_size * self.window_size, self.window_size * self.window_size, -1
207-
)
223+
relative_position_bias = relative_position_bias.view(N, N, -1)
208224
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)
209225

210226
return shifted_window_attention(
@@ -228,31 +244,33 @@ class SwinTransformerBlock(nn.Module):
228244
Args:
229245
dim (int): Number of input channels.
230246
num_heads (int): Number of attention heads.
231-
window_size (int): Window size. Default: 7.
232-
shift_size (int): Shift size for shifted window attention. Default: 0.
247+
window_size (List[int]): Window size.
248+
shift_size (List[int]): Shift size for shifted window attention.
233249
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
234250
dropout (float): Dropout rate. Default: 0.0.
235251
attention_dropout (float): Attention dropout rate. Default: 0.0.
236252
stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
237253
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
254+
attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttention
238255
"""
239256

240257
def __init__(
241258
self,
242259
dim: int,
243260
num_heads: int,
244-
window_size: int = 7,
245-
shift_size: int = 0,
261+
window_size: List[int],
262+
shift_size: List[int],
246263
mlp_ratio: float = 4.0,
247264
dropout: float = 0.0,
248265
attention_dropout: float = 0.0,
249266
stochastic_depth_prob: float = 0.0,
250267
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
268+
attn_layer: Callable[..., nn.Module] = ShiftedWindowAttention,
251269
):
252270
super().__init__()
253271

254272
self.norm1 = norm_layer(dim)
255-
self.attn = ShiftedWindowAttention(
273+
self.attn = attn_layer(
256274
dim,
257275
window_size,
258276
shift_size,
@@ -281,11 +299,11 @@ class SwinTransformer(nn.Module):
281299
Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using
282300
Shifted Windows" <https://arxiv.org/pdf/2103.14030>`_ paper.
283301
Args:
284-
patch_size (int): Patch size.
302+
patch_size (List[int]): Patch size.
285303
embed_dim (int): Patch embedding dimension.
286304
depths (List(int)): Depth of each Swin Transformer layer.
287305
num_heads (List(int)): Number of attention heads in different layers.
288-
window_size (int): Window size. Default: 7.
306+
window_size (List[int]): Window size.
289307
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
290308
dropout (float): Dropout rate. Default: 0.0.
291309
attention_dropout (float): Attention dropout rate. Default: 0.0.
@@ -297,11 +315,11 @@ class SwinTransformer(nn.Module):
297315

298316
def __init__(
299317
self,
300-
patch_size: int,
318+
patch_size: List[int],
301319
embed_dim: int,
302320
depths: List[int],
303321
num_heads: List[int],
304-
window_size: int = 7,
322+
window_size: List[int],
305323
mlp_ratio: float = 4.0,
306324
dropout: float = 0.0,
307325
attention_dropout: float = 0.0,
@@ -324,7 +342,9 @@ def __init__(
324342
# split image into non-overlapping patches
325343
layers.append(
326344
nn.Sequential(
327-
nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size),
345+
nn.Conv2d(
346+
3, embed_dim, kernel_size=(patch_size[0], patch_size[1]), stride=(patch_size[0], patch_size[1])
347+
),
328348
Permute([0, 2, 3, 1]),
329349
norm_layer(embed_dim),
330350
)
@@ -344,7 +364,7 @@ def __init__(
344364
dim,
345365
num_heads[i_stage],
346366
window_size=window_size,
347-
shift_size=0 if i_layer % 2 == 0 else window_size // 2,
367+
shift_size=[0 if i_layer % 2 == 0 else w // 2 for w in window_size],
348368
mlp_ratio=mlp_ratio,
349369
dropout=dropout,
350370
attention_dropout=attention_dropout,
@@ -381,11 +401,11 @@ def forward(self, x):
381401

382402

383403
def _swin_transformer(
384-
patch_size: int,
404+
patch_size: List[int],
385405
embed_dim: int,
386406
depths: List[int],
387407
num_heads: List[int],
388-
window_size: int,
408+
window_size: List[int],
389409
stochastic_depth_prob: float,
390410
weights: Optional[WeightsEnum],
391411
progress: bool,
@@ -508,11 +528,11 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, *
508528
weights = Swin_T_Weights.verify(weights)
509529

510530
return _swin_transformer(
511-
patch_size=4,
531+
patch_size=[4, 4],
512532
embed_dim=96,
513533
depths=[2, 2, 6, 2],
514534
num_heads=[3, 6, 12, 24],
515-
window_size=7,
535+
window_size=[7, 7],
516536
stochastic_depth_prob=0.2,
517537
weights=weights,
518538
progress=progress,
@@ -544,11 +564,11 @@ def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, *
544564
weights = Swin_S_Weights.verify(weights)
545565

546566
return _swin_transformer(
547-
patch_size=4,
567+
patch_size=[4, 4],
548568
embed_dim=96,
549569
depths=[2, 2, 18, 2],
550570
num_heads=[3, 6, 12, 24],
551-
window_size=7,
571+
window_size=[7, 7],
552572
stochastic_depth_prob=0.3,
553573
weights=weights,
554574
progress=progress,
@@ -580,11 +600,11 @@ def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, *
580600
weights = Swin_B_Weights.verify(weights)
581601

582602
return _swin_transformer(
583-
patch_size=4,
603+
patch_size=[4, 4],
584604
embed_dim=128,
585605
depths=[2, 2, 18, 2],
586606
num_heads=[4, 8, 16, 32],
587-
window_size=7,
607+
window_size=[7, 7],
588608
stochastic_depth_prob=0.5,
589609
weights=weights,
590610
progress=progress,

0 commit comments

Comments
 (0)