Skip to content

Commit cfa1f9b

Browse files
committed
add sgd group
1 parent a9b9ec4 commit cfa1f9b

File tree

4 files changed

+285
-0
lines changed

4 files changed

+285
-0
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/sgd_group_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class SGDGroupOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
void InferShape(framework::InferShapeContext* ctx) const override {
25+
PADDLE_ENFORCE(ctx->HasInputs("Params"),
26+
"Inputs(Param) of SGDGroupOp should not be null.");
27+
PADDLE_ENFORCE(ctx->HasInputs("Grads"),
28+
"Inputs(Grad) of SGDGroupOp should not be null.");
29+
PADDLE_ENFORCE(ctx->HasInputs("LearningRates"),
30+
"Inputs(LearningRates) of SGDGroupOp should not be null.");
31+
PADDLE_ENFORCE(ctx->HasOutputs("ParamOuts"),
32+
"Outputs(ParamOut) of SGDGroupOp should not be null.");
33+
34+
auto params = ctx->GetInputsDim("Params");
35+
auto grads = ctx->GetInputsDim("Grads");
36+
auto learning_rates = ctx->GetInputsDim("LearningRates");
37+
38+
auto param_num = params.size();
39+
40+
PADDLE_ENFORCE_EQ(param_num, grads.size(),
41+
"The number of param and grads should be equal.");
42+
PADDLE_ENFORCE_EQ(
43+
param_num, learning_rates.size(),
44+
"The number of param and learning_rates should be equal.");
45+
46+
for (size_t i = 0; i < param_num; ++i) {
47+
PADDLE_ENFORCE_EQ(framework::product(learning_rates[i]), 1,
48+
"Learning rate should have 1 element");
49+
}
50+
51+
auto param_dims = ctx->GetInputsDim("Params");
52+
// TODO(qijun): check dimensions of Param and Grad at complie
53+
// and run time.
54+
ctx->SetOutputsDim("ParamOuts", param_dims);
55+
}
56+
};
57+
58+
class SGDGroupOpMaker : public framework::OpProtoAndCheckerMaker {
59+
public:
60+
SGDGroupOpMaker(OpProto* proto, OpAttrChecker* op_checker)
61+
: OpProtoAndCheckerMaker(proto, op_checker) {
62+
AddInput("Params", "(vector<Tensor>) Input parameter").AsDuplicable();
63+
AddInput("LearningRates", "(vector<Tensor>) Learning rate of SGD")
64+
.AsDuplicable();
65+
AddInput("Grads", "(vector<Tensor>) Input gradient").AsDuplicable();
66+
AddOutput("ParamOuts", "(vector<Tensor>) Output parameter").AsDuplicable();
67+
AddComment(R"DOC(
68+
SGDGroup operator
69+
70+
)DOC");
71+
}
72+
};
73+
74+
} // namespace operators
75+
} // namespace paddle
76+
77+
namespace ops = paddle::operators;
78+
REGISTER_OP_WITHOUT_GRADIENT(sgd_group, ops::SGDGroupOp, ops::SGDGroupOpMaker);
79+
REGISTER_OP_CPU_KERNEL(sgd_group, ops::SGDGroupOpKernel<float>,
80+
ops::SGDGroupOpKernel<double>);
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#define EIGEN_USE_GPU
16+
#include "paddle/fluid/operators/sgd_group_op.h"
17+
#include "paddle/fluid/platform/cuda_helper.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
namespace {
23+
24+
template <typename T>
25+
__global__ void SGDGroupKernel(const T* g, const T* p, const T* learning_rate,
26+
const int num, T* p_out) {
27+
T lr = learning_rate[0];
28+
int grid_size = blockDim.x * gridDim.x;
29+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += grid_size) {
30+
T g_data = g[i];
31+
T p_data = p[i];
32+
p_out[i] = p_data - lr * g_data;
33+
}
34+
}
35+
36+
} // namespace
37+
38+
template <typename T>
39+
class SGDGroupOpCUDAKernel : public framework::OpKernel<T> {
40+
public:
41+
void Compute(const framework::ExecutionContext& ctx) const override {
42+
auto params = ctx.MultiInput<framework::Tensor>("Params");
43+
auto learning_rates = ctx.MultiInput<framework::Tensor>("LearningRates");
44+
auto grads = ctx.MultiInput<framework::Tensor>("Grads");
45+
46+
auto param_outs = ctx.MultiOutput<framework::Tensor>("ParamOuts");
47+
48+
auto grad_var = ctx.MultiInputVar("Grads");
49+
50+
if (grad_var[0]->IsType<framework::LoDTensor>()) {
51+
for (size_t j = 0; j < params.size(); ++j) {
52+
auto* param_out_data = param_outs[j]->mutable_data<T>(ctx.GetPlace());
53+
auto* grad_data = grads[j]->data<T>();
54+
auto* param_data = params[j]->data<T>();
55+
int param_num = params[j]->numel();
56+
int block = 512;
57+
int grid = (param_num + block - 1) / block;
58+
SGDGroupKernel<
59+
T><<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
60+
grad_data, param_data, learning_rates[j]->data<T>(), param_num,
61+
param_out_data);
62+
}
63+
} else {
64+
PADDLE_THROW("Unsupported Variable Type of Grad");
65+
}
66+
}
67+
};
68+
} // namespace operators
69+
} // namespace paddle
70+
71+
namespace ops = paddle::operators;
72+
REGISTER_OP_CUDA_KERNEL(sgd_group, ops::SGDGroupOpCUDAKernel<float>,
73+
ops::SGDGroupOpCUDAKernel<double>);
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/fluid/framework/eigen.h"
17+
#include "paddle/fluid/framework/op_registry.h"
18+
#include "paddle/fluid/framework/selected_rows.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
template <typename T>
24+
class SGDGroupOpKernel : public framework::OpKernel<T> {
25+
public:
26+
void Compute(const framework::ExecutionContext& ctx) const override {
27+
//
28+
// auto* param = ctx.Input<framework::Tensor>("Param");
29+
// auto* param_out = ctx.Output<framework::Tensor>("ParamOut");
30+
// auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate");
31+
//
32+
// auto* grad_var = ctx.InputVar("Grad");
33+
// // Actually, all tensors are LoDTensor except SelectedRows.
34+
// if (grad_var->IsType<framework::LoDTensor>()) {
35+
// param_out->mutable_data<T>(ctx.GetPlace());
36+
// auto* grad = ctx.Input<framework::Tensor>("Grad");
37+
//
38+
// auto p = framework::EigenVector<T>::Flatten(*param);
39+
// auto g = framework::EigenVector<T>::Flatten(*grad);
40+
// auto o = framework::EigenVector<T>::Flatten(*param_out);
41+
// auto* lr = learning_rate->data<T>();
42+
//
43+
// o = p - lr[0] * g;
44+
// } else if (grad_var->IsType<framework::SelectedRows>()) {
45+
// // TODO(qijun): In Sparse SGD operator, in-place update is enforced.
46+
// // This manual optimization brings difficulty to track data
47+
// dependency.
48+
// // It's better to find a more elegant solution.
49+
// PADDLE_ENFORCE_EQ(param, param_out);
50+
// auto* grad = ctx.Input<framework::SelectedRows>("Grad");
51+
//
52+
// auto in_height = grad->height();
53+
// auto out_dims = param_out->dims();
54+
// PADDLE_ENFORCE_EQ(in_height, out_dims[0]);
55+
//
56+
// auto& in_value = grad->value();
57+
// auto& in_rows = grad->rows();
58+
//
59+
// int64_t in_row_numel = in_value.numel() / in_rows.size();
60+
// PADDLE_ENFORCE_EQ(in_row_numel, param_out->numel() / in_height);
61+
//
62+
// auto* in_data = in_value.data<T>();
63+
// auto* out_data = param_out->data<T>();
64+
// auto* lr = learning_rate->data<T>();
65+
//
66+
// for (size_t i = 0; i < in_rows.size(); i++) {
67+
// for (int64_t j = 0; j < in_row_numel; j++) {
68+
// out_data[in_rows[i] * in_row_numel + j] -=
69+
// lr[0] * in_data[i * in_row_numel + j];
70+
// }
71+
// }
72+
// } else {
73+
// PADDLE_THROW("Unsupported Variable Type of Grad");
74+
// }
75+
}
76+
};
77+
} // namespace operators
78+
} // namespace paddle
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import numpy as np
17+
import paddle.fluid.core as core
18+
from paddle.fluid.op import Operator
19+
from op_test import OpTest
20+
21+
22+
class TestSGDOp(OpTest):
23+
def setUp(self):
24+
self.op_type = "sgd_group"
25+
w0 = np.random.random((1, 124)).astype('float32')
26+
w1 = np.random.random((3, 24)).astype('float32')
27+
w2 = np.random.random((4, 104)).astype('float32')
28+
29+
g0 = np.random.random((1, 124)).astype('float32')
30+
g1 = np.random.random((3, 24)).astype('float32')
31+
g2 = np.random.random((4, 104)).astype('float32')
32+
33+
lr0 = np.array([0.1]).astype("float32")
34+
lr1 = np.array([0.2]).astype("float32")
35+
lr2 = np.array([0.3]).astype("float32")
36+
37+
o0 = w0 - lr0 * g0
38+
o1 = w1 - lr1 * g1
39+
o2 = w2 - lr2 * g2
40+
41+
self.inputs = {
42+
"Params": [("w0", w0), ("w1", w1), ("w2", w2)],
43+
"Grads": [("g0", g0), ("g1", g1), ("g2", g2)],
44+
'LearningRates': [("lr0", lr0), ("lr1", lr1), ("lr2", lr2)]
45+
}
46+
47+
self.outputs = {'ParamOuts': [("o0", o0), ("o1", o1), ("o2", o2)]}
48+
49+
def test_check_output(self):
50+
self.check_output()
51+
52+
53+
if __name__ == "__main__":
54+
unittest.main()

0 commit comments

Comments
 (0)