Skip to content

Commit 0221cd8

Browse files
RabbitWhite1facebook-github-bot
authored andcommitted
split cpu parts in permute_pooled_embedding_ops for cpu_only (#987)
Summary: As specified in [CMakeLists.txt](https://github.com/pytorch/FBGEMM/blob/c3a26e17e7c514041ae3d08f39d0e19063614869/fbgemm_gpu/CMakeLists.txt#L239), "src/permute_pooled_embedding_ops_gpu.cpp" will only compile when "NOT FBGEMM_CPU_ONLY", which means the method "permute_pooled_embs_auto_grad" won't be generated when --cpu_only. However, this method is used by torchrec's column_wise sharding. The pr mentioned in #950 cannot work because of not using m.def to define permute_pooled_embs_auto_grad. Pull Request resolved: #987 Reviewed By: jianyuh Differential Revision: D34984810 Pulled By: geyyer fbshipit-source-id: f3730cc69f760a414e8a9dfcf4f843a545b15756
1 parent 356ca9c commit 0221cd8

File tree

6 files changed

+157
-106
lines changed

6 files changed

+157
-106
lines changed

fbgemm_gpu/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ set(fbgemm_gpu_sources_cpu
228228
src/jagged_tensor_ops_cpu.cpp
229229
src/input_combine_cpu.cpp
230230
src/layout_transform_ops_cpu.cpp
231+
src/permute_pooled_embedding_ops_cpu.cpp
231232
src/quantize_ops_cpu.cpp
232233
src/sparse_ops_cpu.cpp)
233234

fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
# pyre-ignore[21]
1717
from fbgemm_gpu import open_source # noqa: F401
1818
except Exception:
19+
torch.ops.load_library(
20+
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu"
21+
)
1922
torch.ops.load_library(
2023
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_gpu"
2124
)

fbgemm_gpu/include/fbgemm_gpu/permute_pooled_embedding_ops.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,18 @@ at::Tensor permute_pooled_embs_gpu(
2323
const at::Tensor& permute_list,
2424
const at::Tensor& inv_offset_dim_list,
2525
const at::Tensor& inv_permute_list);
26+
27+
at::Tensor permute_pooled_embs_auto_grad_cpu(
28+
const at::Tensor& pooled_embs,
29+
const at::Tensor& offset_dim_list,
30+
const at::Tensor& permute_list,
31+
const at::Tensor& inv_offset_dim_list,
32+
const at::Tensor& inv_permute_list);
33+
34+
at::Tensor permute_pooled_embs_auto_grad_gpu(
35+
const at::Tensor& pooled_embs,
36+
const at::Tensor& offset_dim_list,
37+
const at::Tensor& permute_list,
38+
const at::Tensor& inv_offset_dim_list,
39+
const at::Tensor& inv_permute_list);
2640
} // namespace fbgemm_gpu

fbgemm_gpu/include/fbgemm_gpu/permute_pooled_embs_function.h

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,70 @@
44
* This source code is licensed under the BSD-style license found in the
55
* LICENSE file in the root directory of this source tree.
66
*/
7+
8+
#pragma once
9+
10+
#include <ATen/ATen.h>
11+
#include <torch/script.h>
12+
13+
namespace fbgemm_gpu {
14+
15+
using torch::autograd::AutogradContext;
16+
using torch::autograd::Variable;
17+
using torch::autograd::variable_list;
18+
19+
template <torch::autograd::Variable (*permute_pooled_embs_op)(
20+
const at::Tensor&, // [B_local][Sum_T_global(D)]
21+
const at::Tensor&,
22+
const at::Tensor&,
23+
const at::Tensor&,
24+
const at::Tensor&)>
25+
class PermutePooledEmbsFunction
26+
: public torch::autograd::Function<
27+
PermutePooledEmbsFunction<permute_pooled_embs_op>> {
28+
public:
29+
static Variable forward(
30+
AutogradContext* ctx,
31+
const at::Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
32+
const at::Tensor& offset_dim_list,
33+
const at::Tensor& permute_list,
34+
const at::Tensor& inv_offset_dim_list,
35+
const at::Tensor& inv_permute_list) {
36+
ctx->saved_data["offset_dim_list"] = offset_dim_list;
37+
ctx->saved_data["permute_list"] = permute_list;
38+
ctx->saved_data["inv_offset_dim_list"] = inv_offset_dim_list;
39+
ctx->saved_data["inv_permute_list"] = inv_permute_list;
40+
TORCH_CHECK(
41+
offset_dim_list.scalar_type() == at::ScalarType::Long,
42+
"offset_dim_list needs to have long/int64 type");
43+
TORCH_CHECK(
44+
permute_list.scalar_type() == at::ScalarType::Long,
45+
"permute_list needs to have long/int64 type");
46+
return permute_pooled_embs_op(
47+
pooled_embs,
48+
offset_dim_list,
49+
permute_list,
50+
inv_offset_dim_list,
51+
inv_permute_list);
52+
}
53+
static variable_list backward(
54+
AutogradContext* ctx,
55+
variable_list grad_output) {
56+
const auto& offset_dim_list = ctx->saved_data["offset_dim_list"].toTensor();
57+
const auto& permute_list = ctx->saved_data["permute_list"].toTensor();
58+
const auto& inv_offset_dim_list =
59+
ctx->saved_data["inv_offset_dim_list"].toTensor();
60+
const auto& inv_permute_list =
61+
ctx->saved_data["inv_permute_list"].toTensor();
62+
variable_list grad_inputs(5);
63+
grad_inputs[0] = permute_pooled_embs_op(
64+
grad_output[0],
65+
inv_offset_dim_list,
66+
inv_permute_list,
67+
offset_dim_list,
68+
permute_list);
69+
return grad_inputs;
70+
}
71+
};
72+
73+
} // namespace fbgemm_gpu

fbgemm_gpu/src/permute_pooled_embedding_ops_cpu.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,74 @@
44
* This source code is licensed under the BSD-style license found in the
55
* LICENSE file in the root directory of this source tree.
66
*/
7+
#include <ATen/ATen.h>
8+
#include <ATen/core/op_registration/op_registration.h>
9+
#include <c10/util/irange.h>
10+
#include <torch/script.h>
11+
#include <vector>
12+
13+
#include "fbgemm_gpu/permute_pooled_embedding_ops.h"
14+
#include "fbgemm_gpu/permute_pooled_embs_function.h"
15+
#include "fbgemm_gpu/sparse_ops_utils.h"
16+
17+
using Tensor = at::Tensor;
18+
19+
namespace fbgemm_gpu {
20+
21+
using torch::autograd::AutogradContext;
22+
using torch::autograd::Variable;
23+
using torch::autograd::variable_list;
24+
25+
Tensor permute_pooled_embs_cpu(
26+
const Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
27+
const Tensor& offset_dim_list,
28+
const Tensor& permute_list,
29+
const Tensor& inv_offset_dim_list,
30+
const Tensor& inv_permute_list) {
31+
TORCH_CHECK(
32+
offset_dim_list.scalar_type() == at::ScalarType::Long,
33+
"offset_dim_list needs to have long/int64 type")
34+
TORCH_CHECK(
35+
permute_list.scalar_type() == at::ScalarType::Long,
36+
"permute_list needs to have long/int64 type")
37+
auto permute = permute_list.data_ptr<int64_t>();
38+
const auto n = permute_list.numel();
39+
std::vector<int64_t> dims;
40+
dims.reserve(n - 1);
41+
for (const auto i : c10::irange(1, n)) {
42+
dims.push_back(offset_dim_list[i].item<int64_t>());
43+
}
44+
auto ts = pooled_embs.tensor_split(dims, 1);
45+
std::vector<Tensor> permuted_ts;
46+
permuted_ts.reserve(n);
47+
for (const auto i : c10::irange(n)) {
48+
permuted_ts.push_back(ts[permute[i]]);
49+
}
50+
return at::cat(permuted_ts, 1);
51+
}
52+
53+
Tensor permute_pooled_embs_auto_grad_cpu(
54+
const Tensor& pooled_embs,
55+
const Tensor& offset_dim_list,
56+
const Tensor& permute_list,
57+
const Tensor& inv_offset_dim_list,
58+
const Tensor& inv_permute_list) {
59+
return PermutePooledEmbsFunction<permute_pooled_embs_cpu>::apply(
60+
pooled_embs,
61+
offset_dim_list,
62+
permute_list,
63+
inv_offset_dim_list,
64+
inv_permute_list);
65+
}
66+
} // namespace fbgemm_gpu
67+
68+
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
69+
m.def(
70+
"permute_pooled_embs(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor");
71+
DISPATCH_TO_CPU("permute_pooled_embs", fbgemm_gpu::permute_pooled_embs_cpu);
72+
m.def(
73+
"permute_pooled_embs_auto_grad(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor");
74+
DISPATCH_TO_CPU(
75+
"permute_pooled_embs_auto_grad",
76+
fbgemm_gpu::permute_pooled_embs_auto_grad_cpu);
77+
}

fbgemm_gpu/src/permute_pooled_embedding_ops_gpu.cpp

Lines changed: 1 addition & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -11,98 +11,13 @@
1111
#include <vector>
1212

1313
#include "fbgemm_gpu/permute_pooled_embedding_ops.h"
14+
#include "fbgemm_gpu/permute_pooled_embs_function.h"
1415
#include "fbgemm_gpu/sparse_ops_utils.h"
1516

1617
using Tensor = at::Tensor;
1718

1819
namespace fbgemm_gpu {
1920

20-
Tensor permute_pooled_embs_cpu(
21-
const Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
22-
const Tensor& offset_dim_list,
23-
const Tensor& permute_list,
24-
const Tensor& inv_offset_dim_list,
25-
const Tensor& inv_permute_list) {
26-
TORCH_CHECK(
27-
offset_dim_list.scalar_type() == at::ScalarType::Long,
28-
"offset_dim_list needs to have long/int64 type")
29-
TORCH_CHECK(
30-
permute_list.scalar_type() == at::ScalarType::Long,
31-
"permute_list needs to have long/int64 type")
32-
auto permute = permute_list.data_ptr<int64_t>();
33-
const auto n = permute_list.numel();
34-
std::vector<int64_t> dims;
35-
dims.reserve(n - 1);
36-
for (const auto i : c10::irange(1, n)) {
37-
dims.push_back(offset_dim_list[i].item<int64_t>());
38-
}
39-
auto ts = pooled_embs.tensor_split(dims, 1);
40-
std::vector<Tensor> permuted_ts;
41-
permuted_ts.reserve(n);
42-
for (const auto i : c10::irange(n)) {
43-
permuted_ts.push_back(ts[permute[i]]);
44-
}
45-
return at::cat(permuted_ts, 1);
46-
}
47-
48-
using torch::autograd::AutogradContext;
49-
using torch::autograd::Variable;
50-
using torch::autograd::variable_list;
51-
52-
template <torch::autograd::Variable (*permute_pooled_embs_op)(
53-
const Tensor&, // [B_local][Sum_T_global(D)]
54-
const Tensor&,
55-
const Tensor&,
56-
const Tensor&,
57-
const Tensor&)>
58-
class PermutePooledEmbsFunction
59-
: public torch::autograd::Function<
60-
PermutePooledEmbsFunction<permute_pooled_embs_op>> {
61-
public:
62-
static Variable forward(
63-
AutogradContext* ctx,
64-
const Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
65-
const Tensor& offset_dim_list,
66-
const Tensor& permute_list,
67-
const Tensor& inv_offset_dim_list,
68-
const Tensor& inv_permute_list) {
69-
ctx->saved_data["offset_dim_list"] = offset_dim_list;
70-
ctx->saved_data["permute_list"] = permute_list;
71-
ctx->saved_data["inv_offset_dim_list"] = inv_offset_dim_list;
72-
ctx->saved_data["inv_permute_list"] = inv_permute_list;
73-
TORCH_CHECK(
74-
offset_dim_list.scalar_type() == at::ScalarType::Long,
75-
"offset_dim_list needs to have long/int64 type");
76-
TORCH_CHECK(
77-
permute_list.scalar_type() == at::ScalarType::Long,
78-
"permute_list needs to have long/int64 type");
79-
return permute_pooled_embs_op(
80-
pooled_embs,
81-
offset_dim_list,
82-
permute_list,
83-
inv_offset_dim_list,
84-
inv_permute_list);
85-
}
86-
static variable_list backward(
87-
AutogradContext* ctx,
88-
variable_list grad_output) {
89-
const auto& offset_dim_list = ctx->saved_data["offset_dim_list"].toTensor();
90-
const auto& permute_list = ctx->saved_data["permute_list"].toTensor();
91-
const auto& inv_offset_dim_list =
92-
ctx->saved_data["inv_offset_dim_list"].toTensor();
93-
const auto& inv_permute_list =
94-
ctx->saved_data["inv_permute_list"].toTensor();
95-
variable_list grad_inputs(5);
96-
grad_inputs[0] = permute_pooled_embs_op(
97-
grad_output[0],
98-
inv_offset_dim_list,
99-
inv_permute_list,
100-
offset_dim_list,
101-
permute_list);
102-
return grad_inputs;
103-
}
104-
};
105-
10621
Tensor permute_pooled_embs_auto_grad_gpu(
10722
const Tensor& pooled_embs,
10823
const Tensor& offset_dim_list,
@@ -117,30 +32,10 @@ Tensor permute_pooled_embs_auto_grad_gpu(
11732
inv_permute_list);
11833
}
11934

120-
Tensor permute_pooled_embs_auto_grad_cpu(
121-
const Tensor& pooled_embs,
122-
const Tensor& offset_dim_list,
123-
const Tensor& permute_list,
124-
const Tensor& inv_offset_dim_list,
125-
const Tensor& inv_permute_list) {
126-
return PermutePooledEmbsFunction<permute_pooled_embs_cpu>::apply(
127-
pooled_embs,
128-
offset_dim_list,
129-
permute_list,
130-
inv_offset_dim_list,
131-
inv_permute_list);
132-
}
13335
} // namespace fbgemm_gpu
13436

13537
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
136-
m.def(
137-
"permute_pooled_embs(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor");
13838
DISPATCH_TO_CUDA("permute_pooled_embs", fbgemm_gpu::permute_pooled_embs_gpu);
139-
m.def(
140-
"permute_pooled_embs_auto_grad(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor");
141-
DISPATCH_TO_CPU(
142-
"permute_pooled_embs_auto_grad",
143-
fbgemm_gpu::permute_pooled_embs_auto_grad_cpu);
14439
DISPATCH_TO_CUDA(
14540
"permute_pooled_embs_auto_grad",
14641
fbgemm_gpu::permute_pooled_embs_auto_grad_gpu);

0 commit comments

Comments
 (0)