@@ -23,16 +23,32 @@ class SoftmaxWithCrossEntropyOpMaker
2323 SoftmaxWithCrossEntropyOpMaker (framework::OpProto* proto,
2424 framework::OpAttrChecker* op_checker)
2525 : OpProtoAndCheckerMaker(proto, op_checker) {
26+ // (TODO caoying) replace int with boolean
27+ AddAttr<int >(" soft_label" ,
28+ " (int, default 0), A flag to indicate whether to interpretate "
29+ " the given labels as soft labels." )
30+ .SetDefault (0 );
2631 AddInput (" Logits" ,
27- " The unscaled log probabilities which is a 2-D tensor<float> with"
28- " shape [N x K]. N is the batch_size, and K is the class number." )
32+ " (Tensor, default Tensor<float>), The unscaled log probabilities "
33+ " which is a 2-D tensor with shape [N x K]. N is the batch_size, "
34+ " and K is the class number." )
2935 .NotInGradient ();
30- AddInput (" Label" , " The ground truth. A 1-D tensor<int> with shape N." );
31- AddOutput (" Softmax" ,
32- " Store the outputs of softmax function, "
33- " which will be used in backward calculation." )
36+ AddInput (
37+ " Label" ,
38+ " (Tensor, default Tensor<int>), The ground truth which is "
39+ " a 1-D or 2-D tensor. "
40+ " If soft_label is set to 0, Label is a Tensor<int> with shape [N x 1]. "
41+ " If soft_label is set to 1, Label is a Tensor<float/double> "
42+ " with shape [N x K]." );
43+ AddOutput (
44+ " Softmax" ,
45+ " (Tensor, default Tensor<float>), A 2-D tensor with shape [N x K]. "
46+ " The outputs value of softmax activation by given the input batch, "
47+ " which will be used in backward calculation." )
3448 .AsIntermediate ();
35- AddOutput (" Out" , " A 1-D tensor<float> with shape N." );
49+ AddOutput (" Loss" ,
50+ " (Tensor, default Tensor<float>), A 1-D tensor. The cross "
51+ " entropy loss with shape [N x 1]." );
3652 AddComment (R"DOC(
3753Cross entropy loss with softmax are used as the output layer extensively. This
3854operator computes the softmax normalized values for each row of the input
@@ -46,25 +62,18 @@ which will produce incorrect results.
4662This operators expects mutually exclusive hard labels, each sample in a batch
4763is in exactly one class with probabilities 1. Each sample in the batch with one
4864and only one label.
49- )DOC" );
50- }
51- };
5265
53- class SoftmaxWithCrossEntropyOpGrad : public framework ::OperatorWithKernel {
54- public:
55- using framework::OperatorWithKernel::OperatorWithKernel;
66+ Equation:
5667
57- protected:
58- void InferShape (const framework::InferShapeContext& ctx) const override {
59- PADDLE_ENFORCE_NOT_NULL (ctx.InputVar (framework::GradVarName (" Out" )),
60- " Input(Out@Grad) should not be null" );
61- PADDLE_ENFORCE_NOT_NULL (ctx.InputVar (" Softmax" ),
62- " Input(Softmax) should be not null." );
63- PADDLE_ENFORCE_NOT_NULL (ctx.InputVar (" Label" ),
64- " Input(Lable) should be not null." );
68+ 1) hard label (one-hot label)
6569
66- ctx.Output <framework::LoDTensor>(framework::GradVarName (" Logits" ))
67- ->Resize (ctx.Input <Tensor>(" Softmax" )->dims ());
70+ Loss_j = -\text{Logit}_{Label_j} + \log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right), j = 1, ..., K
71+
72+ 2) soft label (a distribution over all classes)
73+
74+ Loss_j = -\sum_{i=0}^{K}\text{Label}_i\left(\text{Logit}_i-\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right), j = 1,...,K
75+
76+ )DOC" );
6877 }
6978};
7079
@@ -82,7 +91,25 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
8291 " The label should be a 1-d tensor." );
8392
8493 ctx.Output <framework::LoDTensor>(" Softmax" )->Resize (logits->dims ());
85- ctx.Output <framework::LoDTensor>(" Out" )->Resize ({logits->dims ()[0 ], 1 });
94+ ctx.Output <framework::LoDTensor>(" Loss" )->Resize ({logits->dims ()[0 ], 1 });
95+ }
96+ };
97+
98+ class SoftmaxWithCrossEntropyOpGrad : public framework ::OperatorWithKernel {
99+ public:
100+ using framework::OperatorWithKernel::OperatorWithKernel;
101+
102+ protected:
103+ void InferShape (const framework::InferShapeContext& ctx) const override {
104+ PADDLE_ENFORCE_NOT_NULL (ctx.InputVar (framework::GradVarName (" Loss" )),
105+ " Input(Loss@Grad) should not be null" );
106+ PADDLE_ENFORCE_NOT_NULL (ctx.InputVar (" Softmax" ),
107+ " Input(Softmax) should be not null." );
108+ PADDLE_ENFORCE_NOT_NULL (ctx.InputVar (" Label" ),
109+ " Input(Lable) should be not null." );
110+
111+ ctx.Output <framework::LoDTensor>(framework::GradVarName (" Logits" ))
112+ ->Resize (ctx.Input <Tensor>(" Softmax" )->dims ());
86113 }
87114};
88115
0 commit comments