|
1 | | -/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve. |
| 1 | +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. |
2 | 2 |
|
3 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 4 | you may not use this file except in compliance with the License. |
|
13 | 13 | limitations under the License. */ |
14 | 14 |
|
15 | 15 | #define EIGEN_USE_GPU |
16 | | -#include "softmax_with_cross_entropy_op.h" |
| 16 | +#include "paddle/framework/op_registry.h" |
| 17 | +#include "paddle/operators/math/softmax_function.h" |
| 18 | +#include "paddle/operators/math/utils.h" |
17 | 19 |
|
18 | | -namespace ops = paddle::operators; |
| 20 | +namespace paddle { |
| 21 | +namespace operators { |
| 22 | + |
| 23 | +using Tensor = framework::Tensor; |
| 24 | + |
| 25 | +template <typename T> |
| 26 | +__global__ void CrossEntropyKernel(T* out, const T* softmax_out, |
| 27 | + const int* label, const int batch_size, |
| 28 | + const int class_num) { |
| 29 | + int i = blockIdx.x * blockDim.x + threadIdx.x; |
| 30 | + if (i >= batch_size) return; |
| 31 | + PADDLE_ASSERT(label[i] >= 0 && label[i] < class_num); |
| 32 | + out[i] = -math::tolerable_value(log(softmax_out[i * class_num + label[i]])); |
| 33 | +} |
| 34 | + |
| 35 | +template <typename T> |
| 36 | +__global__ void CrossEntropyWithSoftmaxGradKernel(T* softmax_out, |
| 37 | + const int* label, |
| 38 | + const int batch_size, |
| 39 | + const int class_num) { |
| 40 | + int i = blockIdx.x * blockDim.x + threadIdx.x; |
| 41 | + if (i >= batch_size) return; |
| 42 | + |
| 43 | + PADDLE_ASSERT(label[i] >= 0 && label[i] < class_num); |
| 44 | + softmax_out[i * class_num + label[i]] -= 1.; |
| 45 | +} |
| 46 | + |
| 47 | +template <typename T> |
| 48 | +class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { |
| 49 | + public: |
| 50 | + void Compute(const framework::ExecutionContext& context) const override { |
| 51 | + PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), |
| 52 | + "This kernel only runs on GPU device."); |
| 53 | + |
| 54 | + // Calculate ths softmax outputs. |
| 55 | + const Tensor* logits = context.Input<Tensor>("Logits"); |
| 56 | + Tensor* softmax = context.Output<Tensor>("Softmax"); |
| 57 | + softmax->mutable_data<T>(context.GetPlace()); |
| 58 | + math::SoftmaxFunctor<platform::GPUPlace, T>()(logits, softmax, context); |
| 59 | + T* softmax_out = softmax->data<T>(); |
| 60 | + |
| 61 | + // Calculate the cross entropy loss based on hard labels. |
| 62 | + const int* label_data = context.Input<Tensor>("Label")->data<int>(); |
| 63 | + Tensor* loss = context.Output<Tensor>("Out"); |
| 64 | + loss->mutable_data<T>(context.GetPlace()); |
| 65 | + T* loss_data = loss->data<T>(); |
| 66 | + |
| 67 | + const int batch_size = logits->dims()[0]; |
| 68 | + const int class_num = logits->dims()[1]; |
| 69 | + int block = 512; |
| 70 | + int grid = (batch_size + block - 1) / block; |
19 | 71 |
|
20 | | -// TODO(caoying) add GPU kernel |
| 72 | + CrossEntropyKernel<T><<<grid, block>>>(loss_data, softmax_out, label_data, |
| 73 | + batch_size, class_num); |
| 74 | + } |
| 75 | +}; |
| 76 | + |
| 77 | +template <typename T> |
| 78 | +class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { |
| 79 | + public: |
| 80 | + void Compute(const framework::ExecutionContext& context) const override { |
| 81 | + PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), |
| 82 | + "This kernel only runs on GPU device."); |
| 83 | + |
| 84 | + Tensor* logit_grad = |
| 85 | + context.Output<Tensor>(framework::GradVarName("Logits")); |
| 86 | + logit_grad->ShareDataWith<T>(*context.Input<Tensor>("Softmax")); |
| 87 | + T* logit_grad_data = logit_grad->data<T>(); |
| 88 | + |
| 89 | + const int batch_size = logit_grad->dims()[0]; |
| 90 | + const int class_num = logit_grad->dims()[1]; |
| 91 | + |
| 92 | + const int* label_data = context.Input<Tensor>("Label")->data<int>(); |
| 93 | + |
| 94 | + const int block = 512; |
| 95 | + const int grid = (batch_size + block - 1) / block; |
| 96 | + |
| 97 | + CrossEntropyWithSoftmaxGradKernel<T><<<grid, block>>>( |
| 98 | + logit_grad_data, label_data, batch_size, class_num); |
| 99 | + } |
| 100 | +}; |
| 101 | + |
| 102 | +} // namespace operators |
| 103 | +} // namespace paddle |
| 104 | + |
| 105 | +namespace ops = paddle::operators; |
| 106 | +REGISTER_OP_GPU_KERNEL(softmax_with_cross_entropy, |
| 107 | + ops::SoftmaxWithCrossEntropyCUDAKernel<float>); |
| 108 | +REGISTER_OP_GPU_KERNEL(softmax_with_cross_entropy_grad, |
| 109 | + ops::SoftmaxWithCrossEntropyGradCUDAKernel<float>); |
0 commit comments