8
8
import logging
9
9
import types
10
10
from enum import IntFlag
11
- from typing import Any , Optional , Sequence , _GenericAlias # type: ignore[attr-defined]
11
+ from typing import _GenericAlias # type: ignore[attr-defined]
12
+ from typing import Any , Optional , Sequence
12
13
13
14
import onnx
14
15
import onnx .defs
15
16
16
- from onnxscript import irbuilder , sourceinfo
17
+ from onnxscript import irbuilder , sourceinfo , type_annotation
18
+ from onnxscript ._internal import version_utils
17
19
18
20
_ATTRIBUTE_TYPE_TO_PYTHON_TYPE = {
19
21
onnx .defs .OpSchema .AttrType .FLOAT : float ,
34
36
35
37
# A special value to indicate that the default value is not specified
36
38
_EmptyDefault = object ()
39
+ _ONNX_OP_SCHEMA_WRITABLE = not version_utils .onnx_older_than ("1.14" )
37
40
38
41
39
42
class Opset :
@@ -173,7 +176,7 @@ def __init__(
173
176
) -> None :
174
177
self .opset = opset
175
178
self .opname = opname
176
- self .opschema = opschema
179
+ self ._opschema = opschema
177
180
self ._param_schemas : Optional [tuple [ParamSchema , ...]] = None
178
181
179
182
def __call__ (self , * args , ** kwargs ):
@@ -190,9 +193,13 @@ def __call__(self, *args, **kwargs):
190
193
def is_single_op (self ) -> bool :
191
194
return isinstance (self .opname , str )
192
195
196
+ @property
197
+ def opschema (self ) -> Optional [onnx .defs .OpSchema ]:
198
+ return self ._opschema
199
+
193
200
def get_schema (self ) -> Optional [onnx .defs .OpSchema ]:
194
201
"""Returns the ONNX OpSchema for this op."""
195
- if self .opschema :
202
+ if self .opschema is not None :
196
203
return self .opschema
197
204
return self .opset [self .opname ]
198
205
@@ -249,6 +256,100 @@ class OnnxClosure:
249
256
function : Any
250
257
251
258
259
+ @dataclasses .dataclass
260
+ class TypeConstraint :
261
+ """Represents a type constraint for an ONNX op.
262
+
263
+ Attributes:
264
+ name: The name of the type constraint.
265
+ allowed_types: The allowed types for the type constraint.
266
+ """
267
+
268
+ name : str
269
+ allowed_types : list [str ]
270
+ description : str = ""
271
+
272
+ def as_tuple (self ) -> tuple [str , list [str ], str ]:
273
+ """Returns the type constraint as a tuple."""
274
+ return (self .name , self .allowed_types , self .description )
275
+
276
+
277
+ def op_schema_from_function_ir (
278
+ function_ir : irbuilder .IRFunction , opset : Opset
279
+ ) -> onnx .defs .OpSchema :
280
+ """Construct an ONNX OpSchema from an IRFunction."""
281
+
282
+ # Find all distinct types in the inputs and outputs
283
+ distinct_types = {arg .typeinfo for arg in function_ir .inputs }.union (
284
+ {arg .typeinfo for arg in function_ir .outputs }
285
+ )
286
+ # Create a mapping from type to a unique name
287
+ type_to_constraint = {}
288
+ for i , type_ in enumerate (distinct_types ):
289
+ name = f"T{ i } "
290
+ type_to_constraint [type_ ] = TypeConstraint (
291
+ name = type_annotation .get_type_constraint_name (type_ ) or name ,
292
+ allowed_types = type_annotation .pytype_to_type_strings (type_ ),
293
+ )
294
+
295
+ formal_inputs = [
296
+ onnx .defs .OpSchema .FormalParameter (
297
+ arg .name ,
298
+ type_to_constraint [arg .typeinfo ].name ,
299
+ param_option = (
300
+ onnx .defs .OpSchema .FormalParameterOption .Optional
301
+ if type_annotation .is_optional (arg .typeinfo )
302
+ else onnx .defs .OpSchema .FormalParameterOption .Single
303
+ ),
304
+ # TODO(justinchu): Check this is_homogeneous thing
305
+ is_homogeneous = True ,
306
+ )
307
+ for arg in function_ir .inputs
308
+ ]
309
+ formal_outputs = [
310
+ onnx .defs .OpSchema .FormalParameter (
311
+ arg .name ,
312
+ type_to_constraint [arg .typeinfo ].name ,
313
+ param_option = (
314
+ onnx .defs .OpSchema .FormalParameterOption .Optional
315
+ if type_annotation .is_optional (arg .typeinfo )
316
+ else onnx .defs .OpSchema .FormalParameterOption .Single
317
+ ),
318
+ # TODO(justinchu): Check this is_homogeneous thing
319
+ is_homogeneous = True ,
320
+ )
321
+ for arg in function_ir .outputs
322
+ ]
323
+
324
+ return onnx .defs .OpSchema (
325
+ function_ir .name ,
326
+ opset .domain ,
327
+ since_version = opset .version ,
328
+ doc = function_ir .docstring ,
329
+ inputs = formal_inputs ,
330
+ outputs = formal_outputs ,
331
+ type_constraints = [constraint .as_tuple () for constraint in type_to_constraint .values ()],
332
+ attributes = [
333
+ * [
334
+ onnx .defs .OpSchema .Attribute (
335
+ attr .name ,
336
+ type = onnx .defs .OpSchema .AttrType (attr .type ),
337
+ )
338
+ for attr in function_ir .attrs
339
+ if not attr .has_default
340
+ ],
341
+ * [
342
+ onnx .defs .OpSchema .Attribute (
343
+ attr .name ,
344
+ default_value = attr .attr_proto ,
345
+ )
346
+ for attr in function_ir .attrs
347
+ if attr .has_default
348
+ ],
349
+ ],
350
+ )
351
+
352
+
252
353
class OnnxFunction (Op ):
253
354
"""Represents an ONNX op for which a function-body has been defined in onnxscript.
254
355
@@ -276,12 +377,26 @@ def __init__(
276
377
self .source = source
277
378
self .kwargs = kwargs
278
379
self ._param_schemas : Optional [tuple [ParamSchema , ...]] = None
380
+ self ._opschema : Optional [onnx .defs .OpSchema ] = None
279
381
280
382
@property
281
383
def name (self ):
282
384
"""Returns the function name."""
283
385
return self .opname
284
386
387
+ @property
388
+ def opschema (self ) -> Optional [onnx .defs .OpSchema ]:
389
+ """Construct an OpSchema from function_ir."""
390
+ if self ._opschema is not None :
391
+ return self ._opschema
392
+
393
+ if not _ONNX_OP_SCHEMA_WRITABLE :
394
+ return None
395
+
396
+ self ._opschema = op_schema_from_function_ir (self .function_ir , self .opset )
397
+
398
+ return self ._opschema
399
+
285
400
def __getitem__ (self , instance ):
286
401
"""Returns a lambda to evaluate function using given evaluator instance.
287
402
@@ -311,6 +426,9 @@ def param_schemas(self) -> tuple[ParamSchema, ...]:
311
426
if self ._param_schemas is not None :
312
427
return self ._param_schemas
313
428
429
+ # NOTE: We generate the parameter schemas from the function_ir instead
430
+ # of relying on the auto generated OpSchema because we need to preserve the keyword
431
+ # argument order from the Python function definition, which is lost in OpSchema.
314
432
function_ir = self .function_ir
315
433
# The first len(func_ir.inputs) arguments are onnx inputs
316
434
inputs = function_ir .inputs
0 commit comments