Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
maximum = float_image.amax(dim=(-2, -1), keepdim=True)

eq_idxs = maximum == minimum
inv_scale = maximum.sub_(minimum).div_(bound)
inv_scale = maximum.sub_(minimum).mul_(1.0 / bound)
minimum[eq_idxs] = 0.0
inv_scale[eq_idxs] = 1.0

Expand Down
23 changes: 13 additions & 10 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def _affine_bounding_box_xyxy(
device=device,
)
new_points = torch.matmul(points, transposed_affine_matrix)
tr, _ = torch.min(new_points, dim=0, keepdim=True)
tr = torch.amin(new_points, dim=0, keepdim=True)
# Translate bounding boxes
out_bboxes.sub_(tr.repeat((1, 2)))
# Estimate meta-data for image with inverted=True and with center=[0,0]
Expand Down Expand Up @@ -701,7 +701,7 @@ def pad_image_tensor(
# internally.
torch_padding = _parse_pad_padding(padding)

if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
if padding_mode not in ("constant", "edge", "reflect", "symmetric"):
raise ValueError(
f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, "
f"but got `'{padding_mode}'`."
Expand Down Expand Up @@ -917,17 +917,17 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype,
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
#

# TODO: should we define them transposed?
theta1 = torch.tensor(
[[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
)
theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)

d = 0.5
base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
x_grid = torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow, device=device)
x_grid = torch.linspace(d, ow + d - 1.0, steps=ow, device=device)
base_grid[..., 0].copy_(x_grid)
y_grid = torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh, device=device).unsqueeze_(-1)
y_grid = torch.linspace(d, oh + d - 1.0, steps=oh, device=device).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid)
base_grid[..., 2].fill_(1)

Expand Down Expand Up @@ -1059,6 +1059,7 @@ def perspective_bounding_box(
(-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom,
]

# TODO: should we define them transposed?
theta1 = torch.tensor(
[[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]],
dtype=dtype,
Expand Down Expand Up @@ -1165,14 +1166,18 @@ def elastic_image_tensor(
return image

shape = image.shape
device = image.device

if image.ndim > 4:
image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True
else:
needs_unsquash = False

output = _FT.elastic_transform(image, displacement, interpolation=interpolation.value, fill=fill)
image_height, image_width = shape[-2:]
identity_grid = _create_identity_grid((image_height, image_width), device=device)
grid = identity_grid + displacement.to(device)
output = _FT._apply_grid_transform(image, grid, interpolation.value, fill)

if needs_unsquash:
output = output.reshape(shape)
Expand Down Expand Up @@ -1505,8 +1510,7 @@ def five_crop_image_tensor(
image_height, image_width = image.shape[-2:]

if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}"
raise ValueError(msg.format(size, (image_height, image_width)))
raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")

tl = crop_image_tensor(image, 0, 0, crop_height, crop_width)
tr = crop_image_tensor(image, 0, image_width - crop_width, crop_height, crop_width)
Expand All @@ -1525,8 +1529,7 @@ def five_crop_image_pil(
image_height, image_width = get_spatial_size_image_pil(image)

if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}"
raise ValueError(msg.format(size, (image_height, image_width)))
raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")

tl = crop_image_pil(image, 0, 0, crop_height, crop_width)
tr = crop_image_pil(image, 0, image_width - crop_width, crop_height, crop_width)
Expand Down