4
4
# --------------------------------------------------------------------------
5
5
from __future__ import annotations
6
6
7
+ import dataclasses
7
8
import io
8
9
import logging
9
10
import warnings
@@ -108,9 +109,15 @@ def _opt_var_to_str(x):
108
109
109
110
110
111
class IRAttributeValue :
111
- """An attribute value (representing an actual parameter)."""
112
+ """An attribute value (representing an actual parameter).
112
113
113
- def __init__ (self , attrproto ) -> None :
114
+ Attributes:
115
+ name: The name of the attribute.
116
+ type: The type of the attribute.
117
+ attr_proto: The attribute proto.
118
+ """
119
+
120
+ def __init__ (self , attrproto : onnx .AttributeProto ) -> None :
114
121
self .attr_proto = attrproto
115
122
116
123
def __str__ (self ):
@@ -120,14 +127,54 @@ def __str__(self):
120
127
return helper .printable_attribute (self .attr_proto )
121
128
122
129
@property
123
- def name (self ):
130
+ def name (self ) -> str :
124
131
return self .attr_proto .name
125
132
126
133
@property
127
- def type (self ):
134
+ def type (self ) -> onnx . AttributeProto . AttributeType :
128
135
return self .attr_proto .type
129
136
130
137
138
+ @dataclasses .dataclass (frozen = True )
139
+ class IRAttributeParameter :
140
+ """An attribute parameter (representing a formal parameter).
141
+
142
+ It may or may not carry a default value.
143
+
144
+ Attributes:
145
+ name: The name of the attribute.
146
+ type: The type of the attribute.
147
+ default_value: The default value of the attribute.
148
+ has_default: Whether the attribute has a default value.
149
+ attr_proto: The attribute proto.
150
+ """
151
+
152
+ name : str
153
+ type : onnx .AttributeProto .AttributeType
154
+ default_value : str | int | float | None = None
155
+
156
+ # TODO(justinchuby): Validate the default_value is the same type as specified in AttributeType.
157
+
158
+ def __str__ (self ):
159
+ if self .has_default :
160
+ return helper .printable_attribute (self .attr_proto )
161
+ # TODO(justinchuby): Include a readable type name.
162
+ return self .name
163
+
164
+ @property
165
+ def has_default (self ):
166
+ return self .default_value is not None
167
+
168
+ @property
169
+ def attr_proto (self ) -> onnx .AttributeProto :
170
+ if not self .has_default :
171
+ raise ValueError (
172
+ "Attribute has no default value. Only attributes with default "
173
+ "values can be converted to AttributeProto."
174
+ )
175
+ return helper .make_attribute (self .name , self .default_value )
176
+
177
+
131
178
class IRStmt :
132
179
def __init__ (
133
180
self ,
@@ -191,9 +238,7 @@ def __init__(self, name: str, domain: str = "") -> None:
191
238
self .outputs : list [IRVar ] = []
192
239
self .stmts : list [IRStmt ] = []
193
240
# attribute parameters
194
- self .attrs : list [str ] = []
195
- # attribute parameters with default value
196
- self .attr_protos : list [IRAttributeValue ] = []
241
+ self .attrs : list [IRAttributeParameter ] = []
197
242
self .called_functions : dict [str , onnx .FunctionProto ] = {}
198
243
self .docstring : str = ""
199
244
# a dictionary of nested function-definitions
@@ -207,11 +252,10 @@ def assigned_names(self) -> Sequence[str]:
207
252
208
253
def __str__ (self ):
209
254
attrs = _format (self .attrs , "<" , ", " , ">" ) if self .attrs else ""
210
- attr_protos = _format (self .attr_protos , "<" , ", " , ">" ) if self .attr_protos else ""
211
255
inputs = _format ([x .typed_str () for x in self .inputs ], "(" , ", " , ")" )
212
256
outputs = _format ([x .typed_str () for x in self .outputs ], "(" , ", " , ")" )
213
257
stmts = _format (self .stmts , "\n {\n " , "\n " , "\n }\n " )
214
- return f"{ self .name } { attrs } { attr_protos } { inputs } => { outputs } { stmts } "
258
+ return f"{ self .name } { attrs } { inputs } => { outputs } { stmts } "
215
259
216
260
def append_docstring (self , docstring ):
217
261
self .docstring += docstring
@@ -225,11 +269,8 @@ def append_input(self, name: IRVar) -> None:
225
269
def append_output (self , name : IRVar ) -> None :
226
270
self .outputs .append (name )
227
271
228
- def add_attr_parameter (self , attr : str | IRAttributeValue ) -> None :
229
- if isinstance (attr , IRAttributeValue ):
230
- self .attr_protos .append (attr )
231
- else :
232
- self .attrs .append (attr )
272
+ def add_attr_parameter (self , attr : IRAttributeParameter ) -> None :
273
+ self .attrs .append (attr )
233
274
234
275
def debug_print (self ):
235
276
if logger .isEnabledFor (logging .DEBUG ):
@@ -398,19 +439,19 @@ def to_function_proto(self) -> onnx.FunctionProto:
398
439
onnx .helper .make_opsetid (domain , version ) for domain , version in opsets .items ()
399
440
]
400
441
401
- # attribute_proto is introduced in version onnx==1.13 .0.
442
+ # attribute_proto is introduced in version onnx==1.14 .0.
402
443
# If this attribute is available, onnxscript uses it to
403
444
# default values for attributes. The function has then two
404
445
# lists, one list for attributes without default values,
405
446
# another one for attributes with default values.
406
447
# If this *attribute_proto* is not available,
407
- # all attributes with a default value are moved to the first
448
+ # all attributes are moved to the first
408
449
# list, default values are removed.
409
450
# TODO: remove this when onnx with attribute_proto is released.
410
451
if hasattr (onnx .FunctionProto , "attribute_proto" ):
411
- atts = self .attrs
452
+ attribute_names = [ attr . name for attr in self .attrs if not attr . has_default ]
412
453
else :
413
- atts = self . attrs + [ a . attr_proto . name for a in self .attr_protos ]
454
+ attribute_names = [ attr . name for attr in self .attrs ]
414
455
415
456
f = helper .make_function (
416
457
self .domain ,
@@ -419,11 +460,13 @@ def to_function_proto(self) -> onnx.FunctionProto:
419
460
outputs = [y .name for y in self .outputs ],
420
461
nodes = nodes ,
421
462
opset_imports = opset_imports , # TODO
422
- attributes = atts ,
463
+ attributes = attribute_names ,
423
464
doc_string = self .docstring ,
424
465
)
425
466
if hasattr (onnx .FunctionProto , "attribute_proto" ):
426
- f .attribute_proto .extend ([a .attr_proto for a in self .attr_protos ])
467
+ f .attribute_proto .extend (
468
+ [attr .attr_proto for attr in self .attrs if attr .has_default ]
469
+ )
427
470
return f
428
471
429
472
@@ -437,10 +480,10 @@ def __init__(self):
437
480
def new_function (self , name : str , domain : str = "" , register : bool = False ):
438
481
if register and (domain , name ) in self .functions :
439
482
raise RuntimeError (f"Function '{ name } ' already exists in domain '{ domain } '." )
440
- fct = IRFunction (name , domain )
483
+ function = IRFunction (name , domain )
441
484
if register :
442
- self .functions [domain , name ] = fct
443
- return fct
485
+ self .functions [domain , name ] = function
486
+ return function
444
487
445
488
def add_docstring (self , fn : IRFunction , docstring : str ):
446
489
fn .append_docstring (docstring )
@@ -454,34 +497,34 @@ def add_stmt(
454
497
attrs : Sequence [IRAttributeValue ],
455
498
sub_functions = None ,
456
499
) -> None :
457
- s = IRStmt (results , callee , args , attrs , sub_functions = sub_functions )
458
- fn .append_stmt (s )
500
+ stmt = IRStmt (results , callee , args , attrs , sub_functions = sub_functions )
501
+ fn .append_stmt (stmt )
459
502
460
503
def add_input (
461
504
self , fn : IRFunction , varname : str , type : IRTypeLike , info : SourceInfo
462
505
) -> None :
463
- v = IRVar (varname , type , info )
464
- fn .append_input (v )
506
+ var = IRVar (varname , type , info )
507
+ fn .append_input (var )
465
508
466
- def add_attr_parameter (self , fn : IRFunction , varname : str , default_value ) -> None :
467
- if default_value is not None :
468
- a = IRAttributeValue (helper .make_attribute (varname , default_value ))
469
- fn .add_attr_parameter (a )
470
- else :
471
- fn .add_attr_parameter (varname )
509
+ def add_attr_parameter (
510
+ self ,
511
+ fn : IRFunction ,
512
+ varname : str ,
513
+ attribute_type : onnx .AttributeProto .AttributeType ,
514
+ default_value : int | float | str | None ,
515
+ ) -> None :
516
+ fn .add_attr_parameter (IRAttributeParameter (varname , attribute_type , default_value ))
472
517
473
- def add_output (self , fn : IRFunction , varname : str , type , info ) -> None :
474
- v = IRVar (varname , type , info )
475
- fn .append_output (v )
518
+ def add_output (self , fn : IRFunction , varname : str , typeinfo , sourceinfo ) -> None :
519
+ var = IRVar (varname , typeinfo , sourceinfo )
520
+ fn .append_output (var )
476
521
477
522
def make_attr (self , attrname : str , attrval : Any ) -> IRAttributeValue :
478
523
return IRAttributeValue (helper .make_attribute (attrname , attrval ))
479
524
480
525
def make_attr_ref (self , attrname : str , refname : str , pytype : type ) -> IRAttributeValue :
481
- a = onnx .AttributeProto ()
482
- a .name = attrname
483
- a .ref_attr_name = refname
484
- type_ = ta .pytype_to_attrtype (pytype )
485
- assert type_ is not None
486
- a .type = type_
487
- return IRAttributeValue (a )
526
+ proto = onnx .AttributeProto ()
527
+ proto .name = attrname
528
+ proto .ref_attr_name = refname
529
+ proto .type = ta .pytype_to_attrtype (pytype )
530
+ return IRAttributeValue (proto )
0 commit comments