-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[NOMERGE] AugMix inplace mul #6861
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't merge this PR, just highlighting 2 bits:
aug = aug.mul(weights) | ||
else: | ||
# We can do in-place on all other cases. | ||
aug.mul_(weights) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Turns out this is not beneficial and just complicates the implementation. See benchmark:
[-------------- AutoAugment cpu torch.float32 --------------]
| old | new
1 threads: --------------------------------------------------
(3, 256, 256) | 7 (+- 4) ms | 6 (+- 0) ms
(16, 3, 256, 256) | 119 (+- 74) ms | 120 (+- 72) ms
6 threads: --------------------------------------------------
(3, 256, 256) | 8 (+- 1) ms | 8 (+- 1) ms
(16, 3, 256, 256) | 114 (+- 77) ms | 117 (+- 78) ms
Times are in milliseconds (ms).
[-------------- AutoAugment cuda torch.float32 -------------]
| old | new
1 threads: --------------------------------------------------
(3, 256, 256) | 2 (+- 0) ms | 2 (+- 0) ms
(16, 3, 256, 256) | 2 (+- 0) ms | 2 (+- 0) ms
6 threads: --------------------------------------------------
(3, 256, 256) | 2 (+- 0) ms | 2 (+- 0) ms
(16, 3, 256, 256) | 2 (+- 0) ms | 2 (+- 0) ms
Times are in milliseconds (ms).
[--------------- AutoAugment cpu torch.uint8 ---------------]
| old | new
1 threads: --------------------------------------------------
(3, 256, 256) | 8 (+- 1) ms | 8 (+- 1) ms
(16, 3, 256, 256) | 154 (+- 73) ms | 145 (+- 78) ms
6 threads: --------------------------------------------------
(3, 256, 256) | 10 (+- 1) ms | 10 (+- 1) ms
(16, 3, 256, 256) | 177 (+- 82) ms | 179 (+- 90) ms
Times are in milliseconds (ms).
[--------------- AutoAugment cuda torch.uint8 --------------]
| old | new
1 threads: --------------------------------------------------
(3, 256, 256) | 2 (+- 0) ms | 2 (+- 0) ms
(16, 3, 256, 256) | 2 (+- 0) ms | 2 (+- 0) ms
6 threads: --------------------------------------------------
(3, 256, 256) | 2 (+- 0) ms | 2 (+- 0) ms
(16, 3, 256, 256) | 2 (+- 0) ms | 2 (+- 0) ms
Times are in milliseconds (ms).
[----------------- AutoAugment cpu pil -----------------]
| old | new
1 threads: ----------------------------------------------
(3, 256, 256) | 9 (+- 1) ms | 9 (+- 1) ms
6 threads: ----------------------------------------------
(3, 256, 256) | 11 (+- 1) ms | 11 (+- 1) ms
Times are in milliseconds (ms).
I think we should just remove the TODO. @vfdev-5 / @pmeier could you clip this on your next PR?
bound = 1 if image.is_floating_point() else 255 | ||
if threshold > bound: | ||
raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pmeier This check didn't allow me to run the benchmark for float types so I removed it:
return solarize_image_tensor(inpt, threshold=threshold)
File "./vision/torchvision/prototype/transforms/functional/_color.py", line 323, in solarize_image_tensor
raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}")
TypeError: Threshold should be less or equal the maximum value of the dtype, but got 226.6666717529297
This is caused by the 255 values hardcoded on the Solarize linspace of AutoAugment. Now that we actually support float32 in all kernels, how do you propose to solve this? I think this should be disconnected from the PR at #6830. We could implement an approach to switch to 1.0 and 255 depending on the type of the input. Later if the aforementioned PR gets merged, we can update with more fine-grained max values that support all integers. Thoughts?
Investigating if adding an inplace mul on AugMix is beneficial.