Skip to content

Commit aa2f683

Browse files
authored
Replace if-guards in fusion_pass.py with asserts
Differential Revision: D92744743 Pull Request resolved: #17318
1 parent 4630347 commit aa2f683

File tree

1 file changed

+39
-55
lines changed

1 file changed

+39
-55
lines changed

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 39 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -160,20 +160,15 @@ def get_args_and_kwargs_layer_norm(
160160
),
161161
{"dtype": torch.float32},
162162
)
163-
if len(inputs_inputs) > 0:
164-
if "val" in inputs_inputs[0].meta:
165-
fake_mode = inputs_inputs[0].meta["val"].fake_mode
166-
if fake_mode is not None:
167-
with fake_mode:
168-
fake_weight = torch.full(
169-
other_inputs[0], 1, dtype=torch.float32
170-
)
171-
weight.meta["val"] = fake_weight
172-
else:
173-
weight.meta["val"] = torch.full(
174-
other_inputs[0], 1, dtype=torch.float32
175-
)
176-
copy_node_metadata(weight, inputs_inputs[0])
163+
assert (
164+
len(inputs_inputs) == 1
165+
), f"Expected 1 input for layer norm weight, got {len(inputs_inputs)}"
166+
assert "val" in inputs_inputs[0].meta, "Missing val metadata on input node"
167+
fake_mode = inputs_inputs[0].meta["val"].fake_mode
168+
assert fake_mode is not None, "fake_mode is None on input node"
169+
with fake_mode:
170+
weight.meta["val"] = torch.full(other_inputs[0], 1, dtype=torch.float32)
171+
copy_node_metadata(weight, inputs_inputs[0])
177172

178173
bias = other_inputs[2] if len(other_inputs) > 2 else None
179174

@@ -186,18 +181,15 @@ def get_args_and_kwargs_layer_norm(
186181
),
187182
{"dtype": torch.float32},
188183
)
189-
if len(inputs_inputs) > 0:
190-
if "val" in inputs_inputs[0].meta:
191-
fake_mode = inputs_inputs[0].meta["val"].fake_mode
192-
if fake_mode is not None:
193-
with fake_mode:
194-
fake_bias = torch.full(other_inputs[0], 0, dtype=torch.float32)
195-
bias.meta["val"] = fake_bias
196-
else:
197-
bias.meta["val"] = torch.full(
198-
other_inputs[0], 0, dtype=torch.float32
199-
)
200-
copy_node_metadata(bias, inputs_inputs[0])
184+
assert (
185+
len(inputs_inputs) == 1
186+
), f"Expected 1 input for layer norm bias, got {len(inputs_inputs)}"
187+
assert "val" in inputs_inputs[0].meta, "Missing val metadata on input node"
188+
fake_mode = inputs_inputs[0].meta["val"].fake_mode
189+
assert fake_mode is not None, "fake_mode is None on input node"
190+
with fake_mode:
191+
bias.meta["val"] = torch.full(other_inputs[0], 0, dtype=torch.float32)
192+
copy_node_metadata(bias, inputs_inputs[0])
201193

202194
# Make the args and kwargs for the replacement op
203195
args = tuple(inputs_inputs + [scale, zero_point])
@@ -373,16 +365,15 @@ def get_args_and_kwargs_softmax(
373365
),
374366
{"dtype": torch.int32},
375367
)
376-
if len(inputs_inputs) > 0:
377-
if "val" in inputs_inputs[0].meta:
378-
fake_mode = inputs_inputs[0].meta["val"].fake_mode
379-
if fake_mode is not None:
380-
with fake_mode:
381-
fake_mask = torch.full(mask_shape, 0.0, dtype=torch.int32)
382-
mask_tensor.meta["val"] = fake_mask
383-
else:
384-
mask_tensor.meta["val"] = torch.full(mask_shape, 0.0, dtype=torch.int32)
385-
copy_node_metadata(mask_tensor, inputs_inputs[0])
368+
assert (
369+
len(inputs_inputs) == 1
370+
), f"Expected 1 input for softmax, got {len(inputs_inputs)}"
371+
assert "val" in inputs_inputs[0].meta, "Missing val metadata on input node"
372+
fake_mode = inputs_inputs[0].meta["val"].fake_mode
373+
assert fake_mode is not None, "fake_mode is None on input node"
374+
with fake_mode:
375+
mask_tensor.meta["val"] = torch.full(mask_shape, 0.0, dtype=torch.int32)
376+
copy_node_metadata(mask_tensor, inputs_inputs[0])
386377
# Make the scale and zero_point tensors
387378
in_scale = dequants_inputs[0].args[1]
388379
in_zero_point = dequants_inputs[0].args[2]
@@ -636,25 +627,18 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
636627
torch.ops.aten.transpose.int,
637628
(weights_inputs[0], 0, 1),
638629
)
639-
if "val" in weights_inputs[0].meta:
640-
original_val = weights_inputs[0].meta["val"]
641-
fake_mode = original_val.fake_mode
642-
if fake_mode is not None:
643-
with fake_mode:
644-
transposed_val = torch.ops.aten.transpose.int(
645-
original_val, 0, 1
646-
)
647-
transposed_weights.meta["val"] = transposed_val
648-
else:
649-
transposed_shape = list(original_val.shape)
650-
transposed_shape[0], transposed_shape[1] = (
651-
transposed_shape[1],
652-
transposed_shape[0],
653-
)
654-
transposed_weights.meta["val"] = torch.zeros(
655-
transposed_shape, dtype=original_val.dtype
656-
)
657-
copy_node_metadata(transposed_weights, weights_inputs[0])
630+
assert (
631+
"val" in weights_inputs[0].meta
632+
), "Missing val metadata on weight node"
633+
original_val = weights_inputs[0].meta["val"]
634+
assert (
635+
original_val.fake_mode is not None
636+
), "fake_mode is None on weight node"
637+
with original_val.fake_mode:
638+
transposed_weights.meta["val"] = (
639+
torch.ops.aten.transpose.int(original_val, 0, 1)
640+
)
641+
copy_node_metadata(transposed_weights, weights_inputs[0])
658642

659643
# Call linear with transposed weight
660644
args, kwargs = get_args_and_kwargs_linear(

0 commit comments

Comments
 (0)