Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
f35c8c4
remove simple_op_design.md
JiayiFeng Jul 2, 2017
7dc53ea
renew simple_op_design.md
JiayiFeng Jul 3, 2017
015ccd4
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
JiayiFeng Jul 3, 2017
816b4c8
"add backward Op"
dzhwinter Jul 18, 2017
9890b23
fix confilct
dzhwinter Jul 18, 2017
8b80cf8
"add net op testing"
dzhwinter Jul 18, 2017
7f1533f
test collaborating
JiayiFeng Jul 18, 2017
855cae6
move unused file
dzhwinter Jul 18, 2017
e786746
Merge branch 'backward' of https://github.com/dzhwinter/Paddle into b…
JiayiFeng Jul 18, 2017
bf4da3d
Refactor Rigistry::CreateGradOp()
JiayiFeng Jul 19, 2017
cb95587
"ignore some gradient of specific op"
dzhwinter Jul 19, 2017
94a6b1f
rename a macro
JiayiFeng Jul 19, 2017
8bc4892
"fix comment "
dzhwinter Jul 19, 2017
3dc70ff
Merge branch 'backward' of https://github.com/dzhwinter/Paddle into b…
JiayiFeng Jul 19, 2017
73f4779
Merge remote-tracking branch 'origin/develop' into backward2
dzhwinter Jul 19, 2017
4876f35
"make plainNet shared"
dzhwinter Jul 19, 2017
e192d0f
Refactor the implementation of gradient Op creating
JiayiFeng Jul 19, 2017
14424f3
"use built-in operator"
dzhwinter Jul 20, 2017
81a352a
"test fc without gradient"
dzhwinter Jul 20, 2017
6f05392
Merge remote-tracking branch 'origin/develop' into backward2
dzhwinter Jul 20, 2017
8a5ee46
Fix some compile errors
JiayiFeng Jul 20, 2017
b635af7
Fix some compile error
JiayiFeng Jul 20, 2017
45452ac
Merge branch 'backward' of https://github.com/dzhwinter/Paddle into b…
JiayiFeng Jul 20, 2017
088e220
"remove unused fake fc op"
dzhwinter Jul 20, 2017
99a5904
Merge branch 'backward_dev' into backward
JiayiFeng Jul 20, 2017
f41fcd4
Merge branch 'backward' of https://github.com/dzhwinter/Paddle into b…
JiayiFeng Jul 20, 2017
9418717
Fix compile errors
JiayiFeng Jul 20, 2017
4736b23
Add a simple test for grad_op_creator
JiayiFeng Jul 21, 2017
f85ccdd
Renew CMakeList dependence
JiayiFeng Jul 24, 2017
0ab8f52
Merge branch 'backward' of https://github.com/Canpio/Paddle into back…
JiayiFeng Jul 24, 2017
380227b
Renew CMakeList dependence
JiayiFeng Jul 24, 2017
5f3bc2a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
JiayiFeng Jul 24, 2017
f4e2555
Fix compile error
JiayiFeng Jul 24, 2017
81df39f
fix compile errer
JiayiFeng Jul 24, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)

# cc_library(fc_op SRCS fully_connected_op.cc DEPS operator)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc enforce)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator)

Expand All @@ -26,5 +27,6 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch
add_dependencies(framework_py_proto framework_py_proto_init)

proto_library(net_proto SRCS net_proto.proto DEPS op_proto)
# cc_library(net SRCS net.cc DEPS operator net_proto op_registry fc_op)
cc_library(net SRCS net.cc DEPS operator net_proto op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net)
54 changes: 54 additions & 0 deletions paddle/framework/fully_connected_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <iostream>
#include "paddle/framework/op_registry.h"

namespace paddle {
namespace framework {
class FCOp : public OperatorBase {
public:
void Run(const ScopePtr& scope,
const platform::DeviceContext& dev_ctx) const override {
std::cout << "FC" << std::endl;
};
void InferShape(const ScopePtr& scope) const override{};
};

class FCOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
FCOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("x", "input data");
AddInput("w", "weights");
AddInput("b", "bias");
AddOutput("y", "output data");
AddComment("Fully connnect op");
}
};

class FCGradientOp : public OperatorBase {
void Run(const ScopePtr& scope,
const platform::DeviceContext& dev_ctx) const override {
std::cout << "FCGrad" << std::endl;
};
void InferShape(const ScopePtr& scope) const override{};
};

// class FCGradientOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {};
REGISTER_OP(my_fc, FCOp, FCOpProtoAndCheckerMaker);
REGISTER_GRADIENT_OP(my_fc_grad, FCGradientOp);

} // namespace framework
} // namespace paddle
14 changes: 14 additions & 0 deletions paddle/framework/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,24 @@
*/

#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"

namespace paddle {
namespace framework {

std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps) {
// NetPtr->reset(new PlainNet);
// NetPtr grad_ops = new PlainNet;
std::shared_ptr<PlainNet> grad_ops;
grad_ops.reset(new PlainNet);
for (auto& op : ForwardOps->ops_) {
auto op_grad = OpRegistry::CreateGradOp(op);
grad_ops->AddOp(op_grad);
}
grad_ops->CompleteAddOp();
return grad_ops;
}

void PlainNet::CompleteAddOp() {
std::unordered_set<std::string> input_set;
std::unordered_set<std::string> output_set;
Expand Down
2 changes: 2 additions & 0 deletions paddle/framework/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,7 @@ class PlainNet : public Net {
}
};

std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps);
Copy link
Collaborator

@reyoung reyoung Jul 19, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this method might be useless, we can directly invoke OpRegistry::CreateGradOp(...) even that Op is a net.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true. I remembered that AddBackwardOp is an interface for PlainNet, which generate the backward ops, then it can be used in Python. Remove it straightly is fine to me.


} // namespace framework
} // namespace paddle
38 changes: 27 additions & 11 deletions paddle/framework/net_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@
#include <paddle/framework/net.h>
#include <paddle/framework/op_registry.h>
#include <paddle/framework/operator.h>
#include "paddle/framework/fully_connected_op.h"

namespace pd = paddle::framework;
namespace paddle {
namespace framework {

static int infer_shape_cnt = 0;
static int run_cnt = 0;

class TestOp : public pd::OperatorBase {
class TestOp : public OperatorBase {
public:
void InferShape(const paddle::framework::ScopePtr& scope) const override {
++infer_shape_cnt;
}
void Run(const paddle::framework::ScopePtr& scope,
const paddle::platform::DeviceContext& dev_ctx) const override {
void InferShape(const ScopePtr& scope) const override { ++infer_shape_cnt; }
void Run(const ScopePtr& scope,
const platform::DeviceContext& dev_ctx) const override {
++run_cnt;
}
};
Expand All @@ -33,7 +33,7 @@ void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
}

TEST(OpKernel, all) {
auto net = std::make_shared<paddle::framework::PlainNet>();
auto net = std::make_shared<PlainNet>();
ASSERT_NE(net, nullptr);

auto op1 = std::make_shared<TestOp>();
Expand All @@ -55,13 +55,29 @@ TEST(OpKernel, all) {
ASSERT_EQ(1UL, tmp_idx.size());
ASSERT_EQ("y", net->outputs_[tmp_idx[0]]);

auto scope = std::make_shared<pd::Scope>();
paddle::platform::CPUDeviceContext dev_ctx;
auto scope = std::make_shared<Scope>();
platform::CPUDeviceContext dev_ctx;

net->InferShape(scope);
net->Run(scope, dev_ctx);
ASSERT_EQ(2, infer_shape_cnt);
ASSERT_EQ(2, run_cnt);

ASSERT_THROW(net->AddOp(op2), paddle::framework::EnforceNotMet);
ASSERT_THROW(net->AddOp(op2), EnforceNotMet);
}

TEST(AddBackwardOp, TestGradOp) {
auto net = std::make_shared<PlainNet>();
ASSERT_NE(net, nullptr);
auto op1 = std::make_shared<FCOp>();
op1->inputs_ = {"x", "w1", "b1"};
op1->outputs_ = {"y"};
net->AddOp(op1);
auto grad_ops = AddBackwardOp(net);
for (auto& op : grad_ops->ops_) {
op->DebugString();
}
}

} // namespace framework
} // namespace paddle
24 changes: 0 additions & 24 deletions paddle/framework/net_test.cc

This file was deleted.

Loading