Skip to content

Invoke AMD specific kernel reorder_batched_ad_indices_kernel_vec #4412

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ __launch_bounds__(fbgemm_gpu::kMaxThreads) void reorder_batched_ad_indices_kerne
// Idea: we want to copy the entire segment of size sum_a(length_{b, t, a})
// from starting point (given by cat_ad_offsets[b, t])
// to end point (given by reordered_cat_ad_indices[t][b])
if (num_elements <= 64) {
if (num_elements <= 64 || !(sizeof(Dtype) == 4 || sizeof(Dtype) == 8)) {
for (auto i = threadIdx.x; i < input_segment_end - input_segment_start;
i += blockDim.x) {
// coalesced global memory access, can be optimzed through ILP with the
Expand Down Expand Up @@ -465,7 +465,35 @@ DLL_PUBLIC Tensor reorder_batched_ad_indices_gpu(
[&] {
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name = "reorder_batched_ad_indices_kernel";
const auto func_name_amd = "reorder_batched_ad_indices_kernel";
#endif
#if defined __HIP_PLATFORM_AMD__
constexpr auto NUM_WARPS1 = 4;
const dim3 threads1(32, NUM_WARPS1); // 32 x 4
const dim3 blocks1(cuda_calc_xblock_count(B * T, NUM_WARPS1));
reorder_batched_ad_indices_kernel_vec<scalar_t, index_t>
<<<blocks1, threads1, 0, at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(
func_name_amd, cat_ad_offsets, index_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name_amd, cat_ad_indices, scalar_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name_amd,
reordered_cat_ad_offsets,
index_t,
1,
32),
MAKE_PTA_WITH_NAME(
func_name_amd,
reordered_cat_ad_indices,
scalar_t,
1,
32),
MAKE_PTA_WITH_NAME(
func_name_amd, batch_offsets, int32_t, 1, 32),
T,
broadcast_indices);
#else
reorder_batched_ad_indices_kernel<scalar_t, index_t>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(
Expand All @@ -480,6 +508,7 @@ DLL_PUBLIC Tensor reorder_batched_ad_indices_gpu(
func_name, batch_offsets, int32_t, 1, 32),
T,
broadcast_indices);
#endif
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
Expand Down
Loading