Skip to content

Commit 552df54

Browse files
committed
[CustomDevice] adapt c_embedding to phi namespace for custom devices
1 parent 600fc2f commit 552df54

File tree

4 files changed

+190
-1
lines changed

4 files changed

+190
-1
lines changed

paddle/phi/kernels/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,8 @@ set(cc_search_pattern
242242
"strings/cpu/*.cc"
243243
"fusion/*.cc"
244244
"stride/*.cc"
245-
"fusion/cpu/*.cc")
245+
"fusion/cpu/*.cc"
246+
"custom/*.cc")
246247

247248
if(WITH_MKLDNN)
248249
set(cc_search_pattern ${cc_search_pattern} "legacy/onednn/*.cc" "onednn/*.cc"
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/c_embedding_grad_kernel.h"
16+
#include "glog/logging.h"
17+
#include "paddle/phi/api/backward/backward_api.h"
18+
#include "paddle/phi/api/include/api.h"
19+
#include "paddle/phi/backends/all_context.h"
20+
#include "paddle/phi/common/float16.h"
21+
#include "paddle/phi/core/kernel_registry.h"
22+
23+
namespace phi {
24+
25+
template <typename T, typename Context>
26+
void CEmbeddingGradKernel(const Context& dev_ctx,
27+
const DenseTensor& w,
28+
const DenseTensor& ids,
29+
const DenseTensor& out_grad,
30+
int64_t start_index,
31+
DenseTensor* w_grad) {
32+
#ifdef PADDLE_WITH_CUSTOM_DEVICE
33+
w_grad->Resize(w.dims());
34+
dev_ctx.template Alloc(w_grad, w.dtype());
35+
const auto& index_type = ids.dtype();
36+
if (index_type == phi::DataType::INT32 ||
37+
index_type == phi::DataType::INT64) {
38+
auto K = ids.numel();
39+
auto N = w.dims()[0];
40+
auto D = w.dims()[1];
41+
42+
auto x_tmp = std::make_shared<phi::DenseTensor>();
43+
x_tmp->ShareDataWith(ids).Resize({K});
44+
auto w_tmp = std::make_shared<phi::DenseTensor>();
45+
w_tmp->set_meta(w.meta());
46+
dev_ctx.Alloc(w_tmp.get(), w_tmp->dtype());
47+
auto out_grad_tmp = std::make_shared<phi::DenseTensor>();
48+
out_grad_tmp->ShareDataWith(out_grad).Resize({K, D});
49+
paddle::Tensor x_tensor(x_tmp), w_tensor(w_tmp),
50+
out_grad_tensor(out_grad_tmp);
51+
52+
auto start_index_tensor = paddle::experimental::full_like(
53+
x_tensor, start_index, x_tensor.dtype(), x_tensor.place());
54+
auto end_index_tensor = paddle::experimental::full_like(
55+
x_tensor, start_index + N, x_tensor.dtype(), x_tensor.place());
56+
auto ids_mask_tensor = paddle::experimental::logical_and(
57+
x_tensor.greater_equal(start_index_tensor),
58+
x_tensor.less_than(end_index_tensor));
59+
auto real_ids_tensor = (x_tensor - start_index_tensor)
60+
.multiply(paddle::experimental::cast(
61+
ids_mask_tensor, x_tensor.dtype()));
62+
auto out_grad_tensor_mul_mask =
63+
paddle::experimental::reshape(out_grad_tensor, {K, D})
64+
.multiply(paddle::experimental::reshape(
65+
paddle::experimental::cast(ids_mask_tensor, w.dtype()),
66+
{K, 1}));
67+
paddle::Tensor w_grad_tensor;
68+
paddle::experimental::embedding_grad(real_ids_tensor,
69+
w_tensor,
70+
out_grad_tensor_mul_mask,
71+
-1,
72+
false,
73+
&w_grad_tensor);
74+
w_grad->ShareDataWith(
75+
*reinterpret_cast<phi::DenseTensor*>(w_grad_tensor.impl().get()));
76+
77+
} else {
78+
PADDLE_THROW(phi::errors::Unavailable(
79+
"Custom Device c_embedding_grad ids only support int32 or int64."));
80+
}
81+
#else
82+
PADDLE_THROW(
83+
phi::errors::Unavailable("This kernel can only be functional when paddle "
84+
"is compiled with custom device."));
85+
#endif
86+
}
87+
} // namespace phi
88+
89+
PD_REGISTER_KERNEL(c_embedding_grad,
90+
Custom,
91+
ALL_LAYOUT,
92+
phi::CEmbeddingGradKernel,
93+
float,
94+
phi::dtype::float16,
95+
phi::dtype::bfloat16) {}
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/c_embedding_kernel.h"
16+
#include "glog/logging.h"
17+
#include "paddle/phi/api/backward/backward_api.h"
18+
#include "paddle/phi/api/include/api.h"
19+
#include "paddle/phi/backends/all_context.h"
20+
#include "paddle/phi/common/float16.h"
21+
#include "paddle/phi/core/kernel_registry.h"
22+
23+
namespace phi {
24+
25+
template <typename T, typename Context>
26+
void CEmbeddingKernel(const Context& dev_ctx,
27+
const DenseTensor& w,
28+
const DenseTensor& ids,
29+
int64_t start_index,
30+
int64_t vocab_size,
31+
DenseTensor* out) {
32+
#ifdef PADDLE_WITH_CUSTOM_DEVICE
33+
const auto& index_type = ids.dtype();
34+
if (index_type == phi::DataType::INT32 ||
35+
index_type == phi::DataType::INT64) {
36+
auto out_dims = out->dims();
37+
auto K = ids.numel();
38+
auto N = w.dims()[0];
39+
auto D = w.dims()[1];
40+
41+
auto x_tmp = std::make_shared<phi::DenseTensor>();
42+
x_tmp->ShareDataWith(ids).Resize({K});
43+
auto w_tmp = std::make_shared<phi::DenseTensor>();
44+
w_tmp->ShareDataWith(w).Resize({N, D});
45+
paddle::Tensor x_tensor(x_tmp), w_tensor(w_tmp);
46+
47+
auto start_index_tensor = paddle::experimental::full_like(
48+
x_tensor, start_index, x_tensor.dtype(), x_tensor.place());
49+
auto end_index_tensor = paddle::experimental::full_like(
50+
x_tensor, start_index + N, x_tensor.dtype(), x_tensor.place());
51+
auto ids_mask_tensor = paddle::experimental::logical_and(
52+
x_tensor.greater_equal(start_index_tensor),
53+
x_tensor.less_than(end_index_tensor));
54+
auto ids_tensor = (x_tensor - start_index_tensor)
55+
.multiply(paddle::experimental::cast(
56+
ids_mask_tensor, x_tensor.dtype()));
57+
auto out_tensor =
58+
paddle::experimental::reshape(
59+
paddle::experimental::cast(ids_mask_tensor, w_tensor.dtype()),
60+
{K, 1})
61+
.multiply(paddle::experimental::reshape(
62+
paddle::experimental::embedding(
63+
ids_tensor, w_tensor, -1, false),
64+
{K, D}));
65+
out->ShareDataWith(
66+
*reinterpret_cast<phi::DenseTensor*>(out_tensor.impl().get()))
67+
.Resize(out_dims);
68+
} else {
69+
PADDLE_THROW(phi::errors::Unavailable(
70+
"Custom Device c_embedding ids only support int32 or int64."));
71+
}
72+
#else
73+
PADDLE_THROW(
74+
phi::errors::Unavailable("This kernel can only be functional when paddle "
75+
"is compiled with custom device."));
76+
#endif
77+
}
78+
} // namespace phi
79+
80+
PD_REGISTER_KERNEL(c_embedding,
81+
Custom,
82+
ALL_LAYOUT,
83+
phi::CEmbeddingKernel,
84+
float,
85+
phi::dtype::float16,
86+
phi::dtype::bfloat16) {}

test/legacy_test/c_embedding_op_base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,16 @@ def test_check_output(self):
8787
self.check_output_with_place(core.CUDAPlace(0))
8888
elif core.is_compiled_with_xpu():
8989
self.check_output_with_place(core.XPUPlace(0))
90+
elif core.is_compiled_with_custom_device():
91+
self.check_output_with_place(core.CustomPlace(0))
9092

9193
def test_check_grad(self):
9294
if core.is_compiled_with_cuda():
9395
self.check_grad_with_place(core.CUDAPlace(0), ['W'], 'Out')
9496
elif core.is_compiled_with_xpu():
9597
self.check_grad_with_place(core.XPUPlace(0), ['W'], 'Out')
98+
elif core.is_compiled_with_custom_device():
99+
self.check_grad_with_place(core.CustomPlace(0), ['W'], 'Out')
96100

97101
def init_dtype(self):
98102
if core.is_compiled_with_cuda():
@@ -101,6 +105,9 @@ def init_dtype(self):
101105
elif core.is_compiled_with_xpu():
102106
self.dtype = "float32"
103107
self.ids_dtype = "int64"
108+
elif core.is_compiled_with_custom_device():
109+
self.dtype = "float32"
110+
self.ids_dtype = "int64"
104111

105112

106113
class TestCEmbeddingOpFP32(TestCEmbeddingOpBase):

0 commit comments

Comments
 (0)