-
Notifications
You must be signed in to change notification settings - Fork 72
Add Op (aten::masked_scatter) | feat (torchlib) #2112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Op (aten::masked_scatter) | feat (torchlib) #2112
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #2112 +/- ##
==========================================
+ Coverage 72.97% 72.98% +0.01%
==========================================
Files 216 216
Lines 28918 28928 +10
Branches 3425 3426 +1
==========================================
+ Hits 21103 21114 +11
+ Misses 6663 6662 -1
Partials 1152 1152 ☔ View full report in Codecov by Sentry. |
ends = op.Gather(op.Shape(index), op.Constant(value_ints=[0]), axis=0) | ||
source = op.Slice(source, starts, ends, axes) | ||
|
||
return op.ScatterND(self, index, source) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about self * (1-mask) + mask * source
or mask * (source - self) + self
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation is following torchscript converter. If we find it's too slow, we can change it later. I am unblocking Gemma3 for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe create a follow up issue for tracking?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From Gemma3, the error lacks of support is raised.
https://github.com/huggingface/transformers/blob/7f5077e53682ca855afc826162b204ebf809f1f9/src/transformers/models/gemma3/modeling_gemma3.py#L1339