Skip to content

Commit dcf98c8

Browse files
authored
Add missing converter for _local_scalar_dense (#2367)
1 parent 51ecf47 commit dcf98c8

File tree

2 files changed

+10
-12
lines changed

2 files changed

+10
-12
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,18 @@
6161
Rank = common_ops.Rank
6262

6363

64-
@torch_op("aten::_local_scalar_dense")
65-
def aten__local_scalar_dense(self: Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16]) -> FLOAT:
64+
@torch_op("aten::_local_scalar_dense", trace_only=True)
65+
def aten__local_scalar_dense(self: TensorType) -> TensorType:
6666
"""_local_scalar_dense(Tensor self) -> Scalar"""
6767

6868
# Return the first element in tensor as a scalar.
69-
return op.Cast(op.Gather(op.Reshape(self, [-1]), 0), to=FLOAT.dtype)
70-
71-
72-
@torch_op("aten::_local_scalar_dense")
73-
def aten__local_scalar_dense_int(self: IntType) -> INT64:
74-
"""_local_scalar_dense(Tensor self) -> Scalar"""
75-
76-
# Return the first element in tensor as a scalar.
77-
return op.Cast(op.Gather(op.Reshape(self, [-1]), 0), to=INT64.dtype)
69+
if self.dtype.is_floating_point():
70+
dtype = ir.DataType.FLOAT
71+
elif self.dtype == ir.DataType.BOOL:
72+
dtype = ir.DataType.BOOL
73+
else:
74+
dtype = ir.DataType.INT64
75+
return op.Cast(op.Gather(op.Reshape(self, [-1]), 0), to=dtype)
7876

7977

8078
@torch_op("aten::_log_softmax", trace_only=True)

tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2308,7 +2308,7 @@ def __init__(self):
23082308
opinfo_core.OpInfo(
23092309
"ops.aten._local_scalar_dense",
23102310
aten_name="_local_scalar_dense",
2311-
dtypes=common_dtype.all_types(),
2311+
dtypes=common_dtype.all_types_and(torch.bool),
23122312
sample_inputs_func=sample_inputs__local_scalar_dense,
23132313
supports_out=False,
23142314
),

0 commit comments

Comments
 (0)