Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@
_MATH_PI = math.pi


@torch_op("aten::_local_scalar_dense")
def aten__local_scalar_dense(self: TTensor) -> TTensor:
Copy link
Copy Markdown
Collaborator

@justinchuby justinchuby Jul 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a docstring with its aten signature from https://aka.ms/native-functions like other functions do (looks like it's _local_scalar_dense(Tensor self) -> Scalar. Would be nice to have a reference to its implementation or documentation too.

"""_local_scalar_dense(Tensor self) -> Scalar"""

# Return the first element in tensor as a scalar.
return op.Gather(op.Reshape(self, [-1]), 0)


@torch_op("aten::abs")
def aten_abs(self: TRealOrUInt8) -> TRealOrUInt8:
"""abs(Tensor self) -> Tensor"""
Expand Down
35 changes: 35 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,33 @@
from torch.testing._internal.opinfo import core as opinfo_core


def sample_inputs__local_scalar_dense(op_info, device, dtype, requires_grad, **kwargs):
del op_info

shapes = (
(),
(1,),
(3,),
(1, 1),
(1, 2),
(2, 1),
(1, 1, 1),
(2, 2, 2),
)

for shape in shapes:
t = torch_testing.make_tensor(
shape,
low=0,
high=1,
device=device,
dtype=dtype,
requires_grad=requires_grad,
**kwargs,
)
yield opinfo_core.SampleInput(t)


def sample_inputs_conv3d(op_info, device, dtype, requires_grad, **kwargs):
del op_info
make_arg = functools.partial(
Expand Down Expand Up @@ -527,6 +554,14 @@ def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_gra


OP_DB: List[opinfo_core.OpInfo] = [
opinfo_core.OpInfo(
"aten._local_scalar_dense",
# pylint: disable=protected-access
op=torch.ops.aten._local_scalar_dense,
aten_name="_local_scalar_dense",
dtypes=common_dtype.all_types(),
sample_inputs_func=sample_inputs__local_scalar_dense,
),
opinfo_core.OpInfo(
"col2im",
op=torch.ops.aten.col2im,
Expand Down
4 changes: 4 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,10 @@ def _where_input_wrangler(
# Ops to be tested for numerical consistency between onnx and pytorch
# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
TESTED_TORCHLIB_OPS: tuple[TorchLibOpInfo, ...] = (
TorchLibOpInfo(
"aten._local_scalar_dense",
core_ops.aten__local_scalar_dense,
),
TorchLibOpInfo("all_dim", core_ops.aten_all_dim).xfail(
matcher=lambda sample: not (len(sample.kwargs) > 0),
reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design",
Expand Down