Skip to content

Commit d805aea

Browse files
vfdev-5NicolasHug
andauthored
Fixed issues in elastic transform (#7257)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent 3080082 commit d805aea

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1538,10 +1538,21 @@ def elastic_image_tensor(
15381538

15391539
device = image.device
15401540
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
1541+
1542+
# Patch: elastic transform should support (cpu,f16) input
1543+
is_cpu_half = device.type == "cpu" and dtype == torch.float16
1544+
if is_cpu_half:
1545+
image = image.to(torch.float32)
1546+
dtype = torch.float32
1547+
15411548
# We are aware that if input image dtype is uint8 and displacement is float64 then
15421549
# displacement will be casted to float32 and all computations will be done with float32
15431550
# We can fix this later if needed
15441551

1552+
expected_shape = (1,) + shape[-2:] + (2,)
1553+
if expected_shape != displacement.shape:
1554+
raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")
1555+
15451556
if ndim > 4:
15461557
image = image.reshape((-1,) + shape[-3:])
15471558
needs_unsquash = True
@@ -1561,6 +1572,9 @@ def elastic_image_tensor(
15611572
if needs_unsquash:
15621573
output = output.reshape(shape)
15631574

1575+
if is_cpu_half:
1576+
output = output.to(torch.float16)
1577+
15641578
return output
15651579

15661580

@@ -1676,6 +1690,9 @@ def elastic(
16761690
if not torch.jit.is_scripting():
16771691
_log_api_usage_once(elastic)
16781692

1693+
if not isinstance(displacement, torch.Tensor):
1694+
raise TypeError("Argument displacement should be a Tensor")
1695+
16791696
if torch.jit.is_scripting() or is_simple_tensor(inpt):
16801697
return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill)
16811698
elif isinstance(inpt, datapoints._datapoint.Datapoint):

0 commit comments

Comments
 (0)