@@ -108,10 +108,16 @@ def _opt_var_to_str(x):
108
108
109
109
110
110
class IRAttributeValue :
111
- """An attribute value (representing an actual parameter)."""
111
+ """An attribute value (representing an actual parameter).
112
112
113
- def __init__ (self , attrproto ) -> None :
113
+ Attributes:
114
+ attr_proto: The attribute proto
115
+ has_default: Whether the attribute has a default value.
116
+ """
117
+
118
+ def __init__ (self , attrproto , has_default : bool ) -> None :
114
119
self .attr_proto = attrproto
120
+ self .has_default = has_default
115
121
116
122
def __str__ (self ):
117
123
if self .attr_proto .HasField ("ref_attr_name" ):
@@ -191,9 +197,7 @@ def __init__(self, name: str, domain: str = "") -> None:
191
197
self .outputs : list [IRVar ] = []
192
198
self .stmts : list [IRStmt ] = []
193
199
# attribute parameters
194
- self .attrs : list [str ] = []
195
- # attribute parameters with default value
196
- self .attr_protos : list [IRAttributeValue ] = []
200
+ self .attrs : list [IRAttributeValue ] = []
197
201
self .called_functions : dict [str , onnx .FunctionProto ] = {}
198
202
self .docstring : str = ""
199
203
# a dictionary of nested function-definitions
@@ -207,11 +211,10 @@ def assigned_names(self) -> Sequence[str]:
207
211
208
212
def __str__ (self ):
209
213
attrs = _format (self .attrs , "<" , ", " , ">" ) if self .attrs else ""
210
- attr_protos = _format (self .attr_protos , "<" , ", " , ">" ) if self .attr_protos else ""
211
214
inputs = _format ([x .typed_str () for x in self .inputs ], "(" , ", " , ")" )
212
215
outputs = _format ([x .typed_str () for x in self .outputs ], "(" , ", " , ")" )
213
216
stmts = _format (self .stmts , "\n {\n " , "\n " , "\n }\n " )
214
- return f"{ self .name } { attrs } { attr_protos } { inputs } => { outputs } { stmts } "
217
+ return f"{ self .name } { attrs } { inputs } => { outputs } { stmts } "
215
218
216
219
def append_docstring (self , docstring ):
217
220
self .docstring += docstring
@@ -225,11 +228,8 @@ def append_input(self, name: IRVar) -> None:
225
228
def append_output (self , name : IRVar ) -> None :
226
229
self .outputs .append (name )
227
230
228
- def add_attr_parameter (self , attr : str | IRAttributeValue ) -> None :
229
- if isinstance (attr , IRAttributeValue ):
230
- self .attr_protos .append (attr )
231
- else :
232
- self .attrs .append (attr )
231
+ def add_attr_parameter (self , attr : IRAttributeValue ) -> None :
232
+ self .attrs .append (attr )
233
233
234
234
def debug_print (self ):
235
235
if logger .isEnabledFor (logging .DEBUG ):
@@ -398,19 +398,19 @@ def to_function_proto(self) -> onnx.FunctionProto:
398
398
onnx .helper .make_opsetid (domain , version ) for domain , version in opsets .items ()
399
399
]
400
400
401
- # attribute_proto is introduced in version onnx==1.13 .0.
401
+ # attribute_proto is introduced in version onnx==1.14 .0.
402
402
# If this attribute is available, onnxscript uses it to
403
403
# default values for attributes. The function has then two
404
404
# lists, one list for attributes without default values,
405
405
# another one for attributes with default values.
406
406
# If this *attribute_proto* is not available,
407
- # all attributes with a default value are moved to the first
407
+ # all attributes are moved to the first
408
408
# list, default values are removed.
409
409
# TODO: remove this when onnx with attribute_proto is released.
410
410
if hasattr (onnx .FunctionProto , "attribute_proto" ):
411
- atts = self .attrs
411
+ attribute_names = [ attr . name for attr in self .attrs if not attr . has_default ]
412
412
else :
413
- atts = self . attrs + [ a . attr_proto . name for a in self .attr_protos ]
413
+ attribute_names = [ attr . name for attr in self .attrs ]
414
414
415
415
f = helper .make_function (
416
416
self .domain ,
@@ -419,11 +419,13 @@ def to_function_proto(self) -> onnx.FunctionProto:
419
419
outputs = [y .name for y in self .outputs ],
420
420
nodes = nodes ,
421
421
opset_imports = opset_imports , # TODO
422
- attributes = atts ,
422
+ attributes = attribute_names ,
423
423
doc_string = self .docstring ,
424
424
)
425
425
if hasattr (onnx .FunctionProto , "attribute_proto" ):
426
- f .attribute_proto .extend ([a .attr_proto for a in self .attr_protos ])
426
+ f .attribute_proto .extend (
427
+ [attr .attr_proto for attr in self .attrs if attr .has_default ]
428
+ )
427
429
return f
428
430
429
431
@@ -463,25 +465,35 @@ def add_input(
463
465
v = IRVar (varname , type , info )
464
466
fn .append_input (v )
465
467
466
- def add_attr_parameter (self , fn : IRFunction , varname : str , default_value ) -> None :
468
+ def add_attr_parameter (
469
+ self ,
470
+ fn : IRFunction ,
471
+ varname : str ,
472
+ attribute_type : onnx .AttributeProto .AttributeType ,
473
+ default_value ,
474
+ ) -> None :
467
475
if default_value is not None :
468
- a = IRAttributeValue (helper .make_attribute (varname , default_value ))
469
- fn .add_attr_parameter (a )
476
+ fn .add_attr_parameter (
477
+ IRAttributeValue (
478
+ helper .make_attribute (varname , default_value ), has_default = True
479
+ )
480
+ )
470
481
else :
471
- fn .add_attr_parameter (varname )
482
+ proto = onnx .AttributeProto ()
483
+ proto .name = varname
484
+ proto .type = attribute_type
485
+ fn .add_attr_parameter (IRAttributeValue (proto , has_default = False ))
472
486
473
487
def add_output (self , fn : IRFunction , varname : str , type , info ) -> None :
474
488
v = IRVar (varname , type , info )
475
489
fn .append_output (v )
476
490
477
491
def make_attr (self , attrname : str , attrval : Any ) -> IRAttributeValue :
478
- return IRAttributeValue (helper .make_attribute (attrname , attrval ))
492
+ return IRAttributeValue (helper .make_attribute (attrname , attrval ), has_default = True )
479
493
480
494
def make_attr_ref (self , attrname : str , refname : str , pytype : type ) -> IRAttributeValue :
481
495
a = onnx .AttributeProto ()
482
496
a .name = attrname
483
497
a .ref_attr_name = refname
484
- type_ = ta .pytype_to_attrtype (pytype )
485
- assert type_ is not None
486
- a .type = type_
487
- return IRAttributeValue (a )
498
+ a .type = ta .pytype_to_attrtype (pytype )
499
+ return IRAttributeValue (a , has_default = False )
0 commit comments