@@ -41,9 +41,7 @@ def is_non_deterministic_op(node: onnx.NodeProto) -> bool:
41
41
42
42
43
43
def is_constant_op (node : onnx .NodeProto ) -> bool :
44
- return node .op_type in {"Constant" , "ConstantOfShape" } and is_onnx_domain (
45
- node .domain
46
- )
44
+ return node .op_type in {"Constant" , "ConstantOfShape" } and is_onnx_domain (node .domain )
47
45
48
46
49
47
class ConstantFolder (visitor .FunctionCallsiteProtoTransformer ):
@@ -119,14 +117,10 @@ def new_constant(self, name, value):
119
117
info .type = onnx .helper .make_tensor_type_proto (
120
118
onnx .helper .np_dtype_to_tensor_dtype (value .dtype ), value .shape
121
119
)
122
- node = onnx .helper .make_node (
123
- "Constant" , inputs = [], outputs = [name ], value = tensor
124
- )
120
+ node = onnx .helper .make_node ("Constant" , inputs = [], outputs = [name ], value = tensor )
125
121
return [node ]
126
122
127
- def convert_attributes (
128
- self , attributes : Sequence [onnx .AttributeProto ]
129
- ) -> dict [str , Any ]:
123
+ def convert_attributes (self , attributes : Sequence [onnx .AttributeProto ]) -> dict [str , Any ]:
130
124
if self .scopes .current_scope ().current_function_scope ():
131
125
# Need to resolve ref_attr_name if inside a function.
132
126
attr_dict = {}
@@ -138,9 +132,7 @@ def convert_attributes(
138
132
)
139
133
if concrete_attribute is None :
140
134
continue
141
- attr_dict [attribute .name ] = onnx .helper .get_attribute_value (
142
- concrete_attribute
143
- )
135
+ attr_dict [attribute .name ] = onnx .helper .get_attribute_value (concrete_attribute )
144
136
return attr_dict
145
137
return {attr .name : onnx .helper .get_attribute_value (attr ) for attr in attributes }
146
138
@@ -226,9 +218,7 @@ def process_node(self, node: onnx.NodeProto) -> Sequence[onnx.NodeProto] | None:
226
218
self .add_count (op , outputs .size )
227
219
return replacement
228
220
else :
229
- logger .warning (
230
- "Skipping constant folding for op %s with multiple outputs." , op
231
- )
221
+ logger .warning ("Skipping constant folding for op %s with multiple outputs." , op )
232
222
return None
233
223
234
224
def process_function_node (
@@ -241,9 +231,7 @@ def process_function_node(
241
231
# Replace function node with Constant if all outputs are constants
242
232
ir_values = [self .lookup (output_name ) for output_name in node .output ]
243
233
tensors = [
244
- self .foldable_value (
245
- output_name , ir_value .value if ir_value is not None else None
246
- )
234
+ self .foldable_value (output_name , ir_value .value if ir_value is not None else None )
247
235
for output_name , ir_value in zip (node .output , ir_values )
248
236
]
249
237
if all (tensor is not None for tensor in tensors ):
0 commit comments