-
Notifications
You must be signed in to change notification settings - Fork 97
feat(atenlib): separate inputs from attributes #368
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b3df419
c255877
2df9e09
6fd5c63
071dd6f
4f2eb8d
9507b8d
7bfbfd7
69cac74
aedb99f
a3175e3
0764da8
22e9b84
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,171 @@ | ||
| """Function for manipulating input parameters of an Op or a OnnxFunction.""" | ||
| from __future__ import annotations | ||
|
|
||
| import collections | ||
| import dataclasses | ||
| from typing import Any, List, OrderedDict, Sequence, Tuple | ||
|
|
||
| import onnx | ||
|
|
||
| from onnxscript import values | ||
|
|
||
| # A special value to indicate that the default value is not specified | ||
| _EmptyDefault = object() | ||
|
|
||
|
|
||
| @dataclasses.dataclass(frozen=True) | ||
| class ParamSchema: | ||
| """A schema for a parameter of an Op or a OnnxFunction. | ||
|
|
||
| Attributes: | ||
| name: The name of the parameter. | ||
| type: The type of the parameter. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we have multiple type representations, which one is this supposed to be? Eg., is this supposed to be python-types (like "int" and "List[FLOAT]") ?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think Python types sg. For now not really enforced and it could be anything |
||
| default: The default value of the parameter. | ||
| is_input: Whether the parameter is an ONNX input. | ||
| """ | ||
|
|
||
| name: str | ||
| type: Any = None # Op input does not have a type, for now | ||
| default: Any = _EmptyDefault | ||
| is_input: bool = True | ||
|
|
||
| def __repr__(self) -> str: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the return value supposed to be valid python or just a description?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just wondering whether str is more appropriate or repr ?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch, thanks! |
||
| param_kind = "INPUT" if self.is_input else "ATTRIBUTE" | ||
| text = f"{self.name}<{param_kind}>: {self.type}" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit (optional): f"{self.name} : {param_kind}({self.type})" or f"{param_kind} {self.name} : {self.type}" would be more readable?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SG!
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will go with |
||
| if self.default is not _EmptyDefault: | ||
| text += f" = {self.default}" | ||
| return text | ||
|
|
||
| @property | ||
| def is_attribute(self) -> bool: | ||
| """Returns True if the parameter is an ONNX attribute.""" | ||
| return not self.is_input | ||
|
|
||
|
|
||
| def extract_param_schema_from_function(onnx_func: values.OnnxFunction) -> List[ParamSchema]: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current Perhaps we should consider updating
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like this idea! I was going to propose exposing ParamSchema somewhere in the function and op as the next thing to do so we have a common interface as #320 |
||
|
|
||
| function_ir = onnx_func.function_ir | ||
| # The first len(func_ir.inputs) arguments are onnx inputs | ||
| inputs = function_ir.inputs | ||
| # The rest is onnx attributes | ||
| attributes = function_ir.attrs | ||
| # Construct a dictionary of attributes with their names specified in the function | ||
| # definition | ||
| attr_name_to_protos = collections.OrderedDict( | ||
| (attr.name, attr) for attr in function_ir.attr_protos | ||
| ) | ||
|
|
||
| # args with default value are attributes | ||
| param_schemas = [] | ||
| for arg in inputs: | ||
| param_schema = ParamSchema(name=arg.name, type=arg.typeinfo, is_input=True) | ||
| param_schemas.append(param_schema) | ||
|
|
||
| for attr_name in attributes: | ||
| # FIXME(justinchuby): Where can we find the type? | ||
| param_schema = ParamSchema(name=attr_name, type=None, is_input=False) | ||
| param_schemas.append(param_schema) | ||
|
|
||
| for name, attr_value in attr_name_to_protos.items(): | ||
| param_schema = ParamSchema( | ||
| name=name, | ||
| type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE[attr_value.type], | ||
| default=_get_attribute_value(attr_value.attr_proto), | ||
| is_input=False, | ||
| ) | ||
| param_schemas.append(param_schema) | ||
| return param_schemas | ||
|
|
||
|
|
||
| _ATTRIBUTE_TYPE_TO_PYTHON_TYPE = { | ||
| onnx.defs.OpSchema.AttrType.FLOAT: float, | ||
| onnx.defs.OpSchema.AttrType.INT: int, | ||
| onnx.defs.OpSchema.AttrType.STRING: str, | ||
| onnx.defs.OpSchema.AttrType.TENSOR: None, | ||
| onnx.defs.OpSchema.AttrType.GRAPH: None, | ||
| onnx.defs.OpSchema.AttrType.SPARSE_TENSOR: None, | ||
| onnx.defs.OpSchema.AttrType.TYPE_PROTO: None, | ||
| onnx.defs.OpSchema.AttrType.FLOATS: Sequence[float], | ||
| onnx.defs.OpSchema.AttrType.INTS: Sequence[int], | ||
| onnx.defs.OpSchema.AttrType.STRINGS: Sequence[str], | ||
| onnx.defs.OpSchema.AttrType.TENSORS: None, | ||
| onnx.defs.OpSchema.AttrType.GRAPHS: None, | ||
| onnx.defs.OpSchema.AttrType.SPARSE_TENSORS: None, | ||
| onnx.defs.OpSchema.AttrType.TYPE_PROTOS: None, | ||
| } | ||
|
|
||
|
|
||
| def _get_attribute_value(attr_proto): | ||
| if attr_proto.type == onnx.AttributeProto.UNDEFINED: | ||
| return _EmptyDefault | ||
| if attr_proto.type == onnx.AttributeProto.FLOAT: | ||
| return attr_proto.f | ||
| if attr_proto.type == onnx.AttributeProto.INT: | ||
| return attr_proto.i | ||
| if attr_proto.type == onnx.AttributeProto.STRING: | ||
| return attr_proto.s | ||
| if attr_proto.type == onnx.AttributeProto.FLOATS: | ||
| return [float(v) for v in attr_proto.f] | ||
| if attr_proto.type == onnx.AttributeProto.INTS: | ||
| return [int(v) for v in attr_proto.i] | ||
| raise TypeError(f"Unsupported attribute type: {attr_proto.type}") | ||
|
|
||
|
|
||
| def extract_param_schema_from_op_schema(op_schema: onnx.defs.OpSchema) -> List[ParamSchema]: | ||
| param_schemas = [] | ||
| for input_ in op_schema.inputs: | ||
| param_schema = ParamSchema(name=input_.name, is_input=True) | ||
| param_schemas.append(param_schema) | ||
| for attr_name, attribute in op_schema.attributes.items(): | ||
| default_attr_proto = attribute.default_value | ||
| param_schema = ParamSchema( | ||
| name=attr_name, | ||
| type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE[attribute.type], | ||
| default=_get_attribute_value(default_attr_proto), | ||
| is_input=False, | ||
| ) | ||
| param_schemas.append(param_schema) | ||
|
|
||
| return param_schemas | ||
|
|
||
|
|
||
| def separate_input_attributes_from_arguments( | ||
| param_schemas: Sequence[ParamSchema], | ||
| args, | ||
| kwargs, | ||
| ) -> Tuple[List[Any], OrderedDict[str, Any]]: | ||
| """Separate Python args and kwargs into ONNX inputs and attributes. | ||
|
|
||
| Args: | ||
| param_schemas: The parameter schemas of an Op or a OnnxFunction. | ||
| args: The Python positional arguments supplied by the caller. | ||
| kwargs: The Python keyword arguments supplied by the caller. | ||
|
|
||
| Returns: | ||
| A tuple of two elements: | ||
| - A list of ONNX inputs. | ||
| - An ordered dictionary of ONNX attribute names and values. | ||
| """ | ||
| # args, kwargs and param_schemas should be all in order | ||
| # user might not specify all attributes | ||
| if len(args) + len(kwargs) > len(param_schemas): | ||
| raise TypeError("Inputs are more than expected in schema") | ||
|
|
||
| onnx_inputs = [] | ||
| onnx_attributes = OrderedDict() | ||
| for i, param in enumerate(param_schemas): | ||
| if i < len(args): | ||
| if not param.is_attribute: | ||
| onnx_inputs.append(args[i]) | ||
| else: | ||
| onnx_attributes[param.name] = args[i] | ||
| elif param.name in kwargs: | ||
| if not param.is_attribute: | ||
| onnx_inputs.append(kwargs[param.name]) | ||
| else: | ||
| onnx_attributes[param.name] = kwargs[param.name] | ||
| else: | ||
| # input doesn't have default | ||
| onnx_attributes[param.name] = param.default | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I haven't seen the usage of this function. Perhaps some usage contexts would not want default-values to be filled in. Anyway, I guess we can cross that bridge when we come to it.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a fill_defaults option to the function |
||
|
|
||
| return onnx_inputs, onnx_attributes | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,124 @@ | ||
| import collections | ||
| import unittest | ||
|
|
||
| import parameterized | ||
|
|
||
| from onnxscript import INT64 | ||
| from onnxscript.function_libs.torch_aten.param_manipulation import ( | ||
| ParamSchema, | ||
| separate_input_attributes_from_arguments, | ||
| ) | ||
|
|
||
| TEST_INPUT = "TEST_INPUT" | ||
|
|
||
|
|
||
| class TestParamManipulation(unittest.TestCase): | ||
| @parameterized.parameterized.expand( | ||
| [ | ||
| ( | ||
| "all_positional", | ||
| (TEST_INPUT, 42, 0.0), | ||
| {}, | ||
| 0.0, | ||
| ), | ||
| ( | ||
| "positional_with_default", | ||
| (TEST_INPUT, 42), | ||
| {}, | ||
| 100.0, | ||
| ), | ||
| ( | ||
| "positional_with_default_and_kwargs", | ||
| (TEST_INPUT,), | ||
| {"b": 42}, | ||
| 100.0, | ||
| ), | ||
| ( | ||
| "positional_with_kwargs", | ||
| (TEST_INPUT, 42), | ||
| {"c": 0.0}, | ||
| 0.0, | ||
| ), | ||
| ( | ||
| "positional_input_with_kwargs_attribute", | ||
| (TEST_INPUT,), | ||
| {"b": 42, "c": 0.0}, | ||
| 0.0, | ||
| ), | ||
| ( | ||
| "all_kwargs", | ||
| (), | ||
| {"a": TEST_INPUT, "b": 42, "c": 0.0}, | ||
| 0.0, | ||
| ), | ||
| ( | ||
| "all_kwargs_with_default", | ||
| (), | ||
| {"a": TEST_INPUT, "b": 42}, | ||
| 100.0, | ||
| ), | ||
| ] | ||
| ) | ||
| def test_separate_input_attributes_from_arguments_correct_on( | ||
| self, _, args, kwargs, expected_c | ||
| ): | ||
| param_schemas = ( | ||
| ParamSchema(name="a", type=INT64, is_input=True), | ||
| ParamSchema(name="b", type=int, is_input=False), | ||
| ParamSchema(name="c", type=float, default=100.0, is_input=False), | ||
| ) | ||
|
|
||
| expected_inputs = [TEST_INPUT] | ||
| expected_attributes = collections.OrderedDict( | ||
| [ | ||
| ("b", 42), | ||
| ("c", expected_c), | ||
| ] | ||
| ) | ||
|
|
||
| inputs, attributes = separate_input_attributes_from_arguments( | ||
| param_schemas, args, kwargs | ||
| ) | ||
|
|
||
| print("\ninputs: ", inputs) | ||
| print("\nexpected_inputs: ", expected_inputs) | ||
|
|
||
| self.assertEqual(len(inputs), len(expected_inputs)) | ||
| for input_, expected_input in zip(inputs, expected_inputs): | ||
| self.assertIs(input_, expected_input) | ||
| self.assertEqual(attributes, expected_attributes) | ||
|
|
||
| @parameterized.parameterized.expand( | ||
| [ | ||
| ( | ||
| "extra_positional", | ||
| (TEST_INPUT, 42, 0.0, -1), | ||
| {}, | ||
| ), | ||
| ( | ||
| "extra_keyword", | ||
| (TEST_INPUT, 42, 0.0), | ||
| {"unknown": -1}, | ||
| ), | ||
| ( | ||
| "extra_positional_and_keyword", | ||
| (TEST_INPUT, 42, 0.0, -1), | ||
| {"unknown": -1}, | ||
| ), | ||
| ] | ||
| ) | ||
| def test_separate_input_attributes_from_arguments_raises_on_extra_args( | ||
| self, _, args, kwargs | ||
| ): | ||
| param_schemas = ( | ||
| ParamSchema(name="a", type=INT64, is_input=True), | ||
| ParamSchema(name="b", type=int, is_input=False), | ||
| ParamSchema(name="c", type=float, default=100.0, is_input=False), | ||
| ) | ||
|
|
||
| with self.assertRaises(TypeError): | ||
| _, _ = separate_input_attributes_from_arguments(param_schemas, args, kwargs) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
Uh oh!
There was an error while loading. Please reload this page.