Skip to content

[NOMERGE] AugMix inplace mul #6861

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Oct 28, 2022

Investigating if adding an inplace mul on AugMix is beneficial.

@datumbox datumbox changed the title [NOMERGE] AugMix inplace add [NOMERGE] AugMix inplace mul Oct 28, 2022
Copy link
Contributor Author

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't merge this PR, just highlighting 2 bits:

aug = aug.mul(weights)
else:
# We can do in-place on all other cases.
aug.mul_(weights)
Copy link
Contributor Author

@datumbox datumbox Oct 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turns out this is not beneficial and just complicates the implementation. See benchmark:

[-------------- AutoAugment cpu torch.float32 --------------]
                         |        old       |        new     
1 threads: --------------------------------------------------
      (3, 256, 256)      |    7 (+-  4) ms  |    6 (+-  0) ms
      (16, 3, 256, 256)  |  119 (+- 74) ms  |  120 (+- 72) ms
6 threads: --------------------------------------------------
      (3, 256, 256)      |    8 (+-  1) ms  |    8 (+-  1) ms
      (16, 3, 256, 256)  |  114 (+- 77) ms  |  117 (+- 78) ms

Times are in milliseconds (ms).

[-------------- AutoAugment cuda torch.float32 -------------]
                         |        old       |        new     
1 threads: --------------------------------------------------
      (3, 256, 256)      |    2 (+-  0) ms  |    2 (+-  0) ms
      (16, 3, 256, 256)  |    2 (+-  0) ms  |    2 (+-  0) ms
6 threads: --------------------------------------------------
      (3, 256, 256)      |    2 (+-  0) ms  |    2 (+-  0) ms
      (16, 3, 256, 256)  |    2 (+-  0) ms  |    2 (+-  0) ms

Times are in milliseconds (ms).

[--------------- AutoAugment cpu torch.uint8 ---------------]
                         |        old       |        new     
1 threads: --------------------------------------------------
      (3, 256, 256)      |    8 (+-  1) ms  |    8 (+-  1) ms
      (16, 3, 256, 256)  |  154 (+- 73) ms  |  145 (+- 78) ms
6 threads: --------------------------------------------------
      (3, 256, 256)      |   10 (+-  1) ms  |   10 (+-  1) ms
      (16, 3, 256, 256)  |  177 (+- 82) ms  |  179 (+- 90) ms

Times are in milliseconds (ms).

[--------------- AutoAugment cuda torch.uint8 --------------]
                         |        old       |        new     
1 threads: --------------------------------------------------
      (3, 256, 256)      |    2 (+-  0) ms  |    2 (+-  0) ms
      (16, 3, 256, 256)  |    2 (+-  0) ms  |    2 (+-  0) ms
6 threads: --------------------------------------------------
      (3, 256, 256)      |    2 (+-  0) ms  |    2 (+-  0) ms
      (16, 3, 256, 256)  |    2 (+-  0) ms  |    2 (+-  0) ms

Times are in milliseconds (ms).

[----------------- AutoAugment cpu pil -----------------]
                     |        old       |        new     
1 threads: ----------------------------------------------
      (3, 256, 256)  |    9 (+-  1) ms  |    9 (+-  1) ms
6 threads: ----------------------------------------------
      (3, 256, 256)  |   11 (+-  1) ms  |   11 (+-  1) ms

Times are in milliseconds (ms).

I think we should just remove the TODO. @vfdev-5 / @pmeier could you clip this on your next PR?

Comment on lines -321 to -324
bound = 1 if image.is_floating_point() else 255
if threshold > bound:
raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmeier This check didn't allow me to run the benchmark for float types so I removed it:

    return solarize_image_tensor(inpt, threshold=threshold)
  File "./vision/torchvision/prototype/transforms/functional/_color.py", line 323, in solarize_image_tensor
    raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}")
TypeError: Threshold should be less or equal the maximum value of the dtype, but got 226.6666717529297

This is caused by the 255 values hardcoded on the Solarize linspace of AutoAugment. Now that we actually support float32 in all kernels, how do you propose to solve this? I think this should be disconnected from the PR at #6830. We could implement an approach to switch to 1.0 and 255 depending on the type of the input. Later if the aforementioned PR gets merged, we can update with more fine-grained max values that support all integers. Thoughts?

@datumbox datumbox marked this pull request as draft October 28, 2022 18:18
@datumbox datumbox closed this Oct 28, 2022
@datumbox datumbox deleted the prototype/augmix_inplace branch October 28, 2022 18:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants