-
Notifications
You must be signed in to change notification settings - Fork 363
fix: Update lowering passes in aten
tracer FX
#1708
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -243,7 +243,7 @@ def remove_ops( | |
module: torch.fx.GraphModule, | ||
) -> torch.fx.GraphModule: | ||
""" | ||
1. Remove clone, _unsafe_view node. #TODO Remove this func after functionalization is workable | ||
1. Remove clone, _unsafe_view, view node. #TODO Remove this func after functionalization is workable | ||
2. Remove inefficient op getitem(index=slice) P561572458 | ||
""" | ||
modified = False | ||
|
@@ -258,6 +258,7 @@ def remove_ops( | |
for n in module.graph.nodes: | ||
if n.op == "call_function" and n.target in ( | ||
torch.ops.aten._unsafe_view.default, | ||
torch.ops.aten.view.default, | ||
): | ||
modified = True | ||
node = n | ||
|
@@ -437,8 +438,10 @@ def compose_bmm( | |
real_other = input_other_n.all_input_nodes[0] | ||
if len(real_other.meta["val"].size()) == 2: | ||
new_func = aten_compose_bmm_2d | ||
if len(real_other.meta["val"].size()) == 3: | ||
elif len(real_other.meta["val"].size()) == 3: | ||
new_func = aten_compose_bmm_3d | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not clear why we need this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This addition is related to an issue in the |
||
new_func = torch.ops.aten.matmul | ||
|
||
with module.graph.inserting_after(node): | ||
new_args = (real_input, real_other) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not necessary to remove aten.view since the reshape operation is decomposed into aten.view(which is safe) and we have converter to support aten.view.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see - thank you for the clarification on that. The reason I had removed the view operator was for cases like this:
These show up in the GPT2 code, and when using the
aten
tracer, they result in the following error (though they run fine in Torch):