Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,55 @@ ignore_errors = True

ignore_errors=True

[mypy-torchvision.models.detection.*]
[mypy-torchvision.models.detection.anchor_utils]

ignore_errors = True

[mypy-torchvision.models.detection.backbone_utils]

ignore_errors = True

[mypy-torchvision.models.detection.image_list]

ignore_errors = True

[mypy-torchvision.models.detection.transform]

ignore_errors = True

[mypy-torchvision.models.detection.rpn]

ignore_errors = True

[mypy-torchvision.models.detection.roi_heads]

ignore_errors = True

[mypy-torchvision.models.detection.generalized_rcnn]

ignore_errors = True

[mypy-torchvision.models.detection.faster_rcnn]

ignore_errors = True

[mypy-torchvision.models.detection.mask_rcnn]

ignore_errors = True

[mypy-torchvision.models.detection.keypoint_rcnn]

ignore_errors = True

[mypy-torchvision.models.detection.retinanet]

ignore_errors = True

[mypy-torchvision.models.detection.ssd]

ignore_errors = True

[mypy-torchvision.models.detection.ssdlite]

ignore_errors = True

Expand Down
43 changes: 19 additions & 24 deletions torchvision/models/detection/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import List, Tuple

import torch
from torch import Tensor
from torch import Tensor, nn
from torchvision.ops.misc import FrozenBatchNorm2d


Expand All @@ -12,18 +12,16 @@ class BalancedPositiveNegativeSampler(object):
This class samples batches, ensuring that they contain a fixed proportion of positives
"""

def __init__(self, batch_size_per_image, positive_fraction):
# type: (int, float) -> None
def __init__(self, batch_size_per_image: int, positive_fraction: float) -> None:
"""
Args:
batch_size_per_image (int): number of elements to be selected per image
positive_fraction (float): percentace of positive elements per batch
positive_fraction (float): percentage of positive elements per batch
"""
self.batch_size_per_image = batch_size_per_image
self.positive_fraction = positive_fraction

def __call__(self, matched_idxs):
# type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
def __call__(self, matched_idxs: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
"""
Args:
matched idxs: list of tensors containing -1, 0 or positive values.
Expand Down Expand Up @@ -73,8 +71,7 @@ def __call__(self, matched_idxs):


@torch.jit._script_if_tracing
def encode_boxes(reference_boxes, proposals, weights):
# type: (torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
def encode_boxes(reference_boxes: Tensor, proposals: Tensor, weights: Tensor) -> Tensor:
"""
Encode a set of proposals with respect to some
reference boxes
Expand Down Expand Up @@ -127,8 +124,9 @@ class BoxCoder(object):
the representation used for training the regressors.
"""

def __init__(self, weights, bbox_xform_clip=math.log(1000.0 / 16)):
# type: (Tuple[float, float, float, float], float) -> None
def __init__(
self, weights: Tuple[float, float, float, float], bbox_xform_clip: float = math.log(1000.0 / 16)
) -> None:
"""
Args:
weights (4-element tuple)
Expand All @@ -137,15 +135,14 @@ def __init__(self, weights, bbox_xform_clip=math.log(1000.0 / 16)):
self.weights = weights
self.bbox_xform_clip = bbox_xform_clip

def encode(self, reference_boxes, proposals):
# type: (List[Tensor], List[Tensor]) -> List[Tensor]
def encode(self, reference_boxes: List[Tensor], proposals: List[Tensor]) -> List[Tensor]:
boxes_per_image = [len(b) for b in reference_boxes]
reference_boxes = torch.cat(reference_boxes, dim=0)
proposals = torch.cat(proposals, dim=0)
targets = self.encode_single(reference_boxes, proposals)
return targets.split(boxes_per_image, 0)

def encode_single(self, reference_boxes, proposals):
def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
"""
Encode a set of proposals with respect to some
reference boxes
Expand All @@ -161,8 +158,7 @@ def encode_single(self, reference_boxes, proposals):

return targets

def decode(self, rel_codes, boxes):
# type: (Tensor, List[Tensor]) -> Tensor
def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
assert isinstance(boxes, (list, tuple))
assert isinstance(rel_codes, torch.Tensor)
boxes_per_image = [b.size(0) for b in boxes]
Expand All @@ -177,7 +173,7 @@ def decode(self, rel_codes, boxes):
pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
return pred_boxes

def decode_single(self, rel_codes, boxes):
def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
"""
From a set of original boxes and encoded relative box offsets,
get the decoded boxes.
Expand Down Expand Up @@ -244,8 +240,7 @@ class Matcher(object):
"BETWEEN_THRESHOLDS": int,
}

def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False):
# type: (float, float, bool) -> None
def __init__(self, high_threshold: float, low_threshold: float, allow_low_quality_matches: bool = False) -> None:
"""
Args:
high_threshold (float): quality values greater than or equal to
Expand All @@ -266,7 +261,7 @@ def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=Fals
self.low_threshold = low_threshold
self.allow_low_quality_matches = allow_low_quality_matches

def __call__(self, match_quality_matrix):
def __call__(self, match_quality_matrix: Tensor) -> Tensor:
"""
Args:
match_quality_matrix (Tensor[float]): an MxN tensor, containing the
Expand Down Expand Up @@ -304,7 +299,7 @@ def __call__(self, match_quality_matrix):

return matches

def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix):
def set_low_quality_matches_(self, matches: Tensor, all_matches: Tensor, match_quality_matrix: Tensor) -> None:
"""
Produce additional matches for predictions that have only low-quality matches.
Specifically, for each ground-truth find the set of predictions that have
Expand Down Expand Up @@ -335,10 +330,10 @@ def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix):


class SSDMatcher(Matcher):
def __init__(self, threshold):
def __init__(self, threshold: float) -> None:
super().__init__(threshold, threshold, allow_low_quality_matches=False)

def __call__(self, match_quality_matrix):
def __call__(self, match_quality_matrix: Tensor) -> Tensor:
matches = super().__call__(match_quality_matrix)

# For each gt, find the prediction with which it has the highest quality
Expand All @@ -350,7 +345,7 @@ def __call__(self, match_quality_matrix):
return matches


def overwrite_eps(model, eps):
def overwrite_eps(model: nn.Module, eps: float) -> None:
"""
This method overwrites the default eps values of all the
FrozenBatchNorm2d layers of the model with the provided value.
Expand All @@ -368,7 +363,7 @@ def overwrite_eps(model, eps):
module.eps = eps


def retrieve_out_channels(model, size):
def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]:
"""
This method retrieves the number of output channels of a specific model.

Expand Down