Skip to content

Commit df707d0

Browse files
authored
Merge pull request #2893 from reyoung/feature/op_creation_methods
Python Generate OpCreation Methods by OpProto
2 parents a0caf23 + 0e77b31 commit df707d0

File tree

7 files changed

+542
-20
lines changed

7 files changed

+542
-20
lines changed

paddle/framework/op_registry.h

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <algorithm>
4+
#include <atomic>
45
#include <type_traits>
56
#include <unordered_map>
67
#include <unordered_set>
@@ -214,25 +215,35 @@ class OpRegistry {
214215
}
215216

216217
static OperatorPtr CreateOp(const OpDesc& op_desc) {
218+
//! Create a OpPtr by type.
217219
std::string op_type = op_desc.type();
218220
OperatorPtr op(creators().at(op_type)());
221+
//! Fill op's data member. Not use constructor because it will be noising
222+
//! for Op developer.
219223
const OpProto& op_proto = protos().at(op_type);
220-
// set op's inputs_ from desc.
221224
op->type_ = op_desc.type();
225+
// set op's inputs_ from desc.
222226
op->inputs_.reserve((size_t)op_desc.inputs_size());
223227
std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
224228
std::back_inserter(op->inputs_));
225229
// set op's outputs_ from desc.
226230
op->outputs_.reserve((size_t)op_desc.outputs_size());
227231
std::copy(op_desc.outputs().begin(), op_desc.outputs().end(),
228232
std::back_inserter(op->outputs_));
229-
// set op's attr;
233+
234+
//! Fill attrs, and validate attrs.
230235
for (auto& attr : op_desc.attrs()) {
231236
op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr);
232237
}
233238
op_checkers().at(op_type).Check(op->attrs_);
239+
240+
//! Convert Temporary variable name to an unique variable name.
241+
GenerateTempVariableName(op.get());
242+
234243
// set argument offsets stored in op.
235244
CreateInOutOffsetMap(op, op_proto);
245+
//! Other op's custom Init for a complex Op. For simple Op, the Init
246+
//! method do nothing.
236247
op->Init();
237248
return op;
238249
}
@@ -248,6 +259,17 @@ class OpRegistry {
248259
};
249260

250261
private:
262+
static void GenerateTempVariableName(OperatorBase* op) {
263+
static std::atomic<size_t> gUniqId(0UL);
264+
for (auto& outname : op->outputs_) {
265+
if (outname == OperatorBase::TMP_VAR_NAME()) {
266+
outname += op->type_;
267+
outname += "@";
268+
outname += std::to_string(gUniqId.fetch_add(1));
269+
}
270+
}
271+
}
272+
251273
static std::unordered_map<std::string, OpCreator>& creators() {
252274
static std::unordered_map<std::string, OpCreator> creators_;
253275
return creators_;

paddle/framework/operator.cc

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -77,23 +77,21 @@ std::vector<std::string> OperatorBase::Outputs(const std::string& name) const {
7777

7878
std::string OperatorBase::DebugString() const {
7979
std::stringstream ss;
80-
ss << "=================\n";
81-
ss << "type = " << type_ << "\n";
82-
ss << "inputs = [";
83-
for (auto& ipt : inputs_) {
84-
ss << ipt << ", ";
80+
ss << "Op(" << type_ << "), inputs:(";
81+
for (size_t i = 0; i < inputs_.size(); ++i) {
82+
ss << inputs_[i];
83+
if (i != inputs_.size() - 1) {
84+
ss << ", ";
85+
}
8586
}
86-
ss << "]\n";
87-
ss << "outputs = [";
88-
for (auto& opt : outputs_) {
89-
ss << opt << ", ";
87+
ss << "), outputs:(";
88+
for (size_t i = 0; i < outputs_.size(); ++i) {
89+
ss << outputs_[i];
90+
if (i != outputs_.size() - 1) {
91+
ss << ", ";
92+
}
9093
}
91-
ss << "]\n";
92-
ss << "attr_keys = [";
93-
for (auto& attr : attrs_) {
94-
ss << attr.first << ", ";
95-
}
96-
ss << "]\n";
94+
ss << ").";
9795
return ss.str();
9896
}
9997

paddle/framework/operator.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@ using OperatorPtr = std::shared_ptr<OperatorBase>;
4141
*/
4242
class OperatorBase {
4343
public:
44+
/// If a variable is a empty variable, that name will be used.
45+
static std::string EMPTY_VAR_NAME() { return "@EMPTY@"; }
46+
47+
/// If a variable is a temporary variable, that name will be set in Python,
48+
/// but it will be convert to a unique name in scope after OpCreator.
49+
static std::string TMP_VAR_NAME() { return "@TEMP@"; }
50+
4451
virtual ~OperatorBase() {}
4552

4653
template <typename T>

paddle/pybind/pybind.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,23 @@ All parameter, weight, gradient are variables in Paddle.
6363
}
6464
return ret_values;
6565
});
66+
m.def_submodule(
67+
"var_names",
68+
"The module will return special predefined variable name in Paddle")
69+
.def("empty", pd::OperatorBase::EMPTY_VAR_NAME)
70+
.def("temp", pd::OperatorBase::TMP_VAR_NAME);
71+
72+
py::class_<pd::OperatorBase, pd::OperatorPtr>(m, "Operator")
73+
.def("__str__", &pd::OperatorBase::DebugString)
74+
.def_static("create", [](const std::string& protobin) {
75+
pd::OpDesc desc;
76+
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
77+
"Cannot parse user input to OpDesc");
78+
PADDLE_ENFORCE(desc.IsInitialized(),
79+
"User OpDesc is not initialized, reason %s",
80+
desc.InitializationErrorString());
81+
return pd::OpRegistry::CreateOp(desc);
82+
});
6683

6784
return m.ptr();
6885
}
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,246 @@
11
import paddle.v2.framework.core as core
22
import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2
3+
import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2
4+
import paddle.v2.framework.proto.attr_type_pb2 as attr_type_pb2
5+
import cStringIO
36

47

58
def get_all_op_protos():
9+
"""
10+
Get all registered op proto from Paddle C++
11+
:return: list of OpProto
12+
"""
613
protostrs = core.get_all_op_protos()
714
ret_values = []
815
for pbstr in protostrs:
916
op_proto = op_proto_pb2.OpProto.FromString(str(pbstr))
1017
ret_values.append(op_proto)
1118
return ret_values
19+
20+
21+
class OpDescCreationMethod(object):
22+
"""
23+
A Functor object to convert user input(use key word args) to OpDesc based on
24+
OpProto.
25+
26+
:param op_proto: The OpProto object.
27+
:type op_proto: op_proto_pb2.OpProto
28+
"""
29+
30+
def __init__(self, op_proto):
31+
if not isinstance(op_proto, op_proto_pb2.OpProto):
32+
raise TypeError("Argument should be OpProto")
33+
self.__op_proto__ = op_proto
34+
35+
def __call__(self, *args, **kwargs):
36+
"""
37+
Convert user input to OpDesc. Only key-word args are supported.
38+
:return: OpDesc based on user input
39+
:rtype: op_desc_pb2.OpDesc
40+
"""
41+
if len(args) != 0:
42+
raise ValueError("Only keyword arguments is supported by Paddle")
43+
op_desc = op_desc_pb2.OpDesc()
44+
45+
# Inputs
46+
ipts, ipt_format, _ = OpDescCreationMethod.extract_input_or_output(
47+
"input", kwargs, self.__op_proto__.inputs)
48+
op_desc.inputs.extend(ipts)
49+
if ipt_format is not None:
50+
op_desc.attrs.extend([ipt_format])
51+
52+
# Outputs
53+
outs, out_format, tmp_index = OpDescCreationMethod.extract_input_or_output(
54+
"output", kwargs, self.__op_proto__.outputs)
55+
op_desc.outputs.extend(outs)
56+
if out_format is not None:
57+
op_desc.attrs.extend([out_format])
58+
if len(tmp_index) != 0:
59+
tmp_index_attr = op_desc.attrs.add()
60+
tmp_index_attr.type = attr_type_pb2.INTS
61+
tmp_index_attr.name = "temporary_index"
62+
tmp_index_attr.ints.extend(tmp_index)
63+
64+
# Types
65+
op_desc.type = self.__op_proto__.type
66+
67+
# Attrs
68+
for attr in self.__op_proto__.attrs:
69+
if attr.generated:
70+
continue
71+
user_defined_attr = kwargs.get(attr.name, None)
72+
if user_defined_attr is not None:
73+
new_attr = op_desc.attrs.add()
74+
new_attr.name = attr.name
75+
new_attr.type = attr.type
76+
if attr.type == attr_type_pb2.INT:
77+
new_attr.i = user_defined_attr
78+
elif attr.type == attr_type_pb2.FLOAT:
79+
new_attr.f = user_defined_attr
80+
elif attr.type == attr_type_pb2.STRING:
81+
new_attr.s = user_defined_attr
82+
elif attr.type == attr_type_pb2.INTS:
83+
new_attr.ints.extend(user_defined_attr)
84+
elif attr.type == attr_type_pb2.FLOATS:
85+
new_attr.floats.extend(user_defined_attr)
86+
elif attr.type == attr_type_pb2.STRINGS:
87+
new_attr.strings.extend(user_defined_attr)
88+
else:
89+
raise NotImplementedError("Not support attribute type " +
90+
attr.type)
91+
92+
return op_desc
93+
94+
@staticmethod
95+
def extract_input_or_output(in_out, kwargs, meta):
96+
"""
97+
Extract input variable names or output variable names from key-word
98+
arguments, which base on VarProtos.
99+
100+
:param in_out: "input" or "output"
101+
:param kwargs: key-word arguments that user inputted.
102+
:param meta: a list of VarProto
103+
:return: The three object will be return. The variable names. The
104+
input_format or output_format attribute(None if the input or output is
105+
not multiple). The temporary variable index list.
106+
"""
107+
multiple = OpDescCreationMethod.any_is_true((m.multiple for m in meta))
108+
tmp_index = []
109+
retv = []
110+
if multiple:
111+
var_format = op_desc_pb2.AttrDesc()
112+
var_format.type = attr_type_pb2.INTS
113+
var_format.name = "%s_format" % in_out
114+
var_format.ints.append(0)
115+
116+
for var in meta:
117+
var_name = var.name
118+
119+
if var.temporary:
120+
var_name = [core.var_names.temp()]
121+
tmp_index.append(len(retv))
122+
else:
123+
var_name = kwargs.get(var_name, [])
124+
if not isinstance(var_name, list):
125+
var_name = [var_name]
126+
retv.extend(var_name)
127+
var_format.ints.append(len(var_name) + var_format.ints[-1])
128+
return retv, var_format, tmp_index
129+
else:
130+
for var in meta:
131+
if var.temporary:
132+
retv.append(kwargs.get(var.name, core.var_names.temp()))
133+
tmp_index.append(len(retv))
134+
else:
135+
retv.append(kwargs.get(var.name, core.var_names.empty()))
136+
return retv, None, tmp_index
137+
138+
@staticmethod
139+
def any_is_true(generator):
140+
"""
141+
Reduce a bool array to one. If any of them is True, then return True.
142+
"""
143+
for flag in generator:
144+
if flag:
145+
return True
146+
return False
147+
148+
149+
def get_docstring_from_op_proto(op_proto):
150+
"""
151+
Generate docstring from a OpProto
152+
:param op_proto: a OpProto instance.
153+
:type op_proto: op_proto_pb2.OpProto
154+
:return: docstring
155+
"""
156+
if not isinstance(op_proto, op_proto_pb2.OpProto):
157+
raise TypeError("Input must be OpProto")
158+
f = cStringIO.StringIO()
159+
f.write(op_proto.comment)
160+
f.write("\n")
161+
162+
def __append_param__(name, comment, type):
163+
# Maybe replace the following line with template engine is better.
164+
f.write(":param ")
165+
f.write(name)
166+
f.write(": ")
167+
f.write(comment)
168+
f.write("\n")
169+
f.write(":type ")
170+
f.write(name)
171+
f.write(": ")
172+
f.write(type)
173+
f.write("\n")
174+
175+
for ipt in op_proto.inputs:
176+
__append_param__(ipt.name, ipt.comment, "list | basestr"
177+
if ipt.multiple else "basestr")
178+
179+
temp_var_prefix = \
180+
"This is a temporary variable. It does not have to set by user. "
181+
for opt in op_proto.outputs:
182+
__append_param__(opt.name, opt.comment if not opt.temporary else
183+
temp_var_prefix + opt.comment, "list | basestr"
184+
if opt.multiple else "basestr")
185+
186+
for attr in op_proto.attrs:
187+
attr_type = None
188+
if attr.type == attr_type_pb2.INT:
189+
attr_type = "int"
190+
elif attr.type == attr_type_pb2.FLOAT:
191+
attr_type = "float"
192+
elif attr.type == attr_type_pb2.STRING:
193+
attr_type = "basestr"
194+
elif attr.type == attr_type_pb2.INTS:
195+
attr_type = "list of int"
196+
elif attr.type == attr_type_pb2.FLOATS:
197+
attr_type = "list of float"
198+
elif attr.type == attr_type_pb2.STRINGS:
199+
attr_type = "list of basestr"
200+
201+
if attr_type is None:
202+
raise RuntimeError("Not supported attribute type " + attr.type)
203+
204+
__append_param__(attr.name, attr.comment, attr_type)
205+
206+
return f.getvalue()
207+
208+
209+
def create_op_creation_method(op_proto):
210+
"""
211+
Generate op creation method for an OpProto
212+
"""
213+
method = OpDescCreationMethod(op_proto)
214+
215+
def __impl__(*args, **kwargs):
216+
opdesc = method(*args, **kwargs)
217+
return core.Operator.create(opdesc.SerializeToString())
218+
219+
__impl__.__doc__ = get_docstring_from_op_proto(op_proto)
220+
return __impl__
221+
222+
223+
class OpCreationsHolder(object):
224+
"""
225+
A object will holds all op creation methods.
226+
227+
Use `op_creations.xxx_op` to access them.
228+
"""
229+
pass
230+
231+
232+
op_creations = OpCreationsHolder()
233+
234+
235+
def __bootstrap__():
236+
"""
237+
Bootstrap function for this module. It will dynamic create all op creation
238+
methods in runtime.
239+
"""
240+
for op_proto in get_all_op_protos():
241+
func = create_op_creation_method(op_proto)
242+
func.__name__ = str(op_proto.type)
243+
setattr(op_creations, func.__name__, func)
244+
245+
246+
__bootstrap__()

0 commit comments

Comments
 (0)