@@ -163,21 +163,105 @@ def _get_attribute_value(attr_proto: onnx.AttributeProto) -> Any:
163
163
return onnx .helper .get_attribute_value (attr_proto )
164
164
165
165
166
- class Op :
166
+ def param_schemas_from_op_schema (
167
+ op_schema : onnx .defs .OpSchema ,
168
+ ) -> tuple [ParamSchema , ...]:
169
+ """Get the parameter schemas from an ONNX OpSchema."""
170
+ schemas = []
171
+ for input_ in op_schema .inputs :
172
+ param_schema = ParamSchema (
173
+ name = input_ .name ,
174
+ is_input = True ,
175
+ required = (input_ .option != onnx .defs .OpSchema .FormalParameterOption .Optional ),
176
+ is_variadic_input = (
177
+ input_ .option == onnx .defs .OpSchema .FormalParameterOption .Variadic
178
+ ),
179
+ )
180
+ schemas .append (param_schema )
181
+ for attr_name , attribute in op_schema .attributes .items ():
182
+ default_attr_proto = attribute .default_value
183
+ param_schema = ParamSchema (
184
+ name = attr_name ,
185
+ type = _ATTRIBUTE_TYPE_TO_PYTHON_TYPE [attribute .type ],
186
+ default = _get_attribute_value (default_attr_proto ),
187
+ is_input = False ,
188
+ required = attribute .required ,
189
+ )
190
+ schemas .append (param_schema )
191
+
192
+ return tuple (schemas )
193
+
194
+
195
+ def param_schemas_from_function_ir (
196
+ function_ir : irbuilder .IRFunction ,
197
+ ) -> tuple [ParamSchema , ...]:
198
+ """Get the parameter schemas from a FunctionIR."""
199
+ # The first len(func_ir.inputs) arguments are onnx inputs
200
+ # The rest is onnx attributes
201
+
202
+ schemas = []
203
+ for arg in function_ir .inputs :
204
+ if isinstance (arg .typeinfo , onnx .TypeProto .Optional ):
205
+ required = False
206
+ else :
207
+ required = True
208
+ schemas .append (
209
+ ParamSchema (name = arg .name , type = arg .typeinfo , is_input = True , required = required )
210
+ )
211
+
212
+ for attr_parameter in function_ir .attrs :
213
+ schemas .append (
214
+ ParamSchema (
215
+ name = attr_parameter .name ,
216
+ type = _ATTRIBUTE_TYPE_TO_PYTHON_TYPE .get (
217
+ onnx .defs .OpSchema .AttrType (attr_parameter .type ) # type: ignore[call-arg]
218
+ ),
219
+ default = _EmptyDefault
220
+ if attr_parameter .default_value is None
221
+ else attr_parameter .default_value ,
222
+ is_input = False ,
223
+ required = not attr_parameter .has_default ,
224
+ )
225
+ )
226
+
227
+ return tuple (schemas )
228
+
229
+
230
+ @typing .runtime_checkable
231
+ class OpLike (Protocol ):
232
+ """A protocol for objects that have an ONNX OpSchema."""
233
+
234
+ @property
235
+ def name (self ) -> str :
236
+ ...
237
+
238
+ @property
239
+ def opset (self ) -> Opset :
240
+ ...
241
+
242
+ @property
243
+ def opschema (self ) -> Optional [onnx .defs .OpSchema ]:
244
+ ...
245
+
246
+ def param_schemas (self ) -> Optional [tuple [ParamSchema , ...]]:
247
+ ...
248
+
249
+
250
+ class Op (OpLike ):
167
251
"""Represents an ONNX op instance (for example, the MatMul op from ONNX opset version 13).
168
252
It belongs to a particular Opset and has a name.
169
253
170
254
Attributes:
171
255
opset: The Opset that this op belongs to.
172
- opname : The name of the op.
256
+ name : The name of the op.
173
257
opschema: The ONNX OpSchema for the op.
174
258
"""
175
259
176
260
def __init__ (
177
- self , opset , opname : str , opschema : Optional [onnx .defs .OpSchema ] = None
261
+ self , opset : Opset , opname : str , opschema : Optional [onnx .defs .OpSchema ] = None
178
262
) -> None :
179
- self .opset = opset
180
- self .opname = opname
263
+ self ._opset = opset
264
+ self ._name = opname
181
265
self ._opschema = opschema
182
266
self ._param_schemas : Optional [tuple [ParamSchema , ...]] = None
183
267
@@ -188,12 +272,17 @@ def __call__(self, *args, **kwargs):
188
272
schema = self .get_schema ()
189
273
if schema is None :
190
274
raise RuntimeError (
191
- f"Op '{ self .opname } ' does not have an OpSchema and cannot be evaluated."
275
+ f"Op '{ self .name } ' does not have an OpSchema and cannot be evaluated."
192
276
)
193
277
return evaluator .default ().eval (schema , args , kwargs )
194
278
195
- def is_single_op (self ) -> bool :
196
- return isinstance (self .opname , str )
279
+ @property
280
+ def name (self ) -> str :
281
+ return self ._name
282
+
283
+ @property
284
+ def opset (self ) -> Opset :
285
+ return self ._opset
197
286
198
287
@property
199
288
def opschema (self ) -> Optional [onnx .defs .OpSchema ]:
@@ -203,7 +292,7 @@ def get_schema(self) -> Optional[onnx.defs.OpSchema]:
203
292
"""Returns the ONNX OpSchema for this op."""
204
293
if self .opschema is not None :
205
294
return self .opschema
206
- return self .opset [self .opname ]
295
+ return self .opset [self .name ]
207
296
208
297
def has_schema (self ) -> bool :
209
298
"""Returns True if this op has an OpSchema."""
@@ -217,30 +306,9 @@ def param_schemas(self) -> Optional[tuple[ParamSchema, ...]]:
217
306
op_schema = self .get_schema ()
218
307
if op_schema is None :
219
308
return None
220
- schemas = []
221
- for input_ in op_schema .inputs :
222
- param_schema = ParamSchema (
223
- name = input_ .name ,
224
- is_input = True ,
225
- required = (input_ .option != onnx .defs .OpSchema .FormalParameterOption .Optional ),
226
- is_variadic_input = (
227
- input_ .option == onnx .defs .OpSchema .FormalParameterOption .Variadic
228
- ),
229
- )
230
- schemas .append (param_schema )
231
- for attr_name , attribute in op_schema .attributes .items ():
232
- default_attr_proto = attribute .default_value
233
- param_schema = ParamSchema (
234
- name = attr_name ,
235
- type = _ATTRIBUTE_TYPE_TO_PYTHON_TYPE [attribute .type ],
236
- default = _get_attribute_value (default_attr_proto ),
237
- is_input = False ,
238
- required = attribute .required ,
239
- )
240
- schemas .append (param_schema )
241
309
242
- self ._param_schemas = tuple ( schemas )
243
- return self ._param_schemas # type: ignore[return-value]
310
+ self ._param_schemas = param_schemas_from_op_schema ( op_schema )
311
+ return self ._param_schemas
244
312
245
313
246
314
@dataclasses .dataclass (repr = False , eq = False )
@@ -355,13 +423,14 @@ def op_schema_from_function_ir(
355
423
class OnnxFunction (Op ):
356
424
"""Represents an ONNX op for which a function-body has been defined in onnxscript.
357
425
358
- Args:
359
- opset: opset the function belongs to
360
- pyfun: python function
361
- irfun: python code parsed by class
362
- :class:`onnxscript.converter.Converter`
363
- source: source code used to generate the function
364
- kwargs: additional properties used to construct a ModelProto
426
+ Attributes:
427
+ opset: Opset the function belongs to.
428
+ name: Name of the function.
429
+ function: Python function.
430
+ function_ir: Python code parsed as an :class:`irbuilder.IRFunction`.
431
+ source: Source code used to generate the function.
432
+ kwargs: Additional properties used to construct a ModelProto.
433
+ opschema: Generated ONNX OpSchema for this op.
365
434
"""
366
435
367
436
def __init__ (
@@ -372,6 +441,16 @@ def __init__(
372
441
source : str ,
373
442
kwargs : dict [str , Any ],
374
443
):
444
+ """Constructs an OnnxFunction.
445
+
446
+ Args:
447
+ opset: opset the function belongs to
448
+ pyfun: python function
449
+ irfun: python code parsed by class
450
+ :class:`onnxscript.converter.Converter`
451
+ source: source code used to generate the function
452
+ kwargs: additional properties used to construct a ModelProto
453
+ """
375
454
opset = opset or Opset (irfun .domain , 1 )
376
455
super ().__init__ (opset , irfun .name )
377
456
self .function = pyfun
@@ -383,11 +462,6 @@ def __init__(
383
462
# Set the signature of the class to function's
384
463
self .__signature__ = inspect .signature (pyfun )
385
464
386
- @property
387
- def name (self ):
388
- """Returns the function name."""
389
- return self .opname
390
-
391
465
@property
392
466
def opschema (self ) -> Optional [onnx .defs .OpSchema ]:
393
467
"""Construct an OpSchema from function_ir."""
@@ -433,38 +507,8 @@ def param_schemas(self) -> tuple[ParamSchema, ...]:
433
507
# NOTE: We generate the parameter schemas from the function_ir instead
434
508
# of relying on the auto generated OpSchema because we need to preserve the keyword
435
509
# argument order from the Python function definition, which is lost in OpSchema.
436
- function_ir = self .function_ir
437
- # The first len(func_ir.inputs) arguments are onnx inputs
438
- inputs = function_ir .inputs
439
- # The rest is onnx attributes
440
-
441
- schemas = []
442
- for arg in inputs :
443
- if isinstance (arg .typeinfo , onnx .TypeProto .Optional ):
444
- required = False
445
- else :
446
- required = True
447
- schemas .append (
448
- ParamSchema (name = arg .name , type = arg .typeinfo , is_input = True , required = required )
449
- )
450
-
451
- for attr_parameter in function_ir .attrs :
452
- schemas .append (
453
- ParamSchema (
454
- name = attr_parameter .name ,
455
- type = _ATTRIBUTE_TYPE_TO_PYTHON_TYPE .get (
456
- onnx .defs .OpSchema .AttrType (attr_parameter .type ) # type: ignore[call-arg]
457
- ),
458
- default = _EmptyDefault
459
- if attr_parameter .default_value is None
460
- else attr_parameter .default_value ,
461
- is_input = False ,
462
- required = not attr_parameter .has_default ,
463
- )
464
- )
465
-
466
- self ._param_schemas = tuple (schemas )
467
- return self ._param_schemas # type: ignore[return-value]
510
+ self ._param_schemas = param_schemas_from_function_ir (self .function_ir )
511
+ return self ._param_schemas
468
512
469
513
def to_function_proto (self ):
470
514
"""Converts the function into :class:`onnx.FunctionProto`."""
0 commit comments