Skip to content

Commit 869c9ee

Browse files
authored
feat(atenlib):add ops(scalar tensor, constant_pad_nd) (#492)
1 parent a6accf0 commit 869c9ee

File tree

2 files changed

+37
-4
lines changed

2 files changed

+37
-4
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,10 +1092,40 @@ def aten_conj_physical(self: TensorType) -> TensorType:
10921092
raise NotImplementedError()
10931093

10941094

1095-
def aten_constant_pad_nd(self: TensorType, pad: INT64, value: float = 0.0) -> TensorType:
1095+
@torch_op("aten::constant_pad_nd")
1096+
def aten_constant_pad_nd(self: TTensor, pad: INT64, value: float = 0.0) -> TTensor:
10961097
"""constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor"""
10971098

1098-
raise NotImplementedError()
1099+
# The desired order of paddings is
1100+
# dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end.
1101+
# n is the dimension of input.
1102+
# assume zero-dimensions in the beginning
1103+
# rank = len(self.shape) # rank must be scalar
1104+
# paddings = list(pad[:]) + [0] * (rank * 2 - len(pad))
1105+
# reverse order and collate first beginnings and then ends
1106+
# paddings = paddings[-2::-2] + paddings[-1::-2]
1107+
1108+
neg_1 = op.Constant(value_ints=[-1])
1109+
1110+
rank = op.Size(op.Shape(self))
1111+
zero_count = op.Sub(op.Mul(rank, 2), op.Size(pad))
1112+
zero_count = op.Reshape(zero_count, neg_1)
1113+
zero = op.Constant(value_ints=[0])
1114+
zeros = op.Expand(zero, zero_count)
1115+
torch_paddings = op.Concat(pad, zeros, axis=0)
1116+
size_d = op.Size(torch_paddings)
1117+
steps = op.Constant(value_ints=[-2])
1118+
1119+
starts = steps
1120+
ends = op.Sub(starts, size_d)
1121+
odd_elements = op.Slice(torch_paddings, starts, ends, zero, steps)
1122+
1123+
starts = neg_1
1124+
ends = op.Sub(starts, size_d)
1125+
even_elements = op.Slice(torch_paddings, starts, ends, zero, steps)
1126+
1127+
onnx_padding = op.Concat(odd_elements, even_elements, axis=0)
1128+
return op.Pad(self, onnx_padding, value)
10991129

11001130

11011131
@torch_op("aten::contiguous", trace_only=True)
@@ -4866,10 +4896,11 @@ def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
48664896
return op.Sub(other, op.Mul(self, alpha))
48674897

48684898

4869-
def aten_scalar_tensor(s: float) -> TensorType:
4899+
@torch_op("aten::scalar_tensor")
4900+
def aten_scalar_tensor(s: float, dtype: int = FLOAT.dtype) -> TTensor: # type: ignore[type-var]
48704901
"""scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
48714902

4872-
raise NotImplementedError()
4903+
return op.Cast(s, to=dtype)
48734904

48744905

48754906
def aten_scatter_add(

onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ def _where_input_wrangler(
301301
"clamp_max": core_ops.aten_clamp_max,
302302
"clamp_min": core_ops.aten_clamp_min,
303303
"clone": core_ops.aten_clone,
304+
"constant_pad_nd": core_ops.aten_constant_pad_nd,
304305
# "copy": core_ops.aten_copy, # copy is not in OPS_DB
305306
"cos": core_ops.aten_cos,
306307
"cosh": core_ops.aten_cosh,
@@ -373,6 +374,7 @@ def _where_input_wrangler(
373374
"rsqrt": core_ops.aten_rsqrt,
374375
"rsub": core_ops.aten_rsub,
375376
"select": core_ops.aten_select,
377+
# "scalar_tensor": core_ops.aten_scalar_tensor, # no test case in OPS_DB
376378
"sigmoid": core_ops.aten_sigmoid,
377379
"sign": core_ops.aten_sign,
378380
"sin": core_ops.aten_sin,

0 commit comments

Comments
 (0)