Skip to content
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,10 +1092,40 @@ def aten_conj_physical(self: TensorType) -> TensorType:
raise NotImplementedError()


def aten_constant_pad_nd(self: TensorType, pad: INT64, value: float = 0.0) -> TensorType:
@torch_op("aten::constant_pad_nd")
def aten_constant_pad_nd(self: TTensor, pad: INT64, value: float = 0.0) -> TTensor:
"""constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor"""

raise NotImplementedError()
# The desired order of paddings is
# dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end.
# n is the dimension of input.
# assume zero-dimensions in the beginning
# rank = len(self.shape) # rank must be scalar
# paddings = list(pad[:]) + [0] * (rank * 2 - len(pad))
# reverse order and collate first beginnings and then ends
# paddings = paddings[-2::-2] + paddings[-1::-2]

neg_1 = op.Constant(value_ints=[-1])

rank = op.Size(op.Shape(self))
zero_count = op.Sub(op.Mul(rank, 2), op.Size(pad))
zero_count = op.Reshape(zero_count, neg_1)
zero = op.Constant(value_ints=[0])
zeros = op.Expand(zero, zero_count)
torch_paddings = op.Concat(pad, zeros, axis=0)
size_d = op.Size(torch_paddings)
steps = op.Constant(value_ints=[-2])

starts = steps
ends = op.Sub(starts, size_d)
odd_elements = op.Slice(torch_paddings, starts, ends, zero, steps)

starts = neg_1
ends = op.Sub(starts, size_d)
even_elements = op.Slice(torch_paddings, starts, ends, zero, steps)

onnx_padding = op.Concat(odd_elements, even_elements, axis=0)
return op.Pad(self, onnx_padding, value)


@torch_op("aten::contiguous", trace_only=True)
Expand Down Expand Up @@ -4866,10 +4896,11 @@ def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
return op.Sub(other, op.Mul(self, alpha))


def aten_scalar_tensor(s: float) -> TensorType:
@torch_op("aten::scalar_tensor")
def aten_scalar_tensor(s: float, dtype: int = FLOAT.dtype) -> TTensor: # type: ignore[type-var]
"""scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""

raise NotImplementedError()
return op.Cast(s, to=dtype)


def aten_scatter_add(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def _where_input_wrangler(
"clamp_max": core_ops.aten_clamp_max,
"clamp_min": core_ops.aten_clamp_min,
"clone": core_ops.aten_clone,
"constant_pad_nd": core_ops.aten_constant_pad_nd,
# "copy": core_ops.aten_copy, # copy is not in OPS_DB
"cos": core_ops.aten_cos,
"cosh": core_ops.aten_cosh,
Expand Down Expand Up @@ -373,6 +374,7 @@ def _where_input_wrangler(
"rsqrt": core_ops.aten_rsqrt,
"rsub": core_ops.aten_rsub,
"select": core_ops.aten_select,
# "scalar_tensor": core_ops.aten_scalar_tensor, # no test case in OPS_DB
"sigmoid": core_ops.aten_sigmoid,
"sign": core_ops.aten_sign,
"sin": core_ops.aten_sin,
Expand Down