Skip to content

Commit 6bbc9a5

Browse files
authored
Merge pull request #3640 from qingqing01/pybind
Move pybind from package paddle/framework into paddle/pybind.
2 parents 6177c83 + bfcaf88 commit 6bbc9a5

File tree

5 files changed

+37
-32
lines changed

5 files changed

+37
-32
lines changed

paddle/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ if(Boost_FOUND)
1515
add_subdirectory(platform)
1616
add_subdirectory(framework)
1717
add_subdirectory(operators)
18+
add_subdirectory(pybind)
1819
endif()
1920

2021
if(WITH_C_API)

paddle/framework/CMakeLists.txt

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -39,23 +39,3 @@ add_custom_command(TARGET framework_py_proto POST_BUILD
3939

4040
cc_library(backward SRCS backward.cc DEPS net_op)
4141
cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context)
42-
43-
if(WITH_PYTHON)
44-
cc_library(paddle_pybind SHARED
45-
SRCS pybind.cc
46-
DEPS pybind python backward
47-
sgd_op
48-
gather_op
49-
add_op
50-
mul_op
51-
rowwise_add_op
52-
sigmoid_op
53-
softmax_op
54-
mean_op
55-
cross_entropy_op
56-
recurrent_op
57-
uniform_random_op
58-
gaussian_random_op
59-
fill_zeros_like_op
60-
scale_op)
61-
endif(WITH_PYTHON)

paddle/pybind/CMakeLists.txt

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
if(WITH_PYTHON)
2+
cc_library(paddle_pybind SHARED
3+
SRCS pybind.cc
4+
DEPS pybind python backward
5+
sgd_op
6+
gather_op
7+
add_op
8+
mul_op
9+
rowwise_add_op
10+
sigmoid_op
11+
softmax_op
12+
mean_op
13+
cross_entropy_op
14+
recurrent_op
15+
uniform_random_op
16+
gaussian_random_op
17+
fill_zeros_like_op
18+
scale_op)
19+
endif(WITH_PYTHON)
Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ limitations under the License. */
1818

1919
#include "paddle/framework/backward.h"
2020
#include "paddle/framework/op_registry.h"
21-
#include "paddle/framework/tensor_py.h"
2221
#include "paddle/operators/net_op.h"
2322
#include "paddle/operators/recurrent_op.h"
2423
#include "paddle/platform/enforce.h"
2524
#include "paddle/platform/place.h"
25+
#include "paddle/pybind/tensor_py.h"
2626
#include "paddle/string/to_string.h"
2727
#include "pybind11/numpy.h"
2828
#include "pybind11/pybind11.h"
@@ -134,7 +134,8 @@ All parameter, weight, gradient are variables in Paddle.
134134
py::return_value_policy::reference)
135135
.def("find_var", &Scope::FindVar, py::return_value_policy::reference)
136136
.def(py::init<>())
137-
.def("new_scope", [](Scope &self) -> Scope * { return &self.NewScope(); },
137+
.def("new_scope",
138+
[](Scope &self) -> Scope * { return &self.NewScope(); },
138139
py::return_value_policy::reference)
139140
.def("drop_kids", &Scope::DropKids);
140141

@@ -222,8 +223,10 @@ All parameter, weight, gradient are variables in Paddle.
222223
retv->SetType("plain_net");
223224
return retv;
224225
})
225-
.def("append_op", [](operators::NetOp &self,
226-
const OperatorBase &op) { self.AppendOp(op); })
226+
.def("append_op",
227+
[](operators::NetOp &self, const OperatorBase &op) {
228+
self.AppendOp(op);
229+
})
227230
.def("complete_add_op", &operators::NetOp::CompleteAddOp)
228231
.def("complete_add_op", [](std::shared_ptr<operators::NetOp> &self) {
229232
self->CompleteAddOp();
@@ -243,10 +246,9 @@ All parameter, weight, gradient are variables in Paddle.
243246
auto rnn_op = OpRegistry::CreateOp(desc);
244247
return static_cast<operators::RecurrentOp *>(rnn_op.release());
245248
})
246-
.def("set_stepnet", [](operators::RecurrentOp &self,
247-
const operators::NetOp &net) -> void {
248-
self.set_stepnet(net.Clone());
249-
});
249+
.def("set_stepnet",
250+
[](operators::RecurrentOp &self, const operators::NetOp &net)
251+
-> void { self.set_stepnet(net.Clone()); });
250252

251253
m.def("unique_integer", UniqueIntegerGenerator);
252254

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,11 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
6363
}
6464
return py::buffer_info(
6565
dst_tensor.mutable_data<CUR_TYPE>(dst_tensor.holder_->place()),
66-
sizeof(CUR_TYPE), py::format_descriptor<CUR_TYPE>::format(),
67-
(size_t)framework::arity(dst_tensor.dims()), dims_outside, strides);
66+
sizeof(CUR_TYPE),
67+
py::format_descriptor<CUR_TYPE>::format(),
68+
(size_t)framework::arity(dst_tensor.dims()),
69+
dims_outside,
70+
strides);
6871
} else {
6972
constexpr bool less = I + 1 < std::tuple_size<std::tuple<ARGS...>>::value;
7073
return CastToPyBufferImpl<less, I + 1, ARGS...>()(tensor);
@@ -107,8 +110,8 @@ void PyCUDATensorSetFromArray(
107110

108111
self.Resize(framework::make_ddim(dims));
109112
auto *dst = self.mutable_data<T>(place);
110-
paddle::platform::GpuMemcpySync(dst, array.data(), sizeof(T) * array.size(),
111-
cudaMemcpyHostToDevice);
113+
paddle::platform::GpuMemcpySync(
114+
dst, array.data(), sizeof(T) * array.size(), cudaMemcpyHostToDevice);
112115
}
113116
#endif
114117

0 commit comments

Comments
 (0)