Skip to content

Commit 8f8ea00

Browse files
committed
fix implementations.
1 parent 1fb5f12 commit 8f8ea00

File tree

7 files changed

+151
-35
lines changed

7 files changed

+151
-35
lines changed

paddle/operators/math/utils.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/platform/assert.h"
17+
#include "paddle/platform/hostdevice.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
namespace math {
22+
23+
template <typename T>
24+
T HOSTDEVICE tolerable_value(const T x) {
25+
PADDLE_ASSERT(std::is_floating_point<T>::value);
26+
27+
const T kApproInf = 1e20;
28+
29+
if (x == INFINITY) {
30+
return kApproInf;
31+
}
32+
33+
if (x == -INFINITY) {
34+
return -kApproInf;
35+
}
36+
37+
return x;
38+
}
39+
40+
} // namespace math
41+
} // namespace operators
42+
} // namespace paddle

paddle/operators/onehot_cross_entropy_op.cu

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,27 +13,14 @@
1313
limitations under the License. */
1414

1515
#include "paddle/framework/op_registry.h"
16+
#include "paddle/operators/math/utils.h"
1617
#include "paddle/platform/assert.h"
1718

1819
namespace paddle {
1920
namespace operators {
2021

2122
using Tensor = framework::Tensor;
2223

23-
template <typename T>
24-
__host__ __device__ T clipping_log(const T x) {
25-
PADDLE_ASSERT(std::is_floating_point<T>::value);
26-
const T kApproInf = 1e20;
27-
T v = log(x);
28-
if (v == INFINITY) {
29-
return kApproInf;
30-
}
31-
if (v == -INFINITY) {
32-
return -kApproInf;
33-
}
34-
return v;
35-
}
36-
3724
template <typename T>
3825
__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
3926
const int N, const int D) {
@@ -42,7 +29,7 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
4229
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
4330
i += blockDim.x * gridDim.x) {
4431
PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
45-
Y[i] = -clipping_log(X[i * D + label[i]]);
32+
Y[i] = -math::tolerable_value(log(X[i * D + label[i]]));
4633
}
4734
}
4835

@@ -73,7 +60,7 @@ class OnehotCrossEntropyOpCUDAKernel : public framework::OpKernel {
7360
public:
7461
void Compute(const framework::ExecutionContext& ctx) const override {
7562
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
76-
"It must use GPUPlace.");
63+
"This kernel only runs on GPU device.");
7764

7865
auto X = ctx.Input<Tensor>("X");
7966
const T* Xdata = X->data<T>();
@@ -86,6 +73,7 @@ class OnehotCrossEntropyOpCUDAKernel : public framework::OpKernel {
8673
int D = X->dims()[1];
8774
int block = 512;
8875
int grid = (N + block - 1) / block;
76+
8977
// TODO(qingqing) launch kernel on specified stream
9078
// base on ExecutionContext.
9179
CrossEntropyKernel<T><<<grid, block>>>(Ydata, Xdata, label_data, N, D);

paddle/operators/softmax_with_cross_entropy_op.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class SoftmaxWithCrossEntropyOpMaker
3232
"Store the outputs of softmax function, "
3333
"which will be used in backward calculation.")
3434
.AsIntermediate();
35-
AddOutput("Loss", "A 1-D tensor<float> with shape N.");
35+
AddOutput("Out", "A 1-D tensor<float> with shape N.");
3636
AddComment(R"DOC(
3737
Cross entropy loss with softmax are used as the output layer extensively. This
3838
operator computes the softmax normalized values for each row of the input
@@ -56,14 +56,14 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
5656

5757
protected:
5858
void InferShape(const framework::InferShapeContext& ctx) const override {
59-
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Loss")),
60-
"Input(Loss@Grad) should not be null");
59+
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
60+
"Input(Out@Grad) should not be null");
6161
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Softmax"),
6262
"Input(Softmax) should be not null.");
6363
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
6464
"Input(Lable) should be not null.");
6565

66-
ctx.Output<Tensor>(framework::GradVarName("Logits"))
66+
ctx.Output<framework::LoDTensor>(framework::GradVarName("Logits"))
6767
->Resize(ctx.Input<Tensor>("Softmax")->dims());
6868
}
6969
};
@@ -81,8 +81,8 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
8181
PADDLE_ENFORCE(ctx.Input<Tensor>("Label")->dims().size() == 1UL,
8282
"The label should be a 1-d tensor.");
8383

84-
ctx.Output<Tensor>("Softmax")->Resize(logits->dims());
85-
ctx.Output<Tensor>("Loss")->Resize({logits->dims()[0], 1});
84+
ctx.Output<framework::LoDTensor>("Softmax")->Resize(logits->dims());
85+
ctx.Output<framework::LoDTensor>("Out")->Resize({logits->dims()[0], 1});
8686
}
8787
};
8888

paddle/operators/softmax_with_cross_entropy_op.cu

Lines changed: 93 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve.
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -13,8 +13,97 @@
1313
limitations under the License. */
1414

1515
#define EIGEN_USE_GPU
16-
#include "softmax_with_cross_entropy_op.h"
16+
#include "paddle/framework/op_registry.h"
17+
#include "paddle/operators/math/softmax_function.h"
18+
#include "paddle/operators/math/utils.h"
1719

18-
namespace ops = paddle::operators;
20+
namespace paddle {
21+
namespace operators {
22+
23+
using Tensor = framework::Tensor;
24+
25+
template <typename T>
26+
__global__ void CrossEntropyKernel(T* out, const T* softmax_out,
27+
const int* label, const int batch_size,
28+
const int class_num) {
29+
int i = blockIdx.x * blockDim.x + threadIdx.x;
30+
if (i >= batch_size) return;
31+
PADDLE_ASSERT(label[i] >= 0 && label[i] < class_num);
32+
out[i] = -math::tolerable_value(log(softmax_out[i * class_num + label[i]]));
33+
}
34+
35+
template <typename T>
36+
__global__ void CrossEntropyWithSoftmaxGradKernel(T* softmax_out,
37+
const int* label,
38+
const int batch_size,
39+
const int class_num) {
40+
int i = blockIdx.x * blockDim.x + threadIdx.x;
41+
if (i >= batch_size) return;
42+
43+
PADDLE_ASSERT(label[i] >= 0 && label[i] < class_num);
44+
softmax_out[i * class_num + label[i]] -= 1.;
45+
}
46+
47+
template <typename T>
48+
class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel {
49+
public:
50+
void Compute(const framework::ExecutionContext& context) const override {
51+
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
52+
"This kernel only runs on GPU device.");
53+
54+
// Calculate ths softmax outputs.
55+
const Tensor* logits = context.Input<Tensor>("Logits");
56+
Tensor* softmax = context.Output<Tensor>("Softmax");
57+
softmax->mutable_data<T>(context.GetPlace());
58+
math::SoftmaxFunctor<platform::GPUPlace, T>()(logits, softmax, context);
59+
T* softmax_out = softmax->data<T>();
60+
61+
// Calculate the cross entropy loss based on hard labels.
62+
const int* label_data = context.Input<Tensor>("Label")->data<int>();
63+
Tensor* loss = context.Output<Tensor>("Out");
64+
loss->mutable_data<T>(context.GetPlace());
65+
T* loss_data = loss->data<T>();
66+
67+
const int batch_size = logits->dims()[0];
68+
const int class_num = logits->dims()[1];
69+
int block = 512;
70+
int grid = (batch_size + block - 1) / block;
1971

20-
// TODO(caoying) add GPU kernel
72+
CrossEntropyKernel<T><<<grid, block>>>(loss_data, softmax_out, label_data,
73+
batch_size, class_num);
74+
}
75+
};
76+
77+
template <typename T>
78+
class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel {
79+
public:
80+
void Compute(const framework::ExecutionContext& context) const override {
81+
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
82+
"This kernel only runs on GPU device.");
83+
84+
Tensor* logit_grad =
85+
context.Output<Tensor>(framework::GradVarName("Logits"));
86+
logit_grad->ShareDataWith<T>(*context.Input<Tensor>("Softmax"));
87+
T* logit_grad_data = logit_grad->data<T>();
88+
89+
const int batch_size = logit_grad->dims()[0];
90+
const int class_num = logit_grad->dims()[1];
91+
92+
const int* label_data = context.Input<Tensor>("Label")->data<int>();
93+
94+
const int block = 512;
95+
const int grid = (batch_size + block - 1) / block;
96+
97+
CrossEntropyWithSoftmaxGradKernel<T><<<grid, block>>>(
98+
logit_grad_data, label_data, batch_size, class_num);
99+
}
100+
};
101+
102+
} // namespace operators
103+
} // namespace paddle
104+
105+
namespace ops = paddle::operators;
106+
REGISTER_OP_GPU_KERNEL(softmax_with_cross_entropy,
107+
ops::SoftmaxWithCrossEntropyCUDAKernel<float>);
108+
REGISTER_OP_GPU_KERNEL(softmax_with_cross_entropy_grad,
109+
ops::SoftmaxWithCrossEntropyGradCUDAKernel<float>);

paddle/operators/softmax_with_cross_entropy_op.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ template <typename T>
3030
class SoftmaxWithCrossEntropyKernel : public framework::OpKernel {
3131
public:
3232
void Compute(const framework::ExecutionContext& context) const override {
33-
auto place = context.GetPlace();
34-
PADDLE_ENFORCE(platform::is_cpu_place(place),
33+
PADDLE_ENFORCE(platform::is_cpu_place(context.GetPlace()),
3534
"This kernel only runs on CPU.");
3635

3736
// Calculate ths softmax outputs.
@@ -45,7 +44,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel {
4544
T* softmax_out = softmax->data<T>();
4645
const int* label_data = context.Input<Tensor>("Label")->data<int>();
4746

48-
Tensor* loss = context.Output<Tensor>("Loss");
47+
Tensor* loss = context.Output<Tensor>("Out");
4948
loss->mutable_data<T>(context.GetPlace());
5049
T* loss_data = loss->data<T>();
5150

@@ -74,7 +73,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel {
7473
const int* label_data = context.Input<Tensor>("Label")->data<int>();
7574
for (int i = 0; i < batch_size; ++i) {
7675
int index = i * class_num + label_data[i];
77-
logit_grad_data[index] -= .1;
76+
logit_grad_data[index] -= 1.;
7877
}
7978
}
8079
};

python/paddle/v2/framework/tests/test_cross_entropy_op.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import unittest
22
import numpy
33
from op_test import OpTest
4-
import pdb
54

65

76
class TestCrossEntropy(OpTest):

python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import unittest
22
import numpy as np
3-
import pdb
43

54
from op_test import OpTest
65
from test_softmax_op import stable_softmax
@@ -11,7 +10,7 @@ def setUp(self):
1110
self.op_type = "softmax_with_cross_entropy"
1211

1312
MAX_BATCH_SIZE = 23
14-
MAX_CLASS_NUM = 10
13+
MAX_CLASS_NUM = 17
1514

1615
batch_size = np.random.randint(1, MAX_BATCH_SIZE, 1)[0]
1716
class_num = np.random.randint(2, MAX_CLASS_NUM, 1)[0]
@@ -26,13 +25,13 @@ def setUp(self):
2625
dtype="float32")
2726

2827
self.inputs = {"Logits": logits, "Label": labels}
29-
self.outputs = {"Softmax": softmax, "Loss": cross_entropy}
28+
self.outputs = {"Softmax": softmax, "Out": cross_entropy}
3029

3130
def test_check_output(self):
3231
self.check_output()
3332

3433
def test_check_grad(self):
35-
self.check_grad(["Logits"], "Loss")
34+
self.check_grad(["Logits"], "Out", max_relative_error=0.05)
3635

3736

3837
if __name__ == "__main__":

0 commit comments

Comments
 (0)