Skip to content

Commit 9fb0a7d

Browse files
authored
Create the Rank and IsScalar shared functions | feat(torchlib) (#1105)
This change introduces two shared operators `Rank` and `IsScalar`. They are used to replace the `Size(Shape())` pattern for code reuse and readability. I used a hack to always include these shared functions in the model proto because without #834 we cannot dynamically add these functions to the model as they are used. I added a TODO for this. The first usage is in `aten_all`. I will update the rest of the functions in a separate PR. #1095
1 parent 0035390 commit 9fb0a7d

File tree

6 files changed

+54
-5
lines changed

6 files changed

+54
-5
lines changed

onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class TestDeduceTypeConstraints(unittest.TestCase):
3030
"_aten_embedding_bag_onnx",
3131
"_aten_embedding_bag_1d_padding_idx_onnx",
3232
)
33-
_SKIP_FUNCTIONS_WITH_NESTED_FUNCTION = ()
33+
_SKIP_FUNCTIONS_WITH_NESTED_FUNCTION = ("aten_all",)
3434

3535
@parameterized.parameterized.expand(
3636
((op,) for op in torch_lib_onnx_functions_from_registry()),
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
"""Shared constants for the library."""
2+
3+
DOMAIN = "pkg.onnxscript.torch_lib"

onnxscript/function_libs/torch_lib/graph_building.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from onnxscript import evaluator
2222
from onnxscript import tensor as onnxscript_tensor
2323
from onnxscript._internal import param_manipulation, runtime_typing
24+
from onnxscript.function_libs.torch_lib.ops import common as common_ops
2425

2526
__all__ = [
2627
"TorchScriptTensor",
@@ -363,6 +364,16 @@ def _tensor_rawdata_size(tensor: torch.Tensor) -> int:
363364
return tensor.numel() * tensor.element_size()
364365

365366

367+
def _shared_functions() -> list[onnx.FunctionProto]:
368+
"""Hack to always include the share ops."""
369+
370+
# TODO: Remove after https://github.com/microsoft/onnxscript/issues/834 is fixed
371+
return [
372+
common_ops.Rank.to_function_proto(),
373+
common_ops.IsScalar.to_function_proto(),
374+
]
375+
376+
366377
class TorchScriptGraph:
367378
def __init__(
368379
self,
@@ -717,7 +728,6 @@ def to_function_proto(self, opset_version: int, function_name: str) -> onnx.Func
717728
opset_imports=onnx_model.opset_import,
718729
doc_string=onnx_model.doc_string,
719730
)
720-
# TODO: onnx.checker.check_function(onnx_function)?
721731
return onnx_function
722732

723733
@runtime_typing.checked
@@ -786,6 +796,7 @@ def to_model_proto(
786796
onnx_model = onnx.load_from_string(proto)
787797

788798
onnx_model.functions.extend(function_proto_dict.values())
799+
onnx_model.functions.extend(_shared_functions())
789800

790801
# `_export_onnx` only exports opset_imports that is visible to it. It does not
791802
# export opset_imports for nested functions, since it does not have access to
@@ -800,6 +811,13 @@ def to_model_proto(
800811
for domain, version in unique_custom_domains.items()
801812
]
802813
)
814+
# Include the library shared opset domain
815+
# TODO: Remove after https://github.com/microsoft/onnxscript/issues/834 is fixed
816+
onnx_model.opset_import.append(
817+
onnx.helper.make_opsetid(
818+
common_ops.common_opset.domain, common_ops.common_opset.version
819+
)
820+
)
803821

804822
try:
805823
if not cache_model_to_disk:
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""Common operators shared in the torchlib library."""
2+
3+
import onnxscript
4+
import onnxscript.values
5+
from onnxscript import BOOL, INT64
6+
from onnxscript import opset18 as op
7+
from onnxscript.function_libs.torch_lib import _constants, tensor_typing
8+
9+
DOMAIN = f"{_constants.DOMAIN}.common"
10+
11+
common_opset = onnxscript.values.Opset(domain=DOMAIN, version=1)
12+
13+
14+
@onnxscript.script(common_opset)
15+
def Rank(input: tensor_typing.TTensor) -> INT64:
16+
"""Take the rank of the input tensor."""
17+
18+
return op.Size(op.Shape(input))
19+
20+
21+
@onnxscript.script(common_opset)
22+
def IsScalar(input: tensor_typing.TTensor) -> BOOL:
23+
"""Return whether the input has rank 0, or is a scalar."""
24+
25+
return op.Equal(op.Size(op.Shape(input)), op.Constant(value_int=0))

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
UINT64,
3131
graph,
3232
)
33+
from onnxscript.function_libs.torch_lib.ops import common as common_ops
3334
from onnxscript.function_libs.torch_lib.registration import torch_op
3435
from onnxscript.function_libs.torch_lib.tensor_typing import (
3536
IntType,
@@ -52,6 +53,8 @@
5253
_INT64_MAX = 9223372036854775807
5354
_INT64_MIN = -9223372036854775808
5455
_MATH_PI = math.pi
56+
IsScalar = common_ops.IsScalar
57+
Rank = common_ops.Rank
5558

5659

5760
@torch_op("aten::_local_scalar_dense")
@@ -320,8 +323,7 @@ def aten_align_to(self: TensorType, names: Sequence[str]) -> TensorType:
320323
def aten_all(self: TTensor) -> BOOL:
321324
"""all(Tensor self) -> Tensor"""
322325

323-
self_rank = op.Size(op.Shape(self))
324-
if self_rank == 0:
326+
if IsScalar(self):
325327
result = op.Cast(self, to=BOOL.dtype)
326328
else:
327329
self_bool = op.Cast(self, to=BOOL.dtype)

onnxscript/function_libs/torch_lib/registration.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any, Callable, Generator, Optional
88

99
import onnxscript
10+
from onnxscript.function_libs.torch_lib import _constants
1011

1112
# Regex that will match "<namespace>::<op_name>[.<overload>]"
1213
_QUALIFIED_OPERATOR_NAME_REGEX = re.compile(
@@ -119,7 +120,7 @@ def wrapper(
119120
func: FunctionType,
120121
) -> onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction:
121122
# Compile the function
122-
custom_opset = onnxscript.values.Opset(domain="pkg.onnxscript.torch_lib", version=1)
123+
custom_opset = onnxscript.values.Opset(domain=_constants.DOMAIN, version=1)
123124

124125
processed_func: onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction
125126
if trace_only:

0 commit comments

Comments
 (0)