|
| 1 | +"""Function for manipulating input parameters of an Op or a OnnxFunction.""" |
| 2 | +from __future__ import annotations |
| 3 | + |
| 4 | +import collections |
| 5 | +import dataclasses |
| 6 | +from typing import Any, List, OrderedDict, Sequence, Tuple |
| 7 | + |
| 8 | +import onnx |
| 9 | + |
| 10 | +from onnxscript import values |
| 11 | + |
| 12 | +# A special value to indicate that the default value is not specified |
| 13 | +_EmptyDefault = object() |
| 14 | + |
| 15 | + |
| 16 | +@dataclasses.dataclass(frozen=True) |
| 17 | +class ParamSchema: |
| 18 | + """A schema for a parameter of an Op or a OnnxFunction. |
| 19 | +
|
| 20 | + Attributes: |
| 21 | + name: The name of the parameter. |
| 22 | + type: The type of the parameter. |
| 23 | + default: The default value of the parameter. |
| 24 | + is_input: Whether the parameter is an ONNX input. |
| 25 | + """ |
| 26 | + |
| 27 | + name: str |
| 28 | + type: Any = None # Op input does not have a type, for now |
| 29 | + default: Any = _EmptyDefault |
| 30 | + is_input: bool = True |
| 31 | + |
| 32 | + def __repr__(self) -> str: |
| 33 | + param_kind = "INPUT" if self.is_input else "ATTRIBUTE" |
| 34 | + text = f"{self.name}<{param_kind}>: {self.type}" |
| 35 | + if self.default is not _EmptyDefault: |
| 36 | + text += f" = {self.default}" |
| 37 | + return text |
| 38 | + |
| 39 | + @property |
| 40 | + def is_attribute(self) -> bool: |
| 41 | + """Returns True if the parameter is an ONNX attribute.""" |
| 42 | + return not self.is_input |
| 43 | + |
| 44 | + |
| 45 | +def extract_param_schema_from_function(onnx_func: values.OnnxFunction) -> List[ParamSchema]: |
| 46 | + |
| 47 | + function_ir = onnx_func.function_ir |
| 48 | + # The first len(func_ir.inputs) arguments are onnx inputs |
| 49 | + inputs = function_ir.inputs |
| 50 | + # The rest is onnx attributes |
| 51 | + attributes = function_ir.attrs |
| 52 | + # Construct a dictionary of attributes with their names specified in the function |
| 53 | + # definition |
| 54 | + attr_name_to_protos = collections.OrderedDict( |
| 55 | + (attr.name, attr) for attr in function_ir.attr_protos |
| 56 | + ) |
| 57 | + |
| 58 | + # args with default value are attributes |
| 59 | + param_schemas = [] |
| 60 | + for arg in inputs: |
| 61 | + param_schema = ParamSchema(name=arg.name, type=arg.typeinfo, is_input=True) |
| 62 | + param_schemas.append(param_schema) |
| 63 | + |
| 64 | + for attr_name in attributes: |
| 65 | + # FIXME(justinchuby): Where can we find the type? |
| 66 | + param_schema = ParamSchema(name=attr_name, type=None, is_input=False) |
| 67 | + param_schemas.append(param_schema) |
| 68 | + |
| 69 | + for name, attr_value in attr_name_to_protos.items(): |
| 70 | + param_schema = ParamSchema( |
| 71 | + name=name, |
| 72 | + type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE[attr_value.type], |
| 73 | + default=_get_attribute_value(attr_value.attr_proto), |
| 74 | + is_input=False, |
| 75 | + ) |
| 76 | + param_schemas.append(param_schema) |
| 77 | + return param_schemas |
| 78 | + |
| 79 | + |
| 80 | +_ATTRIBUTE_TYPE_TO_PYTHON_TYPE = { |
| 81 | + onnx.defs.OpSchema.AttrType.FLOAT: float, |
| 82 | + onnx.defs.OpSchema.AttrType.INT: int, |
| 83 | + onnx.defs.OpSchema.AttrType.STRING: str, |
| 84 | + onnx.defs.OpSchema.AttrType.TENSOR: None, |
| 85 | + onnx.defs.OpSchema.AttrType.GRAPH: None, |
| 86 | + onnx.defs.OpSchema.AttrType.SPARSE_TENSOR: None, |
| 87 | + onnx.defs.OpSchema.AttrType.TYPE_PROTO: None, |
| 88 | + onnx.defs.OpSchema.AttrType.FLOATS: Sequence[float], |
| 89 | + onnx.defs.OpSchema.AttrType.INTS: Sequence[int], |
| 90 | + onnx.defs.OpSchema.AttrType.STRINGS: Sequence[str], |
| 91 | + onnx.defs.OpSchema.AttrType.TENSORS: None, |
| 92 | + onnx.defs.OpSchema.AttrType.GRAPHS: None, |
| 93 | + onnx.defs.OpSchema.AttrType.SPARSE_TENSORS: None, |
| 94 | + onnx.defs.OpSchema.AttrType.TYPE_PROTOS: None, |
| 95 | +} |
| 96 | + |
| 97 | + |
| 98 | +def _get_attribute_value(attr_proto): |
| 99 | + if attr_proto.type == onnx.AttributeProto.UNDEFINED: |
| 100 | + return _EmptyDefault |
| 101 | + if attr_proto.type == onnx.AttributeProto.FLOAT: |
| 102 | + return attr_proto.f |
| 103 | + if attr_proto.type == onnx.AttributeProto.INT: |
| 104 | + return attr_proto.i |
| 105 | + if attr_proto.type == onnx.AttributeProto.STRING: |
| 106 | + return attr_proto.s |
| 107 | + if attr_proto.type == onnx.AttributeProto.FLOATS: |
| 108 | + return [float(v) for v in attr_proto.f] |
| 109 | + if attr_proto.type == onnx.AttributeProto.INTS: |
| 110 | + return [int(v) for v in attr_proto.i] |
| 111 | + raise TypeError(f"Unsupported attribute type: {attr_proto.type}") |
| 112 | + |
| 113 | + |
| 114 | +def extract_param_schema_from_op_schema(op_schema: onnx.defs.OpSchema) -> List[ParamSchema]: |
| 115 | + param_schemas = [] |
| 116 | + for input_ in op_schema.inputs: |
| 117 | + param_schema = ParamSchema(name=input_.name, is_input=True) |
| 118 | + param_schemas.append(param_schema) |
| 119 | + for attr_name, attribute in op_schema.attributes.items(): |
| 120 | + default_attr_proto = attribute.default_value |
| 121 | + param_schema = ParamSchema( |
| 122 | + name=attr_name, |
| 123 | + type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE[attribute.type], |
| 124 | + default=_get_attribute_value(default_attr_proto), |
| 125 | + is_input=False, |
| 126 | + ) |
| 127 | + param_schemas.append(param_schema) |
| 128 | + |
| 129 | + return param_schemas |
| 130 | + |
| 131 | + |
| 132 | +def separate_input_attributes_from_arguments( |
| 133 | + param_schemas: Sequence[ParamSchema], |
| 134 | + args, |
| 135 | + kwargs, |
| 136 | +) -> Tuple[List[Any], OrderedDict[str, Any]]: |
| 137 | + """Separate Python args and kwargs into ONNX inputs and attributes. |
| 138 | +
|
| 139 | + Args: |
| 140 | + param_schemas: The parameter schemas of an Op or a OnnxFunction. |
| 141 | + args: The Python positional arguments supplied by the caller. |
| 142 | + kwargs: The Python keyword arguments supplied by the caller. |
| 143 | +
|
| 144 | + Returns: |
| 145 | + A tuple of two elements: |
| 146 | + - A list of ONNX inputs. |
| 147 | + - An ordered dictionary of ONNX attribute names and values. |
| 148 | + """ |
| 149 | + # args, kwargs and param_schemas should be all in order |
| 150 | + # user might not specify all attributes |
| 151 | + if len(args) + len(kwargs) > len(param_schemas): |
| 152 | + raise TypeError("Inputs are more than expected in schema") |
| 153 | + |
| 154 | + onnx_inputs = [] |
| 155 | + onnx_attributes = OrderedDict() |
| 156 | + for i, param in enumerate(param_schemas): |
| 157 | + if i < len(args): |
| 158 | + if not param.is_attribute: |
| 159 | + onnx_inputs.append(args[i]) |
| 160 | + else: |
| 161 | + onnx_attributes[param.name] = args[i] |
| 162 | + elif param.name in kwargs: |
| 163 | + if not param.is_attribute: |
| 164 | + onnx_inputs.append(kwargs[param.name]) |
| 165 | + else: |
| 166 | + onnx_attributes[param.name] = kwargs[param.name] |
| 167 | + else: |
| 168 | + # input doesn't have default |
| 169 | + onnx_attributes[param.name] = param.default |
| 170 | + |
| 171 | + return onnx_inputs, onnx_attributes |
0 commit comments