@@ -140,14 +140,14 @@ void split_embedding_backward_count_unique_indices_kernel
140
140
141
141
{% for vbe in [True, False] %}
142
142
{% set vdesc = " _vbe" if vbe else " " %}
143
- template <typename grad_t >
143
+ template <typename grad_t , typename offset_t >
144
144
__global__ __launch_bounds__ (kMaxThreads ) void grad_mean{{ vdesc }}_kernel(
145
145
pta::PackedTensorAccessor64<grad_t , 2 , at::RestrictPtrTraits>
146
146
grad_output_mean,
147
147
const pta::PackedTensorAccessor64<grad_t , 2 , at::RestrictPtrTraits>
148
148
grad_output,
149
149
const pta::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits> D_offsets,
150
- const pta::PackedTensorAccessor32<int64_t , 1 , at::RestrictPtrTraits> offsets,
150
+ const pta::PackedTensorAccessor32<offset_t , 1 , at::RestrictPtrTraits> offsets,
151
151
{% if vbe %}
152
152
const pta::PackedTensorAccessor32<int64_t , 1 , at::RestrictPtrTraits> row_grad_offsets,
153
153
const pta::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits> b_t_map,
@@ -212,15 +212,16 @@ __global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vdesc }}_kernel(
212
212
// //////////////////////////////////////////////////////////////////////////////
213
213
214
214
{% for grad_type in [' at::Half' , ' float' , ' at::BFloat16' ] %}
215
+ {% for offset_type in [' int32_t' , ' int64_t' ] %}
215
216
template __global__ __launch_bounds__ (kMaxThreads )
216
217
void grad_mean{{ vdesc }}_kernel
217
- <{{ grad_type }}> (
218
+ <{{ grad_type }}, {{ offset_type }} > (
218
219
pta::PackedTensorAccessor64<{{ grad_type }}, 2 , at::RestrictPtrTraits>
219
220
grad_output_mean,
220
221
const pta::PackedTensorAccessor64<{{ grad_type }}, 2 , at::RestrictPtrTraits>
221
222
grad_output,
222
223
const pta::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits> D_offsets,
223
- const pta::PackedTensorAccessor32<int64_t , 1 , at::RestrictPtrTraits> offsets,
224
+ const pta::PackedTensorAccessor32<{{ offset_type }} , 1 , at::RestrictPtrTraits> offsets,
224
225
{% if vbe %}
225
226
const pta::PackedTensorAccessor32<int64_t , 1 , at::RestrictPtrTraits> row_grad_offsets,
226
227
const pta::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits> b_t_map,
@@ -230,6 +231,7 @@ void grad_mean{{ vdesc }}_kernel
230
231
FixedDivisor fd_B
231
232
{% endif %}
232
233
);
234
+ {% endfor %} // for offset_type in ['int32_t', 'int64_t']
233
235
{% endfor %} // for grad_type in ['at::Half', 'float']
234
236
{% endfor %} // for vbe in [True, False]
235
237
0 commit comments