Skip to content

Use pytorch smooth_l1_loss and remove private custom implem #3539

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions torchvision/models/detection/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,19 +344,6 @@ def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix):
matches[pred_inds_to_update] = all_matches[pred_inds_to_update]


def smooth_l1_loss(input, target, beta: float = 1. / 9, size_average: bool = True):
"""
very similar to the smooth_l1_loss from pytorch, but with
the extra beta parameter
"""
n = torch.abs(input - target)
cond = n < beta
loss = torch.where(cond, 0.5 * n ** 2 / beta, n - 0.5 * beta)
if size_average:
return loss.mean()
return loss.sum()


def overwrite_eps(model, eps):
"""
This method overwrites the default eps values of all the
Expand Down
4 changes: 2 additions & 2 deletions torchvision/models/detection/roi_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
N, num_classes = class_logits.shape
box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)

box_loss = det_utils.smooth_l1_loss(
box_loss = F.smooth_l1_loss(
box_regression[sampled_pos_inds_subset, labels_pos],
regression_targets[sampled_pos_inds_subset],
beta=1 / 9,
size_average=False,
reduction='sum',
)
box_loss = box_loss / labels.numel()

Expand Down
4 changes: 2 additions & 2 deletions torchvision/models/detection/rpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,11 @@ def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets)
labels = torch.cat(labels, dim=0)
regression_targets = torch.cat(regression_targets, dim=0)

box_loss = det_utils.smooth_l1_loss(
box_loss = F.smooth_l1_loss(
pred_bbox_deltas[sampled_pos_inds],
regression_targets[sampled_pos_inds],
beta=1 / 9,
size_average=False,
reduction='sum',
) / (sampled_inds.numel())

objectness_loss = F.binary_cross_entropy_with_logits(
Expand Down