@@ -160,20 +160,15 @@ def get_args_and_kwargs_layer_norm(
160160 ),
161161 {"dtype" : torch .float32 },
162162 )
163- if len (inputs_inputs ) > 0 :
164- if "val" in inputs_inputs [0 ].meta :
165- fake_mode = inputs_inputs [0 ].meta ["val" ].fake_mode
166- if fake_mode is not None :
167- with fake_mode :
168- fake_weight = torch .full (
169- other_inputs [0 ], 1 , dtype = torch .float32
170- )
171- weight .meta ["val" ] = fake_weight
172- else :
173- weight .meta ["val" ] = torch .full (
174- other_inputs [0 ], 1 , dtype = torch .float32
175- )
176- copy_node_metadata (weight , inputs_inputs [0 ])
163+ assert (
164+ len (inputs_inputs ) == 1
165+ ), f"Expected 1 input for layer norm weight, got { len (inputs_inputs )} "
166+ assert "val" in inputs_inputs [0 ].meta , "Missing val metadata on input node"
167+ fake_mode = inputs_inputs [0 ].meta ["val" ].fake_mode
168+ assert fake_mode is not None , "fake_mode is None on input node"
169+ with fake_mode :
170+ weight .meta ["val" ] = torch .full (other_inputs [0 ], 1 , dtype = torch .float32 )
171+ copy_node_metadata (weight , inputs_inputs [0 ])
177172
178173 bias = other_inputs [2 ] if len (other_inputs ) > 2 else None
179174
@@ -186,18 +181,15 @@ def get_args_and_kwargs_layer_norm(
186181 ),
187182 {"dtype" : torch .float32 },
188183 )
189- if len (inputs_inputs ) > 0 :
190- if "val" in inputs_inputs [0 ].meta :
191- fake_mode = inputs_inputs [0 ].meta ["val" ].fake_mode
192- if fake_mode is not None :
193- with fake_mode :
194- fake_bias = torch .full (other_inputs [0 ], 0 , dtype = torch .float32 )
195- bias .meta ["val" ] = fake_bias
196- else :
197- bias .meta ["val" ] = torch .full (
198- other_inputs [0 ], 0 , dtype = torch .float32
199- )
200- copy_node_metadata (bias , inputs_inputs [0 ])
184+ assert (
185+ len (inputs_inputs ) == 1
186+ ), f"Expected 1 input for layer norm bias, got { len (inputs_inputs )} "
187+ assert "val" in inputs_inputs [0 ].meta , "Missing val metadata on input node"
188+ fake_mode = inputs_inputs [0 ].meta ["val" ].fake_mode
189+ assert fake_mode is not None , "fake_mode is None on input node"
190+ with fake_mode :
191+ bias .meta ["val" ] = torch .full (other_inputs [0 ], 0 , dtype = torch .float32 )
192+ copy_node_metadata (bias , inputs_inputs [0 ])
201193
202194 # Make the args and kwargs for the replacement op
203195 args = tuple (inputs_inputs + [scale , zero_point ])
@@ -373,16 +365,15 @@ def get_args_and_kwargs_softmax(
373365 ),
374366 {"dtype" : torch .int32 },
375367 )
376- if len (inputs_inputs ) > 0 :
377- if "val" in inputs_inputs [0 ].meta :
378- fake_mode = inputs_inputs [0 ].meta ["val" ].fake_mode
379- if fake_mode is not None :
380- with fake_mode :
381- fake_mask = torch .full (mask_shape , 0.0 , dtype = torch .int32 )
382- mask_tensor .meta ["val" ] = fake_mask
383- else :
384- mask_tensor .meta ["val" ] = torch .full (mask_shape , 0.0 , dtype = torch .int32 )
385- copy_node_metadata (mask_tensor , inputs_inputs [0 ])
368+ assert (
369+ len (inputs_inputs ) == 1
370+ ), f"Expected 1 input for softmax, got { len (inputs_inputs )} "
371+ assert "val" in inputs_inputs [0 ].meta , "Missing val metadata on input node"
372+ fake_mode = inputs_inputs [0 ].meta ["val" ].fake_mode
373+ assert fake_mode is not None , "fake_mode is None on input node"
374+ with fake_mode :
375+ mask_tensor .meta ["val" ] = torch .full (mask_shape , 0.0 , dtype = torch .int32 )
376+ copy_node_metadata (mask_tensor , inputs_inputs [0 ])
386377 # Make the scale and zero_point tensors
387378 in_scale = dequants_inputs [0 ].args [1 ]
388379 in_zero_point = dequants_inputs [0 ].args [2 ]
@@ -636,25 +627,18 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
636627 torch .ops .aten .transpose .int ,
637628 (weights_inputs [0 ], 0 , 1 ),
638629 )
639- if "val" in weights_inputs [0 ].meta :
640- original_val = weights_inputs [0 ].meta ["val" ]
641- fake_mode = original_val .fake_mode
642- if fake_mode is not None :
643- with fake_mode :
644- transposed_val = torch .ops .aten .transpose .int (
645- original_val , 0 , 1
646- )
647- transposed_weights .meta ["val" ] = transposed_val
648- else :
649- transposed_shape = list (original_val .shape )
650- transposed_shape [0 ], transposed_shape [1 ] = (
651- transposed_shape [1 ],
652- transposed_shape [0 ],
653- )
654- transposed_weights .meta ["val" ] = torch .zeros (
655- transposed_shape , dtype = original_val .dtype
656- )
657- copy_node_metadata (transposed_weights , weights_inputs [0 ])
630+ assert (
631+ "val" in weights_inputs [0 ].meta
632+ ), "Missing val metadata on weight node"
633+ original_val = weights_inputs [0 ].meta ["val" ]
634+ assert (
635+ original_val .fake_mode is not None
636+ ), "fake_mode is None on weight node"
637+ with original_val .fake_mode :
638+ transposed_weights .meta ["val" ] = (
639+ torch .ops .aten .transpose .int (original_val , 0 , 1 )
640+ )
641+ copy_node_metadata (transposed_weights , weights_inputs [0 ])
658642
659643 # Call linear with transposed weight
660644 args , kwargs = get_args_and_kwargs_linear (
0 commit comments