-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[proto] Fix kernel passthrough and types of Normalize #6490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -15,10 +15,8 @@ | |||
|
|||
|
|||
def normalize(inpt: DType, mean: List[float], std: List[float], inplace: bool = False) -> DType: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also have Sequence[float]
here and convert to list again before calling the kernel? Since we already do this conversion in the transform, there will be no extra conversion when repeatedly calling this from there. On the other hand, users that only want the functional get the benefit of not only using lists.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agreed this could be a good change. But it's quite big and should be discussed further because it's changing the signature of a public method. F.normalize()
exists already on the old API. Perhaps we can do this because we already change the type of inpt
. But I think that's beyond the scope of this PR.
I think whether we can do this or not, depends on whether we will be able to keep F
jit-scriptable. Our earlier efforts with @vfdev-5 showed that this might not be possible, but it would be really good if we could keep it as this would be the biggest BC-break of the proposal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me add a bit more info here to have a bit more context.
If we had a way to detect that a certain Tensor object is a "simple tensor" while JIT-scripting, then we would be able refactor the entire F
API to keep it JIT-scriptable and maintain BC. This could require some nasty tricks like:
- Faking types
- Splitting the non-jit scriptable part of
F
high-level kernels in separate methods that are marked with @torch.jit.unused. - Extensive use of
torch.jit.is_scripting()
andtorch.jit.is_tracing()
Another thing we could do if we want to avoid BC breakages is to provide a jit-scriptable F
that has exactly the same signatures as the high-level kernels but only supports pure tensors.
To keep the above a possibility, we shouldn't modify prematurely the existing signatures to receive parameters that are not JIT-scriptable (like Sequence
). We might end up doing it at the end, but let's hold off to that until we know for sure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, thanks @datumbox !
The failing test is unrelated. |
Summary: * Fix pass-through and supported types of Normalize * update error message on kernel * Fix linter. * Fix the tests. * Update type. * Update type. * Remove unnecessary tests for bboxes and masks. Reviewed By: NicolasHug Differential Revision: D39131017 fbshipit-source-id: d7847f4974b083395022471ba33dff0dbf7c9c55
Addresses some of the remarks at #6486
Depends on #6487
This PR:
Normalize
uses the right_transformed_types
F.normalize()
kernel