Skip to content

Commit 3f433e7

Browse files
feat(atenlib): separate inputs from attributes (#368)
Create `param_manipulation` for separate inputs and attributes from python args and kwargs. Raname `typing.py` to `onnxscript/function_libs/torch_aten/tensor_typing.py` - #305 - #287 - #320 Co-authored-by: Ti-Tai Wang <[email protected]>
1 parent 2b84597 commit 3f433e7

File tree

8 files changed

+316
-8
lines changed

8 files changed

+316
-8
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from onnxscript import BOOL, DOUBLE, FLOAT, INT16, INT32, INT64
1717
from onnxscript.function_libs.torch_aten.registration import torch_op
18-
from onnxscript.function_libs.torch_aten.typing import (
18+
from onnxscript.function_libs.torch_aten.tensor_typing import (
1919
IntType,
2020
TFloat,
2121
TFloatOrBFloat16,

onnxscript/function_libs/torch_aten/ops/nn.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818

1919
from onnxscript import FLOAT, INT64
2020
from onnxscript.function_libs.torch_aten.registration import torch_op
21-
from onnxscript.function_libs.torch_aten.typing import TFloat, TFloatOrBFloat16, TReal
21+
from onnxscript.function_libs.torch_aten.tensor_typing import (
22+
TFloat,
23+
TFloatOrBFloat16,
24+
TReal,
25+
)
2226
from onnxscript.onnx_opset import opset18 as op
2327
from onnxscript.onnx_types import TensorType
2428

onnxscript/function_libs/torch_aten/ops/special.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from onnxscript import FLOAT
1717
from onnxscript.function_libs.torch_aten.registration import torch_op
18-
from onnxscript.function_libs.torch_aten.typing import TFloatOrBFloat16
18+
from onnxscript.function_libs.torch_aten.tensor_typing import TFloatOrBFloat16
1919
from onnxscript.onnx_opset import opset18 as op
2020
from onnxscript.onnx_types import TensorType
2121

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
"""Function for manipulating input parameters of an Op or a OnnxFunction."""
2+
from __future__ import annotations
3+
4+
import collections
5+
import dataclasses
6+
from typing import Any, List, OrderedDict, Sequence, Tuple
7+
8+
import onnx
9+
10+
from onnxscript import values
11+
12+
# A special value to indicate that the default value is not specified
13+
_EmptyDefault = object()
14+
15+
16+
@dataclasses.dataclass(frozen=True)
17+
class ParamSchema:
18+
"""A schema for a parameter of an Op or a OnnxFunction.
19+
20+
Attributes:
21+
name: The name of the parameter.
22+
type: The type of the parameter.
23+
default: The default value of the parameter.
24+
is_input: Whether the parameter is an ONNX input.
25+
"""
26+
27+
name: str
28+
type: Any = None # Op input does not have a type, for now
29+
default: Any = _EmptyDefault
30+
is_input: bool = True
31+
32+
def __repr__(self) -> str:
33+
param_kind = "INPUT" if self.is_input else "ATTRIBUTE"
34+
text = f"{self.name}<{param_kind}>: {self.type}"
35+
if self.default is not _EmptyDefault:
36+
text += f" = {self.default}"
37+
return text
38+
39+
@property
40+
def is_attribute(self) -> bool:
41+
"""Returns True if the parameter is an ONNX attribute."""
42+
return not self.is_input
43+
44+
45+
def extract_param_schema_from_function(onnx_func: values.OnnxFunction) -> List[ParamSchema]:
46+
47+
function_ir = onnx_func.function_ir
48+
# The first len(func_ir.inputs) arguments are onnx inputs
49+
inputs = function_ir.inputs
50+
# The rest is onnx attributes
51+
attributes = function_ir.attrs
52+
# Construct a dictionary of attributes with their names specified in the function
53+
# definition
54+
attr_name_to_protos = collections.OrderedDict(
55+
(attr.name, attr) for attr in function_ir.attr_protos
56+
)
57+
58+
# args with default value are attributes
59+
param_schemas = []
60+
for arg in inputs:
61+
param_schema = ParamSchema(name=arg.name, type=arg.typeinfo, is_input=True)
62+
param_schemas.append(param_schema)
63+
64+
for attr_name in attributes:
65+
# FIXME(justinchuby): Where can we find the type?
66+
param_schema = ParamSchema(name=attr_name, type=None, is_input=False)
67+
param_schemas.append(param_schema)
68+
69+
for name, attr_value in attr_name_to_protos.items():
70+
param_schema = ParamSchema(
71+
name=name,
72+
type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE[attr_value.type],
73+
default=_get_attribute_value(attr_value.attr_proto),
74+
is_input=False,
75+
)
76+
param_schemas.append(param_schema)
77+
return param_schemas
78+
79+
80+
_ATTRIBUTE_TYPE_TO_PYTHON_TYPE = {
81+
onnx.defs.OpSchema.AttrType.FLOAT: float,
82+
onnx.defs.OpSchema.AttrType.INT: int,
83+
onnx.defs.OpSchema.AttrType.STRING: str,
84+
onnx.defs.OpSchema.AttrType.TENSOR: None,
85+
onnx.defs.OpSchema.AttrType.GRAPH: None,
86+
onnx.defs.OpSchema.AttrType.SPARSE_TENSOR: None,
87+
onnx.defs.OpSchema.AttrType.TYPE_PROTO: None,
88+
onnx.defs.OpSchema.AttrType.FLOATS: Sequence[float],
89+
onnx.defs.OpSchema.AttrType.INTS: Sequence[int],
90+
onnx.defs.OpSchema.AttrType.STRINGS: Sequence[str],
91+
onnx.defs.OpSchema.AttrType.TENSORS: None,
92+
onnx.defs.OpSchema.AttrType.GRAPHS: None,
93+
onnx.defs.OpSchema.AttrType.SPARSE_TENSORS: None,
94+
onnx.defs.OpSchema.AttrType.TYPE_PROTOS: None,
95+
}
96+
97+
98+
def _get_attribute_value(attr_proto):
99+
if attr_proto.type == onnx.AttributeProto.UNDEFINED:
100+
return _EmptyDefault
101+
if attr_proto.type == onnx.AttributeProto.FLOAT:
102+
return attr_proto.f
103+
if attr_proto.type == onnx.AttributeProto.INT:
104+
return attr_proto.i
105+
if attr_proto.type == onnx.AttributeProto.STRING:
106+
return attr_proto.s
107+
if attr_proto.type == onnx.AttributeProto.FLOATS:
108+
return [float(v) for v in attr_proto.f]
109+
if attr_proto.type == onnx.AttributeProto.INTS:
110+
return [int(v) for v in attr_proto.i]
111+
raise TypeError(f"Unsupported attribute type: {attr_proto.type}")
112+
113+
114+
def extract_param_schema_from_op_schema(op_schema: onnx.defs.OpSchema) -> List[ParamSchema]:
115+
param_schemas = []
116+
for input_ in op_schema.inputs:
117+
param_schema = ParamSchema(name=input_.name, is_input=True)
118+
param_schemas.append(param_schema)
119+
for attr_name, attribute in op_schema.attributes.items():
120+
default_attr_proto = attribute.default_value
121+
param_schema = ParamSchema(
122+
name=attr_name,
123+
type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE[attribute.type],
124+
default=_get_attribute_value(default_attr_proto),
125+
is_input=False,
126+
)
127+
param_schemas.append(param_schema)
128+
129+
return param_schemas
130+
131+
132+
def separate_input_attributes_from_arguments(
133+
param_schemas: Sequence[ParamSchema],
134+
args,
135+
kwargs,
136+
) -> Tuple[List[Any], OrderedDict[str, Any]]:
137+
"""Separate Python args and kwargs into ONNX inputs and attributes.
138+
139+
Args:
140+
param_schemas: The parameter schemas of an Op or a OnnxFunction.
141+
args: The Python positional arguments supplied by the caller.
142+
kwargs: The Python keyword arguments supplied by the caller.
143+
144+
Returns:
145+
A tuple of two elements:
146+
- A list of ONNX inputs.
147+
- An ordered dictionary of ONNX attribute names and values.
148+
"""
149+
# args, kwargs and param_schemas should be all in order
150+
# user might not specify all attributes
151+
if len(args) + len(kwargs) > len(param_schemas):
152+
raise TypeError("Inputs are more than expected in schema")
153+
154+
onnx_inputs = []
155+
onnx_attributes = OrderedDict()
156+
for i, param in enumerate(param_schemas):
157+
if i < len(args):
158+
if not param.is_attribute:
159+
onnx_inputs.append(args[i])
160+
else:
161+
onnx_attributes[param.name] = args[i]
162+
elif param.name in kwargs:
163+
if not param.is_attribute:
164+
onnx_inputs.append(kwargs[param.name])
165+
else:
166+
onnx_attributes[param.name] = kwargs[param.name]
167+
else:
168+
# input doesn't have default
169+
onnx_attributes[param.name] = param.default
170+
171+
return onnx_inputs, onnx_attributes
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import collections
2+
import unittest
3+
4+
import parameterized
5+
6+
from onnxscript import INT64
7+
from onnxscript.function_libs.torch_aten.param_manipulation import (
8+
ParamSchema,
9+
separate_input_attributes_from_arguments,
10+
)
11+
12+
TEST_INPUT = "TEST_INPUT"
13+
14+
15+
class TestParamManipulation(unittest.TestCase):
16+
@parameterized.parameterized.expand(
17+
[
18+
(
19+
"all_positional",
20+
(TEST_INPUT, 42, 0.0),
21+
{},
22+
0.0,
23+
),
24+
(
25+
"positional_with_default",
26+
(TEST_INPUT, 42),
27+
{},
28+
100.0,
29+
),
30+
(
31+
"positional_with_default_and_kwargs",
32+
(TEST_INPUT,),
33+
{"b": 42},
34+
100.0,
35+
),
36+
(
37+
"positional_with_kwargs",
38+
(TEST_INPUT, 42),
39+
{"c": 0.0},
40+
0.0,
41+
),
42+
(
43+
"positional_input_with_kwargs_attribute",
44+
(TEST_INPUT,),
45+
{"b": 42, "c": 0.0},
46+
0.0,
47+
),
48+
(
49+
"all_kwargs",
50+
(),
51+
{"a": TEST_INPUT, "b": 42, "c": 0.0},
52+
0.0,
53+
),
54+
(
55+
"all_kwargs_with_default",
56+
(),
57+
{"a": TEST_INPUT, "b": 42},
58+
100.0,
59+
),
60+
]
61+
)
62+
def test_separate_input_attributes_from_arguments_correct_on(
63+
self, _, args, kwargs, expected_c
64+
):
65+
param_schemas = (
66+
ParamSchema(name="a", type=INT64, is_input=True),
67+
ParamSchema(name="b", type=int, is_input=False),
68+
ParamSchema(name="c", type=float, default=100.0, is_input=False),
69+
)
70+
71+
expected_inputs = [TEST_INPUT]
72+
expected_attributes = collections.OrderedDict(
73+
[
74+
("b", 42),
75+
("c", expected_c),
76+
]
77+
)
78+
79+
inputs, attributes = separate_input_attributes_from_arguments(
80+
param_schemas, args, kwargs
81+
)
82+
83+
print("\ninputs: ", inputs)
84+
print("\nexpected_inputs: ", expected_inputs)
85+
86+
self.assertEqual(len(inputs), len(expected_inputs))
87+
for input_, expected_input in zip(inputs, expected_inputs):
88+
self.assertIs(input_, expected_input)
89+
self.assertEqual(attributes, expected_attributes)
90+
91+
@parameterized.parameterized.expand(
92+
[
93+
(
94+
"extra_positional",
95+
(TEST_INPUT, 42, 0.0, -1),
96+
{},
97+
),
98+
(
99+
"extra_keyword",
100+
(TEST_INPUT, 42, 0.0),
101+
{"unknown": -1},
102+
),
103+
(
104+
"extra_positional_and_keyword",
105+
(TEST_INPUT, 42, 0.0, -1),
106+
{"unknown": -1},
107+
),
108+
]
109+
)
110+
def test_separate_input_attributes_from_arguments_raises_on_extra_args(
111+
self, _, args, kwargs
112+
):
113+
param_schemas = (
114+
ParamSchema(name="a", type=INT64, is_input=True),
115+
ParamSchema(name="b", type=int, is_input=False),
116+
ParamSchema(name="c", type=float, default=100.0, is_input=False),
117+
)
118+
119+
with self.assertRaises(TypeError):
120+
_, _ = separate_input_attributes_from_arguments(param_schemas, args, kwargs)
121+
122+
123+
if __name__ == "__main__":
124+
unittest.main()

onnxscript/irbuilder.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,14 @@ def __str__(self):
120120
# self.name + " = " + self.value
121121
return helper.printable_attribute(self.attr_proto)
122122

123+
@property
124+
def name(self):
125+
return self.attr_proto.name
126+
127+
@property
128+
def type(self):
129+
return self.attr_proto.type
130+
123131

124132
class IRStmt:
125133
def __init__(

onnxscript/values.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,21 +98,22 @@ class Op:
9898
It belongs to a particular Opset and has a name.
9999
"""
100100

101-
def __init__(self, opset, opname, opschema=None) -> None:
102-
101+
def __init__(
102+
self, opset, opname: str, opschema: Optional[onnx.defs.OpSchema] = None
103+
) -> None:
103104
self.opset = opset
104105
self.opname = opname
105106
self.opschema = opschema
106107

107-
def is_single_op(self):
108+
def is_single_op(self) -> bool:
108109
return isinstance(self.opname, str)
109110

110-
def get_schema(self):
111+
def get_schema(self) -> onnx.defs.OpSchema:
111112
if self.opschema:
112113
return self.opschema
113114
return self.opset[self.opname]
114115

115-
def has_schema(self):
116+
def has_schema(self) -> bool:
116117
return self.opschema is not None
117118

118119
def adapt_kwargs(self, kwargs):

0 commit comments

Comments
 (0)