Skip to content

Commit c2bbefc

Browse files
authored
fix type hints and spelling mistake in generalized_rcnn and poolers (#2550)
* fix type hints and move degenerate boxes to a function in torchvision.models.detection.generalized_rcnn * format code * format code * changed to static method * revert imports * changed to method * revert procedure for degenerating boxes
1 parent df6a796 commit c2bbefc

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

torchvision/models/detection/generalized_rcnn.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66
from collections import OrderedDict
7+
from typing import Union
78
import torch
89
from torch import nn
910
import warnings
@@ -35,7 +36,7 @@ def __init__(self, backbone, rpn, roi_heads, transform):
3536

3637
@torch.jit.unused
3738
def eager_outputs(self, losses, detections):
38-
# type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
39+
# type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]
3940
if self.training:
4041
return losses
4142

@@ -85,11 +86,11 @@ def forward(self, images, targets=None):
8586
boxes = target["boxes"]
8687
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
8788
if degenerate_boxes.any():
88-
# print the first degenrate box
89+
# print the first degenerate box
8990
bb_idx = degenerate_boxes.any(dim=1).nonzero().view(-1)[0]
9091
degen_bb: List[float] = boxes[bb_idx].tolist()
9192
raise ValueError("All bounding boxes should have positive height and width."
92-
" Found invaid box {} for target at index {}."
93+
" Found invalid box {} for target at index {}."
9394
.format(degen_bb, target_idx))
9495

9596
features = self.backbone(images.tensors)

torchvision/ops/poolers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2+
from typing import Union
3+
24
import torch
35
import torch.nn.functional as F
46
from torch import nn, Tensor
@@ -119,7 +121,7 @@ class MultiScaleRoIAlign(nn.Module):
119121
def __init__(
120122
self,
121123
featmap_names: List[str],
122-
output_size: List[int],
124+
output_size: Union[int, Tuple[int], List[int]],
123125
sampling_ratio: int,
124126
):
125127
super(MultiScaleRoIAlign, self).__init__()

0 commit comments

Comments
 (0)