@@ -41,9 +41,7 @@ def is_non_deterministic_op(node: onnx.NodeProto) -> bool:
4141
4242
4343def is_constant_op (node : onnx .NodeProto ) -> bool :
44- return node .op_type in {"Constant" , "ConstantOfShape" } and is_onnx_domain (
45- node .domain
46- )
44+ return node .op_type in {"Constant" , "ConstantOfShape" } and is_onnx_domain (node .domain )
4745
4846
4947class ConstantFolder (visitor .FunctionCallsiteProtoTransformer ):
@@ -119,14 +117,10 @@ def new_constant(self, name, value):
119117 info .type = onnx .helper .make_tensor_type_proto (
120118 onnx .helper .np_dtype_to_tensor_dtype (value .dtype ), value .shape
121119 )
122- node = onnx .helper .make_node (
123- "Constant" , inputs = [], outputs = [name ], value = tensor
124- )
120+ node = onnx .helper .make_node ("Constant" , inputs = [], outputs = [name ], value = tensor )
125121 return [node ]
126122
127- def convert_attributes (
128- self , attributes : Sequence [onnx .AttributeProto ]
129- ) -> dict [str , Any ]:
123+ def convert_attributes (self , attributes : Sequence [onnx .AttributeProto ]) -> dict [str , Any ]:
130124 if self .scopes .current_scope ().current_function_scope ():
131125 # Need to resolve ref_attr_name if inside a function.
132126 attr_dict = {}
@@ -138,9 +132,7 @@ def convert_attributes(
138132 )
139133 if concrete_attribute is None :
140134 continue
141- attr_dict [attribute .name ] = onnx .helper .get_attribute_value (
142- concrete_attribute
143- )
135+ attr_dict [attribute .name ] = onnx .helper .get_attribute_value (concrete_attribute )
144136 return attr_dict
145137 return {attr .name : onnx .helper .get_attribute_value (attr ) for attr in attributes }
146138
@@ -226,9 +218,7 @@ def process_node(self, node: onnx.NodeProto) -> Sequence[onnx.NodeProto] | None:
226218 self .add_count (op , outputs .size )
227219 return replacement
228220 else :
229- logger .warning (
230- "Skipping constant folding for op %s with multiple outputs." , op
231- )
221+ logger .warning ("Skipping constant folding for op %s with multiple outputs." , op )
232222 return None
233223
234224 def process_function_node (
@@ -241,9 +231,7 @@ def process_function_node(
241231 # Replace function node with Constant if all outputs are constants
242232 ir_values = [self .lookup (output_name ) for output_name in node .output ]
243233 tensors = [
244- self .foldable_value (
245- output_name , ir_value .value if ir_value is not None else None
246- )
234+ self .foldable_value (output_name , ir_value .value if ir_value is not None else None )
247235 for output_name , ir_value in zip (node .output , ir_values )
248236 ]
249237 if all (tensor is not None for tensor in tensors ):
0 commit comments