Skip to content

Commit a063082

Browse files
committed
fix: Update lowering passes in aten tracer
- Enable translation to `reshape` from `view`, which was causing failures when compiling BERT model due to memory layout of Tensors - Default to `matmul` within `compose_bmm` lowering pass when the dimension of inputs exceeds 3
1 parent c2126b1 commit a063082

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def remove_ops(
243243
module: torch.fx.GraphModule,
244244
) -> torch.fx.GraphModule:
245245
"""
246-
1. Remove clone, _unsafe_view node. #TODO Remove this func after functionalization is workable
246+
1. Remove clone, _unsafe_view, view node. #TODO Remove this func after functionalization is workable
247247
2. Remove inefficient op getitem(index=slice) P561572458
248248
"""
249249
modified = False
@@ -258,6 +258,7 @@ def remove_ops(
258258
for n in module.graph.nodes:
259259
if n.op == "call_function" and n.target in (
260260
torch.ops.aten._unsafe_view.default,
261+
torch.ops.aten.view.default,
261262
):
262263
modified = True
263264
node = n
@@ -437,8 +438,10 @@ def compose_bmm(
437438
real_other = input_other_n.all_input_nodes[0]
438439
if len(real_other.meta["val"].size()) == 2:
439440
new_func = aten_compose_bmm_2d
440-
if len(real_other.meta["val"].size()) == 3:
441+
elif len(real_other.meta["val"].size()) == 3:
441442
new_func = aten_compose_bmm_3d
443+
else:
444+
new_func = torch.ops.aten.matmul
442445

443446
with module.graph.inserting_after(node):
444447
new_args = (real_input, real_other)

0 commit comments

Comments
 (0)