@@ -126,7 +126,8 @@ def shifted_window_attention(
126
126
qkv_bias : Optional [Tensor ] = None ,
127
127
proj_bias : Optional [Tensor ] = None ,
128
128
logit_scale : Optional [torch .Tensor ] = None ,
129
- ):
129
+ training : bool = True ,
130
+ ) -> Tensor :
130
131
"""
131
132
Window based multi-head self attention (W-MSA) module with relative position bias.
132
133
It supports both of shifted and non-shifted window.
@@ -143,6 +144,7 @@ def shifted_window_attention(
143
144
qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
144
145
proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None.
145
146
logit_scale (Tensor[out_dim], optional): Logit scale of cosine attention for Swin Transformer V2. Default: None.
147
+ training (bool, optional): Training flag used by the dropout parameters. Default: True.
146
148
Returns:
147
149
Tensor[N, H, W, C]: The output tensor after shifted window attention.
148
150
"""
@@ -207,11 +209,11 @@ def shifted_window_attention(
207
209
attn = attn .view (- 1 , num_heads , x .size (1 ), x .size (1 ))
208
210
209
211
attn = F .softmax (attn , dim = - 1 )
210
- attn = F .dropout (attn , p = attention_dropout )
212
+ attn = F .dropout (attn , p = attention_dropout , training = training )
211
213
212
214
x = attn .matmul (v ).transpose (1 , 2 ).reshape (x .size (0 ), x .size (1 ), C )
213
215
x = F .linear (x , proj_weight , proj_bias )
214
- x = F .dropout (x , p = dropout )
216
+ x = F .dropout (x , p = dropout , training = training )
215
217
216
218
# reverse windows
217
219
x = x .view (B , pad_H // window_size [0 ], pad_W // window_size [1 ], window_size [0 ], window_size [1 ], C )
@@ -286,7 +288,7 @@ def get_relative_position_bias(self) -> torch.Tensor:
286
288
self .relative_position_bias_table , self .relative_position_index , self .window_size # type: ignore[arg-type]
287
289
)
288
290
289
- def forward (self , x : Tensor ):
291
+ def forward (self , x : Tensor ) -> Tensor :
290
292
"""
291
293
Args:
292
294
x (Tensor): Tensor with layout of [B, H, W, C]
@@ -306,6 +308,7 @@ def forward(self, x: Tensor):
306
308
dropout = self .dropout ,
307
309
qkv_bias = self .qkv .bias ,
308
310
proj_bias = self .proj .bias ,
311
+ training = self .training ,
309
312
)
310
313
311
314
@@ -391,6 +394,7 @@ def forward(self, x: Tensor):
391
394
qkv_bias = self .qkv .bias ,
392
395
proj_bias = self .proj .bias ,
393
396
logit_scale = self .logit_scale ,
397
+ training = self .training ,
394
398
)
395
399
396
400
0 commit comments