Skip to content

[DONOTMERGE] Testing CI failures #3053

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions test/test_models_detection_negative_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,15 @@ def test_forward_negative_sample_krcnn(self):
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_keypoint"], torch.tensor(0.))

def test_forward_negative_sample_retinanet(self):
model = torchvision.models.detection.retinanet_resnet50_fpn(
num_classes=2, min_size=100, max_size=100)

images, targets = self._make_empty_sample()
loss_dict = model(images, targets)

self.assertEqual(loss_dict["bbox_regression"], torch.tensor(0.))


if __name__ == '__main__':
unittest.main()
44 changes: 14 additions & 30 deletions torchvision/models/detection/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,21 +107,16 @@ def compute_loss(self, targets, head_outputs, matched_idxs):
# determine only the foreground
foreground_idxs_per_image = matched_idxs_per_image >= 0
num_foreground = foreground_idxs_per_image.sum()
# no matched_idxs means there were no annotations in this image
# TODO: enable support for images without annotations that works on distributed
if False: # matched_idxs_per_image.numel() == 0:
gt_classes_target = torch.zeros_like(cls_logits_per_image)
valid_idxs_per_image = torch.arange(cls_logits_per_image.shape[0])
else:
# create the target classification
gt_classes_target = torch.zeros_like(cls_logits_per_image)
gt_classes_target[
foreground_idxs_per_image,
targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]]
] = 1.0

# find indices for which anchors should be ignored
valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS

# create the target classification
gt_classes_target = torch.zeros_like(cls_logits_per_image)
gt_classes_target[
foreground_idxs_per_image,
targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]]
] = 1.0

# find indices for which anchors should be ignored
valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS

# compute the classification loss
losses.append(sigmoid_focal_loss(
Expand Down Expand Up @@ -191,23 +186,12 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs):

for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in \
zip(targets, bbox_regression, anchors, matched_idxs):
# no matched_idxs means there were no annotations in this image
# TODO enable support for images without annotations with distributed support
# if matched_idxs_per_image.numel() == 0:
# continue

# get the targets corresponding GT for each proposal
# NB: need to clamp the indices because we can have a single
# GT in the image, and matched_idxs can be -2, which goes
# out of bounds
matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image.clamp(min=0)]

# determine only the foreground indices, ignore the rest
foreground_idxs_per_image = matched_idxs_per_image >= 0
num_foreground = foreground_idxs_per_image.sum()
foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0]
num_foreground = foreground_idxs_per_image.numel()

# select only the foreground boxes
matched_gt_boxes_per_image = matched_gt_boxes_per_image[foreground_idxs_per_image, :]
matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image[foreground_idxs_per_image]]
bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :]
anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]

Expand Down Expand Up @@ -403,7 +387,7 @@ def compute_loss(self, targets, head_outputs, anchors):
matched_idxs = []
for anchors_per_image, targets_per_image in zip(anchors, targets):
if targets_per_image['boxes'].numel() == 0:
matched_idxs.append(torch.empty((0,), dtype=torch.int32))
matched_idxs.append(torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64))
continue

match_quality_matrix = box_ops.box_iou(targets_per_image['boxes'], anchors_per_image)
Expand Down