@@ -70,7 +70,7 @@ def process_value_info(
7070 self , value_info : onnx .ValueInfoProto
7171 ) -> tuple [ir .FunctionId | None , ir .Value ]:
7272 name = value_info .name
73- if len (splits := name .split ("/" )) == 2 : # noqa: PLR2004
73+ if len (splits := name .split ("/" )) == 2 :
7474 # Experimental function value info format.
7575 # To be deprecated after ONNX 1.16, where value_info is introduced in FunctionProto.
7676 function_id , value_name = splits
@@ -79,7 +79,7 @@ def process_value_info(
7979 # 'overload' is introduced in ONNX 1.16, consider it as empty string prior to that.
8080 # The code is for future proof, in case overload is encoded in this format.
8181 overload = ""
82- if len (splits ) == 3 : # noqa: PLR2004
82+ if len (splits ) == 3 :
8383 overload = splits [2 ]
8484 function_id = (domain , function_name , overload )
8585 else :
@@ -96,19 +96,14 @@ def save_to_value_info(
9696 function_id = f"{ domain } ::{ function_name } "
9797
9898 if value .type is not None :
99- return onnx .helper .make_value_info (
100- f"{ function_id } /{ value .name } " , value .type
101- )
99+ return onnx .helper .make_value_info (f"{ function_id } /{ value .name } " , value .type )
102100 return None
103101
104102 def lookup (self , function : onnx .FunctionProto , value_name : str ) -> ir .Value | None :
105103 """Lookup ir value of 'value_name' inside 'function'."""
106104 function_id = ir .get_function_id (function )
107105 function_values = self ._function_values .get (function_id )
108- if (
109- function_values is None
110- or (ir_value := function_values .get (value_name )) is None
111- ):
106+ if function_values is None or (ir_value := function_values .get (value_name )) is None :
112107 logger .debug (
113108 "Lookup Missed %s torch symbolic value info in function %s::%s." ,
114109 value_name ,
@@ -124,9 +119,7 @@ def lookup(self, function: onnx.FunctionProto, value_name: str) -> ir.Value | No
124119 )
125120 return ir_value
126121
127- def bind (
128- self , value : ir .Value , domain : str , function_name : str , overload : str
129- ) -> None :
122+ def bind (self , value : ir .Value , domain : str , function_name : str , overload : str ) -> None :
130123 """Bind ir value 'value' to 'value_name' inside 'function'."""
131124 function_id = (domain , function_name , overload )
132125 self ._function_values .setdefault (function_id , {})[value .name ] = value
@@ -309,9 +302,7 @@ def enter_function_scope(self, function: onnx.FunctionProto) -> None:
309302
310303 def exit_function_scope (self ) -> SubScope :
311304 sub_scope = self .current_scope ().exit_sub_scope ()
312- assert isinstance (
313- sub_scope .owner , onnx .FunctionProto
314- ), "Expected function scope."
305+ assert isinstance (sub_scope .owner , onnx .FunctionProto ), "Expected function scope."
315306 self ._scopes .pop ()
316307 return sub_scope
317308
@@ -483,9 +474,7 @@ def input_element_type(self, node: onnx.NodeProto, index: int) -> int | None:
483474 info = self .get_input (node , index )
484475 return info .element_type if info is not None else None
485476
486- def input_shape (
487- self , node : onnx .NodeProto , index : int
488- ) -> onnx .TensorShapeProto | None :
477+ def input_shape (self , node : onnx .NodeProto , index : int ) -> onnx .TensorShapeProto | None :
489478 info = self .get_input (node , index )
490479 return info .tensor_shape_proto () if info is not None else None
491480
@@ -570,13 +559,11 @@ def process_node(self, node: onnx.NodeProto) -> Sequence[onnx.NodeProto] | None:
570559 # TODO: handle optional inputs
571560 def get_constant_value (i : int ) -> onnx .TensorProto | None :
572561 value = self .input_const_value (node , i )
573- if isinstance (value , np .ndarray ) and value .size < 20 : # noqa: PLR2004
562+ if isinstance (value , np .ndarray ) and value .size < 20 :
574563 return onnx .numpy_helper .from_array (value , node .input [i ])
575564 return None
576565
577- input_types = {
578- x : self .input_type (node , i ) for i , x in enumerate (node .input )
579- }
566+ input_types = {x : self .input_type (node , i ) for i , x in enumerate (node .input )}
580567 input_data = {x : get_constant_value (i ) for i , x in enumerate (node .input )}
581568 input_data = {k : v for k , v in input_data .items () if v is not None }
582569 if any (t is None for t in input_types .values ()):
@@ -593,7 +580,7 @@ def get_constant_value(i: int) -> onnx.TensorProto | None:
593580 output_types = onnx .shape_inference .infer_node_outputs (
594581 schema , node , input_types , input_data
595582 )
596- except Exception as e : # noqa: BLE001
583+ except Exception as e :
597584 logger .debug (
598585 "Skipping shape inference for node %s due to exception: %s" ,
599586 node .name ,
@@ -854,9 +841,7 @@ def process_function_node(
854841
855842 self .enter_function_scope (mutable_function )
856843 if logger .level <= logging .INFO :
857- printable_actual_input_value_infos = [
858- str (x ) for x in actual_input_value_infos
859- ]
844+ printable_actual_input_value_infos = [str (x ) for x in actual_input_value_infos ]
860845 logger .info (
861846 "Actual input value infos: %s" ,
862847 printable_actual_input_value_infos ,
0 commit comments