Skip to content

Commit e8304bd

Browse files
authored
Merge pull request #2945 from reyoung/feature/grouped_ops
Skeleton Of fully connected operator
2 parents b90780c + 0a0b4ca commit e8304bd

File tree

9 files changed

+195
-26
lines changed

9 files changed

+195
-26
lines changed

paddle/framework/attr_checker.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <functional>
55
#include <string>
66
#include <unordered_map>
7+
#include <unordered_set>
78
#include <vector>
89
#include "paddle/framework/enforce.h"
910

@@ -41,6 +42,35 @@ class DefaultValueSetter {
4142
T default_value_;
4243
};
4344

45+
template <typename T>
46+
class EnumInContainer {
47+
public:
48+
explicit EnumInContainer(const std::unordered_set<T>& c) : container_(c) {}
49+
void operator()(T& val) const {
50+
PADDLE_ENFORCE(container_.find(val) != container_.end(),
51+
"Value %s is not in enum container %s", val,
52+
ContainerDebugString());
53+
}
54+
55+
private:
56+
std::string ContainerDebugString() const {
57+
std::ostringstream sout;
58+
sout << "[";
59+
size_t cnt = 0;
60+
for (auto& v : container_) {
61+
sout << v;
62+
++cnt;
63+
if (cnt != container_.size()) {
64+
sout << " ,";
65+
}
66+
}
67+
sout << "]";
68+
return sout.str();
69+
}
70+
71+
std::unordered_set<T> container_;
72+
};
73+
4474
// check whether a certain attribute fit its limits
4575
// an attribute can have more than one limits
4676
template <typename T>
@@ -50,6 +80,11 @@ class TypedAttrChecker {
5080
public:
5181
TypedAttrChecker(const std::string& attr_name) : attr_name_(attr_name) {}
5282

83+
TypedAttrChecker& InEnum(const std::unordered_set<T>& range) {
84+
value_checkers_.push_back(EnumInContainer<T>(range));
85+
return *this;
86+
}
87+
5388
TypedAttrChecker& LargerThan(const T& lower_bound) {
5489
value_checkers_.push_back(LargerThanChecker<T>(lower_bound));
5590
return *this;

paddle/framework/net.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
namespace paddle {
2020
namespace framework {
2121

22-
void PlainNet::CompleteAddOp() {
22+
void PlainNet::CompleteAddOp(bool calc) {
23+
add_op_done_ = true;
24+
if (!calc) return;
25+
2326
std::unordered_set<std::string> input_set;
2427
std::unordered_set<std::string> output_set;
2528
std::unordered_set<std::string> temp_output;
@@ -52,7 +55,6 @@ void PlainNet::CompleteAddOp() {
5255
}
5356

5457
attrs_["temporary_index"] = tmp_index;
55-
add_op_done_ = true;
5658
}
5759

5860
std::string PlainNet::DebugString() const {

paddle/framework/net.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ limitations under the License. */
1616

1717
#include <paddle/framework/op_desc.pb.h>
1818
#include <paddle/framework/operator.h>
19-
#include "paddle/framework/net_proto.pb.h"
2019
#include "paddle/framework/op_proto.pb.h"
2120
#include "paddle/framework/op_registry.h"
2221
#include "paddle/framework/scope.h"
@@ -41,7 +40,7 @@ namespace framework {
4140
class Net : public OperatorBase {
4241
public:
4342
virtual void AddOp(const OperatorPtr& op) = 0;
44-
virtual void CompleteAddOp() = 0;
43+
virtual void CompleteAddOp(bool calc) = 0;
4544
};
4645

4746
using NetPtr = std::shared_ptr<Net>;
@@ -86,7 +85,7 @@ class PlainNet : public Net {
8685
ops_.push_back(op);
8786
}
8887

89-
void CompleteAddOp() override;
88+
void CompleteAddOp(bool calculate = true) override;
9089

9190
std::string DebugString() const override;
9291

paddle/operators/CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ function(op_library TARGET)
2727
endif()
2828

2929
list(LENGTH cu_srcs cu_srcs_len)
30-
if (${cu_srcs_len} EQUAL 0)
30+
list(LENGTH op_library_DEPS dep_len)
31+
if (${cu_srcs_len} EQUAL 0 AND ${dep_len} EQUAL 0)
3132
message(WARNING "The op library ${TARGET} not support GPU!")
3233
endif()
3334

@@ -47,3 +48,6 @@ op_library(mul_op SRCS mul_op.cc mul_op.cu)
4748
op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc)
4849
op_library(sigmoid_op SRCS sigmoid_op.cu sigmoid_op.cc)
4950
op_library(softmax_op SRCS softmax_op.cc softmax_op.cu)
51+
52+
op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op
53+
softmax_op net)

paddle/operators/fc_op.cc

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/framework/net.h"
16+
#include "paddle/framework/op_registry.h"
17+
#include "paddle/framework/operator.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
class FullyConnectedOp : public framework::PlainNet {
23+
public:
24+
void Init() override {
25+
AddOp(framework::OpRegistry::CreateOp("mul",
26+
{
27+
Input("X"), Input("W"),
28+
},
29+
{Output("before_act")},
30+
{}));
31+
auto b = Input("b");
32+
if (b != framework::OperatorBase::EMPTY_VAR_NAME()) {
33+
AddOp(framework::OpRegistry::CreateOp("rowwise_add",
34+
{Output("before_act"), Input("b")},
35+
{Output("before_act")},
36+
{}));
37+
}
38+
39+
auto activation = GetAttr<std::string>("activation");
40+
AddOp(framework::OpRegistry::CreateOp(
41+
activation, {Output("before_act")}, {Output("Y")}, {}));
42+
CompleteAddOp(false);
43+
}
44+
};
45+
46+
class FullyConnectedOpMaker : public framework::OpProtoAndCheckerMaker {
47+
public:
48+
FullyConnectedOpMaker(framework::OpProto *proto,
49+
framework::OpAttrChecker *op_checker)
50+
: OpProtoAndCheckerMaker(proto, op_checker) {
51+
AddInput("X", "the input of fc operator");
52+
AddInput("W", "the weight of fc operator");
53+
AddInput("b", "the bias of fc operator");
54+
55+
AddOutput("Y", "the output of fc operator");
56+
AddOutput(
57+
"before_act", "the before activation output of fc operator", true);
58+
AddAttr<std::string>("activation", "The activation key for fc layer")
59+
.SetDefault("sigmoid")
60+
.InEnum({"sigmoid", "softmax"});
61+
62+
//! TODO(yuyang18): Complete comment;
63+
AddComment("FullyConnected Operator");
64+
}
65+
};
66+
} // namespace operators
67+
} // namespace paddle
68+
69+
USE_OP(mul);
70+
USE_OP(rowwise_add);
71+
USE_OP(sigmoid);
72+
USE_OP(softmax);
73+
74+
REGISTER_OP(fc,
75+
paddle::operators::FullyConnectedOp,
76+
paddle::operators::FullyConnectedOpMaker);

paddle/pybind/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python
2-
add_op mul_op rowwise_add_op sigmoid_op softmax_op)
2+
add_op fc_op)

paddle/pybind/pybind.cc

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#include <Python.h>
1616
#include <paddle/framework/op_registry.h>
17+
#include <paddle/framework/operator.h>
1718
#include <paddle/framework/scope.h>
1819
#include <paddle/pybind/tensor_bind.h>
1920
#include <pybind11/numpy.h>
@@ -26,10 +27,7 @@ namespace py = pybind11;
2627
namespace pd = paddle::framework;
2728

2829
USE_OP(add_two);
29-
USE_OP(softmax);
30-
USE_OP(mul);
31-
USE_OP(rowwise_add);
32-
USE_OP(sigmoid);
30+
USE_OP_WITHOUT_KERNEL(fc);
3331

3432
PYBIND11_PLUGIN(core) {
3533
py::module m("core", "C++ core of Paddle Paddle");
@@ -53,7 +51,9 @@ PYBIND11_PLUGIN(core) {
5351
self.mutable_data<int>(paddle::platform::CPUPlace());
5452
})
5553
.def("set", paddle::pybind::PyTensorSetFromArray<float>)
56-
.def("set", paddle::pybind::PyTensorSetFromArray<int>);
54+
.def("set", paddle::pybind::PyTensorSetFromArray<int>)
55+
.def("shape",
56+
[](pd::Tensor& self) { return pd::vectorize(self.dims()); });
5757

5858
py::class_<pd::Variable>(m, "Variable", R"DOC(Variable Class.
5959
@@ -83,15 +83,16 @@ All parameter, weight, gradient are variables in Paddle.
8383

8484
//! @note: Be careful! PyBind will return std::string as an unicode, not
8585
//! Python str. If you want a str object, you should cast them in Python.
86-
m.def("get_all_op_protos", []() -> std::vector<std::string> {
86+
m.def("get_all_op_protos", []() -> std::vector<py::bytes> {
8787
auto& protos = pd::OpRegistry::protos();
88-
std::vector<std::string> ret_values;
88+
std::vector<py::bytes> ret_values;
8989
for (auto it = protos.begin(); it != protos.end(); ++it) {
9090
PADDLE_ENFORCE(it->second.IsInitialized(),
9191
"OpProto must all be initialized");
92-
ret_values.emplace_back();
93-
PADDLE_ENFORCE(it->second.SerializeToString(&ret_values.back()),
92+
std::string str;
93+
PADDLE_ENFORCE(it->second.SerializeToString(&str),
9494
"Serialize OpProto Error. This could be a bug of Paddle.");
95+
ret_values.push_back(py::bytes(str));
9596
}
9697
return ret_values;
9798
});
@@ -101,17 +102,26 @@ All parameter, weight, gradient are variables in Paddle.
101102
.def("empty", pd::OperatorBase::EMPTY_VAR_NAME)
102103
.def("temp", pd::OperatorBase::TMP_VAR_NAME);
103104

105+
py::class_<paddle::platform::DeviceContext>(m, "DeviceContext")
106+
.def_static("cpu_context", []() -> paddle::platform::DeviceContext* {
107+
return new paddle::platform::CPUDeviceContext();
108+
});
109+
104110
py::class_<pd::OperatorBase, pd::OperatorPtr>(m, "Operator")
105111
.def("__str__", &pd::OperatorBase::DebugString)
106-
.def_static("create", [](const std::string& protobin) {
107-
pd::OpDesc desc;
108-
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
109-
"Cannot parse user input to OpDesc");
110-
PADDLE_ENFORCE(desc.IsInitialized(),
111-
"User OpDesc is not initialized, reason %s",
112-
desc.InitializationErrorString());
113-
return pd::OpRegistry::CreateOp(desc);
114-
});
112+
.def_static("create",
113+
[](py::bytes protobin) {
114+
pd::OpDesc desc;
115+
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
116+
"Cannot parse user input to OpDesc");
117+
PADDLE_ENFORCE(desc.IsInitialized(),
118+
"User OpDesc is not initialized, reason %s",
119+
desc.InitializationErrorString());
120+
return pd::OpRegistry::CreateOp(desc);
121+
})
122+
.def("infer_shape", &pd::OperatorBase::InferShape)
123+
.def("run", &pd::OperatorBase::Run)
124+
.def("outputs", [](const pd::OperatorPtr& op) { return op->outputs_; });
115125

116126
return m.ptr();
117127
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
add_python_test(test_framework test_protobuf.py test_scope.py
22
test_default_scope_funcs.py test_op_creation_methods.py
3-
test_tensor.py)
3+
test_tensor.py test_fc_op.py)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import paddle.v2.framework.core as core
2+
import unittest
3+
import numpy
4+
import paddle.v2.framework.create_op_creation_methods as creation
5+
6+
7+
class TestFc(unittest.TestCase):
8+
def test_fc(self):
9+
scope = core.Scope(None)
10+
x = scope.create_var("X")
11+
x_tensor = x.get_tensor()
12+
x_tensor.set_dims([1000, 784])
13+
x_tensor.alloc_float()
14+
15+
w = scope.create_var("W")
16+
w_tensor = w.get_tensor()
17+
w_tensor.set_dims([784, 100])
18+
w_tensor.alloc_float()
19+
20+
w_tensor.set(numpy.random.random((784, 100)).astype("float32"))
21+
22+
# Set a real numpy array here.
23+
# x_tensor.set(numpy.array([]))
24+
25+
op = creation.op_creations.fc(X="X", Y="Y", W="W")
26+
27+
for out in op.outputs():
28+
if scope.get_var(out) is None:
29+
scope.create_var(out).get_tensor()
30+
31+
tensor = scope.get_var("Y").get_tensor()
32+
op.infer_shape(scope)
33+
self.assertEqual([1000, 100], tensor.shape())
34+
35+
ctx = core.DeviceContext.cpu_context()
36+
37+
op.run(scope, ctx)
38+
39+
# After complete all ops, check Y is expect or not.
40+
41+
42+
if __name__ == '__main__':
43+
unittest.main()

0 commit comments

Comments
 (0)