diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index 5e962f4bad9..56e502f726d 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -10,36 +10,35 @@ @torch.jit.unused -def _resize_image_and_masks_onnx(image, self_min_size, self_max_size, target): - # type: (Tensor, float, float, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]] +def _get_shape_onnx(image): + # type: (Tensor) -> Tensor from torch.onnx import operators - im_shape = operators.shape_as_tensor(image)[-2:] - min_size = torch.min(im_shape).to(dtype=torch.float32) - max_size = torch.max(im_shape).to(dtype=torch.float32) - scale_factor = torch.min(self_min_size / min_size, self_max_size / max_size) - - image = torch.nn.functional.interpolate( - image[None], scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=True, - align_corners=False)[0] + return operators.shape_as_tensor(image)[-2:] - if target is None: - return image, target - if "masks" in target: - mask = target["masks"] - mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor, recompute_scale_factor=True)[:, 0].byte() - target["masks"] = mask - return image, target +@torch.jit.unused +def _fake_cast_onnx(v): + # type: (Tensor) -> float + # ONNX requires a tensor but here we fake its type for JIT. + return v def _resize_image_and_masks(image, self_min_size, self_max_size, target): # type: (Tensor, float, float, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]] - im_shape = torch.tensor(image.shape[-2:]) - min_size = float(torch.min(im_shape)) - max_size = float(torch.max(im_shape)) - scale_factor = self_min_size / min_size - if max_size * scale_factor > self_max_size: - scale_factor = self_max_size / max_size + if torchvision._is_tracing(): + im_shape = _get_shape_onnx(image) + else: + im_shape = torch.tensor(image.shape[-2:]) + + min_size = torch.min(im_shape).to(dtype=torch.float32) + max_size = torch.max(im_shape).to(dtype=torch.float32) + scale = torch.min(self_min_size / min_size, self_max_size / max_size) + + if torchvision._is_tracing(): + scale_factor = _fake_cast_onnx(scale) + else: + scale_factor = scale.item() + image = torch.nn.functional.interpolate( image[None], scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False)[0] @@ -145,10 +144,7 @@ def resize(self, image, target): else: # FIXME assume for now that testing uses the largest scale size = float(self.min_size[-1]) - if torchvision._is_tracing(): - image, target = _resize_image_and_masks_onnx(image, size, float(self.max_size), target) - else: - image, target = _resize_image_and_masks(image, size, float(self.max_size), target) + image, target = _resize_image_and_masks(image, size, float(self.max_size), target) if target is None: return image, target