diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 0544f2effb..6cf5700abc 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -61,20 +61,18 @@ Rank = common_ops.Rank -@torch_op("aten::_local_scalar_dense") -def aten__local_scalar_dense(self: Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16]) -> FLOAT: +@torch_op("aten::_local_scalar_dense", trace_only=True) +def aten__local_scalar_dense(self: TensorType) -> TensorType: """_local_scalar_dense(Tensor self) -> Scalar""" # Return the first element in tensor as a scalar. - return op.Cast(op.Gather(op.Reshape(self, [-1]), 0), to=FLOAT.dtype) - - -@torch_op("aten::_local_scalar_dense") -def aten__local_scalar_dense_int(self: IntType) -> INT64: - """_local_scalar_dense(Tensor self) -> Scalar""" - - # Return the first element in tensor as a scalar. - return op.Cast(op.Gather(op.Reshape(self, [-1]), 0), to=INT64.dtype) + if self.dtype.is_floating_point(): + dtype = ir.DataType.FLOAT + elif self.dtype == ir.DataType.BOOL: + dtype = ir.DataType.BOOL + else: + dtype = ir.DataType.INT64 + return op.Cast(op.Gather(op.Reshape(self, [-1]), 0), to=dtype) @torch_op("aten::_log_softmax", trace_only=True) diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 26b75bf93b..3d73d8b9b0 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -2308,7 +2308,7 @@ def __init__(self): opinfo_core.OpInfo( "ops.aten._local_scalar_dense", aten_name="_local_scalar_dense", - dtypes=common_dtype.all_types(), + dtypes=common_dtype.all_types_and(torch.bool), sample_inputs_func=sample_inputs__local_scalar_dense, supports_out=False, ),