44
55#include < cassert>
66
7- #ifdef __HIP_PLATFORM_AMD__
8- #include < torchaudio/csrc/rnnt/hip/kernel_utils.h>
9- #include < torchaudio/csrc/rnnt/hip/kernels.h>
10- #include < torchaudio/csrc/rnnt/hip/math_hip.cuh>
11- #else
127#include < torchaudio/csrc/rnnt/gpu/kernel_utils.h>
138#include < torchaudio/csrc/rnnt/gpu/kernels.h>
149#include < torchaudio/csrc/rnnt/gpu/math.cuh>
15- #endif
1610
1711namespace torchaudio {
1812namespace rnnt {
@@ -132,11 +126,7 @@ __device__ void ComputeAlphas(
132126
133127#pragma unroll
134128 for (int i = 1 ; i < warpSize ; i <<= 1 ) {
135- #ifdef __HIP_PLATFORM_AMD__
136- val = __shfl_up (skip_prob, i);
137- #else
138129 val = __shfl_up_sync (0xffffffff , skip_prob, i);
139- #endif
140130 if (i <= threadIdx .x ) {
141131 skip_prob = skip_prob + val;
142132 }
@@ -160,11 +150,7 @@ __device__ void ComputeAlphas(
160150 CAST_DTYPE out = val;
161151
162152 for (int i = 1 ; i < warpSize ; ++i) {
163- #ifdef __HIP_PLATFORM_AMD__
164- val = __shfl_up (val, 1 );
165- #else
166153 val = __shfl_up_sync (0xffffffff , val, 1 );
167- #endif
168154 if (i == threadIdx .x ) {
169155 val = math::lse (val + skip_prob, emit);
170156 out = val;
@@ -239,11 +225,7 @@ __device__ void ComputeBetasCosts(
239225
240226#pragma unroll
241227 for (int i = 1 ; i < warpSize ; i <<= 1 ) {
242- #ifdef __HIP_PLATFORM_AMD__
243- val = __shfl_up (skip_prob, i);
244- #else
245228 val = __shfl_up_sync (0xffffffff , skip_prob, i);
246- #endif
247229 if (i <= threadIdx .x ) {
248230 skip_prob = skip_prob + val;
249231 }
@@ -266,11 +248,7 @@ __device__ void ComputeBetasCosts(
266248 CAST_DTYPE out = val;
267249
268250 for (int i = 1 ; i < warpSize ; ++i) {
269- #ifdef __HIP_PLATFORM_AMD__
270- val = __shfl_up (val, 1 );
271- #else
272251 val = __shfl_up_sync (0xffffffff , val, 1 );
273- #endif
274252 if (i == threadIdx .x ) {
275253 val = math::lse (val + skip_prob, emit);
276254 out = val;
0 commit comments