Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions paddle/phi/kernels/cpu/lookup_table_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/lookup_table_grad_kernel.h"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不需要此头文件

#include <string>
#include <vector>

Expand All @@ -25,8 +25,6 @@

namespace phi {

constexpr int64_t kNoPadding = -1;

template <typename T, typename Context>
void LookupTableGradKernel(const Context &dev_ctx,
const DenseTensor &w,
Expand Down
1 change: 0 additions & 1 deletion paddle/phi/kernels/cpu/lookup_table_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

#include <string>
#include <vector>

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/lookup_table_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/lookup_table_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h"
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/lookup_table_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/lookup_table_kernel.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h"
Expand Down
44 changes: 44 additions & 0 deletions paddle/phi/kernels/gpu/lookup_table_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include <vector>

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"

namespace phi {

template <typename T, typename Context>
void LookupTableCUDAKernel(const Context &dev_ctx,
const DenseTensor &w,
const DenseTensor &ids_in,
bool is_sparse,
bool is_distributed,
int64_t padding_idx,
bool remote_prefetch,
const std::string &entry_config,
bool is_test,
const std::string &entry,
const std::string &table_class,
const std::vector<std::string> &table_names,
int trainer_id,
bool grad_inplace,
const std::vector<std::string> &epmap,
const std::vector<int64_t> &height_sections,
DenseTensor *out);

} // namespace phi
109 changes: 109 additions & 0 deletions paddle/phi/kernels/lookup_table_grad_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include <vector>

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/selected_rows.h"

namespace phi {

constexpr int64_t kNoPadding = -1;

template <typename T, typename Context>
void LookupTableGradCUDAKernel(const Context &dev_ctx,
const DenseTensor &w,
const DenseTensor &ids_in,
const DenseTensor &out_grad,
bool is_sparse,
bool is_distributed,
int64_t padding_idx,
bool remote_prefetch,
const std::string &entry_config,
bool is_test,
const std::string &entry,
const std::string &table_class,
const std::vector<std::string> &table_names,
int trainer_id,
bool grad_inplace,
const std::vector<std::string> &epmap,
const std::vector<int64_t> &height_sections,
DenseTensor *w_grad);

template <typename T, typename Context>
void LookupTableSparseGradCUDAKernel(
const Context &dev_ctx,
const DenseTensor &w,
const DenseTensor &ids_in,
const DenseTensor &out_grad,
bool is_sparse,
bool is_distributed,
int64_t padding_idx,
bool remote_prefetch,
const std::string &entry_config,
bool is_test,
const std::string &entry,
const std::string &table_class,
const std::vector<std::string> &table_names,
int trainer_id,
bool grad_inplace,
const std::vector<std::string> &epmap,
const std::vector<int64_t> &height_sections,
SelectedRows *w_grad);

template <typename T, typename Context>
void LookupTableGradKernel(const Context &dev_ctx,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cpu的kernel声明应当去掉,并且此文件应当放在gpu目录中

const DenseTensor &w,
const DenseTensor &ids_in,
const DenseTensor &out_grad,
bool is_sparse,
bool is_distributed,
int64_t padding_idx,
bool remote_prefetch,
const std::string &entry_config,
bool is_test,
const std::string &entry,
const std::string &table_class,
const std::vector<std::string> &table_names,
int trainer_id,
bool grad_inplace,
const std::vector<std::string> &epmap,
const std::vector<int64_t> &height_sections,
DenseTensor *w_grad);

template <typename T, typename Context>
void LookupTableSparseGradKernel(const Context &dev_ctx,
const DenseTensor &w,
const DenseTensor &ids_in,
const DenseTensor &out_grad,
bool is_sparse,
bool is_distributed,
int64_t padding_idx,
bool remote_prefetch,
const std::string &entry_config,
bool is_test,
const std::string &entry,
const std::string &table_class,
const std::vector<std::string> &table_names,
int trainer_id,
bool grad_inplace,
const std::vector<std::string> &epmap,
const std::vector<int64_t> &height_sections,
SelectedRows *w_grad);

} // namespace phi
Loading