|
26 | 26 | #include "fbgemm_gpu/utils/ops_utils.h"
|
27 | 27 | #include "fbgemm_gpu/utils/dispatch_macros.h"
|
28 | 28 | #include "fbgemm_gpu/embedding_common.h"
|
| 29 | +// #include <ATen/ATen.h> |
| 30 | +#include <ATen/Dispatch.h> |
| 31 | +#include <ATen/TensorUtils.h> |
| 32 | +// #include <ATen/Functions.h> |
| 33 | +// #include <ATen/TypeDefault.h> |
| 34 | +// #include "fbgemm_gpu/utils/tensor_utils.h" |
29 | 35 | {%- if has_vbe_support %}
|
30 | 36 | #include "fbgemm_gpu/utils/pt2_autograd_utils.h"
|
31 | 37 | {%- endif %}
|
@@ -64,12 +70,15 @@ Tensor split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cpu_wrapper(
|
64 | 70 | {%- endif %}
|
65 | 71 | ) {
|
66 | 72 | {%- if vbe %}
|
67 |
| - const auto offsets_ = reshape_vbe_offsets( |
68 |
| - offsets, |
69 |
| - vbe_B_offsets_rank_per_feature, |
70 |
| - max_B, |
71 |
| - D_offsets.numel() - 1 |
72 |
| - ); |
| 73 | + Tensor offsets_; |
| 74 | + AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "reshape_vbe_offsets_cpu_grad_indices", [&]() { |
| 75 | + offsets_ = reshape_vbe_offsets<index_t>( |
| 76 | + offsets, |
| 77 | + vbe_B_offsets_rank_per_feature, |
| 78 | + max_B, |
| 79 | + D_offsets.numel() - 1 |
| 80 | + ); |
| 81 | + }); |
73 | 82 | const auto grad_output_ = reshape_vbe_output(
|
74 | 83 | grad_output,
|
75 | 84 | max_B,
|
@@ -126,8 +135,11 @@ Tensor split_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper(
|
126 | 135 | {%- endif %}
|
127 | 136 | const bool /*is_experimental = false*/,
|
128 | 137 | const int64_t output_dtype = static_cast<int64_t>(SparseType::FP32)) {
|
| 138 | + Tensor offsets_; |
129 | 139 | {%- if vbe %}
|
130 |
| - const auto offsets_ = reshape_vbe_offsets(offsets, vbe_B_offsets_rank_per_feature, max_B, D_offsets.numel() - 1); |
| 140 | + AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "reshape_vbe_offsets_cpu_forward", [&]() { |
| 141 | + offsets_ = reshape_vbe_offsets<index_t>(offsets, vbe_B_offsets_rank_per_feature, max_B, D_offsets.numel() - 1); |
| 142 | + }); |
131 | 143 | {%- endif %}
|
132 | 144 | static auto op =
|
133 | 145 | torch::Dispatcher::singleton()
|
@@ -226,7 +238,10 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
|
226 | 238 | {%- endif %})
|
227 | 239 | {
|
228 | 240 | {%- if vbe %}
|
229 |
| - const auto offsets_ = reshape_vbe_offsets(offsets, vbe_B_offsets_rank_per_feature, max_B, D_offsets.numel() - 1); |
| 241 | + Tensor offsets_; |
| 242 | + AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "reshape_vbe_offsets_cpu_backward", [&]() { |
| 243 | + offsets_ = reshape_vbe_offsets<index_t>(offsets, vbe_B_offsets_rank_per_feature, max_B, D_offsets.numel() - 1); |
| 244 | + }); |
230 | 245 | const auto grad_output_ = reshape_vbe_output(grad_output, max_B, vbe_B_offsets_rank_per_feature, D_offsets);
|
231 | 246 | {%- endif %}
|
232 | 247 | {%- set backward_op = "split_embedding_backward_codegen_{}_cpu".format(
|
|
0 commit comments