19
19
20
20
import onnx
21
21
22
- from onnxscript import BFLOAT16 , BOOL , DOUBLE , FLOAT , FLOAT16 , INT64
22
+ from onnxscript import BFLOAT16 , BOOL , DOUBLE , FLOAT , FLOAT16 , INT64 , ir
23
23
from onnxscript .function_libs .torch_lib .ops import common as common_ops
24
24
from onnxscript .function_libs .torch_lib .registration import torch_op
25
25
from onnxscript .function_libs .torch_lib .tensor_typing import (
@@ -1479,12 +1479,56 @@ def aten_one_hot(self: TensorType, num_classes: int = -1) -> TensorType:
1479
1479
raise NotImplementedError ()
1480
1480
1481
1481
1482
+ def _process_padding (padding : Sequence [INT64 | int ], rank : int ) -> INT64 :
1483
+ """Convert PyTorch padding for ONNX Pad."""
1484
+ assert isinstance (padding , (list , tuple ))
1485
+ if all (isinstance (pad , int ) for pad in padding ):
1486
+ paddings = padding
1487
+ zeros = [0 ] * (rank * 2 - len (paddings ))
1488
+ paddings = [* paddings , * zeros ]
1489
+ paddings = paddings [- 2 ::- 2 ] + paddings [- 1 ::- 2 ]
1490
+ return op .Constant (value = ir .tensor (paddings , dtype = ir .DataType .INT64 ))
1491
+ else :
1492
+ paddings = []
1493
+ for pad in padding :
1494
+ if isinstance (pad , int ):
1495
+ paddings .append (op .Constant (value_ints = [pad ]))
1496
+ else :
1497
+ # Dynamic value
1498
+ paddings .append (op .Reshape (pad , [- 1 ]))
1499
+ # Create a series of 1d zero tensors
1500
+ zero = op .Constant (value_ints = [0 ])
1501
+ zeros = [zero ] * (rank * 2 - len (paddings ))
1502
+ paddings = [* paddings , * zeros ]
1503
+ # Interleave the padding values
1504
+ paddings = paddings [- 2 ::- 2 ] + paddings [- 1 ::- 2 ]
1505
+ return op .Concat (paddings , axis = 0 )
1506
+
1507
+
1508
+ @torch_op ("aten::pad" , trace_only = True )
1482
1509
def aten_pad (
1483
- self : TensorType , pad : INT64 , mode : str = "constant" , value : Optional [float ] = None
1510
+ self : TensorType ,
1511
+ pad : Sequence [INT64 ],
1512
+ mode : str = "constant" ,
1513
+ value : Optional [float ] = None ,
1484
1514
) -> TensorType :
1485
1515
"""pad(Tensor self, SymInt[] pad, str mode="constant", float? value=None) -> Tensor"""
1486
1516
1487
- raise NotImplementedError ()
1517
+ rank = len (self .shape )
1518
+ paddings = _process_padding (pad , rank )
1519
+ const_value = (
1520
+ op .Constant (value = ir .tensor (value , dtype = ir .DataType (self .dtype )))
1521
+ if value is not None
1522
+ else None
1523
+ )
1524
+ onnx_mode = {
1525
+ "constant" : "constant" ,
1526
+ "reflect" : "reflect" ,
1527
+ "replicate" : "edge" ,
1528
+ "circular" : "wrap" ,
1529
+ }[mode ]
1530
+
1531
+ return op .Pad (self , paddings , constant_value = const_value , mode = onnx_mode )
1488
1532
1489
1533
1490
1534
def aten_pad_sequence (
@@ -1495,18 +1539,15 @@ def aten_pad_sequence(
1495
1539
raise NotImplementedError ()
1496
1540
1497
1541
1498
- @torch_op ("aten::reflection_pad1d" )
1499
- def aten_reflection_pad1d (self : TFloat , padding : INT64 ) -> TFloat :
1542
+ @torch_op ("aten::reflection_pad1d" , trace_only = True )
1543
+ def aten_reflection_pad1d (self : TFloat , padding : Sequence [ INT64 ] ) -> TFloat :
1500
1544
"""reflection_pad1d(Tensor self, SymInt[2] padding) -> Tensor"""
1501
1545
1502
1546
# assert len(padding) == 2
1503
1547
# Input of padding argument should be [x,y], need change to onnx format [0, x, 0, y]
1504
- start = op .Slice (padding , [0 ], [1 ], axes = [0 ])
1505
- end = op .Slice (padding , [1 ], [2 ], axes = [0 ])
1506
- padding_onnx = op .Concat (
1507
- op .Constant (value_ints = [0 ]), start , op .Constant (value_ints = [0 ]), end , axis = 0
1508
- )
1509
- return op .Pad (self , padding_onnx , mode = "reflect" )
1548
+ rank = len (self .shape )
1549
+ paddings = _process_padding (padding , rank )
1550
+ return op .Pad (self , paddings , mode = "reflect" )
1510
1551
1511
1552
1512
1553
def aten_reflection_pad1d_backward (
@@ -1517,37 +1558,12 @@ def aten_reflection_pad1d_backward(
1517
1558
raise NotImplementedError ()
1518
1559
1519
1560
1520
- @torch_op ("aten::reflection_pad2d" )
1521
- def aten_reflection_pad2d (self : TTensor , padding : INT64 ) -> TTensor :
1561
+ @torch_op ("aten::reflection_pad2d" , trace_only = True )
1562
+ def aten_reflection_pad2d (self : TTensor , padding : Sequence [ INT64 ] ) -> TTensor :
1522
1563
"""reflection_pad2d(Tensor self, SymInt[4] padding) -> Tensor"""
1523
- # Convert torch padding format to onnx padding format
1524
- # Python code is:
1525
- # dim = len(self.shape)
1526
- # paddings = list(padding[:]) + [0] * (dim * 2 - len(padding))
1527
- # paddings = paddings[-2::-2] + paddings[-1::-2]
1528
-
1529
- neg_1 = op .Constant (value_ints = [- 1 ])
1530
- zero = op .Constant (value_ints = [0 ])
1531
- # [0] * (rank * 2 - len(padding))
1532
- rank = Rank (self )
1533
- zero_count = op .Reshape (op .Sub (op .Mul (rank , 2 ), op .Size (padding )), neg_1 )
1534
- zeros = op .Expand (zero , zero_count )
1535
- # list(padding[:]) + [0] * (dim * 2 - len(padding))
1536
- torch_paddings = op .Concat (padding , zeros , axis = 0 )
1537
- # paddings[-2::-2]
1538
- size_d = op .Size (torch_paddings )
1539
- steps = op .Constant (value_ints = [- 2 ])
1540
- starts = steps
1541
- ends = op .Sub (starts , size_d )
1542
- odd_elements = op .Slice (torch_paddings , starts , ends , zero , steps )
1543
- # paddings[-1::-2]
1544
- starts = neg_1
1545
- ends = op .Sub (starts , size_d )
1546
- even_elements = op .Slice (torch_paddings , starts , ends , zero , steps )
1547
- # paddings[-2::-2] + paddings[-1::-2]
1548
- onnx_padding = op .Concat (odd_elements , even_elements , axis = 0 )
1549
-
1550
- return op .Pad (self , onnx_padding , mode = "reflect" )
1564
+ rank = len (self .shape )
1565
+ paddings = _process_padding (padding , rank )
1566
+ return op .Pad (self , paddings , mode = "reflect" )
1551
1567
1552
1568
1553
1569
def aten_reflection_pad2d_backward (
@@ -1558,10 +1574,12 @@ def aten_reflection_pad2d_backward(
1558
1574
raise NotImplementedError ()
1559
1575
1560
1576
1561
- def aten_reflection_pad3d (self : TensorType , padding : INT64 ) -> TensorType :
1577
+ @torch_op ("aten::reflection_pad3d" , trace_only = True )
1578
+ def aten_reflection_pad3d (self : TensorType , padding : Sequence [INT64 ]) -> TensorType :
1562
1579
"""reflection_pad3d(Tensor self, SymInt[6] padding) -> Tensor"""
1563
-
1564
- raise NotImplementedError ()
1580
+ rank = len (self .shape )
1581
+ paddings = _process_padding (padding , rank )
1582
+ return op .Pad (self , paddings , mode = "reflect" )
1565
1583
1566
1584
1567
1585
def aten_reflection_pad3d_backward (
@@ -1587,18 +1605,13 @@ def aten_relu6(self: TReal) -> TReal:
1587
1605
return op .Min (op .Relu (self ), six )
1588
1606
1589
1607
1590
- @torch_op ("aten::replication_pad1d" )
1591
- def aten_replication_pad1d (self : TensorType , padding : INT64 ) -> TensorType :
1608
+ @torch_op ("aten::replication_pad1d" , trace_only = True )
1609
+ def aten_replication_pad1d (self : TensorType , padding : Sequence [ INT64 ] ) -> TensorType :
1592
1610
"""replication_pad1d(Tensor self, SymInt[2] padding) -> Tensor"""
1593
1611
1594
- # assert len(padding) == 2
1595
- # Input of padding argument should be [x,y], need change to onnx format [0, x, 0, y]
1596
- start = op .Slice (padding , [0 ], [1 ], axes = [0 ])
1597
- end = op .Slice (padding , [1 ], [2 ], axes = [0 ])
1598
- padding_onnx = op .Concat (
1599
- op .Constant (value_ints = [0 ]), start , op .Constant (value_ints = [0 ]), end , axis = 0
1600
- )
1601
- return op .Pad (self , padding_onnx , mode = "edge" )
1612
+ rank = len (self .shape )
1613
+ paddings = _process_padding (padding , rank )
1614
+ return op .Pad (self , paddings , mode = "edge" )
1602
1615
1603
1616
1604
1617
def aten_replication_pad1d_backward (
@@ -1609,32 +1622,13 @@ def aten_replication_pad1d_backward(
1609
1622
raise NotImplementedError ()
1610
1623
1611
1624
1612
- @torch_op ("aten::replication_pad2d" )
1613
- def aten_replication_pad2d (self : TTensor , padding : INT64 ) -> TTensor :
1625
+ @torch_op ("aten::replication_pad2d" , trace_only = True )
1626
+ def aten_replication_pad2d (self : TTensor , padding : Sequence [ INT64 ] ) -> TTensor :
1614
1627
"""replication_pad2d(Tensor self, SymInt[4] padding) -> Tensor"""
1615
1628
1616
- neg_1 = op .Constant (value_ints = [- 1 ])
1617
- zero = op .Constant (value_ints = [0 ])
1618
- # [0] * (rank * 2 - len(padding))
1619
- rank = Rank (self )
1620
- zero_count = op .Reshape (op .Sub (op .Mul (rank , 2 ), op .Size (padding )), neg_1 )
1621
- zeros = op .Expand (zero , zero_count )
1622
- # list(padding[:]) + [0] * (dim * 2 - len(padding))
1623
- torch_paddings = op .Concat (padding , zeros , axis = 0 )
1624
- # paddings[-2::-2]
1625
- size_d = op .Size (torch_paddings )
1626
- steps = op .Constant (value_ints = [- 2 ])
1627
- starts = steps
1628
- ends = op .Sub (starts , size_d )
1629
- odd_elements = op .Slice (torch_paddings , starts , ends , zero , steps )
1630
- # paddings[-1::-2]
1631
- starts = neg_1
1632
- ends = op .Sub (starts , size_d )
1633
- even_elements = op .Slice (torch_paddings , starts , ends , zero , steps )
1634
- # paddings[-2::-2] + paddings[-1::-2]
1635
- onnx_padding = op .Concat (odd_elements , even_elements , axis = 0 )
1636
-
1637
- return op .Pad (self , onnx_padding , mode = "edge" )
1629
+ rank = len (self .shape )
1630
+ paddings = _process_padding (padding , rank )
1631
+ return op .Pad (self , paddings , mode = "edge" )
1638
1632
1639
1633
1640
1634
def aten_replication_pad2d_backward (
@@ -1645,32 +1639,13 @@ def aten_replication_pad2d_backward(
1645
1639
raise NotImplementedError ()
1646
1640
1647
1641
1648
- @torch_op ("aten::replication_pad3d" )
1649
- def aten_replication_pad3d (self : TTensor , padding : INT64 ) -> TTensor :
1642
+ @torch_op ("aten::replication_pad3d" , trace_only = True )
1643
+ def aten_replication_pad3d (self : TTensor , padding : Sequence [ INT64 ] ) -> TTensor :
1650
1644
"""replication_pad3d(Tensor self, SymInt[6] padding) -> Tensor"""
1651
1645
1652
- neg_1 = op .Constant (value_ints = [- 1 ])
1653
- zero = op .Constant (value_ints = [0 ])
1654
- # [0] * (rank * 2 - len(padding))
1655
- rank = Rank (self )
1656
- zero_count = op .Reshape (op .Sub (op .Mul (rank , 2 ), op .Size (padding )), neg_1 )
1657
- zeros = op .Expand (zero , zero_count )
1658
- # list(padding[:]) + [0] * (dim * 2 - len(padding))
1659
- torch_paddings = op .Concat (padding , zeros , axis = 0 )
1660
- # paddings[-2::-2]
1661
- size_d = op .Size (torch_paddings )
1662
- steps = op .Constant (value_ints = [- 2 ])
1663
- starts = steps
1664
- ends = op .Sub (starts , size_d )
1665
- odd_elements = op .Slice (torch_paddings , starts , ends , zero , steps )
1666
- # paddings[-1::-2]
1667
- starts = neg_1
1668
- ends = op .Sub (starts , size_d )
1669
- even_elements = op .Slice (torch_paddings , starts , ends , zero , steps )
1670
- # paddings[-2::-2] + paddings[-1::-2]
1671
- onnx_padding = op .Concat (odd_elements , even_elements , axis = 0 )
1672
-
1673
- return op .Pad (self , onnx_padding , mode = "edge" )
1646
+ rank = len (self .shape )
1647
+ paddings = _process_padding (padding , rank )
1648
+ return op .Pad (self , paddings , mode = "edge" )
1674
1649
1675
1650
1676
1651
def aten_replication_pad3d_backward (
0 commit comments