Skip to content

Commit 951c2dc

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
implementation of fbgemm op - permute_multi_embedding (#2738)
Summary: X-link: pytorch/torchrec#2120 # context * current we have a working function `permute_pooled_embs_auto_grad` to do a full permute of KTs, including forward and backward * it has several limitations: a) it has to be a full permute, duplicates are not supported; b) in the main [use case](https://fburl.com/code/89od0rqm) there has to be a torch.concat on the input KTs, which is not very efficient; c) the function output a single KT which requires a split operation * there is some attempt to support duplicated outputs, but the backward doesn't work * this diff is trying to create a new kernel (named `permute_multi_embedding`) to support a multiple-KT to multiple-KT mapping operation with backward support # notes * this diff focuses on the implemenation and test of the operator * performance analysis and benchmark are in the next diff # operator example usage * used in python ``` # test inputs: 3 KTs with batch_size=2048 batch_size = 2048 keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] lengths = [[96, 256], [512, 128, 768], [1024]] values = [ torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True) for lens in lengths ] # target outputs: 4 KTs with re-arranged keys (features), duplicates are allowed groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] # accessorial arguments to the op/kernel permutes, in_lengths, out_lengths = _multi_remap_to_groups( keys, lengths, groups ) # arguments outputs = torch.ops.fbgemm.permute_multi_embedding( values, permutes, in_lengths, out_lengths ) ``` * permutes ``` permutes = tensor( [ [0, 0, 0, 0, 3, 4], # f1 [1, 0, 0, 3, 5, 0], # f3 [0, 1, 3, 0, 4, 0], # f2 [1, 2, 5, 0, 6, 0], # f4 [0, 2, 0, 6, 3, -6], # f1 [2, 2, 0, 9, 8, 0], # f6 [0, 3, 0, 0, 3, -8], # f1 [1, 3, 11, 3, 7, 0], # f5 ] ) ``` # details 1. from the above example usage, we can clearly see that the operatior takes in the following: a) values: List[torch.Tensor], which represents the input KTs b) permutes: torch.Tensor, which contains the permute information, will be explained later c) output_lengths_list: List[int], the lengths of the output tensors (KTs), which is needed to allocate memory on device ahead d) in_lengths: torch.Tensor, lengths of input tensors, which is on device e) out_lengths: torch.Tensor, lengths of output tensors, which is on device 2. the operator returns a list of tensors, which represents the permuted KTs 3. `permute` is the most critical argument in this operator: a) 2-D tensor b) each row represents a key (feature) permute move c) a permute move = [input_tensor_id, output_tensor_id, input_start_idx, output_start_idx, feature_length, jump] d) jump is used in backward when a key (feature) from the input tensor is mapped to multiple places in the output tensors Differential Revision: D57055616
1 parent 7f77444 commit 951c2dc

File tree

7 files changed

+508
-0
lines changed

7 files changed

+508
-0
lines changed

fbgemm_gpu/FbgemmGpu.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,7 @@ set(fbgemm_gpu_sources_static_cpu
446446
codegen/training/backward/embedding_backward_dense_host_cpu.cpp
447447
codegen/utils/embedding_bounds_check_host_cpu.cpp
448448
src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp
449+
src/permute_pooled_embedding_ops/permute_multi_embedding_ops_cpu.cpp
449450
src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp
450451
src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp
451452
src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp

fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,16 @@
1919
torch.ops.load_library(
2020
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu"
2121
)
22+
torch.ops.load_library(
23+
"//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_cpu"
24+
)
2225
try:
2326
torch.ops.load_library(
2427
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_gpu"
2528
)
29+
torch.ops.load_library(
30+
"//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_gpu"
31+
)
2632
except OSError:
2733
# This is for forward compatibility (new torch.package + old backend)
2834
# We should be able to remove it after this diff is picked up by all backend
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <ATen/ATen.h>
12+
#include <torch/csrc/api/include/torch/types.h>
13+
#include <torch/csrc/autograd/custom_function.h>
14+
15+
#include "fbgemm_gpu/dispatch_macros.h"
16+
#include "fbgemm_gpu/ops_utils.h"
17+
#include "fbgemm_gpu/sparse_ops_utils.h"
18+
19+
namespace fbgemm_gpu {
20+
21+
using Tensor = at::Tensor;
22+
using torch::autograd::AutogradContext;
23+
using torch::autograd::variable_list;
24+
25+
using Tensor = at::Tensor;
26+
using torch::autograd::AutogradContext;
27+
using torch::autograd::variable_list;
28+
29+
class PermuteMultiEmbeddingOp
30+
: public torch::autograd::Function<PermuteMultiEmbeddingOp> {
31+
public:
32+
static variable_list forward(
33+
AutogradContext* ctx,
34+
const at::TensorList& pooled_embs,
35+
const std::vector<int64_t>& permutes,
36+
const std::vector<int64_t>& in_lengths,
37+
const std::vector<int64_t>& out_lengths);
38+
39+
static variable_list backward(
40+
AutogradContext* ctx,
41+
variable_list grad_output);
42+
};
43+
44+
std::vector<Tensor> permute_multi_embedding_cpu(
45+
const at::TensorList& pooled_embs,
46+
const std::vector<int64_t>& permutes,
47+
const std::vector<int64_t>& in_lengths,
48+
const std::vector<int64_t>& out_lengths,
49+
const bool& reverse_permute);
50+
51+
std::vector<Tensor> permute_multi_embedding_meta(
52+
const at::TensorList& pooled_embs,
53+
const std::vector<int64_t>& permutes,
54+
const std::vector<int64_t>& in_lengths,
55+
const std::vector<int64_t>& out_lengths,
56+
const bool& reverse_permute);
57+
58+
std::vector<Tensor> permute_multi_embedding_gpu(
59+
const at::TensorList& pooled_embs,
60+
const std::vector<int64_t>& permutes,
61+
const std::vector<int64_t>& in_lengths,
62+
const std::vector<int64_t>& out_lengths,
63+
const bool& reverse_permute);
64+
} // namespace fbgemm_gpu
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "fbgemm_gpu/permute_multi_embedding_function.h"
10+
#include <cstdint>
11+
#include <iostream>
12+
13+
namespace fbgemm_gpu {
14+
15+
using Tensor = at::Tensor;
16+
using torch::autograd::AutogradContext;
17+
using torch::autograd::variable_list;
18+
19+
variable_list PermuteMultiEmbeddingOp::forward(
20+
AutogradContext* ctx,
21+
const at::TensorList& pooled_embs,
22+
const std::vector<int64_t>& permutes,
23+
const std::vector<int64_t>& in_lengths,
24+
const std::vector<int64_t>& out_lengths) {
25+
ctx->saved_data["permutes"] = permutes;
26+
ctx->saved_data["in_lengths"] = in_lengths;
27+
ctx->saved_data["out_lengths"] = out_lengths;
28+
29+
/*
30+
select the correct dispatched (cpu/gpu) forward function
31+
the cpu/gup function needs to be registered in the dispatcher,
32+
e.g., DISPATCH_TO_CPU, DISPATCH_TO_CUDA, etc.
33+
*/
34+
const auto permute_op =
35+
torch::Dispatcher::singleton()
36+
.findSchemaOrThrow("fbgemm::permute_multi_embedding_function", "")
37+
.typed<decltype(permute_multi_embedding_cpu)>();
38+
39+
return permute_op.call(pooled_embs, permutes, in_lengths, out_lengths, false);
40+
}
41+
42+
variable_list PermuteMultiEmbeddingOp::backward(
43+
AutogradContext* ctx,
44+
variable_list grad_output) {
45+
const auto permutes = ctx->saved_data["permutes"].toIntVector();
46+
const auto in_lengths = ctx->saved_data["in_lengths"].toIntVector();
47+
const auto out_lengths = ctx->saved_data["out_lengths"].toIntVector();
48+
49+
/*
50+
select the correct dispatched (cpu/gpu) backward function
51+
the cpu/gup function needs to be registered in the dispatcher,
52+
e.g., DISPATCH_TO_CPU, DISPATCH_TO_CUDA, etc.
53+
*/
54+
const auto permute_op =
55+
torch::Dispatcher::singleton()
56+
.findSchemaOrThrow("fbgemm::permute_multi_embedding_function", "")
57+
.typed<decltype(permute_multi_embedding_cpu)>();
58+
auto grad_input =
59+
permute_op.call(grad_output, permutes, out_lengths, in_lengths, true);
60+
grad_input.push_back(torch::autograd::Variable()); // permutes
61+
grad_input.push_back(torch::autograd::Variable()); // in_lengths
62+
grad_input.push_back(torch::autograd::Variable()); // out_lengths
63+
return grad_input;
64+
}
65+
66+
} // namespace fbgemm_gpu
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <ATen/AccumulateType.h>
10+
#include <ATen/cuda/CUDAContext.h>
11+
#include <ATen/cuda/Exceptions.h>
12+
#include <c10/cuda/CUDAGuard.h>
13+
#include <cuda.h>
14+
#include <cuda_runtime.h>
15+
#include <cstdint>
16+
#include <ostream>
17+
18+
#include "fbgemm_gpu/fbgemm_cuda_utils.cuh"
19+
#include "fbgemm_gpu/permute_multi_embedding_function.h"
20+
21+
using Tensor = at::Tensor;
22+
23+
namespace fbgemm_gpu {
24+
25+
// Kernerl for permute pooled embedding op.
26+
// This kernel is moving D elements per warp.
27+
template <typename scalar_t>
28+
__global__ void permute_multi_embs_kernel(
29+
const scalar_t** __restrict__ inputs,
30+
const scalar_t** __restrict__ outputs,
31+
const int64_t* __restrict__ permutes,
32+
const int64_t* __restrict__ input_lengths,
33+
const int64_t* __restrict__ output_lengths,
34+
const int64_t batch_size,
35+
const int64_t permute_size,
36+
const bool reverse_permute) {
37+
// workers in a warp handle a feature
38+
const int32_t worker_id = threadIdx.x % warpSize;
39+
const int32_t worker_size = warpSize;
40+
const int32_t permute_id =
41+
blockIdx.x * (blockDim.x / warpSize) + threadIdx.x / warpSize;
42+
const int32_t batch_id = blockIdx.y + gridDim.y * blockIdx.z;
43+
if (batch_id >= batch_size) {
44+
return;
45+
}
46+
if (permute_id >= permute_size) {
47+
return;
48+
}
49+
50+
// parse permutes
51+
const int64_t params = 6;
52+
int64_t in_tensor, out_tensor, in_start, out_start, length, jump;
53+
if (reverse_permute) {
54+
out_tensor = permutes[params * permute_id];
55+
in_tensor = permutes[params * permute_id + 1];
56+
out_start = permutes[params * permute_id + 2];
57+
in_start = permutes[params * permute_id + 3];
58+
} else {
59+
in_tensor = permutes[params * permute_id];
60+
out_tensor = permutes[params * permute_id + 1];
61+
in_start = permutes[params * permute_id + 2];
62+
out_start = permutes[params * permute_id + 3];
63+
}
64+
length = permutes[params * permute_id + 4];
65+
jump = permutes[params * permute_id + 5];
66+
67+
if (worker_id >= length) {
68+
return;
69+
}
70+
if (reverse_permute && jump < 0) {
71+
return;
72+
}
73+
74+
// locate the batch_id
75+
int64_t in_length = input_lengths[in_tensor];
76+
scalar_t* input_ptr = (scalar_t*)inputs[in_tensor];
77+
input_ptr += batch_id * in_length;
78+
79+
int64_t out_length = output_lengths[out_tensor];
80+
scalar_t* output_ptr = (scalar_t*)outputs[out_tensor];
81+
output_ptr += batch_id * out_length;
82+
83+
// printf( // debug print
84+
// "input_tensors[%ld][%ld][%d] = %f\n",
85+
// in_tensor,
86+
// batch_id,
87+
// in_start + worker_id,
88+
// input_ptr[in_start + worker_id]);
89+
if (fbgemm_gpu::is_aligned<fbgemm_gpu::Vec4T<scalar_t>>(
90+
&output_ptr[out_start]) &&
91+
fbgemm_gpu::is_aligned<fbgemm_gpu::Vec4T<scalar_t>>(
92+
&input_ptr[in_start])) {
93+
const int32_t vec_size = 4;
94+
const int32_t loop_end = length / (vec_size) * (vec_size);
95+
for (int32_t i = worker_id * vec_size; i < loop_end;
96+
i += worker_size * vec_size) {
97+
fbgemm_gpu::Vec4T<scalar_t>::copy(
98+
&input_ptr[in_start + i], &output_ptr[out_start + i]);
99+
}
100+
// Use elementwise access for the last incomplete vector.
101+
for (int32_t i = loop_end + worker_id; i < length; i += worker_size) {
102+
output_ptr[out_start + i] = input_ptr[in_start + i];
103+
}
104+
} else { // Fallback if not aligned.
105+
for (int32_t i = worker_id; i < length; i += worker_size) {
106+
output_ptr[out_start + i] = input_ptr[in_start + i];
107+
}
108+
}
109+
110+
// for reverse_permute (backward) with jump
111+
while (reverse_permute && jump > 0 && jump < permute_size) {
112+
in_tensor = permutes[params * jump + 1];
113+
in_start = permutes[params * jump + 3];
114+
length = permutes[params * jump + 4];
115+
jump = -permutes[params * jump + 5];
116+
117+
int64_t in_length = input_lengths[in_tensor];
118+
scalar_t* input_ptr = (scalar_t*)inputs[in_tensor];
119+
input_ptr += batch_id * in_length;
120+
121+
for (int32_t i = worker_id; i < length; i += worker_size) {
122+
output_ptr[out_start + i] += input_ptr[in_start + i];
123+
}
124+
}
125+
}
126+
127+
template <typename index_t>
128+
Tensor from_vec(const std::vector<index_t> input) {
129+
const auto int_opts =
130+
torch::TensorOptions().dtype(torch::kInt64).pinned_memory(true);
131+
Tensor output = at::empty({static_cast<index_t>(input.size())}, int_opts);
132+
// Ensure that output is contiguous
133+
TORCH_CHECK(output.is_contiguous());
134+
std::memcpy(
135+
output.data_ptr<index_t>(), input.data(), input.size() * sizeof(index_t));
136+
return output;
137+
}
138+
139+
template <typename scalar_t>
140+
Tensor tensors_ptr(const at::TensorList& tensors) {
141+
auto size = tensors.size();
142+
Tensor ptr_tensor = at::empty(
143+
{static_cast<long>(size * sizeof(scalar_t*))},
144+
at::TensorOptions().dtype(tensors[0].scalar_type()).pinned_memory(true));
145+
146+
// Ensure that ptr_tensor is contiguous
147+
TORCH_CHECK(ptr_tensor.is_contiguous());
148+
auto tp = reinterpret_cast<scalar_t**>(ptr_tensor.data_ptr());
149+
for (int32_t i = 0; i < tensors.size(); i++) {
150+
tp[i] = tensors[i].data_ptr<scalar_t>();
151+
}
152+
// Ensure that ptr_tensor is contiguous
153+
TORCH_CHECK(ptr_tensor.is_contiguous());
154+
return ptr_tensor;
155+
}
156+
157+
std::vector<Tensor> permute_multi_embedding_gpu(
158+
const at::TensorList& pooled_embs,
159+
const std::vector<int64_t>& permutes,
160+
const std::vector<int64_t>& in_lengths,
161+
const std::vector<int64_t>& out_lengths,
162+
const bool& reverse_permute) {
163+
const int64_t permute_param = 6;
164+
int64_t num_of_input_tensors = in_lengths.size();
165+
int64_t num_of_output_tensors = out_lengths.size();
166+
int64_t batch_size = pooled_embs[0].size(0);
167+
int64_t permute_size = permutes.size() / permute_param;
168+
169+
// check input tensors
170+
std::vector<Tensor> inputs;
171+
inputs.reserve(pooled_embs.size());
172+
for (int32_t i = 0; i < num_of_input_tensors; i++) {
173+
Tensor cont_tensor = pooled_embs[i].contiguous();
174+
inputs.push_back(cont_tensor);
175+
TENSORS_ON_SAME_DEVICE(cont_tensor, pooled_embs[i]);
176+
TENSORS_ON_SAME_DEVICE(pooled_embs[i], pooled_embs[0]);
177+
}
178+
179+
// initiate output tensors
180+
std::vector<Tensor> outputs;
181+
outputs.reserve(num_of_output_tensors);
182+
for (int32_t i = 0; i < num_of_output_tensors; i++) {
183+
Tensor output =
184+
at::empty({batch_size, out_lengths[i]}, pooled_embs[0].options());
185+
outputs.push_back(output);
186+
}
187+
188+
auto permutes_tensor = from_vec<int64_t>(permutes);
189+
auto in_lengths_tensor = from_vec<int64_t>(in_lengths);
190+
auto out_lengths_tensor = from_vec<int64_t>(out_lengths);
191+
192+
auto device = pooled_embs[0].device();
193+
permutes_tensor = permutes_tensor.to(device, /*non_blocking=*/true);
194+
in_lengths_tensor = in_lengths_tensor.to(device, /*non_blocking=*/true);
195+
out_lengths_tensor = out_lengths_tensor.to(device, /*non_blocking=*/true);
196+
197+
// This kernel is moving D elements per warp.
198+
// We are launching ( div_round_up(T, warp_per_block), B ) blocks.
199+
// The grid z dimension is also used by batch_size in case it's greater than
200+
// 65535.
201+
const int32_t warp_per_block =
202+
fbgemm_gpu::kMaxThreads / fbgemm_gpu::kWarpSize;
203+
const int32_t max_grid_dim_y =
204+
32768; // The CUDA maximum is 65535, not a power of 2.
205+
const dim3 threads(fbgemm_gpu::kMaxThreads);
206+
const dim3 blocks(
207+
fbgemm_gpu::div_round_up(permute_size, warp_per_block),
208+
std::min(static_cast<int32_t>(batch_size), max_grid_dim_y),
209+
(batch_size + max_grid_dim_y - 1) / max_grid_dim_y);
210+
211+
FBGEMM_DISPATCH_FLOATING_TYPES(
212+
pooled_embs[0].scalar_type(), "permute_multi_embedding", [&] {
213+
Tensor in_tensor = tensors_ptr<scalar_t>(inputs);
214+
Tensor out_tensor = tensors_ptr<scalar_t>(outputs);
215+
in_tensor = in_tensor.to(device, /*non_blocking=*/true);
216+
out_tensor = out_tensor.to(device, /*non_blocking=*/true);
217+
permute_multi_embs_kernel<scalar_t>
218+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
219+
(const scalar_t**)in_tensor.data_ptr(),
220+
(const scalar_t**)out_tensor.data_ptr(),
221+
permutes_tensor.data_ptr<int64_t>(),
222+
in_lengths_tensor.data_ptr<int64_t>(),
223+
out_lengths_tensor.data_ptr<int64_t>(),
224+
batch_size,
225+
permute_size,
226+
reverse_permute);
227+
C10_CUDA_KERNEL_LAUNCH_CHECK();
228+
});
229+
return outputs;
230+
}
231+
232+
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)