Skip to content

Transforms without dispatcher #5421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Feb 25, 2022
Merged

Conversation

pmeier
Copy link
Collaborator

@pmeier pmeier commented Feb 14, 2022

This is the alternative proposal to #5418. The main difference is that this PR eliminates the high-level kernels aka dispatchers introduced in #5407 by merging it into the transforms. That means in contrast to #5418, in this PR transforms are responsible for the dispatch.

Pros

  • One reason why Revamp prototype features and transforms #5407 had such a lengthy discussion is that it is really hard to properly document the function of a dispatcher. The approach in this PR completely circumvents this issue since the expectation of calling a transform is very different from calling a high-level kernel. This means, we can simply link the low-level kernels, but don't need to specify how they are called exactly.
  • By removing the high-level kernels we shrink our public API surface that we have to maintain / test / document / ... Although the functionally is equivalent equivalent, this PR removes ~500 LoC.
  • The dispatch mechanism from Revamp prototype features and transforms #5407 involves a lot of magic to enable reasonable flexibility. Still, in some cases (I intentionally avoid "edge cases" here, because I don't think they are this uncommon), we need even more complicated logic to handle everything. Out of the 68 registered kernels, 12 need manual handling (~17%).
  • As noted in add prototype transforms that use the prototype dispatchers #5418 (comment), the other design requires get_params to return all parameters that are needed, whereas in this PR we can again only generate the dynamic ones there and simply access the constant ones.

Cons

  • By removing the magic of the dispatchers, we will need to write more boilerplate code ourselves. Note that this does not necessarily translates to more LoC. Compare

    @dispatch(
        {
            torch.Tensor: _F.resize,
            PIL.Image.Image: _F.resize,
            features.Image: K.resize_image,
            features.SegmentationMask: K.resize_segmentation_mask,
            features.BoundingBox: None,
        }
    )
    def resize(input: Any, *args: Any, **kwargs: Any) -> Any:
        """TODO: add docstring"""
        if isinstance(input, features.BoundingBox):
            size = kwargs.pop("size")
            output = K.resize_bounding_box(input, size=size, image_size=input.image_size)
            return features.BoundingBox.new_like(input, output, image_size=size)
    
        raise RuntimeError
    
    class Resize(Transform):
        _DISPATCHER = F.resize

    to

    class Resize(Transform):
        def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
            if type(input) is torch.Tensor or isinstance(input, PIL.Image.Image):
                return _F.resize(input, size=self.size, interpolation=self.interpolation)
            elif type(input) is features.Image:
                output = K.resize_image(input, size=self.size, interpolation=self.interpolation)
                return features.Image.new_like(input, output)
            elif type(input) is features.SegmentationMask:
                output = K.resize_segmentation_mask(input, size=self.size)
                return features.SegmentationMask.new_like(input, output)
            elif type(input) is features.BoundingBox:
                output = K.resize_bounding_box(input, size=self.size, image_size=input.image_size)
                return features.BoundingBox.new_like(input, output, image_size=self.size)
            else:
                return input
  • Logging API calls is easier with dispatchers since the call can be placed inside the decorator. Since this layer of the API is removed here, but we still want to log the calls to the kernels, we need to perform the logging there. try api call logging with decorator #5424 investigates if this is possible through a decorator, which would weaken this con.

  • Writing custom composite transforms is harder. With the dispatchers we can do something like

    def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
        input = F.foo(input)
        input = F.bar(input)
        input = F.baz(input)
        return input

    Without dispatchers we could either call the low level kernels directly

    def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
        if isinstance(input, features.Image):
            input = K.foo_image(input)
            input = K.bar_image(input)
            input = K.baz_image(input)
        elif isinstance(input, features.BoundingBox):
            input = K.foo_bounding_box(input)
            input = K.bar_bounding_box(input)
            input = K.baz_bounding_box(input)
    
        return input

    or rely on the transforms objects

    def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
        input = transforms.Foo()(input)
        input = transforms.Bar()(input)
        input = transforms.Baz()(input)
        return input

    Out of the two options, I prefer the latter, since we have less duplicated code especially if the dispatch is not as straight forward as in the example.

  • (This only applies if we go with option 2 form the point above) If we use other transforms to write composite ops, we need to have a primitive for every transforms that does not have any random functionality. For example, we would need to have a transforms.Affine as well as a tranforms.RandomAffine. From an implementation perspective this is fairly easy to achieve with RandomAffine subclassing Affine and overwriting the _get_params() method with random sampling. This also ties directly into add prototype transforms that use the prototype dispatchers #5418 (comment)

@facebook-github-bot
Copy link

facebook-github-bot commented Feb 14, 2022

💊 CI failures summary and remediations

As of commit 0943de0 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@datumbox datumbox marked this pull request as draft February 16, 2022 14:15
@pmeier
Copy link
Collaborator Author

pmeier commented Feb 17, 2022

After some offline discussion with @vfdev-5, it became clear that always having a primitive transform will significantly increase the number of transforms that we provide. Thus, although this PR removes the high-level kernels and thus effectively turns 3 API layers into 2, the surface reduction is not as significant as originally thought. There will still be some reduction, since some transforms are already primitives, e.g. Resize, and thus no extra transform is needed there.

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall I like this approach a lot. I've added a few comments, let's sync offline for next steps.

@pmeier pmeier marked this pull request as ready for review February 22, 2022 08:18
Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's looking very good. Few more comments:

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nits and notes:


raise RuntimeError
def gaussian_blur_image_pil(img: PIL.Image, kernel_size: List[int], sigma: Optional[List[float]] = None) -> PIL.Image:
return to_pil_image(gaussian_blur_image_tensor(to_tensor(img), kernel_size=kernel_size, sigma=sigma))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vfdev-5 to_tensor is heavily discouraged to be used in favour of pil_to_tensor. We've made this change few months back. Could you send a separate PR that modifies the gaussian_blur from the main area and this gaussian_blur_image_pil kernel from prototype to use the new method?



def _grayscale_to_rgb_tensor(grayscale: torch.Tensor) -> torch.Tensor:
return grayscale.expand(3, 1, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking TODO: This kernel assumes the tensor is a single image and not a batch. We shouldn't make this assumption. We should fix on a separate PR.

expand: bool = False,
fill: Optional[List[float]] = None,
center: Optional[List[float]] = None,
) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vfdev-5 FYI some transforms have been copypasted here to make refactoring easier. So any changes on main area need to be made also here.

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a major improvement over the previous proposal.

We should merge. I'll follow up with a more thorough review on the code-base that we have on the prototype transforms.

@pmeier pmeier merged commit 7251769 into pytorch:main Feb 25, 2022
@pmeier pmeier deleted the transforms-without-dispatcher branch February 25, 2022 13:15
facebook-github-bot pushed a commit that referenced this pull request Mar 2, 2022
Summary:
* add prototype transforms that don't need dispatchers

* cleanup

* remove legacy_transform decorator

* remove legacy classes

* remove explicit param passing

* streamline extra_repr

* remove obsolete ._supports() method

* cleanup

* remove Query

* cleanup

* fix tests

* kernels -> functional

* move image size and num channels extraction to functional

* extend legacy function to extract image size and num channels

* implement dispatching for auto augment

* fix auto augment dispatch

* revert some naming changes

* remove ability to pass params to autoaugment

* fix legacy image size extraction

* align prototype.transforms.functional with transforms.functional

* cleanup

* fix image size and channels extraction

* fix affine and rotate

* revert image size to (width, height)

* Minor corrections

Reviewed By: datumbox

Differential Revision: D34579512

fbshipit-source-id: 2044269d771b3488010b62ff611cf4f75ef75ed4

Co-authored-by: Vasilis Vryniotis <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants