Skip to content

Commit 4bac0fd

Browse files
committed
[CustomDevice] adapt c_embedding to phi namespace for custom devices
1 parent 600fc2f commit 4bac0fd

File tree

3 files changed

+175
-1
lines changed

3 files changed

+175
-1
lines changed

paddle/phi/kernels/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,8 @@ set(cc_search_pattern
242242
"strings/cpu/*.cc"
243243
"fusion/*.cc"
244244
"stride/*.cc"
245-
"fusion/cpu/*.cc")
245+
"fusion/cpu/*.cc"
246+
"custom/*.cc")
246247

247248
if(WITH_MKLDNN)
248249
set(cc_search_pattern ${cc_search_pattern} "legacy/onednn/*.cc" "onednn/*.cc"
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/c_embedding_grad_kernel.h"
16+
#include "glog/logging.h"
17+
#include "paddle/phi/api/backward/backward_api.h"
18+
#include "paddle/phi/api/include/api.h"
19+
#include "paddle/phi/backends/all_context.h"
20+
#include "paddle/phi/common/float16.h"
21+
#include "paddle/phi/core/kernel_registry.h"
22+
23+
namespace phi {
24+
25+
template <typename T, typename Context>
26+
void CEmbeddingGradKernel(const Context& dev_ctx,
27+
const DenseTensor& w,
28+
const DenseTensor& ids,
29+
const DenseTensor& out_grad,
30+
int64_t start_index,
31+
DenseTensor* w_grad) {
32+
w_grad->Resize(w.dims());
33+
dev_ctx.template Alloc(w_grad, w.dtype());
34+
const auto& index_type = ids.dtype();
35+
if (index_type == phi::DataType::INT32 ||
36+
index_type == phi::DataType::INT64) {
37+
auto K = ids.numel();
38+
auto N = w.dims()[0];
39+
auto D = w.dims()[1];
40+
41+
auto x_tmp = std::make_shared<phi::DenseTensor>();
42+
x_tmp->ShareDataWith(ids).Resize({K});
43+
auto w_tmp = std::make_shared<phi::DenseTensor>();
44+
w_tmp->set_meta(w.meta());
45+
dev_ctx.Alloc(w_tmp.get(), w_tmp->dtype());
46+
auto out_grad_tmp = std::make_shared<phi::DenseTensor>();
47+
out_grad_tmp->ShareDataWith(out_grad).Resize({K, D});
48+
paddle::Tensor x_tensor(x_tmp), w_tensor(w_tmp),
49+
out_grad_tensor(out_grad_tmp);
50+
51+
auto start_index_tensor = paddle::experimental::full_like(
52+
x_tensor, start_index, x_tensor.dtype(), x_tensor.place());
53+
auto end_index_tensor = paddle::experimental::full_like(
54+
x_tensor, start_index + N, x_tensor.dtype(), x_tensor.place());
55+
auto ids_mask_tensor = paddle::experimental::logical_and(
56+
x_tensor.greater_equal(start_index_tensor),
57+
x_tensor.less_than(end_index_tensor));
58+
auto real_ids_tensor = (x_tensor - start_index_tensor)
59+
.multiply(paddle::experimental::cast(
60+
ids_mask_tensor, x_tensor.dtype()));
61+
auto out_grad_tensor_mul_mask =
62+
paddle::experimental::reshape(out_grad_tensor, {K, D})
63+
.multiply(paddle::experimental::reshape(
64+
paddle::experimental::cast(ids_mask_tensor, w.dtype()),
65+
{K, 1}));
66+
paddle::Tensor w_grad_tensor;
67+
paddle::experimental::embedding_grad(real_ids_tensor,
68+
w_tensor,
69+
out_grad_tensor_mul_mask,
70+
-1,
71+
false,
72+
&w_grad_tensor);
73+
w_grad->ShareDataWith(
74+
*reinterpret_cast<phi::DenseTensor*>(w_grad_tensor.impl().get()));
75+
76+
} else {
77+
PADDLE_THROW(phi::errors::Unavailable(
78+
"Custom Device c_embedding_grad ids only support int32 or int64."));
79+
}
80+
}
81+
} // namespace phi
82+
83+
#ifdef PADDLE_WITH_CUSTOM_DEVICE
84+
PD_REGISTER_KERNEL(c_embedding_grad,
85+
Custom,
86+
ALL_LAYOUT,
87+
phi::CEmbeddingGradKernel,
88+
float,
89+
phi::dtype::float16,
90+
phi::dtype::bfloat16) {}
91+
#endif
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/c_embedding_kernel.h"
16+
#include "glog/logging.h"
17+
#include "paddle/phi/api/backward/backward_api.h"
18+
#include "paddle/phi/api/include/api.h"
19+
#include "paddle/phi/backends/all_context.h"
20+
#include "paddle/phi/common/float16.h"
21+
#include "paddle/phi/core/kernel_registry.h"
22+
23+
namespace phi {
24+
25+
template <typename T, typename Context>
26+
void CEmbeddingKernel(const Context& dev_ctx,
27+
const DenseTensor& w,
28+
const DenseTensor& ids,
29+
int64_t start_index,
30+
int64_t vocab_size,
31+
DenseTensor* out) {
32+
const auto& index_type = ids.dtype();
33+
if (index_type == phi::DataType::INT32 ||
34+
index_type == phi::DataType::INT64) {
35+
auto out_dims = out->dims();
36+
auto K = ids.numel();
37+
auto N = w.dims()[0];
38+
auto D = w.dims()[1];
39+
40+
auto x_tmp = std::make_shared<phi::DenseTensor>();
41+
x_tmp->ShareDataWith(ids).Resize({K});
42+
auto w_tmp = std::make_shared<phi::DenseTensor>();
43+
w_tmp->ShareDataWith(w).Resize({N, D});
44+
paddle::Tensor x_tensor(x_tmp), w_tensor(w_tmp);
45+
46+
auto start_index_tensor = paddle::experimental::full_like(
47+
x_tensor, start_index, x_tensor.dtype(), x_tensor.place());
48+
auto end_index_tensor = paddle::experimental::full_like(
49+
x_tensor, start_index + N, x_tensor.dtype(), x_tensor.place());
50+
auto ids_mask_tensor = paddle::experimental::logical_and(
51+
x_tensor.greater_equal(start_index_tensor),
52+
x_tensor.less_than(end_index_tensor));
53+
auto ids_tensor = (x_tensor - start_index_tensor)
54+
.multiply(paddle::experimental::cast(
55+
ids_mask_tensor, x_tensor.dtype()));
56+
auto out_tensor =
57+
paddle::experimental::reshape(
58+
paddle::experimental::cast(ids_mask_tensor, w_tensor.dtype()),
59+
{K, 1})
60+
.multiply(paddle::experimental::reshape(
61+
paddle::experimental::embedding(
62+
ids_tensor, w_tensor, -1, false),
63+
{K, D}));
64+
out->ShareDataWith(
65+
*reinterpret_cast<phi::DenseTensor*>(out_tensor.impl().get()))
66+
.Resize(out_dims);
67+
} else {
68+
PADDLE_THROW(phi::errors::Unavailable(
69+
"Custom Device c_embedding ids only support int32 or int64."));
70+
}
71+
}
72+
} // namespace phi
73+
74+
#ifdef PADDLE_WITH_CUSTOM_DEVICE
75+
PD_REGISTER_KERNEL(c_embedding,
76+
Custom,
77+
ALL_LAYOUT,
78+
phi::CEmbeddingKernel,
79+
float,
80+
phi::dtype::float16,
81+
phi::dtype::bfloat16) {}
82+
#endif

0 commit comments

Comments
 (0)