@@ -1538,10 +1538,21 @@ def elastic_image_tensor(
1538
1538
1539
1539
device = image.device
1540
1540
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
1541
+
1542
+ # Patch: elastic transform should support (cpu,f16) input
1543
+ is_cpu_half = device.type == "cpu" and dtype == torch.float16
1544
+ if is_cpu_half:
1545
+ image = image.to(torch.float32)
1546
+ dtype = torch.float32
1547
+
1541
1548
# We are aware that if input image dtype is uint8 and displacement is float64 then
1542
1549
# displacement will be casted to float32 and all computations will be done with float32
1543
1550
# We can fix this later if needed
1544
1551
1552
+ expected_shape = (1,) + shape[-2:] + (2,)
1553
+ if expected_shape != displacement.shape:
1554
+ raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")
1555
+
1545
1556
if ndim > 4:
1546
1557
image = image.reshape((-1,) + shape[-3:])
1547
1558
needs_unsquash = True
@@ -1561,6 +1572,9 @@ def elastic_image_tensor(
1561
1572
if needs_unsquash:
1562
1573
output = output.reshape(shape)
1563
1574
1575
+ if is_cpu_half:
1576
+ output = output.to(torch.float16)
1577
+
1564
1578
return output
1565
1579
1566
1580
@@ -1676,6 +1690,9 @@ def elastic(
1676
1690
if not torch.jit.is_scripting():
1677
1691
_log_api_usage_once(elastic)
1678
1692
1693
+ if not isinstance(displacement, torch.Tensor):
1694
+ raise TypeError("Argument displacement should be a Tensor")
1695
+
1679
1696
if torch.jit.is_scripting() or is_simple_tensor(inpt):
1680
1697
return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill)
1681
1698
elif isinstance(inpt, datapoints._datapoint.Datapoint):
0 commit comments