|
10 | 10 |
|
11 | 11 |
|
12 | 12 | @torch.jit.unused
|
13 |
| -def _resize_image_and_masks_onnx(image, self_min_size, self_max_size, target): |
14 |
| - # type: (Tensor, float, float, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]] |
| 13 | +def _get_shape_onnx(image): |
| 14 | + # type: (Tensor) -> Tensor |
15 | 15 | from torch.onnx import operators
|
16 |
| - im_shape = operators.shape_as_tensor(image)[-2:] |
17 |
| - min_size = torch.min(im_shape).to(dtype=torch.float32) |
18 |
| - max_size = torch.max(im_shape).to(dtype=torch.float32) |
19 |
| - scale_factor = torch.min(self_min_size / min_size, self_max_size / max_size) |
20 |
| - |
21 |
| - image = torch.nn.functional.interpolate( |
22 |
| - image[None], scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=True, |
23 |
| - align_corners=False)[0] |
| 16 | + return operators.shape_as_tensor(image)[-2:] |
24 | 17 |
|
25 |
| - if target is None: |
26 |
| - return image, target |
27 | 18 |
|
28 |
| - if "masks" in target: |
29 |
| - mask = target["masks"] |
30 |
| - mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor, recompute_scale_factor=True)[:, 0].byte() |
31 |
| - target["masks"] = mask |
32 |
| - return image, target |
| 19 | +@torch.jit.unused |
| 20 | +def _fake_cast_onnx(v): |
| 21 | + # type: (Tensor) -> float |
| 22 | + # ONNX requires a tensor but here we fake its type for JIT. |
| 23 | + return v |
33 | 24 |
|
34 | 25 |
|
35 | 26 | def _resize_image_and_masks(image, self_min_size, self_max_size, target):
|
36 | 27 | # type: (Tensor, float, float, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]
|
37 |
| - im_shape = torch.tensor(image.shape[-2:]) |
38 |
| - min_size = float(torch.min(im_shape)) |
39 |
| - max_size = float(torch.max(im_shape)) |
40 |
| - scale_factor = self_min_size / min_size |
41 |
| - if max_size * scale_factor > self_max_size: |
42 |
| - scale_factor = self_max_size / max_size |
| 28 | + if torchvision._is_tracing(): |
| 29 | + im_shape = _get_shape_onnx(image) |
| 30 | + else: |
| 31 | + im_shape = torch.tensor(image.shape[-2:]) |
| 32 | + |
| 33 | + min_size = torch.min(im_shape).to(dtype=torch.float32) |
| 34 | + max_size = torch.max(im_shape).to(dtype=torch.float32) |
| 35 | + scale = torch.min(self_min_size / min_size, self_max_size / max_size) |
| 36 | + |
| 37 | + if torchvision._is_tracing(): |
| 38 | + scale_factor = _fake_cast_onnx(scale) |
| 39 | + else: |
| 40 | + scale_factor = scale.item() |
| 41 | + |
43 | 42 | image = torch.nn.functional.interpolate(
|
44 | 43 | image[None], scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=True,
|
45 | 44 | align_corners=False)[0]
|
@@ -145,10 +144,7 @@ def resize(self, image, target):
|
145 | 144 | else:
|
146 | 145 | # FIXME assume for now that testing uses the largest scale
|
147 | 146 | size = float(self.min_size[-1])
|
148 |
| - if torchvision._is_tracing(): |
149 |
| - image, target = _resize_image_and_masks_onnx(image, size, float(self.max_size), target) |
150 |
| - else: |
151 |
| - image, target = _resize_image_and_masks(image, size, float(self.max_size), target) |
| 147 | + image, target = _resize_image_and_masks(image, size, float(self.max_size), target) |
152 | 148 |
|
153 | 149 | if target is None:
|
154 | 150 | return image, target
|
|
0 commit comments