Skip to content
Merged
2 changes: 1 addition & 1 deletion paddle/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id,

for (auto& op_desc : block.AllOps()) {
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
VLOG(10) << op->DebugString();
VLOG(3) << op->DebugString();
op->Run(*local_scope, *device);
}
if (create_local_scope) {
Expand Down
23 changes: 23 additions & 0 deletions paddle/framework/prune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ namespace framework {

const std::string kFeedOpType = "feed";
const std::string kFetchOpType = "fetch";
const std::string kDropOutOpType = "dropout";
const std::string kBatchNormOpType = "batch_norm";

bool HasDependentVar(const OpDesc& op_desc,
const std::set<std::string>& dependent_vars) {
Expand Down Expand Up @@ -106,5 +108,26 @@ void Prune(const ProgramDesc& input, ProgramDesc* output) {
prune_impl(input, output, 0);
}

void inference_optimize_impl(const ProgramDesc& input, ProgramDesc* output,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic of inference_optimize_impl is quite simple. Maybe we can implement it in Python.

int block_id) {
*output = input;
auto* op_field = output->mutable_blocks(block_id)->mutable_ops();
for (auto& op_desc : *op_field) {
if (op_desc.type() == kDropOutOpType ||
op_desc.type() == kBatchNormOpType) {
for (auto& attr : *op_desc.mutable_attrs()) {
if (attr.name() == "is_test") {
attr.set_b(true);
break;
}
}
}
}
}

void InferenceOptimize(const ProgramDesc& input, ProgramDesc* output) {
inference_optimize_impl(input, output, 0);
}

} // namespace framework
} // namespace paddle
2 changes: 2 additions & 0 deletions paddle/framework/prune.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,7 @@ namespace framework {

void Prune(const ProgramDesc& input, ProgramDesc* output);

void InferenceOptimize(const ProgramDesc& input, ProgramDesc* output);

} // namespace framework
} // namespace paddle
8 changes: 4 additions & 4 deletions paddle/operators/dropout_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class DropoutOp : public framework::OperatorWithKernel {

auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims);
if (ctx->Attrs().Get<bool>("is_training") == true) {
if (ctx->Attrs().Get<bool>("is_test") == false) {
ctx->SetOutputDim("Mask", x_dims);
}
ctx->ShareLoD("X", /*->*/ "Out");
Expand All @@ -49,7 +49,7 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {

AddAttr<float>("dropout_prob", "Probability of setting units to zero.")
.SetDefault(.5f);
AddAttr<bool>("is_training", "True if in training phase.").SetDefault(true);
AddAttr<bool>("is_test", "True if in test phase.").SetDefault(false);
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);

AddComment(R"DOC(
Expand All @@ -71,8 +71,8 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_training"), true,
"GradOp is only callable when is_training is true");
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_test"), false,
"GradOp is only callable when is_test is false");

PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasInput("Mask"), "Mask must not be null.");
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/dropout_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
auto Y = EigenMatrix<T>::Reshape(*y, 1);

auto place = context.GetEigenDevice<Place>();
if (context.Attr<bool>("is_training")) {
if (!context.Attr<bool>("is_test")) {
auto* mask = context.Output<Tensor>("Mask");
auto* mask_data = mask->mutable_data<T>(context.GetPlace());
int size = framework::product(mask->dims());
Expand Down
6 changes: 3 additions & 3 deletions paddle/operators/dropout_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto* y_data = y->mutable_data<T>(context.GetPlace());
float dropout_prob = context.Attr<float>("dropout_prob");

if (context.Attr<bool>("is_training")) {
if (!context.Attr<bool>("is_test")) {
auto* mask = context.Output<Tensor>("Mask");
auto* mask_data = mask->mutable_data<T>(context.GetPlace());
int seed = context.Attr<int>("seed");
Expand Down Expand Up @@ -65,8 +65,8 @@ template <typename Place, typename T>
class DropoutGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(context.Attr<bool>("is_training"),
"GradOp is only callable when is_training is true");
PADDLE_ENFORCE(!context.Attr<bool>("is_test"),
"GradOp is only callable when is_test is false");

auto* grad_x = context.Output<Tensor>(framework::GradVarName("X"));
auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out"));
Expand Down
5 changes: 5 additions & 0 deletions paddle/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,11 @@ All parameter, weight, gradient are variables in Paddle.
Prune(*prog_with_targets.Proto(), &pruned_desc);
return new ProgramDescBind(pruned_desc);
});
m.def("inference_optimize", [](ProgramDescBind &origin) {
ProgramDesc pruned_desc;
InferenceOptimize(*(origin.Proto()), &pruned_desc);
return new ProgramDescBind(pruned_desc);
});
m.def_submodule(
"var_names",
"The module will return special predefined variable name in Paddle")
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/v2/fluid/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def __init__(self, name, **kwargs):
else:
self._main_program = g_main_program

def states(self):
return self._states

def _update_ops(self, *args, **kwargs):
"""
append update ops to the global states
Expand Down
7 changes: 7 additions & 0 deletions python/paddle/v2/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,13 @@ def prune(self, targets):
res.sync_with_cpp()
return res

def inference_optimize(self):
res = Program()
res.desc = core.inference_optimize(self.desc)
res.blocks = [Block(res, i) for i in xrange(res.desc.num_blocks())]
res.sync_with_cpp()
return res

@staticmethod
def parse_from_string(binary_str):
p = Program()
Expand Down
19 changes: 16 additions & 3 deletions python/paddle/v2/fluid/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

__all__ = [
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
'load_persistables', "save_inference_model", "load_inference_model"
'load_persistables', "save_inference_model", "load_inference_model",
"get_inference_program"
]


Expand Down Expand Up @@ -151,6 +152,17 @@ def load_persistables(executor, dirname, main_program=None):
predicate=is_persistable)


def get_inference_program(target_vars, main_program=None):
if main_program is None:
main_program = g_main_program
if not isinstance(target_vars, list):
target_vars = [target_vars]

pruned_program = main_program.prune(targets=target_vars)
inference_program = pruned_program.inference_optimize()
return inference_program


def save_inference_model(dirname,
feeded_var_names,
target_vars,
Expand All @@ -177,13 +189,14 @@ def save_inference_model(dirname,
if not os.path.isdir(dirname):
os.makedirs(dirname)

pruned_program = main_program.prune(target_vars)
pruned_program = main_program.prune(targets=target_vars)
inference_program = pruned_program.inference_optimize()
fetch_var_names = [v.name for v in target_vars]

model_file_name = dirname + "/__model__"
with open(model_file_name, "w") as f:
pickle.dump({
"program_desc_str": pruned_program.desc.serialize_to_string(),
"program_desc_str": inference_program.desc.serialize_to_string(),
"feed_var_names": feeded_var_names,
"fetch_var_names": fetch_var_names
}, f, -1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.nets as nets
import paddle.v2.fluid.evaluator as evaluator
from paddle.v2.fluid.io import get_inference_program
from paddle.v2.fluid.executor import Executor
from paddle.v2.fluid.initializer import XavierInitializer
from paddle.v2.fluid.optimizer import AdamOptimizer
Expand Down Expand Up @@ -116,9 +117,11 @@ def conv_block(input, num_filter, groups, dropouts):

train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.cifar.train10(), buf_size=128 * 10),
paddle.dataset.cifar.train10(), buf_size=BATCH_SIZE * 10),
batch_size=BATCH_SIZE)

test_reader = paddle.batch(paddle.dataset.cifar.test10(), batch_size=BATCH_SIZE)

place = core.CPUPlace()
exe = Executor(place)

Expand Down Expand Up @@ -149,10 +152,41 @@ def conv_block(input, num_filter, groups, dropouts):
loss = np.array(outs[0])
acc = np.array(outs[1])
pass_acc = accuracy.eval(exe)

batch_id = batch_id + 1

test_accuracy, test_acc_out = evaluator.accuracy(
input=predict, label=label)

test_target = [avg_cost, test_acc_out] + test_accuracy.states().values()
inference_program = get_inference_program(test_target)

test_accuracy.reset(exe)

for data in test_reader():
x_data = np.array(map(lambda x: x[0].reshape(data_shape),
data)).astype("float32")
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
y_data = np.expand_dims(y_data, axis=1)

tensor_x = core.LoDTensor()
tensor_x.set(x_data, place)

tensor_y = core.LoDTensor()
tensor_y.set(y_data, place)

outs = exe.run(inference_program,
feed={'pixel': tensor_x,
'label': tensor_y},
fetch_list=[avg_cost, test_acc_out])
out = np.array(outs[0])
acc = np.array(outs[1])

test_pass_acc = test_accuracy.eval(exe)

print("pass_id:" + str(pass_id) + " batch_id:" + str(batch_id) +
" loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str(
pass_acc))
batch_id = batch_id + 1
pass_acc) + " test_pass_acc:" + str(test_pass_acc))

if batch_id > 1:
# this model is slow, so if we can train two mini batch, we think it works properly.
Expand Down
37 changes: 34 additions & 3 deletions python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import paddle.v2.fluid.framework as framework
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.evaluator as evaluator
from paddle.v2.fluid.io import get_inference_program
from paddle.v2.fluid.executor import Executor
from paddle.v2.fluid.initializer import UniformInitializer
from paddle.v2.fluid.optimizer import MomentumOptimizer
Expand Down Expand Up @@ -42,6 +43,8 @@
paddle.dataset.mnist.train(), buf_size=8192),
batch_size=BATCH_SIZE)

test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128)

place = core.CPUPlace()
exe = Executor(place)

Expand Down Expand Up @@ -69,8 +72,36 @@
acc = np.array(outs[1])
pass_acc = accuracy.eval(exe)

if pass_acc > 0.7:
test_accuracy, test_acc_out = evaluator.accuracy(
input=predict, label=label)

test_target = [avg_cost, test_acc_out] + test_accuracy.states().values()
inference_program = get_inference_program(test_target)

test_accuracy.reset(exe)
for data in test_reader():
x_data = np.array(map(lambda x: x[0], data)).astype("float32")
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
y_data = np.expand_dims(y_data, axis=1)

tensor_x = core.LoDTensor()
tensor_x.set(x_data, place)

tensor_y = core.LoDTensor()
tensor_y.set(y_data, place)

outs = exe.run(inference_program,
feed={'x': tensor_x,
'y': tensor_y},
fetch_list=[avg_cost, test_acc_out])
out = np.array(outs[0])
acc = np.array(outs[1])

test_pass_acc = test_accuracy.eval(exe)
print("pass_id=" + str(pass_id) + " train_cost=" + str(
out) + " train_acc=" + str(acc) + " train_pass_acc=" + str(pass_acc)
+ " test_acc=" + str(test_pass_acc))

if test_pass_acc > 0.7:
exit(0)
# print("pass_id=" + str(pass_id) + " auc=" +
# str(acc) + " pass_acc=" + str(pass_acc))
exit(1)
10 changes: 5 additions & 5 deletions python/paddle/v2/fluid/tests/test_dropout_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class TestDropoutOp(OpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {'dropout_prob': 0.0, 'is_training': True}
self.attrs = {'dropout_prob': 0.0, 'is_test': False}
self.outputs = {
'Out': self.inputs['X'],
'Mask': np.ones((32, 64)).astype('float32')
Expand All @@ -24,7 +24,7 @@ class TestDropoutOp2(TestDropoutOp):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {'dropout_prob': 1.0, 'is_training': True}
self.attrs = {'dropout_prob': 1.0, 'is_test': False}
self.outputs = {
'Out': np.zeros((32, 64)).astype('float32'),
'Mask': np.zeros((32, 64)).astype('float32')
Expand All @@ -35,7 +35,7 @@ class TestDropoutOp3(TestDropoutOp):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")}
self.attrs = {'dropout_prob': 0.0, 'is_training': True}
self.attrs = {'dropout_prob': 0.0, 'is_test': False}
self.outputs = {
'Out': self.inputs['X'],
'Mask': np.ones((32, 64, 2)).astype('float32')
Expand All @@ -46,7 +46,7 @@ class TestDropoutOp4(OpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {'dropout_prob': 0.35, 'is_training': False}
self.attrs = {'dropout_prob': 0.35, 'is_test': True}
self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']}

def test_check_output(self):
Expand All @@ -57,7 +57,7 @@ class TestDropoutOp5(OpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64, 3)).astype("float32")}
self.attrs = {'dropout_prob': 0.75, 'is_training': False}
self.attrs = {'dropout_prob': 0.75, 'is_test': True}
self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']}

def test_check_output(self):
Expand Down