Skip to content

Reimplement torch::flip based on advanced indexing #56713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 23 commits into from

Conversation

andfoy
Copy link
Collaborator

@andfoy andfoy commented Apr 22, 2021

Rationale

This PR improves the performance of torch::flip by using TensorIterator as the same fashion as using AdvancedIndexing. Which means that this implementation is semantically equivalent to indexing a tensor using reverse indices A[dim0_size - 1:0 ..., dimN_size-1:0, ...].

Benchmark results

The following benchmark compares the runtime of this implementation of flip against the current implementation, AdvancedIndexing with reversed indices, as well as OpenCV one. The comparison scenarios consider a 4D tensor [B, C, H, W], where the dimensions flipped correspond to H (vertical flip) and W (horizontal flip) under float32 and uint8 datatypes.

The benchmark implementation details can be found in https://github.com/andfoy/flip-benchmarks/blob/main/5_Stable_implementation/benchmarks.py. Additionally, there are correctness tests against the current flip implementation in https://github.com/andfoy/flip-benchmarks/blob/main/5_Stable_implementation/main.cpp, which tests against different layouts, datatypes and contiguous/non-contiguous tensors.

The following plots correspond to the means of the runtime of each operator after 100 samples. As it is possible to observe, the latest implementation of flip has a runtime similar to the indexing one. Also, the performance gains are up to 6X under some scenarios.

Horizontal flip (float)

bokeh_plot

Horizontal flip (uint8)

bokeh_plot(1)

Vertical flip (float)

bokeh_plot(2)

Vertical flip (uint8)

bokeh_plot(3)

cc @fmassa @vfdev-5

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Apr 22, 2021

💊 CI failures summary and remediations

As of commit ab66825 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@ailzhang ailzhang requested a review from wenleix April 23, 2021 04:36
@ailzhang ailzhang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 23, 2021
@wenleix wenleix requested a review from ngimel April 27, 2021 06:08
@wenleix
Copy link
Contributor

wenleix commented Apr 27, 2021

Thanks @andfoy , wondering why levering TensorIterator makes it fast?

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This generally looks pretty good, thanks!

I've left a few comments.

Also, could you run the performance numbers once more but this time from the PyTorch version that you compiled, just to double-check that under the same compilation flags we get the same speed-up as reported?

Additionally, I it might make sense to see if moving this code to the native/cpu/ folder would bring speed-ups to the code, as it would be compiled with -maxv and -mavx2 flags, potentially allowing for further compiler optimizations.

Comment on lines 29 to 31
int64_t offset = *(int64_t*)&indexers[0][idx * indexer_strides[0]];
for (int j = 1; j < num_indexers; j++) {
offset += *(int64_t*)&indexers[j][idx * indexer_strides[j]];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the future: we could look into specializing this when num_indexers=1. It could bring additional performance improvements

andfoy added a commit to andfoy/flip-benchmarks that referenced this pull request Apr 29, 2021
@andfoy
Copy link
Collaborator Author

andfoy commented Apr 29, 2021

These are the benchmark results for commit 0b63646 against the current torch.flip implementation. The comparison was done by declaring both implementations on the same PyTorch compilation on master...andfoy:benchmark_flip, as it is possible to observe, the results are similar to those presented initially.

Horizontal flip (float)

bokeh_plot(4)

Horizontal flip (uint8)

bokeh_plot(5)

Vertical flip (float)

bokeh_plot(6)

Vertical flip (uint8)

bokeh_plot(7)

@andfoy
Copy link
Collaborator Author

andfoy commented Apr 29, 2021

Wondering why levering TensorIterator makes it fast?

@wenleix My guess here would be that by precomputing the indices to flip, the runtime due to this loop is removed:

for (int64_t d = 0; d < total_dims; d++) {
      int64_t temp = cur_indices;
      cur_indices = cur_indices / stride_contiguous_v[d];
      rem = temp - cur_indices * stride_contiguous_v[d];
      dst_offset += flip_dims_b[d] ? (sizes_v[d] - 1 - cur_indices) * strides_v[d] : cur_indices * strides_v[d];
      cur_indices = rem;
}

However, I could be wrong here

@ngimel
Copy link
Collaborator

ngimel commented Apr 29, 2021

@andfoy what do you think of this approach https://github.com/pytorch/pytorch/compare/master...ngimel:flip?expand=1 where TI is used directly, without indexing tensors?
The advantage is that it can be very easily extended to cuda too. Perf benchmarks comparing to existing flip:
Before (time in us):

[-------------------- flip -------------------]
                       |   dim=1    |   dim=2  
1 threads: ------------------------------------
      (7, 112, 3)      |      80.4  |      73.2
      (28, 28, 3)      |      73.4  |      74.0
      (112, 7, 3)      |      71.6  |      74.5
      (8, 2048, 3)     |    1559.4  |    1501.0
      (128, 128, 3)    |    1487.5  |    1488.6
      (2048, 8, 3)     |    1489.0  |    1480.9
      (5, 102400, 3)   |   46704.5  |   46615.1
      (800, 640, 3)    |   46453.5  |   46702.9
      (128000, 4, 3)   |   46944.7  |   47343.2
      (4, 196608, 3)   |   72009.5  |   73965.6
      (1024, 768, 3)   |   72100.0  |   73281.6
      (262144, 3, 3)   |   70512.6  |   72408.6
      (16, 129600, 3)  |  189054.7  |  201812.7
      (1920, 1080, 3)  |  184979.7  |  224558.0
      (230400, 9, 3)   |  197751.2  |  195470.6

After

[------------------ flip -----------------]
                       |  dim=1   |  dim=2 
1 threads: --------------------------------
      (7, 112, 3)      |     3.7  |     4.0
      (28, 28, 3)      |     4.1  |     4.3
      (112, 7, 3)      |     4.8  |     5.0
      (8, 2048, 3)     |    20.1  |    27.6
      (128, 128, 3)    |    22.4  |    35.2
      (2048, 8, 3)     |    41.4  |    50.4
      (5, 102400, 3)   |   628.8  |   731.0
      (800, 640, 3)    |   651.5  |   939.0
      (128000, 4, 3)   |  1698.2  |  2766.2
      (4, 196608, 3)   |  1130.8  |  1145.4
      (1024, 768, 3)   |  1030.7  |  1643.5
      (262144, 3, 3)   |  3464.7  |  5177.9
      (16, 129600, 3)  |  3583.6  |  4012.7
      (1920, 1080, 3)  |  4049.7  |  4525.1
      (230400, 9, 3)   |  5824.7  |  6371.6

Benchmarking script:

CLICK ME

import torch
from torch.utils.benchmark import Timer
from torch.utils.benchmark import Compare
sizes = [
    (7, 112, 3),
    (28, 28, 3),
    (112, 7, 3),

    (8, 2048, 3),
    (128, 128, 3),
    (2048, 8, 3),

    (5, 102400, 3),
    (800, 640, 3),
    (128000, 4, 3),

    (4, 196608, 3),
    (1024, 768, 3),
    (262144, 3, 3),

    (16, 129600, 3),
    (1920, 1080, 3),
    (230400, 9, 3),

    # (16, 518400, 3),
    # (3840, 2160, 3),
    # (921600, 9, 3),
]
results = []

for size in sizes:

    H, W, C = size


    inp = torch.rand(C, H, W)
    #print(inp.size())
    t1 = Timer(stmt = "torch.flip(inp, [1])", sub_label=f"{size}", description="dim=1", label="flip", globals=globals())
    t2 = Timer(stmt = "torch.flip(inp, [2])", sub_label=f"{size}", description="dim=2", label="flip", globals=globals())
    timers = [t1,t2]
    for t in timers:
        results.append(
            t.blocked_autorange()
        )

comparison=Compare(results)
comparison.print()


@andfoy
Copy link
Collaborator Author

andfoy commented Apr 29, 2021

Since it can scale to CUDA easily, and the changes are way more simpler than the ones proposed, I think it is a good option. So in this order of ideas, basically a loop over the dimensions to flip should be called before calling the actual kernel?

for(int64_t i = 0; i < total_dims; i++) {
   if(flip_dims_b[i]) {
      iter.flip_strides(0, i);
   }
}

@ngimel
Copy link
Collaborator

ngimel commented Apr 29, 2021

No, flip_strides has to be called only once, and it's implementation in TensorIterator flips all the necessary dimensions. I agree that if we could do it in a loop like you propose, it's conceptually cleaner, but the reason it has to be done all at once is after TI is built, the dims that are being flipped are no longer the dims that were originally specified, because TensorIterator coalesces dimensions that it can view as one larger dim.
Imagine there's a 3d tensor where you want to flip the last dim. If input is contiguous, TensorIterator will know that it can;t collapse the last dimension because it will be flipped, but it will collapse first 2 dimensions, and will view the tensor as a 2d tensor (size0*size1, size2). So, when flipping strides, you should no longer flip the 2nd (0-based) stride, you need to flip the 1st! Luckily, we are sending a dummy tensor to TensorIterator that tracks which dimensions actually have to be flipped even after coalescing.
The code I'm proposing is very sparsely tested, so I won't be surprised if there are bugs, don't hold it against me :-)

@andfoy
Copy link
Collaborator Author

andfoy commented Apr 29, 2021

Thanks for the clarification @ngimel! I'll do a run of your changes against the tests that I have on the other repo.

@fmassa
Copy link
Member

fmassa commented Apr 29, 2021

@ngimel I like your approach with changing TensorIterator directly, but I wonder given how widely used it is if it would be ok to extend it's API for a single function to use it?

@andfoy
Copy link
Collaborator Author

andfoy commented May 4, 2021

I have a question regarding the quantized call for the new flip kernel, should I duplicate this code under quantized/cpu, or should I copy it back to TensorTransformations.cpp, while we also remove FlipKernel.cpp?

@ngimel
Copy link
Collaborator

ngimel commented May 4, 2021

Did you verify that duplicating code under native/cpu actually improves performance compared to keeping in in just native?

@andfoy
Copy link
Collaborator Author

andfoy commented May 4, 2021

@ngimel, let me check the performance comparison; if the performance is at par or if the gains are marginal, then I'll keep the kernel under TensorTransforms

@andfoy
Copy link
Collaborator Author

andfoy commented May 4, 2021

I checked the benchmark results, and the differences are not significant, which means that we can leave the kernel in TensorTransformations

@codecov
Copy link

codecov bot commented May 4, 2021

Codecov Report

Merging #56713 (ab66825) into master (b587354) will increase coverage by 0.02%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master   #56713      +/-   ##
==========================================
+ Coverage   76.84%   76.86%   +0.02%     
==========================================
  Files        1986     1986              
  Lines      197354   197384      +30     
==========================================
+ Hits       151661   151728      +67     
+ Misses      45693    45656      -37     

@andfoy
Copy link
Collaborator Author

andfoy commented May 4, 2021

The error in ROCm seems to be unrelated to this PR

@andfoy
Copy link
Collaborator Author

andfoy commented May 5, 2021

@ngimel @fmassa @wenleix This one is ready for a final review

@facebook-github-bot
Copy link
Contributor

@fmassa has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

// numbers get more balanced work load and a better cache location. The grain
// size here is chosen by the op benchmark to overcome the thread launch
// overhead. This value was taken from the AdvancedIndexing kernel.
const int index_parallel_grain_size = 3000;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andfoy any gains using this value vs default one ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me check!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andfoy
Copy link
Collaborator Author

andfoy commented May 10, 2021

These are the benchmark results for commit 118d256, where the default GRAIN_SIZE (32768) is compared against the custom value of this PR (3000). The comparison was done by exposing the grain_size as a parameter to flip (master...andfoy:benchmark_grain_size), as it is possible to observe, the custom value seems to lower the runtime against the default value. All the benchmark values were computed with parallelism enabled.

Horizontal flip (float)

bokeh_plot(12)

Horizontal flip (uint8)

bokeh_plot(13)

Vertical flip (float)

bokeh_plot(14)

Vertical flip (uint8)

bokeh_plot(15)

@facebook-github-bot
Copy link
Contributor

@fmassa has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me, thanks!

I only have one comment which I think could be good to have, otherwise I think this is good for merge.

Let me know what you think

@@ -13,81 +13,145 @@ namespace native {

constexpr size_t dim_bitset_size = 64;

Tensor build_index(Tensor input, int64_t flip_dim) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add all these internal functions inside an anonymous namespace? Given that those names are very generic there could potentially be conflicts with other files.

See for example how it's done in

So that all the build_index, build_indices_loop , make_index_iterator and Indexer are in the private namespace.

Thoughts?

@facebook-github-bot
Copy link
Contributor

@fmassa has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@facebook-github-bot
Copy link
Contributor

@fmassa merged this pull request in 30f26c5.

krshrimali pushed a commit to krshrimali/pytorch that referenced this pull request May 19, 2021
Summary:
## Rationale
This PR improves the performance of `torch::flip` by using `TensorIterator` as the same fashion as using `AdvancedIndexing`. Which means that this implementation is semantically equivalent to indexing a tensor using reverse indices `A[dim0_size - 1:0 ..., dimN_size-1:0, ...]`.

## Benchmark results
The following benchmark compares the runtime of this implementation of `flip` against the current implementation, AdvancedIndexing with reversed indices, as well as OpenCV one. The comparison scenarios consider a 4D tensor `[B, C, H, W]`, where the dimensions flipped correspond to `H` (vertical flip) and `W` (horizontal flip) under float32 and uint8 datatypes.

The benchmark implementation details can be found in https://github.com/andfoy/flip-benchmarks/blob/main/5_Stable_implementation/benchmarks.py. Additionally, there are correctness tests against the current flip implementation in https://github.com/andfoy/flip-benchmarks/blob/main/5_Stable_implementation/main.cpp, which tests against different layouts, datatypes and contiguous/non-contiguous tensors.

The following plots correspond to the means of the runtime of each operator after 100 samples. As it is possible to observe, the latest implementation of flip has a runtime similar to the indexing one. Also, the performance gains are up to 6X under some scenarios.

### Horizontal flip (float)
![bokeh_plot](https://user-images.githubusercontent.com/1878982/115766715-e72a3d80-a36d-11eb-8552-9005028900b1.png)

### Horizontal flip (uint8)
![bokeh_plot(1)](https://user-images.githubusercontent.com/1878982/115766720-e7c2d400-a36d-11eb-822d-44046882c976.png)

### Vertical flip (float)
![bokeh_plot(2)](https://user-images.githubusercontent.com/1878982/115766721-e7c2d400-a36d-11eb-8f4b-d44c8c33d104.png)

### Vertical flip (uint8)
![bokeh_plot(3)](https://user-images.githubusercontent.com/1878982/115766725-e85b6a80-a36d-11eb-907a-cfcddba555ad.png)

cc fmassa vfdev-5

Pull Request resolved: pytorch#56713

Reviewed By: datumbox

Differential Revision: D28255088

Pulled By: fmassa

fbshipit-source-id: 5b8684812357c331e83a677b99cf0d78f0821678
@andfoy andfoy deleted the improve_flip branch May 24, 2021 23:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants