diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 0d3ab9ad32a..249ca37b9d2 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -126,7 +126,8 @@ def shifted_window_attention( qkv_bias: Optional[Tensor] = None, proj_bias: Optional[Tensor] = None, logit_scale: Optional[torch.Tensor] = None, -): + training: bool = True, +) -> Tensor: """ Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. @@ -143,6 +144,7 @@ def shifted_window_attention( qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None. logit_scale (Tensor[out_dim], optional): Logit scale of cosine attention for Swin Transformer V2. Default: None. + training (bool, optional): Training flag used by the dropout parameters. Default: True. Returns: Tensor[N, H, W, C]: The output tensor after shifted window attention. """ @@ -207,11 +209,11 @@ def shifted_window_attention( attn = attn.view(-1, num_heads, x.size(1), x.size(1)) attn = F.softmax(attn, dim=-1) - attn = F.dropout(attn, p=attention_dropout) + attn = F.dropout(attn, p=attention_dropout, training=training) x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C) x = F.linear(x, proj_weight, proj_bias) - x = F.dropout(x, p=dropout) + x = F.dropout(x, p=dropout, training=training) # reverse windows x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C) @@ -286,7 +288,7 @@ def get_relative_position_bias(self) -> torch.Tensor: self.relative_position_bias_table, self.relative_position_index, self.window_size # type: ignore[arg-type] ) - def forward(self, x: Tensor): + def forward(self, x: Tensor) -> Tensor: """ Args: x (Tensor): Tensor with layout of [B, H, W, C] @@ -306,6 +308,7 @@ def forward(self, x: Tensor): dropout=self.dropout, qkv_bias=self.qkv.bias, proj_bias=self.proj.bias, + training=self.training, ) @@ -391,6 +394,7 @@ def forward(self, x: Tensor): qkv_bias=self.qkv.bias, proj_bias=self.proj.bias, logit_scale=self.logit_scale, + training=self.training, ) diff --git a/torchvision/models/video/swin_transformer.py b/torchvision/models/video/swin_transformer.py index c6a1602d255..25cf3cf997e 100644 --- a/torchvision/models/video/swin_transformer.py +++ b/torchvision/models/video/swin_transformer.py @@ -124,6 +124,7 @@ def shifted_window_attention_3d( dropout: float = 0.0, qkv_bias: Optional[Tensor] = None, proj_bias: Optional[Tensor] = None, + training: bool = True, ) -> Tensor: """ Window based multi-head self attention (W-MSA) module with relative position bias. @@ -140,6 +141,7 @@ def shifted_window_attention_3d( dropout (float): Dropout ratio of output. Default: 0.0. qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None. + training (bool, optional): Training flag used by the dropout parameters. Default: True. Returns: Tensor[B, T, H, W, C]: The output tensor after shifted window attention. """ @@ -194,11 +196,11 @@ def shifted_window_attention_3d( attn = attn.view(-1, num_heads, x.size(1), x.size(1)) attn = F.softmax(attn, dim=-1) - attn = F.dropout(attn, p=attention_dropout) + attn = F.dropout(attn, p=attention_dropout, training=training) x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), c) x = F.linear(x, proj_weight, proj_bias) - x = F.dropout(x, p=dropout) + x = F.dropout(x, p=dropout, training=training) # reverse windows x = x.view( @@ -310,6 +312,7 @@ def forward(self, x: Tensor) -> Tensor: dropout=self.dropout, qkv_bias=self.qkv.bias, proj_bias=self.proj.bias, + training=self.training, )