@@ -14,6 +14,7 @@ limitations under the License. */
1414
1515#include < Python.h>
1616#include < paddle/framework/op_registry.h>
17+ #include < paddle/framework/operator.h>
1718#include < paddle/framework/scope.h>
1819#include < paddle/pybind/tensor_bind.h>
1920#include < pybind11/numpy.h>
@@ -26,10 +27,7 @@ namespace py = pybind11;
2627namespace pd = paddle::framework;
2728
2829USE_OP (add_two);
29- USE_OP (softmax);
30- USE_OP (mul);
31- USE_OP (rowwise_add);
32- USE_OP (sigmoid);
30+ USE_OP_WITHOUT_KERNEL (fc);
3331
3432PYBIND11_PLUGIN (core) {
3533 py::module m (" core" , " C++ core of Paddle Paddle" );
@@ -53,7 +51,9 @@ PYBIND11_PLUGIN(core) {
5351 self.mutable_data <int >(paddle::platform::CPUPlace ());
5452 })
5553 .def (" set" , paddle::pybind::PyTensorSetFromArray<float >)
56- .def (" set" , paddle::pybind::PyTensorSetFromArray<int >);
54+ .def (" set" , paddle::pybind::PyTensorSetFromArray<int >)
55+ .def (" shape" ,
56+ [](pd::Tensor& self) { return pd::vectorize (self.dims ()); });
5757
5858 py::class_<pd::Variable>(m, " Variable" , R"DOC( Variable Class.
5959
@@ -83,15 +83,16 @@ All parameter, weight, gradient are variables in Paddle.
8383
8484 // ! @note: Be careful! PyBind will return std::string as an unicode, not
8585 // ! Python str. If you want a str object, you should cast them in Python.
86- m.def (" get_all_op_protos" , []() -> std::vector<std::string > {
86+ m.def (" get_all_op_protos" , []() -> std::vector<py::bytes > {
8787 auto & protos = pd::OpRegistry::protos ();
88- std::vector<std::string > ret_values;
88+ std::vector<py::bytes > ret_values;
8989 for (auto it = protos.begin (); it != protos.end (); ++it) {
9090 PADDLE_ENFORCE (it->second .IsInitialized (),
9191 " OpProto must all be initialized" );
92- ret_values. emplace_back () ;
93- PADDLE_ENFORCE (it->second .SerializeToString (&ret_values. back () ),
92+ std::string str ;
93+ PADDLE_ENFORCE (it->second .SerializeToString (&str ),
9494 " Serialize OpProto Error. This could be a bug of Paddle." );
95+ ret_values.push_back (py::bytes (str));
9596 }
9697 return ret_values;
9798 });
@@ -101,17 +102,26 @@ All parameter, weight, gradient are variables in Paddle.
101102 .def (" empty" , pd::OperatorBase::EMPTY_VAR_NAME)
102103 .def (" temp" , pd::OperatorBase::TMP_VAR_NAME);
103104
105+ py::class_<paddle::platform::DeviceContext>(m, " DeviceContext" )
106+ .def_static (" cpu_context" , []() -> paddle::platform::DeviceContext* {
107+ return new paddle::platform::CPUDeviceContext ();
108+ });
109+
104110 py::class_<pd::OperatorBase, pd::OperatorPtr>(m, " Operator" )
105111 .def (" __str__" , &pd::OperatorBase::DebugString)
106- .def_static (" create" , [](const std::string& protobin) {
107- pd::OpDesc desc;
108- PADDLE_ENFORCE (desc.ParsePartialFromString (protobin),
109- " Cannot parse user input to OpDesc" );
110- PADDLE_ENFORCE (desc.IsInitialized (),
111- " User OpDesc is not initialized, reason %s" ,
112- desc.InitializationErrorString ());
113- return pd::OpRegistry::CreateOp (desc);
114- });
112+ .def_static (" create" ,
113+ [](py::bytes protobin) {
114+ pd::OpDesc desc;
115+ PADDLE_ENFORCE (desc.ParsePartialFromString (protobin),
116+ " Cannot parse user input to OpDesc" );
117+ PADDLE_ENFORCE (desc.IsInitialized (),
118+ " User OpDesc is not initialized, reason %s" ,
119+ desc.InitializationErrorString ());
120+ return pd::OpRegistry::CreateOp (desc);
121+ })
122+ .def (" infer_shape" , &pd::OperatorBase::InferShape)
123+ .def (" run" , &pd::OperatorBase::Run)
124+ .def (" outputs" , [](const pd::OperatorPtr& op) { return op->outputs_ ; });
115125
116126 return m.ptr ();
117127}
0 commit comments