Skip to content

Commit 4cb5bd9

Browse files
authored
Implementing the Adamax optimizer operator (#4538)
* Implementing the Adamax optimizer step operator * Adding unit tests for adamax_op * Changing learning rate and time step to inputs from attributes * Changing learning rate and time step to input(tensors) * Making the Adamax operator conform to naming convention * Removing Tensor<float> from comments * Rectifying the Adamax implementation * Changing Unit Test values and adding comments * Changing Unit Test to test multiple steps
1 parent f30a1f4 commit 4cb5bd9

File tree

4 files changed

+409
-0
lines changed

4 files changed

+409
-0
lines changed

paddle/operators/adamax_op.cc

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/adamax_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class AdamaxOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
protected:
25+
void InferShape(framework::InferShapeContextBase *ctx) const override {
26+
PADDLE_ENFORCE(ctx->HasInput("Param"),
27+
"Input(Param) of AdamaxOp should not be null.");
28+
PADDLE_ENFORCE(ctx->HasInput("Grad"),
29+
"Input(Grad) of AdamaxOp should not be null.");
30+
PADDLE_ENFORCE(ctx->HasInput("Moment"),
31+
"Input(Moment) of AdamaxOp should not be null.");
32+
PADDLE_ENFORCE(ctx->HasInput("InfNorm"),
33+
"Input(InfNorm) of AdamaxOp should not be null.");
34+
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
35+
"Input(LearningRate) of AdamaxOp should not be null.");
36+
PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"),
37+
"Input(Beta1Pow) of AdamaxOp should not be null.");
38+
39+
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
40+
"Output(ParamOut) of AdamaxOp should not be null.");
41+
PADDLE_ENFORCE(ctx->HasOutput("MomentOut"),
42+
"Output(MomentOut) of AdamaxOp should not be null.");
43+
PADDLE_ENFORCE(ctx->HasOutput("InfNormOut"),
44+
"Output(InfNormOut) of AdamaxOp should not be null.");
45+
PADDLE_ENFORCE(ctx->HasOutput("Beta1PowOut"),
46+
"Output(Beta1PowOut) of AdamaxOp should not be null.");
47+
48+
auto lr_dims = ctx->GetInputDim("LearningRate");
49+
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
50+
"Learning rate should have 1 dimension");
51+
auto beta1_pow_dims = ctx->GetInputDim("Beta1Pow");
52+
PADDLE_ENFORCE_EQ(framework::product(beta1_pow_dims), 1,
53+
"Beta1 power accumulator should have 1 dimension");
54+
auto param_dims = ctx->GetInputDim("Param");
55+
PADDLE_ENFORCE_EQ(
56+
param_dims, ctx->GetInputDim("Grad"),
57+
"Param and Grad input of AdamaxOp should have same dimension");
58+
PADDLE_ENFORCE_EQ(
59+
param_dims, ctx->GetInputDim("Moment"),
60+
"Param and Moment input of AdamaxOp should have same dimension");
61+
PADDLE_ENFORCE_EQ(
62+
param_dims, ctx->GetInputDim("InfNorm"),
63+
"Param and InfNorm input of AdamaxOp should have same dimension");
64+
65+
ctx->SetOutputDim("ParamOut", param_dims);
66+
ctx->SetOutputDim("MomentOut", param_dims);
67+
ctx->SetOutputDim("InfNormOut", param_dims);
68+
ctx->SetOutputDim("Beta1PowOut", beta1_pow_dims);
69+
}
70+
};
71+
72+
class AdamaxOpMaker : public framework::OpProtoAndCheckerMaker {
73+
public:
74+
AdamaxOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
75+
: OpProtoAndCheckerMaker(proto, op_checker) {
76+
AddInput("Param", "(Tensor) Input parameter");
77+
AddInput("Grad", "(Tensor) Input gradient");
78+
AddInput("LearningRate", "(Tensor) Learning rate");
79+
AddInput("Moment", "(Tensor) First moment");
80+
AddInput("InfNorm",
81+
"(Tensor) "
82+
"Input exponentially weighted infinity norm");
83+
AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator");
84+
85+
AddOutput("ParamOut", "(Tensor) Output parameter");
86+
AddOutput("MomentOut", "(Tensor) Output first moment");
87+
AddOutput("InfNormOut",
88+
"(Tensor) "
89+
"Output exponentially weighted infinity norm");
90+
AddOutput("Beta1PowOut", "(Tensor) Output beta1 power accumulator");
91+
92+
AddAttr<float>("beta1",
93+
"(float, default 0.9) "
94+
"Exponential decay rate for the "
95+
"1st moment estimates.")
96+
.SetDefault(0.9f);
97+
AddAttr<float>("beta2",
98+
"(float, default 0.999) "
99+
"exponential decay rate for the weighted "
100+
"infinity norm estimates.")
101+
.SetDefault(0.999f);
102+
AddAttr<float>("epsilon",
103+
"(float, default 1.0e-8) "
104+
"Constant for numerical stability")
105+
.SetDefault(1.0e-8f);
106+
AddComment(R"DOC(
107+
Adamax Updates Operator.
108+
109+
This implements the Adamax optimizer from Section 7 of the Adam
110+
paper[1]. Adamax is a variant of the
111+
Adam algorithm based on the infinity norm.
112+
113+
Adamax updates:
114+
115+
moment_out = beta1 * moment + (1 - beta1) * grad
116+
inf_norm_out = max(beta2 * inf_norm + epsilon, abs(grad))
117+
beta1_pow_out = beta1_pow * beta1
118+
learning_rate_t = learning_rate/(1 - beta1_pow_out)
119+
param_out = param - learning_rate_t * moment_out/inf_norm_out
120+
121+
The original paper does not have an epsilon attribute.
122+
However, it is added here for numerical stability
123+
by preventing divide by 0.
124+
125+
References:
126+
[1] Adam: A Method for Stochastic Optimization
127+
(https://arxiv.org/abs/1412.6980)
128+
129+
)DOC");
130+
}
131+
};
132+
133+
} // namespace operators
134+
} // namespace paddle
135+
136+
namespace ops = paddle::operators;
137+
REGISTER_OP_WITHOUT_GRADIENT(adamax, ops::AdamaxOp, ops::AdamaxOpMaker);
138+
REGISTER_OP_CPU_KERNEL(adamax,
139+
ops::AdamaxOpKernel<paddle::platform::CPUPlace, float>);

paddle/operators/adamax_op.cu

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#define EIGEN_USE_GPU
16+
#include "paddle/operators/adamax_op.h"
17+
18+
namespace ops = paddle::operators;
19+
REGISTER_OP_GPU_KERNEL(adamax,
20+
ops::AdamaxOpKernel<paddle::platform::GPUPlace, float>);

paddle/operators/adamax_op.h

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/framework/eigen.h"
17+
#include "paddle/framework/op_registry.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
template <typename Place, typename T>
23+
class AdamaxOpKernel : public framework::OpKernel<T> {
24+
public:
25+
void Compute(const framework::ExecutionContext& ctx) const override {
26+
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
27+
auto moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
28+
auto inf_norm_out_tensor = ctx.Output<framework::Tensor>("InfNormOut");
29+
auto beta1_pow_out_tensor = ctx.Output<framework::Tensor>("Beta1PowOut");
30+
31+
param_out_tensor->mutable_data<T>(ctx.GetPlace());
32+
moment_out_tensor->mutable_data<T>(ctx.GetPlace());
33+
inf_norm_out_tensor->mutable_data<T>(ctx.GetPlace());
34+
beta1_pow_out_tensor->mutable_data<T>(ctx.GetPlace());
35+
36+
float beta1 = ctx.Attr<float>("beta1");
37+
float beta2 = ctx.Attr<float>("beta2");
38+
float epsilon = ctx.Attr<float>("epsilon");
39+
40+
auto param = framework::EigenVector<T>::Flatten(
41+
*ctx.Input<framework::Tensor>("Param"));
42+
auto grad = framework::EigenVector<T>::Flatten(
43+
*ctx.Input<framework::Tensor>("Grad"));
44+
auto moment = framework::EigenVector<T>::Flatten(
45+
*ctx.Input<framework::Tensor>("Moment"));
46+
auto inf_norm = framework::EigenVector<T>::Flatten(
47+
*ctx.Input<framework::Tensor>("InfNorm"));
48+
auto lr = framework::EigenVector<T>::Flatten(
49+
*ctx.Input<framework::Tensor>("LearningRate"));
50+
auto beta1_pow = framework::EigenVector<T>::Flatten(
51+
*ctx.Input<framework::Tensor>("Beta1Pow"));
52+
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
53+
auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor);
54+
auto inf_norm_out =
55+
framework::EigenVector<T>::Flatten(*inf_norm_out_tensor);
56+
auto beta1_pow_out =
57+
framework::EigenVector<T>::Flatten(*beta1_pow_out_tensor);
58+
auto place = ctx.GetEigenDevice<Place>();
59+
60+
moment_out.device(place) = beta1 * moment + (1 - beta1) * grad;
61+
inf_norm_out.device(place) =
62+
grad.abs().cwiseMax((beta2 * inf_norm) + epsilon);
63+
beta1_pow_out.device(place) = beta1_pow * beta1;
64+
auto lr_t = lr / (1 - beta1_pow_out);
65+
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
66+
param_out.device(place) =
67+
param - lr_t.broadcast(m_dsize) * (moment_out / inf_norm_out);
68+
}
69+
};
70+
71+
} // namespace operators
72+
} // namespace paddle

0 commit comments

Comments
 (0)