Skip to content

Commit c42e309

Browse files
committed
Fixing for onnx.
1 parent 64b33a9 commit c42e309

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

test/test_onnx.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from common_utils import set_rng_seed
12
import io
23
import torch
34
from torchvision import ops
@@ -257,6 +258,8 @@ def get_features(self, images):
257258
return features
258259

259260
def test_rpn(self):
261+
set_rng_seed(0)
262+
260263
class RPNModule(torch.nn.Module):
261264
def __init__(self_module):
262265
super(RPNModule, self_module).__init__()

torchvision/models/detection/rpn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,8 @@ def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_
264264
boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
265265

266266
# remove low scoring boxes
267-
keep = torch.where(scores > self.score_thresh)[0]
267+
# use >= for Backwards compatibility
268+
keep = torch.where(scores >= self.score_thresh)[0]
268269
boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
269270

270271
# non-maximum suppression, independently done per level

0 commit comments

Comments
 (0)