diff --git a/timm/layers/blur_pool.py b/timm/layers/blur_pool.py index 6a4b668c1..b7302d1ac 100644 --- a/timm/layers/blur_pool.py +++ b/timm/layers/blur_pool.py @@ -6,12 +6,12 @@ Hacked together by Chris Ha and Ross Wightman """ from functools import partial +from math import comb # Python 3.8 from typing import Optional, Type import torch import torch.nn as nn import torch.nn.functional as F -import numpy as np from .padding import get_padding from .typing import LayerType @@ -45,7 +45,11 @@ def __init__( self.pad_mode = pad_mode self.padding = [get_padding(filt_size, stride, dilation=1)] * 4 - coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32)) + # (0.5 + 0.5 x)^N => coefficients = C(N,k) / 2^N, k = 0..N + coeffs = torch.tensor( + [comb(filt_size - 1, k) for k in range(filt_size)], + dtype=torch.float32, + ) / (2 ** (filt_size - 1)) # normalise so coefficients sum to 1 blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :] if channels is not None: blur_filter = blur_filter.repeat(self.channels, 1, 1, 1) diff --git a/timm/layers/cond_conv2d.py b/timm/layers/cond_conv2d.py index 43654c597..32f02c98a 100644 --- a/timm/layers/cond_conv2d.py +++ b/timm/layers/cond_conv2d.py @@ -8,7 +8,6 @@ import math from functools import partial -import numpy as np import torch from torch import nn as nn from torch.nn import functional as F @@ -21,7 +20,7 @@ def get_condconv_initializer(initializer, num_experts, expert_shape): def condconv_initializer(weight): """CondConv initializer function.""" - num_params = np.prod(expert_shape) + num_params = math.prod(expert_shape) if (len(weight.shape) != 2 or weight.shape[0] != num_experts or weight.shape[1] != num_params): raise (ValueError( @@ -75,7 +74,7 @@ def reset_parameters(self): partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) init_weight(self.weight) if self.bias is not None: - fan_in = np.prod(self.weight_shape[1:]) + fan_in = math.prod(self.weight_shape[1:]) bound = 1 / math.sqrt(fan_in) init_bias = get_condconv_initializer( partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) diff --git a/timm/models/_builder.py b/timm/models/_builder.py index 03e5dc3a8..d51db363a 100644 --- a/timm/models/_builder.py +++ b/timm/models/_builder.py @@ -3,7 +3,7 @@ import os from copy import deepcopy from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union from torch import nn as nn from torch.hub import load_state_dict_from_url @@ -26,11 +26,21 @@ _CHECK_HASH = False _USE_OLD_CACHE = int(os.environ.get('TIMM_USE_OLD_CACHE', 0)) > 0 -__all__ = ['set_pretrained_download_progress', 'set_pretrained_check_hash', 'load_custom_pretrained', 'load_pretrained', - 'pretrained_cfg_for_features', 'resolve_pretrained_cfg', 'build_model_with_cfg'] +__all__ = [ + 'set_pretrained_download_progress', + 'set_pretrained_check_hash', + 'load_custom_pretrained', + 'load_pretrained', + 'pretrained_cfg_for_features', + 'resolve_pretrained_cfg', + 'build_model_with_cfg', +] -def _resolve_pretrained_source(pretrained_cfg): +ModelT = TypeVar("ModelT", bound=nn.Module) # any subclass of nn.Module + + +def _resolve_pretrained_source(pretrained_cfg: Dict[str, Any]) -> Tuple[str, str]: cfg_source = pretrained_cfg.get('source', '') pretrained_url = pretrained_cfg.get('url', None) pretrained_file = pretrained_cfg.get('file', None) @@ -78,13 +88,13 @@ def _resolve_pretrained_source(pretrained_cfg): return load_from, pretrained_loc -def set_pretrained_download_progress(enable=True): +def set_pretrained_download_progress(enable: bool = True) -> None: """ Set download progress for pretrained weights on/off (globally). """ global _DOWNLOAD_PROGRESS _DOWNLOAD_PROGRESS = enable -def set_pretrained_check_hash(enable=True): +def set_pretrained_check_hash(enable: bool = True) -> None: """ Set hash checking for pretrained weights on/off (globally). """ global _CHECK_HASH _CHECK_HASH = enable @@ -92,11 +102,11 @@ def set_pretrained_check_hash(enable=True): def load_custom_pretrained( model: nn.Module, - pretrained_cfg: Optional[Dict] = None, + pretrained_cfg: Optional[Dict[str, Any]] = None, load_fn: Optional[Callable] = None, cache_dir: Optional[Union[str, Path]] = None, -): - r"""Loads a custom (read non .pth) weight file +) -> None: + """Loads a custom (read non .pth) weight file Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls a passed in custom load fun, or the `load_pretrained` model member fn. @@ -141,13 +151,13 @@ def load_custom_pretrained( def load_pretrained( model: nn.Module, - pretrained_cfg: Optional[Dict] = None, + pretrained_cfg: Optional[Dict[str, Any]] = None, num_classes: int = 1000, in_chans: int = 3, filter_fn: Optional[Callable] = None, strict: bool = True, cache_dir: Optional[Union[str, Path]] = None, -): +) -> None: """ Load pretrained checkpoint Args: @@ -278,7 +288,7 @@ def load_pretrained( f' This may be expected if model is being adapted.') -def pretrained_cfg_for_features(pretrained_cfg): +def pretrained_cfg_for_features(pretrained_cfg: Dict[str, Any]) -> Dict[str, Any]: pretrained_cfg = deepcopy(pretrained_cfg) # remove default pretrained cfg fields that don't have much relevance for feature backbone to_remove = ('num_classes', 'classifier', 'global_pool') # add default final pool size? @@ -287,14 +297,14 @@ def pretrained_cfg_for_features(pretrained_cfg): return pretrained_cfg -def _filter_kwargs(kwargs, names): +def _filter_kwargs(kwargs: Dict[str, Any], names: List[str]) -> None: if not kwargs or not names: return for n in names: kwargs.pop(n, None) -def _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter): +def _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter) -> None: """ Update the default_cfg and kwargs before passing to model Args: @@ -340,6 +350,7 @@ def resolve_pretrained_cfg( pretrained_cfg: Optional[Union[str, Dict[str, Any]]] = None, pretrained_cfg_overlay: Optional[Dict[str, Any]] = None, ) -> PretrainedCfg: + """Resolve pretrained configuration from various sources.""" model_with_tag = variant pretrained_tag = None if pretrained_cfg: @@ -371,7 +382,7 @@ def resolve_pretrained_cfg( def build_model_with_cfg( - model_cls: Callable, + model_cls: Union[Type[ModelT], Callable[..., ModelT]], variant: str, pretrained: bool, pretrained_cfg: Optional[Dict] = None, @@ -383,7 +394,7 @@ def build_model_with_cfg( cache_dir: Optional[Union[str, Path]] = None, kwargs_filter: Optional[Tuple[str]] = None, **kwargs, -): +) -> ModelT: """ Build model with specified default_cfg and optional model_cfg This helper fn aids in the construction of a model including: diff --git a/timm/models/_factory.py b/timm/models/_factory.py index 63e897e53..dc6909165 100644 --- a/timm/models/_factory.py +++ b/timm/models/_factory.py @@ -1,8 +1,10 @@ import os from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union from urllib.parse import urlsplit +from torch import nn + from timm.layers import set_layer_config from ._helpers import load_checkpoint from ._hub import load_model_config_from_hf, load_model_config_from_path @@ -13,7 +15,8 @@ __all__ = ['parse_model_name', 'safe_model_name', 'create_model'] -def parse_model_name(model_name: str): +def parse_model_name(model_name: str) -> Tuple[Optional[str], str]: + """Parse source and name from potentially prefixed model name.""" if model_name.startswith('hf_hub'): # NOTE for backwards compat, deprecate hf_hub use model_name = model_name.replace('hf_hub', 'hf-hub') @@ -29,9 +32,9 @@ def parse_model_name(model_name: str): return None, model_name -def safe_model_name(model_name: str, remove_source: bool = True): - # return a filename / path safe model name - def make_safe(name): +def safe_model_name(model_name: str, remove_source: bool = True) -> str: + """Return a filename / path safe model name.""" + def make_safe(name: str) -> str: return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_') if remove_source: model_name = parse_model_name(model_name)[-1] @@ -42,14 +45,14 @@ def create_model( model_name: str, pretrained: bool = False, pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None, - pretrained_cfg_overlay: Optional[Dict[str, Any]] = None, + pretrained_cfg_overlay: Optional[Dict[str, Any]] = None, checkpoint_path: Optional[Union[str, Path]] = None, cache_dir: Optional[Union[str, Path]] = None, scriptable: Optional[bool] = None, exportable: Optional[bool] = None, no_jit: Optional[bool] = None, - **kwargs, -): + **kwargs: Any, +) -> nn.Module: """Create a model. Lookup model's entrypoint function and pass relevant args to create a new model. diff --git a/timm/models/_helpers.py b/timm/models/_helpers.py index ca5dc2445..e1da0e224 100644 --- a/timm/models/_helpers.py +++ b/timm/models/_helpers.py @@ -7,8 +7,10 @@ from typing import Any, Callable, Dict, Optional, Union import torch + try: import safetensors.torch + _has_safetensors = True except ImportError: _has_safetensors = False @@ -18,7 +20,7 @@ __all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_state_dict', 'resume_checkpoint'] -def _remove_prefix(text, prefix): +def _remove_prefix(text: str, prefix: str) -> str: # FIXME replace with 3.9 stdlib fn when min at 3.9 if text.startswith(prefix): return text[len(prefix):] @@ -45,6 +47,17 @@ def load_state_dict( device: Union[str, torch.device] = 'cpu', weights_only: bool = False, ) -> Dict[str, Any]: + """Load state dictionary from checkpoint file. + + Args: + checkpoint_path: Path to checkpoint file. + use_ema: Whether to use EMA weights if available. + device: Device to load checkpoint to. + weights_only: Whether to load only weights (torch.load parameter). + + Returns: + State dictionary loaded from checkpoint. + """ if checkpoint_path and os.path.isfile(checkpoint_path): # Check if safetensors or not and load weights accordingly if str(checkpoint_path).endswith(".safetensors"): @@ -83,7 +96,22 @@ def load_checkpoint( remap: bool = False, filter_fn: Optional[Callable] = None, weights_only: bool = False, -): +) -> Any: + """Load checkpoint into model. + + Args: + model: Model to load checkpoint into. + checkpoint_path: Path to checkpoint file. + use_ema: Whether to use EMA weights if available. + device: Device to load checkpoint to. + strict: Whether to strictly enforce state_dict keys match. + remap: Whether to remap state dict keys by order. + filter_fn: Optional function to filter state dict. + weights_only: Whether to load only weights (torch.load parameter). + + Returns: + Incompatible keys from model.load_state_dict(). + """ if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'): # numpy checkpoint, try to load via model specific load_pretrained fn if hasattr(model, 'load_pretrained'): @@ -105,9 +133,18 @@ def remap_state_dict( state_dict: Dict[str, Any], model: torch.nn.Module, allow_reshape: bool = True -): - """ remap checkpoint by iterating over state dicts in order (ignoring original keys). +) -> Dict[str, Any]: + """Remap checkpoint by iterating over state dicts in order (ignoring original keys). + This assumes models (and originating state dict) were created with params registered in same order. + + Args: + state_dict: State dict to remap. + model: Model whose state dict keys to use. + allow_reshape: Whether to allow reshaping tensors to match. + + Returns: + Remapped state dictionary. """ out_dict = {} for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()): @@ -116,7 +153,7 @@ def remap_state_dict( if allow_reshape: vb = vb.reshape(va.shape) else: - assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.' + assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.' out_dict[ka] = vb return out_dict @@ -124,10 +161,22 @@ def remap_state_dict( def resume_checkpoint( model: torch.nn.Module, checkpoint_path: str, - optimizer: torch.optim.Optimizer = None, - loss_scaler: Any = None, + optimizer: Optional[torch.optim.Optimizer] = None, + loss_scaler: Optional[Any] = None, log_info: bool = True, -): +) -> Optional[int]: + """Resume training from checkpoint. + + Args: + model: Model to load checkpoint into. + checkpoint_path: Path to checkpoint file. + optimizer: Optional optimizer to restore state. + loss_scaler: Optional AMP loss scaler to restore state. + log_info: Whether to log loading info. + + Returns: + Resume epoch number if available, else None. + """ resume_epoch = None if os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) @@ -162,5 +211,3 @@ def resume_checkpoint( else: _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) raise FileNotFoundError() - - diff --git a/timm/models/beit.py b/timm/models/beit.py index 5123a6062..2ee5fbb01 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -39,7 +39,7 @@ # --------------------------------------------------------' import math -from typing import Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import torch import torch.nn as nn @@ -58,6 +58,18 @@ def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor: + """Generate relative position index for window-based attention. + + Creates a lookup table for relative position indices between all pairs of positions + within a window, including special handling for cls token interactions. + + Args: + window_size: Height and width of the attention window. + + Returns: + Relative position index tensor of shape (window_area+1, window_area+1) + where +1 accounts for the cls token. + """ num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 # cls to token & token 2 cls & cls to cls # get pair-wise relative position index for each token inside the window @@ -78,6 +90,11 @@ def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor: class Attention(nn.Module): + """Multi-head attention module with optional relative position bias. + + Implements multi-head self-attention with support for relative position bias + and fused attention operations. Can use either standard or custom head dimensions. + """ fused_attn: torch.jit.Final[bool] def __init__( @@ -91,6 +108,18 @@ def __init__( window_size: Optional[Tuple[int, int]] = None, attn_head_dim: Optional[int] = None, ): + """Initialize attention module. + + Args: + dim: Input feature dimension. + num_heads: Number of attention heads. + qkv_bias: If True, add learnable bias to query, key, value projections. + qkv_bias_separate: If True, use separate bias for q, k, v projections. + attn_drop: Dropout rate for attention weights. + proj_drop: Dropout rate for output projection. + window_size: Window size for relative position bias. If None, no relative position bias. + attn_head_dim: Dimension per attention head. If None, uses dim // num_heads. + """ super().__init__() self.num_heads = num_heads head_dim = dim // num_heads @@ -126,7 +155,12 @@ def __init__( self.proj = nn.Linear(all_head_dim, dim) self.proj_drop = nn.Dropout(proj_drop) - def _get_rel_pos_bias(self): + def _get_rel_pos_bias(self) -> torch.Tensor: + """Get relative position bias for the attention window. + + Returns: + Relative position bias tensor of shape (1, num_heads, window_area+1, window_area+1). + """ relative_position_bias = self.relative_position_bias_table[ self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1] + 1, @@ -134,7 +168,16 @@ def _get_rel_pos_bias(self): relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww return relative_position_bias.unsqueeze(0) - def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None): + def forward(self, x: torch.Tensor, shared_rel_pos_bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward pass of attention module. + + Args: + x: Input tensor of shape (batch_size, num_tokens, dim). + shared_rel_pos_bias: Optional shared relative position bias from parent module. + + Returns: + Output tensor of shape (batch_size, num_tokens, dim). + """ B, N, C = x.shape if self.q_bias is None: @@ -183,6 +226,12 @@ def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None): class Block(nn.Module): + """Transformer block with attention and MLP. + + Standard transformer block consisting of multi-head self-attention and MLP + with residual connections and layer normalization. Supports layer scale and + stochastic depth regularization. + """ def __init__( self, @@ -201,6 +250,24 @@ def __init__( window_size: Optional[Tuple[int, int]] = None, attn_head_dim: Optional[int] = None, ): + """Initialize transformer block. + + Args: + dim: Input feature dimension. + num_heads: Number of attention heads. + qkv_bias: If True, add learnable bias to query, key, value projections. + mlp_ratio: Ratio of MLP hidden dimension to input dimension. + scale_mlp: If True, apply layer normalization in MLP. + swiglu_mlp: If True, use SwiGLU activation in MLP. + proj_drop: Dropout rate for projections. + attn_drop: Dropout rate for attention. + drop_path: Drop path rate for stochastic depth. + init_values: Initial values for layer scale. If None, no layer scale. + act_layer: Activation function class. + norm_layer: Normalization layer class. + window_size: Window size for relative position bias in attention. + attn_head_dim: Dimension per attention head. + """ super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( @@ -239,7 +306,16 @@ def __init__( else: self.gamma_1, self.gamma_2 = None, None - def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None): + def forward(self, x: torch.Tensor, shared_rel_pos_bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward pass of transformer block. + + Args: + x: Input tensor of shape (batch_size, num_tokens, dim). + shared_rel_pos_bias: Optional shared relative position bias. + + Returns: + Output tensor of shape (batch_size, num_tokens, dim). + """ if self.gamma_1 is None: x = x + self.drop_path1(self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias)) x = x + self.drop_path2(self.mlp(self.norm2(x))) @@ -250,8 +326,19 @@ def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None): class RelativePositionBias(nn.Module): + """Relative position bias module for window-based attention. + + Generates learnable relative position biases for all pairs of positions + within a window, including special handling for cls token. + """ + + def __init__(self, window_size: Tuple[int, int], num_heads: int): + """Initialize relative position bias module. - def __init__(self, window_size, num_heads): + Args: + window_size: Height and width of the attention window. + num_heads: Number of attention heads. + """ super().__init__() self.window_size = window_size self.window_area = window_size[0] * window_size[1] @@ -260,14 +347,23 @@ def __init__(self, window_size, num_heads): # trunc_normal_(self.relative_position_bias_table, std=.02) self.register_buffer("relative_position_index", gen_relative_position_index(window_size)) - def forward(self): + def forward(self) -> torch.Tensor: + """Generate relative position bias. + + Returns: + Relative position bias tensor of shape (num_heads, window_area+1, window_area+1). + """ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_area + 1, self.window_area + 1, -1) # Wh*Ww,Wh*Ww,nH return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww class Beit(nn.Module): - """ Vision Transformer with support for patch or hybrid CNN input stage + """BEiT: BERT Pre-Training of Image Transformers. + + Vision Transformer model with support for relative position bias and + shared relative position bias across layers. Implements both BEiT v1 and v2 + architectures with flexible configuration options. """ def __init__( @@ -296,6 +392,33 @@ def __init__( use_shared_rel_pos_bias: bool = False, head_init_scale: float = 0.001, ): + """Initialize BEiT model. + + Args: + img_size: Input image size. + patch_size: Patch size for patch embedding. + in_chans: Number of input image channels. + num_classes: Number of classes for classification head. + global_pool: Type of global pooling ('avg' or ''). + embed_dim: Embedding dimension. + depth: Number of transformer blocks. + num_heads: Number of attention heads. + qkv_bias: If True, add learnable bias to query, key, value projections. + mlp_ratio: Ratio of MLP hidden dimension to embedding dimension. + swiglu_mlp: If True, use SwiGLU activation in MLP. + scale_mlp: If True, apply layer normalization in MLP. + drop_rate: Dropout rate. + pos_drop_rate: Dropout rate for position embeddings. + proj_drop_rate: Dropout rate for projections. + attn_drop_rate: Dropout rate for attention. + drop_path_rate: Stochastic depth rate. + norm_layer: Normalization layer class. + init_values: Initial values for layer scale. + use_abs_pos_emb: If True, use absolute position embeddings. + use_rel_pos_bias: If True, use relative position bias in attention. + use_shared_rel_pos_bias: If True, share relative position bias across layers. + head_init_scale: Scale factor for head initialization. + """ super().__init__() self.num_classes = num_classes self.global_pool = global_pool @@ -363,6 +486,11 @@ def __init__( self.head.bias.data.mul_(head_init_scale) def fix_init_weight(self): + """Fix initialization weights according to BEiT paper. + + Rescales attention and MLP weights based on layer depth to improve + training stability. + """ def rescale(param, layer_id): param.div_(math.sqrt(2.0 * layer_id)) @@ -370,7 +498,12 @@ def rescale(param, layer_id): rescale(layer.attn.proj.weight.data, layer_id + 1) rescale(layer.mlp.fc2.weight.data, layer_id + 1) - def _init_weights(self, m): + def _init_weights(self, m: nn.Module): + """Initialize model weights. + + Args: + m: Module to initialize. + """ if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: @@ -380,7 +513,12 @@ def _init_weights(self, m): nn.init.constant_(m.weight, 1.0) @torch.jit.ignore - def no_weight_decay(self): + def no_weight_decay(self) -> Set[str]: + """Get parameter names that should not use weight decay. + + Returns: + Set of parameter names to exclude from weight decay. + """ nwd = {'pos_embed', 'cls_token'} for n, _ in self.named_parameters(): if 'relative_position_bias_table' in n: @@ -388,11 +526,24 @@ def no_weight_decay(self): return nwd @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): + def set_grad_checkpointing(self, enable: bool = True): + """Enable or disable gradient checkpointing. + + Args: + enable: If True, enable gradient checkpointing. + """ self.grad_checkpointing = enable @torch.jit.ignore - def group_matcher(self, coarse=False): + def group_matcher(self, coarse: bool = False) -> Dict[str, Any]: + """Create parameter group matcher for optimizer parameter groups. + + Args: + coarse: If True, use coarse grouping. + + Returns: + Dictionary mapping group names to regex patterns. + """ matcher = dict( stem=r'^cls_token|pos_embed|patch_embed|rel_pos_bias', # stem and embed blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))], @@ -401,9 +552,20 @@ def group_matcher(self, coarse=False): @torch.jit.ignore def get_classifier(self) -> nn.Module: + """Get the classifier head. + + Returns: + The classification head module. + """ return self.head def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + """Reset the classification head. + + Args: + num_classes: Number of classes for new head. + global_pool: Global pooling type. + """ self.num_classes = num_classes if global_pool is not None: self.global_pool = global_pool @@ -419,18 +581,20 @@ def forward_intermediates( output_fmt: str = 'NCHW', intermediates_only: bool = False, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: - """ Forward features that returns intermediates. + """Forward pass that returns intermediate feature maps. Args: - x: Input image tensor - indices: Take last n blocks if an int, if is a sequence, select by matching indices - return_prefix_tokens: Return both prefix and spatial intermediate tokens - norm: Apply norm layer to all intermediates - stop_early: Stop iterating over blocks when last desired intermediate hit - output_fmt: Shape of intermediate feature outputs - intermediates_only: Only return intermediate features - Returns: + x: Input image tensor of shape (batch_size, channels, height, width). + indices: Block indices to return features from. If int, returns last n blocks. + return_prefix_tokens: If True, return both prefix and spatial tokens. + norm: If True, apply normalization to intermediate features. + stop_early: If True, stop at last selected intermediate. + output_fmt: Output format ('NCHW' or 'NLC'). + intermediates_only: If True, only return intermediate features. + Returns: + If intermediates_only is True, returns list of intermediate tensors. + Otherwise, returns tuple of (final_features, intermediates). """ assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.' reshape = output_fmt == 'NCHW' @@ -481,8 +645,16 @@ def prune_intermediate_layers( indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, - ): - """ Prune layers not required for specified intermediates. + ) -> List[int]: + """Prune layers not required for specified intermediate outputs. + + Args: + indices: Indices of blocks to keep. + prune_norm: If True, remove final normalization. + prune_head: If True, remove classification head. + + Returns: + List of indices that were kept. """ take_indices, max_index = feature_take_indices(len(self.blocks), indices) self.blocks = self.blocks[:max_index + 1] # truncate blocks @@ -493,7 +665,15 @@ def prune_intermediate_layers( self.reset_classifier(0, '') return take_indices - def forward_features(self, x): + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through feature extraction layers. + + Args: + x: Input tensor of shape (batch_size, channels, height, width). + + Returns: + Feature tensor of shape (batch_size, num_tokens, embed_dim). + """ x = self.patch_embed(x) x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) if self.pos_embed is not None: @@ -509,20 +689,46 @@ def forward_features(self, x): x = self.norm(x) return x - def forward_head(self, x, pre_logits: bool = False): + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + """Forward pass through classification head. + + Args: + x: Feature tensor of shape (batch_size, num_tokens, embed_dim). + pre_logits: If True, return features before final linear layer. + + Returns: + Logits tensor of shape (batch_size, num_classes) or pre-logits. + """ if self.global_pool: x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] x = self.fc_norm(x) x = self.head_drop(x) return x if pre_logits else self.head(x) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through the model. + + Args: + x: Input tensor of shape (batch_size, channels, height, width). + + Returns: + Logits tensor of shape (batch_size, num_classes). + """ x = self.forward_features(x) x = self.forward_head(x) return x -def _cfg(url='', **kwargs): +def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: + """Create a default configuration dictionary for BEiT models. + + Args: + url: Model weights URL. + **kwargs: Additional configuration parameters. + + Returns: + Configuration dictionary. + """ return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, @@ -599,7 +805,21 @@ def _cfg(url='', **kwargs): }) -def checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antialias=True): +def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module, interpolation: str = 'bicubic', antialias: bool = True) -> Dict[str, torch.Tensor]: + """Filter and process checkpoint state dict for loading. + + Handles resizing of patch embeddings, position embeddings, and relative position + bias tables when model size differs from checkpoint. + + Args: + state_dict: Checkpoint state dictionary. + model: Target model to load weights into. + interpolation: Interpolation method for resizing. + antialias: If True, use antialiasing when resizing. + + Returns: + Filtered state dictionary. + """ state_dict = state_dict.get('model', state_dict) state_dict = state_dict.get('module', state_dict) # beit v2 didn't strip module @@ -641,7 +861,17 @@ def checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antialias=T return out_dict -def _create_beit(variant, pretrained=False, **kwargs): +def _create_beit(variant: str, pretrained: bool = False, **kwargs) -> Beit: + """Create a BEiT model. + + Args: + variant: Model variant name. + pretrained: If True, load pretrained weights. + **kwargs: Additional model arguments. + + Returns: + BEiT model instance. + """ out_indices = kwargs.pop('out_indices', 3) model = build_model_with_cfg( Beit, variant, pretrained, @@ -653,7 +883,8 @@ def _create_beit(variant, pretrained=False, **kwargs): @register_model -def beit_base_patch16_224(pretrained=False, **kwargs) -> Beit: +def beit_base_patch16_224(pretrained: bool = False, **kwargs) -> Beit: + """BEiT base model @ 224x224 with patch size 16x16.""" model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1) @@ -662,7 +893,8 @@ def beit_base_patch16_224(pretrained=False, **kwargs) -> Beit: @register_model -def beit_base_patch16_384(pretrained=False, **kwargs) -> Beit: +def beit_base_patch16_384(pretrained: bool = False, **kwargs) -> Beit: + """BEiT base model @ 384x384 with patch size 16x16.""" model_args = dict( img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1) @@ -671,7 +903,8 @@ def beit_base_patch16_384(pretrained=False, **kwargs) -> Beit: @register_model -def beit_large_patch16_224(pretrained=False, **kwargs) -> Beit: +def beit_large_patch16_224(pretrained: bool = False, **kwargs) -> Beit: + """BEiT large model @ 224x224 with patch size 16x16.""" model_args = dict( patch_size=16, embed_dim=1024, depth=24, num_heads=16, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5) @@ -680,7 +913,8 @@ def beit_large_patch16_224(pretrained=False, **kwargs) -> Beit: @register_model -def beit_large_patch16_384(pretrained=False, **kwargs) -> Beit: +def beit_large_patch16_384(pretrained: bool = False, **kwargs) -> Beit: + """BEiT large model @ 384x384 with patch size 16x16.""" model_args = dict( img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5) @@ -689,7 +923,8 @@ def beit_large_patch16_384(pretrained=False, **kwargs) -> Beit: @register_model -def beit_large_patch16_512(pretrained=False, **kwargs) -> Beit: +def beit_large_patch16_512(pretrained: bool = False, **kwargs) -> Beit: + """BEiT large model @ 512x512 with patch size 16x16.""" model_args = dict( img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5) @@ -698,7 +933,8 @@ def beit_large_patch16_512(pretrained=False, **kwargs) -> Beit: @register_model -def beitv2_base_patch16_224(pretrained=False, **kwargs) -> Beit: +def beitv2_base_patch16_224(pretrained: bool = False, **kwargs) -> Beit: + """BEiT v2 base model @ 224x224 with patch size 16x16.""" model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5) @@ -707,7 +943,8 @@ def beitv2_base_patch16_224(pretrained=False, **kwargs) -> Beit: @register_model -def beitv2_large_patch16_224(pretrained=False, **kwargs) -> Beit: +def beitv2_large_patch16_224(pretrained: bool = False, **kwargs) -> Beit: + """BEiT v2 large model @ 224x224 with patch size 16x16.""" model_args = dict( patch_size=16, embed_dim=1024, depth=24, num_heads=16, use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5) diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index 683ed0ca0..f5af4fd11 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -12,6 +12,8 @@ Hacked together by / copyright Ross Wightman, 2021. """ +from typing import Any, Dict, Optional + from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from ._builder import build_model_with_cfg from ._registry import register_model, generate_default_cfgs @@ -260,7 +262,18 @@ ) -def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs): +def _create_byoanet(variant: str, cfg_variant: Optional[str] = None, pretrained: bool = False, **kwargs) -> ByobNet: + """Create a Bring-Your-Own-Attention network model. + + Args: + variant: Model variant name. + cfg_variant: Config variant name if different from model variant. + pretrained: Load pretrained weights. + **kwargs: Additional model arguments. + + Returns: + Instantiated ByobNet model. + """ return build_model_with_cfg( ByobNet, variant, pretrained, model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], @@ -269,7 +282,16 @@ def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs): ) -def _cfg(url='', **kwargs): +def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: + """Generate default model configuration. + + Args: + url: URL for pretrained weights. + **kwargs: Override default configuration values. + + Returns: + Model configuration dictionary. + """ return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.95, 'interpolation': 'bicubic', @@ -346,7 +368,7 @@ def _cfg(url='', **kwargs): @register_model -def botnet26t_256(pretrained=False, **kwargs) -> ByobNet: +def botnet26t_256(pretrained: bool = False, **kwargs) -> ByobNet: """ Bottleneck Transformer w/ ResNet26-T backbone. """ kwargs.setdefault('img_size', 256) @@ -354,14 +376,14 @@ def botnet26t_256(pretrained=False, **kwargs) -> ByobNet: @register_model -def sebotnet33ts_256(pretrained=False, **kwargs) -> ByobNet: +def sebotnet33ts_256(pretrained: bool = False, **kwargs) -> ByobNet: """ Bottleneck Transformer w/ a ResNet33-t backbone, SE attn for non Halo blocks, SiLU, """ return _create_byoanet('sebotnet33ts_256', 'sebotnet33ts', pretrained=pretrained, **kwargs) @register_model -def botnet50ts_256(pretrained=False, **kwargs) -> ByobNet: +def botnet50ts_256(pretrained: bool = False, **kwargs) -> ByobNet: """ Bottleneck Transformer w/ ResNet50-T backbone, silu act. """ kwargs.setdefault('img_size', 256) @@ -369,7 +391,7 @@ def botnet50ts_256(pretrained=False, **kwargs) -> ByobNet: @register_model -def eca_botnext26ts_256(pretrained=False, **kwargs) -> ByobNet: +def eca_botnext26ts_256(pretrained: bool = False, **kwargs) -> ByobNet: """ Bottleneck Transformer w/ ResNet26-T backbone, silu act. """ kwargs.setdefault('img_size', 256) @@ -377,7 +399,7 @@ def eca_botnext26ts_256(pretrained=False, **kwargs) -> ByobNet: @register_model -def halonet_h1(pretrained=False, **kwargs) -> ByobNet: +def halonet_h1(pretrained: bool = False, **kwargs) -> ByobNet: """ HaloNet-H1. Halo attention in all stages as per the paper. NOTE: This runs very slowly! """ @@ -385,49 +407,49 @@ def halonet_h1(pretrained=False, **kwargs) -> ByobNet: @register_model -def halonet26t(pretrained=False, **kwargs) -> ByobNet: +def halonet26t(pretrained: bool = False, **kwargs) -> ByobNet: """ HaloNet w/ a ResNet26-t backbone. Halo attention in final two stages """ return _create_byoanet('halonet26t', pretrained=pretrained, **kwargs) @register_model -def sehalonet33ts(pretrained=False, **kwargs) -> ByobNet: +def sehalonet33ts(pretrained: bool = False, **kwargs) -> ByobNet: """ HaloNet w/ a ResNet33-t backbone, SE attn for non Halo blocks, SiLU, 1-2 Halo in stage 2,3,4. """ return _create_byoanet('sehalonet33ts', pretrained=pretrained, **kwargs) @register_model -def halonet50ts(pretrained=False, **kwargs) -> ByobNet: +def halonet50ts(pretrained: bool = False, **kwargs) -> ByobNet: """ HaloNet w/ a ResNet50-t backbone, silu act. Halo attention in final two stages """ return _create_byoanet('halonet50ts', pretrained=pretrained, **kwargs) @register_model -def eca_halonext26ts(pretrained=False, **kwargs) -> ByobNet: +def eca_halonext26ts(pretrained: bool = False, **kwargs) -> ByobNet: """ HaloNet w/ a ResNet26-t backbone, silu act. Halo attention in final two stages """ return _create_byoanet('eca_halonext26ts', pretrained=pretrained, **kwargs) @register_model -def lambda_resnet26t(pretrained=False, **kwargs) -> ByobNet: +def lambda_resnet26t(pretrained: bool = False, **kwargs) -> ByobNet: """ Lambda-ResNet-26-T. Lambda layers w/ conv pos in last two stages. """ return _create_byoanet('lambda_resnet26t', pretrained=pretrained, **kwargs) @register_model -def lambda_resnet50ts(pretrained=False, **kwargs) -> ByobNet: +def lambda_resnet50ts(pretrained: bool = False, **kwargs) -> ByobNet: """ Lambda-ResNet-50-TS. SiLU act. Lambda layers w/ conv pos in last two stages. """ return _create_byoanet('lambda_resnet50ts', pretrained=pretrained, **kwargs) @register_model -def lambda_resnet26rpt_256(pretrained=False, **kwargs) -> ByobNet: +def lambda_resnet26rpt_256(pretrained: bool = False, **kwargs) -> ByobNet: """ Lambda-ResNet-26-R-T. Lambda layers w/ rel pos embed in last two stages. """ kwargs.setdefault('img_size', 256) @@ -435,21 +457,21 @@ def lambda_resnet26rpt_256(pretrained=False, **kwargs) -> ByobNet: @register_model -def haloregnetz_b(pretrained=False, **kwargs) -> ByobNet: +def haloregnetz_b(pretrained: bool = False, **kwargs) -> ByobNet: """ Halo + RegNetZ """ return _create_byoanet('haloregnetz_b', pretrained=pretrained, **kwargs) @register_model -def lamhalobotnet50ts_256(pretrained=False, **kwargs) -> ByobNet: +def lamhalobotnet50ts_256(pretrained: bool = False, **kwargs) -> ByobNet: """ Combo Attention (Lambda + Halo + Bot) Network """ return _create_byoanet('lamhalobotnet50ts_256', 'lamhalobotnet50ts', pretrained=pretrained, **kwargs) @register_model -def halo2botnet50ts_256(pretrained=False, **kwargs) -> ByobNet: +def halo2botnet50ts_256(pretrained: bool = False, **kwargs) -> ByobNet: """ Combo Attention (Halo + Halo + Bot) Network """ return _create_byoanet('halo2botnet50ts_256', 'halo2botnet50ts', pretrained=pretrained, **kwargs) diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 764d5ad5e..18da1f2dc 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -52,6 +52,10 @@ @dataclass class ByoBlockCfg: + """Block configuration for Bring-Your-Own-Blocks. + + Defines configuration for a single block or stage of blocks. + """ type: Union[str, nn.Module] d: int # block depth (number of block repeats in stage) c: int # number of output channels for each block in stage @@ -69,6 +73,10 @@ class ByoBlockCfg: @dataclass class ByoModelCfg: + """Model configuration for Bring-Your-Own-Blocks network. + + Defines overall architecture configuration. + """ blocks: Tuple[Union[ByoBlockCfg, Tuple[ByoBlockCfg, ...]], ...] downsample: str = 'conv1x1' stem_type: str = '3x3' @@ -97,7 +105,18 @@ class ByoModelCfg: block_kwargs: Dict[str, Any] = field(default_factory=lambda: dict()) -def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0): +def _rep_vgg_bcfg(d: Tuple[int, ...] = (4, 6, 16, 1), wf: Tuple[float, ...] = (1., 1., 1., 1.), groups: int = 0) -> \ +Tuple[ByoBlockCfg, ...]: + """Create RepVGG block configuration. + + Args: + d: Depth (number of blocks) per stage. + wf: Width factor per stage. + groups: Number of groups for grouped convolution. + + Returns: + Tuple of block configurations. + """ c = (64, 128, 256, 512) group_size = 0 if groups > 0: @@ -106,7 +125,23 @@ def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0): return bcfg -def _mobileone_bcfg(d=(2, 8, 10, 1), wf=(1., 1., 1., 1.), se_blocks=(), num_conv_branches=1): +def _mobileone_bcfg( + d: Tuple[int, ...] = (2, 8, 10, 1), + wf: Tuple[float, ...] = (1., 1., 1., 1.), + se_blocks: Tuple[int, ...] = (), + num_conv_branches: int = 1 +) -> List[List[ByoBlockCfg]]: + """Create MobileOne block configuration. + + Args: + d: Depth (number of blocks) per stage. + wf: Width factor per stage. + se_blocks: Number of SE blocks per stage. + num_conv_branches: Number of conv branches. + + Returns: + List of block configurations per stage. + """ c = (64, 128, 256, 512) prev_c = min(64, c[0] * wf[0]) se_blocks = se_blocks or (0,) * len(d) @@ -128,12 +163,23 @@ def _mobileone_bcfg(d=(2, 8, 10, 1), wf=(1., 1., 1., 1.), se_blocks=(), num_conv def interleave_blocks( - types: Tuple[str, str], d, + types: Tuple[str, str], + d: int, every: Union[int, List[int]] = 1, first: bool = False, **kwargs, -) -> Tuple[ByoBlockCfg]: - """ interleave 2 block types in stack +) -> Tuple[ByoBlockCfg, ...]: + """Interleave 2 block types in stack. + + Args: + types: Two block type names to interleave. + d: Total depth of blocks. + every: Interval for alternating blocks. + first: Whether to start with alternate block. + **kwargs: Additional block arguments. + + Returns: + Tuple of interleaved block configurations. """ assert len(types) == 2 if isinstance(every, int): @@ -149,6 +195,14 @@ def interleave_blocks( def expand_blocks_cfg(stage_blocks_cfg: Union[ByoBlockCfg, Sequence[ByoBlockCfg]]) -> List[ByoBlockCfg]: + """Expand block config into individual block instances. + + Args: + stage_blocks_cfg: Block configuration(s) for a stage. + + Returns: + List of individual block configurations. + """ if not isinstance(stage_blocks_cfg, Sequence): stage_blocks_cfg = (stage_blocks_cfg,) block_cfgs = [] @@ -157,7 +211,16 @@ def expand_blocks_cfg(stage_blocks_cfg: Union[ByoBlockCfg, Sequence[ByoBlockCfg] return block_cfgs -def num_groups(group_size, channels): +def num_groups(group_size: Optional[int], channels: int) -> int: + """Calculate number of groups for grouped convolution. + + Args: + group_size: Size of each group (1 for depthwise). + channels: Number of channels. + + Returns: + Number of groups. + """ if not group_size: # 0 or None return 1 # normal conv with 1 group else: @@ -168,6 +231,7 @@ def num_groups(group_size, channels): @dataclass class LayerFn: + """Container for layer factory functions.""" conv_norm_act: Callable = ConvNormAct norm_act: Callable = BatchNormAct2d act: Callable = nn.ReLU @@ -176,6 +240,11 @@ class LayerFn: class DownsampleAvg(nn.Module): + """Average pool downsampling module. + + AvgPool Downsampling as in 'D' ResNet variants. + """ + def __init__( self, in_chs: int, @@ -183,9 +252,18 @@ def __init__( stride: int = 1, dilation: int = 1, apply_act: bool = False, - layers: LayerFn = None, + layers: Optional[LayerFn] = None, ): - """ AvgPool Downsampling as in 'D' ResNet variants.""" + """Initialize DownsampleAvg. + + Args: + in_chs: Number of input channels. + out_chs: Number of output channels. + stride: Stride for downsampling. + dilation: Dilation rate. + apply_act: Whether to apply activation. + layers: Layer factory functions. + """ super(DownsampleAvg, self).__init__() layers = layers or LayerFn() avg_stride = stride if dilation == 1 else 1 @@ -196,7 +274,15 @@ def __init__( self.pool = nn.Identity() self.conv = layers.conv_norm_act(in_chs, out_chs, 1, apply_act=apply_act) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor. + + Returns: + Output tensor. + """ return self.conv(self.pool(x)) @@ -208,7 +294,21 @@ def create_shortcut( dilation: Tuple[int, int], layers: LayerFn, **kwargs, -): +) -> Optional[nn.Module]: + """Create shortcut connection for residual blocks. + + Args: + downsample_type: Type of downsampling ('avg', 'conv1x1', or ''). + in_chs: Input channels. + out_chs: Output channels. + stride: Stride for downsampling. + dilation: Dilation rates. + layers: Layer factory functions. + **kwargs: Additional arguments. + + Returns: + Shortcut module or None. + """ assert downsample_type in ('avg', 'conv1x1', '') if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: if not downsample_type: @@ -911,7 +1011,7 @@ def forward(self, x): ) -def register_block(block_type:str, block_fn: nn.Module): +def register_block(block_type: str, block_fn: nn.Module): _block_registry[block_type] = block_fn @@ -1108,7 +1208,6 @@ def create_byob_stages( layers: Optional[LayerFn] = None, block_kwargs_fn: Optional[Callable] = update_block_kwargs, ): - layers = layers or LayerFn() feature_info = [] block_cfgs = [expand_blocks_cfg(s) for s in cfg.blocks] @@ -1177,13 +1276,14 @@ def get_layer_fns(cfg: ByoModelCfg, allow_aa: bool = True): class ByobNet(nn.Module): - """ 'Bring-your-own-blocks' Net + """Bring-your-own-blocks Network. A flexible network backbone that allows building model stem + blocks via dataclass cfg definition w/ factory functions for module instantiation. Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act). """ + def __init__( self, cfg: ByoModelCfg, @@ -1193,7 +1293,7 @@ def __init__( output_stride: int = 32, img_size: Optional[Union[int, Tuple[int, int]]] = None, drop_rate: float = 0., - drop_path_rate: float =0., + drop_path_rate: float = 0., zero_init_last: bool = True, **kwargs, ): @@ -1288,7 +1388,7 @@ def __init__( qkv_separate=True, ) self.head_hidden_size = self.head.embed_dim - elif cfg.head_type =='attn_rot': + elif cfg.head_type == 'attn_rot': if global_pool is None: global_pool = 'token' assert global_pool in ('', 'token') @@ -1318,7 +1418,15 @@ def __init__( named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) @torch.jit.ignore - def group_matcher(self, coarse=False): + def group_matcher(self, coarse: bool = False) -> Dict[str, Any]: + """Group matcher for parameter groups. + + Args: + coarse: Whether to use coarse grouping. + + Returns: + Dictionary mapping group names to patterns. + """ matcher = dict( stem=r'^stem', blocks=[ @@ -1329,14 +1437,30 @@ def group_matcher(self, coarse=False): return matcher @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): + def set_grad_checkpointing(self, enable: bool = True) -> None: + """Enable or disable gradient checkpointing. + + Args: + enable: Whether to enable gradient checkpointing. + """ self.grad_checkpointing = enable @torch.jit.ignore def get_classifier(self) -> nn.Module: + """Get classifier module. + + Returns: + Classifier module. + """ return self.head.fc - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None: + """Reset classifier. + + Args: + num_classes: Number of classes for new classifier. + global_pool: Global pooling type. + """ self.num_classes = num_classes self.head.reset(num_classes, global_pool) @@ -1404,8 +1528,16 @@ def prune_intermediate_layers( indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, - ): - """ Prune layers not required for specified intermediates. + ) -> List[int]: + """Prune layers not required for specified intermediates. + + Args: + indices: Indices of intermediate layers to keep. + prune_norm: Whether to prune normalization layer. + prune_head: Whether to prune the classifier head. + + Returns: + List of indices that were kept. """ take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) max_index = self.stage_ends[max_index] @@ -1416,8 +1548,15 @@ def prune_intermediate_layers( self.reset_classifier(0, '') return take_indices + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through feature extraction. - def forward_features(self, x): + Args: + x: Input tensor. + + Returns: + Feature tensor. + """ x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.stages, x) @@ -1426,16 +1565,40 @@ def forward_features(self, x): x = self.final_conv(x) return x - def forward_head(self, x, pre_logits: bool = False): + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + """Forward pass through head. + + Args: + x: Input features. + pre_logits: Return features before final linear layer. + + Returns: + Classification logits or features. + """ return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor. + + Returns: + Output logits. + """ x = self.forward_features(x) x = self.forward_head(x) return x -def _init_weights(module, name='', zero_init_last=False): +def _init_weights(module: nn.Module, name: str = '', zero_init_last: bool = False) -> None: + """Initialize weights. + + Args: + module: Module to initialize. + name: Module name. + zero_init_last: Zero-initialize last layer. + """ if isinstance(module, nn.Conv2d): fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels fan_out //= module.groups @@ -2058,7 +2221,7 @@ def _down_sub(m): if k.startswith(f'{prefix}attnpool'): if not model_has_attn_pool: continue - k = k.replace(prefix + 'attnpool', 'head') #'attn_pool') + k = k.replace(prefix + 'attnpool', 'head') # 'attn_pool') k = k.replace('positional_embedding', 'pos_embed') k = k.replace('q_proj', 'q') k = k.replace('k_proj', 'k') @@ -2078,7 +2241,17 @@ def checkpoint_filter_fn( return state_dict -def _create_byobnet(variant, pretrained=False, **kwargs): +def _create_byobnet(variant: str, pretrained: bool = False, **kwargs) -> ByobNet: + """Create a ByobNet model. + + Args: + variant: Model variant name. + pretrained: Load pretrained weights. + **kwargs: Additional model arguments. + + Returns: + ByobNet model instance. + """ return build_model_with_cfg( ByobNet, variant, pretrained, model_cfg=model_cfgs[variant], @@ -2088,7 +2261,16 @@ def _create_byobnet(variant, pretrained=False, **kwargs): ) -def _cfg(url='', **kwargs): +def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: + """Create default configuration dictionary. + + Args: + url: Model weight URL. + **kwargs: Additional configuration options. + + Returns: + Configuration dictionary. + """ return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bilinear', @@ -2098,7 +2280,16 @@ def _cfg(url='', **kwargs): } -def _cfgr(url='', **kwargs): +def _cfgr(url: str = '', **kwargs) -> Dict[str, Any]: + """Create RepVGG configuration dictionary. + + Args: + url: Model weight URL. + **kwargs: Additional configuration options. + + Returns: + Configuration dictionary. + """ return { 'url': url, 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8), 'crop_pct': 0.9, 'interpolation': 'bicubic', diff --git a/timm/models/coat.py b/timm/models/coat.py index 906ecb908..e0b0bcfb8 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -1,4 +1,4 @@ -""" +""" CoaT architecture. Paper: Co-Scale Conv-Attentional Image Transformers - https://arxiv.org/abs/2104.06399 @@ -44,8 +44,8 @@ def __init__(self, head_chs, num_heads, window): elif isinstance(window, dict): self.window = window else: - raise ValueError() - + raise ValueError() + self.conv_list = nn.ModuleList() self.head_splits = [] for cur_window, cur_head_split in window.items(): @@ -56,9 +56,9 @@ def __init__(self, head_chs, num_heads, window): cur_conv = nn.Conv2d( cur_head_split * head_chs, cur_head_split * head_chs, - kernel_size=(cur_window, cur_window), + kernel_size=(cur_window, cur_window), padding=(padding_size, padding_size), - dilation=(dilation, dilation), + dilation=(dilation, dilation), groups=cur_head_split * head_chs, ) self.conv_list.append(cur_conv) @@ -138,13 +138,13 @@ def forward(self, x, size: Tuple[int, int]): class ConvPosEnc(nn.Module): - """ Convolutional Position Encoding. + """ Convolutional Position Encoding. Note: This module is similar to the conditional position encoding in CPVT. """ def __init__(self, dim, k=3): super(ConvPosEnc, self).__init__() - self.proj = nn.Conv2d(dim, dim, k, 1, k//2, groups=dim) - + self.proj = nn.Conv2d(dim, dim, k, 1, k//2, groups=dim) + def forward(self, x, size: Tuple[int, int]): B, N, C = x.shape H, W = size @@ -152,7 +152,7 @@ def forward(self, x, size: Tuple[int, int]): # Extract CLS token and image tokens. cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C] - + # Depthwise convolution. feat = img_tokens.transpose(1, 2).view(B, C, H, W) x = self.proj(feat) + feat @@ -212,9 +212,9 @@ def forward(self, x, size: Tuple[int, int]): x = self.cpe(x, size) cur = self.norm1(x) cur = self.factoratt_crpe(cur, size) - x = x + self.drop_path(cur) + x = x + self.drop_path(cur) - # MLP. + # MLP. cur = self.norm2(x) cur = self.mlp(cur) x = x + self.drop_path(cur) @@ -300,7 +300,7 @@ def interpolate(self, x, scale_factor: float, size: Tuple[int, int]): cls_token = x[:, :1, :] img_tokens = x[:, 1:, :] - + img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W) img_tokens = F.interpolate( img_tokens, @@ -310,7 +310,7 @@ def interpolate(self, x, scale_factor: float, size: Tuple[int, int]): align_corners=False, ) img_tokens = img_tokens.reshape(B, C, -1).transpose(1, 2) - + out = torch.cat((cls_token, img_tokens), dim=1) return out @@ -332,11 +332,11 @@ def forward(self, x1, x2, x3, x4, sizes: List[Tuple[int, int]]): cur2 = cur2 + upsample3_2 + upsample4_2 cur3 = cur3 + upsample4_3 + downsample2_3 cur4 = cur4 + downsample3_4 + downsample2_4 - x2 = x2 + self.drop_path(cur2) - x3 = x3 + self.drop_path(cur3) - x4 = x4 + self.drop_path(cur4) + x2 = x2 + self.drop_path(cur2) + x3 = x3 + self.drop_path(cur3) + x4 = x4 + self.drop_path(cur4) - # MLP. + # MLP. cur2 = self.norm22(x2) cur3 = self.norm23(x3) cur4 = self.norm24(x4) @@ -345,7 +345,7 @@ def forward(self, x1, x2, x3, x4, sizes: List[Tuple[int, int]]): cur4 = self.mlp4(cur4) x2 = x2 + self.drop_path(cur2) x3 = x3 + self.drop_path(cur3) - x4 = x4 + self.drop_path(cur4) + x4 = x4 + self.drop_path(cur4) return x1, x2, x3, x4 @@ -576,7 +576,7 @@ def forward_features(self, x0): for blk in self.serial_blocks1: x1 = blk(x1, size=(H1, W1)) x1_nocls = remove_cls(x1).reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous() - + # Serial blocks 2. x2 = self.patch_embed2(x1_nocls) H2, W2 = self.patch_embed2.grid_size @@ -605,7 +605,7 @@ def forward_features(self, x0): if self.parallel_blocks is None: if not torch.jit.is_scripting() and self.return_interm_layers: # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2). - feat_out = {} + feat_out = {} if 'x1_nocls' in self.out_features: feat_out['x1_nocls'] = x1_nocls if 'x2_nocls' in self.out_features: @@ -627,7 +627,7 @@ def forward_features(self, x0): if not torch.jit.is_scripting() and self.return_interm_layers: # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2). - feat_out = {} + feat_out = {} if 'x1_nocls' in self.out_features: x1_nocls = remove_cls(x1).reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous() feat_out['x1_nocls'] = x1_nocls diff --git a/timm/models/convmixer.py b/timm/models/convmixer.py index c7a250776..de986e3ba 100644 --- a/timm/models/convmixer.py +++ b/timm/models/convmixer.py @@ -82,7 +82,7 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): if global_pool is not None: self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - + def forward_features(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): diff --git a/timm/models/convnext.py b/timm/models/convnext.py index e2eb48d37..509c49a2c 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -38,7 +38,7 @@ # No code was used directly from ConvNeXt-V2, however the weights are CC BY-NC 4.0 so beware if using commercially. from functools import partial -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -57,8 +57,17 @@ class Downsample(nn.Module): + """Downsample module for ConvNeXt.""" - def __init__(self, in_chs, out_chs, stride=1, dilation=1): + def __init__(self, in_chs: int, out_chs: int, stride: int = 1, dilation: int = 1) -> None: + """Initialize Downsample module. + + Args: + in_chs: Number of input channels. + out_chs: Number of output channels. + stride: Stride for downsampling. + dilation: Dilation rate. + """ super().__init__() avg_stride = stride if dilation == 1 else 1 if stride > 1 or dilation > 1: @@ -72,14 +81,16 @@ def __init__(self, in_chs, out_chs, stride=1, dilation=1): else: self.conv = nn.Identity() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass.""" x = self.pool(x) x = self.conv(x) return x class ConvNeXtBlock(nn.Module): - """ ConvNeXt Block + """ConvNeXt Block. + There are two equivalent implementations: (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back @@ -148,7 +159,8 @@ def __init__( self.shortcut = nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass.""" shortcut = x x = self.conv_dw(x) if self.use_conv_mlp: @@ -167,24 +179,43 @@ def forward(self, x): class ConvNeXtStage(nn.Module): + """ConvNeXt stage (multiple blocks).""" def __init__( self, - in_chs, - out_chs, - kernel_size=7, - stride=2, - depth=2, - dilation=(1, 1), - drop_path_rates=None, - ls_init_value=1.0, - conv_mlp=False, - conv_bias=True, - use_grn=False, - act_layer='gelu', - norm_layer=None, - norm_layer_cl=None - ): + in_chs: int, + out_chs: int, + kernel_size: int = 7, + stride: int = 2, + depth: int = 2, + dilation: Tuple[int, int] = (1, 1), + drop_path_rates: Optional[List[float]] = None, + ls_init_value: float = 1.0, + conv_mlp: bool = False, + conv_bias: bool = True, + use_grn: bool = False, + act_layer: Union[str, Callable] = 'gelu', + norm_layer: Optional[Callable] = None, + norm_layer_cl: Optional[Callable] = None + ) -> None: + """Initialize ConvNeXt stage. + + Args: + in_chs: Number of input channels. + out_chs: Number of output channels. + kernel_size: Kernel size for depthwise convolution. + stride: Stride for downsampling. + depth: Number of blocks in stage. + dilation: Dilation rates. + drop_path_rates: Drop path rates for each block. + ls_init_value: Initial value for layer scale. + conv_mlp: Use convolutional MLP. + conv_bias: Use bias in convolutions. + use_grn: Use global response normalization. + act_layer: Activation layer. + norm_layer: Normalization layer. + norm_layer_cl: Normalization layer for channels last. + """ super().__init__() self.grad_checkpointing = False @@ -226,7 +257,8 @@ def __init__( in_chs = out_chs self.blocks = nn.Sequential(*stage_blocks) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass.""" x = self.downsample(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) @@ -264,8 +296,9 @@ def _get_norm_layers(norm_layer: Union[Callable, str], conv_mlp: bool, norm_eps: class ConvNeXt(nn.Module): - r""" ConvNeXt - A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf + """ConvNeXt model architecture. + + A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf """ def __init__( @@ -406,7 +439,15 @@ def __init__( named_apply(partial(_init_weights, head_init_scale=head_init_scale), self) @torch.jit.ignore - def group_matcher(self, coarse=False): + def group_matcher(self, coarse: bool = False) -> Dict[str, Union[str, List]]: + """Create regex patterns for parameter grouping. + + Args: + coarse: Use coarse grouping. + + Returns: + Dictionary mapping group names to regex patterns. + """ return dict( stem=r'^stem', blocks=r'^stages\.(\d+)' if coarse else [ @@ -417,15 +458,27 @@ def group_matcher(self, coarse=False): ) @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): + def set_grad_checkpointing(self, enable: bool = True) -> None: + """Enable or disable gradient checkpointing. + + Args: + enable: Whether to enable gradient checkpointing. + """ for s in self.stages: s.grad_checkpointing = enable @torch.jit.ignore def get_classifier(self) -> nn.Module: + """Get the classifier module.""" return self.head.fc - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None: + """Reset the classifier head. + + Args: + num_classes: Number of classes for new classifier. + global_pool: Global pooling type. + """ self.num_classes = num_classes self.head.reset(num_classes, global_pool) @@ -438,17 +491,18 @@ def forward_intermediates( output_fmt: str = 'NCHW', intermediates_only: bool = False, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: - """ Forward features that returns intermediates. + """Forward features that returns intermediates. Args: - x: Input image tensor - indices: Take last n blocks if int, all if None, select matching indices if sequence - norm: Apply norm layer to compatible intermediates - stop_early: Stop iterating over blocks when last desired intermediate hit - output_fmt: Shape of intermediate feature outputs - intermediates_only: Only return intermediate features - Returns: + x: Input image tensor. + indices: Take last n blocks if int, all if None, select matching indices if sequence. + norm: Apply norm layer to compatible intermediates. + stop_early: Stop iterating over blocks when last desired intermediate hit. + output_fmt: Shape of intermediate feature outputs. + intermediates_only: Only return intermediate features. + Returns: + List of intermediate features or tuple of (final features, intermediates). """ assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' intermediates = [] @@ -483,8 +537,16 @@ def prune_intermediate_layers( indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, - ): - """ Prune layers not required for specified intermediates. + ) -> List[int]: + """Prune layers not required for specified intermediates. + + Args: + indices: Indices of intermediate layers to keep. + prune_norm: Whether to prune normalization layer. + prune_head: Whether to prune the classifier head. + + Returns: + List of indices that were kept. """ take_indices, max_index = feature_take_indices(len(self.stages), indices) self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 @@ -494,22 +556,40 @@ def prune_intermediate_layers( self.reset_classifier(0, '') return take_indices - def forward_features(self, x): + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through feature extraction layers.""" x = self.stem(x) x = self.stages(x) x = self.norm_pre(x) return x - def forward_head(self, x, pre_logits: bool = False): + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + """Forward pass through classifier head. + + Args: + x: Feature tensor. + pre_logits: Return features before final classifier. + + Returns: + Output tensor. + """ return self.head(x, pre_logits=True) if pre_logits else self.head(x) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass.""" x = self.forward_features(x) x = self.forward_head(x) return x -def _init_weights(module, name=None, head_init_scale=1.0): +def _init_weights(module: nn.Module, name: Optional[str] = None, head_init_scale: float = 1.0) -> None: + """Initialize model weights. + + Args: + module: Module to initialize. + name: Module name. + head_init_scale: Scale factor for head initialization. + """ if isinstance(module, nn.Conv2d): trunc_normal_(module.weight, std=.02) if module.bias is not None: diff --git a/timm/models/davit.py b/timm/models/davit.py index f538ecca8..a82f2e5fa 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -447,7 +447,7 @@ def __init__( ''' repeating alternating attention blocks in each stage default: (spatial -> channel) x depth - + potential opportunity to integrate with a more general version of ByobNet/ByoaNet since the logic is similar ''' @@ -503,7 +503,7 @@ class DaVit(nn.Module): r""" DaViT A PyTorch implementation of `DaViT: Dual Attention Vision Transformers` - https://arxiv.org/abs/2204.03645 Supports arbitrary input sizes and pyramid feature extraction - + Args: in_chans (int): Number of input image channels. Default: 3 num_classes (int): Number of classes for classification head. Default: 1000 @@ -669,7 +669,7 @@ def forward_intermediates( stages = self.stages else: stages = self.stages[:max_index + 1] - + for feat_idx, stage in enumerate(stages): x = stage(x) if feat_idx in take_indices: diff --git a/timm/models/densenet.py b/timm/models/densenet.py index d52296590..1a9f9887a 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -4,6 +4,7 @@ """ import re from collections import OrderedDict +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn @@ -20,15 +21,30 @@ class DenseLayer(nn.Module): + """Dense layer for DenseNet. + + Implements the bottleneck layer with 1x1 and 3x3 convolutions. + """ + def __init__( self, - num_input_features, - growth_rate, - bn_size, - norm_layer=BatchNormAct2d, - drop_rate=0., - grad_checkpointing=False, - ): + num_input_features: int, + growth_rate: int, + bn_size: int, + norm_layer: type = BatchNormAct2d, + drop_rate: float = 0., + grad_checkpointing: bool = False, + ) -> None: + """Initialize DenseLayer. + + Args: + num_input_features: Number of input features. + growth_rate: Growth rate (k) of the layer. + bn_size: Bottleneck size multiplier. + norm_layer: Normalization layer class. + drop_rate: Dropout rate. + grad_checkpointing: Use gradient checkpointing. + """ super(DenseLayer, self).__init__() self.add_module('norm1', norm_layer(num_input_features)), self.add_module('conv1', nn.Conv2d( @@ -39,23 +55,23 @@ def __init__( self.drop_rate = float(drop_rate) self.grad_checkpointing = grad_checkpointing - def bottleneck_fn(self, xs): - # type: (List[torch.Tensor]) -> torch.Tensor + def bottleneck_fn(self, xs: List[torch.Tensor]) -> torch.Tensor: + """Bottleneck function for concatenated features.""" concated_features = torch.cat(xs, 1) bottleneck_output = self.conv1(self.norm1(concated_features)) # noqa: T484 return bottleneck_output # todo: rewrite when torchscript supports any - def any_requires_grad(self, x): - # type: (List[torch.Tensor]) -> bool + def any_requires_grad(self, x: List[torch.Tensor]) -> bool: + """Check if any tensor in list requires gradient.""" for tensor in x: if tensor.requires_grad: return True return False @torch.jit.unused # noqa: T484 - def call_checkpoint_bottleneck(self, x): - # type: (List[torch.Tensor]) -> torch.Tensor + def call_checkpoint_bottleneck(self, x: List[torch.Tensor]) -> torch.Tensor: + """Call bottleneck function with gradient checkpointing.""" def closure(*xs): return self.bottleneck_fn(xs) @@ -73,7 +89,15 @@ def forward(self, x): # torchscript does not yet support *args, so we overload method # allowing it to take either a List[Tensor] or single Tensor - def forward(self, x): # noqa: F811 + def forward(self, x: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor: # noqa: F811 + """Forward pass. + + Args: + x: Input features (single tensor or list of tensors). + + Returns: + New features to be concatenated. + """ if isinstance(x, torch.Tensor): prev_features = [x] else: @@ -93,18 +117,33 @@ def forward(self, x): # noqa: F811 class DenseBlock(nn.ModuleDict): + """DenseNet Block. + + Contains multiple dense layers with concatenated features. + """ _version = 2 def __init__( self, - num_layers, - num_input_features, - bn_size, - growth_rate, - norm_layer=BatchNormAct2d, - drop_rate=0., - grad_checkpointing=False, - ): + num_layers: int, + num_input_features: int, + bn_size: int, + growth_rate: int, + norm_layer: type = BatchNormAct2d, + drop_rate: float = 0., + grad_checkpointing: bool = False, + ) -> None: + """Initialize DenseBlock. + + Args: + num_layers: Number of layers in the block. + num_input_features: Number of input features. + bn_size: Bottleneck size multiplier. + growth_rate: Growth rate (k) for each layer. + norm_layer: Normalization layer class. + drop_rate: Dropout rate. + grad_checkpointing: Use gradient checkpointing. + """ super(DenseBlock, self).__init__() for i in range(num_layers): layer = DenseLayer( @@ -117,7 +156,15 @@ def __init__( ) self.add_module('denselayer%d' % (i + 1), layer) - def forward(self, init_features): + def forward(self, init_features: torch.Tensor) -> torch.Tensor: + """Forward pass through all layers in the block. + + Args: + init_features: Initial features from previous layer. + + Returns: + Concatenated features from all layers. + """ features = [init_features] for name, layer in self.items(): new_features = layer(features) @@ -126,13 +173,26 @@ def forward(self, init_features): class DenseTransition(nn.Sequential): + """Transition layer between DenseNet blocks. + + Reduces feature dimensions and spatial resolution. + """ + def __init__( self, - num_input_features, - num_output_features, - norm_layer=BatchNormAct2d, - aa_layer=None, - ): + num_input_features: int, + num_output_features: int, + norm_layer: type = BatchNormAct2d, + aa_layer: Optional[type] = None, + ) -> None: + """Initialize DenseTransition. + + Args: + num_input_features: Number of input features. + num_output_features: Number of output features. + norm_layer: Normalization layer class. + aa_layer: Anti-aliasing layer class. + """ super(DenseTransition, self).__init__() self.add_module('norm', norm_layer(num_input_features)) self.add_module('conv', nn.Conv2d( @@ -144,38 +204,57 @@ def __init__( class DenseNet(nn.Module): - r"""Densenet-BC model class, based on - `"Densely Connected Convolutional Networks" `_ + """Densenet-BC model class. + + Based on `"Densely Connected Convolutional Networks" `_ Args: - growth_rate (int) - how many filters to add each layer (`k` in paper) - block_config (list of 4 ints) - how many layers in each pooling block - bn_size (int) - multiplicative factor for number of bottle neck layers - (i.e. bn_size * k features in the bottleneck layer) - drop_rate (float) - dropout rate before classifier layer - proj_drop_rate (float) - dropout rate after each dense layer - num_classes (int) - number of classification classes - memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, - but slower. Default: *False*. See `"paper" `_ + growth_rate: How many filters to add each layer (`k` in paper). + block_config: How many layers in each pooling block. + bn_size: Multiplicative factor for number of bottle neck layers + (i.e. bn_size * k features in the bottleneck layer). + drop_rate: Dropout rate before classifier layer. + proj_drop_rate: Dropout rate after each dense layer. + num_classes: Number of classification classes. + memory_efficient: If True, uses checkpointing. Much more memory efficient, + but slower. Default: *False*. See `"paper" `_. """ def __init__( self, - growth_rate=32, - block_config=(6, 12, 24, 16), - num_classes=1000, - in_chans=3, - global_pool='avg', - bn_size=4, - stem_type='', - act_layer='relu', - norm_layer='batchnorm2d', - aa_layer=None, - drop_rate=0., - proj_drop_rate=0., - memory_efficient=False, - aa_stem_only=True, - ): + growth_rate: int = 32, + block_config: Tuple[int, ...] = (6, 12, 24, 16), + num_classes: int = 1000, + in_chans: int = 3, + global_pool: str = 'avg', + bn_size: int = 4, + stem_type: str = '', + act_layer: str = 'relu', + norm_layer: str = 'batchnorm2d', + aa_layer: Optional[type] = None, + drop_rate: float = 0., + proj_drop_rate: float = 0., + memory_efficient: bool = False, + aa_stem_only: bool = True, + ) -> None: + """Initialize DenseNet. + + Args: + growth_rate: How many filters to add each layer (k in paper). + block_config: How many layers in each pooling block. + num_classes: Number of classification classes. + in_chans: Number of input channels. + global_pool: Global pooling type. + bn_size: Multiplicative factor for number of bottle neck layers. + stem_type: Type of stem ('', 'deep', 'deep_tiered'). + act_layer: Activation layer. + norm_layer: Normalization layer. + aa_layer: Anti-aliasing layer. + drop_rate: Dropout rate before classifier layer. + proj_drop_rate: Dropout rate after each dense layer. + memory_efficient: If True, uses checkpointing for memory efficiency. + aa_stem_only: Apply anti-aliasing only to stem. + """ self.num_classes = num_classes super(DenseNet, self).__init__() norm_layer = get_norm_act_layer(norm_layer, act_layer=act_layer) @@ -269,7 +348,8 @@ def __init__( nn.init.constant_(m.bias, 0) @torch.jit.ignore - def group_matcher(self, coarse=False): + def group_matcher(self, coarse: bool = False) -> Dict[str, Any]: + """Group parameters for optimization.""" matcher = dict( stem=r'^features\.conv[012]|features\.norm[012]|features\.pool[012]', blocks=r'^features\.(?:denseblock|transition)(\d+)' if coarse else [ @@ -280,35 +360,69 @@ def group_matcher(self, coarse=False): return matcher @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): + def set_grad_checkpointing(self, enable: bool = True) -> None: + """Enable or disable gradient checkpointing.""" for b in self.features.modules(): if isinstance(b, DenseLayer): b.grad_checkpointing = enable @torch.jit.ignore def get_classifier(self) -> nn.Module: + """Get the classifier head.""" return self.classifier - def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg') -> None: + """Reset the classifier head. + + Args: + num_classes: Number of classes for new classifier. + global_pool: Global pooling type. + """ self.num_classes = num_classes self.global_pool, self.classifier = create_classifier( self.num_features, self.num_classes, pool_type=global_pool) - def forward_features(self, x): + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through feature extraction layers.""" return self.features(x) - def forward_head(self, x, pre_logits: bool = False): + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + """Forward pass through classifier head. + + Args: + x: Feature tensor. + pre_logits: Return features before final classifier. + + Returns: + Output tensor. + """ x = self.global_pool(x) x = self.head_drop(x) return x if pre_logits else self.classifier(x) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor. + + Returns: + Output logits. + """ x = self.forward_features(x) x = self.forward_head(x) return x -def _filter_torchvision_pretrained(state_dict): +def _filter_torchvision_pretrained(state_dict: dict) -> Dict[str, torch.Tensor]: + """Filter torchvision pretrained state dict for compatibility. + + Args: + state_dict: State dictionary from torchvision checkpoint. + + Returns: + Filtered state dictionary. + """ pattern = re.compile( r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') @@ -321,7 +435,25 @@ def _filter_torchvision_pretrained(state_dict): return state_dict -def _create_densenet(variant, growth_rate, block_config, pretrained, **kwargs): +def _create_densenet( + variant: str, + growth_rate: int, + block_config: Tuple[int, ...], + pretrained: bool, + **kwargs, +) -> DenseNet: + """Create a DenseNet model. + + Args: + variant: Model variant name. + growth_rate: Growth rate parameter. + block_config: Block configuration. + pretrained: Load pretrained weights. + **kwargs: Additional model arguments. + + Returns: + DenseNet model instance. + """ kwargs['growth_rate'] = growth_rate kwargs['block_config'] = block_config return build_model_with_cfg( @@ -334,7 +466,8 @@ def _create_densenet(variant, growth_rate, block_config, pretrained, **kwargs): ) -def _cfg(url='', **kwargs): +def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: + """Create default configuration for DenseNet models.""" return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bicubic', diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py index e21be9713..333f73f0d 100644 --- a/timm/models/edgenext.py +++ b/timm/models/edgenext.py @@ -451,7 +451,7 @@ def forward_intermediates( stages = self.stages else: stages = self.stages[:max_index + 1] - + for feat_idx, stage in enumerate(stages): x = stage(x) if feat_idx in take_indices: diff --git a/timm/models/efficientformer_v2.py b/timm/models/efficientformer_v2.py index 5bdc473fc..39c832c8c 100644 --- a/timm/models/efficientformer_v2.py +++ b/timm/models/efficientformer_v2.py @@ -659,7 +659,7 @@ def forward_intermediates( stages = self.stages else: stages = self.stages[:max_index + 1] - + for feat_idx, stage in enumerate(stages): x = stage(x) if feat_idx in take_indices: diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index b5bc35c03..245f54406 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -36,7 +36,7 @@ Hacked together by / Copyright 2019, Ross Wightman """ from functools import partial -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -57,7 +57,7 @@ class EfficientNet(nn.Module): - """ EfficientNet + """EfficientNet model architecture. A flexible and performant PyTorch implementation of efficient network architectures, including: * EfficientNet-V2 Small, Medium, Large, XL & B0-B3 @@ -70,6 +70,12 @@ class EfficientNet(nn.Module): * FBNet C * Single-Path NAS Pixel1 * TinyNet + + References: + - EfficientNet: https://arxiv.org/abs/1905.11946 + - EfficientNetV2: https://arxiv.org/abs/2104.00298 + - MixNet: https://arxiv.org/abs/1907.09595 + - MnasNet: https://arxiv.org/abs/1807.11626 """ def __init__( @@ -91,7 +97,28 @@ def __init__( drop_rate: float = 0., drop_path_rate: float = 0., global_pool: str = 'avg' - ): + ) -> None: + """Initialize EfficientNet model. + + Args: + block_args: Arguments for building blocks. + num_classes: Number of classifier classes. + num_features: Number of features for penultimate layer. + in_chans: Number of input channels. + stem_size: Number of output channels in stem. + stem_kernel_size: Kernel size for stem convolution. + fix_stem: If True, don't scale stem channels. + output_stride: Output stride of network. + pad_type: Padding type. + act_layer: Activation layer class. + norm_layer: Normalization layer class. + aa_layer: Anti-aliasing layer class. + se_layer: Squeeze-and-excitation layer class. + round_chs_fn: Channel rounding function. + drop_rate: Dropout rate for classifier. + drop_path_rate: Drop path rate for stochastic depth. + global_pool: Global pooling type. + """ super(EfficientNet, self).__init__() act_layer = act_layer or nn.ReLU norm_layer = norm_layer or nn.BatchNorm2d @@ -138,7 +165,8 @@ def __init__( efficientnet_init_weights(self) - def as_sequential(self): + def as_sequential(self) -> nn.Sequential: + """Convert model to sequential for feature extraction.""" layers = [self.conv_stem, self.bn1] layers.extend(self.blocks) layers.extend([self.conv_head, self.bn2, self.global_pool]) @@ -146,7 +174,15 @@ def as_sequential(self): return nn.Sequential(*layers) @torch.jit.ignore - def group_matcher(self, coarse=False): + def group_matcher(self, coarse: bool = False) -> Dict[str, Union[str, List]]: + """Create regex patterns for parameter groups. + + Args: + coarse: Use coarse (stage-level) grouping. + + Returns: + Dictionary mapping group names to regex patterns. + """ return dict( stem=r'^conv_stem|bn1', blocks=[ @@ -156,14 +192,26 @@ def group_matcher(self, coarse=False): ) @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): + def set_grad_checkpointing(self, enable: bool = True) -> None: + """Enable or disable gradient checkpointing. + + Args: + enable: Whether to enable gradient checkpointing. + """ self.grad_checkpointing = enable @torch.jit.ignore def get_classifier(self) -> nn.Module: + """Get the classifier module.""" return self.classifier - def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg') -> None: + """Reset the classifier head. + + Args: + num_classes: Number of classes for new classifier. + global_pool: Global pooling type. + """ self.num_classes = num_classes self.global_pool, self.classifier = create_classifier( self.num_features, self.num_classes, pool_type=global_pool) @@ -178,18 +226,19 @@ def forward_intermediates( intermediates_only: bool = False, extra_blocks: bool = False, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: - """ Forward features that returns intermediates. + """Forward features that returns intermediates. Args: - x: Input image tensor - indices: Take last n blocks if int, all if None, select matching indices if sequence - norm: Apply norm layer to compatible intermediates - stop_early: Stop iterating over blocks when last desired intermediate hit - output_fmt: Shape of intermediate feature outputs - intermediates_only: Only return intermediate features - extra_blocks: Include outputs of all blocks and head conv in output, does not align with feature_info - Returns: + x: Input image tensor. + indices: Take last n blocks if int, all if None, select matching indices if sequence. + norm: Apply norm layer to compatible intermediates. + stop_early: Stop iterating over blocks when last desired intermediate hit. + output_fmt: Shape of intermediate feature outputs. + intermediates_only: Only return intermediate features. + extra_blocks: Include outputs of all blocks and head conv in output, does not align with feature_info. + Returns: + List of intermediate features or tuple of (final features, intermediates). """ assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' intermediates = [] @@ -231,8 +280,17 @@ def prune_intermediate_layers( prune_norm: bool = False, prune_head: bool = True, extra_blocks: bool = False, - ): - """ Prune layers not required for specified intermediates. + ) -> List[int]: + """Prune layers not required for specified intermediates. + + Args: + indices: Indices of intermediate layers to keep. + prune_norm: Whether to prune normalization layers. + prune_head: Whether to prune the classifier head. + extra_blocks: Include all blocks in indexing. + + Returns: + List of indices that were kept. """ if extra_blocks: take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices) @@ -247,7 +305,8 @@ def prune_intermediate_layers( self.reset_classifier(0, '') return take_indices - def forward_features(self, x): + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through feature extraction layers.""" x = self.conv_stem(x) x = self.bn1(x) if self.grad_checkpointing and not torch.jit.is_scripting(): @@ -258,13 +317,23 @@ def forward_features(self, x): x = self.bn2(x) return x - def forward_head(self, x, pre_logits: bool = False): + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + """Forward pass through classifier head. + + Args: + x: Feature tensor. + pre_logits: Return features before final classifier. + + Returns: + Output tensor. + """ x = self.global_pool(x) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) return x if pre_logits else self.classifier(x) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass.""" x = self.forward_features(x) x = self.forward_head(x) return x @@ -335,7 +404,12 @@ def __init__( self.feature_hooks = FeatureHooks(hooks, self.named_modules()) @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): + def set_grad_checkpointing(self, enable: bool = True) -> None: + """Enable or disable gradient checkpointing. + + Args: + enable: Whether to enable gradient checkpointing. + """ self.grad_checkpointing = enable def forward(self, x) -> List[torch.Tensor]: diff --git a/timm/models/efficientvit_mit.py b/timm/models/efficientvit_mit.py index 27872310e..80fbe4331 100644 --- a/timm/models/efficientvit_mit.py +++ b/timm/models/efficientvit_mit.py @@ -423,7 +423,7 @@ def forward(self, x): return res -def build_local_block( +def build_local_block( in_channels: int, out_channels: int, stride: int, @@ -787,7 +787,7 @@ def forward_intermediates( stages = self.stages else: stages = self.stages[:max_index + 1] - + for feat_idx, stage in enumerate(stages): x = stage(x) if feat_idx in take_indices: @@ -941,7 +941,7 @@ def forward_intermediates( stages = self.stages else: stages = self.stages[:max_index + 1] - + for feat_idx, stage in enumerate(stages): x = stage(x) if feat_idx in take_indices: diff --git a/timm/models/efficientvit_msra.py b/timm/models/efficientvit_msra.py index 91caaa5a4..f3c8db74c 100644 --- a/timm/models/efficientvit_msra.py +++ b/timm/models/efficientvit_msra.py @@ -508,7 +508,7 @@ def forward_intermediates( stages = self.stages else: stages = self.stages[:max_index + 1] - + for feat_idx, stage in enumerate(stages): x = stage(x) if feat_idx in take_indices: diff --git a/timm/models/eva.py b/timm/models/eva.py index 166a07bb0..61301616d 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -26,7 +26,7 @@ # EVA02 models Copyright (c) 2023 BAAI-Vision import math from functools import partial -from typing import Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import torch import torch.nn as nn @@ -167,12 +167,12 @@ def forward( else: q = q * self.scale attn = (q @ k.transpose(-2, -1)) - + if attn_mask is not None: attn_mask = attn_mask.to(torch.bool) attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf")) attn = attn.softmax(dim=-1) - + attn = self.attn_drop(attn) x = attn @ v @@ -599,7 +599,8 @@ def __init__( self.head.weight.data.mul_(head_init_scale) self.head.bias.data.mul_(head_init_scale) - def fix_init_weight(self): + def fix_init_weight(self) -> None: + """Fix initialization weights by rescaling based on layer depth.""" def rescale(param, layer_id): param.div_(math.sqrt(2.0 * layer_id)) @@ -607,23 +608,31 @@ def rescale(param, layer_id): rescale(layer.attn.proj.weight.data, layer_id + 1) rescale(layer.mlp.fc2.weight.data, layer_id + 1) - def _init_weights(self, m): + def _init_weights(self, m: nn.Module) -> None: + """Initialize weights for Linear layers. + + Args: + m: Module to initialize. + """ if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.zeros_(m.bias) @torch.jit.ignore - def no_weight_decay(self): + def no_weight_decay(self) -> Set[str]: + """Parameters to exclude from weight decay.""" nwd = {'pos_embed', 'cls_token'} return nwd @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): + def set_grad_checkpointing(self, enable: bool = True) -> None: + """Enable or disable gradient checkpointing.""" self.grad_checkpointing = enable @torch.jit.ignore - def group_matcher(self, coarse=False): + def group_matcher(self, coarse: bool = False) -> Dict[str, Any]: + """Create layer groupings for optimization.""" matcher = dict( stem=r'^cls_token|pos_embed|patch_embed', # stem and embed blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))], @@ -634,7 +643,13 @@ def group_matcher(self, coarse=False): def get_classifier(self) -> nn.Module: return self.head - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None: + """Reset the classifier head. + + Args: + num_classes: Number of output classes. + global_pool: Global pooling type. + """ self.num_classes = num_classes if global_pool is not None: self.global_pool = global_pool @@ -766,7 +781,15 @@ def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens) return x - def forward_features(self, x): + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through feature extraction layers. + + Args: + x: Input tensor. + + Returns: + Feature tensor. + """ x = self.patch_embed(x) x, rot_pos_embed = self._pos_embed(x) x = self.norm_pre(x) @@ -778,24 +801,50 @@ def forward_features(self, x): x = self.norm(x) return x - def forward_head(self, x, pre_logits: bool = False): + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + """Forward pass through classifier head. + + Args: + x: Feature tensor. + pre_logits: Return pre-logits if True. + + Returns: + Output tensor. + """ x = self.pool(x) x = self.fc_norm(x) x = self.head_drop(x) return x if pre_logits else self.head(x) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor. + + Returns: + Output tensor. + """ x = self.forward_features(x) x = self.forward_head(x) return x def _convert_pe( - state_dict, - model, + state_dict: Dict[str, torch.Tensor], + model: nn.Module, prefix: str = 'visual.', -): - """ Convert Perception Encoder weights """ +) -> Dict[str, torch.Tensor]: + """Convert Perception Encoder weights. + + Args: + state_dict: State dictionary to convert. + model: Target model instance. + prefix: Prefix to strip from keys. + + Returns: + Converted state dictionary. + """ state_dict = state_dict.get('model', state_dict) state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} @@ -852,12 +901,22 @@ def _convert_pe( def checkpoint_filter_fn( - state_dict, - model, - interpolation='bicubic', - antialias=True, -): - """ convert patch embedding weight from manual patchify + linear proj to conv""" + state_dict: Dict[str, torch.Tensor], + model: nn.Module, + interpolation: str = 'bicubic', + antialias: bool = True, +) -> Dict[str, torch.Tensor]: + """Convert patch embedding weight from manual patchify + linear proj to conv. + + Args: + state_dict: Checkpoint state dictionary. + model: Target model instance. + interpolation: Interpolation method for resizing. + antialias: Whether to use antialiasing when resizing. + + Returns: + Filtered state dictionary. + """ out_dict = {} state_dict = state_dict.get('model_ema', state_dict) state_dict = state_dict.get('model', state_dict) @@ -936,7 +995,17 @@ def checkpoint_filter_fn( return out_dict -def _create_eva(variant, pretrained=False, **kwargs): +def _create_eva(variant: str, pretrained: bool = False, **kwargs) -> Eva: + """Create an EVA model. + + Args: + variant: Model variant name. + pretrained: Load pretrained weights. + **kwargs: Additional model arguments. + + Returns: + Instantiated Eva model. + """ out_indices = kwargs.pop('out_indices', 3) model = build_model_with_cfg( Eva, variant, pretrained, @@ -947,7 +1016,16 @@ def _create_eva(variant, pretrained=False, **kwargs): return model -def _cfg(url='', **kwargs): +def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: + """Generate default configuration for EVA models. + + Args: + url: Model weights URL. + **kwargs: Additional configuration parameters. + + Returns: + Model configuration dictionary. + """ return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, @@ -958,7 +1036,16 @@ def _cfg(url='', **kwargs): } -def _pe_cfg(url='', **kwargs): +def _pe_cfg(url: str = '', **kwargs) -> Dict[str, Any]: + """Generate default configuration for Perception Encoder models. + + Args: + url: Model weights URL. + **kwargs: Additional configuration parameters. + + Returns: + Model configuration dictionary. + """ return { 'url': url, 'num_classes': 0, 'input_size': (3, 224, 224), 'pool_size': None, @@ -1203,7 +1290,7 @@ def _pe_cfg(url='', **kwargs): @register_model -def eva_giant_patch14_224(pretrained=False, **kwargs) -> Eva: +def eva_giant_patch14_224(pretrained: bool = False, **kwargs) -> Eva: """ EVA-g model https://arxiv.org/abs/2211.07636 """ model_args = dict(patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408) model = _create_eva('eva_giant_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) @@ -1211,7 +1298,7 @@ def eva_giant_patch14_224(pretrained=False, **kwargs) -> Eva: @register_model -def eva_giant_patch14_336(pretrained=False, **kwargs) -> Eva: +def eva_giant_patch14_336(pretrained: bool = False, **kwargs) -> Eva: """ EVA-g model https://arxiv.org/abs/2211.07636 """ model_args = dict(patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408) model = _create_eva('eva_giant_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs)) @@ -1219,7 +1306,7 @@ def eva_giant_patch14_336(pretrained=False, **kwargs) -> Eva: @register_model -def eva_giant_patch14_560(pretrained=False, **kwargs) -> Eva: +def eva_giant_patch14_560(pretrained: bool = False, **kwargs) -> Eva: """ EVA-g model https://arxiv.org/abs/2211.07636 """ model_args = dict(patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408) model = _create_eva('eva_giant_patch14_560', pretrained=pretrained, **dict(model_args, **kwargs)) @@ -1227,7 +1314,7 @@ def eva_giant_patch14_560(pretrained=False, **kwargs) -> Eva: @register_model -def eva02_tiny_patch14_224(pretrained=False, **kwargs) -> Eva: +def eva02_tiny_patch14_224(pretrained: bool = False, **kwargs) -> Eva: model_args = dict( img_size=224, patch_size=14, @@ -1244,7 +1331,7 @@ def eva02_tiny_patch14_224(pretrained=False, **kwargs) -> Eva: @register_model -def eva02_small_patch14_224(pretrained=False, **kwargs) -> Eva: +def eva02_small_patch14_224(pretrained: bool = False, **kwargs) -> Eva: model_args = dict( img_size=224, patch_size=14, @@ -1261,7 +1348,7 @@ def eva02_small_patch14_224(pretrained=False, **kwargs) -> Eva: @register_model -def eva02_base_patch14_224(pretrained=False, **kwargs) -> Eva: +def eva02_base_patch14_224(pretrained: bool = False, **kwargs) -> Eva: model_args = dict( img_size=224, patch_size=14, @@ -1280,7 +1367,7 @@ def eva02_base_patch14_224(pretrained=False, **kwargs) -> Eva: @register_model -def eva02_large_patch14_224(pretrained=False, **kwargs) -> Eva: +def eva02_large_patch14_224(pretrained: bool = False, **kwargs) -> Eva: model_args = dict( img_size=224, patch_size=14, @@ -1299,7 +1386,7 @@ def eva02_large_patch14_224(pretrained=False, **kwargs) -> Eva: @register_model -def eva02_tiny_patch14_336(pretrained=False, **kwargs) -> Eva: +def eva02_tiny_patch14_336(pretrained: bool = False, **kwargs) -> Eva: model_args = dict( img_size=336, patch_size=14, @@ -1316,7 +1403,7 @@ def eva02_tiny_patch14_336(pretrained=False, **kwargs) -> Eva: @register_model -def eva02_small_patch14_336(pretrained=False, **kwargs) -> Eva: +def eva02_small_patch14_336(pretrained: bool = False, **kwargs) -> Eva: model_args = dict( img_size=336, patch_size=14, @@ -1333,7 +1420,7 @@ def eva02_small_patch14_336(pretrained=False, **kwargs) -> Eva: @register_model -def eva02_base_patch14_448(pretrained=False, **kwargs) -> Eva: +def eva02_base_patch14_448(pretrained: bool = False, **kwargs) -> Eva: model_args = dict( img_size=448, patch_size=14, @@ -1352,7 +1439,7 @@ def eva02_base_patch14_448(pretrained=False, **kwargs) -> Eva: @register_model -def eva02_large_patch14_448(pretrained=False, **kwargs) -> Eva: +def eva02_large_patch14_448(pretrained: bool = False, **kwargs) -> Eva: model_args = dict( img_size=448, patch_size=14, @@ -1371,7 +1458,7 @@ def eva02_large_patch14_448(pretrained=False, **kwargs) -> Eva: @register_model -def eva_giant_patch14_clip_224(pretrained=False, **kwargs) -> Eva: +def eva_giant_patch14_clip_224(pretrained: bool = False, **kwargs) -> Eva: """ EVA-g CLIP model (only difference from non-CLIP is the pooling) """ model_args = dict( patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, @@ -1381,7 +1468,7 @@ def eva_giant_patch14_clip_224(pretrained=False, **kwargs) -> Eva: @register_model -def eva02_base_patch16_clip_224(pretrained=False, **kwargs) -> Eva: +def eva02_base_patch16_clip_224(pretrained: bool = False, **kwargs) -> Eva: """ A EVA-CLIP specific variant that adds additional attn scale layernorm to eva02_base """ model_args = dict( img_size=224, @@ -1403,7 +1490,7 @@ def eva02_base_patch16_clip_224(pretrained=False, **kwargs) -> Eva: @register_model -def eva02_large_patch14_clip_224(pretrained=False, **kwargs) -> Eva: +def eva02_large_patch14_clip_224(pretrained: bool = False, **kwargs) -> Eva: """ A EVA-CLIP specific variant that adds additional attn scale layernorm to eva02_large """ model_args = dict( img_size=224, @@ -1425,7 +1512,7 @@ def eva02_large_patch14_clip_224(pretrained=False, **kwargs) -> Eva: @register_model -def eva02_large_patch14_clip_336(pretrained=False, **kwargs) -> Eva: +def eva02_large_patch14_clip_336(pretrained: bool = False, **kwargs) -> Eva: """ A EVA-CLIP specific variant that adds additional attn scale layernorm to eva02_large """ model_args = dict( img_size=336, @@ -1447,7 +1534,7 @@ def eva02_large_patch14_clip_336(pretrained=False, **kwargs) -> Eva: @register_model -def eva02_enormous_patch14_clip_224(pretrained=False, **kwargs) -> Eva: +def eva02_enormous_patch14_clip_224(pretrained: bool = False, **kwargs) -> Eva: """ A EVA-CLIP specific variant that uses residual post-norm in blocks """ model_args = dict( img_size=224, @@ -1464,7 +1551,7 @@ def eva02_enormous_patch14_clip_224(pretrained=False, **kwargs) -> Eva: @register_model -def vit_medium_patch16_rope_reg1_gap_256(pretrained=False, **kwargs) -> Eva: +def vit_medium_patch16_rope_reg1_gap_256(pretrained: bool = False, **kwargs) -> Eva: model_args = dict( img_size=256, patch_size=16, @@ -1485,7 +1572,7 @@ def vit_medium_patch16_rope_reg1_gap_256(pretrained=False, **kwargs) -> Eva: @register_model -def vit_mediumd_patch16_rope_reg1_gap_256(pretrained=False, **kwargs) -> Eva: +def vit_mediumd_patch16_rope_reg1_gap_256(pretrained: bool = False, **kwargs) -> Eva: model_args = dict( img_size=256, patch_size=16, @@ -1506,7 +1593,7 @@ def vit_mediumd_patch16_rope_reg1_gap_256(pretrained=False, **kwargs) -> Eva: @register_model -def vit_betwixt_patch16_rope_reg4_gap_256(pretrained=False, **kwargs) -> Eva: +def vit_betwixt_patch16_rope_reg4_gap_256(pretrained: bool = False, **kwargs) -> Eva: model_args = dict( img_size=256, patch_size=16, @@ -1527,7 +1614,7 @@ def vit_betwixt_patch16_rope_reg4_gap_256(pretrained=False, **kwargs) -> Eva: @register_model -def vit_base_patch16_rope_reg1_gap_256(pretrained=False, **kwargs) -> Eva: +def vit_base_patch16_rope_reg1_gap_256(pretrained: bool = False, **kwargs) -> Eva: model_args = dict( img_size=256, patch_size=16, @@ -1548,7 +1635,7 @@ def vit_base_patch16_rope_reg1_gap_256(pretrained=False, **kwargs) -> Eva: @register_model -def vit_pe_core_base_patch16_224(pretrained=False, **kwargs): +def vit_pe_core_base_patch16_224(pretrained: bool = False, **kwargs) -> Eva: model_args = dict( patch_size=16, embed_dim=768, @@ -1571,7 +1658,7 @@ def vit_pe_core_base_patch16_224(pretrained=False, **kwargs): @register_model -def vit_pe_core_large_patch14_336(pretrained=False, **kwargs): +def vit_pe_core_large_patch14_336(pretrained: bool = False, **kwargs) -> Eva: model_args = dict( patch_size=14, embed_dim=1024, @@ -1594,7 +1681,7 @@ def vit_pe_core_large_patch14_336(pretrained=False, **kwargs): @register_model -def vit_pe_core_gigantic_patch14_448(pretrained=False, **kwargs): +def vit_pe_core_gigantic_patch14_448(pretrained: bool = False, **kwargs) -> Eva: model_args = dict( patch_size=14, embed_dim=1536, @@ -1617,7 +1704,7 @@ def vit_pe_core_gigantic_patch14_448(pretrained=False, **kwargs): @register_model -def vit_pe_lang_large_patch14_448(pretrained=False, **kwargs): +def vit_pe_lang_large_patch14_448(pretrained: bool = False, **kwargs) -> Eva: model_args = dict( patch_size=14, embed_dim=1024, @@ -1641,7 +1728,7 @@ def vit_pe_lang_large_patch14_448(pretrained=False, **kwargs): @register_model -def vit_pe_lang_gigantic_patch14_448(pretrained=False, **kwargs): +def vit_pe_lang_gigantic_patch14_448(pretrained: bool = False, **kwargs) -> Eva: model_args = dict( patch_size=14, embed_dim=1536, @@ -1664,7 +1751,7 @@ def vit_pe_lang_gigantic_patch14_448(pretrained=False, **kwargs): @register_model -def vit_pe_spatial_gigantic_patch14_448(pretrained=False, **kwargs): +def vit_pe_spatial_gigantic_patch14_448(pretrained: bool = False, **kwargs) -> Eva: model_args = dict( patch_size=14, embed_dim=1536, diff --git a/timm/models/fasternet.py b/timm/models/fasternet.py index b9f857b06..d73f49a26 100644 --- a/timm/models/fasternet.py +++ b/timm/models/fasternet.py @@ -52,7 +52,7 @@ def forward_slicing(self, x: torch.Tensor) -> torch.Tensor: x[:, :self.dim_conv3, :, :] = self.partial_conv3(x[:, :self.dim_conv3, :, :]) return x - def forward_split_cat(self, x: torch.Tensor) -> torch.Tensor: + def forward_split_cat(self, x: torch.Tensor) -> torch.Tensor: # for training/inference x1, x2 = torch.split(x, [self.dim_conv3, self.dim_untouched], dim=1) x1 = self.partial_conv3(x1) @@ -74,7 +74,7 @@ def __init__( ): super().__init__() mlp_hidden_dim = int(dim * mlp_ratio) - + self.mlp = nn.Sequential(*[ nn.Conv2d(dim, mlp_hidden_dim, 1, bias=False), norm_layer(mlp_hidden_dim), @@ -152,7 +152,7 @@ class PatchEmbed(nn.Module): def __init__( self, in_chans: int, - embed_dim: int, + embed_dim: int, patch_size: Union[int, Tuple[int, int]] = 4, norm_layer: LayerType = nn.BatchNorm2d, ): @@ -327,7 +327,7 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages): x = stage(x) if feat_idx in take_indices: - intermediates.append(x) + intermediates.append(x) if intermediates_only: return intermediates @@ -361,7 +361,7 @@ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tenso if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) return x if pre_logits else self.classifier(x) - + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.forward_features(x) x = self.forward_head(x) diff --git a/timm/models/focalnet.py b/timm/models/focalnet.py index ec7cd1cff..aa3237925 100644 --- a/timm/models/focalnet.py +++ b/timm/models/focalnet.py @@ -503,7 +503,7 @@ def forward_intermediates( if intermediates_only: return intermediates - + if feat_idx == last_idx: x = self.norm(x) diff --git a/timm/models/gcvit.py b/timm/models/gcvit.py index c862dc4a2..214619de9 100644 --- a/timm/models/gcvit.py +++ b/timm/models/gcvit.py @@ -531,7 +531,7 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages): x = stage(x) if feat_idx in take_indices: - intermediates.append(x) + intermediates.append(x) if intermediates_only: return intermediates diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index 2f1587015..a2ddad469 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -113,7 +113,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class GhostModuleV3(nn.Module): def __init__( - self, + self, in_chs: int, out_chs: int, kernel_size: int = 1, @@ -152,7 +152,7 @@ def __init__( self.cheap_rpr_scale = ConvBnAct(init_chs, new_chs, 1, 1, pad_type=0, group_size=1, act_layer=None) self.cheap_activation = act_layer(inplace=True) - self.short_conv = nn.Sequential( + self.short_conv = nn.Sequential( nn.Conv2d(in_chs, out_chs, kernel_size, stride, kernel_size//2, bias=False), nn.BatchNorm2d(out_chs), nn.Conv2d(out_chs, out_chs, kernel_size=(1,5), stride=1, padding=(0,2), groups=out_chs, bias=False), @@ -164,7 +164,7 @@ def __init__( self.in_channels = init_chs self.groups = init_chs self.kernel_size = dw_size - + def forward(self, x): if self.infer_mode: x1 = self.primary_conv(x) @@ -179,9 +179,9 @@ def forward(self, x): for cheap_rpr_conv in self.cheap_rpr_conv: x2 += cheap_rpr_conv(x1) x2 = self.cheap_activation(x2) - + out = torch.cat([x1,x2], dim=1) - if self.mode not in ['shortcut']: + if self.mode not in ['shortcut']: return out else: res = self.short_conv(F.avg_pool2d(x, kernel_size=2, stride=2)) @@ -211,7 +211,7 @@ def _get_kernel_bias_primary(self): kernel_final = kernel_conv + kernel_scale + kernel_identity bias_final = bias_conv + bias_scale + bias_identity return kernel_final, bias_final - + def _get_kernel_bias_cheap(self): kernel_scale = 0 bias_scale = 0 @@ -235,7 +235,7 @@ def _get_kernel_bias_cheap(self): kernel_final = kernel_conv + kernel_scale + kernel_identity bias_final = bias_conv + bias_scale + bias_identity return kernel_final, bias_final - + def _fuse_bn_tensor(self, branch): if isinstance(branch, ConvBnAct): kernel = branch.conv.weight @@ -285,7 +285,7 @@ def switch_to_deploy(self): self.primary_conv.weight.data = primary_kernel self.primary_conv.bias.data = primary_bias self.primary_conv = nn.Sequential( - self.primary_conv, + self.primary_conv, self.primary_activation if self.primary_activation is not None else nn.Sequential() ) @@ -304,7 +304,7 @@ def switch_to_deploy(self): self.cheap_operation.bias.data = cheap_bias self.cheap_operation = nn.Sequential( - self.cheap_operation, + self.cheap_operation, self.cheap_activation if self.cheap_activation is not None else nn.Sequential() ) @@ -326,7 +326,7 @@ def switch_to_deploy(self): self.__delattr__('cheap_rpr_skip') self.infer_mode = True - + def reparameterize(self): self.switch_to_deploy() @@ -370,7 +370,7 @@ def __init__( # Point-wise linear projection self.ghost2 = GhostModule(mid_chs, out_chs, act_layer=nn.Identity) - + # shortcut if in_chs == out_chs and self.stride == 1: self.shortcut = nn.Sequential() @@ -401,19 +401,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # 2nd ghost bottleneck x = self.ghost2(x) - + x += self.shortcut(shortcut) return x -class GhostBottleneckV3(nn.Module): +class GhostBottleneckV3(nn.Module): """ GhostV3 bottleneck w/ optional SE""" def __init__( - self, - in_chs: int, - mid_chs: int, - out_chs: int, + self, + in_chs: int, + mid_chs: int, + out_chs: int, dw_kernel_size: int = 3, stride: int = 1, act_layer: LayerType = nn.ReLU, @@ -436,7 +436,7 @@ def __init__( # Depth-wise convolution if self.stride > 1: self.dw_rpr_conv = nn.ModuleList( - [ConvBnAct(mid_chs, mid_chs, dw_kernel_size, stride, pad_type=(dw_kernel_size - 1) // 2, + [ConvBnAct(mid_chs, mid_chs, dw_kernel_size, stride, pad_type=(dw_kernel_size - 1) // 2, group_size=1, act_layer=None) for _ in range(self.num_conv_branches)] ) # Re-parameterizable scale branch @@ -453,7 +453,7 @@ def __init__( # Point-wise linear projection self.ghost2 = GhostModuleV3(mid_chs, out_chs, act_layer=nn.Identity, mode='original') - + # shortcut if in_chs == out_chs and self.stride == 1: self.shortcut = nn.Identity() @@ -489,7 +489,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # 2nd ghost bottleneck x = self.ghost2(x) - + x += self.shortcut(shortcut) return x @@ -549,7 +549,7 @@ def _fuse_bn_tensor(self, branch): t = (gamma / std).reshape(-1, 1, 1, 1) return kernel * t, beta - running_mean * gamma / std - def switch_to_deploy(self): + def switch_to_deploy(self): if self.infer_mode or self.stride == 1: return dw_kernel, dw_bias = self._get_kernel_bias_dw() @@ -578,7 +578,7 @@ def switch_to_deploy(self): self.__delattr__('dw_rpr_skip') self.infer_mode = True - + def reparameterize(self): self.switch_to_deploy() @@ -642,8 +642,8 @@ def __init__( out_chs = make_divisible(exp_size * width, 4) stages.append(nn.Sequential(ConvBnAct(prev_chs, out_chs, 1))) self.pool_dim = prev_chs = out_chs - - self.blocks = nn.Sequential(*stages) + + self.blocks = nn.Sequential(*stages) # building last several layers self.num_features = prev_chs @@ -729,7 +729,7 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages, start=1): x = stage(x) if feat_idx in take_indices: - intermediates.append(x) + intermediates.append(x) if intermediates_only: return intermediates @@ -750,7 +750,7 @@ def prune_intermediate_layers( self.blocks = self.blocks[:max_index + 1] # truncate blocks w/ stem as idx 0 if prune_head: self.reset_classifier(0, '') - return take_indices + return take_indices def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.conv_stem(x) @@ -801,7 +801,7 @@ def _create_ghostnet(variant: str, width: float = 1.0, pretrained: bool = False, Constructs a GhostNet model """ cfgs = [ - # k, t, c, SE, s + # k, t, c, SE, s # stage1 [[3, 16, 16, 0, 1]], # stage2 diff --git a/timm/models/hgnet.py b/timm/models/hgnet.py index 3e44c9dc9..212cbb58f 100644 --- a/timm/models/hgnet.py +++ b/timm/models/hgnet.py @@ -544,7 +544,7 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages): x = stage(x) if feat_idx in take_indices: - intermediates.append(x) + intermediates.append(x) if intermediates_only: return intermediates diff --git a/timm/models/inception_next.py b/timm/models/inception_next.py index 2fcf123ff..3159ac03b 100644 --- a/timm/models/inception_next.py +++ b/timm/models/inception_next.py @@ -385,7 +385,7 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages): x = stage(x) if feat_idx in take_indices: - intermediates.append(x) + intermediates.append(x) if intermediates_only: return intermediates diff --git a/timm/models/inception_v4.py b/timm/models/inception_v4.py index 315328a21..cadbf9528 100644 --- a/timm/models/inception_v4.py +++ b/timm/models/inception_v4.py @@ -324,7 +324,7 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages): x = stage(x) if feat_idx in take_indices: - intermediates.append(x) + intermediates.append(x) if intermediates_only: return intermediates diff --git a/timm/models/mambaout.py b/timm/models/mambaout.py index 71d12fe67..b28827ed7 100644 --- a/timm/models/mambaout.py +++ b/timm/models/mambaout.py @@ -454,7 +454,7 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages): x = stage(x) if feat_idx in take_indices: - intermediates.append(x) + intermediates.append(x) if channel_first: # reshape to BCHW output format diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index b7d4e7e44..85cc3b423 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -38,7 +38,7 @@ from collections import OrderedDict from dataclasses import dataclass, replace, field from functools import partial -from typing import Callable, Optional, Union, Tuple, List +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -60,6 +60,7 @@ @dataclass class MaxxVitTransformerCfg: + """Configuration for MaxxVit transformer blocks.""" dim_head: int = 32 head_first: bool = True # head ordering in qkv channel dim expand_ratio: float = 4.0 @@ -93,6 +94,7 @@ def __post_init__(self): @dataclass class MaxxVitConvCfg: + """Configuration for MaxxVit convolution blocks.""" block_type: str = 'mbconv' expand_ratio: float = 4.0 expand_output: bool = True # calculate expansion channels from output (vs input chs) @@ -129,6 +131,7 @@ def __post_init__(self): @dataclass class MaxxVitCfg: + """Configuration for MaxxVit models.""" embed_dim: Tuple[int, ...] = (96, 192, 384, 768) depths: Tuple[int, ...] = (2, 3, 5, 2) block_type: Tuple[Union[str, Tuple[str, ...]], ...] = ('C', 'C', 'T', 'T') @@ -136,14 +139,14 @@ class MaxxVitCfg: stem_bias: bool = False conv_cfg: MaxxVitConvCfg = field(default_factory=MaxxVitConvCfg) transformer_cfg: MaxxVitTransformerCfg = field(default_factory=MaxxVitTransformerCfg) - head_hidden_size: int = None + head_hidden_size: Optional[int] = None weight_init: str = 'vit_eff' class Attention2d(nn.Module): + """Multi-head attention for 2D NCHW tensors.""" fused_attn: Final[bool] - """ multi-head attention for 2D NCHW tensors""" def __init__( self, dim: int, @@ -152,10 +155,22 @@ def __init__( bias: bool = True, expand_first: bool = True, head_first: bool = True, - rel_pos_cls: Callable = None, + rel_pos_cls: Optional[Callable] = None, attn_drop: float = 0., proj_drop: float = 0. ): + """ + Args: + dim: Input dimension. + dim_out: Output dimension (defaults to input dimension). + dim_head: Dimension per attention head. + bias: Whether to use bias in qkv and projection. + expand_first: Whether to expand channels before or after qkv. + head_first: Whether heads are first in tensor layout. + rel_pos_cls: Relative position class to use. + attn_drop: Attention dropout rate. + proj_drop: Projection dropout rate. + """ super().__init__() dim_out = dim_out or dim dim_attn = dim_out if expand_first else dim @@ -171,7 +186,7 @@ def __init__( self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias) self.proj_drop = nn.Dropout(proj_drop) - def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): + def forward(self, x: torch.Tensor, shared_rel_pos: Optional[torch.Tensor] = None) -> torch.Tensor: B, C, H, W = x.shape if self.head_first: @@ -210,7 +225,7 @@ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): class AttentionCl(nn.Module): - """ Channels-last multi-head attention (B, ..., C) """ + """Channels-last multi-head attention (B, ..., C).""" fused_attn: Final[bool] def __init__( @@ -221,10 +236,22 @@ def __init__( bias: bool = True, expand_first: bool = True, head_first: bool = True, - rel_pos_cls: Callable = None, + rel_pos_cls: Optional[Callable] = None, attn_drop: float = 0., proj_drop: float = 0. ): + """ + Args: + dim: Input dimension. + dim_out: Output dimension (defaults to input dimension). + dim_head: Dimension per attention head. + bias: Whether to use bias in qkv and projection. + expand_first: Whether to expand channels before or after qkv. + head_first: Whether heads are first in tensor layout. + rel_pos_cls: Relative position class to use. + attn_drop: Attention dropout rate. + proj_drop: Projection dropout rate. + """ super().__init__() dim_out = dim_out or dim dim_attn = dim_out if expand_first and dim_out > dim else dim @@ -241,7 +268,7 @@ def __init__( self.proj = nn.Linear(dim_attn, dim_out, bias=bias) self.proj_drop = nn.Dropout(proj_drop) - def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): + def forward(self, x: torch.Tensor, shared_rel_pos: Optional[torch.Tensor] = None) -> torch.Tensor: B = x.shape[0] restore_shape = x.shape[:-1] @@ -280,29 +307,46 @@ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): class LayerScale(nn.Module): - def __init__(self, dim, init_values=1e-5, inplace=False): + """Per-channel scaling layer.""" + + def __init__(self, dim: int, init_values: float = 1e-5, inplace: bool = False): + """ + Args: + dim: Number of channels. + init_values: Initial scaling value. + inplace: Whether to perform inplace operations. + """ super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: gamma = self.gamma return x.mul_(gamma) if self.inplace else x * gamma class LayerScale2d(nn.Module): - def __init__(self, dim, init_values=1e-5, inplace=False): + """Per-channel scaling layer for 2D tensors.""" + + def __init__(self, dim: int, init_values: float = 1e-5, inplace: bool = False): + """ + Args: + dim: Number of channels. + init_values: Initial scaling value. + inplace: Whether to perform inplace operations. + """ super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: gamma = self.gamma.view(1, -1, 1, 1) return x.mul_(gamma) if self.inplace else x * gamma class Downsample2d(nn.Module): - """ A downsample pooling module supporting several maxpool and avgpool modes + """A downsample pooling module supporting several maxpool and avgpool modes. + * 'max' - MaxPool2d w/ kernel_size 3, stride 2, padding 1 * 'max2' - MaxPool2d w/ kernel_size = stride = 2 * 'avg' - AvgPool2d w/ kernel_size 3, stride 2, padding 1 @@ -317,6 +361,14 @@ def __init__( padding: str = '', bias: bool = True, ): + """ + Args: + dim: Input dimension. + dim_out: Output dimension. + pool_type: Type of pooling operation. + padding: Padding mode. + bias: Whether to use bias in expansion conv. + """ super().__init__() assert pool_type in ('max', 'max2', 'avg', 'avg2') if pool_type == 'max': @@ -334,13 +386,14 @@ def __init__( else: self.expand = nn.Identity() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.pool(x) # spatial downsample x = self.expand(x) # expand chs return x -def _init_transformer(module, name, scheme=''): +def _init_transformer(module: nn.Module, name: str, scheme: str = '') -> None: + """Initialize transformer module weights.""" if isinstance(module, (nn.Conv2d, nn.Linear)): if scheme == 'normal': nn.init.normal_(module.weight, std=.02) @@ -365,7 +418,8 @@ def _init_transformer(module, name, scheme=''): class TransformerBlock2d(nn.Module): - """ Transformer block with 2D downsampling + """Transformer block with 2D downsampling. + '2D' NCHW tensor layout Some gains can be seen on GPU using a 1D / CL block, BUT w/ the need to switch back/forth to NCHW @@ -379,10 +433,19 @@ def __init__( dim: int, dim_out: int, stride: int = 1, - rel_pos_cls: Callable = None, + rel_pos_cls: Optional[Callable] = None, cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), drop_path: float = 0., ): + """ + Args: + dim: Input dimension. + dim_out: Output dimension. + stride: Stride for downsampling. + rel_pos_cls: Relative position class. + cfg: Transformer block configuration. + drop_path: Drop path rate. + """ super().__init__() norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) act_layer = get_act_layer(cfg.act_layer) @@ -420,16 +483,17 @@ def __init__( self.ls2 = LayerScale2d(dim_out, init_values=cfg.init_values) if cfg.init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - def init_weights(self, scheme=''): + def init_weights(self, scheme: str = '') -> None: named_apply(partial(_init_transformer, scheme=scheme), self) - def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): + def forward(self, x: torch.Tensor, shared_rel_pos: Optional[torch.Tensor] = None) -> torch.Tensor: x = self.shortcut(x) + self.drop_path1(self.ls1(self.attn(self.norm1(x), shared_rel_pos=shared_rel_pos))) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x -def _init_conv(module, name, scheme=''): +def _init_conv(module: nn.Module, name: str, scheme: str = '') -> None: + """Initialize convolution module weights.""" if isinstance(module, nn.Conv2d): if scheme == 'normal': nn.init.normal_(module.weight, std=.02) @@ -452,7 +516,8 @@ def _init_conv(module, name, scheme=''): nn.init.zeros_(module.bias) -def num_groups(group_size, channels): +def num_groups(group_size: Optional[int], channels: int) -> int: + """Calculate number of groups for grouped convolution.""" if not group_size: # 0 or None return 1 # normal conv with 1 group else: @@ -462,8 +527,8 @@ def num_groups(group_size, channels): class MbConvBlock(nn.Module): - """ Pre-Norm Conv Block - 1x1 - kxk - 1x1, w/ inverted bottleneck (expand) - """ + """Pre-Norm Conv Block - 1x1 - kxk - 1x1, w/ inverted bottleneck (expand).""" + def __init__( self, in_chs: int, @@ -473,6 +538,15 @@ def __init__( cfg: MaxxVitConvCfg = MaxxVitConvCfg(), drop_path: float = 0. ): + """ + Args: + in_chs: Input channels. + out_chs: Output channels. + stride: Stride for conv. + dilation: Dilation for conv. + cfg: Convolution block configuration. + drop_path: Drop path rate. + """ super(MbConvBlock, self).__init__() norm_act_layer = partial(get_norm_act_layer(cfg.norm_layer, cfg.act_layer), eps=cfg.norm_eps) mid_chs = make_divisible((out_chs if cfg.expand_output else in_chs) * cfg.expand_ratio) @@ -527,10 +601,10 @@ def __init__( self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=cfg.output_bias) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - def init_weights(self, scheme=''): + def init_weights(self, scheme: str = '') -> None: named_apply(partial(_init_conv, scheme=scheme), self) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: shortcut = self.shortcut(x) x = self.pre_norm(x) x = self.down(x) @@ -554,8 +628,7 @@ def forward(self, x): class ConvNeXtBlock(nn.Module): - """ ConvNeXt Block - """ + """ConvNeXt Block.""" def __init__( self, @@ -568,6 +641,17 @@ def __init__( conv_mlp: bool = True, drop_path: float = 0. ): + """ + Args: + in_chs: Input channels. + out_chs: Output channels. + kernel_size: Kernel size for depthwise conv. + stride: Stride for conv. + dilation: Dilation for conv. + cfg: Convolution block configuration. + conv_mlp: Whether to use convolutional MLP. + drop_path: Drop path rate. + """ super().__init__() out_chs = out_chs or in_chs act_layer = get_act_layer(cfg.act_layer) @@ -611,7 +695,7 @@ def __init__( self.ls = LayerScale(out_chs, cfg.init_values) if cfg.init_values else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: shortcut = self.shortcut(x) x = self.down(x) x = self.conv_dw(x) @@ -630,7 +714,8 @@ def forward(self, x): return x -def window_partition(x, window_size: List[int]): +def window_partition(x: torch.Tensor, window_size: List[int]) -> torch.Tensor: + """Partition into non-overlapping windows.""" B, H, W, C = x.shape _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})') _assert(W % window_size[1] == 0, f'width ({W}) must be divisible by window ({window_size[1]})') @@ -640,7 +725,8 @@ def window_partition(x, window_size: List[int]): @register_notrace_function # reason: int argument is a Proxy -def window_reverse(windows, window_size: List[int], img_size: List[int]): +def window_reverse(windows: torch.Tensor, window_size: List[int], img_size: List[int]) -> torch.Tensor: + """Reverse window partition.""" H, W = img_size C = windows.shape[-1] x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C) @@ -648,7 +734,8 @@ def window_reverse(windows, window_size: List[int], img_size: List[int]): return x -def grid_partition(x, grid_size: List[int]): +def grid_partition(x: torch.Tensor, grid_size: List[int]) -> torch.Tensor: + """Partition into overlapping windows with grid striding.""" B, H, W, C = x.shape _assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}') _assert(W % grid_size[1] == 0, f'width {W} must be divisible by grid {grid_size[1]}') @@ -658,7 +745,8 @@ def grid_partition(x, grid_size: List[int]): @register_notrace_function # reason: int argument is a Proxy -def grid_reverse(windows, grid_size: List[int], img_size: List[int]): +def grid_reverse(windows: torch.Tensor, grid_size: List[int], img_size: List[int]) -> torch.Tensor: + """Reverse grid partition.""" H, W = img_size C = windows.shape[-1] x = windows.view(-1, H // grid_size[0], W // grid_size[1], grid_size[0], grid_size[1], C) @@ -666,7 +754,8 @@ def grid_reverse(windows, grid_size: List[int], img_size: List[int]): return x -def get_rel_pos_cls(cfg: MaxxVitTransformerCfg, window_size): +def get_rel_pos_cls(cfg: MaxxVitTransformerCfg, window_size: Tuple[int, int]) -> Optional[Callable]: + """Get relative position class based on config.""" rel_pos_cls = None if cfg.rel_pos_type == 'mlp': rel_pos_cls = partial(RelPosMlp, window_size=window_size, hidden_dim=cfg.rel_pos_dim) @@ -678,7 +767,8 @@ def get_rel_pos_cls(cfg: MaxxVitTransformerCfg, window_size): class PartitionAttentionCl(nn.Module): - """ Grid or Block partition + Attn + FFN. + """Grid or Block partition + Attn + FFN. + NxC 'channels last' tensor layout. """ @@ -742,7 +832,8 @@ def forward(self, x): class ParallelPartitionAttention(nn.Module): - """ Experimental. Grid and Block partition + single FFN + """Experimental. Grid and Block partition + single FFN. + NxC tensor layout. """ @@ -752,6 +843,12 @@ def __init__( cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), drop_path: float = 0., ): + """ + Args: + dim: Input dimension. + cfg: Transformer block configuration. + drop_path: Drop path rate. + """ super().__init__() assert dim % 2 == 0 norm_layer = partial(get_norm_layer(cfg.norm_layer_cl), eps=cfg.norm_eps) # NOTE this block is channels-last @@ -795,7 +892,7 @@ def __init__( self.ls2 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - def _partition_attn(self, x): + def _partition_attn(self, x: torch.Tensor) -> torch.Tensor: img_size = x.shape[1:3] partitioned_block = window_partition(x, self.partition_size) @@ -808,13 +905,14 @@ def _partition_attn(self, x): return torch.cat([x_window, x_grid], dim=-1) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x)))) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x -def window_partition_nchw(x, window_size: List[int]): +def window_partition_nchw(x: torch.Tensor, window_size: List[int]) -> torch.Tensor: + """Partition windows for NCHW tensors.""" B, C, H, W = x.shape _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})') _assert(W % window_size[1] == 0, f'width ({W}) must be divisible by window ({window_size[1]})') @@ -824,7 +922,8 @@ def window_partition_nchw(x, window_size: List[int]): @register_notrace_function # reason: int argument is a Proxy -def window_reverse_nchw(windows, window_size: List[int], img_size: List[int]): +def window_reverse_nchw(windows: torch.Tensor, window_size: List[int], img_size: List[int]) -> torch.Tensor: + """Reverse window partition for NCHW tensors.""" H, W = img_size C = windows.shape[1] x = windows.view(-1, H // window_size[0], W // window_size[1], C, window_size[0], window_size[1]) @@ -832,7 +931,8 @@ def window_reverse_nchw(windows, window_size: List[int], img_size: List[int]): return x -def grid_partition_nchw(x, grid_size: List[int]): +def grid_partition_nchw(x: torch.Tensor, grid_size: List[int]) -> torch.Tensor: + """Grid partition for NCHW tensors.""" B, C, H, W = x.shape _assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}') _assert(W % grid_size[1] == 0, f'width {W} must be divisible by grid {grid_size[1]}') @@ -842,7 +942,8 @@ def grid_partition_nchw(x, grid_size: List[int]): @register_notrace_function # reason: int argument is a Proxy -def grid_reverse_nchw(windows, grid_size: List[int], img_size: List[int]): +def grid_reverse_nchw(windows: torch.Tensor, grid_size: List[int], img_size: List[int]) -> torch.Tensor: + """Reverse grid partition for NCHW tensors.""" H, W = img_size C = windows.shape[1] x = windows.view(-1, H // grid_size[0], W // grid_size[1], C, grid_size[0], grid_size[1]) @@ -851,7 +952,7 @@ def grid_reverse_nchw(windows, grid_size: List[int], img_size: List[int]): class PartitionAttention2d(nn.Module): - """ Grid or Block partition + Attn + FFN + """Grid or Block partition + Attn + FFN. '2D' NCHW tensor layout. """ @@ -863,6 +964,13 @@ def __init__( cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), drop_path: float = 0., ): + """ + Args: + dim: Input dimension. + partition_type: Partition type ('block' or 'grid'). + cfg: Transformer block configuration. + drop_path: Drop path rate. + """ super().__init__() norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) # NOTE this block is channels-last act_layer = get_act_layer(cfg.act_layer) @@ -894,7 +1002,7 @@ def __init__( self.ls2 = LayerScale2d(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - def _partition_attn(self, x): + def _partition_attn(self, x: torch.Tensor) -> torch.Tensor: img_size = x.shape[-2:] if self.partition_block: partitioned = window_partition_nchw(x, self.partition_size) @@ -909,15 +1017,14 @@ def _partition_attn(self, x): x = grid_reverse_nchw(partitioned, self.partition_size, img_size) return x - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x)))) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x class MaxxVitBlock(nn.Module): - """ MaxVit conv, window partition + FFN , grid partition + FFN - """ + """MaxVit conv, window partition + FFN , grid partition + FFN.""" def __init__( self, @@ -928,6 +1035,16 @@ def __init__( transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), drop_path: float = 0., ): + """Initialize MaxxVitBlock. + + Args: + dim: Input channel dimension. + dim_out: Output channel dimension. + stride: Stride for downsampling. + conv_cfg: Configuration for convolutional blocks. + transformer_cfg: Configuration for transformer blocks. + drop_path: Drop path rate. + """ super().__init__() self.nchw_attn = transformer_cfg.use_nchw_attn @@ -960,20 +1077,31 @@ def forward(self, x): class ParallelMaxxVitBlock(nn.Module): - """ MaxVit block with parallel cat(window + grid), one FF + """MaxVit block with parallel cat(window + grid), one FF. + Experimental timm block. """ def __init__( self, - dim, - dim_out, - stride=1, - num_conv=2, + dim: int, + dim_out: int, + stride: int = 1, + num_conv: int = 2, conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(), transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), - drop_path=0., + drop_path: float = 0., ): + """ + Args: + dim: Input dimension. + dim_out: Output dimension. + stride: Stride for first conv block. + num_conv: Number of convolution blocks. + conv_cfg: Convolution block configuration. + transformer_cfg: Transformer block configuration. + drop_path: Drop path rate. + """ super().__init__() conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock @@ -985,11 +1113,11 @@ def __init__( self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path) self.attn = ParallelPartitionAttention(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) - def init_weights(self, scheme=''): + def init_weights(self, scheme: str = '') -> None: named_apply(partial(_init_transformer, scheme=scheme), self.attn) named_apply(partial(_init_conv, scheme=scheme), self.conv) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x) x = x.permute(0, 2, 3, 1) x = self.attn(x) @@ -998,6 +1126,8 @@ def forward(self, x): class MaxxVitStage(nn.Module): + """MaxxVit stage consisting of mixed convolution and transformer blocks.""" + def __init__( self, in_chs: int, @@ -1010,6 +1140,18 @@ def __init__( conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(), drop_path: Union[float, List[float]] = 0., ): + """ + Args: + in_chs: Input channels. + out_chs: Output channels. + stride: Stride for first block. + depth: Number of blocks in stage. + feat_size: Feature map size. + block_types: Block types ('C' for conv, 'T' for transformer, etc). + transformer_cfg: Transformer block configuration. + conv_cfg: Convolution block configuration. + drop_path: Drop path rate(s). + """ super().__init__() self.grad_checkpointing = False @@ -1058,7 +1200,7 @@ def __init__( in_chs = out_chs self.blocks = nn.Sequential(*blocks) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) else: @@ -1067,6 +1209,7 @@ def forward(self, x): class Stem(nn.Module): + """Stem layer for feature extraction.""" def __init__( self, @@ -1079,6 +1222,17 @@ def __init__( norm_layer: str = 'batchnorm2d', norm_eps: float = 1e-5, ): + """ + Args: + in_chs: Input channels. + out_chs: Output channels. + kernel_size: Kernel size for convolutions. + padding: Padding mode. + bias: Whether to use bias. + act_layer: Activation layer. + norm_layer: Normalization layer. + norm_eps: Normalization epsilon. + """ super().__init__() if not isinstance(out_chs, (list, tuple)): out_chs = to_2tuple(out_chs) @@ -1091,17 +1245,18 @@ def __init__( self.norm1 = norm_act_layer(out_chs[0]) self.conv2 = create_conv2d(out_chs[0], out_chs[1], kernel_size, stride=1, padding=padding, bias=bias) - def init_weights(self, scheme=''): + def init_weights(self, scheme: str = '') -> None: named_apply(partial(_init_conv, scheme=scheme), self) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv1(x) x = self.norm1(x) x = self.conv2(x) return x -def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]): +def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]) -> MaxxVitTransformerCfg: + """Configure window size based on image size and partition ratio.""" if cfg.window_size is not None: assert cfg.grid_size return cfg @@ -1110,7 +1265,8 @@ def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]): return cfg -def _overlay_kwargs(cfg: MaxxVitCfg, **kwargs): +def _overlay_kwargs(cfg: MaxxVitCfg, **kwargs: Any) -> MaxxVitCfg: + """Overlay keyword arguments onto configuration.""" transformer_kwargs = {} conv_kwargs = {} base_kwargs = {} @@ -1131,7 +1287,7 @@ def _overlay_kwargs(cfg: MaxxVitCfg, **kwargs): class MaxxVit(nn.Module): - """ CoaTNet + MaxVit base model. + """CoaTNet + MaxVit base model. Highly configurable for different block compositions, tensor layouts, pooling types. """ @@ -1145,8 +1301,19 @@ def __init__( global_pool: str = 'avg', drop_rate: float = 0., drop_path_rate: float = 0., - **kwargs, + **kwargs: Any, ): + """ + Args: + cfg: Model configuration. + img_size: Input image size. + in_chans: Number of input channels. + num_classes: Number of classification classes. + global_pool: Global pooling type. + drop_rate: Dropout rate. + drop_path_rate: Drop path rate. + **kwargs: Additional keyword arguments to overlay on config. + """ super().__init__() img_size = to_2tuple(img_size) if kwargs: @@ -1219,7 +1386,7 @@ def __init__( if cfg.weight_init: named_apply(partial(self._init_weights, scheme=cfg.weight_init), self) - def _init_weights(self, module, name, scheme=''): + def _init_weights(self, module: nn.Module, name: str, scheme: str = '') -> None: if hasattr(module, 'init_weights'): try: module.init_weights(scheme=scheme) @@ -1227,13 +1394,13 @@ def _init_weights(self, module, name, scheme=''): module.init_weights() @torch.jit.ignore - def no_weight_decay(self): + def no_weight_decay(self) -> Set[str]: return { k for k, _ in self.named_parameters() if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])} @torch.jit.ignore - def group_matcher(self, coarse=False): + def group_matcher(self, coarse: bool = False) -> Dict[str, Any]: matcher = dict( stem=r'^stem', # stem and embed blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))] @@ -1241,7 +1408,7 @@ def group_matcher(self, coarse=False): return matcher @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): + def set_grad_checkpointing(self, enable: bool = True) -> None: for s in self.stages: s.grad_checkpointing = enable @@ -1249,7 +1416,7 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.head.fc - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None: self.num_classes = num_classes self.head.reset(num_classes, global_pool) @@ -1312,9 +1479,8 @@ def prune_intermediate_layers( indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, - ): - """ Prune layers not required for specified intermediates. - """ + ) -> Tuple[int, ...]: + """Prune layers not required for specified intermediates.""" take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices) self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0 if prune_norm: @@ -1323,48 +1489,53 @@ def prune_intermediate_layers( self.head = self.reset_classifier(0, '') return take_indices - def forward_features(self, x): + def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.stem(x) x = self.stages(x) x = self.norm(x) return x - def forward_head(self, x, pre_logits: bool = False): + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.forward_features(x) x = self.forward_head(x) return x def _rw_coat_cfg( - stride_mode='pool', - pool_type='avg2', - conv_output_bias=False, - conv_attn_early=False, - conv_attn_act_layer='relu', - conv_norm_layer='', - transformer_shortcut_bias=True, - transformer_norm_layer='layernorm2d', - transformer_norm_layer_cl='layernorm', - init_values=None, - rel_pos_type='bias', - rel_pos_dim=512, -): - # 'RW' timm variant models were created and trained before seeing https://github.com/google-research/maxvit - # Common differences for initial timm models: - # - pre-norm layer in MZBConv included an activation after norm - # - mbconv expansion calculated from input instead of output chs - # - mbconv shortcut and final 1x1 conv did not have a bias - # - SE act layer was relu, not silu - # - mbconv uses silu in timm, not gelu - # - expansion in attention block done via output proj, not input proj - # Variable differences (evolved over training initial models): - # - avg pool with kernel_size=2 favoured downsampling (instead of maxpool for coat) - # - SE attention was between conv2 and norm/act - # - default to avg pool for mbconv downsample instead of 1x1 or dw conv - # - transformer block shortcut has no bias + stride_mode: str = 'pool', + pool_type: str = 'avg2', + conv_output_bias: bool = False, + conv_attn_early: bool = False, + conv_attn_act_layer: str = 'relu', + conv_norm_layer: str = '', + transformer_shortcut_bias: bool = True, + transformer_norm_layer: str = 'layernorm2d', + transformer_norm_layer_cl: str = 'layernorm', + init_values: Optional[float] = None, + rel_pos_type: str = 'bias', + rel_pos_dim: int = 512, +) -> Dict[str, Any]: + """RW variant configuration for CoAtNet models. + + These models were created and trained before seeing https://github.com/google-research/maxvit + + Common differences for initial timm models: + - pre-norm layer in MZBConv included an activation after norm + - mbconv expansion calculated from input instead of output chs + - mbconv shortcut and final 1x1 conv did not have a bias + - SE act layer was relu, not silu + - mbconv uses silu in timm, not gelu + - expansion in attention block done via output proj, not input proj + + Variable differences (evolved over training initial models): + - avg pool with kernel_size=2 favoured downsampling (instead of maxpool for coat) + - SE attention was between conv2 and norm/act + - default to avg pool for mbconv downsample instead of 1x1 or dw conv + - transformer block shortcut has no bias + """ return dict( conv_cfg=MaxxVitConvCfg( stride_mode=stride_mode, @@ -1391,25 +1562,29 @@ def _rw_coat_cfg( def _rw_max_cfg( - stride_mode='dw', - pool_type='avg2', - conv_output_bias=False, - conv_attn_ratio=1 / 16, - conv_norm_layer='', - transformer_norm_layer='layernorm2d', - transformer_norm_layer_cl='layernorm', - window_size=None, - dim_head=32, - init_values=None, - rel_pos_type='bias', - rel_pos_dim=512, -): - # 'RW' timm variant models were created and trained before seeing https://github.com/google-research/maxvit - # Differences of initial timm models: - # - mbconv expansion calculated from input instead of output chs - # - mbconv shortcut and final 1x1 conv did not have a bias - # - mbconv uses silu in timm, not gelu - # - expansion in attention block done via output proj, not input proj + stride_mode: str = 'dw', + pool_type: str = 'avg2', + conv_output_bias: bool = False, + conv_attn_ratio: float = 1 / 16, + conv_norm_layer: str = '', + transformer_norm_layer: str = 'layernorm2d', + transformer_norm_layer_cl: str = 'layernorm', + window_size: Optional[Tuple[int, int]] = None, + dim_head: int = 32, + init_values: Optional[float] = None, + rel_pos_type: str = 'bias', + rel_pos_dim: int = 512, +) -> Dict[str, Any]: + """RW variant configuration for MaxViT models. + + These models were created and trained before seeing https://github.com/google-research/maxvit + + Differences of initial timm models: + - mbconv expansion calculated from input instead of output chs + - mbconv shortcut and final 1x1 conv did not have a bias + - mbconv uses silu in timm, not gelu + - expansion in attention block done via output proj, not input proj + """ return dict( conv_cfg=MaxxVitConvCfg( stride_mode=stride_mode, @@ -1435,19 +1610,19 @@ def _rw_max_cfg( def _next_cfg( - stride_mode='dw', - pool_type='avg2', - conv_norm_layer='layernorm2d', - conv_norm_layer_cl='layernorm', - transformer_norm_layer='layernorm2d', - transformer_norm_layer_cl='layernorm', - window_size=None, - no_block_attn=False, - init_values=1e-6, - rel_pos_type='mlp', # MLP by default for maxxvit - rel_pos_dim=512, -): - # For experimental models with convnext instead of mbconv + stride_mode: str = 'dw', + pool_type: str = 'avg2', + conv_norm_layer: str = 'layernorm2d', + conv_norm_layer_cl: str = 'layernorm', + transformer_norm_layer: str = 'layernorm2d', + transformer_norm_layer_cl: str = 'layernorm', + window_size: Optional[Tuple[int, int]] = None, + no_block_attn: bool = False, + init_values: Union[float, Tuple[float, float]] = 1e-6, + rel_pos_type: str = 'mlp', # MLP by default for maxxvit + rel_pos_dim: int = 512, +) -> Dict[str, Any]: + """Configuration for experimental ConvNeXt-based MaxxViT models.""" init_values = to_2tuple(init_values) return dict( conv_cfg=MaxxVitConvCfg( @@ -1473,7 +1648,8 @@ def _next_cfg( ) -def _tf_cfg(): +def _tf_cfg() -> Dict[str, Any]: + """Configuration matching TensorFlow MaxViT models.""" return dict( conv_cfg=MaxxVitConvCfg( norm_eps=1e-3, @@ -1858,7 +2034,8 @@ def _tf_cfg(): ) -def checkpoint_filter_fn(state_dict, model: nn.Module): +def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module) -> Dict[str, torch.Tensor]: + """Filter checkpoint state dict for compatibility.""" model_state_dict = model.state_dict() out_dict = {} for k, v in state_dict.items(): @@ -1879,7 +2056,8 @@ def checkpoint_filter_fn(state_dict, model: nn.Module): return out_dict -def _create_maxxvit(variant, cfg_variant=None, pretrained=False, **kwargs): +def _create_maxxvit(variant: str, cfg_variant: Optional[str] = None, pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """Create a MaxxVit model variant.""" if cfg_variant is None: if variant in model_cfgs: cfg_variant = variant @@ -1893,7 +2071,8 @@ def _create_maxxvit(variant, cfg_variant=None, pretrained=False, **kwargs): **kwargs) -def _cfg(url='', **kwargs): +def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]: + """Create a default configuration dict.""" return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.95, 'interpolation': 'bicubic', @@ -2124,280 +2303,336 @@ def _cfg(url='', **kwargs): @register_model -def coatnet_pico_rw_224(pretrained=False, **kwargs) -> MaxxVit: +def coatnet_pico_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet Pico model with RW configuration.""" return _create_maxxvit('coatnet_pico_rw_224', pretrained=pretrained, **kwargs) @register_model -def coatnet_nano_rw_224(pretrained=False, **kwargs) -> MaxxVit: +def coatnet_nano_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet Nano model with RW configuration.""" return _create_maxxvit('coatnet_nano_rw_224', pretrained=pretrained, **kwargs) @register_model -def coatnet_0_rw_224(pretrained=False, **kwargs) -> MaxxVit: +def coatnet_0_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-0 model with RW configuration.""" return _create_maxxvit('coatnet_0_rw_224', pretrained=pretrained, **kwargs) @register_model -def coatnet_1_rw_224(pretrained=False, **kwargs) -> MaxxVit: +def coatnet_1_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-1 model with RW configuration.""" return _create_maxxvit('coatnet_1_rw_224', pretrained=pretrained, **kwargs) @register_model -def coatnet_2_rw_224(pretrained=False, **kwargs) -> MaxxVit: +def coatnet_2_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-2 model with RW configuration.""" return _create_maxxvit('coatnet_2_rw_224', pretrained=pretrained, **kwargs) @register_model -def coatnet_3_rw_224(pretrained=False, **kwargs) -> MaxxVit: +def coatnet_3_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-3 model with RW configuration.""" return _create_maxxvit('coatnet_3_rw_224', pretrained=pretrained, **kwargs) @register_model -def coatnet_bn_0_rw_224(pretrained=False, **kwargs) -> MaxxVit: +def coatnet_bn_0_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-0 model with BatchNorm and RW configuration.""" return _create_maxxvit('coatnet_bn_0_rw_224', pretrained=pretrained, **kwargs) @register_model -def coatnet_rmlp_nano_rw_224(pretrained=False, **kwargs) -> MaxxVit: +def coatnet_rmlp_nano_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet Nano model with Relative Position MLP.""" return _create_maxxvit('coatnet_rmlp_nano_rw_224', pretrained=pretrained, **kwargs) @register_model -def coatnet_rmlp_0_rw_224(pretrained=False, **kwargs) -> MaxxVit: +def coatnet_rmlp_0_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-0 model with Relative Position MLP.""" return _create_maxxvit('coatnet_rmlp_0_rw_224', pretrained=pretrained, **kwargs) @register_model -def coatnet_rmlp_1_rw_224(pretrained=False, **kwargs) -> MaxxVit: +def coatnet_rmlp_1_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-1 model with Relative Position MLP.""" return _create_maxxvit('coatnet_rmlp_1_rw_224', pretrained=pretrained, **kwargs) @register_model -def coatnet_rmlp_1_rw2_224(pretrained=False, **kwargs) -> MaxxVit: +def coatnet_rmlp_1_rw2_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-1 model with Relative Position MLP v2.""" return _create_maxxvit('coatnet_rmlp_1_rw2_224', pretrained=pretrained, **kwargs) @register_model -def coatnet_rmlp_2_rw_224(pretrained=False, **kwargs) -> MaxxVit: +def coatnet_rmlp_2_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-2 model with Relative Position MLP.""" return _create_maxxvit('coatnet_rmlp_2_rw_224', pretrained=pretrained, **kwargs) @register_model -def coatnet_rmlp_2_rw_384(pretrained=False, **kwargs) -> MaxxVit: +def coatnet_rmlp_2_rw_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-2 model with Relative Position MLP at 384x384.""" return _create_maxxvit('coatnet_rmlp_2_rw_384', pretrained=pretrained, **kwargs) @register_model -def coatnet_rmlp_3_rw_224(pretrained=False, **kwargs) -> MaxxVit: +def coatnet_rmlp_3_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-3 model with Relative Position MLP.""" return _create_maxxvit('coatnet_rmlp_3_rw_224', pretrained=pretrained, **kwargs) @register_model -def coatnet_nano_cc_224(pretrained=False, **kwargs) -> MaxxVit: +def coatnet_nano_cc_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet Nano model with ConvNeXt blocks.""" return _create_maxxvit('coatnet_nano_cc_224', pretrained=pretrained, **kwargs) @register_model -def coatnext_nano_rw_224(pretrained=False, **kwargs) -> MaxxVit: +def coatnext_nano_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoAtNeXt Nano model with RW configuration.""" return _create_maxxvit('coatnext_nano_rw_224', pretrained=pretrained, **kwargs) @register_model -def coatnet_0_224(pretrained=False, **kwargs) -> MaxxVit: +def coatnet_0_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-0 model.""" return _create_maxxvit('coatnet_0_224', pretrained=pretrained, **kwargs) @register_model -def coatnet_1_224(pretrained=False, **kwargs) -> MaxxVit: +def coatnet_1_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-1 model.""" return _create_maxxvit('coatnet_1_224', pretrained=pretrained, **kwargs) @register_model -def coatnet_2_224(pretrained=False, **kwargs) -> MaxxVit: +def coatnet_2_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-2 model.""" return _create_maxxvit('coatnet_2_224', pretrained=pretrained, **kwargs) @register_model -def coatnet_3_224(pretrained=False, **kwargs) -> MaxxVit: +def coatnet_3_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-3 model.""" return _create_maxxvit('coatnet_3_224', pretrained=pretrained, **kwargs) @register_model -def coatnet_4_224(pretrained=False, **kwargs) -> MaxxVit: +def coatnet_4_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-4 model.""" return _create_maxxvit('coatnet_4_224', pretrained=pretrained, **kwargs) @register_model -def coatnet_5_224(pretrained=False, **kwargs) -> MaxxVit: +def coatnet_5_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-5 model.""" return _create_maxxvit('coatnet_5_224', pretrained=pretrained, **kwargs) @register_model -def maxvit_pico_rw_256(pretrained=False, **kwargs) -> MaxxVit: +def maxvit_pico_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Pico model with RW configuration.""" return _create_maxxvit('maxvit_pico_rw_256', pretrained=pretrained, **kwargs) @register_model -def maxvit_nano_rw_256(pretrained=False, **kwargs) -> MaxxVit: +def maxvit_nano_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Nano model with RW configuration.""" return _create_maxxvit('maxvit_nano_rw_256', pretrained=pretrained, **kwargs) @register_model -def maxvit_tiny_rw_224(pretrained=False, **kwargs) -> MaxxVit: +def maxvit_tiny_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Tiny model with RW configuration.""" return _create_maxxvit('maxvit_tiny_rw_224', pretrained=pretrained, **kwargs) @register_model -def maxvit_tiny_rw_256(pretrained=False, **kwargs) -> MaxxVit: +def maxvit_tiny_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Tiny model with RW configuration at 256x256.""" return _create_maxxvit('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs) @register_model -def maxvit_rmlp_pico_rw_256(pretrained=False, **kwargs) -> MaxxVit: +def maxvit_rmlp_pico_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Relative Position MLP Pico RW 256x256 model.""" return _create_maxxvit('maxvit_rmlp_pico_rw_256', pretrained=pretrained, **kwargs) @register_model -def maxvit_rmlp_nano_rw_256(pretrained=False, **kwargs) -> MaxxVit: +def maxvit_rmlp_nano_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Relative Position MLP Nano RW 256x256 model.""" return _create_maxxvit('maxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs) @register_model -def maxvit_rmlp_tiny_rw_256(pretrained=False, **kwargs) -> MaxxVit: +def maxvit_rmlp_tiny_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Relative Position MLP Tiny RW 256x256 model.""" return _create_maxxvit('maxvit_rmlp_tiny_rw_256', pretrained=pretrained, **kwargs) @register_model -def maxvit_rmlp_small_rw_224(pretrained=False, **kwargs) -> MaxxVit: +def maxvit_rmlp_small_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Relative Position MLP Small RW 224x224 model.""" return _create_maxxvit('maxvit_rmlp_small_rw_224', pretrained=pretrained, **kwargs) @register_model -def maxvit_rmlp_small_rw_256(pretrained=False, **kwargs) -> MaxxVit: +def maxvit_rmlp_small_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Small model with Relative Position MLP at 256x256.""" return _create_maxxvit('maxvit_rmlp_small_rw_256', pretrained=pretrained, **kwargs) @register_model -def maxvit_rmlp_base_rw_224(pretrained=False, **kwargs) -> MaxxVit: +def maxvit_rmlp_base_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Base model with Relative Position MLP.""" return _create_maxxvit('maxvit_rmlp_base_rw_224', pretrained=pretrained, **kwargs) @register_model -def maxvit_rmlp_base_rw_384(pretrained=False, **kwargs) -> MaxxVit: +def maxvit_rmlp_base_rw_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Base model with Relative Position MLP at 384x384.""" return _create_maxxvit('maxvit_rmlp_base_rw_384', pretrained=pretrained, **kwargs) @register_model -def maxvit_tiny_pm_256(pretrained=False, **kwargs) -> MaxxVit: +def maxvit_tiny_pm_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Tiny model with parallel blocks.""" return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs) @register_model -def maxxvit_rmlp_nano_rw_256(pretrained=False, **kwargs) -> MaxxVit: +def maxxvit_rmlp_nano_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxxViT Relative Position MLP Nano RW 256x256 model.""" return _create_maxxvit('maxxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs) @register_model -def maxxvit_rmlp_tiny_rw_256(pretrained=False, **kwargs) -> MaxxVit: +def maxxvit_rmlp_tiny_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxxViT Tiny model with Relative Position MLP.""" return _create_maxxvit('maxxvit_rmlp_tiny_rw_256', pretrained=pretrained, **kwargs) @register_model -def maxxvit_rmlp_small_rw_256(pretrained=False, **kwargs) -> MaxxVit: +def maxxvit_rmlp_small_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxxViT Small model with Relative Position MLP.""" return _create_maxxvit('maxxvit_rmlp_small_rw_256', pretrained=pretrained, **kwargs) @register_model -def maxxvitv2_nano_rw_256(pretrained=False, **kwargs) -> MaxxVit: +def maxxvitv2_nano_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxxViT-V2 Nano model.""" return _create_maxxvit('maxxvitv2_nano_rw_256', pretrained=pretrained, **kwargs) @register_model -def maxxvitv2_rmlp_base_rw_224(pretrained=False, **kwargs) -> MaxxVit: +def maxxvitv2_rmlp_base_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxxViT-V2 Base model with Relative Position MLP.""" return _create_maxxvit('maxxvitv2_rmlp_base_rw_224', pretrained=pretrained, **kwargs) @register_model -def maxxvitv2_rmlp_base_rw_384(pretrained=False, **kwargs) -> MaxxVit: +def maxxvitv2_rmlp_base_rw_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxxViT-V2 Base model with Relative Position MLP at 384x384.""" return _create_maxxvit('maxxvitv2_rmlp_base_rw_384', pretrained=pretrained, **kwargs) @register_model -def maxxvitv2_rmlp_large_rw_224(pretrained=False, **kwargs) -> MaxxVit: +def maxxvitv2_rmlp_large_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxxViT-V2 Large model with Relative Position MLP.""" return _create_maxxvit('maxxvitv2_rmlp_large_rw_224', pretrained=pretrained, **kwargs) @register_model -def maxvit_tiny_tf_224(pretrained=False, **kwargs) -> MaxxVit: +def maxvit_tiny_tf_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Tiny model from TensorFlow.""" return _create_maxxvit('maxvit_tiny_tf_224', 'maxvit_tiny_tf', pretrained=pretrained, **kwargs) @register_model -def maxvit_tiny_tf_384(pretrained=False, **kwargs) -> MaxxVit: +def maxvit_tiny_tf_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Tiny model from TensorFlow at 384x384.""" return _create_maxxvit('maxvit_tiny_tf_384', 'maxvit_tiny_tf', pretrained=pretrained, **kwargs) @register_model -def maxvit_tiny_tf_512(pretrained=False, **kwargs) -> MaxxVit: +def maxvit_tiny_tf_512(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Tiny model from TensorFlow at 512x512.""" return _create_maxxvit('maxvit_tiny_tf_512', 'maxvit_tiny_tf', pretrained=pretrained, **kwargs) @register_model -def maxvit_small_tf_224(pretrained=False, **kwargs) -> MaxxVit: +def maxvit_small_tf_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Small model from TensorFlow.""" return _create_maxxvit('maxvit_small_tf_224', 'maxvit_small_tf', pretrained=pretrained, **kwargs) @register_model -def maxvit_small_tf_384(pretrained=False, **kwargs) -> MaxxVit: +def maxvit_small_tf_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Small model from TensorFlow at 384x384.""" return _create_maxxvit('maxvit_small_tf_384', 'maxvit_small_tf', pretrained=pretrained, **kwargs) @register_model -def maxvit_small_tf_512(pretrained=False, **kwargs) -> MaxxVit: +def maxvit_small_tf_512(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Small model from TensorFlow at 512x512.""" return _create_maxxvit('maxvit_small_tf_512', 'maxvit_small_tf', pretrained=pretrained, **kwargs) @register_model -def maxvit_base_tf_224(pretrained=False, **kwargs) -> MaxxVit: +def maxvit_base_tf_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Base model from TensorFlow.""" return _create_maxxvit('maxvit_base_tf_224', 'maxvit_base_tf', pretrained=pretrained, **kwargs) @register_model -def maxvit_base_tf_384(pretrained=False, **kwargs) -> MaxxVit: +def maxvit_base_tf_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Base model from TensorFlow at 384x384.""" return _create_maxxvit('maxvit_base_tf_384', 'maxvit_base_tf', pretrained=pretrained, **kwargs) @register_model -def maxvit_base_tf_512(pretrained=False, **kwargs) -> MaxxVit: +def maxvit_base_tf_512(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Base model from TensorFlow at 512x512.""" return _create_maxxvit('maxvit_base_tf_512', 'maxvit_base_tf', pretrained=pretrained, **kwargs) @register_model -def maxvit_large_tf_224(pretrained=False, **kwargs) -> MaxxVit: +def maxvit_large_tf_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Large model from TensorFlow.""" return _create_maxxvit('maxvit_large_tf_224', 'maxvit_large_tf', pretrained=pretrained, **kwargs) @register_model -def maxvit_large_tf_384(pretrained=False, **kwargs) -> MaxxVit: +def maxvit_large_tf_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Large model from TensorFlow at 384x384.""" return _create_maxxvit('maxvit_large_tf_384', 'maxvit_large_tf', pretrained=pretrained, **kwargs) @register_model -def maxvit_large_tf_512(pretrained=False, **kwargs) -> MaxxVit: +def maxvit_large_tf_512(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Large model from TensorFlow at 512x512.""" return _create_maxxvit('maxvit_large_tf_512', 'maxvit_large_tf', pretrained=pretrained, **kwargs) @register_model -def maxvit_xlarge_tf_224(pretrained=False, **kwargs) -> MaxxVit: +def maxvit_xlarge_tf_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT XLarge model from TensorFlow.""" return _create_maxxvit('maxvit_xlarge_tf_224', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs) @register_model -def maxvit_xlarge_tf_384(pretrained=False, **kwargs) -> MaxxVit: +def maxvit_xlarge_tf_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT XLarge model from TensorFlow at 384x384.""" return _create_maxxvit('maxvit_xlarge_tf_384', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs) @register_model -def maxvit_xlarge_tf_512(pretrained=False, **kwargs) -> MaxxVit: +def maxvit_xlarge_tf_512(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT XLarge model from TensorFlow at 512x512.""" return _create_maxxvit('maxvit_xlarge_tf_512', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs) diff --git a/timm/models/metaformer.py b/timm/models/metaformer.py index 490852cfe..2e93e01b1 100644 --- a/timm/models/metaformer.py +++ b/timm/models/metaformer.py @@ -633,7 +633,7 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages): x = stage(x) if feat_idx in take_indices: - intermediates.append(x) + intermediates.append(x) if intermediates_only: return intermediates diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index 25cde6a67..b024fba47 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -40,7 +40,7 @@ """ import math from functools import partial -from typing import List, Optional, Union, Tuple +from typing import Any, Dict, List, Optional, Union, Tuple import torch import torch.nn as nn @@ -56,20 +56,33 @@ class MixerBlock(nn.Module): - """ Residual Block w/ token mixing and channel MLPs + """Residual Block w/ token mixing and channel MLPs. + Based on: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 """ def __init__( self, - dim, - seq_len, - mlp_ratio=(0.5, 4.0), - mlp_layer=Mlp, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - act_layer=nn.GELU, - drop=0., - drop_path=0., - ): + dim: int, + seq_len: int, + mlp_ratio: Union[float, Tuple[float, float]] = (0.5, 4.0), + mlp_layer: type = Mlp, + norm_layer: type = partial(nn.LayerNorm, eps=1e-6), + act_layer: type = nn.GELU, + drop: float = 0., + drop_path: float = 0., + ) -> None: + """Initialize MixerBlock. + + Args: + dim: Dimension of input features. + seq_len: Sequence length. + mlp_ratio: Expansion ratios for token mixing and channel MLPs. + mlp_layer: MLP layer class. + norm_layer: Normalization layer. + act_layer: Activation layer. + drop: Dropout rate. + drop_path: Drop path rate. + """ super().__init__() tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)] self.norm1 = norm_layer(dim) @@ -78,39 +91,61 @@ def __init__( self.norm2 = norm_layer(dim) self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass.""" x = x + self.drop_path(self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2)) x = x + self.drop_path(self.mlp_channels(self.norm2(x))) return x class Affine(nn.Module): - def __init__(self, dim): + """Affine transformation layer.""" + + def __init__(self, dim: int) -> None: + """Initialize Affine layer. + + Args: + dim: Dimension of features. + """ super().__init__() self.alpha = nn.Parameter(torch.ones((1, 1, dim))) self.beta = nn.Parameter(torch.zeros((1, 1, dim))) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply affine transformation.""" return torch.addcmul(self.beta, self.alpha, x) class ResBlock(nn.Module): - """ Residual MLP block w/ LayerScale and Affine 'norm' + """Residual MLP block w/ LayerScale and Affine 'norm'. Based on: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 """ def __init__( self, - dim, - seq_len, - mlp_ratio=4, - mlp_layer=Mlp, - norm_layer=Affine, - act_layer=nn.GELU, - init_values=1e-4, - drop=0., - drop_path=0., - ): + dim: int, + seq_len: int, + mlp_ratio: float = 4, + mlp_layer: type = Mlp, + norm_layer: type = Affine, + act_layer: type = nn.GELU, + init_values: float = 1e-4, + drop: float = 0., + drop_path: float = 0., + ) -> None: + """Initialize ResBlock. + + Args: + dim: Dimension of input features. + seq_len: Sequence length. + mlp_ratio: Channel MLP expansion ratio. + mlp_layer: MLP layer class. + norm_layer: Normalization layer. + act_layer: Activation layer. + init_values: Initial values for layer scale. + drop: Dropout rate. + drop_path: Drop path rate. + """ super().__init__() channel_dim = int(dim * mlp_ratio) self.norm1 = norm_layer(dim) @@ -121,29 +156,39 @@ def __init__( self.ls1 = nn.Parameter(init_values * torch.ones(dim)) self.ls2 = nn.Parameter(init_values * torch.ones(dim)) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass.""" x = x + self.drop_path(self.ls1 * self.linear_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2)) x = x + self.drop_path(self.ls2 * self.mlp_channels(self.norm2(x))) return x class SpatialGatingUnit(nn.Module): - """ Spatial Gating Unit + """Spatial Gating Unit. Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 """ - def __init__(self, dim, seq_len, norm_layer=nn.LayerNorm): + def __init__(self, dim: int, seq_len: int, norm_layer: type = nn.LayerNorm) -> None: + """Initialize Spatial Gating Unit. + + Args: + dim: Dimension of input features. + seq_len: Sequence length. + norm_layer: Normalization layer. + """ super().__init__() gate_dim = dim // 2 self.norm = norm_layer(gate_dim) self.proj = nn.Linear(seq_len, seq_len) - def init_weights(self): + def init_weights(self) -> None: + """Initialize weights for projection gate.""" # special init for the projection gate, called as override by base model init nn.init.normal_(self.proj.weight, std=1e-6) nn.init.ones_(self.proj.bias) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply spatial gating.""" u, v = x.chunk(2, dim=-1) v = self.norm(v) v = self.proj(v.transpose(-1, -2)) @@ -151,21 +196,33 @@ def forward(self, x): class SpatialGatingBlock(nn.Module): - """ Residual Block w/ Spatial Gating + """Residual Block w/ Spatial Gating. Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 """ def __init__( self, - dim, - seq_len, - mlp_ratio=4, - mlp_layer=GatedMlp, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - act_layer=nn.GELU, - drop=0., - drop_path=0., - ): + dim: int, + seq_len: int, + mlp_ratio: float = 4, + mlp_layer: type = GatedMlp, + norm_layer: type = partial(nn.LayerNorm, eps=1e-6), + act_layer: type = nn.GELU, + drop: float = 0., + drop_path: float = 0., + ) -> None: + """Initialize SpatialGatingBlock. + + Args: + dim: Dimension of input features. + seq_len: Sequence length. + mlp_ratio: Channel MLP expansion ratio. + mlp_layer: MLP layer class. + norm_layer: Normalization layer. + act_layer: Activation layer. + drop: Dropout rate. + drop_path: Drop path rate. + """ super().__init__() channel_dim = int(dim * mlp_ratio) self.norm = norm_layer(dim) @@ -173,33 +230,59 @@ def __init__( self.mlp_channels = mlp_layer(dim, channel_dim, act_layer=act_layer, gate_layer=sgu, drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass.""" x = x + self.drop_path(self.mlp_channels(self.norm(x))) return x class MlpMixer(nn.Module): + """MLP-Mixer model architecture. + + Based on: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 + """ def __init__( self, - num_classes=1000, - img_size=224, - in_chans=3, - patch_size=16, - num_blocks=8, - embed_dim=512, - mlp_ratio=(0.5, 4.0), - block_layer=MixerBlock, - mlp_layer=Mlp, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - act_layer=nn.GELU, - drop_rate=0., - proj_drop_rate=0., - drop_path_rate=0., - nlhb=False, - stem_norm=False, - global_pool='avg', - ): + num_classes: int = 1000, + img_size: int = 224, + in_chans: int = 3, + patch_size: int = 16, + num_blocks: int = 8, + embed_dim: int = 512, + mlp_ratio: Union[float, Tuple[float, float]] = (0.5, 4.0), + block_layer: type = MixerBlock, + mlp_layer: type = Mlp, + norm_layer: type = partial(nn.LayerNorm, eps=1e-6), + act_layer: type = nn.GELU, + drop_rate: float = 0., + proj_drop_rate: float = 0., + drop_path_rate: float = 0., + nlhb: bool = False, + stem_norm: bool = False, + global_pool: str = 'avg', + ) -> None: + """Initialize MLP-Mixer. + + Args: + num_classes: Number of classes for classification. + img_size: Input image size. + in_chans: Number of input channels. + patch_size: Patch size. + num_blocks: Number of mixer blocks. + embed_dim: Embedding dimension. + mlp_ratio: MLP expansion ratio(s). + block_layer: Block layer class. + mlp_layer: MLP layer class. + norm_layer: Normalization layer. + act_layer: Activation layer. + drop_rate: Head dropout rate. + proj_drop_rate: Projection dropout rate. + drop_path_rate: Drop path rate. + nlhb: Use negative log bias initialization. + stem_norm: Apply normalization to stem. + global_pool: Global pooling type. + """ super().__init__() self.num_classes = num_classes self.global_pool = global_pool @@ -236,26 +319,51 @@ def __init__( self.init_weights(nlhb=nlhb) @torch.jit.ignore - def init_weights(self, nlhb=False): + def init_weights(self, nlhb: bool = False) -> None: + """Initialize model weights. + + Args: + nlhb: Use negative log bias initialization for head. + """ head_bias = -math.log(self.num_classes) if nlhb else 0. named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first @torch.jit.ignore - def group_matcher(self, coarse=False): + def group_matcher(self, coarse: bool = False) -> Dict[str, Any]: + """Create regex patterns for parameter grouping. + + Args: + coarse: Use coarse grouping. + + Returns: + Dictionary mapping group names to regex patterns. + """ return dict( stem=r'^stem', # stem and embed blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] ) @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): + def set_grad_checkpointing(self, enable: bool = True) -> None: + """Enable or disable gradient checkpointing. + + Args: + enable: Whether to enable gradient checkpointing. + """ self.grad_checkpointing = enable @torch.jit.ignore def get_classifier(self) -> nn.Module: + """Get the classifier module.""" return self.head - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None: + """Reset the classifier head. + + Args: + num_classes: Number of classes for new classifier. + global_pool: Global pooling type. + """ self.num_classes = num_classes if global_pool is not None: assert global_pool in ('', 'avg') @@ -271,18 +379,18 @@ def forward_intermediates( output_fmt: str = 'NCHW', intermediates_only: bool = False, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: - """ Forward features that returns intermediates. + """Forward features that returns intermediates. Args: - x: Input image tensor - indices: Take last n blocks if int, all if None, select matching indices if sequence - return_prefix_tokens: Return both prefix and spatial intermediate tokens - norm: Apply norm layer to all intermediates - stop_early: Stop iterating over blocks when last desired intermediate hit - output_fmt: Shape of intermediate feature outputs - intermediates_only: Only return intermediate features - Returns: + x: Input image tensor. + indices: Take last n blocks if int, all if None, select matching indices if sequence. + norm: Apply norm layer to all intermediates. + stop_early: Stop iterating over blocks when last desired intermediate hit. + output_fmt: Shape of intermediate feature outputs ('NCHW' or 'NLC'). + intermediates_only: Only return intermediate features. + Returns: + List of intermediate features or tuple of (final features, intermediates). """ assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.' reshape = output_fmt == 'NCHW' @@ -321,8 +429,16 @@ def prune_intermediate_layers( indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, - ): - """ Prune layers not required for specified intermediates. + ) -> List[int]: + """Prune layers not required for specified intermediates. + + Args: + indices: Indices of intermediate layers to keep. + prune_norm: Whether to prune normalization layer. + prune_head: Whether to prune the classifier head. + + Returns: + List of indices that were kept. """ take_indices, max_index = feature_take_indices(len(self.blocks), indices) self.blocks = self.blocks[:max_index + 1] # truncate blocks @@ -332,7 +448,8 @@ def prune_intermediate_layers( self.reset_classifier(0, '') return take_indices - def forward_features(self, x): + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through feature extraction layers.""" x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) @@ -341,20 +458,36 @@ def forward_features(self, x): x = self.norm(x) return x - def forward_head(self, x, pre_logits: bool = False): + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + """Forward pass through classifier head. + + Args: + x: Feature tensor. + pre_logits: Return features before final classifier. + + Returns: + Output tensor. + """ if self.global_pool == 'avg': x = x.mean(dim=1) x = self.head_drop(x) return x if pre_logits else self.head(x) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass.""" x = self.forward_features(x) x = self.forward_head(x) return x -def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=False): - """ Mixer weight initialization (trying to match Flax defaults) +def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax: bool = False) -> None: + """Mixer weight initialization (trying to match Flax defaults). + + Args: + module: Module to initialize. + name: Module name. + head_bias: Bias value for head layer. + flax: Use Flax-style initialization. """ if isinstance(module, nn.Linear): if name.startswith('head'): @@ -404,7 +537,7 @@ def checkpoint_filter_fn(state_dict, model): return state_dict -def _create_mixer(variant, pretrained=False, **kwargs): +def _create_mixer(variant, pretrained=False, **kwargs) -> MlpMixer: out_indices = kwargs.pop('out_indices', 3) model = build_model_with_cfg( MlpMixer, @@ -417,7 +550,7 @@ def _create_mixer(variant, pretrained=False, **kwargs): return model -def _cfg(url='', **kwargs): +def _cfg(url='', **kwargs) -> Dict[str, Any]: return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 08dcb064f..8e25674b6 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -7,7 +7,7 @@ Hacked together by / Copyright 2019, Ross Wightman """ from functools import partial -from typing import Callable, List, Optional, Tuple, Union +from typing import Any, Dict, Callable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -27,7 +27,7 @@ class MobileNetV3(nn.Module): - """ MobiletNet-V3 + """MobileNetV3. Based on my EfficientNet implementation and building blocks, this model utilizes the MobileNet-v3 specific 'efficient head', where global pooling is done before the head convolution without a final batch-norm @@ -64,7 +64,8 @@ def __init__( layer_scale_init_value: Optional[float] = None, global_pool: str = 'avg', ): - """ + """Initialize MobileNetV3. + Args: block_args: Arguments for blocks of the network. num_classes: Number of classes for classification head. @@ -73,6 +74,7 @@ def __init__( fix_stem: If True, don't scale stem by round_chs_fn. num_features: Number of output channels of the conv head layer. head_bias: If True, add a learnable bias to the conv head layer. + head_norm: If True, add normalization to the head layer. pad_type: Type of padding to use for convolution layers. act_layer: Type of activation layer. norm_layer: Type of normalization layer. @@ -137,7 +139,12 @@ def __init__( efficientnet_init_weights(self) - def as_sequential(self): + def as_sequential(self) -> nn.Sequential: + """Convert model to sequential form. + + Returns: + Sequential module containing all layers. + """ layers = [self.conv_stem, self.bn1] layers.extend(self.blocks) layers.extend([self.global_pool, self.conv_head, self.norm_head, self.act2]) @@ -145,21 +152,30 @@ def as_sequential(self): return nn.Sequential(*layers) @torch.jit.ignore - def group_matcher(self, coarse: bool = False): + def group_matcher(self, coarse: bool = False) -> Dict[str, Any]: + """Group parameters for optimization.""" return dict( stem=r'^conv_stem|bn1', blocks=r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)' ) @torch.jit.ignore - def set_grad_checkpointing(self, enable: bool = True): + def set_grad_checkpointing(self, enable: bool = True) -> None: + """Enable or disable gradient checkpointing.""" self.grad_checkpointing = enable @torch.jit.ignore def get_classifier(self) -> nn.Module: + """Get the classifier head.""" return self.classifier - def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg') -> None: + """Reset the classifier head. + + Args: + num_classes: Number of classes for new classifier. + global_pool: Global pooling type. + """ self.num_classes = num_classes # NOTE: cannot meaningfully change pooling of efficient head after creation self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) @@ -228,8 +244,17 @@ def prune_intermediate_layers( prune_norm: bool = False, prune_head: bool = True, extra_blocks: bool = False, - ): - """ Prune layers not required for specified intermediates. + ) -> List[int]: + """Prune layers not required for specified intermediates. + + Args: + indices: Indices of intermediate layers to keep. + prune_norm: Whether to prune normalization layer. + prune_head: Whether to prune the classifier head. + extra_blocks: Include outputs of all blocks. + + Returns: + List of indices that were kept. """ if extra_blocks: take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices) @@ -247,6 +272,14 @@ def prune_intermediate_layers( return take_indices def forward_features(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through feature extraction layers. + + Args: + x: Input tensor. + + Returns: + Feature tensor. + """ x = self.conv_stem(x) x = self.bn1(x) if self.grad_checkpointing and not torch.jit.is_scripting(): @@ -256,6 +289,15 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: return x def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + """Forward pass through classifier head. + + Args: + x: Input features. + pre_logits: Return features before final linear layer. + + Returns: + Classification logits or features. + """ x = self.global_pool(x) x = self.conv_head(x) x = self.norm_head(x) @@ -268,13 +310,21 @@ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tenso return self.classifier(x) def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor. + + Returns: + Output logits. + """ x = self.forward_features(x) x = self.forward_head(x) return x class MobileNetV3Features(nn.Module): - """ MobileNetV3 Feature Extractor + """MobileNetV3 Feature Extractor. A work-in-progress feature extraction module for MobileNet-V3 to use as a backbone for segmentation and object detection models. @@ -300,11 +350,12 @@ def __init__( drop_path_rate: float = 0., layer_scale_init_value: Optional[float] = None, ): - """ + """Initialize MobileNetV3Features. + Args: block_args: Arguments for blocks of the network. out_indices: Output from stages at indices. - feature_location: Location of feature before/after each block, must be in ['bottleneck', 'expansion'] + feature_location: Location of feature before/after each block, must be in ['bottleneck', 'expansion']. in_chans: Number of input image channels. stem_size: Number of output channels of the initial stem convolution. fix_stem: If True, don't scale stem by round_chs_fn. @@ -314,6 +365,7 @@ def __init__( se_from_exp: If True, calculate SE channel reduction from expanded mid channels. act_layer: Type of activation layer. norm_layer: Type of normalization layer. + aa_layer: Type of anti-aliasing layer. se_layer: Type of Squeeze-and-Excite layer. drop_rate: Dropout rate. drop_path_rate: Stochastic depth rate. @@ -360,10 +412,19 @@ def __init__( self.feature_hooks = FeatureHooks(hooks, self.named_modules()) @torch.jit.ignore - def set_grad_checkpointing(self, enable: bool = True): + def set_grad_checkpointing(self, enable: bool = True) -> None: + """Enable or disable gradient checkpointing.""" self.grad_checkpointing = enable def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + """Forward pass through feature extraction. + + Args: + x: Input tensor. + + Returns: + List of feature tensors. + """ x = self.conv_stem(x) x = self.bn1(x) x = self.act1(x) @@ -386,6 +447,16 @@ def forward(self, x: torch.Tensor) -> List[torch.Tensor]: def _create_mnv3(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV3: + """Create a MobileNetV3 model. + + Args: + variant: Model variant name. + pretrained: Load pretrained weights. + **kwargs: Additional model arguments. + + Returns: + MobileNetV3 model instance. + """ features_mode = '' model_cls = MobileNetV3 kwargs_filter = None @@ -420,7 +491,13 @@ def _gen_mobilenet_v3_rw( Paper: https://arxiv.org/abs/1905.02244 Args: - channel_multiplier: multiplier to number of channels per layer. + variant: Model variant name. + channel_multiplier: Multiplier to number of channels per layer. + pretrained: Load pretrained weights. + **kwargs: Additional model arguments. + + Returns: + MobileNetV3 model instance. """ arch_def = [ # stage 0, 112x112 in @@ -452,8 +529,12 @@ def _gen_mobilenet_v3_rw( def _gen_mobilenet_v3( - variant: str, channel_multiplier: float = 1.0, depth_multiplier: float = 1.0, - group_size=None, pretrained: bool = False, **kwargs + variant: str, + channel_multiplier: float = 1.0, + depth_multiplier: float = 1.0, + group_size: Optional[int] = None, + pretrained: bool = False, + **kwargs ) -> MobileNetV3: """Creates a MobileNet-V3 model. @@ -461,7 +542,15 @@ def _gen_mobilenet_v3( Paper: https://arxiv.org/abs/1905.02244 Args: - channel_multiplier: multiplier to number of channels per layer. + variant: Model variant name. + channel_multiplier: Multiplier to number of channels per layer. + depth_multiplier: Depth multiplier for model scaling. + group_size: Group size for grouped convolutions. + pretrained: Load pretrained weights. + **kwargs: Additional model arguments. + + Returns: + MobileNetV3 model instance. """ if 'small' in variant: num_features = 1024 @@ -551,11 +640,21 @@ def _gen_mobilenet_v3( return model -def _gen_fbnetv3(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs): - """ FBNetV3 +def _gen_fbnetv3(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3: + """FBNetV3 model generator. + Paper: `FBNetV3: Joint Architecture-Recipe Search using Predictor Pretraining` - https://arxiv.org/abs/2006.02049 FIXME untested, this is a preliminary impl of some FBNet-V3 variants. + + Args: + variant: Model variant name. + channel_multiplier: Channel width multiplier. + pretrained: Load pretrained weights. + **kwargs: Additional model arguments. + + Returns: + MobileNetV3 model instance. """ vl = variant.split('_')[-1] if vl in ('a', 'b'): @@ -612,14 +711,21 @@ def _gen_fbnetv3(variant: str, channel_multiplier: float = 1.0, pretrained: bool return model -def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs): - """ LCNet +def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3: + """LCNet model generator. + Essentially a MobileNet-V3 crossed with a MobileNet-V1 Paper: `PP-LCNet: A Lightweight CPU Convolutional Neural Network` - https://arxiv.org/abs/2109.15099 Args: - channel_multiplier: multiplier to number of channels per layer. + variant: Model variant name. + channel_multiplier: Multiplier to number of channels per layer. + pretrained: Load pretrained weights. + **kwargs: Additional model arguments. + + Returns: + MobileNetV3 model instance. """ arch_def = [ # stage 0, 112x112 in @@ -651,15 +757,25 @@ def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool = def _gen_mobilenet_v4( - variant: str, channel_multiplier: float = 1.0, group_size=None, pretrained: bool = False, **kwargs, + variant: str, + channel_multiplier: float = 1.0, + group_size: Optional[int] = None, + pretrained: bool = False, + **kwargs, ) -> MobileNetV3: """Creates a MobileNet-V4 model. - Ref impl: ? - Paper: https://arxiv.org/abs/1905.02244 + Paper: https://arxiv.org/abs/2404.10518 Args: - channel_multiplier: multiplier to number of channels per layer. + variant: Model variant name. + channel_multiplier: Multiplier to number of channels per layer. + group_size: Group size for grouped convolutions. + pretrained: Load pretrained weights. + **kwargs: Additional model arguments. + + Returns: + MobileNetV3 model instance. """ num_features = 1280 if 'hybrid' in variant: @@ -899,7 +1015,16 @@ def _gen_mobilenet_v4( return model -def _cfg(url: str = '', **kwargs): +def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: + """Create default configuration dictionary. + + Args: + url: Model weight URL. + **kwargs: Additional configuration options. + + Returns: + Configuration dictionary. + """ return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bilinear', diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index 0bcc04856..a0d2d8a73 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -132,7 +132,7 @@ def __init__(self, stem_size, num_channels, pad_type=''): self.path_1 = nn.Sequential() self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) self.path_1.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False)) - + self.path_2 = nn.Sequential() self.path_2.add_module('pad', nn.ZeroPad2d((-1, 1, -1, 1))) self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 68e92128f..d848f81eb 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -19,7 +19,7 @@ from collections import OrderedDict from dataclasses import dataclass, replace from functools import partial -from typing import Callable, Tuple, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -37,6 +37,7 @@ @dataclass class NfCfg: + """Configuration for Normalization-Free Networks.""" depths: Tuple[int, int, int, int] channels: Tuple[int, int, int, int] alpha: float = 0.2 @@ -44,7 +45,7 @@ class NfCfg: stem_chs: Optional[int] = None group_size: Optional[int] = None attn_layer: Optional[str] = None - attn_kwargs: dict = None + attn_kwargs: Optional[Dict[str, Any]] = None attn_gain: float = 2.0 # NF correction gain to apply if attn layer is used width_factor: float = 1.0 bottle_ratio: float = 0.5 @@ -61,23 +62,51 @@ class NfCfg: class GammaAct(nn.Module): - def __init__(self, act_type='relu', gamma: float = 1.0, inplace=False): + """Activation function with gamma scaling factor.""" + + def __init__(self, act_type: str = 'relu', gamma: float = 1.0, inplace: bool = False): + """Initialize GammaAct. + + Args: + act_type: Type of activation function. + gamma: Scaling factor for activation output. + inplace: Whether to perform activation in-place. + """ super().__init__() self.act_fn = get_act_fn(act_type) self.gamma = gamma self.inplace = inplace - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor. + + Returns: + Scaled activation output. + """ return self.act_fn(x, inplace=self.inplace).mul_(self.gamma) -def act_with_gamma(act_type, gamma: float = 1.): - def _create(inplace=False): +def act_with_gamma(act_type: str, gamma: float = 1.) -> Callable: + """Create activation function factory with gamma scaling. + + Args: + act_type: Type of activation function. + gamma: Scaling factor for activation output. + + Returns: + Activation function factory. + """ + def _create(inplace: bool = False) -> GammaAct: return GammaAct(act_type, gamma=gamma, inplace=inplace) return _create class DownsampleAvg(nn.Module): + """AvgPool downsampling as in 'D' ResNet variants with dilation support.""" + def __init__( self, in_chs: int, @@ -87,7 +116,16 @@ def __init__( first_dilation: Optional[int] = None, conv_layer: Callable = ScaledStdConv2d, ): - """ AvgPool Downsampling as in 'D' ResNet variants. Support for dilation.""" + """Initialize DownsampleAvg. + + Args: + in_chs: Input channels. + out_chs: Output channels. + stride: Stride for downsampling. + dilation: Dilation rate. + first_dilation: First dilation rate (unused). + conv_layer: Convolution layer type. + """ super(DownsampleAvg, self).__init__() avg_stride = stride if dilation == 1 else 1 if stride > 1 or dilation > 1: @@ -97,7 +135,15 @@ def __init__( self.pool = nn.Identity() self.conv = conv_layer(in_chs, out_chs, 1, stride=1) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor. + + Returns: + Downsampled tensor. + """ return self.conv(self.pool(x)) @@ -122,11 +168,33 @@ def __init__( extra_conv: bool = False, skipinit: bool = False, attn_layer: Optional[Callable] = None, - attn_gain: bool = 2.0, + attn_gain: float = 2.0, act_layer: Optional[Callable] = None, conv_layer: Callable = ScaledStdConv2d, drop_path_rate: float = 0., ): + """Initialize NormFreeBlock. + + Args: + in_chs: Input channels. + out_chs: Output channels. + stride: Stride for convolution. + dilation: Dilation rate. + first_dilation: First dilation rate. + alpha: Alpha scaling factor for residual. + beta: Beta scaling factor for pre-activation. + bottle_ratio: Bottleneck ratio. + group_size: Group convolution size. + ch_div: Channel divisor for rounding. + reg: Use RegNet-style configuration. + extra_conv: Add extra 3x3 convolution. + skipinit: Use skipinit initialization. + attn_layer: Attention layer type. + attn_gain: Attention gain factor. + act_layer: Activation layer type. + conv_layer: Convolution layer type. + drop_path_rate: Stochastic depth drop rate. + """ super().__init__() first_dilation = first_dilation or dilation out_chs = out_chs or in_chs @@ -174,7 +242,15 @@ def __init__( self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() self.skipinit_gain = nn.Parameter(torch.tensor(0.)) if skipinit else None - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor. + + Returns: + Output tensor. + """ out = self.act1(x) * self.beta # shortcut branch @@ -207,7 +283,20 @@ def create_stem( conv_layer: Optional[Callable] = None, act_layer: Optional[Callable] = None, preact_feature: bool = True, -): +) -> Tuple[nn.Sequential, int, Dict[str, Any]]: + """Create stem module for NFNet models. + + Args: + in_chs: Input channels. + out_chs: Output channels. + stem_type: Type of stem ('', 'deep', 'deep_tiered', 'deep_quad', '3x3', '7x7', etc.). + conv_layer: Convolution layer type. + act_layer: Activation layer type. + preact_feature: Use pre-activation feature. + + Returns: + Tuple of (stem_module, stem_stride, stem_feature_info). + """ stem_stride = 2 stem_feature = dict(num_chs=out_chs, reduction=2, module='stem.conv') stem = OrderedDict() @@ -298,7 +387,7 @@ def __init__( output_stride: int = 32, drop_rate: float = 0., drop_path_rate: float = 0., - **kwargs, + **kwargs: Any, ): """ Args: @@ -415,7 +504,8 @@ def __init__( nn.init.zeros_(m.bias) @torch.jit.ignore - def group_matcher(self, coarse=False): + def group_matcher(self, coarse: bool = False) -> Dict[str, Any]: + """Group parameters for optimization.""" matcher = dict( stem=r'^stem', blocks=[ @@ -426,18 +516,34 @@ def group_matcher(self, coarse=False): return matcher @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): + def set_grad_checkpointing(self, enable: bool = True) -> None: + """Enable or disable gradient checkpointing.""" self.grad_checkpointing = enable @torch.jit.ignore def get_classifier(self) -> nn.Module: + """Get the classifier head.""" return self.head.fc - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None: + """Reset the classifier head. + + Args: + num_classes: Number of classes for new classifier. + global_pool: Global pooling type. + """ self.num_classes = num_classes self.head.reset(num_classes, global_pool) - def forward_features(self, x): + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through feature extraction layers. + + Args: + x: Input tensor. + + Returns: + Feature tensor. + """ x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.stages, x) @@ -447,23 +553,53 @@ def forward_features(self, x): x = self.final_act(x) return x - def forward_head(self, x, pre_logits: bool = False): + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + """Forward pass through classifier head. + + Args: + x: Input features. + pre_logits: Return features before final linear layer. + + Returns: + Classification logits or features. + """ return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor. + + Returns: + Output logits. + """ x = self.forward_features(x) x = self.forward_head(x) return x def _nfres_cfg( - depths, - channels=(256, 512, 1024, 2048), - group_size=None, - act_layer='relu', - attn_layer=None, - attn_kwargs=None, -): + depths: Tuple[int, ...], + channels: Tuple[int, ...] = (256, 512, 1024, 2048), + group_size: Optional[int] = None, + act_layer: str = 'relu', + attn_layer: Optional[str] = None, + attn_kwargs: Optional[Dict[str, Any]] = None, +) -> NfCfg: + """Create NFNet ResNet configuration. + + Args: + depths: Number of blocks in each stage. + channels: Channel dimensions for each stage. + group_size: Group convolution size. + act_layer: Activation layer type. + attn_layer: Attention layer type. + attn_kwargs: Attention layer arguments. + + Returns: + NFNet configuration. + """ attn_kwargs = attn_kwargs or {} cfg = NfCfg( depths=depths, @@ -479,7 +615,16 @@ def _nfres_cfg( return cfg -def _nfreg_cfg(depths, channels=(48, 104, 208, 440)): +def _nfreg_cfg(depths: Tuple[int, ...], channels: Tuple[int, ...] = (48, 104, 208, 440)) -> NfCfg: + """Create NFNet RegNet configuration. + + Args: + depths: Number of blocks in each stage. + channels: Channel dimensions for each stage. + + Returns: + NFNet configuration. + """ num_features = 1280 * channels[-1] // 440 attn_kwargs = dict(rd_ratio=0.5) cfg = NfCfg( @@ -498,15 +643,30 @@ def _nfreg_cfg(depths, channels=(48, 104, 208, 440)): def _nfnet_cfg( - depths, - channels=(256, 512, 1536, 1536), - group_size=128, - bottle_ratio=0.5, - feat_mult=2., - act_layer='gelu', - attn_layer='se', - attn_kwargs=None, -): + depths: Tuple[int, ...], + channels: Tuple[int, ...] = (256, 512, 1536, 1536), + group_size: int = 128, + bottle_ratio: float = 0.5, + feat_mult: float = 2., + act_layer: str = 'gelu', + attn_layer: str = 'se', + attn_kwargs: Optional[Dict[str, Any]] = None, +) -> NfCfg: + """Create NFNet configuration. + + Args: + depths: Number of blocks in each stage. + channels: Channel dimensions for each stage. + group_size: Group convolution size. + bottle_ratio: Bottleneck ratio. + feat_mult: Feature multiplier for final layer. + act_layer: Activation layer type. + attn_layer: Attention layer type. + attn_kwargs: Attention layer arguments. + + Returns: + NFNet configuration. + """ num_features = int(channels[-1] * feat_mult) attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(rd_ratio=0.5) cfg = NfCfg( @@ -526,11 +686,22 @@ def _nfnet_cfg( def _dm_nfnet_cfg( - depths, - channels=(256, 512, 1536, 1536), - act_layer='gelu', - skipinit=True, -): + depths: Tuple[int, ...], + channels: Tuple[int, ...] = (256, 512, 1536, 1536), + act_layer: str = 'gelu', + skipinit: bool = True, +) -> NfCfg: + """Create DeepMind NFNet configuration. + + Args: + depths: Number of blocks in each stage. + channels: Channel dimensions for each stage. + act_layer: Activation layer type. + skipinit: Use skipinit initialization. + + Returns: + NFNet configuration. + """ cfg = NfCfg( depths=depths, channels=channels, @@ -615,7 +786,17 @@ def _dm_nfnet_cfg( ) -def _create_normfreenet(variant, pretrained=False, **kwargs): +def _create_normfreenet(variant: str, pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """Create a NormFreeNet model. + + Args: + variant: Model variant name. + pretrained: Load pretrained weights. + **kwargs: Additional model arguments. + + Returns: + NormFreeNet model instance. + """ model_cfg = model_cfgs[variant] feature_cfg = dict(flatten_sequential=True) return build_model_with_cfg( @@ -628,7 +809,16 @@ def _create_normfreenet(variant, pretrained=False, **kwargs): ) -def _dcfg(url='', **kwargs): +def _dcfg(url: str = '', **kwargs: Any) -> Dict[str, Any]: + """Create default configuration dictionary. + + Args: + url: Model weight URL. + **kwargs: Additional configuration options. + + Returns: + Configuration dictionary. + """ return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), @@ -744,303 +934,240 @@ def _dcfg(url='', **kwargs): @register_model -def dm_nfnet_f0(pretrained=False, **kwargs) -> NormFreeNet: - """ NFNet-F0 (DeepMind weight compatible) - `High-Performance Large-Scale Image Recognition Without Normalization` - - https://arxiv.org/abs/2102.06171 - """ +def dm_nfnet_f0(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """NFNet-F0 (DeepMind weight compatible).""" return _create_normfreenet('dm_nfnet_f0', pretrained=pretrained, **kwargs) @register_model -def dm_nfnet_f1(pretrained=False, **kwargs) -> NormFreeNet: - """ NFNet-F1 (DeepMind weight compatible) - `High-Performance Large-Scale Image Recognition Without Normalization` - - https://arxiv.org/abs/2102.06171 - """ +def dm_nfnet_f1(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """NFNet-F1 (DeepMind weight compatible).""" return _create_normfreenet('dm_nfnet_f1', pretrained=pretrained, **kwargs) @register_model -def dm_nfnet_f2(pretrained=False, **kwargs) -> NormFreeNet: - """ NFNet-F2 (DeepMind weight compatible) - `High-Performance Large-Scale Image Recognition Without Normalization` - - https://arxiv.org/abs/2102.06171 - """ +def dm_nfnet_f2(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """NFNet-F2 (DeepMind weight compatible).""" return _create_normfreenet('dm_nfnet_f2', pretrained=pretrained, **kwargs) @register_model -def dm_nfnet_f3(pretrained=False, **kwargs) -> NormFreeNet: - """ NFNet-F3 (DeepMind weight compatible) - `High-Performance Large-Scale Image Recognition Without Normalization` - - https://arxiv.org/abs/2102.06171 - """ +def dm_nfnet_f3(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """NFNet-F3 (DeepMind weight compatible).""" return _create_normfreenet('dm_nfnet_f3', pretrained=pretrained, **kwargs) @register_model -def dm_nfnet_f4(pretrained=False, **kwargs) -> NormFreeNet: - """ NFNet-F4 (DeepMind weight compatible) - `High-Performance Large-Scale Image Recognition Without Normalization` - - https://arxiv.org/abs/2102.06171 - """ +def dm_nfnet_f4(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """NFNet-F4 (DeepMind weight compatible).""" return _create_normfreenet('dm_nfnet_f4', pretrained=pretrained, **kwargs) @register_model -def dm_nfnet_f5(pretrained=False, **kwargs) -> NormFreeNet: - """ NFNet-F5 (DeepMind weight compatible) - `High-Performance Large-Scale Image Recognition Without Normalization` - - https://arxiv.org/abs/2102.06171 - """ +def dm_nfnet_f5(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """NFNet-F5 (DeepMind weight compatible).""" return _create_normfreenet('dm_nfnet_f5', pretrained=pretrained, **kwargs) @register_model -def dm_nfnet_f6(pretrained=False, **kwargs) -> NormFreeNet: - """ NFNet-F6 (DeepMind weight compatible) - `High-Performance Large-Scale Image Recognition Without Normalization` - - https://arxiv.org/abs/2102.06171 - """ +def dm_nfnet_f6(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """NFNet-F6 (DeepMind weight compatible).""" return _create_normfreenet('dm_nfnet_f6', pretrained=pretrained, **kwargs) @register_model -def nfnet_f0(pretrained=False, **kwargs) -> NormFreeNet: - """ NFNet-F0 - `High-Performance Large-Scale Image Recognition Without Normalization` - - https://arxiv.org/abs/2102.06171 - """ +def nfnet_f0(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """NFNet-F0.""" return _create_normfreenet('nfnet_f0', pretrained=pretrained, **kwargs) @register_model -def nfnet_f1(pretrained=False, **kwargs) -> NormFreeNet: - """ NFNet-F1 - `High-Performance Large-Scale Image Recognition Without Normalization` - - https://arxiv.org/abs/2102.06171 - """ +def nfnet_f1(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """NFNet-F1.""" return _create_normfreenet('nfnet_f1', pretrained=pretrained, **kwargs) @register_model -def nfnet_f2(pretrained=False, **kwargs) -> NormFreeNet: - """ NFNet-F2 - `High-Performance Large-Scale Image Recognition Without Normalization` - - https://arxiv.org/abs/2102.06171 - """ +def nfnet_f2(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """NFNet-F2.""" return _create_normfreenet('nfnet_f2', pretrained=pretrained, **kwargs) @register_model -def nfnet_f3(pretrained=False, **kwargs) -> NormFreeNet: - """ NFNet-F3 - `High-Performance Large-Scale Image Recognition Without Normalization` - - https://arxiv.org/abs/2102.06171 - """ +def nfnet_f3(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """NFNet-F3.""" return _create_normfreenet('nfnet_f3', pretrained=pretrained, **kwargs) @register_model -def nfnet_f4(pretrained=False, **kwargs) -> NormFreeNet: - """ NFNet-F4 - `High-Performance Large-Scale Image Recognition Without Normalization` - - https://arxiv.org/abs/2102.06171 - """ +def nfnet_f4(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """NFNet-F4.""" return _create_normfreenet('nfnet_f4', pretrained=pretrained, **kwargs) @register_model -def nfnet_f5(pretrained=False, **kwargs) -> NormFreeNet: - """ NFNet-F5 - `High-Performance Large-Scale Image Recognition Without Normalization` - - https://arxiv.org/abs/2102.06171 - """ +def nfnet_f5(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """NFNet-F5.""" return _create_normfreenet('nfnet_f5', pretrained=pretrained, **kwargs) @register_model -def nfnet_f6(pretrained=False, **kwargs) -> NormFreeNet: - """ NFNet-F6 - `High-Performance Large-Scale Image Recognition Without Normalization` - - https://arxiv.org/abs/2102.06171 - """ +def nfnet_f6(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """NFNet-F6.""" return _create_normfreenet('nfnet_f6', pretrained=pretrained, **kwargs) @register_model -def nfnet_f7(pretrained=False, **kwargs) -> NormFreeNet: - """ NFNet-F7 - `High-Performance Large-Scale Image Recognition Without Normalization` - - https://arxiv.org/abs/2102.06171 - """ +def nfnet_f7(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """NFNet-F7.""" return _create_normfreenet('nfnet_f7', pretrained=pretrained, **kwargs) @register_model -def nfnet_l0(pretrained=False, **kwargs) -> NormFreeNet: - """ NFNet-L0b w/ SiLU +def nfnet_l0(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """NFNet-L0b w/ SiLU. + My experimental 'light' model w/ F0 repeats, 1.5x final_conv mult, 64 group_size, .25 bottleneck & SE ratio """ return _create_normfreenet('nfnet_l0', pretrained=pretrained, **kwargs) @register_model -def eca_nfnet_l0(pretrained=False, **kwargs) -> NormFreeNet: - """ ECA-NFNet-L0 w/ SiLU +def eca_nfnet_l0(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """ECA-NFNet-L0 w/ SiLU. + My experimental 'light' model w/ F0 repeats, 1.5x final_conv mult, 64 group_size, .25 bottleneck & ECA attn """ return _create_normfreenet('eca_nfnet_l0', pretrained=pretrained, **kwargs) @register_model -def eca_nfnet_l1(pretrained=False, **kwargs) -> NormFreeNet: - """ ECA-NFNet-L1 w/ SiLU +def eca_nfnet_l1(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """ECA-NFNet-L1 w/ SiLU. + My experimental 'light' model w/ F1 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn """ return _create_normfreenet('eca_nfnet_l1', pretrained=pretrained, **kwargs) @register_model -def eca_nfnet_l2(pretrained=False, **kwargs) -> NormFreeNet: - """ ECA-NFNet-L2 w/ SiLU +def eca_nfnet_l2(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """ECA-NFNet-L2 w/ SiLU. + My experimental 'light' model w/ F2 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn """ return _create_normfreenet('eca_nfnet_l2', pretrained=pretrained, **kwargs) @register_model -def eca_nfnet_l3(pretrained=False, **kwargs) -> NormFreeNet: - """ ECA-NFNet-L3 w/ SiLU +def eca_nfnet_l3(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """ECA-NFNet-L3 w/ SiLU. + My experimental 'light' model w/ F3 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn """ return _create_normfreenet('eca_nfnet_l3', pretrained=pretrained, **kwargs) @register_model -def nf_regnet_b0(pretrained=False, **kwargs) -> NormFreeNet: - """ Normalization-Free RegNet-B0 - `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - - https://arxiv.org/abs/2101.08692 +def nf_regnet_b0(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """Normalization-Free RegNet-B0. """ return _create_normfreenet('nf_regnet_b0', pretrained=pretrained, **kwargs) @register_model -def nf_regnet_b1(pretrained=False, **kwargs) -> NormFreeNet: - """ Normalization-Free RegNet-B1 - `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - - https://arxiv.org/abs/2101.08692 +def nf_regnet_b1(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """Normalization-Free RegNet-B1. """ return _create_normfreenet('nf_regnet_b1', pretrained=pretrained, **kwargs) @register_model -def nf_regnet_b2(pretrained=False, **kwargs) -> NormFreeNet: - """ Normalization-Free RegNet-B2 - `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - - https://arxiv.org/abs/2101.08692 +def nf_regnet_b2(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """Normalization-Free RegNet-B2. """ return _create_normfreenet('nf_regnet_b2', pretrained=pretrained, **kwargs) @register_model -def nf_regnet_b3(pretrained=False, **kwargs) -> NormFreeNet: - """ Normalization-Free RegNet-B3 - `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - - https://arxiv.org/abs/2101.08692 +def nf_regnet_b3(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """Normalization-Free RegNet-B3. """ return _create_normfreenet('nf_regnet_b3', pretrained=pretrained, **kwargs) @register_model -def nf_regnet_b4(pretrained=False, **kwargs) -> NormFreeNet: - """ Normalization-Free RegNet-B4 - `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - - https://arxiv.org/abs/2101.08692 +def nf_regnet_b4(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """Normalization-Free RegNet-B4. """ return _create_normfreenet('nf_regnet_b4', pretrained=pretrained, **kwargs) @register_model -def nf_regnet_b5(pretrained=False, **kwargs) -> NormFreeNet: - """ Normalization-Free RegNet-B5 - `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - - https://arxiv.org/abs/2101.08692 +def nf_regnet_b5(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """Normalization-Free RegNet-B5. """ return _create_normfreenet('nf_regnet_b5', pretrained=pretrained, **kwargs) @register_model -def nf_resnet26(pretrained=False, **kwargs) -> NormFreeNet: - """ Normalization-Free ResNet-26 - `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - - https://arxiv.org/abs/2101.08692 +def nf_resnet26(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """Normalization-Free ResNet-26. """ return _create_normfreenet('nf_resnet26', pretrained=pretrained, **kwargs) @register_model -def nf_resnet50(pretrained=False, **kwargs) -> NormFreeNet: - """ Normalization-Free ResNet-50 - `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - - https://arxiv.org/abs/2101.08692 +def nf_resnet50(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """Normalization-Free ResNet-50. """ return _create_normfreenet('nf_resnet50', pretrained=pretrained, **kwargs) @register_model -def nf_resnet101(pretrained=False, **kwargs) -> NormFreeNet: - """ Normalization-Free ResNet-101 - `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - - https://arxiv.org/abs/2101.08692 +def nf_resnet101(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """Normalization-Free ResNet-101. """ return _create_normfreenet('nf_resnet101', pretrained=pretrained, **kwargs) @register_model -def nf_seresnet26(pretrained=False, **kwargs) -> NormFreeNet: - """ Normalization-Free SE-ResNet26 - """ +def nf_seresnet26(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """Normalization-Free SE-ResNet26.""" return _create_normfreenet('nf_seresnet26', pretrained=pretrained, **kwargs) @register_model -def nf_seresnet50(pretrained=False, **kwargs) -> NormFreeNet: - """ Normalization-Free SE-ResNet50 - """ +def nf_seresnet50(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """Normalization-Free SE-ResNet50.""" return _create_normfreenet('nf_seresnet50', pretrained=pretrained, **kwargs) @register_model -def nf_seresnet101(pretrained=False, **kwargs) -> NormFreeNet: - """ Normalization-Free SE-ResNet101 - """ +def nf_seresnet101(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """Normalization-Free SE-ResNet101.""" return _create_normfreenet('nf_seresnet101', pretrained=pretrained, **kwargs) @register_model -def nf_ecaresnet26(pretrained=False, **kwargs) -> NormFreeNet: - """ Normalization-Free ECA-ResNet26 - """ +def nf_ecaresnet26(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """Normalization-Free ECA-ResNet26.""" return _create_normfreenet('nf_ecaresnet26', pretrained=pretrained, **kwargs) @register_model -def nf_ecaresnet50(pretrained=False, **kwargs) -> NormFreeNet: - """ Normalization-Free ECA-ResNet50 - """ +def nf_ecaresnet50(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """Normalization-Free ECA-ResNet50.""" return _create_normfreenet('nf_ecaresnet50', pretrained=pretrained, **kwargs) @register_model -def nf_ecaresnet101(pretrained=False, **kwargs) -> NormFreeNet: - """ Normalization-Free ECA-ResNet101 - """ +def nf_ecaresnet101(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """Normalization-Free ECA-ResNet101.""" return _create_normfreenet('nf_ecaresnet101', pretrained=pretrained, **kwargs) @register_model -def test_nfnet(pretrained=False, **kwargs) -> NormFreeNet: +def test_nfnet(pretrained: bool = False, **kwargs: Any) -> NormFreeNet: + """Test NFNet model for experimentation.""" return _create_normfreenet('test_nfnet', pretrained=pretrained, **kwargs) \ No newline at end of file diff --git a/timm/models/pit.py b/timm/models/pit.py index 1d5386a92..b7985cf99 100644 --- a/timm/models/pit.py +++ b/timm/models/pit.py @@ -279,12 +279,12 @@ def forward_intermediates( assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' intermediates = [] take_indices, max_index = feature_take_indices(len(self.transformers), indices) - + # forward pass x = self.patch_embed(x) x = self.pos_drop(x + self.pos_embed) cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) - + last_idx = len(self.transformers) - 1 if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript stages = self.transformers @@ -294,11 +294,11 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages): x, cls_tokens = stage((x, cls_tokens)) if feat_idx in take_indices: - intermediates.append(x) + intermediates.append(x) if intermediates_only: return intermediates - + if feat_idx == last_idx: cls_tokens = self.norm(cls_tokens) diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index 8cd42fe8a..bb1baf664 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -422,7 +422,7 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages): x = stage(x) if feat_idx in take_indices: - intermediates.append(x) + intermediates.append(x) if intermediates_only: return intermediates diff --git a/timm/models/regnet.py b/timm/models/regnet.py index 49a19aa16..f05bad37b 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -26,9 +26,8 @@ import math from dataclasses import dataclass, replace from functools import partial -from typing import Callable, List, Optional, Union, Tuple +from typing import Any, Callable, Dict, List, Optional, Union, Tuple -import numpy as np import torch import torch.nn as nn @@ -45,6 +44,7 @@ @dataclass class RegNetCfg: + """RegNet architecture configuration.""" depth: int = 21 w0: int = 80 wa: float = 42.63 @@ -62,13 +62,36 @@ class RegNetCfg: norm_layer: Union[str, Callable] = 'batchnorm' -def quantize_float(f, q): - """Converts a float to the closest non-zero int divisible by q.""" +def quantize_float(f: float, q: int) -> int: + """Converts a float to the closest non-zero int divisible by q. + + Args: + f: Input float value. + q: Quantization divisor. + + Returns: + Quantized integer value. + """ return int(round(f / q) * q) -def adjust_widths_groups_comp(widths, bottle_ratios, groups, min_ratio=0.): - """Adjusts the compatibility of widths and groups.""" +def adjust_widths_groups_comp( + widths: List[int], + bottle_ratios: List[float], + groups: List[int], + min_ratio: float = 0. +) -> Tuple[List[int], List[int]]: + """Adjusts the compatibility of widths and groups. + + Args: + widths: List of channel widths. + bottle_ratios: List of bottleneck ratios. + groups: List of group sizes. + min_ratio: Minimum ratio for divisibility. + + Returns: + Tuple of adjusted widths and groups. + """ bottleneck_widths = [int(w * b) for w, b in zip(widths, bottle_ratios)] groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_widths)] if min_ratio: @@ -80,29 +103,62 @@ def adjust_widths_groups_comp(widths, bottle_ratios, groups, min_ratio=0.): return widths, groups -def generate_regnet(width_slope, width_initial, width_mult, depth, group_size, quant=8): - """Generates per block widths from RegNet parameters.""" +def generate_regnet( + width_slope: float, + width_initial: int, + width_mult: float, + depth: int, + group_size: int, + quant: int = 8 +) -> Tuple[List[int], int, List[int]]: + """Generates per block widths from RegNet parameters. + + Args: + width_slope: Slope parameter for width progression. + width_initial: Initial width. + width_mult: Width multiplier. + depth: Network depth. + group_size: Group convolution size. + quant: Quantization factor. + + Returns: + Tuple of (widths, num_stages, groups). + """ assert width_slope >= 0 and width_initial > 0 and width_mult > 1 and width_initial % quant == 0 # TODO dWr scaling? # depth = int(depth * (scale ** 0.1)) # width_scale = scale ** 0.4 # dWr scale, exp 0.8 / 2, applied to both group and layer widths - widths_cont = np.arange(depth) * width_slope + width_initial - width_exps = np.round(np.log(widths_cont / width_initial) / np.log(width_mult)) - widths = np.round(np.divide(width_initial * np.power(width_mult, width_exps), quant)) * quant - num_stages, max_stage = len(np.unique(widths)), width_exps.max() + 1 - groups = np.array([group_size for _ in range(num_stages)]) - return widths.astype(int).tolist(), num_stages, groups.astype(int).tolist() + widths_cont = torch.arange(depth, dtype=torch.float32) * width_slope + width_initial + width_exps = torch.round(torch.log(widths_cont / width_initial) / math.log(width_mult)) + widths = torch.round((width_initial * torch.pow(width_mult, width_exps)) / quant) * quant + num_stages, max_stage = len(torch.unique(widths)), int(width_exps.max().item()) + 1 + groups = torch.tensor([group_size for _ in range(num_stages)], dtype=torch.int32) + return widths.int().tolist(), num_stages, groups.tolist() def downsample_conv( - in_chs, - out_chs, - kernel_size=1, - stride=1, - dilation=1, - norm_layer=None, - preact=False, -): + in_chs: int, + out_chs: int, + kernel_size: int = 1, + stride: int = 1, + dilation: int = 1, + norm_layer: Optional[Callable] = None, + preact: bool = False, +) -> nn.Module: + """Create convolutional downsampling module. + + Args: + in_chs: Input channels. + out_chs: Output channels. + kernel_size: Convolution kernel size. + stride: Convolution stride. + dilation: Convolution dilation. + norm_layer: Normalization layer. + preact: Use pre-activation. + + Returns: + Downsampling module. + """ norm_layer = norm_layer or nn.BatchNorm2d kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size dilation = dilation if kernel_size > 1 else 1 @@ -127,15 +183,30 @@ def downsample_conv( def downsample_avg( - in_chs, - out_chs, - kernel_size=1, - stride=1, - dilation=1, - norm_layer=None, - preact=False, -): - """ AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment.""" + in_chs: int, + out_chs: int, + kernel_size: int = 1, + stride: int = 1, + dilation: int = 1, + norm_layer: Optional[Callable] = None, + preact: bool = False, +) -> nn.Sequential: + """Create average pool downsampling module. + + AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment. + + Args: + in_chs: Input channels. + out_chs: Output channels. + kernel_size: Convolution kernel size. + stride: Convolution stride. + dilation: Convolution dilation. + norm_layer: Normalization layer. + preact: Use pre-activation. + + Returns: + Sequential downsampling module. + """ norm_layer = norm_layer or nn.BatchNorm2d avg_stride = stride if dilation == 1 else 1 pool = nn.Identity() @@ -150,15 +221,30 @@ def downsample_avg( def create_shortcut( - downsample_type, - in_chs, - out_chs, - kernel_size, - stride, - dilation=(1, 1), - norm_layer=None, - preact=False, -): + downsample_type: Optional[str], + in_chs: int, + out_chs: int, + kernel_size: int, + stride: int, + dilation: Tuple[int, int] = (1, 1), + norm_layer: Optional[Callable] = None, + preact: bool = False, +) -> Optional[nn.Module]: + """Create shortcut connection for residual blocks. + + Args: + downsample_type: Type of downsampling ('avg', 'conv1x1', or None). + in_chs: Input channels. + out_chs: Output channels. + kernel_size: Kernel size for conv downsampling. + stride: Stride for downsampling. + dilation: Dilation rates. + norm_layer: Normalization layer. + preact: Use pre-activation. + + Returns: + Shortcut module or None. + """ assert downsample_type in ('avg', 'conv1x1', '', None) if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: dargs = dict(stride=stride, dilation=dilation[0], norm_layer=norm_layer, preact=preact) @@ -173,7 +259,7 @@ def create_shortcut( class Bottleneck(nn.Module): - """ RegNet Bottleneck + """RegNet Bottleneck block. This is almost exactly the same as a ResNet Bottleneck. The main difference is the SE block is moved from after conv3 to after conv2. Otherwise, it's just redefining the arguments for groups/bottleneck channels. @@ -181,20 +267,37 @@ class Bottleneck(nn.Module): def __init__( self, - in_chs, - out_chs, - stride=1, - dilation=(1, 1), - bottle_ratio=1, - group_size=1, - se_ratio=0.25, - downsample='conv1x1', - linear_out=False, - act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, + in_chs: int, + out_chs: int, + stride: int = 1, + dilation: Tuple[int, int] = (1, 1), + bottle_ratio: float = 1, + group_size: int = 1, + se_ratio: float = 0.25, + downsample: str = 'conv1x1', + linear_out: bool = False, + act_layer: Callable = nn.ReLU, + norm_layer: Callable = nn.BatchNorm2d, drop_block=None, - drop_path_rate=0., + drop_path_rate: float = 0., ): + """Initialize RegNet Bottleneck block. + + Args: + in_chs: Input channels. + out_chs: Output channels. + stride: Convolution stride. + dilation: Dilation rates for conv2 and shortcut. + bottle_ratio: Bottleneck ratio (reduction factor). + group_size: Group convolution size. + se_ratio: Squeeze-and-excitation ratio. + downsample: Shortcut downsampling type. + linear_out: Use linear activation for output. + act_layer: Activation layer. + norm_layer: Normalization layer. + drop_block: Drop block layer. + drop_path_rate: Stochastic depth drop rate. + """ super(Bottleneck, self).__init__() act_layer = get_act_layer(act_layer) bottleneck_chs = int(round(out_chs * bottle_ratio)) @@ -230,10 +333,19 @@ def __init__( ) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() - def zero_init_last(self): + def zero_init_last(self) -> None: + """Zero-initialize the last batch norm in the block.""" nn.init.zeros_(self.conv3.bn.weight) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor. + + Returns: + Output tensor. + """ shortcut = x x = self.conv1(x) x = self.conv2(x) @@ -248,28 +360,44 @@ def forward(self, x): class PreBottleneck(nn.Module): - """ RegNet Bottleneck + """Pre-activation RegNet Bottleneck block. - This is almost exactly the same as a ResNet Bottleneck. The main difference is the SE block is moved from - after conv3 to after conv2. Otherwise, it's just redefining the arguments for groups/bottleneck channels. + Similar to Bottleneck but with pre-activation normalization. """ def __init__( self, - in_chs, - out_chs, - stride=1, - dilation=(1, 1), - bottle_ratio=1, - group_size=1, - se_ratio=0.25, - downsample='conv1x1', - linear_out=False, - act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, + in_chs: int, + out_chs: int, + stride: int = 1, + dilation: Tuple[int, int] = (1, 1), + bottle_ratio: float = 1, + group_size: int = 1, + se_ratio: float = 0.25, + downsample: str = 'conv1x1', + linear_out: bool = False, + act_layer: Callable = nn.ReLU, + norm_layer: Callable = nn.BatchNorm2d, drop_block=None, - drop_path_rate=0., + drop_path_rate: float = 0., ): + """Initialize pre-activation RegNet Bottleneck block. + + Args: + in_chs: Input channels. + out_chs: Output channels. + stride: Convolution stride. + dilation: Dilation rates for conv2 and shortcut. + bottle_ratio: Bottleneck ratio (reduction factor). + group_size: Group convolution size. + se_ratio: Squeeze-and-excitation ratio. + downsample: Shortcut downsampling type. + linear_out: Use linear activation for output. + act_layer: Activation layer. + norm_layer: Normalization layer. + drop_block: Drop block layer. + drop_path_rate: Stochastic depth drop rate. + """ super(PreBottleneck, self).__init__() norm_act_layer = get_norm_act_layer(norm_layer, act_layer) bottleneck_chs = int(round(out_chs * bottle_ratio)) @@ -304,10 +432,19 @@ def __init__( ) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() - def zero_init_last(self): + def zero_init_last(self) -> None: + """Zero-initialize the last batch norm (no-op for pre-activation).""" pass - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor. + + Returns: + Output tensor. + """ x = self.norm1(x) shortcut = x x = self.conv1(x) @@ -324,19 +461,34 @@ def forward(self, x): class RegStage(nn.Module): - """Stage (sequence of blocks w/ the same output shape).""" + """RegNet stage (sequence of blocks with the same output shape). + + A stage consists of multiple bottleneck blocks with the same output dimensions. + """ def __init__( self, - depth, - in_chs, - out_chs, - stride, - dilation, - drop_path_rates=None, - block_fn=Bottleneck, + depth: int, + in_chs: int, + out_chs: int, + stride: int, + dilation: int, + drop_path_rates: Optional[List[float]] = None, + block_fn: Callable = Bottleneck, **block_kwargs, ): + """Initialize RegNet stage. + + Args: + depth: Number of blocks in stage. + in_chs: Input channels. + out_chs: Output channels. + stride: Stride for first block. + dilation: Dilation rate. + drop_path_rates: Drop path rates for each block. + block_fn: Block class to use. + **block_kwargs: Additional block arguments. + """ super(RegStage, self).__init__() self.grad_checkpointing = False @@ -360,7 +512,15 @@ def __init__( ) first_dilation = dilation - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through all blocks in the stage. + + Args: + x: Input tensor. + + Returns: + Output tensor. + """ if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.children(), x) else: @@ -370,7 +530,7 @@ def forward(self, x): class RegNet(nn.Module): - """RegNet-X, Y, and Z Models + """RegNet-X, Y, and Z Models. Paper: https://arxiv.org/abs/2003.13678 Original Impl: https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py @@ -379,27 +539,27 @@ class RegNet(nn.Module): def __init__( self, cfg: RegNetCfg, - in_chans=3, - num_classes=1000, - output_stride=32, - global_pool='avg', - drop_rate=0., - drop_path_rate=0., - zero_init_last=True, + in_chans: int = 3, + num_classes: int = 1000, + output_stride: int = 32, + global_pool: str = 'avg', + drop_rate: float = 0., + drop_path_rate: float = 0., + zero_init_last: bool = True, **kwargs, ): - """ + """Initialize RegNet model. Args: - cfg (RegNetCfg): Model architecture configuration - in_chans (int): Number of input channels (default: 3) - num_classes (int): Number of classifier classes (default: 1000) - output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32) - global_pool (str): Global pooling type (default: 'avg') - drop_rate (float): Dropout rate (default: 0.) - drop_path_rate (float): Stochastic depth drop-path rate (default: 0.) - zero_init_last (bool): Zero-init last weight of residual path - kwargs (dict): Extra kwargs overlayed onto cfg + cfg: Model architecture configuration. + in_chans: Number of input channels. + num_classes: Number of classifier classes. + output_stride: Output stride of network, one of (8, 16, 32). + global_pool: Global pooling type. + drop_rate: Dropout rate. + drop_path_rate: Stochastic depth drop-path rate. + zero_init_last: Zero-init last weight of residual path. + kwargs: Extra kwargs overlayed onto cfg. """ super().__init__() self.num_classes = num_classes @@ -459,12 +619,30 @@ def __init__( named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) - def _get_stage_args(self, cfg: RegNetCfg, default_stride=2, output_stride=32, drop_path_rate=0.): + def _get_stage_args( + self, + cfg: RegNetCfg, + default_stride: int = 2, + output_stride: int = 32, + drop_path_rate: float = 0. + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + """Generate stage arguments from configuration. + + Args:` + cfg: RegNet configuration. + default_stride: Default stride for stages. + output_stride: Target output stride. + drop_path_rate: Stochastic depth rate. + + Returns: + Tuple of (per_stage_args, common_args). + """ # Generate RegNet ws per block widths, num_stages, stage_gs = generate_regnet(cfg.wa, cfg.w0, cfg.wm, cfg.depth, cfg.group_size) # Convert to per stage format - stage_widths, stage_depths = np.unique(widths, return_counts=True) + stage_widths, stage_depths = torch.unique(torch.tensor(widths), return_counts=True) + stage_widths, stage_depths = stage_widths.tolist(), stage_depths.tolist() stage_br = [cfg.bottle_ratio for _ in range(num_stages)] stage_strides = [] stage_dilations = [] @@ -479,7 +657,10 @@ def _get_stage_args(self, cfg: RegNetCfg, default_stride=2, output_stride=32, dr net_stride *= stride stage_strides.append(stride) stage_dilations.append(dilation) - stage_dpr = np.split(np.linspace(0, drop_path_rate, sum(stage_depths)), np.cumsum(stage_depths[:-1])) + dpr_tensor = torch.linspace(0, drop_path_rate, sum(stage_depths)) + split_indices = torch.cumsum(torch.tensor(stage_depths[:-1]), dim=0) + stage_dpr = torch.tensor_split(dpr_tensor, split_indices.tolist()) + stage_dpr = [dpr.tolist() for dpr in stage_dpr] # Adjust the compatibility of ws and gws stage_widths, stage_gs = adjust_widths_groups_comp( @@ -499,22 +680,31 @@ def _get_stage_args(self, cfg: RegNetCfg, default_stride=2, output_stride=32, dr return per_stage_args, common_args @torch.jit.ignore - def group_matcher(self, coarse=False): + def group_matcher(self, coarse: bool = False) -> Dict[str, Any]: + """Group parameters for optimization.""" return dict( stem=r'^stem', blocks=r'^s(\d+)' if coarse else r'^s(\d+)\.b(\d+)', ) @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): + def set_grad_checkpointing(self, enable: bool = True) -> None: + """Enable or disable gradient checkpointing.""" for s in list(self.children())[1:-1]: s.grad_checkpointing = enable @torch.jit.ignore def get_classifier(self) -> nn.Module: + """Get the classifier head.""" return self.head.fc - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None: + """Reset the classifier head. + + Args: + num_classes: Number of classes for new classifier. + global_pool: Global pooling type. + """ self.num_classes = num_classes self.head.reset(num_classes, pool_type=global_pool) @@ -571,8 +761,16 @@ def prune_intermediate_layers( indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, - ): - """ Prune layers not required for specified intermediates. + ) -> List[int]: + """Prune layers not required for specified intermediates. + + Args: + indices: Indices of intermediate layers to keep. + prune_norm: Whether to prune normalization layer. + prune_head: Whether to prune the classifier head. + + Returns: + List of indices that were kept. """ take_indices, max_index = feature_take_indices(5, indices) layer_names = ('s1', 's2', 's3', 's4') @@ -585,7 +783,15 @@ def prune_intermediate_layers( self.reset_classifier(0, '') return take_indices - def forward_features(self, x): + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through feature extraction layers. + + Args: + x: Input tensor. + + Returns: + Feature tensor. + """ x = self.stem(x) x = self.s1(x) x = self.s2(x) @@ -594,16 +800,40 @@ def forward_features(self, x): x = self.final_conv(x) return x - def forward_head(self, x, pre_logits: bool = False): + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + """Forward pass through classifier head. + + Args: + x: Input features. + pre_logits: Return features before final linear layer. + + Returns: + Classification logits or features. + """ return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor. + + Returns: + Output logits. + """ x = self.forward_features(x) x = self.forward_head(x) return x -def _init_weights(module, name='', zero_init_last=False): +def _init_weights(module: nn.Module, name: str = '', zero_init_last: bool = False) -> None: + """Initialize module weights. + + Args: + module: PyTorch module to initialize. + name: Module name. + zero_init_last: Zero-initialize last layer weights. + """ if isinstance(module, nn.Conv2d): fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels fan_out //= module.groups @@ -618,7 +848,15 @@ def _init_weights(module, name='', zero_init_last=False): module.zero_init_last() -def _filter_fn(state_dict): +def _filter_fn(state_dict: Dict[str, Any]) -> Dict[str, Any]: + """Filter and remap state dict keys for compatibility. + + Args: + state_dict: Raw state dictionary. + + Returns: + Filtered state dictionary. + """ state_dict = state_dict.get('model', state_dict) replaces = [ ('f.a.0', 'conv1.conv'), @@ -740,7 +978,17 @@ def _filter_fn(state_dict): ) -def _create_regnet(variant, pretrained, **kwargs): +def _create_regnet(variant: str, pretrained: bool, **kwargs) -> RegNet: + """Create a RegNet model. + + Args: + variant: Model variant name. + pretrained: Load pretrained weights. + **kwargs: Additional model arguments. + + Returns: + RegNet model instance. + """ return build_model_with_cfg( RegNet, variant, pretrained, model_cfg=model_cfgs[variant], @@ -748,7 +996,16 @@ def _create_regnet(variant, pretrained, **kwargs): **kwargs) -def _cfg(url='', **kwargs): +def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: + """Create default configuration dictionary. + + Args: + url: Model weight URL. + **kwargs: Additional configuration options. + + Returns: + Configuration dictionary. + """ return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'test_input_size': (3, 288, 288), 'crop_pct': 0.95, 'test_crop_pct': 1.0, @@ -758,7 +1015,16 @@ def _cfg(url='', **kwargs): } -def _cfgpyc(url='', **kwargs): +def _cfgpyc(url: str = '', **kwargs) -> Dict[str, Any]: + """Create pycls configuration dictionary. + + Args: + url: Model weight URL. + **kwargs: Additional configuration options. + + Returns: + Configuration dictionary. + """ return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bicubic', @@ -768,7 +1034,16 @@ def _cfgpyc(url='', **kwargs): } -def _cfgtv2(url='', **kwargs): +def _cfgtv2(url: str = '', **kwargs) -> Dict[str, Any]: + """Create torchvision v2 configuration dictionary. + + Args: + url: Model weight URL. + **kwargs: Additional configuration options. + + Returns: + Configuration dictionary. + """ return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.965, 'interpolation': 'bicubic', @@ -963,205 +1238,205 @@ def _cfgtv2(url='', **kwargs): @register_model -def regnetx_002(pretrained=False, **kwargs) -> RegNet: +def regnetx_002(pretrained: bool = False, **kwargs) -> RegNet: """RegNetX-200MF""" return _create_regnet('regnetx_002', pretrained, **kwargs) @register_model -def regnetx_004(pretrained=False, **kwargs) -> RegNet: +def regnetx_004(pretrained: bool = False, **kwargs) -> RegNet: """RegNetX-400MF""" return _create_regnet('regnetx_004', pretrained, **kwargs) @register_model -def regnetx_004_tv(pretrained=False, **kwargs) -> RegNet: +def regnetx_004_tv(pretrained: bool = False, **kwargs) -> RegNet: """RegNetX-400MF w/ torchvision group rounding""" return _create_regnet('regnetx_004_tv', pretrained, **kwargs) @register_model -def regnetx_006(pretrained=False, **kwargs) -> RegNet: +def regnetx_006(pretrained: bool = False, **kwargs) -> RegNet: """RegNetX-600MF""" return _create_regnet('regnetx_006', pretrained, **kwargs) @register_model -def regnetx_008(pretrained=False, **kwargs) -> RegNet: +def regnetx_008(pretrained: bool = False, **kwargs) -> RegNet: """RegNetX-800MF""" return _create_regnet('regnetx_008', pretrained, **kwargs) @register_model -def regnetx_016(pretrained=False, **kwargs) -> RegNet: +def regnetx_016(pretrained: bool = False, **kwargs) -> RegNet: """RegNetX-1.6GF""" return _create_regnet('regnetx_016', pretrained, **kwargs) @register_model -def regnetx_032(pretrained=False, **kwargs) -> RegNet: +def regnetx_032(pretrained: bool = False, **kwargs) -> RegNet: """RegNetX-3.2GF""" return _create_regnet('regnetx_032', pretrained, **kwargs) @register_model -def regnetx_040(pretrained=False, **kwargs) -> RegNet: +def regnetx_040(pretrained: bool = False, **kwargs) -> RegNet: """RegNetX-4.0GF""" return _create_regnet('regnetx_040', pretrained, **kwargs) @register_model -def regnetx_064(pretrained=False, **kwargs) -> RegNet: +def regnetx_064(pretrained: bool = False, **kwargs) -> RegNet: """RegNetX-6.4GF""" return _create_regnet('regnetx_064', pretrained, **kwargs) @register_model -def regnetx_080(pretrained=False, **kwargs) -> RegNet: +def regnetx_080(pretrained: bool = False, **kwargs) -> RegNet: """RegNetX-8.0GF""" return _create_regnet('regnetx_080', pretrained, **kwargs) @register_model -def regnetx_120(pretrained=False, **kwargs) -> RegNet: +def regnetx_120(pretrained: bool = False, **kwargs) -> RegNet: """RegNetX-12GF""" return _create_regnet('regnetx_120', pretrained, **kwargs) @register_model -def regnetx_160(pretrained=False, **kwargs) -> RegNet: +def regnetx_160(pretrained: bool = False, **kwargs) -> RegNet: """RegNetX-16GF""" return _create_regnet('regnetx_160', pretrained, **kwargs) @register_model -def regnetx_320(pretrained=False, **kwargs) -> RegNet: +def regnetx_320(pretrained: bool = False, **kwargs) -> RegNet: """RegNetX-32GF""" return _create_regnet('regnetx_320', pretrained, **kwargs) @register_model -def regnety_002(pretrained=False, **kwargs) -> RegNet: +def regnety_002(pretrained: bool = False, **kwargs) -> RegNet: """RegNetY-200MF""" return _create_regnet('regnety_002', pretrained, **kwargs) @register_model -def regnety_004(pretrained=False, **kwargs) -> RegNet: +def regnety_004(pretrained: bool = False, **kwargs) -> RegNet: """RegNetY-400MF""" return _create_regnet('regnety_004', pretrained, **kwargs) @register_model -def regnety_006(pretrained=False, **kwargs) -> RegNet: +def regnety_006(pretrained: bool = False, **kwargs) -> RegNet: """RegNetY-600MF""" return _create_regnet('regnety_006', pretrained, **kwargs) @register_model -def regnety_008(pretrained=False, **kwargs) -> RegNet: +def regnety_008(pretrained: bool = False, **kwargs) -> RegNet: """RegNetY-800MF""" return _create_regnet('regnety_008', pretrained, **kwargs) @register_model -def regnety_008_tv(pretrained=False, **kwargs) -> RegNet: +def regnety_008_tv(pretrained: bool = False, **kwargs) -> RegNet: """RegNetY-800MF w/ torchvision group rounding""" return _create_regnet('regnety_008_tv', pretrained, **kwargs) @register_model -def regnety_016(pretrained=False, **kwargs) -> RegNet: +def regnety_016(pretrained: bool = False, **kwargs) -> RegNet: """RegNetY-1.6GF""" return _create_regnet('regnety_016', pretrained, **kwargs) @register_model -def regnety_032(pretrained=False, **kwargs) -> RegNet: +def regnety_032(pretrained: bool = False, **kwargs) -> RegNet: """RegNetY-3.2GF""" return _create_regnet('regnety_032', pretrained, **kwargs) @register_model -def regnety_040(pretrained=False, **kwargs) -> RegNet: +def regnety_040(pretrained: bool = False, **kwargs) -> RegNet: """RegNetY-4.0GF""" return _create_regnet('regnety_040', pretrained, **kwargs) @register_model -def regnety_064(pretrained=False, **kwargs) -> RegNet: +def regnety_064(pretrained: bool = False, **kwargs) -> RegNet: """RegNetY-6.4GF""" return _create_regnet('regnety_064', pretrained, **kwargs) @register_model -def regnety_080(pretrained=False, **kwargs) -> RegNet: +def regnety_080(pretrained: bool = False, **kwargs) -> RegNet: """RegNetY-8.0GF""" return _create_regnet('regnety_080', pretrained, **kwargs) @register_model -def regnety_080_tv(pretrained=False, **kwargs) -> RegNet: +def regnety_080_tv(pretrained: bool = False, **kwargs) -> RegNet: """RegNetY-8.0GF w/ torchvision group rounding""" return _create_regnet('regnety_080_tv', pretrained, **kwargs) @register_model -def regnety_120(pretrained=False, **kwargs) -> RegNet: +def regnety_120(pretrained: bool = False, **kwargs) -> RegNet: """RegNetY-12GF""" return _create_regnet('regnety_120', pretrained, **kwargs) @register_model -def regnety_160(pretrained=False, **kwargs) -> RegNet: +def regnety_160(pretrained: bool = False, **kwargs) -> RegNet: """RegNetY-16GF""" return _create_regnet('regnety_160', pretrained, **kwargs) @register_model -def regnety_320(pretrained=False, **kwargs) -> RegNet: +def regnety_320(pretrained: bool = False, **kwargs) -> RegNet: """RegNetY-32GF""" return _create_regnet('regnety_320', pretrained, **kwargs) @register_model -def regnety_640(pretrained=False, **kwargs) -> RegNet: +def regnety_640(pretrained: bool = False, **kwargs) -> RegNet: """RegNetY-64GF""" return _create_regnet('regnety_640', pretrained, **kwargs) @register_model -def regnety_1280(pretrained=False, **kwargs) -> RegNet: +def regnety_1280(pretrained: bool = False, **kwargs) -> RegNet: """RegNetY-128GF""" return _create_regnet('regnety_1280', pretrained, **kwargs) @register_model -def regnety_2560(pretrained=False, **kwargs) -> RegNet: +def regnety_2560(pretrained: bool = False, **kwargs) -> RegNet: """RegNetY-256GF""" return _create_regnet('regnety_2560', pretrained, **kwargs) @register_model -def regnety_040_sgn(pretrained=False, **kwargs) -> RegNet: +def regnety_040_sgn(pretrained: bool = False, **kwargs) -> RegNet: """RegNetY-4.0GF w/ GroupNorm """ return _create_regnet('regnety_040_sgn', pretrained, **kwargs) @register_model -def regnetv_040(pretrained=False, **kwargs) -> RegNet: +def regnetv_040(pretrained: bool = False, **kwargs) -> RegNet: """RegNetV-4.0GF (pre-activation)""" return _create_regnet('regnetv_040', pretrained, **kwargs) @register_model -def regnetv_064(pretrained=False, **kwargs) -> RegNet: +def regnetv_064(pretrained: bool = False, **kwargs) -> RegNet: """RegNetV-6.4GF (pre-activation)""" return _create_regnet('regnetv_064', pretrained, **kwargs) @register_model -def regnetz_005(pretrained=False, **kwargs) -> RegNet: +def regnetz_005(pretrained: bool = False, **kwargs) -> RegNet: """RegNetZ-500MF NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py but it's not clear it is equivalent to paper model as not detailed in the paper. @@ -1170,7 +1445,7 @@ def regnetz_005(pretrained=False, **kwargs) -> RegNet: @register_model -def regnetz_040(pretrained=False, **kwargs) -> RegNet: +def regnetz_040(pretrained: bool = False, **kwargs) -> RegNet: """RegNetZ-4.0GF NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py but it's not clear it is equivalent to paper model as not detailed in the paper. @@ -1179,7 +1454,7 @@ def regnetz_040(pretrained=False, **kwargs) -> RegNet: @register_model -def regnetz_040_h(pretrained=False, **kwargs) -> RegNet: +def regnetz_040_h(pretrained: bool = False, **kwargs) -> RegNet: """RegNetZ-4.0GF NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py but it's not clear it is equivalent to paper model as not detailed in the paper. diff --git a/timm/models/repghost.py b/timm/models/repghost.py index 77fc35d59..a75d9d850 100644 --- a/timm/models/repghost.py +++ b/timm/models/repghost.py @@ -168,7 +168,7 @@ def __init__( # Point-wise linear projection self.ghost2 = RepGhostModule(mid_chs, out_chs, relu=False, reparam=reparam) - + # shortcut if in_chs == out_chs and self.stride == 1: self.shortcut = nn.Sequential() @@ -199,7 +199,7 @@ def forward(self, x): # 2nd ghost bottleneck x = self.ghost2(x) - + x += self.shortcut(shortcut) return x @@ -256,8 +256,8 @@ def __init__( out_chs = make_divisible(exp_size * width * 2, 4) stages.append(nn.Sequential(ConvBnAct(prev_chs, out_chs, 1))) self.pool_dim = prev_chs = out_chs - - self.blocks = nn.Sequential(*stages) + + self.blocks = nn.Sequential(*stages) # building last several layers self.num_features = prev_chs @@ -338,7 +338,7 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages, start=1): x = stage(x) if feat_idx in take_indices: - intermediates.append(x) + intermediates.append(x) if intermediates_only: return intermediates @@ -408,7 +408,7 @@ def _create_repghostnet(variant, width=1.0, pretrained=False, **kwargs): Constructs a RepGhostNet model """ cfgs = [ - # k, t, c, SE, s + # k, t, c, SE, s # stage1 [[3, 8, 16, 0, 1]], # stage2 diff --git a/timm/models/repvit.py b/timm/models/repvit.py index ddcfed55c..190f4b529 100644 --- a/timm/models/repvit.py +++ b/timm/models/repvit.py @@ -4,7 +4,7 @@ - https://arxiv.org/abs/2307.09283 @misc{wang2023repvit, - title={RepViT: Revisiting Mobile CNN From ViT Perspective}, + title={RepViT: Revisiting Mobile CNN From ViT Perspective}, author={Ao Wang and Hui Chen and Zijia Lin and Hengjun Pu and Guiguang Ding}, year={2023}, eprint={2307.09283}, @@ -369,7 +369,7 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages): x = stage(x) if feat_idx in take_indices: - intermediates.append(x) + intermediates.append(x) if intermediates_only: return intermediates diff --git a/timm/models/resnet.py b/timm/models/resnet.py index dd6aa1e8b..ca07682b8 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -32,6 +32,10 @@ def get_padding(kernel_size: int, stride: int, dilation: int = 1) -> int: class BasicBlock(nn.Module): + """Basic residual block for ResNet. + + This is the standard residual block used in ResNet-18 and ResNet-34. + """ expansion = 1 def __init__( @@ -51,7 +55,7 @@ def __init__( aa_layer: Optional[Type[nn.Module]] = None, drop_block: Optional[Type[nn.Module]] = None, drop_path: Optional[nn.Module] = None, - ): + ) -> None: """ Args: inplanes: Input channel dimensionality. @@ -63,12 +67,12 @@ def __init__( reduce_first: Reduction factor for first convolution output width of residual blocks. dilation: Dilation rate for convolution layers. first_dilation: Dilation rate for first convolution layer. - act_layer: Activation layer. - norm_layer: Normalization layer. - attn_layer: Attention layer. - aa_layer: Anti-aliasing layer. - drop_block: Class for DropBlock layer. - drop_path: Optional DropPath layer. + act_layer: Activation layer class. + norm_layer: Normalization layer class. + attn_layer: Attention layer class. + aa_layer: Anti-aliasing layer class. + drop_block: DropBlock layer class. + drop_path: Optional DropPath layer instance. """ super(BasicBlock, self).__init__() @@ -99,7 +103,8 @@ def __init__( self.dilation = dilation self.drop_path = drop_path - def zero_init_last(self): + def zero_init_last(self) -> None: + """Initialize the last batch norm layer weights to zero for better convergence.""" if getattr(self.bn2, 'weight', None) is not None: nn.init.zeros_(self.bn2.weight) @@ -130,6 +135,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Bottleneck(nn.Module): + """Bottleneck residual block for ResNet. + + This is the bottleneck block used in ResNet-50, ResNet-101, and ResNet-152. + """ expansion = 4 def __init__( @@ -149,7 +158,7 @@ def __init__( aa_layer: Optional[Type[nn.Module]] = None, drop_block: Optional[Type[nn.Module]] = None, drop_path: Optional[nn.Module] = None, - ): + ) -> None: """ Args: inplanes: Input channel dimensionality. @@ -161,12 +170,12 @@ def __init__( reduce_first: Reduction factor for first convolution output width of residual blocks. dilation: Dilation rate for convolution layers. first_dilation: Dilation rate for first convolution layer. - act_layer: Activation layer. - norm_layer: Normalization layer. - attn_layer: Attention layer. - aa_layer: Anti-aliasing layer. - drop_block: Class for DropBlock layer. - drop_path: Optional DropPath layer. + act_layer: Activation layer class. + norm_layer: Normalization layer class. + attn_layer: Attention layer class. + aa_layer: Anti-aliasing layer class. + drop_block: DropBlock layer class. + drop_path: Optional DropPath layer instance. """ super(Bottleneck, self).__init__() @@ -199,7 +208,8 @@ def __init__( self.dilation = dilation self.drop_path = drop_path - def zero_init_last(self): + def zero_init_last(self) -> None: + """Initialize the last batch norm layer weights to zero for better convergence.""" if getattr(self.bn3, 'weight', None) is not None: nn.init.zeros_(self.bn3.weight) @@ -278,7 +288,15 @@ def downsample_avg( ]) -def drop_blocks(drop_prob: float = 0.): +def drop_blocks(drop_prob: float = 0.) -> List[Optional[partial]]: + """Create DropBlock layer instances for each stage. + + Args: + drop_prob: Drop probability for DropBlock. + + Returns: + List of DropBlock partial instances or None for each stage. + """ return [ None, None, partial(DropBlock2d, drop_prob=drop_prob, block_size=5, gamma_scale=0.25) if drop_prob else None, @@ -286,7 +304,7 @@ def drop_blocks(drop_prob: float = 0.): def make_blocks( - block_fns: Tuple[Union[BasicBlock, Bottleneck]], + block_fns: Tuple[Union[Type[BasicBlock], Type[Bottleneck]], ...], channels: Tuple[int, ...], block_repeats: Tuple[int, ...], inplanes: int, @@ -298,6 +316,24 @@ def make_blocks( drop_path_rate: float = 0., **kwargs, ) -> Tuple[List[Tuple[str, nn.Module]], List[Dict[str, Any]]]: + """Create ResNet stages with specified block configurations. + + Args: + block_fns: Block class to use for each stage. + channels: Number of channels for each stage. + block_repeats: Number of blocks to repeat for each stage. + inplanes: Number of input channels. + reduce_first: Reduction factor for first convolution in each stage. + output_stride: Target output stride of network. + down_kernel_size: Kernel size for downsample layers. + avg_down: Use average pooling for downsample. + drop_block_rate: DropBlock drop rate. + drop_path_rate: Drop path rate for stochastic depth. + **kwargs: Additional arguments passed to block constructors. + + Returns: + Tuple of stage modules list and feature info list. + """ stages = [] feature_info = [] net_num_blocks = sum(block_repeats) @@ -445,7 +481,7 @@ def __init__( self.num_classes = num_classes self.drop_rate = drop_rate self.grad_checkpointing = False - + act_layer = get_act_layer(act_layer) norm_layer = get_norm_layer(norm_layer) @@ -520,7 +556,12 @@ def __init__( self.init_weights(zero_init_last=zero_init_last) @torch.jit.ignore - def init_weights(self, zero_init_last: bool = True): + def init_weights(self, zero_init_last: bool = True) -> None: + """Initialize model weights. + + Args: + zero_init_last: Zero-initialize the last BN in each residual branch. + """ for n, m in self.named_modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') @@ -530,19 +571,46 @@ def init_weights(self, zero_init_last: bool = True): m.zero_init_last() @torch.jit.ignore - def group_matcher(self, coarse: bool = False): + def group_matcher(self, coarse: bool = False) -> Dict[str, str]: + """Create regex patterns for parameter grouping. + + Args: + coarse: Use coarse (stage-level) or fine (block-level) grouping. + + Returns: + Dictionary mapping group names to regex patterns. + """ matcher = dict(stem=r'^conv1|bn1|maxpool', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+)\.(\d+)') return matcher @torch.jit.ignore - def set_grad_checkpointing(self, enable: bool = True): + def set_grad_checkpointing(self, enable: bool = True) -> None: + """Enable or disable gradient checkpointing. + + Args: + enable: Whether to enable gradient checkpointing. + """ self.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self, name_only: bool = False): + def get_classifier(self, name_only: bool = False) -> Union[str, nn.Module]: + """Get the classifier module. + + Args: + name_only: Return classifier module name instead of module. + + Returns: + Classifier module or name. + """ return 'fc' if name_only else self.fc - def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg') -> None: + """Reset the classifier head. + + Args: + num_classes: Number of classes for new classifier. + global_pool: Global pooling type. + """ self.num_classes = num_classes self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) @@ -555,17 +623,18 @@ def forward_intermediates( output_fmt: str = 'NCHW', intermediates_only: bool = False, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: - """ Forward features that returns intermediates. + """Forward features that returns intermediates. Args: - x: Input image tensor - indices: Take last n blocks if int, all if None, select matching indices if sequence - norm: Apply norm layer to compatible intermediates - stop_early: Stop iterating over blocks when last desired intermediate hit - output_fmt: Shape of intermediate feature outputs - intermediates_only: Only return intermediate features - Returns: + x: Input image tensor. + indices: Take last n blocks if int, all if None, select matching indices if sequence. + norm: Apply norm layer to compatible intermediates. + stop_early: Stop iterating over blocks when last desired intermediate hit. + output_fmt: Shape of intermediate feature outputs. + intermediates_only: Only return intermediate features. + Returns: + Features and list of intermediate features or just intermediate features. """ assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' intermediates = [] @@ -599,8 +668,16 @@ def prune_intermediate_layers( indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, - ): - """ Prune layers not required for specified intermediates. + ) -> List[int]: + """Prune layers not required for specified intermediates. + + Args: + indices: Indices of intermediate layers to keep. + prune_norm: Whether to prune normalization layers. + prune_head: Whether to prune the classifier head. + + Returns: + List of indices that were kept. """ take_indices, max_index = feature_take_indices(5, indices) layer_names = ('layer1', 'layer2', 'layer3', 'layer4') @@ -612,6 +689,7 @@ def prune_intermediate_layers( return take_indices def forward_features(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through feature extraction layers.""" x = self.conv1(x) x = self.bn1(x) x = self.act1(x) @@ -627,22 +705,43 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: return x def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + """Forward pass through classifier head. + + Args: + x: Feature tensor. + pre_logits: Return features before final classifier layer. + + Returns: + Output tensor. + """ x = self.global_pool(x) if self.drop_rate: x = F.dropout(x, p=float(self.drop_rate), training=self.training) return x if pre_logits else self.fc(x) def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass.""" x = self.forward_features(x) x = self.forward_head(x) return x -def _create_resnet(variant, pretrained: bool = False, **kwargs) -> ResNet: +def _create_resnet(variant: str, pretrained: bool = False, **kwargs) -> ResNet: + """Create a ResNet model. + + Args: + variant: Model variant name. + pretrained: Load pretrained weights. + **kwargs: Additional model arguments. + + Returns: + ResNet model instance. + """ return build_model_with_cfg(ResNet, variant, pretrained, **kwargs) -def _cfg(url='', **kwargs): +def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: + """Create a default configuration for ResNet models.""" return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), @@ -653,25 +752,29 @@ def _cfg(url='', **kwargs): } -def _tcfg(url='', **kwargs): +def _tcfg(url: str = '', **kwargs) -> Dict[str, Any]: + """Create a configuration with bicubic interpolation.""" return _cfg(url=url, **dict({'interpolation': 'bicubic'}, **kwargs)) -def _ttcfg(url='', **kwargs): +def _ttcfg(url: str = '', **kwargs) -> Dict[str, Any]: + """Create a configuration for models trained with timm.""" return _cfg(url=url, **dict({ 'interpolation': 'bicubic', 'test_input_size': (3, 288, 288), 'test_crop_pct': 0.95, 'origin_url': 'https://github.com/huggingface/pytorch-image-models', }, **kwargs)) -def _rcfg(url='', **kwargs): +def _rcfg(url: str = '', **kwargs) -> Dict[str, Any]: + """Create a configuration for ResNet-RS models.""" return _cfg(url=url, **dict({ 'interpolation': 'bicubic', 'crop_pct': 0.95, 'test_input_size': (3, 288, 288), 'test_crop_pct': 1.0, 'origin_url': 'https://github.com/huggingface/pytorch-image-models', 'paper_ids': 'arXiv:2110.00476' }, **kwargs)) -def _r3cfg(url='', **kwargs): +def _r3cfg(url: str = '', **kwargs) -> Dict[str, Any]: + """Create a configuration for ResNet-RS models with 160x160 input.""" return _cfg(url=url, **dict({ 'interpolation': 'bicubic', 'input_size': (3, 160, 160), 'pool_size': (5, 5), 'crop_pct': 0.95, 'test_input_size': (3, 224, 224), 'test_crop_pct': 0.95, @@ -679,7 +782,8 @@ def _r3cfg(url='', **kwargs): }, **kwargs)) -def _gcfg(url='', **kwargs): +def _gcfg(url: str = '', **kwargs) -> Dict[str, Any]: + """Create a configuration for Gluon pretrained models.""" return _cfg(url=url, **dict({ 'interpolation': 'bicubic', 'origin_url': 'https://cv.gluon.ai/model_zoo/classification.html', diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 5cc164ae1..38cd89ce1 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -31,7 +31,7 @@ from collections import OrderedDict # pylint: disable=g-importing-member from functools import partial -from typing import List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -48,24 +48,39 @@ class PreActBasic(nn.Module): - """ Pre-activation basic block (not in typical 'v2' implementations) - """ + """Pre-activation basic block (not in typical 'v2' implementations).""" def __init__( self, - in_chs, - out_chs=None, - bottle_ratio=1.0, - stride=1, - dilation=1, - first_dilation=None, - groups=1, - act_layer=None, - conv_layer=None, - norm_layer=None, - proj_layer=None, - drop_path_rate=0., + in_chs: int, + out_chs: Optional[int] = None, + bottle_ratio: float = 1.0, + stride: int = 1, + dilation: int = 1, + first_dilation: Optional[int] = None, + groups: int = 1, + act_layer: Optional[Callable] = None, + conv_layer: Optional[Callable] = None, + norm_layer: Optional[Callable] = None, + proj_layer: Optional[Callable] = None, + drop_path_rate: float = 0., ): + """Initialize PreActBasic block. + + Args: + in_chs: Input channels. + out_chs: Output channels. + bottle_ratio: Bottleneck ratio (not used in basic block). + stride: Stride for convolution. + dilation: Dilation rate. + first_dilation: First dilation rate. + groups: Group convolution size. + act_layer: Activation layer type. + conv_layer: Convolution layer type. + norm_layer: Normalization layer type. + proj_layer: Projection/downsampling layer type. + drop_path_rate: Stochastic depth drop rate. + """ super().__init__() first_dilation = first_dilation or dilation conv_layer = conv_layer or StdConv2d @@ -93,10 +108,19 @@ def __init__( self.conv2 = conv_layer(mid_chs, out_chs, 3, dilation=dilation, groups=groups) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() - def zero_init_last(self): - nn.init.zeros_(self.conv3.weight) + def zero_init_last(self) -> None: + """Zero-initialize the last convolution weight (not applicable to basic block).""" + nn.init.zeros_(self.conv2.weight) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor. + + Returns: + Output tensor. + """ x_preact = self.norm1(x) # shortcut branch @@ -122,19 +146,35 @@ class PreActBottleneck(nn.Module): def __init__( self, - in_chs, - out_chs=None, - bottle_ratio=0.25, - stride=1, - dilation=1, - first_dilation=None, - groups=1, - act_layer=None, - conv_layer=None, - norm_layer=None, - proj_layer=None, - drop_path_rate=0., + in_chs: int, + out_chs: Optional[int] = None, + bottle_ratio: float = 0.25, + stride: int = 1, + dilation: int = 1, + first_dilation: Optional[int] = None, + groups: int = 1, + act_layer: Optional[Callable] = None, + conv_layer: Optional[Callable] = None, + norm_layer: Optional[Callable] = None, + proj_layer: Optional[Callable] = None, + drop_path_rate: float = 0., ): + """Initialize PreActBottleneck block. + + Args: + in_chs: Input channels. + out_chs: Output channels. + bottle_ratio: Bottleneck ratio. + stride: Stride for convolution. + dilation: Dilation rate. + first_dilation: First dilation rate. + groups: Group convolution size. + act_layer: Activation layer type. + conv_layer: Convolution layer type. + norm_layer: Normalization layer type. + proj_layer: Projection/downsampling layer type. + drop_path_rate: Stochastic depth drop rate. + """ super().__init__() first_dilation = first_dilation or dilation conv_layer = conv_layer or StdConv2d @@ -164,10 +204,19 @@ def __init__( self.conv3 = conv_layer(mid_chs, out_chs, 1) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() - def zero_init_last(self): + def zero_init_last(self) -> None: + """Zero-initialize the last convolution weight.""" nn.init.zeros_(self.conv3.weight) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor. + + Returns: + Output tensor. + """ x_preact = self.norm1(x) # shortcut branch @@ -188,18 +237,18 @@ class Bottleneck(nn.Module): """ def __init__( self, - in_chs, - out_chs=None, - bottle_ratio=0.25, - stride=1, - dilation=1, - first_dilation=None, - groups=1, - act_layer=None, - conv_layer=None, - norm_layer=None, - proj_layer=None, - drop_path_rate=0., + in_chs: int, + out_chs: Optional[int] = None, + bottle_ratio: float = 0.25, + stride: int = 1, + dilation: int = 1, + first_dilation: Optional[int] = None, + groups: int = 1, + act_layer: Optional[Callable] = None, + conv_layer: Optional[Callable] = None, + norm_layer: Optional[Callable] = None, + proj_layer: Optional[Callable] = None, + drop_path_rate: float = 0., ): super().__init__() first_dilation = first_dilation or dilation @@ -231,11 +280,20 @@ def __init__( self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() self.act3 = act_layer(inplace=True) - def zero_init_last(self): + def zero_init_last(self) -> None: + """Zero-initialize the last batch norm weight.""" if getattr(self.norm3, 'weight', None) is not None: nn.init.zeros_(self.norm3.weight) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor. + + Returns: + Output tensor. + """ # shortcut branch shortcut = x if self.downsample is not None: @@ -254,38 +312,49 @@ def forward(self, x): class DownsampleConv(nn.Module): + """1x1 convolution downsampling module.""" + def __init__( self, - in_chs, - out_chs, - stride=1, - dilation=1, - first_dilation=None, - preact=True, - conv_layer=None, - norm_layer=None, + in_chs: int, + out_chs: int, + stride: int = 1, + dilation: int = 1, + first_dilation: Optional[int] = None, + preact: bool = True, + conv_layer: Optional[Callable] = None, + norm_layer: Optional[Callable] = None, ): super(DownsampleConv, self).__init__() self.conv = conv_layer(in_chs, out_chs, 1, stride=stride) self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor. + + Returns: + Downsampled tensor. + """ return self.norm(self.conv(x)) class DownsampleAvg(nn.Module): + """AvgPool downsampling as in 'D' ResNet variants.""" + def __init__( self, - in_chs, - out_chs, - stride=1, - dilation=1, - first_dilation=None, - preact=True, - conv_layer=None, - norm_layer=None, + in_chs: int, + out_chs: int, + stride: int = 1, + dilation: int = 1, + first_dilation: Optional[int] = None, + preact: bool = True, + conv_layer: Optional[Callable] = None, + norm_layer: Optional[Callable] = None, ): - """ AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment.""" super(DownsampleAvg, self).__init__() avg_stride = stride if dilation == 1 else 1 if stride > 1 or dilation > 1: @@ -296,7 +365,15 @@ def __init__( self.conv = conv_layer(in_chs, out_chs, 1, stride=1) self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor. + + Returns: + Downsampled tensor. + """ return self.norm(self.conv(self.pool(x))) @@ -304,20 +381,20 @@ class ResNetStage(nn.Module): """ResNet Stage.""" def __init__( self, - in_chs, - out_chs, - stride, - dilation, - depth, - bottle_ratio=0.25, - groups=1, - avg_down=False, - block_dpr=None, - block_fn=PreActBottleneck, - act_layer=None, - conv_layer=None, - norm_layer=None, - **block_kwargs, + in_chs: int, + out_chs: int, + stride: int, + dilation: int, + depth: int, + bottle_ratio: float = 0.25, + groups: int = 1, + avg_down: bool = False, + block_dpr: Optional[List[float]] = None, + block_fn: Callable = PreActBottleneck, + act_layer: Optional[Callable] = None, + conv_layer: Optional[Callable] = None, + norm_layer: Optional[Callable] = None, + **block_kwargs: Any, ): super(ResNetStage, self).__init__() first_dilation = 1 if dilation in (1, 2) else 2 @@ -345,23 +422,39 @@ def __init__( first_dilation = dilation proj_layer = None - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through all blocks in the stage. + + Args: + x: Input tensor. + + Returns: + Output tensor. + """ x = self.blocks(x) return x -def is_stem_deep(stem_type): +def is_stem_deep(stem_type: str) -> bool: + """Check if stem type is deep (has multiple convolutions). + + Args: + stem_type: Type of stem to check. + + Returns: + True if stem is deep, False otherwise. + """ return any([s in stem_type for s in ('deep', 'tiered')]) def create_resnetv2_stem( - in_chs, - out_chs=64, - stem_type='', - preact=True, - conv_layer=StdConv2d, - norm_layer=partial(GroupNormAct, num_groups=32), -): + in_chs: int, + out_chs: int = 64, + stem_type: str = '', + preact: bool = True, + conv_layer: Callable = StdConv2d, + norm_layer: Callable = partial(GroupNormAct, num_groups=32), +) -> nn.Sequential: stem = OrderedDict() assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same', 'tiered') @@ -405,25 +498,25 @@ class ResNetV2(nn.Module): def __init__( self, - layers, - channels=(256, 512, 1024, 2048), - num_classes=1000, - in_chans=3, - global_pool='avg', - output_stride=32, - width_factor=1, - stem_chs=64, - stem_type='', - avg_down=False, - preact=True, - basic=False, - bottle_ratio=0.25, - act_layer=nn.ReLU, - norm_layer=partial(GroupNormAct, num_groups=32), - conv_layer=StdConv2d, - drop_rate=0., - drop_path_rate=0., - zero_init_last=False, + layers: List[int], + channels: Tuple[int, ...] = (256, 512, 1024, 2048), + num_classes: int = 1000, + in_chans: int = 3, + global_pool: str = 'avg', + output_stride: int = 32, + width_factor: int = 1, + stem_chs: int = 64, + stem_type: str = '', + avg_down: bool = False, + preact: bool = True, + basic: bool = False, + bottle_ratio: float = 0.25, + act_layer: Callable = nn.ReLU, + norm_layer: Callable = partial(GroupNormAct, num_groups=32), + conv_layer: Callable = StdConv2d, + drop_rate: float = 0., + drop_path_rate: float = 0., + zero_init_last: bool = False, ): """ Args: @@ -514,15 +607,18 @@ def __init__( self.grad_checkpointing = False @torch.jit.ignore - def init_weights(self, zero_init_last=True): + def init_weights(self, zero_init_last: bool = True) -> None: + """Initialize model weights.""" named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) @torch.jit.ignore() - def load_pretrained(self, checkpoint_path, prefix='resnet/'): + def load_pretrained(self, checkpoint_path: str, prefix: str = 'resnet/') -> None: + """Load pretrained weights.""" _load_weights(self, checkpoint_path, prefix) @torch.jit.ignore - def group_matcher(self, coarse=False): + def group_matcher(self, coarse: bool = False) -> Dict[str, Any]: + """Group parameters for optimization.""" matcher = dict( stem=r'^stem', blocks=r'^stages\.(\d+)' if coarse else [ @@ -533,14 +629,22 @@ def group_matcher(self, coarse=False): return matcher @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): + def set_grad_checkpointing(self, enable: bool = True) -> None: + """Enable or disable gradient checkpointing.""" self.grad_checkpointing = enable @torch.jit.ignore def get_classifier(self) -> nn.Module: + """Get the classifier head.""" return self.head.fc - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None: + """Reset the classifier head. + + Args: + num_classes: Number of classes for new classifier. + global_pool: Global pooling type. + """ self.num_classes = num_classes self.head.reset(num_classes, global_pool) @@ -568,7 +672,7 @@ def forward_intermediates( assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' intermediates = [] take_indices, max_index = feature_take_indices(5, indices) - + # forward pass feat_idx = 0 H, W = x.shape[-2:] @@ -591,11 +695,11 @@ def forward_intermediates( x_inter = self.norm(x) if norm else x intermediates.append(x_inter) else: - intermediates.append(x) + intermediates.append(x) if intermediates_only: return intermediates - + if feat_idx == last_idx: x = self.norm(x) @@ -617,7 +721,15 @@ def prune_intermediate_layers( self.reset_classifier(0, '') return take_indices - def forward_features(self, x): + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through feature extraction layers. + + Args: + x: Input tensor. + + Returns: + Feature tensor. + """ x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.stages, x, flatten=True) @@ -626,16 +738,40 @@ def forward_features(self, x): x = self.norm(x) return x - def forward_head(self, x, pre_logits: bool = False): + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + """Forward pass through classifier head. + + Args: + x: Input features. + pre_logits: Return features before final linear layer. + + Returns: + Classification logits or features. + """ return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor. + + Returns: + Output logits. + """ x = self.forward_features(x) x = self.forward_head(x) return x -def _init_weights(module: nn.Module, name: str = '', zero_init_last=True): +def _init_weights(module: nn.Module, name: str = '', zero_init_last: bool = True) -> None: + """Initialize module weights. + + Args: + module: PyTorch module to initialize. + name: Module name. + zero_init_last: Zero-initialize last layer weights. + """ if isinstance(module, nn.Linear) or ('head.fc' in name and isinstance(module, nn.Conv2d)): nn.init.normal_(module.weight, mean=0.0, std=0.01) nn.init.zeros_(module.bias) @@ -688,7 +824,17 @@ def t2p(conv_weights): block.downsample.conv.weight.copy_(t2p(w)) -def _create_resnetv2(variant, pretrained=False, **kwargs): +def _create_resnetv2(variant: str, pretrained: bool = False, **kwargs: Any) -> ResNetV2: + """Create a ResNetV2 model. + + Args: + variant: Model variant name. + pretrained: Load pretrained weights. + **kwargs: Additional model arguments. + + Returns: + ResNetV2 model instance. + """ feature_cfg = dict(flatten_sequential=True) return build_model_with_cfg( ResNetV2, variant, pretrained, @@ -697,7 +843,17 @@ def _create_resnetv2(variant, pretrained=False, **kwargs): ) -def _create_resnetv2_bit(variant, pretrained=False, **kwargs): +def _create_resnetv2_bit(variant: str, pretrained: bool = False, **kwargs: Any) -> ResNetV2: + """Create a ResNetV2 model with BiT weights. + + Args: + variant: Model variant name. + pretrained: Load pretrained weights. + **kwargs: Additional model arguments. + + Returns: + ResNetV2 model instance. + """ return _create_resnetv2( variant, pretrained=pretrained, @@ -707,7 +863,7 @@ def _create_resnetv2_bit(variant, pretrained=False, **kwargs): ) -def _cfg(url='', **kwargs): +def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]: return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), @@ -819,43 +975,50 @@ def _cfg(url='', **kwargs): @register_model -def resnetv2_50x1_bit(pretrained=False, **kwargs) -> ResNetV2: +def resnetv2_50x1_bit(pretrained: bool = False, **kwargs: Any) -> ResNetV2: + """ResNetV2-50x1-BiT model.""" return _create_resnetv2_bit( 'resnetv2_50x1_bit', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=1, **kwargs) @register_model -def resnetv2_50x3_bit(pretrained=False, **kwargs) -> ResNetV2: +def resnetv2_50x3_bit(pretrained: bool = False, **kwargs: Any) -> ResNetV2: + """ResNetV2-50x3-BiT model.""" return _create_resnetv2_bit( 'resnetv2_50x3_bit', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=3, **kwargs) @register_model -def resnetv2_101x1_bit(pretrained=False, **kwargs) -> ResNetV2: +def resnetv2_101x1_bit(pretrained: bool = False, **kwargs: Any) -> ResNetV2: + """ResNetV2-101x1-BiT model.""" return _create_resnetv2_bit( 'resnetv2_101x1_bit', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=1, **kwargs) @register_model -def resnetv2_101x3_bit(pretrained=False, **kwargs) -> ResNetV2: +def resnetv2_101x3_bit(pretrained: bool = False, **kwargs: Any) -> ResNetV2: + """ResNetV2-101x3-BiT model.""" return _create_resnetv2_bit( 'resnetv2_101x3_bit', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=3, **kwargs) @register_model -def resnetv2_152x2_bit(pretrained=False, **kwargs) -> ResNetV2: +def resnetv2_152x2_bit(pretrained: bool = False, **kwargs: Any) -> ResNetV2: + """ResNetV2-152x2-BiT model.""" return _create_resnetv2_bit( 'resnetv2_152x2_bit', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs) @register_model -def resnetv2_152x4_bit(pretrained=False, **kwargs) -> ResNetV2: +def resnetv2_152x4_bit(pretrained: bool = False, **kwargs: Any) -> ResNetV2: + """ResNetV2-152x4-BiT model.""" return _create_resnetv2_bit( 'resnetv2_152x4_bit', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=4, **kwargs) @register_model -def resnetv2_18(pretrained=False, **kwargs) -> ResNetV2: +def resnetv2_18(pretrained: bool = False, **kwargs: Any) -> ResNetV2: + """ResNetV2-18 model.""" model_args = dict( layers=[2, 2, 2, 2], channels=(64, 128, 256, 512), basic=True, bottle_ratio=1.0, conv_layer=create_conv2d, norm_layer=BatchNormAct2d @@ -864,7 +1027,8 @@ def resnetv2_18(pretrained=False, **kwargs) -> ResNetV2: @register_model -def resnetv2_18d(pretrained=False, **kwargs) -> ResNetV2: +def resnetv2_18d(pretrained: bool = False, **kwargs: Any) -> ResNetV2: + """ResNetV2-18d model (deep stem variant).""" model_args = dict( layers=[2, 2, 2, 2], channels=(64, 128, 256, 512), basic=True, bottle_ratio=1.0, conv_layer=create_conv2d, norm_layer=BatchNormAct2d, stem_type='deep', avg_down=True @@ -873,7 +1037,8 @@ def resnetv2_18d(pretrained=False, **kwargs) -> ResNetV2: @register_model -def resnetv2_34(pretrained=False, **kwargs) -> ResNetV2: +def resnetv2_34(pretrained: bool = False, **kwargs: Any) -> ResNetV2: + """ResNetV2-34 model.""" model_args = dict( layers=(3, 4, 6, 3), channels=(64, 128, 256, 512), basic=True, bottle_ratio=1.0, conv_layer=create_conv2d, norm_layer=BatchNormAct2d @@ -882,7 +1047,8 @@ def resnetv2_34(pretrained=False, **kwargs) -> ResNetV2: @register_model -def resnetv2_34d(pretrained=False, **kwargs) -> ResNetV2: +def resnetv2_34d(pretrained: bool = False, **kwargs: Any) -> ResNetV2: + """ResNetV2-34d model (deep stem variant).""" model_args = dict( layers=(3, 4, 6, 3), channels=(64, 128, 256, 512), basic=True, bottle_ratio=1.0, conv_layer=create_conv2d, norm_layer=BatchNormAct2d, stem_type='deep', avg_down=True @@ -891,13 +1057,15 @@ def resnetv2_34d(pretrained=False, **kwargs) -> ResNetV2: @register_model -def resnetv2_50(pretrained=False, **kwargs) -> ResNetV2: +def resnetv2_50(pretrained: bool = False, **kwargs: Any) -> ResNetV2: + """ResNetV2-50 model.""" model_args = dict(layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d) return _create_resnetv2('resnetv2_50', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model -def resnetv2_50d(pretrained=False, **kwargs) -> ResNetV2: +def resnetv2_50d(pretrained: bool = False, **kwargs: Any) -> ResNetV2: + """ResNetV2-50d model (deep stem variant).""" model_args = dict( layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, stem_type='deep', avg_down=True) @@ -905,7 +1073,8 @@ def resnetv2_50d(pretrained=False, **kwargs) -> ResNetV2: @register_model -def resnetv2_50t(pretrained=False, **kwargs) -> ResNetV2: +def resnetv2_50t(pretrained: bool = False, **kwargs: Any) -> ResNetV2: + """ResNetV2-50t model (tiered stem variant).""" model_args = dict( layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, stem_type='tiered', avg_down=True) @@ -913,13 +1082,15 @@ def resnetv2_50t(pretrained=False, **kwargs) -> ResNetV2: @register_model -def resnetv2_101(pretrained=False, **kwargs) -> ResNetV2: +def resnetv2_101(pretrained: bool = False, **kwargs: Any) -> ResNetV2: + """ResNetV2-101 model.""" model_args = dict(layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d) return _create_resnetv2('resnetv2_101', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model -def resnetv2_101d(pretrained=False, **kwargs) -> ResNetV2: +def resnetv2_101d(pretrained: bool = False, **kwargs: Any) -> ResNetV2: + """ResNetV2-101d model (deep stem variant).""" model_args = dict( layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, stem_type='deep', avg_down=True) @@ -927,13 +1098,15 @@ def resnetv2_101d(pretrained=False, **kwargs) -> ResNetV2: @register_model -def resnetv2_152(pretrained=False, **kwargs) -> ResNetV2: +def resnetv2_152(pretrained: bool = False, **kwargs: Any) -> ResNetV2: + """ResNetV2-152 model.""" model_args = dict(layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d) return _create_resnetv2('resnetv2_152', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model -def resnetv2_152d(pretrained=False, **kwargs) -> ResNetV2: +def resnetv2_152d(pretrained: bool = False, **kwargs: Any) -> ResNetV2: + """ResNetV2-152d model (deep stem variant).""" model_args = dict( layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, stem_type='deep', avg_down=True) @@ -943,7 +1116,8 @@ def resnetv2_152d(pretrained=False, **kwargs) -> ResNetV2: # Experimental configs (may change / be removed) @register_model -def resnetv2_50d_gn(pretrained=False, **kwargs) -> ResNetV2: +def resnetv2_50d_gn(pretrained: bool = False, **kwargs: Any) -> ResNetV2: + """ResNetV2-50d model with Group Normalization.""" model_args = dict( layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=GroupNormAct, stem_type='deep', avg_down=True) @@ -951,7 +1125,8 @@ def resnetv2_50d_gn(pretrained=False, **kwargs) -> ResNetV2: @register_model -def resnetv2_50d_evos(pretrained=False, **kwargs) -> ResNetV2: +def resnetv2_50d_evos(pretrained: bool = False, **kwargs: Any) -> ResNetV2: + """ResNetV2-50d model with EvoNorm.""" model_args = dict( layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dS0, stem_type='deep', avg_down=True) @@ -959,7 +1134,8 @@ def resnetv2_50d_evos(pretrained=False, **kwargs) -> ResNetV2: @register_model -def resnetv2_50d_frn(pretrained=False, **kwargs) -> ResNetV2: +def resnetv2_50d_frn(pretrained: bool = False, **kwargs: Any) -> ResNetV2: + """ResNetV2-50d model with Filter Response Normalization.""" model_args = dict( layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=FilterResponseNormTlu2d, stem_type='deep', avg_down=True) diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index dd3cb4f32..04e284158 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -12,7 +12,7 @@ from functools import partial from math import ceil -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -32,19 +32,38 @@ class LinearBottleneck(nn.Module): + """Linear bottleneck block for ReXNet. + + A mobile inverted residual bottleneck block as used in MobileNetV2 and subsequent models. + """ + def __init__( self, - in_chs, - out_chs, - stride, - dilation=(1, 1), - exp_ratio=1.0, - se_ratio=0., - ch_div=1, - act_layer='swish', - dw_act_layer='relu6', - drop_path=None, + in_chs: int, + out_chs: int, + stride: int, + dilation: Tuple[int, int] = (1, 1), + exp_ratio: float = 1.0, + se_ratio: float = 0., + ch_div: int = 1, + act_layer: str = 'swish', + dw_act_layer: str = 'relu6', + drop_path: Optional[nn.Module] = None, ): + """Initialize LinearBottleneck. + + Args: + in_chs: Number of input channels. + out_chs: Number of output channels. + stride: Stride for depthwise conv. + dilation: Dilation rates. + exp_ratio: Expansion ratio. + se_ratio: Squeeze-excitation ratio. + ch_div: Channel divisor. + act_layer: Activation layer for expansion. + dw_act_layer: Activation layer for depthwise. + drop_path: Drop path module. + """ super(LinearBottleneck, self).__init__() self.use_shortcut = stride == 1 and dilation[0] == dilation[1] and in_chs <= out_chs self.in_channels = in_chs @@ -75,10 +94,26 @@ def __init__( self.conv_pwl = ConvNormAct(dw_chs, out_chs, 1, apply_act=False) self.drop_path = drop_path - def feat_channels(self, exp=False): + def feat_channels(self, exp: bool = False) -> int: + """Get feature channel count. + + Args: + exp: Return expanded channels if True. + + Returns: + Number of feature channels. + """ return self.conv_dw.out_channels if exp else self.out_channels - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor. + + Returns: + Output tensor. + """ shortcut = x if self.conv_exp is not None: x = self.conv_exp(x) @@ -95,13 +130,26 @@ def forward(self, x): def _block_cfg( - width_mult=1.0, - depth_mult=1.0, - initial_chs=16, - final_chs=180, - se_ratio=0., - ch_div=1, -): + width_mult: float = 1.0, + depth_mult: float = 1.0, + initial_chs: int = 16, + final_chs: int = 180, + se_ratio: float = 0., + ch_div: int = 1, +) -> List[Tuple[int, float, int, float]]: + """Generate ReXNet block configuration. + + Args: + width_mult: Width multiplier. + depth_mult: Depth multiplier. + initial_chs: Initial channel count. + final_chs: Final channel count. + se_ratio: Squeeze-excitation ratio. + ch_div: Channel divisor. + + Returns: + List of tuples (out_channels, exp_ratio, stride, se_ratio). + """ layers = [1, 2, 2, 3, 3, 5] strides = [1, 2, 2, 2, 1, 2] layers = [ceil(element * depth_mult) for element in layers] @@ -122,15 +170,30 @@ def _block_cfg( def _build_blocks( - block_cfg, - prev_chs, - width_mult, - ch_div=1, - output_stride=32, - act_layer='swish', - dw_act_layer='relu6', - drop_path_rate=0., -): + block_cfg: List[Tuple[int, float, int, float]], + prev_chs: int, + width_mult: float, + ch_div: int = 1, + output_stride: int = 32, + act_layer: str = 'swish', + dw_act_layer: str = 'relu6', + drop_path_rate: float = 0., +) -> Tuple[List[nn.Module], List[Dict[str, Any]]]: + """Build ReXNet blocks from configuration. + + Args: + block_cfg: Block configuration list. + prev_chs: Previous channel count. + width_mult: Width multiplier. + ch_div: Channel divisor. + output_stride: Target output stride. + act_layer: Activation layer name. + dw_act_layer: Depthwise activation layer name. + drop_path_rate: Drop path rate. + + Returns: + Tuple of (features list, feature_info list). + """ feat_chs = [prev_chs] feature_info = [] curr_stride = 2 @@ -170,23 +233,47 @@ def _build_blocks( class RexNet(nn.Module): + """ReXNet model architecture. + + Based on `ReXNet: Diminishing Representational Bottleneck on Convolutional Neural Network` + - https://arxiv.org/abs/2007.00992 + """ + def __init__( self, - in_chans=3, - num_classes=1000, - global_pool='avg', - output_stride=32, - initial_chs=16, - final_chs=180, - width_mult=1.0, - depth_mult=1.0, - se_ratio=1/12., - ch_div=1, - act_layer='swish', - dw_act_layer='relu6', - drop_rate=0.2, - drop_path_rate=0., + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + output_stride: int = 32, + initial_chs: int = 16, + final_chs: int = 180, + width_mult: float = 1.0, + depth_mult: float = 1.0, + se_ratio: float = 1/12., + ch_div: int = 1, + act_layer: str = 'swish', + dw_act_layer: str = 'relu6', + drop_rate: float = 0.2, + drop_path_rate: float = 0., ): + """Initialize ReXNet. + + Args: + in_chans: Number of input channels. + num_classes: Number of classes for classification. + global_pool: Global pooling type. + output_stride: Output stride. + initial_chs: Initial channel count. + final_chs: Final channel count. + width_mult: Width multiplier. + depth_mult: Depth multiplier. + se_ratio: Squeeze-excitation ratio. + ch_div: Channel divisor. + act_layer: Activation layer name. + dw_act_layer: Depthwise activation layer name. + drop_rate: Dropout rate. + drop_path_rate: Drop path rate. + """ super(RexNet, self).__init__() self.num_classes = num_classes self.drop_rate = drop_rate @@ -216,7 +303,15 @@ def __init__( efficientnet_init_weights(self) @torch.jit.ignore - def group_matcher(self, coarse=False): + def group_matcher(self, coarse: bool = False) -> Dict[str, Any]: + """Group matcher for parameter groups. + + Args: + coarse: Whether to use coarse grouping. + + Returns: + Dictionary of grouped parameters. + """ matcher = dict( stem=r'^stem', blocks=r'^features\.(\d+)', @@ -224,14 +319,30 @@ def group_matcher(self, coarse=False): return matcher @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): + def set_grad_checkpointing(self, enable: bool = True) -> None: + """Enable or disable gradient checkpointing. + + Args: + enable: Whether to enable gradient checkpointing. + """ self.grad_checkpointing = enable @torch.jit.ignore def get_classifier(self) -> nn.Module: + """Get the classifier module. + + Returns: + Classifier module. + """ return self.head.fc - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None: + """Reset the classifier. + + Args: + num_classes: Number of classes for new classifier. + global_pool: Global pooling type. + """ self.num_classes = num_classes self.head.reset(num_classes, global_pool) @@ -273,7 +384,7 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages): x = stage(x) if feat_idx in take_indices: - intermediates.append(x) + intermediates.append(x) if intermediates_only: return intermediates @@ -285,8 +396,16 @@ def prune_intermediate_layers( indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, - ): - """ Prune layers not required for specified intermediates. + ) -> List[int]: + """Prune layers not required for specified intermediates. + + Args: + indices: Indices of intermediate layers to keep. + prune_norm: Whether to prune normalization layer. + prune_head: Whether to prune the classifier head. + + Returns: + List of indices that were kept. """ stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info] take_indices, max_index = feature_take_indices(len(stage_ends), indices) @@ -296,7 +415,15 @@ def prune_intermediate_layers( self.reset_classifier(0, '') return take_indices - def forward_features(self, x): + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through feature extraction layers. + + Args: + x: Input tensor. + + Returns: + Feature tensor. + """ x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.features, x, flatten=True) @@ -304,16 +431,43 @@ def forward_features(self, x): x = self.features(x) return x - def forward_head(self, x, pre_logits: bool = False): + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + """Forward pass through head. + + Args: + x: Input features. + pre_logits: Return features before final linear layer. + + Returns: + Classification logits or features. + """ return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor. + + Returns: + Output logits. + """ x = self.forward_features(x) x = self.forward_head(x) return x -def _create_rexnet(variant, pretrained, **kwargs): +def _create_rexnet(variant: str, pretrained: bool, **kwargs) -> RexNet: + """Create a ReXNet model. + + Args: + variant: Model variant name. + pretrained: Load pretrained weights. + **kwargs: Additional model arguments. + + Returns: + ReXNet model instance. + """ feature_cfg = dict(flatten_sequential=True) return build_model_with_cfg( RexNet, @@ -324,7 +478,16 @@ def _create_rexnet(variant, pretrained, **kwargs): ) -def _cfg(url='', **kwargs): +def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: + """Create default configuration dictionary. + + Args: + url: Model weight URL. + **kwargs: Additional configuration options. + + Returns: + Configuration dictionary. + """ return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bicubic', @@ -361,60 +524,60 @@ def _cfg(url='', **kwargs): @register_model -def rexnet_100(pretrained=False, **kwargs) -> RexNet: +def rexnet_100(pretrained: bool = False, **kwargs) -> RexNet: """ReXNet V1 1.0x""" return _create_rexnet('rexnet_100', pretrained, **kwargs) @register_model -def rexnet_130(pretrained=False, **kwargs) -> RexNet: +def rexnet_130(pretrained: bool = False, **kwargs) -> RexNet: """ReXNet V1 1.3x""" return _create_rexnet('rexnet_130', pretrained, width_mult=1.3, **kwargs) @register_model -def rexnet_150(pretrained=False, **kwargs) -> RexNet: +def rexnet_150(pretrained: bool = False, **kwargs) -> RexNet: """ReXNet V1 1.5x""" return _create_rexnet('rexnet_150', pretrained, width_mult=1.5, **kwargs) @register_model -def rexnet_200(pretrained=False, **kwargs) -> RexNet: +def rexnet_200(pretrained: bool = False, **kwargs) -> RexNet: """ReXNet V1 2.0x""" return _create_rexnet('rexnet_200', pretrained, width_mult=2.0, **kwargs) @register_model -def rexnet_300(pretrained=False, **kwargs) -> RexNet: +def rexnet_300(pretrained: bool = False, **kwargs) -> RexNet: """ReXNet V1 3.0x""" return _create_rexnet('rexnet_300', pretrained, width_mult=3.0, **kwargs) @register_model -def rexnetr_100(pretrained=False, **kwargs) -> RexNet: +def rexnetr_100(pretrained: bool = False, **kwargs) -> RexNet: """ReXNet V1 1.0x w/ rounded (mod 8) channels""" return _create_rexnet('rexnetr_100', pretrained, ch_div=8, **kwargs) @register_model -def rexnetr_130(pretrained=False, **kwargs) -> RexNet: +def rexnetr_130(pretrained: bool = False, **kwargs) -> RexNet: """ReXNet V1 1.3x w/ rounded (mod 8) channels""" return _create_rexnet('rexnetr_130', pretrained, width_mult=1.3, ch_div=8, **kwargs) @register_model -def rexnetr_150(pretrained=False, **kwargs) -> RexNet: +def rexnetr_150(pretrained: bool = False, **kwargs) -> RexNet: """ReXNet V1 1.5x w/ rounded (mod 8) channels""" return _create_rexnet('rexnetr_150', pretrained, width_mult=1.5, ch_div=8, **kwargs) @register_model -def rexnetr_200(pretrained=False, **kwargs) -> RexNet: +def rexnetr_200(pretrained: bool = False, **kwargs) -> RexNet: """ReXNet V1 2.0x w/ rounded (mod 8) channels""" return _create_rexnet('rexnetr_200', pretrained, width_mult=2.0, ch_div=8, **kwargs) @register_model -def rexnetr_300(pretrained=False, **kwargs) -> RexNet: +def rexnetr_300(pretrained: bool = False, **kwargs) -> RexNet: """ReXNet V1 3.0x w/ rounded (mod 16) channels""" return _create_rexnet('rexnetr_300', pretrained, width_mult=3.0, ch_div=16, **kwargs) diff --git a/timm/models/shvit.py b/timm/models/shvit.py index 33364b059..be3e206ee 100644 --- a/timm/models/shvit.py +++ b/timm/models/shvit.py @@ -74,12 +74,12 @@ def fuse(self) -> nn.Conv2d: w = c.weight * w[:, None, None, None] b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 m = nn.Conv2d( - in_channels=w.size(1) * self.c.groups, - out_channels=w.size(0), - kernel_size=w.shape[2:], - stride=self.c.stride, - padding=self.c.padding, - dilation=self.c.dilation, + in_channels=w.size(1) * self.c.groups, + out_channels=w.size(0), + kernel_size=w.shape[2:], + stride=self.c.stride, + padding=self.c.padding, + dilation=self.c.dilation, groups=self.c.groups, device=c.weight.device, dtype=c.weight.dtype, @@ -183,7 +183,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: qkv = self.qkv(x1) q, k, v = torch.split(qkv, [self.qk_dim, self.qk_dim, self.pdim], dim=1) q, k, v = q.flatten(2), k.flatten(2), v.flatten(2) - + attn = (q.transpose(-2, -1) @ k) * self.scale attn = attn.softmax(dim=-1) x1 = (v @ attn.transpose(-2, -1)).reshape(B, self.pdim, H, W) @@ -203,9 +203,9 @@ def __init__( ): super().__init__() self.conv = Residual(Conv2dNorm(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0)) - if type == "s": + if type == "s": self.mixer = Residual(SHSA(dim, qk_dim, pdim, norm_layer, act_layer)) - else: + else: self.mixer = nn.Identity() self.ffn = Residual(FFN(dim, int(dim * 2))) @@ -330,7 +330,7 @@ def set_grad_checkpointing(self, enable=True): @torch.jit.ignore def get_classifier(self) -> nn.Module: return self.head.l - + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes # cannot meaningfully change pooling of efficient head after creation @@ -373,7 +373,7 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages): x = stage(x) if feat_idx in take_indices: - intermediates.append(x) + intermediates.append(x) if intermediates_only: return intermediates @@ -405,7 +405,7 @@ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tenso if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) return x if pre_logits else self.head(x) - + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.forward_features(x) x = self.forward_head(x) diff --git a/timm/models/starnet.py b/timm/models/starnet.py index bc140e00d..646fd324b 100644 --- a/timm/models/starnet.py +++ b/timm/models/starnet.py @@ -94,10 +94,10 @@ def __init__( self.grad_checkpointing = False self.feature_info = [] stem_chs = 32 - + # stem layer self.stem = nn.Sequential( - ConvBN(in_chans, stem_chs, kernel_size=3, stride=2, padding=1), + ConvBN(in_chans, stem_chs, kernel_size=3, stride=2, padding=1), act_layer(), ) prev_chs = stem_chs @@ -204,7 +204,7 @@ def forward_intermediates( x_inter = self.norm(x) # applying final norm last intermediate else: x_inter = x - intermediates.append(x_inter) + intermediates.append(x_inter) if intermediates_only: return intermediates diff --git a/timm/models/swiftformer.py b/timm/models/swiftformer.py index 8a5842823..5998c233f 100644 --- a/timm/models/swiftformer.py +++ b/timm/models/swiftformer.py @@ -36,7 +36,7 @@ def __init__(self, dim: int, init_values: float = 1e-5, inplace: bool = False): def forward(self, x: torch.Tensor) -> torch.Tensor: return x.mul_(self.gamma) if self.inplace else x * self.gamma - + class Embedding(nn.Module): """ @@ -73,7 +73,7 @@ class ConvEncoder(nn.Module): Output: tensor with shape [B, C, H, W] """ def __init__( - self, + self, dim: int, hidden_dim: int = 64, kernel_size: int = 3, @@ -150,7 +150,7 @@ def __init__(self, in_dims: int = 512, token_dim: int = 256, num_heads: int = 1) self.to_key = nn.Linear(in_dims, token_dim * num_heads) self.w_g = nn.Parameter(torch.randn(token_dim * num_heads, 1)) - + self.proj = nn.Linear(token_dim * num_heads, token_dim * num_heads) self.final = nn.Linear(token_dim * num_heads, token_dim) @@ -203,7 +203,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.layer_scale(x) x = skip + self.drop_path(x) return x - + class Block(nn.Module): """ @@ -225,7 +225,7 @@ def __init__( ): super().__init__() self.local_representation = LocalRepresentation( - dim=dim, + dim=dim, use_layer_scale=use_layer_scale, act_layer=act_layer, norm_layer=norm_layer, @@ -235,7 +235,7 @@ def __init__( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, - norm_layer=norm_layer, + norm_layer=norm_layer, drop=drop_rate, ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -280,7 +280,7 @@ def __init__( block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1) if layers[index] - block_idx <= 1: blocks.append(Block( - dim, + dim, mlp_ratio=mlp_ratio, drop_rate=drop_rate, drop_path=block_dpr, @@ -291,8 +291,8 @@ def __init__( )) else: blocks.append(ConvEncoder( - dim=dim, - hidden_dim=int(mlp_ratio * dim), + dim=dim, + hidden_dim=int(mlp_ratio * dim), kernel_size=3, drop_path=block_dpr, act_layer=act_layer, @@ -357,9 +357,9 @@ def __init__( padding=down_pad, ) if downsamples[i] else nn.Identity() stage = Stage( - dim=embed_dims[i], - index=i, - layers=layers, + dim=embed_dims[i], + index=i, + layers=layers, mlp_ratio=mlp_ratios, act_layer=act_layer, drop_rate=drop_rate, @@ -429,7 +429,7 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): @torch.jit.ignore def set_distilled_training(self, enable: bool = True): self.distilled_training = enable - + def forward_intermediates( self, x: torch.Tensor, @@ -470,7 +470,7 @@ def forward_intermediates( x_inter = self.norm(x) # applying final norm last intermediate else: x_inter = x - intermediates.append(x_inter) + intermediates.append(x_inter) if intermediates_only: return intermediates diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index 54d57f814..7eeae8316 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -17,7 +17,7 @@ # -------------------------------------------------------- import logging import math -from typing import Callable, List, Optional, Tuple, Union +from typing import Any, Dict, Callable, List, Optional, Set, Tuple, Union import torch import torch.nn as nn @@ -43,15 +43,14 @@ def window_partition( x: torch.Tensor, window_size: Tuple[int, int], ) -> torch.Tensor: - """ - Partition into non-overlapping windows with padding if needed. + """Partition into non-overlapping windows. + Args: - x (tensor): input tokens with [B, H, W, C]. - window_size (int): window size. + x: Input tokens with shape [B, H, W, C]. + window_size: Window size. Returns: - windows: windows after partition with [B * num_windows, window_size, window_size, C]. - (Hp, Wp): padded height and width before partition + Windows after partition with shape [B * num_windows, window_size, window_size, C]. """ B, H, W, C = x.shape x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) @@ -60,16 +59,17 @@ def window_partition( @register_notrace_function # reason: int argument is a Proxy -def window_reverse(windows, window_size: Tuple[int, int], H: int, W: int): - """ +def window_reverse(windows: torch.Tensor, window_size: Tuple[int, int], H: int, W: int) -> torch.Tensor: + """Reverse window partition. + Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image + windows: Windows with shape (num_windows*B, window_size, window_size, C). + window_size: Window size. + H: Height of image. + W: Width of image. Returns: - x: (B, H, W, C) + Tensor with shape (B, H, W, C). """ C = windows.shape[-1] x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C) @@ -77,7 +77,16 @@ def window_reverse(windows, window_size: Tuple[int, int], H: int, W: int): return x -def get_relative_position_index(win_h: int, win_w: int): +def get_relative_position_index(win_h: int, win_w: int) -> torch.Tensor: + """Get pair-wise relative position index for each token inside the window. + + Args: + win_h: Window height. + win_w: Window width. + + Returns: + Relative position index tensor. + """ # get pair-wise relative position index for each token inside the window coords = torch.stack(ndgrid(torch.arange(win_h), torch.arange(win_w))) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww @@ -90,8 +99,9 @@ def get_relative_position_index(win_h: int, win_w: int): class WindowAttention(nn.Module): - """ Window based multi-head self attention (W-MSA) module with relative position bias. - It supports shifted and non-shifted windows. + """Window based multi-head self attention (W-MSA) module with relative position bias. + + Supports both shifted and non-shifted windows. """ fused_attn: torch.jit.Final[bool] @@ -167,11 +177,15 @@ def _get_rel_pos_bias(self) -> torch.Tensor: relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww return relative_position_bias.unsqueeze(0) - def forward(self, x, mask: Optional[torch.Tensor] = None): - """ + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward pass. + Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + x: Input features with shape of (num_windows*B, N, C). + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None. + + Returns: + Output features with shape of (num_windows*B, N, C). """ B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) @@ -207,7 +221,9 @@ def forward(self, x, mask: Optional[torch.Tensor] = None): class SwinTransformerBlock(nn.Module): - """ Swin Transformer Block. + """Swin Transformer Block. + + A transformer block with window-based self-attention and shifted windows. """ def __init__( @@ -401,7 +417,15 @@ def _attn(self, x): x = shifted_x return x - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input features with shape (B, H, W, C). + + Returns: + Output features with shape (B, H, W, C). + """ B, H, W, C = x.shape x = x + self.drop_path1(self._attn(self.norm1(x))) x = x.reshape(B, -1, C) @@ -411,7 +435,9 @@ def forward(self, x): class PatchMerging(nn.Module): - """ Patch Merging Layer. + """Patch Merging Layer. + + Downsample features by merging 2x2 neighboring patches. """ def __init__( @@ -432,7 +458,15 @@ def __init__( self.norm = norm_layer(4 * dim) self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input features with shape (B, H, W, C). + + Returns: + Output features with shape (B, H//2, W//2, out_dim). + """ B, H, W, C = x.shape pad_values = (0, 0, 0, W % 2, 0, H % 2) @@ -446,7 +480,9 @@ def forward(self, x): class SwinTransformerStage(nn.Module): - """ A basic Swin Transformer layer for one stage. + """A basic Swin Transformer layer for one stage. + + Contains multiple Swin Transformer blocks and optional downsampling. """ def __init__( @@ -550,7 +586,15 @@ def set_input_size( always_partition=always_partition, ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input features. + + Returns: + Output features. + """ x = self.downsample(x) if self.grad_checkpointing and not torch.jit.is_scripting(): @@ -561,7 +605,7 @@ def forward(self, x): class SwinTransformer(nn.Module): - """ Swin Transformer + """Swin Transformer. A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - https://arxiv.org/pdf/2103.14030 @@ -690,13 +734,19 @@ def __init__( self.init_weights(weight_init) @torch.jit.ignore - def init_weights(self, mode=''): + def init_weights(self, mode: str = '') -> None: + """Initialize model weights. + + Args: + mode: Weight initialization mode ('jax', 'jax_nlhb', 'moco', or ''). + """ assert mode in ('jax', 'jax_nlhb', 'moco', '') head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. named_apply(get_init_weights_vit(mode, head_bias=head_bias), self) @torch.jit.ignore - def no_weight_decay(self): + def no_weight_decay(self) -> Set[str]: + """Parameters that should not use weight decay.""" nwd = set() for n, _ in self.named_parameters(): if 'relative_position_bias_table' in n: @@ -711,14 +761,14 @@ def set_input_size( window_ratio: int = 8, always_partition: Optional[bool] = None, ) -> None: - """ Updates the image resolution and window size. + """Update the image resolution and window size. Args: - img_size: New input resolution, if None current resolution is used - patch_size (Optional[Tuple[int, int]): New patch size, if None use current patch size - window_size: New window size, if None based on new_img_size // window_div - window_ratio: divisor for calculating window size from grid size - always_partition: always partition into windows and shift (even if window size < feat size) + img_size: New input resolution, if None current resolution is used. + patch_size: New patch size, if None use current patch size. + window_size: New window size, if None based on new_img_size // window_div. + window_ratio: Divisor for calculating window size from grid size. + always_partition: Always partition into windows and shift (even if window size < feat size). """ if img_size is not None or patch_size is not None: self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size) @@ -736,7 +786,8 @@ def set_input_size( ) @torch.jit.ignore - def group_matcher(self, coarse=False): + def group_matcher(self, coarse: bool = False) -> Dict[str, Any]: + """Group parameters for optimization.""" return dict( stem=r'^patch_embed', # stem and embed blocks=r'^layers\.(\d+)' if coarse else [ @@ -747,15 +798,23 @@ def group_matcher(self, coarse=False): ) @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): + def set_grad_checkpointing(self, enable: bool = True) -> None: + """Enable or disable gradient checkpointing.""" for l in self.layers: l.grad_checkpointing = enable @torch.jit.ignore def get_classifier(self) -> nn.Module: + """Get the classifier head.""" return self.head.fc - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None: + """Reset the classifier head. + + Args: + num_classes: Number of classes for new classifier. + global_pool: Global pooling type. + """ self.num_classes = num_classes self.head.reset(num_classes, pool_type=global_pool) @@ -768,17 +827,18 @@ def forward_intermediates( output_fmt: str = 'NCHW', intermediates_only: bool = False, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: - """ Forward features that returns intermediates. + """Forward features that returns intermediates. Args: - x: Input image tensor - indices: Take last n blocks if int, all if None, select matching indices if sequence - norm: Apply norm layer to compatible intermediates - stop_early: Stop iterating over blocks when last desired intermediate hit - output_fmt: Shape of intermediate feature outputs - intermediates_only: Only return intermediate features - Returns: + x: Input image tensor. + indices: Take last n blocks if int, all if None, select matching indices if sequence. + norm: Apply norm layer to compatible intermediates. + stop_early: Stop iterating over blocks when last desired intermediate hit. + output_fmt: Shape of intermediate feature outputs. + intermediates_only: Only return intermediate features. + Returns: + List of intermediate features or tuple of (final features, intermediates). """ assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' intermediates = [] @@ -814,8 +874,16 @@ def prune_intermediate_layers( indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, - ): - """ Prune layers not required for specified intermediates. + ) -> List[int]: + """Prune layers not required for specified intermediates. + + Args: + indices: Indices of intermediate layers to keep. + prune_norm: Whether to prune normalization layer. + prune_head: Whether to prune the classifier head. + + Returns: + List of indices that were kept. """ take_indices, max_index = feature_take_indices(len(self.layers), indices) self.layers = self.layers[:max_index + 1] # truncate blocks @@ -825,23 +893,49 @@ def prune_intermediate_layers( self.reset_classifier(0, '') return take_indices - def forward_features(self, x): + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through feature extraction layers.""" x = self.patch_embed(x) x = self.layers(x) x = self.norm(x) return x - def forward_head(self, x, pre_logits: bool = False): + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + """Forward pass through classifier head. + + Args: + x: Feature tensor. + pre_logits: Return features before final classifier. + + Returns: + Output tensor. + """ return self.head(x, pre_logits=True) if pre_logits else self.head(x) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor. + + Returns: + Output logits. + """ x = self.forward_features(x) x = self.forward_head(x) return x -def checkpoint_filter_fn(state_dict, model): - """ convert patch embedding weight from manual patchify + linear proj to conv""" +def checkpoint_filter_fn(state_dict: dict, model: nn.Module) -> Dict[str, torch.Tensor]: + """Convert patch embedding weight from manual patchify + linear proj to conv. + + Args: + state_dict: State dictionary from checkpoint. + model: Model instance. + + Returns: + Filtered state dictionary. + """ old_weights = True if 'head.fc.weight' in state_dict: old_weights = False @@ -881,7 +975,17 @@ def checkpoint_filter_fn(state_dict, model): return out_dict -def _create_swin_transformer(variant, pretrained=False, **kwargs): +def _create_swin_transformer(variant: str, pretrained: bool = False, **kwargs) -> SwinTransformer: + """Create a Swin Transformer model. + + Args: + variant: Model variant name. + pretrained: Load pretrained weights. + **kwargs: Additional model arguments. + + Returns: + SwinTransformer model instance. + """ default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) out_indices = kwargs.pop('out_indices', default_out_indices) @@ -894,7 +998,8 @@ def _create_swin_transformer(variant, pretrained=False, **kwargs): return model -def _cfg(url='', **kwargs): +def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: + """Create default configuration for Swin Transformer models.""" return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index 6a3330f47..f7b758aa8 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -13,7 +13,7 @@ # Written by Ze Liu # -------------------------------------------------------- import math -from typing import Callable, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union import torch import torch.nn as nn @@ -34,13 +34,14 @@ def window_partition(x: torch.Tensor, window_size: Tuple[int, int]) -> torch.Tensor: - """ + """Partition into non-overlapping windows. + Args: - x: (B, H, W, C) - window_size (int): window size + x: Input tensor of shape (B, H, W, C). + window_size: Window size (height, width). Returns: - windows: (num_windows*B, window_size, window_size, C) + Windows tensor of shape (num_windows*B, window_size[0], window_size[1], C). """ B, H, W, C = x.shape x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) @@ -50,14 +51,15 @@ def window_partition(x: torch.Tensor, window_size: Tuple[int, int]) -> torch.Ten @register_notrace_function # reason: int argument is a Proxy def window_reverse(windows: torch.Tensor, window_size: Tuple[int, int], img_size: Tuple[int, int]) -> torch.Tensor: - """ + """Merge windows back to feature map. + Args: - windows: (num_windows * B, window_size[0], window_size[1], C) - window_size (Tuple[int, int]): Window size - img_size (Tuple[int, int]): Image size + windows: Windows tensor of shape (num_windows * B, window_size[0], window_size[1], C). + window_size: Window size (height, width). + img_size: Image size (height, width). Returns: - x: (B, H, W, C) + Feature map tensor of shape (B, H, W, C). """ H, W = img_size C = windows.shape[-1] @@ -67,17 +69,10 @@ def window_reverse(windows: torch.Tensor, window_size: Tuple[int, int], img_size class WindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. + """Window based multi-head self attention (W-MSA) module with relative position bias. - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - pretrained_window_size (tuple[int]): The height and width of the window in pre-training. + Supports both shifted and non-shifted window attention with continuous relative + position bias and cosine attention. """ def __init__( @@ -91,6 +86,18 @@ def __init__( proj_drop: float = 0., pretrained_window_size: Tuple[int, int] = (0, 0), ) -> None: + """Initialize window attention module. + + Args: + dim: Number of input channels. + window_size: The height and width of the window. + num_heads: Number of attention heads. + qkv_bias: If True, add a learnable bias to query, key, value. + qkv_bias_separate: If True, use separate bias for q, k, v projections. + attn_drop: Dropout ratio of attention weight. + proj_drop: Dropout ratio of output. + pretrained_window_size: The height and width of the window in pre-training. + """ super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww @@ -123,7 +130,8 @@ def __init__( self._make_pair_wise_relative_positions() - def _make_pair_wise_relative_positions(self): + def _make_pair_wise_relative_positions(self) -> None: + """Create pair-wise relative position index and coordinates table.""" # get relative_coords_table relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0]).to(torch.float32) relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1]).to(torch.float32) @@ -154,9 +162,10 @@ def _make_pair_wise_relative_positions(self): self.register_buffer("relative_position_index", relative_position_index, persistent=False) def set_window_size(self, window_size: Tuple[int, int]) -> None: - """Update window size & interpolate position embeddings + """Update window size and regenerate relative position tables. + Args: - window_size (int): New window size + window_size: New window size (height, width). """ window_size = to_2tuple(window_size) if window_size != self.window_size: @@ -164,10 +173,14 @@ def set_window_size(self, window_size: Tuple[int, int]) -> None: self._make_pair_wise_relative_positions() def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: - """ + """Forward pass of window attention. + Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + x: Input features with shape of (num_windows*B, N, C). + mask: Attention mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None. + + Returns: + Output features with shape of (num_windows*B, N, C). """ B_, N, C = x.shape @@ -212,7 +225,10 @@ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch class SwinTransformerV2Block(nn.Module): - """ Swin Transformer Block. + """Swin Transformer V2 Block. + + A standard transformer block with window attention and shifted window attention + for modeling long-range dependencies efficiently. """ def __init__( @@ -290,6 +306,14 @@ def __init__( ) def get_attn_mask(self, x: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]: + """Generate attention mask for shifted window attention. + + Args: + x: Input tensor for dynamic shape calculation. + + Returns: + Attention mask or None if no shift. + """ if any(self.shift_size): # calculate attention mask for SW-MSA if x is None: @@ -322,6 +346,15 @@ def _calc_window_shift( target_window_size: _int_or_tuple_2_t, target_shift_size: Optional[_int_or_tuple_2_t] = None, ) -> Tuple[Tuple[int, int], Tuple[int, int]]: + """Calculate window size and shift size based on input resolution. + + Args: + target_window_size: Target window size. + target_shift_size: Target shift size. + + Returns: + Tuple of (adjusted_window_size, adjusted_shift_size). + """ target_window_size = to_2tuple(target_window_size) if target_shift_size is None: # if passed value is None, recalculate from default window_size // 2 if it was active @@ -346,13 +379,13 @@ def set_input_size( feat_size: Tuple[int, int], window_size: Tuple[int, int], always_partition: Optional[bool] = None, - ): - """ Updates the input resolution, window size. + ) -> None: + """Set input size and update window configuration. Args: - feat_size (Tuple[int, int]): New input resolution - window_size (int): New window size - always_partition: Change always_partition attribute if not None + feat_size: New feature map size. + window_size: New window size. + always_partition: Override always_partition setting. """ # Update input resolution self.input_resolution = feat_size @@ -368,6 +401,14 @@ def set_input_size( ) def _attn(self, x: torch.Tensor) -> torch.Tensor: + """Apply windowed attention with optional shift. + + Args: + x: Input tensor of shape (B, H, W, C). + + Returns: + Output tensor of shape (B, H, W, C). + """ B, H, W, C = x.shape # cyclic shift @@ -415,7 +456,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class PatchMerging(nn.Module): - """ Patch Merging Layer. + """Patch Merging Layer. + + Merges 2x2 neighboring patches and projects to higher dimension, + effectively downsampling the feature maps. """ def __init__( @@ -450,7 +494,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class SwinTransformerV2Stage(nn.Module): - """ A Swin Transformer V2 Stage. + """A Swin Transformer V2 Stage. + + A single stage consisting of multiple Swin Transformer blocks with + optional downsampling at the beginning. """ def __init__( @@ -538,13 +585,13 @@ def set_input_size( feat_size: Tuple[int, int], window_size: int, always_partition: Optional[bool] = None, - ): - """ Updates the resolution, window size and so the pair-wise relative positions. + ) -> None: + """Update resolution, window size and relative positions. Args: - feat_size: New input (feature) resolution - window_size: New window size - always_partition: Always partition / shift the window + feat_size: New input (feature) resolution. + window_size: New window size. + always_partition: Always partition / shift the window. """ self.input_resolution = feat_size if isinstance(self.downsample, nn.Identity): @@ -560,6 +607,14 @@ def set_input_size( ) def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through the stage. + + Args: + x: Input tensor of shape (B, H, W, C). + + Returns: + Output tensor of shape (B, H', W', C'). + """ x = self.downsample(x) for blk in self.blocks: @@ -570,6 +625,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x def _init_respostnorm(self) -> None: + """Initialize residual post-normalization weights.""" for blk in self.blocks: nn.init.constant_(blk.norm1.bias, 0) nn.init.constant_(blk.norm1.weight, 0) @@ -578,7 +634,10 @@ def _init_respostnorm(self) -> None: class SwinTransformerV2(nn.Module): - """ Swin Transformer V2 + """Swin Transformer V2. + + A hierarchical vision transformer using shifted windows for efficient + self-attention computation with continuous position bias. A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution` - https://arxiv.org/abs/2111.09883 @@ -700,7 +759,12 @@ def __init__( for bly in self.layers: bly._init_respostnorm() - def _init_weights(self, m): + def _init_weights(self, m: nn.Module) -> None: + """Initialize model weights. + + Args: + m: Module to initialize. + """ if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: @@ -739,7 +803,12 @@ def set_input_size( ) @torch.jit.ignore - def no_weight_decay(self): + def no_weight_decay(self) -> Set[str]: + """Get parameter names that should not use weight decay. + + Returns: + Set of parameter names to exclude from weight decay. + """ nod = set() for n, m in self.named_modules(): if any([kw in n for kw in ("cpb_mlp", "logit_scale")]): @@ -747,7 +816,15 @@ def no_weight_decay(self): return nod @torch.jit.ignore - def group_matcher(self, coarse=False): + def group_matcher(self, coarse: bool = False) -> Dict[str, Any]: + """Create parameter group matcher for optimizer parameter groups. + + Args: + coarse: If True, use coarse grouping. + + Returns: + Dictionary mapping group names to regex patterns. + """ return dict( stem=r'^absolute_pos_embed|patch_embed', # stem and embed blocks=r'^layers\.(\d+)' if coarse else [ @@ -758,15 +835,31 @@ def group_matcher(self, coarse=False): ) @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): + def set_grad_checkpointing(self, enable: bool = True) -> None: + """Enable or disable gradient checkpointing. + + Args: + enable: If True, enable gradient checkpointing. + """ for l in self.layers: l.grad_checkpointing = enable @torch.jit.ignore def get_classifier(self) -> nn.Module: + """Get the classifier head. + + Returns: + The classification head module. + """ return self.head.fc - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None: + """Reset the classification head. + + Args: + num_classes: Number of classes for new head. + global_pool: Global pooling type. + """ self.num_classes = num_classes self.head.reset(num_classes, global_pool) @@ -836,22 +929,59 @@ def prune_intermediate_layers( self.reset_classifier(0, '') return take_indices - def forward_features(self, x): + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through feature extraction layers. + + Args: + x: Input tensor of shape (B, C, H, W). + + Returns: + Feature tensor of shape (B, H', W', C). + """ x = self.patch_embed(x) x = self.layers(x) x = self.norm(x) return x - def forward_head(self, x, pre_logits: bool = False): + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + """Forward pass through classification head. + + Args: + x: Feature tensor of shape (B, H, W, C). + pre_logits: If True, return features before final linear layer. + + Returns: + Logits tensor of shape (B, num_classes) or pre-logits. + """ return self.head(x, pre_logits=True) if pre_logits else self.head(x) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through the model. + + Args: + x: Input tensor of shape (B, C, H, W). + + Returns: + Logits tensor of shape (B, num_classes). + """ x = self.forward_features(x) x = self.forward_head(x) return x -def checkpoint_filter_fn(state_dict, model): +def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module) -> Dict[str, torch.Tensor]: + """Filter and process checkpoint state dict for loading. + + Handles resizing of patch embeddings and relative position tables + when model size differs from checkpoint. + + Args: + state_dict: Checkpoint state dictionary. + model: Target model to load weights into. + + Returns: + Filtered state dictionary. + """ state_dict = state_dict.get('model', state_dict) state_dict = state_dict.get('state_dict', state_dict) native_checkpoint = 'head.fc.weight' in state_dict @@ -881,7 +1011,17 @@ def checkpoint_filter_fn(state_dict, model): return out_dict -def _create_swin_transformer_v2(variant, pretrained=False, **kwargs): +def _create_swin_transformer_v2(variant: str, pretrained: bool = False, **kwargs) -> SwinTransformerV2: + """Create a Swin Transformer V2 model. + + Args: + variant: Model variant name. + pretrained: If True, load pretrained weights. + **kwargs: Additional model arguments. + + Returns: + SwinTransformerV2 model instance. + """ default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 1, 1)))) out_indices = kwargs.pop('out_indices', default_out_indices) @@ -963,72 +1103,64 @@ def _cfg(url='', **kwargs): @register_model -def swinv2_tiny_window16_256(pretrained=False, **kwargs) -> SwinTransformerV2: - """ - """ +def swinv2_tiny_window16_256(pretrained: bool = False, **kwargs) -> SwinTransformerV2: + """Swin-T V2 @ 256x256, window 16x16.""" model_args = dict(window_size=16, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24)) return _create_swin_transformer_v2( 'swinv2_tiny_window16_256', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model -def swinv2_tiny_window8_256(pretrained=False, **kwargs) -> SwinTransformerV2: - """ - """ +def swinv2_tiny_window8_256(pretrained: bool = False, **kwargs) -> SwinTransformerV2: + """Swin-T V2 @ 256x256, window 8x8.""" model_args = dict(window_size=8, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24)) return _create_swin_transformer_v2( 'swinv2_tiny_window8_256', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model -def swinv2_small_window16_256(pretrained=False, **kwargs) -> SwinTransformerV2: - """ - """ +def swinv2_small_window16_256(pretrained: bool = False, **kwargs) -> SwinTransformerV2: + """Swin-S V2 @ 256x256, window 16x16.""" model_args = dict(window_size=16, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24)) return _create_swin_transformer_v2( 'swinv2_small_window16_256', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model -def swinv2_small_window8_256(pretrained=False, **kwargs) -> SwinTransformerV2: - """ - """ +def swinv2_small_window8_256(pretrained: bool = False, **kwargs) -> SwinTransformerV2: + """Swin-S V2 @ 256x256, window 8x8.""" model_args = dict(window_size=8, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24)) return _create_swin_transformer_v2( 'swinv2_small_window8_256', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model -def swinv2_base_window16_256(pretrained=False, **kwargs) -> SwinTransformerV2: - """ - """ +def swinv2_base_window16_256(pretrained: bool = False, **kwargs) -> SwinTransformerV2: + """Swin-B V2 @ 256x256, window 16x16.""" model_args = dict(window_size=16, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32)) return _create_swin_transformer_v2( 'swinv2_base_window16_256', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model -def swinv2_base_window8_256(pretrained=False, **kwargs) -> SwinTransformerV2: - """ - """ +def swinv2_base_window8_256(pretrained: bool = False, **kwargs) -> SwinTransformerV2: + """Swin-B V2 @ 256x256, window 8x8.""" model_args = dict(window_size=8, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32)) return _create_swin_transformer_v2( 'swinv2_base_window8_256', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model -def swinv2_base_window12_192(pretrained=False, **kwargs) -> SwinTransformerV2: - """ - """ +def swinv2_base_window12_192(pretrained: bool = False, **kwargs) -> SwinTransformerV2: + """Swin-B V2 @ 192x192, window 12x12.""" model_args = dict(window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32)) return _create_swin_transformer_v2( 'swinv2_base_window12_192', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model -def swinv2_base_window12to16_192to256(pretrained=False, **kwargs) -> SwinTransformerV2: - """ - """ +def swinv2_base_window12to16_192to256(pretrained: bool = False, **kwargs) -> SwinTransformerV2: + """Swin-B V2 @ 192x192, trained at window 12x12, fine-tuned to 256x256 window 16x16.""" model_args = dict( window_size=16, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), pretrained_window_sizes=(12, 12, 12, 6)) @@ -1037,9 +1169,8 @@ def swinv2_base_window12to16_192to256(pretrained=False, **kwargs) -> SwinTransfo @register_model -def swinv2_base_window12to24_192to384(pretrained=False, **kwargs) -> SwinTransformerV2: - """ - """ +def swinv2_base_window12to24_192to384(pretrained: bool = False, **kwargs) -> SwinTransformerV2: + """Swin-B V2 @ 192x192, trained at window 12x12, fine-tuned to 384x384 window 24x24.""" model_args = dict( window_size=24, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), pretrained_window_sizes=(12, 12, 12, 6)) @@ -1048,18 +1179,16 @@ def swinv2_base_window12to24_192to384(pretrained=False, **kwargs) -> SwinTransfo @register_model -def swinv2_large_window12_192(pretrained=False, **kwargs) -> SwinTransformerV2: - """ - """ +def swinv2_large_window12_192(pretrained: bool = False, **kwargs) -> SwinTransformerV2: + """Swin-L V2 @ 192x192, window 12x12.""" model_args = dict(window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48)) return _create_swin_transformer_v2( 'swinv2_large_window12_192', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model -def swinv2_large_window12to16_192to256(pretrained=False, **kwargs) -> SwinTransformerV2: - """ - """ +def swinv2_large_window12to16_192to256(pretrained: bool = False, **kwargs) -> SwinTransformerV2: + """Swin-L V2 @ 192x192, trained at window 12x12, fine-tuned to 256x256 window 16x16.""" model_args = dict( window_size=16, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), pretrained_window_sizes=(12, 12, 12, 6)) @@ -1068,9 +1197,8 @@ def swinv2_large_window12to16_192to256(pretrained=False, **kwargs) -> SwinTransf @register_model -def swinv2_large_window12to24_192to384(pretrained=False, **kwargs) -> SwinTransformerV2: - """ - """ +def swinv2_large_window12to24_192to384(pretrained: bool = False, **kwargs) -> SwinTransformerV2: + """Swin-L V2 @ 192x192, trained at window 12x12, fine-tuned to 384x384 window 24x24.""" model_args = dict( window_size=24, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), pretrained_window_sizes=(12, 12, 12, 6)) diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index d8d247cde..c490fa23c 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -29,7 +29,7 @@ # -------------------------------------------------------- import logging import math -from typing import Tuple, Optional, List, Union, Any, Type +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union import torch import torch.nn as nn @@ -49,23 +49,24 @@ def bchw_to_bhwc(x: torch.Tensor) -> torch.Tensor: - """Permutes a tensor from the shape (B, C, H, W) to (B, H, W, C). """ + """Permutes a tensor from the shape (B, C, H, W) to (B, H, W, C).""" return x.permute(0, 2, 3, 1) def bhwc_to_bchw(x: torch.Tensor) -> torch.Tensor: - """Permutes a tensor from the shape (B, H, W, C) to (B, C, H, W). """ + """Permutes a tensor from the shape (B, H, W, C) to (B, C, H, W).""" return x.permute(0, 3, 1, 2) -def window_partition(x, window_size: Tuple[int, int]): - """ +def window_partition(x: torch.Tensor, window_size: Tuple[int, int]) -> torch.Tensor: + """Partition into non-overlapping windows. + Args: - x: (B, H, W, C) - window_size (int): window size + x: Input tensor of shape (B, H, W, C). + window_size: Window size (height, width). Returns: - windows: (num_windows*B, window_size, window_size, C) + Windows tensor of shape (num_windows*B, window_size[0], window_size[1], C). """ B, H, W, C = x.shape x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) @@ -74,15 +75,16 @@ def window_partition(x, window_size: Tuple[int, int]): @register_notrace_function # reason: int argument is a Proxy -def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]): - """ +def window_reverse(windows: torch.Tensor, window_size: Tuple[int, int], img_size: Tuple[int, int]) -> torch.Tensor: + """Merge windows back to feature map. + Args: - windows: (num_windows * B, window_size[0], window_size[1], C) - window_size (Tuple[int, int]): Window size - img_size (Tuple[int, int]): Image size + windows: Windows tensor of shape (num_windows * B, window_size[0], window_size[1], C). + window_size: Window size (height, width). + img_size: Image size (height, width). Returns: - x: (B, H, W, C) + Feature map tensor of shape (B, H, W, C). """ H, W = img_size C = windows.shape[-1] @@ -139,7 +141,7 @@ def __init__( self._make_pair_wise_relative_positions() def _make_pair_wise_relative_positions(self) -> None: - """Method initializes the pair-wise relative positions to compute the positional biases.""" + """Initialize the pair-wise relative positions to compute the positional biases.""" device = self.logit_scale.device coordinates = torch.stack(ndgrid( torch.arange(self.window_size[0], device=device), @@ -152,9 +154,10 @@ def _make_pair_wise_relative_positions(self) -> None: self.register_buffer("relative_coordinates_log", relative_coordinates_log, persistent=False) def set_window_size(self, window_size: Tuple[int, int]) -> None: - """Update window size & interpolate position embeddings + """Update window size and regenerate relative position coordinates. + Args: - window_size (int): New window size + window_size: New window size. """ window_size = to_2tuple(window_size) if window_size != self.window_size: @@ -162,11 +165,10 @@ def set_window_size(self, window_size: Tuple[int, int]) -> None: self._make_pair_wise_relative_positions() def _relative_positional_encodings(self) -> torch.Tensor: - """Method computes the relative positional encodings + """Compute the relative positional encodings. Returns: - relative_position_bias (torch.Tensor): Relative positional encodings - (1, number of heads, window size ** 2, window size ** 2) + Relative positional encodings of shape (1, num_heads, window_size**2, window_size**2). """ window_area = self.window_size[0] * self.window_size[1] relative_position_bias = self.meta_mlp(self.relative_coordinates_log) @@ -177,13 +179,14 @@ def _relative_positional_encodings(self) -> torch.Tensor: return relative_position_bias def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: - """ Forward pass. + """Forward pass of window multi-head self-attention. + Args: - x (torch.Tensor): Input tensor of the shape (B * windows, N, C) - mask (Optional[torch.Tensor]): Attention mask for the shift case + x: Input tensor of shape (B * windows, N, C). + mask: Attention mask for the shift case. Returns: - Output tensor of the shape [B * windows, N, C] + Output tensor of shape (B * windows, N, C). """ Bw, L, C = x.shape @@ -404,13 +407,13 @@ def _shifted_window_attn(self, x): return x def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward pass. + """Forward pass of Swin Transformer V2 block. Args: - x (torch.Tensor): Input tensor of the shape [B, C, H, W] + x: Input tensor of shape [B, C, H, W]. Returns: - output (torch.Tensor): Output tensor of the shape [B, C, H, W] + Output tensor of shape [B, C, H, W]. """ # post-norm branches (op -> norm -> drop) x = x + self.drop_path1(self.norm1(self._shifted_window_attn(x))) @@ -424,23 +427,30 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class PatchMerging(nn.Module): - """ This class implements the patch merging as a strided convolution with a normalization before. - Args: - dim (int): Number of input channels - norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. + """Patch merging layer. + + This class implements the patch merging as a strided convolution with a normalization before. """ def __init__(self, dim: int, norm_layer: Type[nn.Module] = nn.LayerNorm) -> None: + """Initialize patch merging layer. + + Args: + dim: Number of input channels. + norm_layer: Type of normalization layer to be utilized. + """ super(PatchMerging, self).__init__() self.norm = norm_layer(4 * dim) self.reduction = nn.Linear(in_features=4 * dim, out_features=2 * dim, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: - """ Forward pass. + """Forward pass of patch merging. + Args: - x (torch.Tensor): Input tensor of the shape [B, C, H, W] + x: Input tensor of shape [B, C, H, W]. + Returns: - output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2] + Output tensor of shape [B, 2 * C, H // 2, W // 2]. """ B, H, W, C = x.shape @@ -455,16 +465,26 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class PatchEmbed(nn.Module): - """ 2D Image to Patch Embedding """ + """2D Image to Patch Embedding.""" def __init__( self, - img_size=224, - patch_size=16, - in_chans=3, - embed_dim=768, - norm_layer=None, - strict_img_size=True, - ): + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + strict_img_size: bool = True, + ) -> None: + """Initialize patch embedding. + + Args: + img_size: Input image size. + patch_size: Patch size. + in_chans: Number of input channels. + embed_dim: Embedding dimension. + norm_layer: Normalization layer. + strict_img_size: Enforce strict image size. + """ super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) @@ -477,14 +497,27 @@ def __init__( self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - def set_input_size(self, img_size: Tuple[int, int]): + def set_input_size(self, img_size: Tuple[int, int]) -> None: + """Update input image size. + + Args: + img_size: New image size. + """ img_size = to_2tuple(img_size) if img_size != self.img_size: self.img_size = img_size self.grid_size = (img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of patch embedding. + + Args: + x: Input tensor of shape [B, C, H, W]. + + Returns: + Output tensor of shape [B, C', H', W']. + """ B, C, H, W = x.shape if self.strict_img_size: _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") @@ -907,7 +940,16 @@ def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs): return model -def _cfg(url='', **kwargs): +def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: + """Create a default configuration dictionary. + + Args: + url: Model weights URL. + **kwargs: Additional configuration parameters. + + Returns: + Configuration dictionary. + """ return { 'url': url, 'num_classes': 1000, @@ -967,8 +1009,8 @@ def _cfg(url='', **kwargs): @register_model -def swinv2_cr_tiny_384(pretrained=False, **kwargs) -> SwinTransformerV2Cr: - """Swin-T V2 CR @ 384x384, trained ImageNet-1k""" +def swinv2_cr_tiny_384(pretrained: bool = False, **kwargs) -> SwinTransformerV2Cr: + """Swin-T V2 CR @ 384x384, trained ImageNet-1k.""" model_args = dict( embed_dim=96, depths=(2, 2, 6, 2), @@ -978,8 +1020,8 @@ def swinv2_cr_tiny_384(pretrained=False, **kwargs) -> SwinTransformerV2Cr: @register_model -def swinv2_cr_tiny_224(pretrained=False, **kwargs) -> SwinTransformerV2Cr: - """Swin-T V2 CR @ 224x224, trained ImageNet-1k""" +def swinv2_cr_tiny_224(pretrained: bool = False, **kwargs) -> SwinTransformerV2Cr: + """Swin-T V2 CR @ 224x224, trained ImageNet-1k.""" model_args = dict( embed_dim=96, depths=(2, 2, 6, 2), @@ -989,8 +1031,9 @@ def swinv2_cr_tiny_224(pretrained=False, **kwargs) -> SwinTransformerV2Cr: @register_model -def swinv2_cr_tiny_ns_224(pretrained=False, **kwargs) -> SwinTransformerV2Cr: +def swinv2_cr_tiny_ns_224(pretrained: bool = False, **kwargs) -> SwinTransformerV2Cr: """Swin-T V2 CR @ 224x224, trained ImageNet-1k w/ extra stage norms. + ** Experimental, may make default if results are improved. ** """ model_args = dict( @@ -1003,8 +1046,8 @@ def swinv2_cr_tiny_ns_224(pretrained=False, **kwargs) -> SwinTransformerV2Cr: @register_model -def swinv2_cr_small_384(pretrained=False, **kwargs) -> SwinTransformerV2Cr: - """Swin-S V2 CR @ 384x384, trained ImageNet-1k""" +def swinv2_cr_small_384(pretrained: bool = False, **kwargs) -> SwinTransformerV2Cr: + """Swin-S V2 CR @ 384x384, trained ImageNet-1k.""" model_args = dict( embed_dim=96, depths=(2, 2, 18, 2), @@ -1014,8 +1057,8 @@ def swinv2_cr_small_384(pretrained=False, **kwargs) -> SwinTransformerV2Cr: @register_model -def swinv2_cr_small_224(pretrained=False, **kwargs) -> SwinTransformerV2Cr: - """Swin-S V2 CR @ 224x224, trained ImageNet-1k""" +def swinv2_cr_small_224(pretrained: bool = False, **kwargs) -> SwinTransformerV2Cr: + """Swin-S V2 CR @ 224x224, trained ImageNet-1k.""" model_args = dict( embed_dim=96, depths=(2, 2, 18, 2), @@ -1025,8 +1068,8 @@ def swinv2_cr_small_224(pretrained=False, **kwargs) -> SwinTransformerV2Cr: @register_model -def swinv2_cr_small_ns_224(pretrained=False, **kwargs) -> SwinTransformerV2Cr: - """Swin-S V2 CR @ 224x224, trained ImageNet-1k""" +def swinv2_cr_small_ns_224(pretrained: bool = False, **kwargs) -> SwinTransformerV2Cr: + """Swin-S V2 CR @ 224x224, trained ImageNet-1k.""" model_args = dict( embed_dim=96, depths=(2, 2, 18, 2), @@ -1037,8 +1080,8 @@ def swinv2_cr_small_ns_224(pretrained=False, **kwargs) -> SwinTransformerV2Cr: @register_model -def swinv2_cr_small_ns_256(pretrained=False, **kwargs) -> SwinTransformerV2Cr: - """Swin-S V2 CR @ 256x256, trained ImageNet-1k""" +def swinv2_cr_small_ns_256(pretrained: bool = False, **kwargs) -> SwinTransformerV2Cr: + """Swin-S V2 CR @ 256x256, trained ImageNet-1k.""" model_args = dict( embed_dim=96, depths=(2, 2, 18, 2), @@ -1049,8 +1092,8 @@ def swinv2_cr_small_ns_256(pretrained=False, **kwargs) -> SwinTransformerV2Cr: @register_model -def swinv2_cr_base_384(pretrained=False, **kwargs) -> SwinTransformerV2Cr: - """Swin-B V2 CR @ 384x384, trained ImageNet-1k""" +def swinv2_cr_base_384(pretrained: bool = False, **kwargs) -> SwinTransformerV2Cr: + """Swin-B V2 CR @ 384x384, trained ImageNet-1k.""" model_args = dict( embed_dim=128, depths=(2, 2, 18, 2), @@ -1060,8 +1103,8 @@ def swinv2_cr_base_384(pretrained=False, **kwargs) -> SwinTransformerV2Cr: @register_model -def swinv2_cr_base_224(pretrained=False, **kwargs) -> SwinTransformerV2Cr: - """Swin-B V2 CR @ 224x224, trained ImageNet-1k""" +def swinv2_cr_base_224(pretrained: bool = False, **kwargs) -> SwinTransformerV2Cr: + """Swin-B V2 CR @ 224x224, trained ImageNet-1k.""" model_args = dict( embed_dim=128, depths=(2, 2, 18, 2), @@ -1071,8 +1114,8 @@ def swinv2_cr_base_224(pretrained=False, **kwargs) -> SwinTransformerV2Cr: @register_model -def swinv2_cr_base_ns_224(pretrained=False, **kwargs) -> SwinTransformerV2Cr: - """Swin-B V2 CR @ 224x224, trained ImageNet-1k""" +def swinv2_cr_base_ns_224(pretrained: bool = False, **kwargs) -> SwinTransformerV2Cr: + """Swin-B V2 CR @ 224x224, trained ImageNet-1k.""" model_args = dict( embed_dim=128, depths=(2, 2, 18, 2), @@ -1083,8 +1126,8 @@ def swinv2_cr_base_ns_224(pretrained=False, **kwargs) -> SwinTransformerV2Cr: @register_model -def swinv2_cr_large_384(pretrained=False, **kwargs) -> SwinTransformerV2Cr: - """Swin-L V2 CR @ 384x384, trained ImageNet-1k""" +def swinv2_cr_large_384(pretrained: bool = False, **kwargs) -> SwinTransformerV2Cr: + """Swin-L V2 CR @ 384x384, trained ImageNet-1k.""" model_args = dict( embed_dim=192, depths=(2, 2, 18, 2), @@ -1094,8 +1137,8 @@ def swinv2_cr_large_384(pretrained=False, **kwargs) -> SwinTransformerV2Cr: @register_model -def swinv2_cr_large_224(pretrained=False, **kwargs) -> SwinTransformerV2Cr: - """Swin-L V2 CR @ 224x224, trained ImageNet-1k""" +def swinv2_cr_large_224(pretrained: bool = False, **kwargs) -> SwinTransformerV2Cr: + """Swin-L V2 CR @ 224x224, trained ImageNet-1k.""" model_args = dict( embed_dim=192, depths=(2, 2, 18, 2), @@ -1105,8 +1148,8 @@ def swinv2_cr_large_224(pretrained=False, **kwargs) -> SwinTransformerV2Cr: @register_model -def swinv2_cr_huge_384(pretrained=False, **kwargs) -> SwinTransformerV2Cr: - """Swin-H V2 CR @ 384x384, trained ImageNet-1k""" +def swinv2_cr_huge_384(pretrained: bool = False, **kwargs) -> SwinTransformerV2Cr: + """Swin-H V2 CR @ 384x384, trained ImageNet-1k.""" model_args = dict( embed_dim=352, depths=(2, 2, 18, 2), @@ -1117,8 +1160,8 @@ def swinv2_cr_huge_384(pretrained=False, **kwargs) -> SwinTransformerV2Cr: @register_model -def swinv2_cr_huge_224(pretrained=False, **kwargs) -> SwinTransformerV2Cr: - """Swin-H V2 CR @ 224x224, trained ImageNet-1k""" +def swinv2_cr_huge_224(pretrained: bool = False, **kwargs) -> SwinTransformerV2Cr: + """Swin-H V2 CR @ 224x224, trained ImageNet-1k.""" model_args = dict( embed_dim=352, depths=(2, 2, 18, 2), @@ -1129,8 +1172,8 @@ def swinv2_cr_huge_224(pretrained=False, **kwargs) -> SwinTransformerV2Cr: @register_model -def swinv2_cr_giant_384(pretrained=False, **kwargs) -> SwinTransformerV2Cr: - """Swin-G V2 CR @ 384x384, trained ImageNet-1k""" +def swinv2_cr_giant_384(pretrained: bool = False, **kwargs) -> SwinTransformerV2Cr: + """Swin-G V2 CR @ 384x384, trained ImageNet-1k.""" model_args = dict( embed_dim=512, depths=(2, 2, 42, 2), @@ -1141,8 +1184,8 @@ def swinv2_cr_giant_384(pretrained=False, **kwargs) -> SwinTransformerV2Cr: @register_model -def swinv2_cr_giant_224(pretrained=False, **kwargs) -> SwinTransformerV2Cr: - """Swin-G V2 CR @ 224x224, trained ImageNet-1k""" +def swinv2_cr_giant_224(pretrained: bool = False, **kwargs) -> SwinTransformerV2Cr: + """Swin-G V2 CR @ 224x224, trained ImageNet-1k.""" model_args = dict( embed_dim=512, depths=(2, 2, 42, 2), diff --git a/timm/models/tiny_vit.py b/timm/models/tiny_vit.py index d238fa5b2..366eef709 100644 --- a/timm/models/tiny_vit.py +++ b/timm/models/tiny_vit.py @@ -572,7 +572,7 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages): x = stage(x) if feat_idx in take_indices: - intermediates.append(x) + intermediates.append(x) if intermediates_only: return intermediates diff --git a/timm/models/vgg.py b/timm/models/vgg.py index a4cfbffdf..23c883489 100644 --- a/timm/models/vgg.py +++ b/timm/models/vgg.py @@ -30,17 +30,32 @@ @register_notrace_module # reason: FX can't symbolically trace control flow in forward method class ConvMlp(nn.Module): + """Convolutional MLP block for VGG head. + + Replaces traditional Linear layers with Conv2d layers in the classifier. + """ def __init__( self, - in_features=512, - out_features=4096, - kernel_size=7, - mlp_ratio=1.0, + in_features: int = 512, + out_features: int = 4096, + kernel_size: int = 7, + mlp_ratio: float = 1.0, drop_rate: float = 0.2, act_layer: Type[nn.Module] = nn.ReLU, conv_layer: Type[nn.Module] = nn.Conv2d, ): + """Initialize ConvMlp. + + Args: + in_features: Number of input features. + out_features: Number of output features. + kernel_size: Kernel size for first conv layer. + mlp_ratio: Ratio for hidden layer size. + drop_rate: Dropout rate. + act_layer: Activation layer type. + conv_layer: Convolution layer type. + """ super(ConvMlp, self).__init__() self.input_kernel_size = kernel_size mid_features = int(out_features * mlp_ratio) @@ -50,7 +65,15 @@ def __init__( self.fc2 = conv_layer(mid_features, out_features, 1, bias=True) self.act2 = act_layer(True) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor. + + Returns: + Output tensor. + """ if x.shape[-2] < self.input_kernel_size or x.shape[-1] < self.input_kernel_size: # keep the input size >= 7x7 output_size = (max(self.input_kernel_size, x.shape[-2]), max(self.input_kernel_size, x.shape[-1])) @@ -64,6 +87,11 @@ def forward(self, x): class VGG(nn.Module): + """VGG model architecture. + + Based on `Very Deep Convolutional Networks for Large-Scale Image Recognition` + - https://arxiv.org/abs/1409.1556 + """ def __init__( self, @@ -78,6 +106,20 @@ def __init__( global_pool: str = 'avg', drop_rate: float = 0., ) -> None: + """Initialize VGG model. + + Args: + cfg: Configuration list defining network architecture. + num_classes: Number of classes for classification. + in_chans: Number of input channels. + output_stride: Output stride of network. + mlp_ratio: Ratio for MLP hidden layer size. + act_layer: Activation layer type. + conv_layer: Convolution layer type. + norm_layer: Normalization layer type. + global_pool: Global pooling type. + drop_rate: Dropout rate. + """ super(VGG, self).__init__() assert output_stride == 32 self.num_classes = num_classes @@ -128,36 +170,86 @@ def __init__( self._initialize_weights() @torch.jit.ignore - def group_matcher(self, coarse=False): + def group_matcher(self, coarse: bool = False) -> Dict[str, Any]: + """Group matcher for parameter groups. + + Args: + coarse: Whether to use coarse grouping. + + Returns: + Dictionary of grouped parameters. + """ # this treats BN layers as separate groups for bn variants, a lot of effort to fix that return dict(stem=r'^features\.0', blocks=r'^features\.(\d+)') @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): + def set_grad_checkpointing(self, enable: bool = True) -> None: + """Enable or disable gradient checkpointing. + + Args: + enable: Whether to enable gradient checkpointing. + """ assert not enable, 'gradient checkpointing not supported' @torch.jit.ignore def get_classifier(self) -> nn.Module: + """Get the classifier module. + + Returns: + Classifier module. + """ return self.head.fc - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None: + """Reset the classifier. + + Args: + num_classes: Number of classes for new classifier. + global_pool: Global pooling type. + """ self.num_classes = num_classes self.head.reset(num_classes, global_pool) def forward_features(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through feature extraction layers. + + Args: + x: Input tensor. + + Returns: + Feature tensor. + """ x = self.features(x) return x - def forward_head(self, x: torch.Tensor, pre_logits: bool = False): + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + """Forward pass through head. + + Args: + x: Input features. + pre_logits: Return features before final linear layer. + + Returns: + Classification logits or features. + """ x = self.pre_logits(x) return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor. + + Returns: + Output logits. + """ x = self.forward_features(x) x = self.forward_head(x) return x def _initialize_weights(self) -> None: + """Initialize model weights.""" for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') @@ -171,8 +263,15 @@ def _initialize_weights(self) -> None: nn.init.constant_(m.bias, 0) -def _filter_fn(state_dict): - """ convert patch embedding weight from manual patchify + linear proj to conv""" +def _filter_fn(state_dict: dict) -> Dict[str, torch.Tensor]: + """Convert patch embedding weight from manual patchify + linear proj to conv. + + Args: + state_dict: State dictionary to filter. + + Returns: + Filtered state dictionary. + """ out_dict = {} for k, v in state_dict.items(): k_r = k @@ -188,6 +287,16 @@ def _filter_fn(state_dict): def _create_vgg(variant: str, pretrained: bool, **kwargs: Any) -> VGG: + """Create a VGG model. + + Args: + variant: Model variant name. + pretrained: Load pretrained weights. + **kwargs: Additional model arguments. + + Returns: + VGG model instance. + """ cfg = variant.split('_')[0] # NOTE: VGG is one of few models with stride==1 features w/ 6 out_indices [0..5] out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4, 5)) @@ -203,7 +312,16 @@ def _create_vgg(variant: str, pretrained: bool, **kwargs: Any) -> VGG: return model -def _cfg(url='', **kwargs): +def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: + """Create default configuration dictionary. + + Args: + url: Model weight URL. + **kwargs: Additional configuration options. + + Returns: + Configuration dictionary. + """ return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 3fcf400ee..b57e2f213 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -78,21 +78,37 @@ class LayerScale(nn.Module): + """Layer scale module. + + References: + - https://arxiv.org/abs/2103.17239 + """ + def __init__( self, dim: int, init_values: float = 1e-5, inplace: bool = False, ) -> None: + """Initialize LayerScale module. + + Args: + dim: Dimension. + init_values: Initial value for scaling. + inplace: If True, perform inplace operations. + """ super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply layer scaling.""" return x.mul_(self.gamma) if self.inplace else x * self.gamma class Block(nn.Module): + """Transformer block with pre-normalization.""" + def __init__( self, dim: int, @@ -111,6 +127,23 @@ def __init__( norm_layer: Type[nn.Module] = LayerNorm, mlp_layer: Type[nn.Module] = Mlp, ) -> None: + """Initialize Block. + + Args: + dim: Number of input channels. + num_heads: Number of attention heads. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: If True, add a learnable bias to query, key, value. + qk_norm: If True, apply normalization to query and key. + proj_bias: If True, add bias to output projection. + proj_drop: Projection dropout rate. + attn_drop: Attention dropout rate. + init_values: Initial values for layer scale. + drop_path: Stochastic depth rate. + act_layer: Activation layer. + norm_layer: Normalization layer. + mlp_layer: MLP layer. + """ super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( @@ -293,7 +326,7 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x_attn = attn @ v - + x_attn = x_attn.transpose(1, 2).reshape(B, N, C) x_attn = self.attn_out_proj(x_attn) @@ -589,7 +622,8 @@ def __init__( if fix_init: self.fix_init_weight() - def fix_init_weight(self): + def fix_init_weight(self) -> None: + """Apply weight initialization fix (scaling w/ layer index).""" def rescale(param, _layer_id): param.div_(math.sqrt(2.0 * _layer_id)) @@ -598,6 +632,11 @@ def rescale(param, _layer_id): rescale(layer.mlp.fc2.weight.data, layer_id + 1) def init_weights(self, mode: str = '') -> None: + """Initialize model weights. + + Args: + mode: Weight initialization mode ('jax', 'jax_nlhb', 'moco', or ''). + """ assert mode in ('jax', 'jax_nlhb', 'moco', '') head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. if self.pos_embed is not None: @@ -609,19 +648,35 @@ def init_weights(self, mode: str = '') -> None: named_apply(get_init_weights_vit(mode, head_bias), self) def _init_weights(self, m: nn.Module) -> None: + """Initialize weights for a single module (compatibility method).""" # this fn left here for compat with downstream users init_weights_vit_timm(m) @torch.jit.ignore() def load_pretrained(self, checkpoint_path: str, prefix: str = '') -> None: + """Load pretrained weights. + + Args: + checkpoint_path: Path to checkpoint. + prefix: Prefix for state dict keys. + """ _load_weights(self, checkpoint_path, prefix) @torch.jit.ignore - def no_weight_decay(self) -> Set: + def no_weight_decay(self) -> Set[str]: + """Set of parameters that should not use weight decay.""" return {'pos_embed', 'cls_token', 'dist_token'} @torch.jit.ignore - def group_matcher(self, coarse: bool = False) -> Dict: + def group_matcher(self, coarse: bool = False) -> Dict[str, Union[str, List]]: + """Create regex patterns for parameter grouping. + + Args: + coarse: Use coarse grouping. + + Returns: + Dictionary mapping group names to regex patterns. + """ return dict( stem=r'^cls_token|pos_embed|patch_embed', # stem and embed blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] @@ -629,15 +684,27 @@ def group_matcher(self, coarse: bool = False) -> Dict: @torch.jit.ignore def set_grad_checkpointing(self, enable: bool = True) -> None: + """Enable or disable gradient checkpointing. + + Args: + enable: Whether to enable gradient checkpointing. + """ self.grad_checkpointing = enable if hasattr(self.patch_embed, 'set_grad_checkpointing'): self.patch_embed.set_grad_checkpointing(enable) @torch.jit.ignore def get_classifier(self) -> nn.Module: + """Get the classifier head.""" return self.head - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None: + """Reset the classifier head. + + Args: + num_classes: Number of classes for new classifier. + global_pool: Global pooling type. + """ self.num_classes = num_classes if global_pool is not None: assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map') @@ -652,12 +719,12 @@ def set_input_size( self, img_size: Optional[Tuple[int, int]] = None, patch_size: Optional[Tuple[int, int]] = None, - ): - """Method updates the input image resolution, patch size + ) -> None: + """Update the input image resolution and patch size. Args: - img_size: New input resolution, if None current resolution is used - patch_size: New patch size, if None existing patch size is used + img_size: New input resolution, if None current resolution is used. + patch_size: New patch size, if None existing patch size is used. """ prev_grid_size = self.patch_embed.grid_size self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size) @@ -674,6 +741,7 @@ def set_input_size( )) def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: + """Apply positional embedding to input.""" if self.pos_embed is None: return x.view(x.shape[0], -1, x.shape[-1]) @@ -773,7 +841,7 @@ def forward_intermediates( # reshape to BCHW output format H, W = self.patch_embed.dynamic_feat_size((height, width)) intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] - + # For dictionary output, handle prefix tokens separately if output_dict: result_dict = {} @@ -781,14 +849,14 @@ def forward_intermediates( result_dict['image_intermediates'] = intermediates if prefix_tokens is not None and return_prefix_tokens: result_dict['image_intermediates_prefix'] = prefix_tokens - + # Only include features if not intermediates_only if not intermediates_only: x_final = self.norm(x) result_dict['image_features'] = x_final - + return result_dict - + # For non-dictionary output, maintain the original behavior if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None: # return_prefix not support in torchscript due to poor type handling @@ -806,8 +874,16 @@ def prune_intermediate_layers( indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, - ): - """ Prune layers not required for specified intermediates. + ) -> List[int]: + """Prune layers not required for specified intermediates. + + Args: + indices: Indices of intermediate layers to keep. + prune_norm: Whether to prune normalization layer. + prune_head: Whether to prune the classifier head. + + Returns: + List of indices that were kept. """ take_indices, max_index = feature_take_indices(len(self.blocks), indices) self.blocks = self.blocks[:max_index + 1] # truncate blocks @@ -827,8 +903,19 @@ def get_intermediate_layers( norm: bool = False, attn_mask: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: - """ Intermediate layer accessor inspired by DINO / DINOv2 interface. + """Get intermediate layer outputs (DINO interface compatibility). + NOTE: This API is for backwards compat, favour using forward_intermediates() directly. + + Args: + x: Input tensor. + n: Number or indices of layers. + reshape: Reshape to NCHW format. + return_prefix_tokens: Return prefix tokens. + norm: Apply normalization. + + Returns: + List of intermediate features. """ return self.forward_intermediates( x, n, @@ -840,11 +927,12 @@ def get_intermediate_layers( ) def forward_features(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward pass through feature layers (embeddings, transformer blocks, post-transformer norm).""" x = self.patch_embed(x) x = self._pos_embed(x) x = self.patch_drop(x) x = self.norm_pre(x) - + if attn_mask is not None: # If mask provided, we need to apply blocks one by one for blk in self.blocks: @@ -853,11 +941,20 @@ def forward_features(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = x = checkpoint_seq(self.blocks, x) else: x = self.blocks(x) - + x = self.norm(x) return x def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor: + """Apply pooling to feature tokens. + + Args: + x: Feature tensor. + pool_type: Pooling type override. + + Returns: + Pooled features. + """ if self.attn_pool is not None: if not self.pool_include_prefix: x = x[:, self.num_prefix_tokens:] @@ -873,6 +970,15 @@ def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor return x def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + """Forward pass through classifier head. + + Args: + x: Feature tensor. + pre_logits: Return features before final classifier. + + Returns: + Output tensor. + """ x = self.pool(x) x = self.fc_norm(x) x = self.head_drop(x) @@ -885,7 +991,12 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> def init_weights_vit_timm(module: nn.Module, name: str = '') -> None: - """ ViT weight initialization, original timm impl (for reproducibility) """ + """ViT weight initialization, original timm impl (for reproducibility). + + Args: + module: Module to initialize. + name: Module name for context. + """ if isinstance(module, nn.Linear): trunc_normal_(module.weight, std=.02) if module.bias is not None: @@ -895,7 +1006,13 @@ def init_weights_vit_timm(module: nn.Module, name: str = '') -> None: def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.0) -> None: - """ ViT weight initialization, matching JAX (Flax) impl """ + """ViT weight initialization, matching JAX (Flax) impl. + + Args: + module: Module to initialize. + name: Module name for context. + head_bias: Bias value for head layer. + """ if isinstance(module, nn.Linear): if name.startswith('head'): nn.init.zeros_(module.weight) @@ -913,7 +1030,12 @@ def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0 def init_weights_vit_moco(module: nn.Module, name: str = '') -> None: - """ ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """ + """ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed. + + Args: + module: Module to initialize. + name: Module name for context. + """ if isinstance(module, nn.Linear): if 'qkv' in name: # treat the weights of Q, K, V separately diff --git a/timm/models/vitamin.py b/timm/models/vitamin.py index 7d023e0d3..fa7627542 100644 --- a/timm/models/vitamin.py +++ b/timm/models/vitamin.py @@ -124,11 +124,11 @@ class StridedConv(nn.Module): """ downsample 2d as well """ def __init__( - self, - kernel_size=3, - stride=2, + self, + kernel_size=3, + stride=2, padding=1, - in_chans=3, + in_chans=3, embed_dim=768 ): super().__init__() @@ -138,7 +138,7 @@ def __init__( self.norm = norm_layer(in_chans) # affine over C def forward(self, x): - x = self.norm(x) + x = self.norm(x) x = self.proj(x) return x @@ -163,7 +163,7 @@ def __init__( mid_chs = make_divisible(out_chs * expand_ratio) prenorm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps) - if stride == 2: + if stride == 2: self.shortcut = Downsample2d(in_chs, out_chs, pool_type='avg', bias=True) elif in_chs != out_chs: self.shortcut = nn.Conv2d(in_chs, out_chs, 1, bias=True) @@ -188,11 +188,11 @@ def forward(self, x): shortcut = self.shortcut(x) x = self.pre_norm(x) - x = self.down(x) # nn.Identity() + x = self.down(x) # nn.Identity() # 1x1 expansion conv & act x = self.conv1_1x1(x) - x = self.act1(x) + x = self.act1(x) # (strided) depthwise 3x3 conv & act x = self.conv2_kxk(x) @@ -255,8 +255,8 @@ def forward(self, x): class GeGluMlp(nn.Module): def __init__( - self, - in_features, + self, + in_features, hidden_features, act_layer = 'gelu', norm_layer = None, diff --git a/timm/models/volo.py b/timm/models/volo.py index 46be778f6..f76a8361a 100644 --- a/timm/models/volo.py +++ b/timm/models/volo.py @@ -20,9 +20,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -38,18 +37,31 @@ class OutlookAttention(nn.Module): + """Outlook attention mechanism for VOLO models.""" def __init__( self, - dim, - num_heads, - kernel_size=3, - padding=1, - stride=1, - qkv_bias=False, - attn_drop=0., - proj_drop=0., + dim: int, + num_heads: int, + kernel_size: int = 3, + padding: int = 1, + stride: int = 1, + qkv_bias: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., ): + """Initialize OutlookAttention. + + Args: + dim: Input feature dimension. + num_heads: Number of attention heads. + kernel_size: Kernel size for attention computation. + padding: Padding for attention computation. + stride: Stride for attention computation. + qkv_bias: Whether to use bias in linear layers. + attn_drop: Attention dropout rate. + proj_drop: Projection dropout rate. + """ super().__init__() head_dim = dim // num_heads self.num_heads = num_heads @@ -68,7 +80,15 @@ def __init__( self.unfold = nn.Unfold(kernel_size=kernel_size, padding=padding, stride=stride) self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor of shape (B, H, W, C). + + Returns: + Output tensor of shape (B, H, W, C). + """ B, H, W, C = x.shape v = self.v(x).permute(0, 3, 1, 2) # B, C, H, W @@ -96,20 +116,37 @@ def forward(self, x): class Outlooker(nn.Module): + """Outlooker block that combines outlook attention with MLP.""" + def __init__( self, - dim, - kernel_size, - padding, - stride=1, - num_heads=1, - mlp_ratio=3., - attn_drop=0., - drop_path=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - qkv_bias=False, + dim: int, + kernel_size: int, + padding: int, + stride: int = 1, + num_heads: int = 1, + mlp_ratio: float = 3., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + qkv_bias: bool = False, ): + """Initialize Outlooker block. + + Args: + dim: Input feature dimension. + kernel_size: Kernel size for outlook attention. + padding: Padding for outlook attention. + stride: Stride for outlook attention. + num_heads: Number of attention heads. + mlp_ratio: Ratio for MLP hidden dimension. + attn_drop: Attention dropout rate. + drop_path: Stochastic depth drop rate. + act_layer: Activation layer type. + norm_layer: Normalization layer type. + qkv_bias: Whether to use bias in linear layers. + """ super().__init__() self.norm1 = norm_layer(dim) self.attn = OutlookAttention( @@ -131,23 +168,41 @@ def __init__( ) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor. + + Returns: + Output tensor. + """ x = x + self.drop_path1(self.attn(self.norm1(x))) x = x + self.drop_path2(self.mlp(self.norm2(x))) return x class Attention(nn.Module): + """Multi-head self-attention module.""" fused_attn: torch.jit.Final[bool] def __init__( self, - dim, - num_heads=8, - qkv_bias=False, - attn_drop=0., - proj_drop=0., + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., ): + """Initialize Attention module. + + Args: + dim: Input feature dimension. + num_heads: Number of attention heads. + qkv_bias: Whether to use bias in QKV projection. + attn_drop: Attention dropout rate. + proj_drop: Projection dropout rate. + """ super().__init__() self.num_heads = num_heads head_dim = dim // num_heads @@ -159,7 +214,15 @@ def __init__( self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor of shape (B, H, W, C). + + Returns: + Output tensor of shape (B, H, W, C). + """ B, H, W, C = x.shape qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) @@ -185,18 +248,31 @@ def forward(self, x): class Transformer(nn.Module): + """Transformer block with multi-head self-attention and MLP.""" def __init__( self, - dim, - num_heads, - mlp_ratio=4., - qkv_bias=False, - attn_drop=0., - drop_path=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + attn_drop: float = 0., + drop_path: float = 0., + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, ): + """Initialize Transformer block. + + Args: + dim: Input feature dimension. + num_heads: Number of attention heads. + mlp_ratio: Ratio for MLP hidden dimension. + qkv_bias: Whether to use bias in QKV projection. + attn_drop: Attention dropout rate. + drop_path: Stochastic depth drop rate. + act_layer: Activation layer type. + norm_layer: Normalization layer type. + """ super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop) @@ -206,23 +282,42 @@ def __init__( self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor. + + Returns: + Output tensor. + """ x = x + self.drop_path1(self.attn(self.norm1(x))) x = x + self.drop_path2(self.mlp(self.norm2(x))) return x class ClassAttention(nn.Module): + """Class attention mechanism for class token interaction.""" def __init__( self, - dim, - num_heads=8, - head_dim=None, - qkv_bias=False, - attn_drop=0., - proj_drop=0., + dim: int, + num_heads: int = 8, + head_dim: Optional[int] = None, + qkv_bias: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., ): + """Initialize ClassAttention. + + Args: + dim: Input feature dimension. + num_heads: Number of attention heads. + head_dim: Dimension per head. If None, computed as dim // num_heads. + qkv_bias: Whether to use bias in QKV projection. + attn_drop: Attention dropout rate. + proj_drop: Projection dropout rate. + """ super().__init__() self.num_heads = num_heads if head_dim is not None: @@ -238,7 +333,15 @@ def __init__( self.proj = nn.Linear(self.head_dim * self.num_heads, dim) self.proj_drop = nn.Dropout(proj_drop) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor of shape (B, N, C) where first token is class token. + + Returns: + Class token output of shape (B, 1, C). + """ B, N, C = x.shape kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) @@ -256,20 +359,35 @@ def forward(self, x): class ClassBlock(nn.Module): + """Class block that combines class attention with MLP.""" def __init__( self, - dim, - num_heads, - head_dim=None, - mlp_ratio=4., - qkv_bias=False, - drop=0., - attn_drop=0., - drop_path=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, + dim: int, + num_heads: int, + head_dim: Optional[int] = None, + mlp_ratio: float = 4., + qkv_bias: bool = False, + drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, ): + """Initialize ClassBlock. + + Args: + dim: Input feature dimension. + num_heads: Number of attention heads. + head_dim: Dimension per head. If None, computed as dim // num_heads. + mlp_ratio: Ratio for MLP hidden dimension. + qkv_bias: Whether to use bias in QKV projection. + drop: Dropout rate. + attn_drop: Attention dropout rate. + drop_path: Stochastic depth drop rate. + act_layer: Activation layer type. + norm_layer: Normalization layer type. + """ super().__init__() self.norm1 = norm_layer(dim) self.attn = ClassAttention( @@ -291,56 +409,94 @@ def __init__( ) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor of shape (B, N, C) where first token is class token. + + Returns: + Output tensor with updated class token. + """ cls_embed = x[:, :1] cls_embed = cls_embed + self.drop_path1(self.attn(self.norm1(x))) cls_embed = cls_embed + self.drop_path2(self.mlp(self.norm2(cls_embed))) return torch.cat([cls_embed, x[:, 1:]], dim=1) -def get_block(block_type, **kargs): +def get_block(block_type: str, **kargs: Any) -> nn.Module: + """Get block based on type. + + Args: + block_type: Type of block ('ca' for ClassBlock). + **kargs: Additional keyword arguments for block. + + Returns: + The requested block module. + """ if block_type == 'ca': return ClassBlock(**kargs) -def rand_bbox(size, lam, scale=1): - """ - get bounding box as token labeling (https://github.com/zihangJiang/TokenLabeling) - return: bounding box +def rand_bbox(size: Tuple[int, ...], lam: float, scale: int = 1) -> Tuple[int, int, int, int]: + """Get random bounding box for token labeling. + + Reference: https://github.com/zihangJiang/TokenLabeling + + Args: + size: Input tensor size tuple. + lam: Lambda parameter for cutmix. + scale: Scaling factor. + + Returns: + Bounding box coordinates (bbx1, bby1, bbx2, bby2). """ W = size[1] // scale H = size[2] // scale - cut_rat = np.sqrt(1. - lam) - cut_w = (W * cut_rat).astype(int) - cut_h = (H * cut_rat).astype(int) + W_t = torch.tensor(W, dtype=torch.float32) + H_t = torch.tensor(H, dtype=torch.float32) + cut_rat = torch.sqrt(1. - lam) + cut_w = (W_t * cut_rat).int() + cut_h = (H_t * cut_rat).int() # uniform - cx = np.random.randint(W) - cy = np.random.randint(H) + cx = torch.randint(0, W, (1,)) + cy = torch.randint(0, H, (1,)) - bbx1 = np.clip(cx - cut_w // 2, 0, W) - bby1 = np.clip(cy - cut_h // 2, 0, H) - bbx2 = np.clip(cx + cut_w // 2, 0, W) - bby2 = np.clip(cy + cut_h // 2, 0, H) + bbx1 = torch.clamp(cx - cut_w // 2, 0, W) + bby1 = torch.clamp(cy - cut_h // 2, 0, H) + bbx2 = torch.clamp(cx + cut_w // 2, 0, W) + bby2 = torch.clamp(cy + cut_h // 2, 0, H) - return bbx1, bby1, bbx2, bby2 + return bbx1.item(), bby1.item(), bbx2.item(), bby2.item() class PatchEmbed(nn.Module): - """ Image to Patch Embedding. - Different with ViT use 1 conv layer, we use 4 conv layers to do patch embedding - """ + """Image to patch embedding with multi-layer convolution.""" def __init__( self, - img_size=224, - stem_conv=False, - stem_stride=1, - patch_size=8, - in_chans=3, - hidden_dim=64, - embed_dim=384, + img_size: int = 224, + stem_conv: bool = False, + stem_stride: int = 1, + patch_size: int = 8, + in_chans: int = 3, + hidden_dim: int = 64, + embed_dim: int = 384, ): + """Initialize PatchEmbed. + + Different from ViT which uses 1 conv layer, VOLO uses multiple conv layers for patch embedding. + + Args: + img_size: Input image size. + stem_conv: Whether to use stem convolution layers. + stem_stride: Stride for stem convolution. + patch_size: Patch size (must be 4, 8, or 16). + in_chans: Number of input channels. + hidden_dim: Hidden dimension for stem convolution. + embed_dim: Output embedding dimension. + """ super().__init__() assert patch_size in [4, 8, 16] if stem_conv: @@ -362,7 +518,15 @@ def __init__( hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride) self.num_patches = (img_size // patch_size) * (img_size // patch_size) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor of shape (B, C, H, W). + + Returns: + Output tensor of shape (B, embed_dim, H', W'). + """ if self.conv is not None: x = self.conv(x) x = self.proj(x) # B, C, H, W @@ -370,14 +534,28 @@ def forward(self, x): class Downsample(nn.Module): - """ Image to Patch Embedding, downsampling between stage1 and stage2 - """ + """Downsampling module between stages.""" + + def __init__(self, in_embed_dim: int, out_embed_dim: int, patch_size: int = 2): + """Initialize Downsample. - def __init__(self, in_embed_dim, out_embed_dim, patch_size=2): + Args: + in_embed_dim: Input embedding dimension. + out_embed_dim: Output embedding dimension. + patch_size: Patch size for downsampling. + """ super().__init__() self.proj = nn.Conv2d(in_embed_dim, out_embed_dim, kernel_size=patch_size, stride=patch_size) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor of shape (B, H, W, C). + + Returns: + Output tensor of shape (B, H', W', C'). + """ x = x.permute(0, 3, 1, 2) x = self.proj(x) # B, C, H, W x = x.permute(0, 2, 3, 1) @@ -385,23 +563,39 @@ def forward(self, x): def outlooker_blocks( - block_fn, - index, - dim, - layers, - num_heads=1, - kernel_size=3, - padding=1, - stride=2, - mlp_ratio=3., - qkv_bias=False, - attn_drop=0, - drop_path_rate=0., - **kwargs, -): - """ - generate outlooker layer in stage1 - return: outlooker layers + block_fn: Callable, + index: int, + dim: int, + layers: List[int], + num_heads: int = 1, + kernel_size: int = 3, + padding: int = 1, + stride: int = 2, + mlp_ratio: float = 3., + qkv_bias: bool = False, + attn_drop: float = 0, + drop_path_rate: float = 0., + **kwargs: Any, +) -> nn.Sequential: + """Generate outlooker layers for stage 1. + + Args: + block_fn: Block function to use (typically Outlooker). + index: Index of current stage. + dim: Feature dimension. + layers: List of layer counts for each stage. + num_heads: Number of attention heads. + kernel_size: Kernel size for outlook attention. + padding: Padding for outlook attention. + stride: Stride for outlook attention. + mlp_ratio: Ratio for MLP hidden dimension. + qkv_bias: Whether to use bias in QKV projection. + attn_drop: Attention dropout rate. + drop_path_rate: Stochastic depth drop rate. + **kwargs: Additional keyword arguments. + + Returns: + Sequential module containing outlooker blocks. """ blocks = [] for block_idx in range(layers[index]): @@ -422,20 +616,33 @@ def outlooker_blocks( def transformer_blocks( - block_fn, - index, - dim, - layers, - num_heads, - mlp_ratio=3., - qkv_bias=False, - attn_drop=0, - drop_path_rate=0., - **kwargs, -): - """ - generate transformer layers in stage2 - return: transformer layers + block_fn: Callable, + index: int, + dim: int, + layers: List[int], + num_heads: int, + mlp_ratio: float = 3., + qkv_bias: bool = False, + attn_drop: float = 0, + drop_path_rate: float = 0., + **kwargs: Any, +) -> nn.Sequential: + """Generate transformer layers for stage 2. + + Args: + block_fn: Block function to use (typically Transformer). + index: Index of current stage. + dim: Feature dimension. + layers: List of layer counts for each stage. + num_heads: Number of attention heads. + mlp_ratio: Ratio for MLP hidden dimension. + qkv_bias: Whether to use bias in QKV projection. + attn_drop: Attention dropout rate. + drop_path_rate: Stochastic depth drop rate. + **kwargs: Additional keyword arguments. + + Returns: + Sequential module containing transformer blocks. """ blocks = [] for block_idx in range(layers[index]): @@ -453,35 +660,59 @@ def transformer_blocks( class VOLO(nn.Module): - """ - Vision Outlooker, the main class of our model - """ + """Vision Outlooker (VOLO) model.""" def __init__( self, - layers, - img_size=224, - in_chans=3, - num_classes=1000, - global_pool='token', - patch_size=8, - stem_hidden_dim=64, - embed_dims=None, - num_heads=None, - downsamples=(True, False, False, False), - outlook_attention=(True, False, False, False), - mlp_ratio=3.0, - qkv_bias=False, - drop_rate=0., - pos_drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0., - norm_layer=nn.LayerNorm, - post_layers=('ca', 'ca'), - use_aux_head=True, - use_mix_token=False, - pooling_scale=2, + layers: List[int], + img_size: int = 224, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'token', + patch_size: int = 8, + stem_hidden_dim: int = 64, + embed_dims: Optional[List[int]] = None, + num_heads: Optional[List[int]] = None, + downsamples: Tuple[bool, ...] = (True, False, False, False), + outlook_attention: Tuple[bool, ...] = (True, False, False, False), + mlp_ratio: float = 3.0, + qkv_bias: bool = False, + drop_rate: float = 0., + pos_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + norm_layer: Callable = nn.LayerNorm, + post_layers: Optional[Tuple[str, ...]] = ('ca', 'ca'), + use_aux_head: bool = True, + use_mix_token: bool = False, + pooling_scale: int = 2, ): + """Initialize VOLO model. + + Args: + layers: Number of blocks in each stage. + img_size: Input image size. + in_chans: Number of input channels. + num_classes: Number of classes for classification. + global_pool: Global pooling type ('token', 'avg', or ''). + patch_size: Patch size for patch embedding. + stem_hidden_dim: Hidden dimension for stem convolution. + embed_dims: List of embedding dimensions for each stage. + num_heads: List of number of attention heads for each stage. + downsamples: Whether to downsample between stages. + outlook_attention: Whether to use outlook attention in each stage. + mlp_ratio: Ratio for MLP hidden dimension. + qkv_bias: Whether to use bias in QKV projection. + drop_rate: Dropout rate. + pos_drop_rate: Position embedding dropout rate. + attn_drop_rate: Attention dropout rate. + drop_path_rate: Stochastic depth drop rate. + norm_layer: Normalization layer type. + post_layers: Post-processing layer types. + use_aux_head: Whether to use auxiliary head. + use_mix_token: Whether to use token mixing for training. + pooling_scale: Pooling scale factor. + """ super().__init__() num_layers = len(layers) mlp_ratio = to_ntuple(num_layers)(mlp_ratio) @@ -589,18 +820,36 @@ def __init__( trunc_normal_(self.pos_embed, std=.02) self.apply(self._init_weights) - def _init_weights(self, m): + def _init_weights(self, m: nn.Module) -> None: + """Initialize weights for modules. + + Args: + m: Module to initialize. + """ if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) @torch.jit.ignore - def no_weight_decay(self): + def no_weight_decay(self) -> set: + """Get set of parameters that should not have weight decay. + + Returns: + Set of parameter names. + """ return {'pos_embed', 'cls_token'} @torch.jit.ignore - def group_matcher(self, coarse=False): + def group_matcher(self, coarse: bool = False) -> Dict[str, Any]: + """Get parameter grouping for optimizer. + + Args: + coarse: Whether to use coarse grouping. + + Returns: + Parameter grouping dictionary. + """ return dict( stem=r'^cls_token|pos_embed|patch_embed', # stem and embed blocks=[ @@ -615,14 +864,30 @@ def group_matcher(self, coarse=False): ) @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): + def set_grad_checkpointing(self, enable: bool = True) -> None: + """Set gradient checkpointing. + + Args: + enable: Whether to enable gradient checkpointing. + """ self.grad_checkpointing = enable @torch.jit.ignore def get_classifier(self) -> nn.Module: + """Get classifier module. + + Returns: + The classifier head module. + """ return self.head - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None: + """Reset classifier head. + + Args: + num_classes: Number of classes for new classifier. + global_pool: Global pooling type. + """ self.num_classes = num_classes if global_pool is not None: self.global_pool = global_pool @@ -630,7 +895,15 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): if self.aux_head is not None: self.aux_head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - def forward_tokens(self, x): + def forward_tokens(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through token processing stages. + + Args: + x: Input tensor of shape (B, H, W, C). + + Returns: + Token tensor of shape (B, N, C). + """ for idx, block in enumerate(self.network): if idx == 2: # add positional encoding after outlooker blocks @@ -645,7 +918,15 @@ def forward_tokens(self, x): x = x.reshape(B, -1, C) return x - def forward_cls(self, x): + def forward_cls(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through class attention blocks. + + Args: + x: Input token tensor of shape (B, N, C). + + Returns: + Output tensor with class token of shape (B, N+1, C). + """ B, N, C = x.shape cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat([cls_tokens, x], dim=1) @@ -656,7 +937,16 @@ def forward_cls(self, x): x = block(x) return x - def forward_train(self, x): + def forward_train(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, Tuple[int, int, int, int]]]: + """Forward pass for training with mix token support. + + Args: + x: Input tensor of shape (B, C, H, W). + + Returns: + If training with mix_token: tuple of (class_token, aux_tokens, bbox). + Otherwise: class_token tensor. + """ """ A separate forward fn for training with mix_token (if a train script supports). Combining multiple modes in as single forward with different return types is torchscript hell. """ @@ -665,7 +955,7 @@ def forward_train(self, x): # mix token, see token labeling for details. if self.mix_token and self.training: - lam = np.random.beta(self.beta, self.beta) + lam = torch.distributions.Beta(self.beta, self.beta).sample() patch_h, patch_w = x.shape[1] // self.pooling_scale, x.shape[2] // self.pooling_scale bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam, scale=self.pooling_scale) temp_x = x.clone() @@ -775,7 +1065,17 @@ def prune_intermediate_layers( indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, - ): + ) -> List[int]: + """Prune layers not required for specified intermediates. + + Args: + indices: Indices of intermediate layers to keep. + prune_norm: Whether to prune normalization layer. + prune_head: Whether to prune classification head. + + Returns: + List of kept intermediate indices. + """ """ Prune layers not required for specified intermediates. """ take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) @@ -788,7 +1088,15 @@ def prune_intermediate_layers( self.reset_classifier(0, '') return take_indices - def forward_features(self, x): + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through feature extraction. + + Args: + x: Input tensor of shape (B, C, H, W). + + Returns: + Feature tensor. + """ x = self.patch_embed(x).permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C # step2: tokens learning in the two stages @@ -800,7 +1108,16 @@ def forward_features(self, x): x = self.norm(x) return x - def forward_head(self, x, pre_logits: bool = False): + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + """Forward pass through classification head. + + Args: + x: Input feature tensor. + pre_logits: Whether to return pre-logits features. + + Returns: + Classification logits or pre-logits features. + """ if self.global_pool == 'avg': out = x.mean(dim=1) elif self.global_pool == 'token': @@ -817,14 +1134,32 @@ def forward_head(self, x, pre_logits: bool = False): out = out + 0.5 * aux.max(1)[0] return out - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass (simplified, without mix token training). + + Args: + x: Input tensor of shape (B, C, H, W). + + Returns: + Classification logits. + """ """ simplified forward (without mix token training) """ x = self.forward_features(x) x = self.forward_head(x) return x -def _create_volo(variant, pretrained=False, **kwargs): +def _create_volo(variant: str, pretrained: bool = False, **kwargs: Any) -> VOLO: + """Create VOLO model. + + Args: + variant: Model variant name. + pretrained: Whether to load pretrained weights. + **kwargs: Additional model arguments. + + Returns: + VOLO model instance. + """ out_indices = kwargs.pop('out_indices', 3) return build_model_with_cfg( VOLO, @@ -835,7 +1170,16 @@ def _create_volo(variant, pretrained=False, **kwargs): ) -def _cfg(url='', **kwargs): +def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]: + """Create model configuration. + + Args: + url: URL for pretrained weights. + **kwargs: Additional configuration options. + + Returns: + Model configuration dictionary. + """ return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, @@ -895,73 +1239,74 @@ def _cfg(url='', **kwargs): @register_model -def volo_d1_224(pretrained=False, **kwargs) -> VOLO: - """ VOLO-D1 model, Params: 27M """ +def volo_d1_224(pretrained: bool = False, **kwargs: Any) -> VOLO: + """VOLO-D1 model, Params: 27M.""" model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs) model = _create_volo('volo_d1_224', pretrained=pretrained, **model_args) return model @register_model -def volo_d1_384(pretrained=False, **kwargs) -> VOLO: - """ VOLO-D1 model, Params: 27M """ +def volo_d1_384(pretrained: bool = False, **kwargs: Any) -> VOLO: + """VOLO-D1 model, Params: 27M.""" model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs) model = _create_volo('volo_d1_384', pretrained=pretrained, **model_args) return model @register_model -def volo_d2_224(pretrained=False, **kwargs) -> VOLO: - """ VOLO-D2 model, Params: 59M """ +def volo_d2_224(pretrained: bool = False, **kwargs: Any) -> VOLO: + """VOLO-D2 model, Params: 59M.""" model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs) model = _create_volo('volo_d2_224', pretrained=pretrained, **model_args) return model @register_model -def volo_d2_384(pretrained=False, **kwargs) -> VOLO: - """ VOLO-D2 model, Params: 59M """ +def volo_d2_384(pretrained: bool = False, **kwargs: Any) -> VOLO: + """VOLO-D2 model, Params: 59M.""" model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs) model = _create_volo('volo_d2_384', pretrained=pretrained, **model_args) return model @register_model -def volo_d3_224(pretrained=False, **kwargs) -> VOLO: - """ VOLO-D3 model, Params: 86M """ +def volo_d3_224(pretrained: bool = False, **kwargs: Any) -> VOLO: + """VOLO-D3 model, Params: 86M.""" model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs) model = _create_volo('volo_d3_224', pretrained=pretrained, **model_args) return model @register_model -def volo_d3_448(pretrained=False, **kwargs) -> VOLO: - """ VOLO-D3 model, Params: 86M """ +def volo_d3_448(pretrained: bool = False, **kwargs: Any) -> VOLO: + """VOLO-D3 model, Params: 86M.""" model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs) model = _create_volo('volo_d3_448', pretrained=pretrained, **model_args) return model @register_model -def volo_d4_224(pretrained=False, **kwargs) -> VOLO: - """ VOLO-D4 model, Params: 193M """ +def volo_d4_224(pretrained: bool = False, **kwargs: Any) -> VOLO: + """VOLO-D4 model, Params: 193M.""" model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs) model = _create_volo('volo_d4_224', pretrained=pretrained, **model_args) return model @register_model -def volo_d4_448(pretrained=False, **kwargs) -> VOLO: - """ VOLO-D4 model, Params: 193M """ +def volo_d4_448(pretrained: bool = False, **kwargs: Any) -> VOLO: + """VOLO-D4 model, Params: 193M.""" model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs) model = _create_volo('volo_d4_448', pretrained=pretrained, **model_args) return model @register_model -def volo_d5_224(pretrained=False, **kwargs) -> VOLO: - """ VOLO-D5 model, Params: 296M - stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5 +def volo_d5_224(pretrained: bool = False, **kwargs: Any) -> VOLO: + """VOLO-D5 model, Params: 296M. + + stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5. """ model_args = dict( layers=(12, 12, 20, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), @@ -971,9 +1316,10 @@ def volo_d5_224(pretrained=False, **kwargs) -> VOLO: @register_model -def volo_d5_448(pretrained=False, **kwargs) -> VOLO: - """ VOLO-D5 model, Params: 296M - stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5 +def volo_d5_448(pretrained: bool = False, **kwargs: Any) -> VOLO: + """VOLO-D5 model, Params: 296M. + + stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5. """ model_args = dict( layers=(12, 12, 20, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), @@ -983,9 +1329,10 @@ def volo_d5_448(pretrained=False, **kwargs) -> VOLO: @register_model -def volo_d5_512(pretrained=False, **kwargs) -> VOLO: - """ VOLO-D5 model, Params: 296M - stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5 +def volo_d5_512(pretrained: bool = False, **kwargs: Any) -> VOLO: + """VOLO-D5 model, Params: 296M. + + stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5. """ model_args = dict( layers=(12, 12, 20, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), diff --git a/timm/models/vovnet.py b/timm/models/vovnet.py index 0b48a34c1..8da9431f0 100644 --- a/timm/models/vovnet.py +++ b/timm/models/vovnet.py @@ -305,7 +305,7 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages, start=1): x = stage(x) if feat_idx in take_indices: - intermediates.append(x) + intermediates.append(x) if intermediates_only: return intermediates