Skip to content

Commit 03f40b9

Browse files
authored
Implement aten._local_scalar_dense | feat(torchlib) (#847)
This op is used in llama when capturing graph using dynamo.
1 parent 52c8531 commit 03f40b9

File tree

3 files changed

+46
-0
lines changed

3 files changed

+46
-0
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,14 @@
3939
_MATH_PI = math.pi
4040

4141

42+
@torch_op("aten::_local_scalar_dense")
43+
def aten__local_scalar_dense(self: TTensor) -> TTensor:
44+
"""_local_scalar_dense(Tensor self) -> Scalar"""
45+
46+
# Return the first element in tensor as a scalar.
47+
return op.Gather(op.Reshape(self, [-1]), 0)
48+
49+
4250
@torch_op("aten::abs")
4351
def aten_abs(self: TRealOrUInt8) -> TRealOrUInt8:
4452
"""abs(Tensor self) -> Tensor"""

onnxscript/tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,33 @@
1717
from torch.testing._internal.opinfo import core as opinfo_core
1818

1919

20+
def sample_inputs__local_scalar_dense(op_info, device, dtype, requires_grad, **kwargs):
21+
del op_info
22+
23+
shapes = (
24+
(),
25+
(1,),
26+
(3,),
27+
(1, 1),
28+
(1, 2),
29+
(2, 1),
30+
(1, 1, 1),
31+
(2, 2, 2),
32+
)
33+
34+
for shape in shapes:
35+
t = torch_testing.make_tensor(
36+
shape,
37+
low=0,
38+
high=1,
39+
device=device,
40+
dtype=dtype,
41+
requires_grad=requires_grad,
42+
**kwargs,
43+
)
44+
yield opinfo_core.SampleInput(t)
45+
46+
2047
def sample_inputs_conv3d(op_info, device, dtype, requires_grad, **kwargs):
2148
del op_info
2249
make_arg = functools.partial(
@@ -527,6 +554,13 @@ def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_gra
527554

528555

529556
OP_DB: List[opinfo_core.OpInfo] = [
557+
opinfo_core.OpInfo(
558+
"aten._local_scalar_dense",
559+
op=torch.ops.aten._local_scalar_dense, # pylint: disable=protected-access
560+
aten_name="_local_scalar_dense",
561+
dtypes=common_dtype.all_types(),
562+
sample_inputs_func=sample_inputs__local_scalar_dense,
563+
),
530564
opinfo_core.OpInfo(
531565
"col2im",
532566
op=torch.ops.aten.col2im,

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,10 @@ def _where_input_wrangler(
408408
# Ops to be tested for numerical consistency between onnx and pytorch
409409
# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
410410
TESTED_TORCHLIB_OPS: tuple[TorchLibOpInfo, ...] = (
411+
TorchLibOpInfo(
412+
"aten._local_scalar_dense",
413+
core_ops.aten__local_scalar_dense,
414+
),
411415
TorchLibOpInfo("all_dim", core_ops.aten_all_dim).xfail(
412416
matcher=lambda sample: not (len(sample.kwargs) > 0),
413417
reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design",

0 commit comments

Comments
 (0)