Skip to content

Commit df31a76

Browse files
spcypptfacebook-github-bot
authored andcommitted
Enable int32_t support for reshape_vbe_offsets
Summary: X-link: facebookresearch/FBGEMM#866 As titled. Differential Revision: D70760386
1 parent de35b3c commit df31a76

File tree

3 files changed

+39
-20
lines changed

3 files changed

+39
-20
lines changed

fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@
2626
#include "fbgemm_gpu/utils/ops_utils.h"
2727
#include "fbgemm_gpu/utils/dispatch_macros.h"
2828
#include "fbgemm_gpu/embedding_common.h"
29+
// #include <ATen/ATen.h>
30+
#include <ATen/Dispatch.h>
31+
#include <ATen/TensorUtils.h>
32+
// #include <ATen/Functions.h>
33+
// #include <ATen/TypeDefault.h>
34+
// #include "fbgemm_gpu/utils/tensor_utils.h"
2935
{%- if has_vbe_support %}
3036
#include "fbgemm_gpu/utils/pt2_autograd_utils.h"
3137
{%- endif %}
@@ -64,12 +70,15 @@ Tensor split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cpu_wrapper(
6470
{%- endif %}
6571
) {
6672
{%- if vbe %}
67-
const auto offsets_ = reshape_vbe_offsets(
68-
offsets,
69-
vbe_B_offsets_rank_per_feature,
70-
max_B,
71-
D_offsets.numel() - 1
72-
);
73+
Tensor offsets_;
74+
AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "reshape_vbe_offsets_cpu_grad_indices", [&]() {
75+
offsets_ = reshape_vbe_offsets<index_t>(
76+
offsets,
77+
vbe_B_offsets_rank_per_feature,
78+
max_B,
79+
D_offsets.numel() - 1
80+
);
81+
});
7382
const auto grad_output_ = reshape_vbe_output(
7483
grad_output,
7584
max_B,
@@ -126,8 +135,11 @@ Tensor split_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper(
126135
{%- endif %}
127136
const bool /*is_experimental = false*/,
128137
const int64_t output_dtype = static_cast<int64_t>(SparseType::FP32)) {
138+
Tensor offsets_;
129139
{%- if vbe %}
130-
const auto offsets_ = reshape_vbe_offsets(offsets, vbe_B_offsets_rank_per_feature, max_B, D_offsets.numel() - 1);
140+
AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "reshape_vbe_offsets_cpu_forward", [&]() {
141+
offsets_ = reshape_vbe_offsets<index_t>(offsets, vbe_B_offsets_rank_per_feature, max_B, D_offsets.numel() - 1);
142+
});
131143
{%- endif %}
132144
static auto op =
133145
torch::Dispatcher::singleton()
@@ -226,7 +238,10 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
226238
{%- endif %})
227239
{
228240
{%- if vbe %}
229-
const auto offsets_ = reshape_vbe_offsets(offsets, vbe_B_offsets_rank_per_feature, max_B, D_offsets.numel() - 1);
241+
Tensor offsets_;
242+
AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "reshape_vbe_offsets_cpu_backward", [&]() {
243+
offsets_ = reshape_vbe_offsets<index_t>(offsets, vbe_B_offsets_rank_per_feature, max_B, D_offsets.numel() - 1);
244+
});
230245
const auto grad_output_ = reshape_vbe_output(grad_output, max_B, vbe_B_offsets_rank_per_feature, D_offsets);
231246
{%- endif %}
232247
{%- set backward_op = "split_embedding_backward_codegen_{}_cpu".format(

fbgemm_gpu/codegen/training/pt2/pt2_autograd_utils.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ void checked_memcpy(
113113
/// size(1) is number of ranks
114114
/// @param max_B Maximum batch size
115115
/// @param T Number of embedding tables (features)
116+
template <typename index_t>
116117
Tensor reshape_vbe_offsets(
117118
const Tensor& offsets,
118119
const Tensor& B_offsets_rank_per_feature,
@@ -125,12 +126,8 @@ Tensor reshape_vbe_offsets(
125126
B_offsets_rank_per_feature.accessor<int32_t, 2>();
126127
auto reshaped_offsets = at::empty({T * max_B + 1}, offsets.options());
127128
// TODO: support other types
128-
TORCH_CHECK(
129-
offsets.dtype() == at::kLong,
130-
"Expected offsets to be int64 but got ",
131-
offsets.dtype());
132-
auto reshaped_offsets_acc = reshaped_offsets.accessor<int64_t, 1>();
133-
auto offsets_acc = offsets.accessor<int64_t, 1>();
129+
auto reshaped_offsets_acc = reshaped_offsets.accessor<index_t, 1>();
130+
auto offsets_acc = offsets.accessor<index_t, 1>();
134131
auto begin = 0;
135132
for (int32_t t = 0; t < T; t++) {
136133
const auto batch_size =
@@ -167,4 +164,16 @@ Tensor reshape_vbe_offsets(
167164
return reshaped_offsets;
168165
}
169166

167+
template Tensor reshape_vbe_offsets<int32_t>(
168+
const Tensor& offsets,
169+
const Tensor& B_offsets_rank_per_feature,
170+
const int64_t max_B,
171+
const int32_t T);
172+
173+
template Tensor reshape_vbe_offsets<int64_t>(
174+
const Tensor& offsets,
175+
const Tensor& B_offsets_rank_per_feature,
176+
const int64_t max_B,
177+
const int32_t T);
178+
170179
} // namespace fbgemm_gpu

fbgemm_gpu/include/fbgemm_gpu/utils/pt2_autograd_utils.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,6 @@
88

99
#include <ATen/ATen.h>
1010
#include <ATen/TypeDefault.h>
11-
// #include <ATen/core/op_registration/op_registration.h>
12-
// #include <torch/script.h>
13-
// #include "fbgemm_gpu/embedding_common.h"
14-
// #include "fbgemm_gpu/utils/dispatch_macros.h"
15-
// #include "fbgemm_gpu/utils/ops_utils.h"
16-
// #include "fbgemm_gpu/utils/tensor_utils.h"
1711

1812
using Tensor = at::Tensor;
1913

@@ -29,6 +23,7 @@ Tensor reshape_vbe_output(
2923
const Tensor& B_offsets_rank_per_feature,
3024
const Tensor& D_offsets);
3125

26+
template <typename index_t>
3227
Tensor reshape_vbe_offsets(
3328
const Tensor& offsets,
3429
const Tensor& B_offsets_rank_per_feature,

0 commit comments

Comments
 (0)