Skip to content

Commit 6a46040

Browse files
datumboxvfdev-5
authored andcommitted
Change reshape to support empty batches. (pytorch#3031)
1 parent c62ec9d commit 6a46040

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torchvision/models/detection/roi_heads.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
4040
sampled_pos_inds_subset = torch.where(labels > 0)[0]
4141
labels_pos = labels[sampled_pos_inds_subset]
4242
N, num_classes = class_logits.shape
43-
box_regression = box_regression.reshape(N, -1, 4)
43+
box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
4444

4545
box_loss = det_utils.smooth_l1_loss(
4646
box_regression[sampled_pos_inds_subset, labels_pos],

0 commit comments

Comments
 (0)