diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index ffa21ae382..1e27b1840b 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -39,6 +39,14 @@ _MATH_PI = math.pi +@torch_op("aten::_local_scalar_dense") +def aten__local_scalar_dense(self: TTensor) -> TTensor: + """_local_scalar_dense(Tensor self) -> Scalar""" + + # Return the first element in tensor as a scalar. + return op.Gather(op.Reshape(self, [-1]), 0) + + @torch_op("aten::abs") def aten_abs(self: TRealOrUInt8) -> TRealOrUInt8: """abs(Tensor self) -> Tensor""" diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 7877061a2e..85b66bacd9 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -17,6 +17,33 @@ from torch.testing._internal.opinfo import core as opinfo_core +def sample_inputs__local_scalar_dense(op_info, device, dtype, requires_grad, **kwargs): + del op_info + + shapes = ( + (), + (1,), + (3,), + (1, 1), + (1, 2), + (2, 1), + (1, 1, 1), + (2, 2, 2), + ) + + for shape in shapes: + t = torch_testing.make_tensor( + shape, + low=0, + high=1, + device=device, + dtype=dtype, + requires_grad=requires_grad, + **kwargs, + ) + yield opinfo_core.SampleInput(t) + + def sample_inputs_conv3d(op_info, device, dtype, requires_grad, **kwargs): del op_info make_arg = functools.partial( @@ -527,6 +554,13 @@ def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_gra OP_DB: List[opinfo_core.OpInfo] = [ + opinfo_core.OpInfo( + "aten._local_scalar_dense", + op=torch.ops.aten._local_scalar_dense, # pylint: disable=protected-access + aten_name="_local_scalar_dense", + dtypes=common_dtype.all_types(), + sample_inputs_func=sample_inputs__local_scalar_dense, + ), opinfo_core.OpInfo( "col2im", op=torch.ops.aten.col2im, diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 072e001b5a..89987b5fcc 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -408,6 +408,10 @@ def _where_input_wrangler( # Ops to be tested for numerical consistency between onnx and pytorch # Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py TESTED_TORCHLIB_OPS: tuple[TorchLibOpInfo, ...] = ( + TorchLibOpInfo( + "aten._local_scalar_dense", + core_ops.aten__local_scalar_dense, + ), TorchLibOpInfo("all_dim", core_ops.aten_all_dim).xfail( matcher=lambda sample: not (len(sample.kwargs) > 0), reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design",