Skip to content
2 changes: 1 addition & 1 deletion onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from onnxscript import BOOL, DOUBLE, FLOAT, INT16, INT32, INT64
from onnxscript.function_libs.torch_aten.registration import torch_op
from onnxscript.function_libs.torch_aten.typing import (
from onnxscript.function_libs.torch_aten.tensor_typing import (
IntType,
TFloat,
TFloatOrBFloat16,
Expand Down
6 changes: 5 additions & 1 deletion onnxscript/function_libs/torch_aten/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@

from onnxscript import FLOAT, INT64
from onnxscript.function_libs.torch_aten.registration import torch_op
from onnxscript.function_libs.torch_aten.typing import TFloat, TFloatOrBFloat16, TReal
from onnxscript.function_libs.torch_aten.tensor_typing import (
TFloat,
TFloatOrBFloat16,
TReal,
)
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import TensorType

Expand Down
2 changes: 1 addition & 1 deletion onnxscript/function_libs/torch_aten/ops/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from onnxscript import FLOAT
from onnxscript.function_libs.torch_aten.registration import torch_op
from onnxscript.function_libs.torch_aten.typing import TFloatOrBFloat16
from onnxscript.function_libs.torch_aten.tensor_typing import TFloatOrBFloat16
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import TensorType

Expand Down
171 changes: 171 additions & 0 deletions onnxscript/function_libs/torch_aten/param_manipulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""Function for manipulating input parameters of an Op or a OnnxFunction."""
from __future__ import annotations

import collections
import dataclasses
from typing import Any, List, OrderedDict, Sequence, Tuple

import onnx

from onnxscript import values

# A special value to indicate that the default value is not specified
_EmptyDefault = object()


@dataclasses.dataclass(frozen=True)
class ParamSchema:
"""A schema for a parameter of an Op or a OnnxFunction.

Attributes:
name: The name of the parameter.
type: The type of the parameter.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we have multiple type representations, which one is this supposed to be? Eg., is this supposed to be python-types (like "int" and "List[FLOAT]") ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Python types sg. For now not really enforced and it could be anything

default: The default value of the parameter.
is_input: Whether the parameter is an ONNX input.
"""

name: str
type: Any = None # Op input does not have a type, for now
default: Any = _EmptyDefault
is_input: bool = True

def __repr__(self) -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the return value supposed to be valid python or just a description?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering whether str is more appropriate or repr ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thanks!

param_kind = "INPUT" if self.is_input else "ATTRIBUTE"
text = f"{self.name}<{param_kind}>: {self.type}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit (optional): f"{self.name} : {param_kind}({self.type})" or f"{param_kind} {self.name} : {self.type}" would be more readable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG!

Copy link
Collaborator Author

@justinchuby justinchuby Feb 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a<INTPUT>: INT64

vs

INPUT a : INT64
a : INTPUT(INT64)

Copy link
Collaborator Author

@justinchuby justinchuby Feb 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will go with a: Input[INT64]

if self.default is not _EmptyDefault:
text += f" = {self.default}"
return text

@property
def is_attribute(self) -> bool:
"""Returns True if the parameter is an ONNX attribute."""
return not self.is_input


def extract_param_schema_from_function(onnx_func: values.OnnxFunction) -> List[ParamSchema]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current OnnxFunction IR loses the positional-order information between inputs and attributes (since we didn't care about it then given our assumptions that all inputs come first). We can't reconstruct it subsequently.

Perhaps we should consider updating OnnxFunction IR to store a List[ParamSchema] in the first place ... just thinking out aloud, haven't read the rest of the PR yet.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this idea! I was going to propose exposing ParamSchema somewhere in the function and op as the next thing to do so we have a common interface as #320


function_ir = onnx_func.function_ir
# The first len(func_ir.inputs) arguments are onnx inputs
inputs = function_ir.inputs
# The rest is onnx attributes
attributes = function_ir.attrs
# Construct a dictionary of attributes with their names specified in the function
# definition
attr_name_to_protos = collections.OrderedDict(
(attr.name, attr) for attr in function_ir.attr_protos
)

# args with default value are attributes
param_schemas = []
for arg in inputs:
param_schema = ParamSchema(name=arg.name, type=arg.typeinfo, is_input=True)
param_schemas.append(param_schema)

for attr_name in attributes:
# FIXME(justinchuby): Where can we find the type?
param_schema = ParamSchema(name=attr_name, type=None, is_input=False)
param_schemas.append(param_schema)

for name, attr_value in attr_name_to_protos.items():
param_schema = ParamSchema(
name=name,
type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE[attr_value.type],
default=_get_attribute_value(attr_value.attr_proto),
is_input=False,
)
param_schemas.append(param_schema)
return param_schemas


_ATTRIBUTE_TYPE_TO_PYTHON_TYPE = {
onnx.defs.OpSchema.AttrType.FLOAT: float,
onnx.defs.OpSchema.AttrType.INT: int,
onnx.defs.OpSchema.AttrType.STRING: str,
onnx.defs.OpSchema.AttrType.TENSOR: None,
onnx.defs.OpSchema.AttrType.GRAPH: None,
onnx.defs.OpSchema.AttrType.SPARSE_TENSOR: None,
onnx.defs.OpSchema.AttrType.TYPE_PROTO: None,
onnx.defs.OpSchema.AttrType.FLOATS: Sequence[float],
onnx.defs.OpSchema.AttrType.INTS: Sequence[int],
onnx.defs.OpSchema.AttrType.STRINGS: Sequence[str],
onnx.defs.OpSchema.AttrType.TENSORS: None,
onnx.defs.OpSchema.AttrType.GRAPHS: None,
onnx.defs.OpSchema.AttrType.SPARSE_TENSORS: None,
onnx.defs.OpSchema.AttrType.TYPE_PROTOS: None,
}


def _get_attribute_value(attr_proto):
if attr_proto.type == onnx.AttributeProto.UNDEFINED:
return _EmptyDefault
if attr_proto.type == onnx.AttributeProto.FLOAT:
return attr_proto.f
if attr_proto.type == onnx.AttributeProto.INT:
return attr_proto.i
if attr_proto.type == onnx.AttributeProto.STRING:
return attr_proto.s
if attr_proto.type == onnx.AttributeProto.FLOATS:
return [float(v) for v in attr_proto.f]
if attr_proto.type == onnx.AttributeProto.INTS:
return [int(v) for v in attr_proto.i]
raise TypeError(f"Unsupported attribute type: {attr_proto.type}")


def extract_param_schema_from_op_schema(op_schema: onnx.defs.OpSchema) -> List[ParamSchema]:
param_schemas = []
for input_ in op_schema.inputs:
param_schema = ParamSchema(name=input_.name, is_input=True)
param_schemas.append(param_schema)
for attr_name, attribute in op_schema.attributes.items():
default_attr_proto = attribute.default_value
param_schema = ParamSchema(
name=attr_name,
type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE[attribute.type],
default=_get_attribute_value(default_attr_proto),
is_input=False,
)
param_schemas.append(param_schema)

return param_schemas


def separate_input_attributes_from_arguments(
param_schemas: Sequence[ParamSchema],
args,
kwargs,
) -> Tuple[List[Any], OrderedDict[str, Any]]:
"""Separate Python args and kwargs into ONNX inputs and attributes.

Args:
param_schemas: The parameter schemas of an Op or a OnnxFunction.
args: The Python positional arguments supplied by the caller.
kwargs: The Python keyword arguments supplied by the caller.

Returns:
A tuple of two elements:
- A list of ONNX inputs.
- An ordered dictionary of ONNX attribute names and values.
"""
# args, kwargs and param_schemas should be all in order
# user might not specify all attributes
if len(args) + len(kwargs) > len(param_schemas):
raise TypeError("Inputs are more than expected in schema")

onnx_inputs = []
onnx_attributes = OrderedDict()
for i, param in enumerate(param_schemas):
if i < len(args):
if not param.is_attribute:
onnx_inputs.append(args[i])
else:
onnx_attributes[param.name] = args[i]
elif param.name in kwargs:
if not param.is_attribute:
onnx_inputs.append(kwargs[param.name])
else:
onnx_attributes[param.name] = kwargs[param.name]
else:
# input doesn't have default
onnx_attributes[param.name] = param.default
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't seen the usage of this function. Perhaps some usage contexts would not want default-values to be filled in. Anyway, I guess we can cross that bridge when we come to it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a fill_defaults option to the function


return onnx_inputs, onnx_attributes
124 changes: 124 additions & 0 deletions onnxscript/function_libs/torch_aten/param_manipulation_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import collections
import unittest

import parameterized

from onnxscript import INT64
from onnxscript.function_libs.torch_aten.param_manipulation import (
ParamSchema,
separate_input_attributes_from_arguments,
)

TEST_INPUT = "TEST_INPUT"


class TestParamManipulation(unittest.TestCase):
@parameterized.parameterized.expand(
[
(
"all_positional",
(TEST_INPUT, 42, 0.0),
{},
0.0,
),
(
"positional_with_default",
(TEST_INPUT, 42),
{},
100.0,
),
(
"positional_with_default_and_kwargs",
(TEST_INPUT,),
{"b": 42},
100.0,
),
(
"positional_with_kwargs",
(TEST_INPUT, 42),
{"c": 0.0},
0.0,
),
(
"positional_input_with_kwargs_attribute",
(TEST_INPUT,),
{"b": 42, "c": 0.0},
0.0,
),
(
"all_kwargs",
(),
{"a": TEST_INPUT, "b": 42, "c": 0.0},
0.0,
),
(
"all_kwargs_with_default",
(),
{"a": TEST_INPUT, "b": 42},
100.0,
),
]
)
def test_separate_input_attributes_from_arguments_correct_on(
self, _, args, kwargs, expected_c
):
param_schemas = (
ParamSchema(name="a", type=INT64, is_input=True),
ParamSchema(name="b", type=int, is_input=False),
ParamSchema(name="c", type=float, default=100.0, is_input=False),
)

expected_inputs = [TEST_INPUT]
expected_attributes = collections.OrderedDict(
[
("b", 42),
("c", expected_c),
]
)

inputs, attributes = separate_input_attributes_from_arguments(
param_schemas, args, kwargs
)

print("\ninputs: ", inputs)
print("\nexpected_inputs: ", expected_inputs)

self.assertEqual(len(inputs), len(expected_inputs))
for input_, expected_input in zip(inputs, expected_inputs):
self.assertIs(input_, expected_input)
self.assertEqual(attributes, expected_attributes)

@parameterized.parameterized.expand(
[
(
"extra_positional",
(TEST_INPUT, 42, 0.0, -1),
{},
),
(
"extra_keyword",
(TEST_INPUT, 42, 0.0),
{"unknown": -1},
),
(
"extra_positional_and_keyword",
(TEST_INPUT, 42, 0.0, -1),
{"unknown": -1},
),
]
)
def test_separate_input_attributes_from_arguments_raises_on_extra_args(
self, _, args, kwargs
):
param_schemas = (
ParamSchema(name="a", type=INT64, is_input=True),
ParamSchema(name="b", type=int, is_input=False),
ParamSchema(name="c", type=float, default=100.0, is_input=False),
)

with self.assertRaises(TypeError):
_, _ = separate_input_attributes_from_arguments(param_schemas, args, kwargs)


if __name__ == "__main__":
unittest.main()
8 changes: 8 additions & 0 deletions onnxscript/irbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,14 @@ def __str__(self):
# self.name + " = " + self.value
return helper.printable_attribute(self.attr_proto)

@property
def name(self):
return self.attr_proto.name

@property
def type(self):
return self.attr_proto.type


class IRStmt:
def __init__(
Expand Down
11 changes: 6 additions & 5 deletions onnxscript/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,21 +98,22 @@ class Op:
It belongs to a particular Opset and has a name.
"""

def __init__(self, opset, opname, opschema=None) -> None:

def __init__(
self, opset, opname: str, opschema: Optional[onnx.defs.OpSchema] = None
) -> None:
self.opset = opset
self.opname = opname
self.opschema = opschema

def is_single_op(self):
def is_single_op(self) -> bool:
return isinstance(self.opname, str)

def get_schema(self):
def get_schema(self) -> onnx.defs.OpSchema:
if self.opschema:
return self.opschema
return self.opset[self.opname]

def has_schema(self):
def has_schema(self) -> bool:
return self.opschema is not None

def adapt_kwargs(self, kwargs):
Expand Down