Skip to content
94 changes: 59 additions & 35 deletions paddle/fluid/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,15 @@ static bool has_feed_operators(
feed_count, feed_targets.size(),
"The number of feed operators should match 'feed_targets'");

// When feed operator are present, so should be feed_holder
auto var = block.FindVar(feed_holder_name);
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
feed_holder_name);
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH,
"'%s' variable should be 'FEED_MINIBATCH' type",
feed_holder_name);
if (!feed_holder_name.empty()) {
// When feed operator are present, so should be feed_holder
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • When feed operators are present
  • so should be feed_holder缺少主语?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这似乎是一种英语句法:https://zhidao.baidu.com/question/6942604.html

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这似乎是一种英语句法:https://zhidao.baidu.com/question/6942604.html

auto var = block.FindVar(feed_holder_name);
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
feed_holder_name);
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH,
"'%s' variable should be 'FEED_MINIBATCH' type",
feed_holder_name);
}
}

return feed_count > 0;
Expand Down Expand Up @@ -169,13 +171,15 @@ static bool has_fetch_operators(
fetch_count, fetch_targets.size(),
"The number of fetch operators should match 'fetch_targets'");

// When fetch operator are present, so should be fetch_holder
auto var = block.FindVar(fetch_holder_name);
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
fetch_holder_name);
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST,
"'%s' variable should be 'FETCH_LIST' type",
fetch_holder_name);
if (!fetch_holder_name.empty()) {
// When fetch operator are present, so should be fetch_holder
auto var = block.FindVar(fetch_holder_name);
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
fetch_holder_name);
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST,
"'%s' variable should be 'FETCH_LIST' type",
fetch_holder_name);
}
}

return fetch_count > 0;
Expand Down Expand Up @@ -222,16 +226,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
}
}

// map the data of feed_targets to feed_holder
for (auto* op : global_block->AllOps()) {
if (op->Type() == kFeedOpType) {
std::string feed_target_name = op->Output("Out")[0];
int idx = boost::get<int>(op->GetAttr("col"));
SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name,
idx);
}
}

if (!has_fetch_ops) {
// create fetch_holder variable
auto* fetch_holder = global_block->Var(fetch_holder_name);
Expand All @@ -255,17 +249,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
}
}

Run(*copy_program, scope, 0, create_vars, create_vars);

// obtain the data of fetch_targets from fetch_holder
for (auto* op : global_block->AllOps()) {
if (op->Type() == kFetchOpType) {
std::string fetch_target_name = op->Input("X")[0];
int idx = boost::get<int>(op->GetAttr("col"));
*fetch_targets[fetch_target_name] =
GetFetchVariable(*scope, fetch_holder_name, idx);
}
}
auto ctx = Prepare(*copy_program, 0);
RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets,
feed_holder_name, fetch_holder_name, create_vars);
}

std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
Expand Down Expand Up @@ -358,5 +344,43 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
}
}

void Executor::RunPreparedContext(
ExecutorPrepareContext* ctx, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets,
const std::string& feed_holder_name, const std::string& fetch_holder_name,
bool create_vars) {
auto& global_block = ctx->prog_.Block(ctx->block_id_);

PADDLE_ENFORCE(
has_feed_operators(global_block, feed_targets, feed_holder_name),
"Program in ExecutorPrepareContext should has feed_ops.");
PADDLE_ENFORCE(
has_fetch_operators(global_block, fetch_targets, fetch_holder_name),
"Program in the prepared context should has fetch_ops.");

// map the data of feed_targets to feed_holder
for (auto* op : global_block.AllOps()) {
if (op->Type() == kFeedOpType) {
std::string feed_target_name = op->Output("Out")[0];
int idx = boost::get<int>(op->GetAttr("col"));
SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name,
idx);
}
}

RunPreparedContext(ctx, scope, create_vars, create_vars);

// obtain the data of fetch_targets from fetch_holder
for (auto* op : global_block.AllOps()) {
if (op->Type() == kFetchOpType) {
std::string fetch_target_name = op->Input("X")[0];
int idx = boost::get<int>(op->GetAttr("col"));
*fetch_targets[fetch_target_name] =
GetFetchVariable(*scope, fetch_holder_name, idx);
}
}
}

} // namespace framework
} // namespace paddle
7 changes: 7 additions & 0 deletions paddle/fluid/framework/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ class Executor {
bool create_local_scope = true,
bool create_vars = true);

void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets,
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch",
bool create_vars = true);

private:
const platform::Place place_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ TEST(inference, image_classification) {

// Run inference on CPU
LOG(INFO) << "--- CPU Runs: ---";
TestInference<paddle::platform::CPUPlace>(dirname, cpu_feeds, cpu_fetchs1,
FLAGS_repeat);
TestInference<paddle::platform::CPUPlace, true>(dirname, cpu_feeds,
cpu_fetchs1, FLAGS_repeat);
LOG(INFO) << output1.dims();

#ifdef PADDLE_WITH_CUDA
Expand All @@ -57,8 +57,8 @@ TEST(inference, image_classification) {

// Run inference on CUDA GPU
LOG(INFO) << "--- GPU Runs: ---";
TestInference<paddle::platform::CUDAPlace>(dirname, cpu_feeds, cpu_fetchs2,
FLAGS_repeat);
TestInference<paddle::platform::CUDAPlace, true>(dirname, cpu_feeds,
cpu_fetchs2, FLAGS_repeat);
LOG(INFO) << output2.dims();

CheckError<float>(output1, output2);
Expand Down
20 changes: 17 additions & 3 deletions paddle/fluid/inference/tests/test_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ void CheckError(const paddle::framework::LoDTensor& output1,
EXPECT_EQ(count, 0U) << "There are " << count << " different elements.";
}

template <typename Place>
template <typename Place, bool PrepareContext = false>
void TestInference(const std::string& dirname,
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
const std::vector<paddle::framework::LoDTensor*>& cpu_fetchs,
Expand Down Expand Up @@ -167,7 +167,14 @@ void TestInference(const std::string& dirname,
// 6. Run the inference program
{
// Ignore the profiling results of the first run
executor.Run(*inference_program, scope, feed_targets, fetch_targets);
std::unique_ptr<paddle::framework::ExecutorPrepareContext> ctx;
if (PrepareContext) {
ctx = executor.Prepare(*inference_program, 0);
executor.RunPreparedContext(ctx.get(), scope, feed_targets,
fetch_targets);
} else {
executor.Run(*inference_program, scope, feed_targets, fetch_targets);
}

// Enable the profiler
paddle::platform::EnableProfiler(state);
Expand All @@ -178,7 +185,14 @@ void TestInference(const std::string& dirname,
"run_inference",
paddle::platform::DeviceContextPool::Instance().Get(place));

executor.Run(*inference_program, scope, feed_targets, fetch_targets);
if (PrepareContext) {
// Note: if you changed the inference_program, you need to call
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed->change

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

// executor.Prepare() again to get a new ExecutorPrepareContext.
executor.RunPreparedContext(ctx.get(), scope, feed_targets,
fetch_targets);
} else {
executor.Run(*inference_program, scope, feed_targets, fetch_targets);
}
}

// Disable the profiler and print the timing information
Expand Down