Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions paddle/operators/identity_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,20 @@
namespace paddle {
namespace operators {

// identity is a alias of scale op. This is also a example for creating a alias
// operator.
// The identity operator is an alias of the scale operator. This is also an
// example for creating an alias for an existing operator.
template <typename AttrType>
class IdentityOpMaker : public framework::OpProtoAndCheckerMaker {
public:
IdentityOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "input tensor of identity op");
AddOutput("Out", "output tensor of identity op");
AddComment("identity operator. Just a alias of scale op which scale = 1.0");
AddInput("X", "The input tensor of identity operator.");
AddOutput("Out", "The output tensor of identity operator.");
AddComment(R"DOC(
The identity operator is an alias of the scale operator
with the attribute scale fixed to 1.0.
)DOC");
}
};

Expand Down
6 changes: 4 additions & 2 deletions paddle/operators/scale_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,13 @@ class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {

The equation is: Out = scale*X
)DOC");
AddAttr<AttrType>("scale", "scale of scale operator.").SetDefault(1.0);
AddAttr<AttrType>("scale", "The scaling factor of the scale operator.")
.SetDefault(1.0);
}
};

// Scale Op's gradient is scale op, too.
// The operator to calculate gradients of a scale operator is just the scale
// operator itself.
// Grad(Out=scale(X)) => Grad(X) = scale(Grad(Out))
template <typename AttrType>
class ScaleGradOp : public NetOp {
Expand Down
15 changes: 8 additions & 7 deletions paddle/operators/softmax_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ the other dimensions in the K-dimensional vector input. Then the ratio of the
exponential of the given dimension and the sum of exponential values of all
the other dimensions is the output of the softmax operator.

For each row `i` and each column `j` in X, we have:
For each row `i` and each column `j` in input X, we have:
Y[i, j] = exp(X[i, j]) / sum_j(exp(X[i, j]))

)DOC");
Expand All @@ -64,14 +64,15 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputVar("Y") != nullptr, "Input(Y) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")),
"Input(Y@GRAD) should not be null");
PADDLE_ENFORCE(ctx.Input<Tensor>("Y")->dims() ==
ctx.Input<Tensor>(framework::GradVarName("Y"))->dims(),
"the shape of Input(0) and Input(1) should be the same");
"Input(Y@GRAD) should be not null.");
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Y")->dims(),
ctx.Input<Tensor>(framework::GradVarName("Y"))->dims(),
"Input(Y) and its gradients should have a same shape.");

ctx.Output<Tensor>(framework::GradVarName("X"))
->Resize(ctx.Input<Tensor>("Y")->dims());
->Resize(ctx.Input<Tensor>("X")->dims());
}
};

Expand Down
10 changes: 5 additions & 5 deletions paddle/operators/softmax_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ template <typename Place, typename T>
class SoftmaxKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto input = context.Input<Tensor>("X");
auto output = context.Output<Tensor>("Y");
output->mutable_data<T>(context.GetPlace());
auto X = context.Input<Tensor>("X");
auto Y = context.Output<Tensor>("Y");
Y->mutable_data<T>(context.GetPlace());

auto logits = EigenMatrix<T>::From(*input);
auto softmax = EigenMatrix<T>::From(*output);
auto logits = EigenMatrix<T>::From(*X);
auto softmax = EigenMatrix<T>::From(*Y);

const int kBatchDim = 0;
const int kClassDim = 1;
Expand Down
59 changes: 33 additions & 26 deletions python/paddle/v2/framework/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

def get_all_op_protos():
"""
Get all registered op proto from Paddle C++
:return: list of OpProto
Get all registered op proto from PaddlePaddle C++ end.
:return: A list of registered OpProto.
"""
protostrs = core.get_all_op_protos()
ret_values = []
Expand All @@ -21,26 +21,27 @@ def is_str(s):

class OpDescCreationMethod(object):
"""
A Functor object to convert user input(use key word args) to OpDesc based on
OpProto.
Convert the user's input(only keyword arguments are supported) to OpDesc
based on the OpProto.

:param op_proto: The OpProto object.
:type op_proto: op_proto_pb2.OpProto
"""

def __init__(self, op_proto):
if not isinstance(op_proto, framework_pb2.OpProto):
raise TypeError("Argument should be OpProto")
raise TypeError(
"Type of op_proto should be OpProto in PaddlePaddle.")
self.__op_proto__ = op_proto

def __call__(self, *args, **kwargs):
"""
Convert user input to OpDesc. Only key-word args are supported.
:return: OpDesc based on user input
Convert user's input to OpDesc. Only keyword arguments are supported.
:return: The OpDesc based on user input.
:rtype: op_desc_pb2.OpDesc
"""
if len(args) != 0:
raise ValueError("Only keyword arguments is supported by Paddle")
raise ValueError("Only keyword arguments are supported.")
op_desc = framework_pb2.OpDesc()

for input_parameter in self.__op_proto__.inputs:
Expand All @@ -49,8 +50,9 @@ def __call__(self, *args, **kwargs):
input_arguments = [input_arguments]

if not input_parameter.duplicable and len(input_arguments) > 1:
raise ValueError("Input %s only accepts one input, but give %d"
% (input_parameter.name, len(input_arguments)))
raise ValueError(
"Input %s expects only one input, but %d are given." %
(input_parameter.name, len(input_arguments)))

ipt = op_desc.inputs.add()
ipt.parameter = input_parameter.name
Expand All @@ -63,7 +65,7 @@ def __call__(self, *args, **kwargs):

if not output_parameter.duplicable and len(output_arguments) > 1:
raise ValueError(
"Output %s only accepts one output, but give %d" %
"Output %s expects only one output, but %d are given." %
(output_parameter.name, len(output_arguments)))

out = op_desc.outputs.add()
Expand Down Expand Up @@ -100,15 +102,17 @@ def __call__(self, *args, **kwargs):
pair.first = p[0]
pair.second = p[1]
else:
raise NotImplementedError("Not support attribute type " +
str(attr.type))
raise NotImplementedError(
"A not supported attribute type: %s." % (
str(attr.type)))

return op_desc

@staticmethod
def any_is_true(generator):
"""
Reduce a bool array to one. If any of them is True, then return True.
Reduce a boolean array to a single boolean parameter. If any element in
the array is True, this function will return True, otherwise False.
"""
for flag in generator:
if flag:
Expand All @@ -127,7 +131,7 @@ def __init__(self, name, method, inputs, outputs, attrs):

def create_op_creation_method(op_proto):
"""
Generate op creation method for an OpProto
Generate op creation method for an OpProto.
"""
method = OpDescCreationMethod(op_proto)

Expand All @@ -146,20 +150,23 @@ def __impl__(*args, **kwargs):
class OperatorFactory(object):
def __init__(self):
self.op_methods = dict()

for op_proto in get_all_op_protos():
method = create_op_creation_method(op_proto)
self.op_methods[method.name] = method

def __call__(self, *args, **kwargs):
if 'type' in kwargs:
if "type" in kwargs:
if len(args) != 0:
raise ValueError("All Paddle argument should be key-word "
"argument except type")
t = kwargs.pop('type')
raise ValueError(
"Except the argument \"type\","
"all of the other arguments should be keyword arguments.")
t = kwargs.pop("type")
else:
if len(args) != 1:
raise ValueError("All Paddle argument should be key-word "
"argument except type")
raise ValueError(
"Except the argument \"type\","
"all of the other arguments should be keyword arguments.")
t = args[0]

return self.get_op_info(t).method(**kwargs)
Expand All @@ -169,7 +176,7 @@ def types(self):

def get_op_info(self, t):
if t not in self.op_methods:
raise ValueError("operator %s is not registered", t)
raise ValueError("The operator: %s is not registered." % t)
return self.op_methods.get(t)

def get_op_input_names(self, type):
Expand All @@ -184,7 +191,7 @@ def get_op_attr_names(self, type):

class __RecurrentOp__(object):
__proto__ = None
type = 'recurrent'
type = "recurrent"

def __init__(self):
# cache recurrent_op's proto
Expand All @@ -194,14 +201,14 @@ def __init__(self):
self.__proto__ = op_proto

def __call__(self, *args, **kwargs):
if self.type not in args and 'type' not in kwargs:
kwargs['type'] = self.type
if self.type not in args and "type" not in kwargs:
kwargs["type"] = self.type
# create proto
create_method = OpDescCreationMethod(self.__proto__)
proto = create_method(*args, **kwargs)
# create rnnop
return core.RecurrentOp.create(proto.SerializeToString())


Operator = OperatorFactory() # Default global factory
Operator = OperatorFactory() # The default global factory
RecurrentOp = __RecurrentOp__()
35 changes: 19 additions & 16 deletions python/paddle/v2/framework/tests/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def feed_data(name, data):
assert isinstance(data, numpy.ndarray)
tensor = scope.find_var(name).get_tensor()
tensor.set_dims(data.shape)
if data.dtype == numpy.dtype('int32'):
if data.dtype == numpy.dtype("int32"):
tensor.alloc_int(place)
elif data.dtype == numpy.dtype('float32'):
elif data.dtype == numpy.dtype("float32"):
tensor.alloc_float(place)
else:
raise ValueError("data type not supported")
Expand Down Expand Up @@ -74,22 +74,25 @@ def init_param(net, param_name, dims):
# fc_layer
def fc_layer(net, input, size, act="softmax", bias=True, param=None, name=None):
"""
Add a fc layer to net
The fully connected layer.

:param input: input variable name.
:param input: The name of input variable.
:type input: str
:param size: fully connected layer size.
:param act: activation name
:param param: parameter attribute, used for initialize parameters.
:param bias: bias attribute. False will not have a bias.
:param name: the name of fc layer. If not set, model will generate a
readable name
:return: output variable name.
:param size: The size of fully connected layer.
:param act: The name of activation.
:param param: The attribute of learnable parameter which can be used to
modify initialization mean and std of the parameter.
:param bias: The attribute of bias. If set False, this layer does not have
a bias.
:param name: The name of this layer. If it is not set explictly, a name
will be generated automatically.
:return: The name of the output variable.
"""

if name is None:
name = 'fc_%d' % uniq_id()
name = "fc_%d" % uniq_id()
if not isinstance(name, str):
raise ValueError("name should be string")
raise ValueError("The name of a layer should be a string.")

input_dims = scope.find_var(input).get_tensor().get_dims()

Expand Down Expand Up @@ -123,7 +126,7 @@ def fc_layer(net, input, size, act="softmax", bias=True, param=None, name=None):


def cross_entropy_layer(net, input, label):
cost_name = 'cross_entropy_%d' % uniq_id()
cost_name = "cross_entropy_%d" % uniq_id()
cross_entropy_op = Operator(
"onehot_cross_entropy", X=input, label=label, Y=cost_name)
net.append_op(cross_entropy_op)
Expand Down Expand Up @@ -177,8 +180,8 @@ def error_rate(predict, label):
return error_num / float(len(label))


images = data_layer(name='pixel', dims=[BATCH_SIZE, 784])
labels = data_layer(name='label', dims=[BATCH_SIZE])
images = data_layer(name="pixel", dims=[BATCH_SIZE, 784])
labels = data_layer(name="label", dims=[BATCH_SIZE])
fc1 = fc_layer(net=forward_net, input=images, size=100, act="sigmoid")
fc2 = fc_layer(net=forward_net, input=fc1, size=100, act="sigmoid")
predict = fc_layer(net=forward_net, input=fc2, size=10, act="softmax")
Expand Down
8 changes: 4 additions & 4 deletions python/paddle/v2/framework/tests/test_gradient_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@

class GetNumericGradientTest(unittest.TestCase):
def test_add_op(self):
add_op = Operator('add', X="X", Y="Y", Out="Z")
add_op = Operator("add", X="X", Y="Y", Out="Z")
x = numpy.random.random((10, 1)).astype("float32")
y = numpy.random.random((10, 1)).astype("float32")

arr = get_numeric_gradient(add_op, {'X': x, "Y": y}, 'Z', 'X')
arr = get_numeric_gradient(add_op, {"X": x, "Y": y}, "Z", "X")
self.assertAlmostEqual(arr.mean(), 1.0, delta=1e-4)

def test_softmax_op(self):
Expand All @@ -35,9 +35,9 @@ def label_softmax_grad(Y, dY):
dY = numpy.ones(Y.shape)
dX = label_softmax_grad(Y, dY)

arr = get_numeric_gradient(softmax_op, {"X": X}, 'Y', 'X')
arr = get_numeric_gradient(softmax_op, {"X": X}, "Y", "X")
numpy.testing.assert_almost_equal(arr, dX, decimal=1e-2)


if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
20 changes: 12 additions & 8 deletions python/paddle/v2/framework/tests/test_softmax_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,22 @@ class TestSoftmaxOp(unittest.TestCase):

def setUp(self):
self.type = "softmax"
self.inputs = {'X': np.random.random((32, 100)).astype("float32")}
self.inputs = {"X": np.random.random((10, 10)).astype("float32")}
self.outputs = {
'Y': np.apply_along_axis(stable_softmax, 1, self.inputs['X'])
"Y": np.apply_along_axis(stable_softmax, 1, self.inputs["X"])
}


class SoftmaxGradOpTest(GradientChecker):
def test_softmax(self):
op = create_op("softmax")
inputs = {"X": np.random.uniform(0.1, 1, [10, 10]).astype("float32")}
self.check_grad(op, inputs, set("X"), "Y")
class TestSoftmaxGradOp(GradientChecker):
def setUp(self):
self.op = create_op("softmax")
self.inputs = {
"X": np.random.uniform(0.1, 1, [10, 10]).astype("float32")
}

def test_softmax_grad(self):
self.check_grad(self.op, self.inputs, ["X"], "Y")


if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()