-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Add a lookup table op and a CUDA helper. #3620
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0f3b9e4
1795e57
c91e542
31f59d2
9bc1a1a
a8d072c
f188e22
d8ea560
fe480b9
068ddca
aafeff0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include "paddle/operators/lookup_table_op.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| class LookupTableOp : public framework::OperatorWithKernel { | ||
| public: | ||
| using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
|
||
| protected: | ||
| void InferShape(const framework::InferShapeContext &context) const override { | ||
| auto table_t = context.Input<Tensor>("W"); | ||
| auto ids_t = context.Input<Tensor>("Ids"); | ||
| auto output_t = context.Output<Tensor>("Out"); | ||
|
|
||
| output_t->Resize({ids_t->dims()[0], table_t->dims()[1]}); | ||
| } | ||
| }; | ||
|
|
||
| class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { | ||
| public: | ||
| LookupTableOpMaker(framework::OpProto *proto, | ||
| framework::OpAttrChecker *op_checker) | ||
| : OpProtoAndCheckerMaker(proto, op_checker) { | ||
| AddInput("W", | ||
| "An input represents embedding tensors," | ||
| " which is a learnable parameter."); | ||
| AddInput("Ids", | ||
| "An input with type int32 or int64" | ||
| "contains the ids to be looked up in W."); | ||
| AddOutput("Out", "The lookup results, which have the same type with W."); | ||
| AddComment( | ||
| "This operator is used to perform lookups on the parameter W," | ||
| "then concatenated into a dense tensor."); | ||
| } | ||
| }; | ||
|
|
||
| class LookupTableOpGrad : public framework::OperatorWithKernel { | ||
| public: | ||
| using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
|
||
| protected: | ||
| void InferShape(const framework::InferShapeContext &context) const override { | ||
| auto table = context.Input<Tensor>("W"); | ||
| auto d_table = context.Output<Tensor>(framework::GradVarName("W")); | ||
| d_table->Resize(table->dims()); | ||
| } | ||
| }; | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle | ||
|
|
||
| namespace ops = paddle::operators; | ||
| REGISTER_OP(lookup_table, ops::LookupTableOp, ops::LookupTableOpMaker, | ||
| lookup_table_grad, ops::LookupTableOpGrad); | ||
|
|
||
| REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>); | ||
| REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel<float>); |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,116 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include "paddle/framework/eigen.h" | ||
| #include "paddle/framework/op_registry.h" | ||
| #include "paddle/platform/assert.h" | ||
| #include "paddle/platform/cuda_helper.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| using Tensor = framework::Tensor; | ||
|
|
||
| template <typename T, int BlockDimX, int BlockDimY, int GridDimX> | ||
| __global__ void LookupTable(T* output, const T* table, const int32_t* ids, | ||
| const int N, const int K, const int D) { | ||
| int idx = threadIdx.x; | ||
| int idy = blockIdx.x + threadIdx.y * GridDimX; | ||
|
|
||
| while (idy < K) { | ||
| int id = ids[idy]; | ||
| PADDLE_ASSERT(id >= 0); | ||
| PADDLE_ASSERT(id < N); | ||
| T* out = output + idy * D; | ||
| const T* tab = table + id * D; | ||
| for (int i = idx; i < D; i += BlockDimX) { | ||
| out[i] = tab[i]; | ||
| } | ||
| idy += BlockDimY * GridDimX; | ||
| } | ||
| } | ||
|
|
||
| template <typename T, int BlockDimX, int BlockDimY, int GridDimX> | ||
| __global__ void LookupTableGrad(T* table, const T* output, const int32_t* ids, | ||
| const int N, const int K, const int D) { | ||
| int idx = threadIdx.x; | ||
| int idy = blockIdx.x + threadIdx.y * GridDimX; | ||
|
|
||
| while (idy < K) { | ||
| int id = ids[idy]; | ||
| PADDLE_ASSERT(id >= 0); | ||
| PADDLE_ASSERT(id < N); | ||
| const T* out = output + idy * D; | ||
| T* tab = table + id * D; | ||
| for (int i = idx; i < D; i += BlockDimX) { | ||
| paddle::platform::CudaAtomicAdd(&tab[i], out[i]); | ||
| } | ||
| idy += BlockDimY * GridDimX; | ||
| } | ||
| } | ||
|
|
||
| template <typename T> | ||
| class LookupTableCUDAKernel : public framework::OpKernel { | ||
| public: | ||
| void Compute(const framework::ExecutionContext& context) const override { | ||
| auto table_t = context.Input<Tensor>("W"); | ||
| auto ids_t = context.Input<Tensor>("Ids"); | ||
| auto output_t = context.Output<Tensor>("Out"); | ||
|
|
||
| size_t N = table_t->dims()[0]; | ||
| size_t D = table_t->dims()[1]; | ||
| size_t K = product(ids_t->dims()); | ||
| auto ids = ids_t->data<int32_t>(); | ||
| auto table = table_t->data<T>(); | ||
| auto output = output_t->mutable_data<T>(context.GetPlace()); | ||
|
|
||
| dim3 threads(128, 8); | ||
| dim3 grids(8, 1); | ||
| LookupTable<T, 128, 8, 8><<<grids, threads>>>(output, table, ids, N, K, D); | ||
| } | ||
| }; | ||
|
|
||
| template <typename T> | ||
| class LookupTableGradCUDAKernel : public framework::OpKernel { | ||
| public: | ||
| void Compute(const framework::ExecutionContext& context) const override { | ||
| auto ids_t = context.Input<Tensor>("Ids"); | ||
| auto d_output_t = context.Input<Tensor>(framework::GradVarName("Out")); | ||
| auto d_table_t = context.Output<Tensor>(framework::GradVarName("W")); | ||
|
|
||
| int N = d_table_t->dims()[0]; | ||
| int D = d_table_t->dims()[1]; | ||
| int K = product(ids_t->dims()); | ||
| const int32_t* ids = ids_t->data<int32_t>(); | ||
| const T* d_output = d_output_t->data<T>(); | ||
| T* d_table = d_table_t->mutable_data<T>(context.GetPlace()); | ||
|
|
||
| auto t = framework::EigenVector<T>::Flatten(*d_table_t); | ||
| t.device(context.GetEigenDevice<platform::GPUPlace>()) = | ||
| t.constant(static_cast<T>(0)); | ||
|
|
||
| dim3 threads(128, 8); | ||
| dim3 grids(8, 1); | ||
| LookupTableGrad<T, 128, 8, 8><<<grids, threads>>>(d_table, d_output, ids, N, | ||
| K, D); | ||
| } | ||
| }; | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle | ||
|
|
||
| namespace ops = paddle::operators; | ||
| REGISTER_OP_GPU_KERNEL(lookup_table, ops::LookupTableCUDAKernel<float>); | ||
| REGISTER_OP_GPU_KERNEL(lookup_table_grad, | ||
| ops::LookupTableGradCUDAKernel<float>); |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #pragma once | ||
|
|
||
| #include "paddle/framework/eigen.h" | ||
| #include "paddle/framework/op_registry.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| using Tensor = framework::Tensor; | ||
|
|
||
| template <typename T> | ||
| class LookupTableKernel : public framework::OpKernel { | ||
| public: | ||
| void Compute(const framework::ExecutionContext& context) const override { | ||
| auto table_t = context.Input<Tensor>("W"); // float tensor | ||
| auto ids_t = context.Input<Tensor>("Ids"); // int tensor | ||
| auto output_t = context.Output<Tensor>("Out"); // float tensor | ||
|
|
||
| size_t N = table_t->dims()[0]; | ||
| size_t D = table_t->dims()[1]; | ||
| auto ids = ids_t->data<int32_t>(); | ||
| auto table = table_t->data<T>(); | ||
| auto output = output_t->mutable_data<T>(context.GetPlace()); | ||
| for (size_t i = 0; i < product(ids_t->dims()); ++i) { | ||
| PADDLE_ENFORCE_LT(ids[i], N); | ||
| PADDLE_ENFORCE_GE(ids[i], 0); | ||
| memcpy(output + i * D, table + ids[i] * D, D * sizeof(T)); | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| template <typename T> | ||
| class LookupTableGradKernel : public framework::OpKernel { | ||
| public: | ||
| void Compute(const framework::ExecutionContext& context) const override { | ||
| auto ids_t = context.Input<Tensor>("Ids"); | ||
| auto d_output_t = context.Input<Tensor>(framework::GradVarName("Out")); | ||
| auto d_table_t = context.Output<Tensor>(framework::GradVarName("W")); | ||
|
|
||
| size_t N = d_table_t->dims()[0]; | ||
| size_t D = d_table_t->dims()[1]; | ||
| auto ids = ids_t->data<int32_t>(); | ||
| const T* d_output = d_output_t->data<T>(); | ||
| T* d_table = d_table_t->mutable_data<T>(context.GetPlace()); | ||
|
|
||
| auto t = framework::EigenVector<T>::Flatten(*d_table_t); | ||
| t.device(context.GetEigenDevice<platform::CPUPlace>()) = | ||
| t.constant(static_cast<T>(0)); | ||
|
|
||
| for (size_t i = 0; i < product(ids_t->dims()); ++i) { | ||
| PADDLE_ENFORCE_LT(ids[i], N); | ||
| PADDLE_ENFORCE_GE(ids[i], 0); | ||
| for (size_t j = 0; j < D; ++j) { | ||
| d_table[ids[i] * D + j] += d_output[i * D + j]; | ||
| } | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #pragma once | ||
| #include <cuda.h> | ||
|
|
||
| namespace paddle { | ||
| namespace platform { | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this header is necessary. Using
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this header, the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @reyoung remove set functor and remove CUDA_1D_KERNEL_LOOP. |
||
| #define CUDA_ATOMIC_WRAPPER(op, T) \ | ||
| __device__ __forceinline__ T CudaAtomic##op(T* address, const T val) | ||
|
|
||
| #define USE_CUDA_ATOMIC(op, T) \ | ||
| CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); } | ||
|
|
||
| // For atomicAdd. | ||
| USE_CUDA_ATOMIC(Add, float); | ||
|
|
||
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 | ||
| USE_CUDA_ATOMIC(Add, double); | ||
| #else | ||
| CUDA_ATOMIC_WRAPPER(Add, double) { | ||
| unsigned long long int* address_as_ull = | ||
| reinterpret_cast<unsigned long long int*>(address); | ||
| unsigned long long int old = *address_as_ull, assumed; | ||
|
|
||
| do { | ||
| assumed = old; | ||
| old = atomicCAS(address_as_ull, assumed, | ||
| __double_as_longlong(val + __longlong_as_double(assumed))); | ||
|
|
||
| // Note: uses integer comparison to avoid hang in case of NaN | ||
| } while (assumed != old); | ||
|
|
||
| return __longlong_as_double(old); | ||
| } | ||
| #endif | ||
|
|
||
| } // namespace platform | ||
| } // namespace paddle | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cuda_helper.h => cuda.h ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CUDA also has a
cuda.hheader. In caffe2, the file with similar function is calledcommon_gpu.h, and in TensorFlow, calledcuda_kernel_helper.h.