Skip to content

Fixed issues with dtype in geom functional transforms v2 #7211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions test/prototype_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def make_image_loaders(
"RGBA",
),
extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.float32, torch.uint8),
dtypes=(torch.float32, torch.float64, torch.uint8),
constant_alpha=True,
):
for params in combinations_grid(size=sizes, color_space=color_spaces, extra_dims=extra_dims, dtype=dtypes):
Expand Down Expand Up @@ -426,7 +426,7 @@ def make_bounding_box_loaders(
extra_dims=DEFAULT_EXTRA_DIMS,
formats=tuple(datapoints.BoundingBoxFormat),
spatial_size="random",
dtypes=(torch.float32, torch.int64),
dtypes=(torch.float32, torch.float64, torch.int64),
):
for params in combinations_grid(extra_dims=extra_dims, format=formats, dtype=dtypes):
yield make_bounding_box_loader(**params, spatial_size=spatial_size)
Expand Down Expand Up @@ -618,7 +618,7 @@ def make_video_loaders(
),
num_frames=(1, 0, "random"),
extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.uint8,),
dtypes=(torch.uint8, torch.float32, torch.float64),
):
for params in combinations_grid(
size=sizes, color_space=color_spaces, num_frames=num_frames, extra_dims=extra_dims, dtype=dtypes
Expand Down
23 changes: 21 additions & 2 deletions test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ def float32_vs_uint8_pixel_difference(atol=1, mae=False):
}


def scripted_vs_eager_double_pixel_difference(device, atol=1e-6, rtol=1e-6):
return {
(("TestKernels", "test_scripted_vs_eager"), torch.float64, device): {"atol": atol, "rtol": rtol, "mae": False},
}


def pil_reference_wrapper(pil_kernel):
@functools.wraps(pil_kernel)
def wrapper(input_tensor, *other_args, **kwargs):
Expand Down Expand Up @@ -541,8 +547,10 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, affine_matrix)
def transform(bbox, affine_matrix_, format_):
# Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
in_dtype = bbox.dtype
if not torch.is_floating_point(bbox):
bbox = bbox.float()
bbox_xyxy = F.convert_format_bounding_box(
bbox.float(), old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True
bbox, old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True
)
points = np.array(
[
Expand All @@ -560,6 +568,7 @@ def transform(bbox, affine_matrix_, format_):
np.max(transformed_points[:, 0]).item(),
np.max(transformed_points[:, 1]).item(),
],
dtype=bbox_xyxy.dtype,
)
out_bbox = F.convert_format_bounding_box(
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_, inplace=True
Expand Down Expand Up @@ -844,6 +853,10 @@ def sample_inputs_rotate_video():
KernelInfo(
F.rotate_bounding_box,
sample_inputs_fn=sample_inputs_rotate_bounding_box,
closeness_kwargs={
**scripted_vs_eager_double_pixel_difference("cpu", atol=1e-6, rtol=1e-6),
**scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5),
},
),
KernelInfo(
F.rotate_mask,
Expand Down Expand Up @@ -1275,6 +1288,8 @@ def sample_inputs_perspective_video():
**pil_reference_pixel_difference(2, mae=True),
**cuda_vs_cpu_pixel_difference(),
**float32_vs_uint8_pixel_difference(),
**scripted_vs_eager_double_pixel_difference("cpu", atol=1e-5, rtol=1e-5),
**scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5),
},
),
KernelInfo(
Expand All @@ -1294,7 +1309,11 @@ def sample_inputs_perspective_video():
KernelInfo(
F.perspective_video,
sample_inputs_fn=sample_inputs_perspective_video,
closeness_kwargs=cuda_vs_cpu_pixel_difference(),
closeness_kwargs={
**cuda_vs_cpu_pixel_difference(),
**scripted_vs_eager_double_pixel_difference("cpu", atol=1e-5, rtol=1e-5),
**scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5),
},
),
]
)
Expand Down
33 changes: 22 additions & 11 deletions test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,17 +138,28 @@ def __init__(
NotScriptableArgsKwargs(5, padding_mode="symmetric"),
],
),
ConsistencyConfig(
prototype_transforms.LinearTransformation,
legacy_transforms.LinearTransformation,
[
ArgsKwargs(LINEAR_TRANSFORMATION_MATRIX, LINEAR_TRANSFORMATION_MEAN),
],
# Make sure that the product of the height, width and number of channels matches the number of elements in
# `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36.
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=["RGB"]),
supports_pil=False,
),
*[
ConsistencyConfig(
prototype_transforms.LinearTransformation,
legacy_transforms.LinearTransformation,
[
ArgsKwargs(LINEAR_TRANSFORMATION_MATRIX.to(matrix_dtype), LINEAR_TRANSFORMATION_MEAN.to(matrix_dtype)),
],
# Make sure that the product of the height, width and number of channels matches the number of elements in
# `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36.
make_images_kwargs=dict(
DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=["RGB"], dtypes=[image_dtype]
),
supports_pil=False,
)
for matrix_dtype, image_dtype in [
(torch.float32, torch.float32),
(torch.float64, torch.float64),
(torch.float32, torch.uint8),
(torch.float64, torch.float32),
(torch.float32, torch.float64),
]
],
ConsistencyConfig(
prototype_transforms.Grayscale,
legacy_transforms.Grayscale,
Expand Down
2 changes: 1 addition & 1 deletion test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def test_scripted_vs_eager(self, test_id, info, args_kwargs, device):
actual,
expected,
**info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
msg=parametrized_error_message(*other_args, **kwargs),
msg=parametrized_error_message(*([actual, expected] + other_args), **kwargs),
)

def _unbatch(self, batch, *, data_dims):
Expand Down
9 changes: 8 additions & 1 deletion torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tenso
f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}"
)

if transformation_matrix.dtype != mean_vector.dtype:
raise ValueError(
f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}"
)

self.transformation_matrix = transformation_matrix
self.mean_vector = mean_vector

Expand Down Expand Up @@ -93,7 +98,9 @@ def _transform(
)

flat_tensor = inpt.reshape(-1, n) - self.mean_vector
transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)

transformation_matrix = self.transformation_matrix.to(flat_tensor.dtype)
transformed_tensor = torch.mm(flat_tensor, transformation_matrix)
return transformed_tensor.reshape(shape)


Expand Down
72 changes: 37 additions & 35 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,9 +404,13 @@ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[in


def _apply_grid_transform(
float_img: torch.Tensor, grid: torch.Tensor, mode: str, fill: datapoints.FillTypeJIT
img: torch.Tensor, grid: torch.Tensor, mode: str, fill: datapoints.FillTypeJIT
) -> torch.Tensor:

# We are using context knowledge that grid should have float dtype
fp = img.dtype == grid.dtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please explain why we're sure that img.dtype is float iff it's the same as the grid dtype? Can't we just use is_floating_dtype()?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we can use is_floating_dtype. I used a context knowledge that grid should have float dtype

float_img = img if fp else img.to(grid.dtype)

shape = float_img.shape
if shape[0] > 1:
# Apply same grid to a batch of images
Expand All @@ -433,7 +437,9 @@ def _apply_grid_transform(
# img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill
float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img)

return float_img
img = float_img.round_().to(img.dtype) if not fp else float_img

return img


def _assert_grid_transform_inputs(
Expand Down Expand Up @@ -511,7 +517,6 @@ def affine_image_tensor(

shape = image.shape
ndim = image.ndim
fp = torch.is_floating_point(image)

if ndim > 4:
image = image.reshape((-1,) + shape[-3:])
Expand All @@ -535,13 +540,10 @@ def affine_image_tensor(

_assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])

dtype = image.dtype if fp else torch.float32
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
grid = _affine_grid(theta, w=width, h=height, ow=width, oh=height)
output = _apply_grid_transform(image if fp else image.to(dtype), grid, interpolation.value, fill=fill)

if not fp:
output = output.round_().to(image.dtype)
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)

if needs_unsquash:
output = output.reshape(shape)
Expand Down Expand Up @@ -612,7 +614,7 @@ def _affine_bounding_box_xyxy(
# Single point structure is similar to
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1)
points = torch.cat([points, torch.ones(points.shape[0], 1, device=device, dtype=dtype)], dim=-1)
# 2) Now let's transform the points using affine matrix
transformed_points = torch.matmul(points, transposed_affine_matrix)
# 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
Expand Down Expand Up @@ -797,19 +799,15 @@ def rotate_image_tensor(
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])

if image.numel() > 0:
fp = torch.is_floating_point(image)
image = image.reshape(-1, num_channels, height, width)

_assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])

ow, oh = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
dtype = image.dtype if fp else torch.float32
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
grid = _affine_grid(theta, w=width, h=height, ow=ow, oh=oh)
output = _apply_grid_transform(image if fp else image.to(dtype), grid, interpolation.value, fill=fill)

if not fp:
output = output.round_().to(image.dtype)
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)

new_height, new_width = output.shape[-2:]
else:
Expand Down Expand Up @@ -1237,9 +1235,9 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype,

d = 0.5
base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
x_grid = torch.linspace(d, ow + d - 1.0, steps=ow, device=device)
x_grid = torch.linspace(d, ow + d - 1.0, steps=ow, device=device, dtype=dtype)
base_grid[..., 0].copy_(x_grid)
y_grid = torch.linspace(d, oh + d - 1.0, steps=oh, device=device).unsqueeze_(-1)
y_grid = torch.linspace(d, oh + d - 1.0, steps=oh, device=device, dtype=dtype).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid)
base_grid[..., 2].fill_(1)

Expand Down Expand Up @@ -1283,7 +1281,6 @@ def perspective_image_tensor(

shape = image.shape
ndim = image.ndim
fp = torch.is_floating_point(image)

if ndim > 4:
image = image.reshape((-1,) + shape[-3:])
Expand All @@ -1304,12 +1301,9 @@ def perspective_image_tensor(
)

oh, ow = shape[-2:]
dtype = image.dtype if fp else torch.float32
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device)
output = _apply_grid_transform(image if fp else image.to(dtype), grid, interpolation.value, fill=fill)

if not fp:
output = output.round_().to(image.dtype)
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)

if needs_unsquash:
output = output.reshape(shape)
Expand Down Expand Up @@ -1494,8 +1488,12 @@ def elastic_image_tensor(

shape = image.shape
ndim = image.ndim

device = image.device
fp = torch.is_floating_point(image)
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
# We are aware that if input image dtype is uint8 and displacement is float64 then
# displacement will be casted to float32 and all computations will be done with float32
# We can fix this later if needed

if ndim > 4:
image = image.reshape((-1,) + shape[-3:])
Expand All @@ -1506,12 +1504,12 @@ def elastic_image_tensor(
else:
needs_unsquash = False

image_height, image_width = shape[-2:]
grid = _create_identity_grid((image_height, image_width), device=device).add_(displacement.to(device))
output = _apply_grid_transform(image if fp else image.to(torch.float32), grid, interpolation.value, fill=fill)
if displacement.dtype != dtype or displacement.device != device:
displacement = displacement.to(dtype=dtype, device=device)

if not fp:
output = output.round_().to(image.dtype)
image_height, image_width = shape[-2:]
grid = _create_identity_grid((image_height, image_width), device=device, dtype=dtype).add_(displacement)
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)

if needs_unsquash:
output = output.reshape(shape)
Expand All @@ -1531,13 +1529,13 @@ def elastic_image_pil(
return to_pil_image(output, mode=image.mode)


def _create_identity_grid(size: Tuple[int, int], device: torch.device) -> torch.Tensor:
def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: torch.dtype) -> torch.Tensor:
sy, sx = size
base_grid = torch.empty(1, sy, sx, 2, device=device)
x_grid = torch.linspace((-sx + 1) / sx, (sx - 1) / sx, sx, device=device)
base_grid = torch.empty(1, sy, sx, 2, device=device, dtype=dtype)
x_grid = torch.linspace((-sx + 1) / sx, (sx - 1) / sx, sx, device=device, dtype=dtype)
base_grid[..., 0].copy_(x_grid)

y_grid = torch.linspace((-sy + 1) / sy, (sy - 1) / sy, sy, device=device).unsqueeze_(-1)
y_grid = torch.linspace((-sy + 1) / sy, (sy - 1) / sy, sy, device=device, dtype=dtype).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid)

return base_grid
Expand All @@ -1552,7 +1550,11 @@ def elastic_bounding_box(
return bounding_box

# TODO: add in docstring about approximation we are doing for grid inversion
displacement = displacement.to(bounding_box.device)
device = bounding_box.device
dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32

if displacement.dtype != dtype or displacement.device != device:
displacement = displacement.to(dtype=dtype, device=device)

original_shape = bounding_box.shape
bounding_box = (
Expand All @@ -1563,7 +1565,7 @@ def elastic_bounding_box(
# Or add spatial_size arg and check displacement shape
spatial_size = displacement.shape[-3], displacement.shape[-2]

id_grid = _create_identity_grid(spatial_size, bounding_box.device)
id_grid = _create_identity_grid(spatial_size, device=device, dtype=dtype)
# We construct an approximation of inverse grid as inv_grid = id_grid - displacement
# This is not an exact inverse of the grid
inv_grid = id_grid.sub_(displacement)
Expand Down
12 changes: 9 additions & 3 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,11 @@ def __init__(self, transformation_matrix, mean_vector):
f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}"
)

if transformation_matrix.dtype != mean_vector.dtype:
raise ValueError(
f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}"
)

self.transformation_matrix = transformation_matrix
self.mean_vector = mean_vector

Expand Down Expand Up @@ -1105,9 +1110,10 @@ def forward(self, tensor: Tensor) -> Tensor:
)

flat_tensor = tensor.view(-1, n) - self.mean_vector
transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
tensor = transformed_tensor.view(shape)
return tensor

transformation_matrix = self.transformation_matrix.to(flat_tensor.dtype)
transformed_tensor = torch.mm(flat_tensor, transformation_matrix)
return transformed_tensor.view(shape)

def __repr__(self) -> str:
s = (
Expand Down