Skip to content

fix: Update lowering passes in aten tracer FX #1708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def remove_ops(
module: torch.fx.GraphModule,
) -> torch.fx.GraphModule:
"""
1. Remove clone, _unsafe_view node. #TODO Remove this func after functionalization is workable
1. Remove clone, _unsafe_view, view node. #TODO Remove this func after functionalization is workable
2. Remove inefficient op getitem(index=slice) P561572458
"""
modified = False
Expand All @@ -258,6 +258,7 @@ def remove_ops(
for n in module.graph.nodes:
if n.op == "call_function" and n.target in (
torch.ops.aten._unsafe_view.default,
torch.ops.aten.view.default,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not necessary to remove aten.view since the reshape operation is decomposed into aten.view(which is safe) and we have converter to support aten.view.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see - thank you for the clarification on that. The reason I had removed the view operator was for cases like this:

    def forward(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_shape = x.size()[:-2] + (-1,)
        return x.view(new_shape)

These show up in the GPT2 code, and when using the aten tracer, they result in the following error (though they run fine in Torch):

  File "~/TensorRT/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py", line 161, in opt_trace
    fx_module(*args)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/graph_module.py", line 662, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/graph_module.py", line 281, in __call__
    raise e
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/graph_module.py", line 271, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.15", line 9, in forward
  File "/usr/local/lib/python3.8/dist-packages/torch/_ops.py", line 329, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

):
modified = True
node = n
Expand Down Expand Up @@ -437,8 +438,10 @@ def compose_bmm(
real_other = input_other_n.all_input_nodes[0]
if len(real_other.meta["val"].size()) == 2:
new_func = aten_compose_bmm_2d
if len(real_other.meta["val"].size()) == 3:
elif len(real_other.meta["val"].size()) == 3:
new_func = aten_compose_bmm_3d
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not clear why we need this new_func = torch.ops.aten.matmul? Any example or unit test?

Copy link
Collaborator Author

@gs-olive gs-olive Mar 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This addition is related to an issue in the compose_bmm lowering pass. I noticed that input_n can have a different shape than real_input, which causes the batch matrix multiply to have 4 dimensions instead of 3, reaching this else statement. I don't yet have a minimal reproducing example yet, as #1789 would likely need to be addressed first.

new_func = torch.ops.aten.matmul

with module.graph.inserting_after(node):
new_args = (real_input, real_other)
Expand Down