diff --git a/test/prototype_common_utils.py b/test/prototype_common_utils.py index 89358ee7dcf..c53fecaef7e 100644 --- a/test/prototype_common_utils.py +++ b/test/prototype_common_utils.py @@ -304,7 +304,7 @@ def make_image_loaders( "RGBA", ), extra_dims=DEFAULT_EXTRA_DIMS, - dtypes=(torch.float32, torch.uint8), + dtypes=(torch.float32, torch.float64, torch.uint8), constant_alpha=True, ): for params in combinations_grid(size=sizes, color_space=color_spaces, extra_dims=extra_dims, dtype=dtypes): @@ -426,7 +426,7 @@ def make_bounding_box_loaders( extra_dims=DEFAULT_EXTRA_DIMS, formats=tuple(datapoints.BoundingBoxFormat), spatial_size="random", - dtypes=(torch.float32, torch.int64), + dtypes=(torch.float32, torch.float64, torch.int64), ): for params in combinations_grid(extra_dims=extra_dims, format=formats, dtype=dtypes): yield make_bounding_box_loader(**params, spatial_size=spatial_size) @@ -618,7 +618,7 @@ def make_video_loaders( ), num_frames=(1, 0, "random"), extra_dims=DEFAULT_EXTRA_DIMS, - dtypes=(torch.uint8,), + dtypes=(torch.uint8, torch.float32, torch.float64), ): for params in combinations_grid( size=sizes, color_space=color_spaces, num_frames=num_frames, extra_dims=extra_dims, dtype=dtypes diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index 2ddf085ea19..eddf76440c5 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -109,6 +109,12 @@ def float32_vs_uint8_pixel_difference(atol=1, mae=False): } +def scripted_vs_eager_double_pixel_difference(device, atol=1e-6, rtol=1e-6): + return { + (("TestKernels", "test_scripted_vs_eager"), torch.float64, device): {"atol": atol, "rtol": rtol, "mae": False}, + } + + def pil_reference_wrapper(pil_kernel): @functools.wraps(pil_kernel) def wrapper(input_tensor, *other_args, **kwargs): @@ -541,8 +547,10 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, affine_matrix) def transform(bbox, affine_matrix_, format_): # Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1 in_dtype = bbox.dtype + if not torch.is_floating_point(bbox): + bbox = bbox.float() bbox_xyxy = F.convert_format_bounding_box( - bbox.float(), old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True + bbox, old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True ) points = np.array( [ @@ -560,6 +568,7 @@ def transform(bbox, affine_matrix_, format_): np.max(transformed_points[:, 0]).item(), np.max(transformed_points[:, 1]).item(), ], + dtype=bbox_xyxy.dtype, ) out_bbox = F.convert_format_bounding_box( out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_, inplace=True @@ -844,6 +853,10 @@ def sample_inputs_rotate_video(): KernelInfo( F.rotate_bounding_box, sample_inputs_fn=sample_inputs_rotate_bounding_box, + closeness_kwargs={ + **scripted_vs_eager_double_pixel_difference("cpu", atol=1e-6, rtol=1e-6), + **scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5), + }, ), KernelInfo( F.rotate_mask, @@ -1275,6 +1288,8 @@ def sample_inputs_perspective_video(): **pil_reference_pixel_difference(2, mae=True), **cuda_vs_cpu_pixel_difference(), **float32_vs_uint8_pixel_difference(), + **scripted_vs_eager_double_pixel_difference("cpu", atol=1e-5, rtol=1e-5), + **scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5), }, ), KernelInfo( @@ -1294,7 +1309,11 @@ def sample_inputs_perspective_video(): KernelInfo( F.perspective_video, sample_inputs_fn=sample_inputs_perspective_video, - closeness_kwargs=cuda_vs_cpu_pixel_difference(), + closeness_kwargs={ + **cuda_vs_cpu_pixel_difference(), + **scripted_vs_eager_double_pixel_difference("cpu", atol=1e-5, rtol=1e-5), + **scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5), + }, ), ] ) diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index 758acc7b10a..f0a7b44db3b 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -138,17 +138,28 @@ def __init__( NotScriptableArgsKwargs(5, padding_mode="symmetric"), ], ), - ConsistencyConfig( - prototype_transforms.LinearTransformation, - legacy_transforms.LinearTransformation, - [ - ArgsKwargs(LINEAR_TRANSFORMATION_MATRIX, LINEAR_TRANSFORMATION_MEAN), - ], - # Make sure that the product of the height, width and number of channels matches the number of elements in - # `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36. - make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=["RGB"]), - supports_pil=False, - ), + *[ + ConsistencyConfig( + prototype_transforms.LinearTransformation, + legacy_transforms.LinearTransformation, + [ + ArgsKwargs(LINEAR_TRANSFORMATION_MATRIX.to(matrix_dtype), LINEAR_TRANSFORMATION_MEAN.to(matrix_dtype)), + ], + # Make sure that the product of the height, width and number of channels matches the number of elements in + # `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36. + make_images_kwargs=dict( + DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=["RGB"], dtypes=[image_dtype] + ), + supports_pil=False, + ) + for matrix_dtype, image_dtype in [ + (torch.float32, torch.float32), + (torch.float64, torch.float64), + (torch.float32, torch.uint8), + (torch.float64, torch.float32), + (torch.float32, torch.float64), + ] + ], ConsistencyConfig( prototype_transforms.Grayscale, legacy_transforms.Grayscale, diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 5469e56df96..539cbce7787 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -142,7 +142,7 @@ def test_scripted_vs_eager(self, test_id, info, args_kwargs, device): actual, expected, **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device), - msg=parametrized_error_message(*other_args, **kwargs), + msg=parametrized_error_message(*([actual, expected] + other_args), **kwargs), ) def _unbatch(self, batch, *, data_dims): diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index e7bb62da18e..39d9dc103f4 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -64,6 +64,11 @@ def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tenso f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}" ) + if transformation_matrix.dtype != mean_vector.dtype: + raise ValueError( + f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}" + ) + self.transformation_matrix = transformation_matrix self.mean_vector = mean_vector @@ -93,7 +98,9 @@ def _transform( ) flat_tensor = inpt.reshape(-1, n) - self.mean_vector - transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) + + transformation_matrix = self.transformation_matrix.to(flat_tensor.dtype) + transformed_tensor = torch.mm(flat_tensor, transformation_matrix) return transformed_tensor.reshape(shape) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 66e777dbdcc..aa16dc0afed 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -404,9 +404,13 @@ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[in def _apply_grid_transform( - float_img: torch.Tensor, grid: torch.Tensor, mode: str, fill: datapoints.FillTypeJIT + img: torch.Tensor, grid: torch.Tensor, mode: str, fill: datapoints.FillTypeJIT ) -> torch.Tensor: + # We are using context knowledge that grid should have float dtype + fp = img.dtype == grid.dtype + float_img = img if fp else img.to(grid.dtype) + shape = float_img.shape if shape[0] > 1: # Apply same grid to a batch of images @@ -433,7 +437,9 @@ def _apply_grid_transform( # img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img) - return float_img + img = float_img.round_().to(img.dtype) if not fp else float_img + + return img def _assert_grid_transform_inputs( @@ -511,7 +517,6 @@ def affine_image_tensor( shape = image.shape ndim = image.ndim - fp = torch.is_floating_point(image) if ndim > 4: image = image.reshape((-1,) + shape[-3:]) @@ -535,13 +540,10 @@ def affine_image_tensor( _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"]) - dtype = image.dtype if fp else torch.float32 + dtype = image.dtype if torch.is_floating_point(image) else torch.float32 theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3) grid = _affine_grid(theta, w=width, h=height, ow=width, oh=height) - output = _apply_grid_transform(image if fp else image.to(dtype), grid, interpolation.value, fill=fill) - - if not fp: - output = output.round_().to(image.dtype) + output = _apply_grid_transform(image, grid, interpolation.value, fill=fill) if needs_unsquash: output = output.reshape(shape) @@ -612,7 +614,7 @@ def _affine_bounding_box_xyxy( # Single point structure is similar to # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)] points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2) - points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1) + points = torch.cat([points, torch.ones(points.shape[0], 1, device=device, dtype=dtype)], dim=-1) # 2) Now let's transform the points using affine matrix transformed_points = torch.matmul(points, transposed_affine_matrix) # 3) Reshape transformed points to [N boxes, 4 points, x/y coords] @@ -797,19 +799,15 @@ def rotate_image_tensor( matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) if image.numel() > 0: - fp = torch.is_floating_point(image) image = image.reshape(-1, num_channels, height, width) _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"]) ow, oh = _compute_affine_output_size(matrix, width, height) if expand else (width, height) - dtype = image.dtype if fp else torch.float32 + dtype = image.dtype if torch.is_floating_point(image) else torch.float32 theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3) grid = _affine_grid(theta, w=width, h=height, ow=ow, oh=oh) - output = _apply_grid_transform(image if fp else image.to(dtype), grid, interpolation.value, fill=fill) - - if not fp: - output = output.round_().to(image.dtype) + output = _apply_grid_transform(image, grid, interpolation.value, fill=fill) new_height, new_width = output.shape[-2:] else: @@ -1237,9 +1235,9 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, d = 0.5 base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device) - x_grid = torch.linspace(d, ow + d - 1.0, steps=ow, device=device) + x_grid = torch.linspace(d, ow + d - 1.0, steps=ow, device=device, dtype=dtype) base_grid[..., 0].copy_(x_grid) - y_grid = torch.linspace(d, oh + d - 1.0, steps=oh, device=device).unsqueeze_(-1) + y_grid = torch.linspace(d, oh + d - 1.0, steps=oh, device=device, dtype=dtype).unsqueeze_(-1) base_grid[..., 1].copy_(y_grid) base_grid[..., 2].fill_(1) @@ -1283,7 +1281,6 @@ def perspective_image_tensor( shape = image.shape ndim = image.ndim - fp = torch.is_floating_point(image) if ndim > 4: image = image.reshape((-1,) + shape[-3:]) @@ -1304,12 +1301,9 @@ def perspective_image_tensor( ) oh, ow = shape[-2:] - dtype = image.dtype if fp else torch.float32 + dtype = image.dtype if torch.is_floating_point(image) else torch.float32 grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device) - output = _apply_grid_transform(image if fp else image.to(dtype), grid, interpolation.value, fill=fill) - - if not fp: - output = output.round_().to(image.dtype) + output = _apply_grid_transform(image, grid, interpolation.value, fill=fill) if needs_unsquash: output = output.reshape(shape) @@ -1494,8 +1488,12 @@ def elastic_image_tensor( shape = image.shape ndim = image.ndim + device = image.device - fp = torch.is_floating_point(image) + dtype = image.dtype if torch.is_floating_point(image) else torch.float32 + # We are aware that if input image dtype is uint8 and displacement is float64 then + # displacement will be casted to float32 and all computations will be done with float32 + # We can fix this later if needed if ndim > 4: image = image.reshape((-1,) + shape[-3:]) @@ -1506,12 +1504,12 @@ def elastic_image_tensor( else: needs_unsquash = False - image_height, image_width = shape[-2:] - grid = _create_identity_grid((image_height, image_width), device=device).add_(displacement.to(device)) - output = _apply_grid_transform(image if fp else image.to(torch.float32), grid, interpolation.value, fill=fill) + if displacement.dtype != dtype or displacement.device != device: + displacement = displacement.to(dtype=dtype, device=device) - if not fp: - output = output.round_().to(image.dtype) + image_height, image_width = shape[-2:] + grid = _create_identity_grid((image_height, image_width), device=device, dtype=dtype).add_(displacement) + output = _apply_grid_transform(image, grid, interpolation.value, fill=fill) if needs_unsquash: output = output.reshape(shape) @@ -1531,13 +1529,13 @@ def elastic_image_pil( return to_pil_image(output, mode=image.mode) -def _create_identity_grid(size: Tuple[int, int], device: torch.device) -> torch.Tensor: +def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: torch.dtype) -> torch.Tensor: sy, sx = size - base_grid = torch.empty(1, sy, sx, 2, device=device) - x_grid = torch.linspace((-sx + 1) / sx, (sx - 1) / sx, sx, device=device) + base_grid = torch.empty(1, sy, sx, 2, device=device, dtype=dtype) + x_grid = torch.linspace((-sx + 1) / sx, (sx - 1) / sx, sx, device=device, dtype=dtype) base_grid[..., 0].copy_(x_grid) - y_grid = torch.linspace((-sy + 1) / sy, (sy - 1) / sy, sy, device=device).unsqueeze_(-1) + y_grid = torch.linspace((-sy + 1) / sy, (sy - 1) / sy, sy, device=device, dtype=dtype).unsqueeze_(-1) base_grid[..., 1].copy_(y_grid) return base_grid @@ -1552,7 +1550,11 @@ def elastic_bounding_box( return bounding_box # TODO: add in docstring about approximation we are doing for grid inversion - displacement = displacement.to(bounding_box.device) + device = bounding_box.device + dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32 + + if displacement.dtype != dtype or displacement.device != device: + displacement = displacement.to(dtype=dtype, device=device) original_shape = bounding_box.shape bounding_box = ( @@ -1563,7 +1565,7 @@ def elastic_bounding_box( # Or add spatial_size arg and check displacement shape spatial_size = displacement.shape[-3], displacement.shape[-2] - id_grid = _create_identity_grid(spatial_size, bounding_box.device) + id_grid = _create_identity_grid(spatial_size, device=device, dtype=dtype) # We construct an approximation of inverse grid as inv_grid = id_grid - displacement # This is not an exact inverse of the grid inv_grid = id_grid.sub_(displacement) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 9395ca674f4..e39e04c3478 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1078,6 +1078,11 @@ def __init__(self, transformation_matrix, mean_vector): f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}" ) + if transformation_matrix.dtype != mean_vector.dtype: + raise ValueError( + f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}" + ) + self.transformation_matrix = transformation_matrix self.mean_vector = mean_vector @@ -1105,9 +1110,10 @@ def forward(self, tensor: Tensor) -> Tensor: ) flat_tensor = tensor.view(-1, n) - self.mean_vector - transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) - tensor = transformed_tensor.view(shape) - return tensor + + transformation_matrix = self.transformation_matrix.to(flat_tensor.dtype) + transformed_tensor = torch.mm(flat_tensor, transformation_matrix) + return transformed_tensor.view(shape) def __repr__(self) -> str: s = (