Skip to content

Commit 456f0ec

Browse files
[torchlib] Fix reflection pad (#2037)
Fixes pytorch/pytorch#144382 --------- Co-authored-by: Ti-Tai Wang <[email protected]>
1 parent 71e24d9 commit 456f0ec

File tree

2 files changed

+86
-102
lines changed

2 files changed

+86
-102
lines changed

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 77 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import onnx
2121

22-
from onnxscript import BFLOAT16, BOOL, DOUBLE, FLOAT, FLOAT16, INT64
22+
from onnxscript import BFLOAT16, BOOL, DOUBLE, FLOAT, FLOAT16, INT64, ir
2323
from onnxscript.function_libs.torch_lib.ops import common as common_ops
2424
from onnxscript.function_libs.torch_lib.registration import torch_op
2525
from onnxscript.function_libs.torch_lib.tensor_typing import (
@@ -1479,12 +1479,56 @@ def aten_one_hot(self: TensorType, num_classes: int = -1) -> TensorType:
14791479
raise NotImplementedError()
14801480

14811481

1482+
def _process_padding(padding: Sequence[INT64 | int], rank: int) -> INT64:
1483+
"""Convert PyTorch padding for ONNX Pad."""
1484+
assert isinstance(padding, (list, tuple))
1485+
if all(isinstance(pad, int) for pad in padding):
1486+
paddings = padding
1487+
zeros = [0] * (rank * 2 - len(paddings))
1488+
paddings = [*paddings, *zeros]
1489+
paddings = paddings[-2::-2] + paddings[-1::-2]
1490+
return op.Constant(value=ir.tensor(paddings, dtype=ir.DataType.INT64))
1491+
else:
1492+
paddings = []
1493+
for pad in padding:
1494+
if isinstance(pad, int):
1495+
paddings.append(op.Constant(value_ints=[pad]))
1496+
else:
1497+
# Dynamic value
1498+
paddings.append(op.Reshape(pad, [-1]))
1499+
# Create a series of 1d zero tensors
1500+
zero = op.Constant(value_ints=[0])
1501+
zeros = [zero] * (rank * 2 - len(paddings))
1502+
paddings = [*paddings, *zeros]
1503+
# Interleave the padding values
1504+
paddings = paddings[-2::-2] + paddings[-1::-2]
1505+
return op.Concat(paddings, axis=0)
1506+
1507+
1508+
@torch_op("aten::pad", trace_only=True)
14821509
def aten_pad(
1483-
self: TensorType, pad: INT64, mode: str = "constant", value: Optional[float] = None
1510+
self: TensorType,
1511+
pad: Sequence[INT64],
1512+
mode: str = "constant",
1513+
value: Optional[float] = None,
14841514
) -> TensorType:
14851515
"""pad(Tensor self, SymInt[] pad, str mode="constant", float? value=None) -> Tensor"""
14861516

1487-
raise NotImplementedError()
1517+
rank = len(self.shape)
1518+
paddings = _process_padding(pad, rank)
1519+
const_value = (
1520+
op.Constant(value=ir.tensor(value, dtype=ir.DataType(self.dtype)))
1521+
if value is not None
1522+
else None
1523+
)
1524+
onnx_mode = {
1525+
"constant": "constant",
1526+
"reflect": "reflect",
1527+
"replicate": "edge",
1528+
"circular": "wrap",
1529+
}[mode]
1530+
1531+
return op.Pad(self, paddings, constant_value=const_value, mode=onnx_mode)
14881532

14891533

14901534
def aten_pad_sequence(
@@ -1495,18 +1539,15 @@ def aten_pad_sequence(
14951539
raise NotImplementedError()
14961540

14971541

1498-
@torch_op("aten::reflection_pad1d")
1499-
def aten_reflection_pad1d(self: TFloat, padding: INT64) -> TFloat:
1542+
@torch_op("aten::reflection_pad1d", trace_only=True)
1543+
def aten_reflection_pad1d(self: TFloat, padding: Sequence[INT64]) -> TFloat:
15001544
"""reflection_pad1d(Tensor self, SymInt[2] padding) -> Tensor"""
15011545

15021546
# assert len(padding) == 2
15031547
# Input of padding argument should be [x,y], need change to onnx format [0, x, 0, y]
1504-
start = op.Slice(padding, [0], [1], axes=[0])
1505-
end = op.Slice(padding, [1], [2], axes=[0])
1506-
padding_onnx = op.Concat(
1507-
op.Constant(value_ints=[0]), start, op.Constant(value_ints=[0]), end, axis=0
1508-
)
1509-
return op.Pad(self, padding_onnx, mode="reflect")
1548+
rank = len(self.shape)
1549+
paddings = _process_padding(padding, rank)
1550+
return op.Pad(self, paddings, mode="reflect")
15101551

15111552

15121553
def aten_reflection_pad1d_backward(
@@ -1517,37 +1558,12 @@ def aten_reflection_pad1d_backward(
15171558
raise NotImplementedError()
15181559

15191560

1520-
@torch_op("aten::reflection_pad2d")
1521-
def aten_reflection_pad2d(self: TTensor, padding: INT64) -> TTensor:
1561+
@torch_op("aten::reflection_pad2d", trace_only=True)
1562+
def aten_reflection_pad2d(self: TTensor, padding: Sequence[INT64]) -> TTensor:
15221563
"""reflection_pad2d(Tensor self, SymInt[4] padding) -> Tensor"""
1523-
# Convert torch padding format to onnx padding format
1524-
# Python code is:
1525-
# dim = len(self.shape)
1526-
# paddings = list(padding[:]) + [0] * (dim * 2 - len(padding))
1527-
# paddings = paddings[-2::-2] + paddings[-1::-2]
1528-
1529-
neg_1 = op.Constant(value_ints=[-1])
1530-
zero = op.Constant(value_ints=[0])
1531-
# [0] * (rank * 2 - len(padding))
1532-
rank = Rank(self)
1533-
zero_count = op.Reshape(op.Sub(op.Mul(rank, 2), op.Size(padding)), neg_1)
1534-
zeros = op.Expand(zero, zero_count)
1535-
# list(padding[:]) + [0] * (dim * 2 - len(padding))
1536-
torch_paddings = op.Concat(padding, zeros, axis=0)
1537-
# paddings[-2::-2]
1538-
size_d = op.Size(torch_paddings)
1539-
steps = op.Constant(value_ints=[-2])
1540-
starts = steps
1541-
ends = op.Sub(starts, size_d)
1542-
odd_elements = op.Slice(torch_paddings, starts, ends, zero, steps)
1543-
# paddings[-1::-2]
1544-
starts = neg_1
1545-
ends = op.Sub(starts, size_d)
1546-
even_elements = op.Slice(torch_paddings, starts, ends, zero, steps)
1547-
# paddings[-2::-2] + paddings[-1::-2]
1548-
onnx_padding = op.Concat(odd_elements, even_elements, axis=0)
1549-
1550-
return op.Pad(self, onnx_padding, mode="reflect")
1564+
rank = len(self.shape)
1565+
paddings = _process_padding(padding, rank)
1566+
return op.Pad(self, paddings, mode="reflect")
15511567

15521568

15531569
def aten_reflection_pad2d_backward(
@@ -1558,10 +1574,12 @@ def aten_reflection_pad2d_backward(
15581574
raise NotImplementedError()
15591575

15601576

1561-
def aten_reflection_pad3d(self: TensorType, padding: INT64) -> TensorType:
1577+
@torch_op("aten::reflection_pad3d", trace_only=True)
1578+
def aten_reflection_pad3d(self: TensorType, padding: Sequence[INT64]) -> TensorType:
15621579
"""reflection_pad3d(Tensor self, SymInt[6] padding) -> Tensor"""
1563-
1564-
raise NotImplementedError()
1580+
rank = len(self.shape)
1581+
paddings = _process_padding(padding, rank)
1582+
return op.Pad(self, paddings, mode="reflect")
15651583

15661584

15671585
def aten_reflection_pad3d_backward(
@@ -1587,18 +1605,13 @@ def aten_relu6(self: TReal) -> TReal:
15871605
return op.Min(op.Relu(self), six)
15881606

15891607

1590-
@torch_op("aten::replication_pad1d")
1591-
def aten_replication_pad1d(self: TensorType, padding: INT64) -> TensorType:
1608+
@torch_op("aten::replication_pad1d", trace_only=True)
1609+
def aten_replication_pad1d(self: TensorType, padding: Sequence[INT64]) -> TensorType:
15921610
"""replication_pad1d(Tensor self, SymInt[2] padding) -> Tensor"""
15931611

1594-
# assert len(padding) == 2
1595-
# Input of padding argument should be [x,y], need change to onnx format [0, x, 0, y]
1596-
start = op.Slice(padding, [0], [1], axes=[0])
1597-
end = op.Slice(padding, [1], [2], axes=[0])
1598-
padding_onnx = op.Concat(
1599-
op.Constant(value_ints=[0]), start, op.Constant(value_ints=[0]), end, axis=0
1600-
)
1601-
return op.Pad(self, padding_onnx, mode="edge")
1612+
rank = len(self.shape)
1613+
paddings = _process_padding(padding, rank)
1614+
return op.Pad(self, paddings, mode="edge")
16021615

16031616

16041617
def aten_replication_pad1d_backward(
@@ -1609,32 +1622,13 @@ def aten_replication_pad1d_backward(
16091622
raise NotImplementedError()
16101623

16111624

1612-
@torch_op("aten::replication_pad2d")
1613-
def aten_replication_pad2d(self: TTensor, padding: INT64) -> TTensor:
1625+
@torch_op("aten::replication_pad2d", trace_only=True)
1626+
def aten_replication_pad2d(self: TTensor, padding: Sequence[INT64]) -> TTensor:
16141627
"""replication_pad2d(Tensor self, SymInt[4] padding) -> Tensor"""
16151628

1616-
neg_1 = op.Constant(value_ints=[-1])
1617-
zero = op.Constant(value_ints=[0])
1618-
# [0] * (rank * 2 - len(padding))
1619-
rank = Rank(self)
1620-
zero_count = op.Reshape(op.Sub(op.Mul(rank, 2), op.Size(padding)), neg_1)
1621-
zeros = op.Expand(zero, zero_count)
1622-
# list(padding[:]) + [0] * (dim * 2 - len(padding))
1623-
torch_paddings = op.Concat(padding, zeros, axis=0)
1624-
# paddings[-2::-2]
1625-
size_d = op.Size(torch_paddings)
1626-
steps = op.Constant(value_ints=[-2])
1627-
starts = steps
1628-
ends = op.Sub(starts, size_d)
1629-
odd_elements = op.Slice(torch_paddings, starts, ends, zero, steps)
1630-
# paddings[-1::-2]
1631-
starts = neg_1
1632-
ends = op.Sub(starts, size_d)
1633-
even_elements = op.Slice(torch_paddings, starts, ends, zero, steps)
1634-
# paddings[-2::-2] + paddings[-1::-2]
1635-
onnx_padding = op.Concat(odd_elements, even_elements, axis=0)
1636-
1637-
return op.Pad(self, onnx_padding, mode="edge")
1629+
rank = len(self.shape)
1630+
paddings = _process_padding(padding, rank)
1631+
return op.Pad(self, paddings, mode="edge")
16381632

16391633

16401634
def aten_replication_pad2d_backward(
@@ -1645,32 +1639,13 @@ def aten_replication_pad2d_backward(
16451639
raise NotImplementedError()
16461640

16471641

1648-
@torch_op("aten::replication_pad3d")
1649-
def aten_replication_pad3d(self: TTensor, padding: INT64) -> TTensor:
1642+
@torch_op("aten::replication_pad3d", trace_only=True)
1643+
def aten_replication_pad3d(self: TTensor, padding: Sequence[INT64]) -> TTensor:
16501644
"""replication_pad3d(Tensor self, SymInt[6] padding) -> Tensor"""
16511645

1652-
neg_1 = op.Constant(value_ints=[-1])
1653-
zero = op.Constant(value_ints=[0])
1654-
# [0] * (rank * 2 - len(padding))
1655-
rank = Rank(self)
1656-
zero_count = op.Reshape(op.Sub(op.Mul(rank, 2), op.Size(padding)), neg_1)
1657-
zeros = op.Expand(zero, zero_count)
1658-
# list(padding[:]) + [0] * (dim * 2 - len(padding))
1659-
torch_paddings = op.Concat(padding, zeros, axis=0)
1660-
# paddings[-2::-2]
1661-
size_d = op.Size(torch_paddings)
1662-
steps = op.Constant(value_ints=[-2])
1663-
starts = steps
1664-
ends = op.Sub(starts, size_d)
1665-
odd_elements = op.Slice(torch_paddings, starts, ends, zero, steps)
1666-
# paddings[-1::-2]
1667-
starts = neg_1
1668-
ends = op.Sub(starts, size_d)
1669-
even_elements = op.Slice(torch_paddings, starts, ends, zero, steps)
1670-
# paddings[-2::-2] + paddings[-1::-2]
1671-
onnx_padding = op.Concat(odd_elements, even_elements, axis=0)
1672-
1673-
return op.Pad(self, onnx_padding, mode="edge")
1646+
rank = len(self.shape)
1647+
paddings = _process_padding(padding, rank)
1648+
return op.Pad(self, paddings, mode="edge")
16741649

16751650

16761651
def aten_replication_pad3d_backward(

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,6 +1082,15 @@ def _where_input_wrangler(
10821082
input_wrangler=_nll_loss_input_wrangler,
10831083
tolerance={torch.float16: (5e-2, 1e-2)},
10841084
),
1085+
TorchLibOpInfo("nn.functional.pad", nn_ops.aten_pad)
1086+
.skip(
1087+
variant_name="circular",
1088+
reason="fixme: ORT does not support the circular mode",
1089+
)
1090+
.skip(
1091+
variant_name="replicate_negative",
1092+
reason="fixme: The implementation for negative paddings is not correct",
1093+
),
10851094
TorchLibOpInfo(
10861095
"nn.functional.pixel_shuffle",
10871096
core_ops.aten_pixel_shuffle,

0 commit comments

Comments
 (0)