-
Notifications
You must be signed in to change notification settings - Fork 72
Eliminate unnecessary ScatterND #2422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
❌ 10 Tests Failed:
View the top 3 failed test(s) by shortest run time
To view more test analytics, go to the Test Analytics Dashboard |
Can the torchlib implementation of slice_scatter be optimized to do the same, I wonder? I thought this part was potentially inside pytorch decomposition of |
Defininately! |
The torchlib implementation also seems off: SymInt? start=None, SymInt? end=None means they can be None even though the comment says otherwise. We should improve torchlib for this op. Related #2372 |
Right. I was puzzled by the comments. Not entirely sure what's happening there. |
Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
shape = op.Shape(data, start=0) | ||
dim = op.Gather(shape, axis, axis=0) | ||
full_range = op.Range(0, dim, 1) | ||
full_range_2d = op.Unsqueeze(full_range, [-1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see these ops in the repro: pytorch/pytorch#157289
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe you can delete
remove_redundant_scatternd = pattern.RewriteRule( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can consolidate the rules separately. (I am thinking of trying out Copilot to do it.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May need to make other dimensions symbolic. Otherwise, all of these ops will be constant-folded, and the indices becomes a constant.
Identify ScatterND(data, indices, updates) that can be replaced by Identity(updates). This is generated by the translation of
x[:, ...] = y
in PyTorch. The specific pattern is that the updated indices take the form [[0], ..., [S-1]] for the first dimension, where S is the size of the first dimension of the updated-data tensor. In effect, the scatter-update ends up being an assignment of a new value to the entire tensor.