diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index db9711760dc..24dc9399fd6 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -344,19 +344,6 @@ def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix): matches[pred_inds_to_update] = all_matches[pred_inds_to_update] -def smooth_l1_loss(input, target, beta: float = 1. / 9, size_average: bool = True): - """ - very similar to the smooth_l1_loss from pytorch, but with - the extra beta parameter - """ - n = torch.abs(input - target) - cond = n < beta - loss = torch.where(cond, 0.5 * n ** 2 / beta, n - 0.5 * beta) - if size_average: - return loss.mean() - return loss.sum() - - def overwrite_eps(model, eps): """ This method overwrites the default eps values of all the diff --git a/torchvision/models/detection/roi_heads.py b/torchvision/models/detection/roi_heads.py index 5f476f63827..ab6e87a86e0 100644 --- a/torchvision/models/detection/roi_heads.py +++ b/torchvision/models/detection/roi_heads.py @@ -42,11 +42,11 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): N, num_classes = class_logits.shape box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4) - box_loss = det_utils.smooth_l1_loss( + box_loss = F.smooth_l1_loss( box_regression[sampled_pos_inds_subset, labels_pos], regression_targets[sampled_pos_inds_subset], beta=1 / 9, - size_average=False, + reduction='sum', ) box_loss = box_loss / labels.numel() diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index 736c82a9009..9f9bf9da5f5 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -304,11 +304,11 @@ def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets) labels = torch.cat(labels, dim=0) regression_targets = torch.cat(regression_targets, dim=0) - box_loss = det_utils.smooth_l1_loss( + box_loss = F.smooth_l1_loss( pred_bbox_deltas[sampled_pos_inds], regression_targets[sampled_pos_inds], beta=1 / 9, - size_average=False, + reduction='sum', ) / (sampled_inds.numel()) objectness_loss = F.binary_cross_entropy_with_logits(