@@ -1092,10 +1092,40 @@ def aten_conj_physical(self: TensorType) -> TensorType:
10921092 raise NotImplementedError ()
10931093
10941094
1095- def aten_constant_pad_nd (self : TensorType , pad : INT64 , value : float = 0.0 ) -> TensorType :
1095+ @torch_op ("aten::constant_pad_nd" )
1096+ def aten_constant_pad_nd (self : TTensor , pad : INT64 , value : float = 0.0 ) -> TTensor :
10961097 """constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor"""
10971098
1098- raise NotImplementedError ()
1099+ # The desired order of paddings is
1100+ # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end.
1101+ # n is the dimension of input.
1102+ # assume zero-dimensions in the beginning
1103+ # rank = len(self.shape) # rank must be scalar
1104+ # paddings = list(pad[:]) + [0] * (rank * 2 - len(pad))
1105+ # reverse order and collate first beginnings and then ends
1106+ # paddings = paddings[-2::-2] + paddings[-1::-2]
1107+
1108+ neg_1 = op .Constant (value_ints = [- 1 ])
1109+
1110+ rank = op .Size (op .Shape (self ))
1111+ zero_count = op .Sub (op .Mul (rank , 2 ), op .Size (pad ))
1112+ zero_count = op .Reshape (zero_count , neg_1 )
1113+ zero = op .Constant (value_ints = [0 ])
1114+ zeros = op .Expand (zero , zero_count )
1115+ torch_paddings = op .Concat (pad , zeros , axis = 0 )
1116+ size_d = op .Size (torch_paddings )
1117+ steps = op .Constant (value_ints = [- 2 ])
1118+
1119+ starts = steps
1120+ ends = op .Sub (starts , size_d )
1121+ odd_elements = op .Slice (torch_paddings , starts , ends , zero , steps )
1122+
1123+ starts = neg_1
1124+ ends = op .Sub (starts , size_d )
1125+ even_elements = op .Slice (torch_paddings , starts , ends , zero , steps )
1126+
1127+ onnx_padding = op .Concat (odd_elements , even_elements , axis = 0 )
1128+ return op .Pad (self , onnx_padding , value )
10991129
11001130
11011131@torch_op ("aten::contiguous" , trace_only = True )
@@ -4866,10 +4896,11 @@ def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
48664896 return op .Sub (other , op .Mul (self , alpha ))
48674897
48684898
4869- def aten_scalar_tensor (s : float ) -> TensorType :
4899+ @torch_op ("aten::scalar_tensor" )
4900+ def aten_scalar_tensor (s : float , dtype : int = FLOAT .dtype ) -> TTensor : # type: ignore[type-var]
48704901 """scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
48714902
4872- raise NotImplementedError ( )
4903+ return op . Cast ( s , to = dtype )
48734904
48744905
48754906def aten_scatter_add (
0 commit comments