|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#include <ATen/AccumulateType.h> |
| 10 | +#include <ATen/cuda/CUDAContext.h> |
| 11 | +#include <ATen/cuda/Exceptions.h> |
| 12 | +#include <c10/cuda/CUDAGuard.h> |
| 13 | +#include <cuda.h> |
| 14 | +#include <cuda_runtime.h> |
| 15 | +#include <cstdint> |
| 16 | +#include <ostream> |
| 17 | + |
| 18 | +#include "fbgemm_gpu/fbgemm_cuda_utils.cuh" |
| 19 | +#include "fbgemm_gpu/permute_multi_embedding_function.h" |
| 20 | + |
| 21 | +using Tensor = at::Tensor; |
| 22 | + |
| 23 | +namespace fbgemm_gpu { |
| 24 | + |
| 25 | +// Kernerl for permute pooled embedding op. |
| 26 | +// This kernel is moving D elements per warp. |
| 27 | +template <typename scalar_t> |
| 28 | +__global__ void permute_multi_embs_kernel( |
| 29 | + const scalar_t** __restrict__ inputs, |
| 30 | + const scalar_t** __restrict__ outputs, |
| 31 | + const int64_t* __restrict__ permutes, |
| 32 | + const int64_t* __restrict__ input_lengths, |
| 33 | + const int64_t* __restrict__ output_lengths, |
| 34 | + const int64_t batch_size, |
| 35 | + const int64_t permute_size, |
| 36 | + const bool reverse_permute) { |
| 37 | + // workers in a warp handle a feature |
| 38 | + const int32_t worker_id = threadIdx.x % warpSize; |
| 39 | + const int32_t worker_size = warpSize; |
| 40 | + const int32_t permute_id = |
| 41 | + blockIdx.x * (blockDim.x / warpSize) + threadIdx.x / warpSize; |
| 42 | + const int32_t batch_id = blockIdx.y + gridDim.y * blockIdx.z; |
| 43 | + if (batch_id >= batch_size) { |
| 44 | + return; |
| 45 | + } |
| 46 | + if (permute_id >= permute_size) { |
| 47 | + return; |
| 48 | + } |
| 49 | + |
| 50 | + // parse permutes |
| 51 | + const int64_t params = 6; |
| 52 | + int64_t in_tensor, out_tensor, in_start, out_start, length, jump; |
| 53 | + if (reverse_permute) { |
| 54 | + out_tensor = permutes[params * permute_id]; |
| 55 | + in_tensor = permutes[params * permute_id + 1]; |
| 56 | + out_start = permutes[params * permute_id + 2]; |
| 57 | + in_start = permutes[params * permute_id + 3]; |
| 58 | + } else { |
| 59 | + in_tensor = permutes[params * permute_id]; |
| 60 | + out_tensor = permutes[params * permute_id + 1]; |
| 61 | + in_start = permutes[params * permute_id + 2]; |
| 62 | + out_start = permutes[params * permute_id + 3]; |
| 63 | + } |
| 64 | + length = permutes[params * permute_id + 4]; |
| 65 | + jump = permutes[params * permute_id + 5]; |
| 66 | + |
| 67 | + if (worker_id >= length) { |
| 68 | + return; |
| 69 | + } |
| 70 | + if (reverse_permute && jump < 0) { |
| 71 | + return; |
| 72 | + } |
| 73 | + |
| 74 | + // locate the batch_id |
| 75 | + int64_t in_length = input_lengths[in_tensor]; |
| 76 | + scalar_t* input_ptr = (scalar_t*)inputs[in_tensor]; |
| 77 | + input_ptr += batch_id * in_length; |
| 78 | + |
| 79 | + int64_t out_length = output_lengths[out_tensor]; |
| 80 | + scalar_t* output_ptr = (scalar_t*)outputs[out_tensor]; |
| 81 | + output_ptr += batch_id * out_length; |
| 82 | + |
| 83 | + // printf( // debug print |
| 84 | + // "input_tensors[%ld][%ld][%d] = %f\n", |
| 85 | + // in_tensor, |
| 86 | + // batch_id, |
| 87 | + // in_start + worker_id, |
| 88 | + // input_ptr[in_start + worker_id]); |
| 89 | + if (fbgemm_gpu::is_aligned<fbgemm_gpu::Vec4T<scalar_t>>( |
| 90 | + &output_ptr[out_start]) && |
| 91 | + fbgemm_gpu::is_aligned<fbgemm_gpu::Vec4T<scalar_t>>( |
| 92 | + &input_ptr[in_start])) { |
| 93 | + const int32_t vec_size = 4; |
| 94 | + const int32_t loop_end = length / (vec_size) * (vec_size); |
| 95 | + for (int32_t i = worker_id * vec_size; i < loop_end; |
| 96 | + i += worker_size * vec_size) { |
| 97 | + fbgemm_gpu::Vec4T<scalar_t>::copy( |
| 98 | + &input_ptr[in_start + i], &output_ptr[out_start + i]); |
| 99 | + } |
| 100 | + // Use elementwise access for the last incomplete vector. |
| 101 | + for (int32_t i = loop_end + worker_id; i < length; i += worker_size) { |
| 102 | + output_ptr[out_start + i] = input_ptr[in_start + i]; |
| 103 | + } |
| 104 | + } else { // Fallback if not aligned. |
| 105 | + for (int32_t i = worker_id; i < length; i += worker_size) { |
| 106 | + output_ptr[out_start + i] = input_ptr[in_start + i]; |
| 107 | + } |
| 108 | + } |
| 109 | + |
| 110 | + // for reverse_permute (backward) with jump |
| 111 | + while (reverse_permute && jump > 0 && jump < permute_size) { |
| 112 | + in_tensor = permutes[params * jump + 1]; |
| 113 | + in_start = permutes[params * jump + 3]; |
| 114 | + length = permutes[params * jump + 4]; |
| 115 | + jump = -permutes[params * jump + 5]; |
| 116 | + |
| 117 | + int64_t in_length = input_lengths[in_tensor]; |
| 118 | + scalar_t* input_ptr = (scalar_t*)inputs[in_tensor]; |
| 119 | + input_ptr += batch_id * in_length; |
| 120 | + |
| 121 | + for (int32_t i = worker_id; i < length; i += worker_size) { |
| 122 | + output_ptr[out_start + i] += input_ptr[in_start + i]; |
| 123 | + } |
| 124 | + } |
| 125 | +} |
| 126 | + |
| 127 | +template <typename index_t> |
| 128 | +Tensor from_vec(const std::vector<index_t> input) { |
| 129 | + const auto int_opts = |
| 130 | + torch::TensorOptions().dtype(torch::kInt64).pinned_memory(true); |
| 131 | + Tensor output = at::empty({static_cast<index_t>(input.size())}, int_opts); |
| 132 | + // Ensure that output is contiguous |
| 133 | + TORCH_CHECK(output.is_contiguous()); |
| 134 | + std::memcpy( |
| 135 | + output.data_ptr<index_t>(), input.data(), input.size() * sizeof(index_t)); |
| 136 | + return output; |
| 137 | +} |
| 138 | + |
| 139 | +template <typename scalar_t> |
| 140 | +Tensor tensors_ptr(const at::TensorList& tensors) { |
| 141 | + auto size = tensors.size(); |
| 142 | + Tensor ptr_tensor = at::empty( |
| 143 | + {static_cast<long>(size * sizeof(scalar_t*))}, |
| 144 | + at::TensorOptions().dtype(tensors[0].scalar_type()).pinned_memory(true)); |
| 145 | + |
| 146 | + // Ensure that ptr_tensor is contiguous |
| 147 | + TORCH_CHECK(ptr_tensor.is_contiguous()); |
| 148 | + auto tp = reinterpret_cast<scalar_t**>(ptr_tensor.data_ptr()); |
| 149 | + for (int32_t i = 0; i < tensors.size(); i++) { |
| 150 | + tp[i] = tensors[i].data_ptr<scalar_t>(); |
| 151 | + } |
| 152 | + // Ensure that ptr_tensor is contiguous |
| 153 | + TORCH_CHECK(ptr_tensor.is_contiguous()); |
| 154 | + return ptr_tensor; |
| 155 | +} |
| 156 | + |
| 157 | +std::vector<Tensor> permute_multi_embedding_gpu( |
| 158 | + const at::TensorList& pooled_embs, |
| 159 | + const std::vector<int64_t>& permutes, |
| 160 | + const std::vector<int64_t>& in_lengths, |
| 161 | + const std::vector<int64_t>& out_lengths, |
| 162 | + const bool& reverse_permute) { |
| 163 | + const int64_t permute_param = 6; |
| 164 | + int64_t num_of_input_tensors = in_lengths.size(); |
| 165 | + int64_t num_of_output_tensors = out_lengths.size(); |
| 166 | + int64_t batch_size = pooled_embs[0].size(0); |
| 167 | + int64_t permute_size = permutes.size() / permute_param; |
| 168 | + |
| 169 | + // check input tensors |
| 170 | + std::vector<Tensor> inputs; |
| 171 | + inputs.reserve(pooled_embs.size()); |
| 172 | + for (int32_t i = 0; i < num_of_input_tensors; i++) { |
| 173 | + Tensor cont_tensor = pooled_embs[i].contiguous(); |
| 174 | + inputs.push_back(cont_tensor); |
| 175 | + TENSORS_ON_SAME_DEVICE(cont_tensor, pooled_embs[i]); |
| 176 | + TENSORS_ON_SAME_DEVICE(pooled_embs[i], pooled_embs[0]); |
| 177 | + } |
| 178 | + |
| 179 | + // initiate output tensors |
| 180 | + std::vector<Tensor> outputs; |
| 181 | + outputs.reserve(num_of_output_tensors); |
| 182 | + for (int32_t i = 0; i < num_of_output_tensors; i++) { |
| 183 | + Tensor output = |
| 184 | + at::empty({batch_size, out_lengths[i]}, pooled_embs[0].options()); |
| 185 | + outputs.push_back(output); |
| 186 | + } |
| 187 | + |
| 188 | + auto permutes_tensor = from_vec<int64_t>(permutes); |
| 189 | + auto in_lengths_tensor = from_vec<int64_t>(in_lengths); |
| 190 | + auto out_lengths_tensor = from_vec<int64_t>(out_lengths); |
| 191 | + |
| 192 | + auto device = pooled_embs[0].device(); |
| 193 | + permutes_tensor = permutes_tensor.to(device, /*non_blocking=*/true); |
| 194 | + in_lengths_tensor = in_lengths_tensor.to(device, /*non_blocking=*/true); |
| 195 | + out_lengths_tensor = out_lengths_tensor.to(device, /*non_blocking=*/true); |
| 196 | + |
| 197 | + // This kernel is moving D elements per warp. |
| 198 | + // We are launching ( div_round_up(T, warp_per_block), B ) blocks. |
| 199 | + // The grid z dimension is also used by batch_size in case it's greater than |
| 200 | + // 65535. |
| 201 | + const int32_t warp_per_block = |
| 202 | + fbgemm_gpu::kMaxThreads / fbgemm_gpu::kWarpSize; |
| 203 | + const int32_t max_grid_dim_y = |
| 204 | + 32768; // The CUDA maximum is 65535, not a power of 2. |
| 205 | + const dim3 threads(fbgemm_gpu::kMaxThreads); |
| 206 | + const dim3 blocks( |
| 207 | + fbgemm_gpu::div_round_up(permute_size, warp_per_block), |
| 208 | + std::min(static_cast<int32_t>(batch_size), max_grid_dim_y), |
| 209 | + (batch_size + max_grid_dim_y - 1) / max_grid_dim_y); |
| 210 | + |
| 211 | + FBGEMM_DISPATCH_FLOATING_TYPES( |
| 212 | + pooled_embs[0].scalar_type(), "permute_multi_embedding", [&] { |
| 213 | + Tensor in_tensor = tensors_ptr<scalar_t>(inputs); |
| 214 | + Tensor out_tensor = tensors_ptr<scalar_t>(outputs); |
| 215 | + in_tensor = in_tensor.to(device, /*non_blocking=*/true); |
| 216 | + out_tensor = out_tensor.to(device, /*non_blocking=*/true); |
| 217 | + permute_multi_embs_kernel<scalar_t> |
| 218 | + <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( |
| 219 | + (const scalar_t**)in_tensor.data_ptr(), |
| 220 | + (const scalar_t**)out_tensor.data_ptr(), |
| 221 | + permutes_tensor.data_ptr<int64_t>(), |
| 222 | + in_lengths_tensor.data_ptr<int64_t>(), |
| 223 | + out_lengths_tensor.data_ptr<int64_t>(), |
| 224 | + batch_size, |
| 225 | + permute_size, |
| 226 | + reverse_permute); |
| 227 | + C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 228 | + }); |
| 229 | + return outputs; |
| 230 | +} |
| 231 | + |
| 232 | +} // namespace fbgemm_gpu |
0 commit comments