Skip to content
Merged
52 changes: 51 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,57 @@ ignore_errors = True

ignore_errors=True

[mypy-torchvision.models.detection.*]
[mypy-torchvision.models.detection.anchor_utils]

ignore_errors = True

[mypy-torchvision.models.detection.backbone_utils]

ignore_errors = True


[mypy-torchvision.models.detection.image_list]

ignore_errors = True


[mypy-torchvision.models.detection.transform]

ignore_errors = True

[mypy-torchvision.models.detection.rpn]

ignore_errors = True

[mypy-torchvision.models.detection.roi_heads]

ignore_errors = True

[mypy-torchvision.models.detection.generalized_rcnn]

ignore_errors = True

[mypy-torchvision.models.detection.faster_rcnn]

ignore_errors = True

[mypy-torchvision.models.detection.mask_rcnn]

ignore_errors = True

[mypy-torchvision.models.detection.keypoint_rcnn]

ignore_errors = True

[mypy-torchvision.models.detection.retinanet]

ignore_errors = True

[mypy-torchvision.models.detection.ssd]

ignore_errors = True

[mypy-torchvision.models.detection.ssdlite]

ignore_errors = True

Expand Down
9 changes: 4 additions & 5 deletions torchvision/models/detection/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import List, Tuple

import torch
from torch import Tensor
from torch import Tensor, nn
from torchvision.ops.misc import FrozenBatchNorm2d


Expand All @@ -12,8 +12,7 @@ class BalancedPositiveNegativeSampler(object):
This class samples batches, ensuring that they contain a fixed proportion of positives
"""

def __init__(self, batch_size_per_image, positive_fraction):
# type: (int, float) -> None
def __init__(self, batch_size_per_image: int, positive_fraction: float) -> None:
"""
Args:
batch_size_per_image (int): number of elements to be selected per image
Expand Down Expand Up @@ -350,7 +349,7 @@ def __call__(self, match_quality_matrix):
return matches


def overwrite_eps(model, eps):
def overwrite_eps(model: nn.Module, eps: float) -> None:
"""
This method overwrites the default eps values of all the
FrozenBatchNorm2d layers of the model with the provided value.
Expand All @@ -368,7 +367,7 @@ def overwrite_eps(model, eps):
module.eps = eps


def retrieve_out_channels(model, size):
def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]:
"""
This method retrieves the number of output channels of a specific model.

Expand Down