Skip to content

Commit b2c4900

Browse files
q10facebook-github-bot
authored andcommitted
Add support for int32_t indices in TBE training (2D/N) (pytorch#3374)
Summary: X-link: facebookresearch/FBGEMM#620 X-link: facebookresearch/FBGEMM#464 - Add `index_t` support to TBE training backward kernels Reviewed By: basilwong Differential Revision: D65930273
1 parent 5ee9c75 commit b2c4900

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

fbgemm_gpu/codegen/training/backward/embedding_backward_split_grad_template.cu

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,14 +140,14 @@ void split_embedding_backward_count_unique_indices_kernel
140140

141141
{% for vbe in [True, False] %}
142142
{% set vdesc = "_vbe" if vbe else "" %}
143-
template <typename grad_t>
143+
template <typename grad_t, typename offset_t>
144144
__global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vdesc }}_kernel(
145145
pta::PackedTensorAccessor64<grad_t, 2, at::RestrictPtrTraits>
146146
grad_output_mean,
147147
const pta::PackedTensorAccessor64<grad_t, 2, at::RestrictPtrTraits>
148148
grad_output,
149149
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets,
150-
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> offsets,
150+
const pta::PackedTensorAccessor32<offset_t, 1, at::RestrictPtrTraits> offsets,
151151
{% if vbe %}
152152
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> row_grad_offsets,
153153
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> b_t_map,
@@ -212,15 +212,16 @@ __global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vdesc }}_kernel(
212212
////////////////////////////////////////////////////////////////////////////////
213213

214214
{% for grad_type in ['at::Half', 'float', 'at::BFloat16'] %}
215+
{% for offset_type in ['int32_t', 'int64_t'] %}
215216
template __global__ __launch_bounds__(kMaxThreads)
216217
void grad_mean{{ vdesc }}_kernel
217-
<{{ grad_type }}> (
218+
<{{ grad_type }}, {{ offset_type }}> (
218219
pta::PackedTensorAccessor64<{{ grad_type }}, 2, at::RestrictPtrTraits>
219220
grad_output_mean,
220221
const pta::PackedTensorAccessor64<{{ grad_type }}, 2, at::RestrictPtrTraits>
221222
grad_output,
222223
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets,
223-
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> offsets,
224+
const pta::PackedTensorAccessor32<{{ offset_type }}, 1, at::RestrictPtrTraits> offsets,
224225
{% if vbe %}
225226
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> row_grad_offsets,
226227
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> b_t_map,
@@ -230,6 +231,7 @@ void grad_mean{{ vdesc }}_kernel
230231
FixedDivisor fd_B
231232
{% endif %}
232233
);
234+
{% endfor %} // for offset_type in ['int32_t', 'int64_t']
233235
{% endfor %} // for grad_type in ['at::Half', 'float']
234236
{% endfor %} // for vbe in [True, False]
235237

fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,13 +308,13 @@ split_embedding_backward_codegen_find_long_segments(
308308
const bool use_deterministic_algorithms);
309309

310310

311-
template <typename grad_t>
311+
template <typename grad_t, typename offset_t>
312312
__global__ __launch_bounds__(kMaxThreads) void
313313
grad_mean{{ vdesc }}_kernel(
314314
pta::PackedTensorAccessor64<grad_t, 2, at::RestrictPtrTraits> grad_output_mean,
315315
const pta::PackedTensorAccessor64<grad_t, 2, at::RestrictPtrTraits> grad_output,
316316
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets,
317-
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> offsets,
317+
const pta::PackedTensorAccessor32<offset_t, 1, at::RestrictPtrTraits> offsets,
318318
{%- if vbe %}
319319
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> grad_offsets,
320320
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> b_t_map,
@@ -950,7 +950,7 @@ Tensor {{ embedding_cuda_op }}(
950950
MAKE_PTA_WITH_NAME(func_name1, grad_output_mean, grad_t, 2, 64),
951951
MAKE_PTA_WITH_NAME(func_name1, grad_output_reshaped, grad_t, 2, 64),
952952
MAKE_PTA_WITH_NAME(func_name1, D_offsets, int32_t, 1, 32),
953-
MAKE_PTA_WITH_NAME(func_name1, offsets, int64_t, 1, 32),
953+
MAKE_PTA_WITH_NAME(func_name1, offsets, index_t, 1, 32),
954954
{%- if vbe %}
955955
MAKE_PTA_WITH_NAME(func_name1, vbe_row_output_offsets, int64_t, 1, 32),
956956
MAKE_PTA_WITH_NAME(func_name1, vbe_b_t_map, int32_t, 1, 32),

0 commit comments

Comments
 (0)