Skip to content

Commit 707a69c

Browse files
fmassafacebook-github-bot
authored andcommitted
[fbsync] Unify onnx and JIT resize implementations (#3654)
Summary: * Make two methods as similar as possible. * Introducing conditional fake casting. * Change the casting mechanism. Reviewed By: NicolasHug Differential Revision: D27706950 fbshipit-source-id: ef7503817cd64ffc8723fec89f1cd94647490eaf
1 parent a603cdd commit 707a69c

File tree

1 file changed

+23
-27
lines changed

1 file changed

+23
-27
lines changed

torchvision/models/detection/transform.py

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,36 +10,35 @@
1010

1111

1212
@torch.jit.unused
13-
def _resize_image_and_masks_onnx(image, self_min_size, self_max_size, target):
14-
# type: (Tensor, float, float, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]
13+
def _get_shape_onnx(image):
14+
# type: (Tensor) -> Tensor
1515
from torch.onnx import operators
16-
im_shape = operators.shape_as_tensor(image)[-2:]
17-
min_size = torch.min(im_shape).to(dtype=torch.float32)
18-
max_size = torch.max(im_shape).to(dtype=torch.float32)
19-
scale_factor = torch.min(self_min_size / min_size, self_max_size / max_size)
20-
21-
image = torch.nn.functional.interpolate(
22-
image[None], scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=True,
23-
align_corners=False)[0]
16+
return operators.shape_as_tensor(image)[-2:]
2417

25-
if target is None:
26-
return image, target
2718

28-
if "masks" in target:
29-
mask = target["masks"]
30-
mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor, recompute_scale_factor=True)[:, 0].byte()
31-
target["masks"] = mask
32-
return image, target
19+
@torch.jit.unused
20+
def _fake_cast_onnx(v):
21+
# type: (Tensor) -> float
22+
# ONNX requires a tensor but here we fake its type for JIT.
23+
return v
3324

3425

3526
def _resize_image_and_masks(image, self_min_size, self_max_size, target):
3627
# type: (Tensor, float, float, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]
37-
im_shape = torch.tensor(image.shape[-2:])
38-
min_size = float(torch.min(im_shape))
39-
max_size = float(torch.max(im_shape))
40-
scale_factor = self_min_size / min_size
41-
if max_size * scale_factor > self_max_size:
42-
scale_factor = self_max_size / max_size
28+
if torchvision._is_tracing():
29+
im_shape = _get_shape_onnx(image)
30+
else:
31+
im_shape = torch.tensor(image.shape[-2:])
32+
33+
min_size = torch.min(im_shape).to(dtype=torch.float32)
34+
max_size = torch.max(im_shape).to(dtype=torch.float32)
35+
scale = torch.min(self_min_size / min_size, self_max_size / max_size)
36+
37+
if torchvision._is_tracing():
38+
scale_factor = _fake_cast_onnx(scale)
39+
else:
40+
scale_factor = scale.item()
41+
4342
image = torch.nn.functional.interpolate(
4443
image[None], scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=True,
4544
align_corners=False)[0]
@@ -145,10 +144,7 @@ def resize(self, image, target):
145144
else:
146145
# FIXME assume for now that testing uses the largest scale
147146
size = float(self.min_size[-1])
148-
if torchvision._is_tracing():
149-
image, target = _resize_image_and_masks_onnx(image, size, float(self.max_size), target)
150-
else:
151-
image, target = _resize_image_and_masks(image, size, float(self.max_size), target)
147+
image, target = _resize_image_and_masks(image, size, float(self.max_size), target)
152148

153149
if target is None:
154150
return image, target

0 commit comments

Comments
 (0)