-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[WIP] Backward #2949
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Backward #2949
Changes from 3 commits
f35c8c4
7dc53ea
015ccd4
816b4c8
9890b23
8b80cf8
7f1533f
855cae6
e786746
bf4da3d
cb95587
94a6b1f
8bc4892
3dc70ff
73f4779
4876f35
e192d0f
14424f3
81a352a
6f05392
8a5ee46
b635af7
45452ac
088e220
99a5904
f41fcd4
9418717
4736b23
f85ccdd
0ab8f52
380227b
5f3bc2a
f4e2555
81df39f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include "paddle/framework/fully_connected_op.h" | ||
| #include <iostream> | ||
| namespace paddle { | ||
| namespace framework { | ||
|
|
||
| void FCOp::Run(const ScopePtr& scope, | ||
| const platform::DeviceContext& dev_ctx) const override { | ||
| std::cout << "FC" << std::endl; | ||
| } | ||
|
|
||
| void FCOp::InferShape(const ScopePtr& scope) const override {} | ||
|
|
||
| void FCGradientOp::Run(const ScopePtr& scope, | ||
| const platform::DeviceContext& dev_ctx) const override { | ||
| std::cout << "FCGrad" << std::endl; | ||
| } | ||
|
|
||
| void FCGradientOp::InferShape(const ScopePtr& scope) const override {} | ||
|
|
||
| REGISTER_OP(my_fc, paddle::framework::FCOp, | ||
| paddle::framework::FCOpProtoAndCheckerMaker); | ||
| REGISTER_OP(my_fc_grad, paddle::framework::FCGradientOp, | ||
| paddle::framework::FCGradientOpProtoAndCheckerMaker); | ||
| } // namespace framework | ||
| } // namespace paddle | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include <iostream> | ||
| #include "paddle/framework/op_registry.h" | ||
|
|
||
| namespace paddle { | ||
| namespace framework { | ||
| class FCOp : public OperatorBase { | ||
| public: | ||
| void Run(const ScopePtr& scope, | ||
| const platform::DeviceContext& dev_ctx) const override { | ||
| std::cout << "FC" << std::endl; | ||
| }; | ||
| void InferShape(const ScopePtr& scope) const override{}; | ||
| }; | ||
|
|
||
| class FCOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { | ||
| public: | ||
| FCOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) | ||
| : OpProtoAndCheckerMaker(proto, op_checker) { | ||
| AddInput("x", "input data"); | ||
| AddInput("w", "weights"); | ||
| AddInput("b", "bias"); | ||
| AddOutput("y", "output data"); | ||
| AddComment("Fully connnect op"); | ||
| } | ||
| }; | ||
|
|
||
| class FCGradientOp : public OperatorBase { | ||
| void Run(const ScopePtr& scope, | ||
| const platform::DeviceContext& dev_ctx) const override { | ||
| std::cout << "FCGrad" << std::endl; | ||
| }; | ||
| void InferShape(const ScopePtr& scope) const override{}; | ||
| }; | ||
|
|
||
| // class FCGradientOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {}; | ||
|
|
||
| } // namespace framework | ||
| } // namespace paddle |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -101,5 +101,7 @@ class PlainNet : public Net { | |
| } | ||
| }; | ||
|
|
||
| std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this method might be useless, we can directly invoke
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's true. I remembered that |
||
|
|
||
| } // namespace framework | ||
| } // namespace paddle | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |
| #include "paddle/framework/op_desc.pb.h" | ||
| #include "paddle/framework/op_proto.pb.h" | ||
| #include "paddle/framework/operator.h" | ||
| #include "paddle/framework/scope.h" | ||
|
|
||
| namespace paddle { | ||
| namespace framework { | ||
|
|
@@ -205,8 +206,8 @@ class OpRegistry { | |
| template <typename OpType, typename ProtoMakerType> | ||
| static void RegisterOp(const std::string& op_type) { | ||
| creators()[op_type] = [] { return new OpType; }; | ||
| OpProto& op_proto = protos()[op_type]; | ||
| OpAttrChecker& op_checker = op_checkers()[op_type]; | ||
| OpProto& op_proto = protos()[op_type]; | ||
| auto maker = ProtoMakerType(&op_proto, &op_checker); | ||
| maker.Validate(); | ||
| *op_proto.mutable_type() = op_type; | ||
|
|
@@ -255,6 +256,11 @@ class OpRegistry { | |
| return OperatorPtr(op); | ||
| } | ||
|
|
||
| template <typename OpType> | ||
| static void RegisterGradOp(const std::string& op_type) { | ||
| grad_creators()[op_type] = [] { return new OpType; }; | ||
| } | ||
|
|
||
| static OperatorPtr CreateOp(const OpDesc& op_desc) { | ||
| std::vector<std::string> inputs; | ||
| inputs.reserve((size_t)op_desc.inputs_size()); | ||
|
|
@@ -274,6 +280,21 @@ class OpRegistry { | |
| return CreateOp(op_desc.type(), inputs, outputs, attrs); | ||
| } | ||
|
|
||
| static OperatorPtr CreateGradOp(std::shared_ptr<OperatorBase> op) { | ||
| OperatorPtr op_grad(grad_creators().at(op->type_)()); | ||
| op_grad->type_ = op->type_; | ||
| op_grad->inputs_.reserve(op->inputs_.size()); | ||
| for (auto& input : op->inputs_) { | ||
| op_grad->inputs_.emplace_back(input); | ||
| op_grad->outputs_.emplace_back(input + "@grad"); | ||
| } | ||
| for (auto& output : op->outputs_) { | ||
| op_grad->inputs_.emplace_back(output); | ||
| op_grad->inputs_.emplace_back(output + "@grad"); | ||
| } | ||
| return op_grad; | ||
| } | ||
|
|
||
| static std::unordered_map<std::string, OpProto>& protos() { | ||
| static std::unordered_map<std::string, OpProto> protos_; | ||
| return protos_; | ||
|
|
@@ -306,6 +327,11 @@ class OpRegistry { | |
| static std::unordered_map<std::string, OpAttrChecker> op_checkers_; | ||
| return op_checkers_; | ||
| }; | ||
|
|
||
| static std::unordered_map<std::string, OpCreator>& grad_creators() { | ||
| static std::unordered_map<std::string, OpCreator> grad_creators_; | ||
| return grad_creators_; | ||
| } | ||
| }; | ||
|
|
||
| template <typename OpType, typename ProtoMakerType> | ||
|
|
@@ -316,6 +342,14 @@ class OpRegisterHelper { | |
| } | ||
| }; | ||
|
|
||
| template <typename OpType> | ||
| class GradOpRegisterHelper { | ||
| public: | ||
| GradOpRegisterHelper(const char* op_type) { | ||
| OpRegistry::RegisterGradOp<OpType>(op_type); | ||
| } | ||
| }; | ||
|
|
||
| /** | ||
| * check if MACRO is used in GLOBAL NAMESPACE. | ||
| */ | ||
|
|
@@ -335,6 +369,17 @@ class OpRegisterHelper { | |
| __op_register_##__op_type##__(#__op_type); \ | ||
| int __op_register_##__op_type##_handle__() { return 0; } | ||
|
|
||
| /** | ||
| * Macro to Register Operator. | ||
| */ | ||
| #define REGISTER_GRADIENT_OP(__op_type, __op_class) \ | ||
| STATIC_ASSERT_GLOBAL_NAMESPACE( \ | ||
| __reg_op__##__op_type, \ | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个名字可以再特殊一些,例如
否则比较容易和其他的REGISTER冲突。
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| "REGISTER_GRADIENT_OP must be in global namespace"); \ | ||
| static ::paddle::framework::GradOpRegisterHelper<__op_class> \ | ||
| __op_register_##__op_type##__(#__op_type); \ | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同理,这个名字也可以特殊一些。。
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
| int __op_register_##__op_type##_handle__() { return 0; } | ||
|
|
||
| /** | ||
| * Macro to Register OperatorKernel. | ||
| */ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.