File tree Expand file tree Collapse file tree 2 files changed +10
-12
lines changed
onnxscript/function_libs/torch_lib/ops
tests/function_libs/torch_lib Expand file tree Collapse file tree 2 files changed +10
-12
lines changed Original file line number Diff line number Diff line change 61
61
Rank = common_ops .Rank
62
62
63
63
64
- @torch_op ("aten::_local_scalar_dense" )
65
- def aten__local_scalar_dense (self : Union [ FLOAT16 , FLOAT , DOUBLE , BFLOAT16 ] ) -> FLOAT :
64
+ @torch_op ("aten::_local_scalar_dense" , trace_only = True )
65
+ def aten__local_scalar_dense (self : TensorType ) -> TensorType :
66
66
"""_local_scalar_dense(Tensor self) -> Scalar"""
67
67
68
68
# Return the first element in tensor as a scalar.
69
- return op .Cast (op .Gather (op .Reshape (self , [- 1 ]), 0 ), to = FLOAT .dtype )
70
-
71
-
72
- @torch_op ("aten::_local_scalar_dense" )
73
- def aten__local_scalar_dense_int (self : IntType ) -> INT64 :
74
- """_local_scalar_dense(Tensor self) -> Scalar"""
75
-
76
- # Return the first element in tensor as a scalar.
77
- return op .Cast (op .Gather (op .Reshape (self , [- 1 ]), 0 ), to = INT64 .dtype )
69
+ if self .dtype .is_floating_point ():
70
+ dtype = ir .DataType .FLOAT
71
+ elif self .dtype == ir .DataType .BOOL :
72
+ dtype = ir .DataType .BOOL
73
+ else :
74
+ dtype = ir .DataType .INT64
75
+ return op .Cast (op .Gather (op .Reshape (self , [- 1 ]), 0 ), to = dtype )
78
76
79
77
80
78
@torch_op ("aten::_log_softmax" , trace_only = True )
Original file line number Diff line number Diff line change @@ -2308,7 +2308,7 @@ def __init__(self):
2308
2308
opinfo_core .OpInfo (
2309
2309
"ops.aten._local_scalar_dense" ,
2310
2310
aten_name = "_local_scalar_dense" ,
2311
- dtypes = common_dtype .all_types ( ),
2311
+ dtypes = common_dtype .all_types_and ( torch . bool ),
2312
2312
sample_inputs_func = sample_inputs__local_scalar_dense ,
2313
2313
supports_out = False ,
2314
2314
),
You can’t perform that action at this time.
0 commit comments