Skip to content

[NOMERGE] AugMix inplace mul #6861

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions torchvision/prototype/transforms/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,13 +505,18 @@ def forward(self, *inputs: Any) -> Any:
aug = self._apply_image_or_video_transform(
aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
mix.add_(
# The multiplication below could become in-place provided `aug is not batch and aug.is_floating_point()`
# Currently we can't do this because `aug` has to be `unint8` to support ops like `equalize`.
# TODO: change this once all ops in `F` support floats. https://github.com/pytorch/vision/issues/6840
combined_weights[:, i].reshape(batch_dims)
* aug
)
weights = combined_weights[:, i].reshape(batch_dims)
# We use the is operator to cheaply check if the `aug` reference we have is the same as the `batch`.
# This can happen if we repeatedly sample the identity operator. Because all operators avoid in-place
# modifications, this is a cheap and safe way to detect it.
if aug is batch or not aug.is_floating_point():
# When the `aug` is not variable is not floating point or if it's the same as the original batch
# we can't do in-place operations.
aug = aug.mul(weights)
else:
# We can do in-place on all other cases.
aug.mul_(weights)
Copy link
Contributor Author

@datumbox datumbox Oct 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turns out this is not beneficial and just complicates the implementation. See benchmark:

[-------------- AutoAugment cpu torch.float32 --------------]
                         |        old       |        new     
1 threads: --------------------------------------------------
      (3, 256, 256)      |    7 (+-  4) ms  |    6 (+-  0) ms
      (16, 3, 256, 256)  |  119 (+- 74) ms  |  120 (+- 72) ms
6 threads: --------------------------------------------------
      (3, 256, 256)      |    8 (+-  1) ms  |    8 (+-  1) ms
      (16, 3, 256, 256)  |  114 (+- 77) ms  |  117 (+- 78) ms

Times are in milliseconds (ms).

[-------------- AutoAugment cuda torch.float32 -------------]
                         |        old       |        new     
1 threads: --------------------------------------------------
      (3, 256, 256)      |    2 (+-  0) ms  |    2 (+-  0) ms
      (16, 3, 256, 256)  |    2 (+-  0) ms  |    2 (+-  0) ms
6 threads: --------------------------------------------------
      (3, 256, 256)      |    2 (+-  0) ms  |    2 (+-  0) ms
      (16, 3, 256, 256)  |    2 (+-  0) ms  |    2 (+-  0) ms

Times are in milliseconds (ms).

[--------------- AutoAugment cpu torch.uint8 ---------------]
                         |        old       |        new     
1 threads: --------------------------------------------------
      (3, 256, 256)      |    8 (+-  1) ms  |    8 (+-  1) ms
      (16, 3, 256, 256)  |  154 (+- 73) ms  |  145 (+- 78) ms
6 threads: --------------------------------------------------
      (3, 256, 256)      |   10 (+-  1) ms  |   10 (+-  1) ms
      (16, 3, 256, 256)  |  177 (+- 82) ms  |  179 (+- 90) ms

Times are in milliseconds (ms).

[--------------- AutoAugment cuda torch.uint8 --------------]
                         |        old       |        new     
1 threads: --------------------------------------------------
      (3, 256, 256)      |    2 (+-  0) ms  |    2 (+-  0) ms
      (16, 3, 256, 256)  |    2 (+-  0) ms  |    2 (+-  0) ms
6 threads: --------------------------------------------------
      (3, 256, 256)      |    2 (+-  0) ms  |    2 (+-  0) ms
      (16, 3, 256, 256)  |    2 (+-  0) ms  |    2 (+-  0) ms

Times are in milliseconds (ms).

[----------------- AutoAugment cpu pil -----------------]
                     |        old       |        new     
1 threads: ----------------------------------------------
      (3, 256, 256)  |    9 (+-  1) ms  |    9 (+-  1) ms
6 threads: ----------------------------------------------
      (3, 256, 256)  |   11 (+-  1) ms  |   11 (+-  1) ms

Times are in milliseconds (ms).

I think we should just remove the TODO. @vfdev-5 / @pmeier could you clip this on your next PR?

mix.add_(aug)
mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)

if isinstance(orig_image_or_video, (features.Image, features.Video)):
Expand Down
4 changes: 0 additions & 4 deletions torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,10 +318,6 @@ def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT:


def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor:
bound = 1 if image.is_floating_point() else 255
if threshold > bound:
raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}")

Comment on lines -321 to -324
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmeier This check didn't allow me to run the benchmark for float types so I removed it:

    return solarize_image_tensor(inpt, threshold=threshold)
  File "./vision/torchvision/prototype/transforms/functional/_color.py", line 323, in solarize_image_tensor
    raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}")
TypeError: Threshold should be less or equal the maximum value of the dtype, but got 226.6666717529297

This is caused by the 255 values hardcoded on the Solarize linspace of AutoAugment. Now that we actually support float32 in all kernels, how do you propose to solve this? I think this should be disconnected from the PR at #6830. We could implement an approach to switch to 1.0 and 255 depending on the type of the input. Later if the aforementioned PR gets merged, we can update with more fine-grained max values that support all integers. Thoughts?

return torch.where(image >= threshold, invert_image_tensor(image), image)


Expand Down