Skip to content

Commit e1a8141

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Fix dropout issue in swin transformers (#7224)
Summary: Co-authored-by: Nicolas Hug <[email protected]> Reviewed By: vmoens Differential Revision: D44416635 fbshipit-source-id: 8f90f59beb98e36ae98a33a8e1b2abd2d07a430b
1 parent cf54a56 commit e1a8141

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

torchvision/models/swin_transformer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ def shifted_window_attention(
126126
qkv_bias: Optional[Tensor] = None,
127127
proj_bias: Optional[Tensor] = None,
128128
logit_scale: Optional[torch.Tensor] = None,
129-
):
129+
training: bool = True,
130+
) -> Tensor:
130131
"""
131132
Window based multi-head self attention (W-MSA) module with relative position bias.
132133
It supports both of shifted and non-shifted window.
@@ -143,6 +144,7 @@ def shifted_window_attention(
143144
qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
144145
proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None.
145146
logit_scale (Tensor[out_dim], optional): Logit scale of cosine attention for Swin Transformer V2. Default: None.
147+
training (bool, optional): Training flag used by the dropout parameters. Default: True.
146148
Returns:
147149
Tensor[N, H, W, C]: The output tensor after shifted window attention.
148150
"""
@@ -207,11 +209,11 @@ def shifted_window_attention(
207209
attn = attn.view(-1, num_heads, x.size(1), x.size(1))
208210

209211
attn = F.softmax(attn, dim=-1)
210-
attn = F.dropout(attn, p=attention_dropout)
212+
attn = F.dropout(attn, p=attention_dropout, training=training)
211213

212214
x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C)
213215
x = F.linear(x, proj_weight, proj_bias)
214-
x = F.dropout(x, p=dropout)
216+
x = F.dropout(x, p=dropout, training=training)
215217

216218
# reverse windows
217219
x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C)
@@ -286,7 +288,7 @@ def get_relative_position_bias(self) -> torch.Tensor:
286288
self.relative_position_bias_table, self.relative_position_index, self.window_size # type: ignore[arg-type]
287289
)
288290

289-
def forward(self, x: Tensor):
291+
def forward(self, x: Tensor) -> Tensor:
290292
"""
291293
Args:
292294
x (Tensor): Tensor with layout of [B, H, W, C]
@@ -306,6 +308,7 @@ def forward(self, x: Tensor):
306308
dropout=self.dropout,
307309
qkv_bias=self.qkv.bias,
308310
proj_bias=self.proj.bias,
311+
training=self.training,
309312
)
310313

311314

@@ -391,6 +394,7 @@ def forward(self, x: Tensor):
391394
qkv_bias=self.qkv.bias,
392395
proj_bias=self.proj.bias,
393396
logit_scale=self.logit_scale,
397+
training=self.training,
394398
)
395399

396400

torchvision/models/video/swin_transformer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def shifted_window_attention_3d(
124124
dropout: float = 0.0,
125125
qkv_bias: Optional[Tensor] = None,
126126
proj_bias: Optional[Tensor] = None,
127+
training: bool = True,
127128
) -> Tensor:
128129
"""
129130
Window based multi-head self attention (W-MSA) module with relative position bias.
@@ -140,6 +141,7 @@ def shifted_window_attention_3d(
140141
dropout (float): Dropout ratio of output. Default: 0.0.
141142
qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
142143
proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None.
144+
training (bool, optional): Training flag used by the dropout parameters. Default: True.
143145
Returns:
144146
Tensor[B, T, H, W, C]: The output tensor after shifted window attention.
145147
"""
@@ -194,11 +196,11 @@ def shifted_window_attention_3d(
194196
attn = attn.view(-1, num_heads, x.size(1), x.size(1))
195197

196198
attn = F.softmax(attn, dim=-1)
197-
attn = F.dropout(attn, p=attention_dropout)
199+
attn = F.dropout(attn, p=attention_dropout, training=training)
198200

199201
x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), c)
200202
x = F.linear(x, proj_weight, proj_bias)
201-
x = F.dropout(x, p=dropout)
203+
x = F.dropout(x, p=dropout, training=training)
202204

203205
# reverse windows
204206
x = x.view(
@@ -310,6 +312,7 @@ def forward(self, x: Tensor) -> Tensor:
310312
dropout=self.dropout,
311313
qkv_bias=self.qkv.bias,
312314
proj_bias=self.proj.bias,
315+
training=self.training,
313316
)
314317

315318

0 commit comments

Comments
 (0)