Skip to content

[proto] Fix kernel passthrough and types of Normalize #6490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 25, 2022

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Aug 25, 2022

Addresses some of the remarks at #6486

Depends on #6487

This PR:

  • Ensures Normalize uses the right _transformed_types
  • Removes the passthrough functionality from the F.normalize() kernel
  • Updates the error message to the one on main branch

@datumbox datumbox marked this pull request as draft August 25, 2022 09:05
@@ -15,10 +15,8 @@


def normalize(inpt: DType, mean: List[float], std: List[float], inplace: bool = False) -> DType:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also have Sequence[float] here and convert to list again before calling the kernel? Since we already do this conversion in the transform, there will be no extra conversion when repeatedly calling this from there. On the other hand, users that only want the functional get the benefit of not only using lists.

Copy link
Contributor Author

@datumbox datumbox Aug 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agreed this could be a good change. But it's quite big and should be discussed further because it's changing the signature of a public method. F.normalize() exists already on the old API. Perhaps we can do this because we already change the type of inpt. But I think that's beyond the scope of this PR.

I think whether we can do this or not, depends on whether we will be able to keep F jit-scriptable. Our earlier efforts with @vfdev-5 showed that this might not be possible, but it would be really good if we could keep it as this would be the biggest BC-break of the proposal.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me add a bit more info here to have a bit more context.

If we had a way to detect that a certain Tensor object is a "simple tensor" while JIT-scripting, then we would be able refactor the entire F API to keep it JIT-scriptable and maintain BC. This could require some nasty tricks like:

  • Faking types
  • Splitting the non-jit scriptable part of F high-level kernels in separate methods that are marked with @torch.jit.unused.
  • Extensive use of torch.jit.is_scripting() and torch.jit.is_tracing()

Another thing we could do if we want to avoid BC breakages is to provide a jit-scriptable F that has exactly the same signatures as the high-level kernels but only supports pure tensors.

To keep the above a possibility, we shouldn't modify prematurely the existing signatures to receive parameters that are not JIT-scriptable (like Sequence). We might end up doing it at the end, but let's hold off to that until we know for sure.

@datumbox datumbox marked this pull request as ready for review August 25, 2022 11:07
@datumbox datumbox requested a review from vfdev-5 August 25, 2022 11:33
Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, thanks @datumbox !

@datumbox
Copy link
Contributor Author

The failing test is unrelated.

@datumbox datumbox merged commit 79098ad into pytorch:main Aug 25, 2022
@datumbox datumbox deleted the prototype/refactor_normalize branch August 25, 2022 16:26
facebook-github-bot pushed a commit that referenced this pull request Aug 30, 2022
Summary:
* Fix pass-through and supported types of Normalize

* update error message on kernel

* Fix linter.

* Fix the tests.

* Update type.

* Update type.

* Remove unnecessary tests for bboxes and masks.

Reviewed By: NicolasHug

Differential Revision: D39131017

fbshipit-source-id: d7847f4974b083395022471ba33dff0dbf7c9c55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants