From 599a1a597d5fe85b1f143aec75202e5ea07c88a0 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 9 Apr 2021 13:47:31 +0100 Subject: [PATCH 1/3] Make two methods as similar as possible. --- torchvision/models/detection/transform.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index 5e962f4bad9..b6ea1b683a7 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -35,11 +35,10 @@ def _resize_image_and_masks_onnx(image, self_min_size, self_max_size, target): def _resize_image_and_masks(image, self_min_size, self_max_size, target): # type: (Tensor, float, float, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]] im_shape = torch.tensor(image.shape[-2:]) - min_size = float(torch.min(im_shape)) - max_size = float(torch.max(im_shape)) - scale_factor = self_min_size / min_size - if max_size * scale_factor > self_max_size: - scale_factor = self_max_size / max_size + min_size = torch.min(im_shape).to(dtype=torch.float32) + max_size = torch.max(im_shape).to(dtype=torch.float32) + scale_factor = torch.min(self_min_size / min_size, self_max_size / max_size).item() + image = torch.nn.functional.interpolate( image[None], scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False)[0] From b31127b5d2f6b5bca2e97bf2598faab881beb84f Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 9 Apr 2021 14:15:09 +0100 Subject: [PATCH 2/3] Introducing conditional fake casting. --- torchvision/models/detection/transform.py | 39 ++++++++++------------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index b6ea1b683a7..3c0991fe033 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -10,35 +10,33 @@ @torch.jit.unused -def _resize_image_and_masks_onnx(image, self_min_size, self_max_size, target): - # type: (Tensor, float, float, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]] +def _get_shape_onnx(image): + # type: (Tensor) -> Tensor from torch.onnx import operators - im_shape = operators.shape_as_tensor(image)[-2:] - min_size = torch.min(im_shape).to(dtype=torch.float32) - max_size = torch.max(im_shape).to(dtype=torch.float32) - scale_factor = torch.min(self_min_size / min_size, self_max_size / max_size) - - image = torch.nn.functional.interpolate( - image[None], scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=True, - align_corners=False)[0] + return operators.shape_as_tensor(image)[-2:] - if target is None: - return image, target - if "masks" in target: - mask = target["masks"] - mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor, recompute_scale_factor=True)[:, 0].byte() - target["masks"] = mask - return image, target +@torch.jit.unused +def _float_to_tensor_onnx(v): + # type: (float) -> float + # ONNX requires a tensor but here we fake its type for JIT. + return torch.tensor(v) def _resize_image_and_masks(image, self_min_size, self_max_size, target): # type: (Tensor, float, float, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]] - im_shape = torch.tensor(image.shape[-2:]) + if torchvision._is_tracing(): + im_shape = _get_shape_onnx(image) + else: + im_shape = torch.tensor(image.shape[-2:]) + min_size = torch.min(im_shape).to(dtype=torch.float32) max_size = torch.max(im_shape).to(dtype=torch.float32) scale_factor = torch.min(self_min_size / min_size, self_max_size / max_size).item() + if torchvision._is_tracing(): + scale_factor = _float_to_tensor_onnx(scale_factor) + image = torch.nn.functional.interpolate( image[None], scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False)[0] @@ -144,10 +142,7 @@ def resize(self, image, target): else: # FIXME assume for now that testing uses the largest scale size = float(self.min_size[-1]) - if torchvision._is_tracing(): - image, target = _resize_image_and_masks_onnx(image, size, float(self.max_size), target) - else: - image, target = _resize_image_and_masks(image, size, float(self.max_size), target) + image, target = _resize_image_and_masks(image, size, float(self.max_size), target) if target is None: return image, target From 0f95df5e2044bfef9a12b03a1438cf6bb8cad7d9 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 9 Apr 2021 14:28:10 +0100 Subject: [PATCH 3/3] Change the casting mechanism. --- torchvision/models/detection/transform.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index 3c0991fe033..56e502f726d 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -17,10 +17,10 @@ def _get_shape_onnx(image): @torch.jit.unused -def _float_to_tensor_onnx(v): - # type: (float) -> float +def _fake_cast_onnx(v): + # type: (Tensor) -> float # ONNX requires a tensor but here we fake its type for JIT. - return torch.tensor(v) + return v def _resize_image_and_masks(image, self_min_size, self_max_size, target): @@ -32,10 +32,12 @@ def _resize_image_and_masks(image, self_min_size, self_max_size, target): min_size = torch.min(im_shape).to(dtype=torch.float32) max_size = torch.max(im_shape).to(dtype=torch.float32) - scale_factor = torch.min(self_min_size / min_size, self_max_size / max_size).item() + scale = torch.min(self_min_size / min_size, self_max_size / max_size) if torchvision._is_tracing(): - scale_factor = _float_to_tensor_onnx(scale_factor) + scale_factor = _fake_cast_onnx(scale) + else: + scale_factor = scale.item() image = torch.nn.functional.interpolate( image[None], scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=True,