Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions paddle/fluid/framework/data_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,16 @@ void DataTransform(const OpKernelType& expected_kernel_type,
}

void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor,
Variable& out_var) {
Variable* out_var) {
if (in_var.IsType<LoDTensor>()) {
auto& in_lod_tensor = in_var.Get<LoDTensor>();
auto* tran_lod_tensor = out_var.GetMutable<LoDTensor>();
auto* tran_lod_tensor = out_var->GetMutable<LoDTensor>();
tran_lod_tensor->set_lod(in_lod_tensor.lod());
tran_lod_tensor->set_layout(in_lod_tensor.layout());
tran_lod_tensor->ShareDataWith(tensor);
} else if (in_var.IsType<SelectedRows>()) {
auto& in_selected_rows = in_var.Get<SelectedRows>();
auto* trans_selected_rows = out_var.GetMutable<SelectedRows>();
auto* trans_selected_rows = out_var->GetMutable<SelectedRows>();
trans_selected_rows->set_height(in_selected_rows.height());
trans_selected_rows->set_rows(in_selected_rows.rows());
trans_selected_rows->mutable_value()->ShareDataWith(tensor);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/data_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void DataTransform(const OpKernelType& expected_kernel_type,
const Tensor& input_tensor, Tensor* out);

void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor,
Variable& out_var);
Variable* out_var);

} // namespace framework
} // namespace paddle
2 changes: 1 addition & 1 deletion paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
std::shared_ptr<Tensor> out(new Tensor);
DataTransform(expected_kernel_key, kernel_type_for_var, *tensor_in,
out.get());
CopyVariableWithTensor(*var, *(out.get()), *trans_var);
CopyVariableWithTensor(*var, *(out.get()), trans_var);
}
}
}
Expand Down
18 changes: 9 additions & 9 deletions paddle/fluid/framework/prune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@ limitations under the License. */

#include "paddle/fluid/framework/prune.h"

#include <glog/logging.h>

#include <algorithm>
#include <set>
#include <string>
#include <unordered_map>
#include <vector>

#include <glog/logging.h>

namespace paddle {
namespace framework {

const std::string kFeedOpType = "feed";
const std::string kFetchOpType = "fetch";
const char kFeedOpType[] = "feed";
const char kFetchOpType[] = "fetch";

bool HasDependentVar(const proto::OpDesc& op_desc,
const std::set<std::string>& dependent_vars) {
Expand Down Expand Up @@ -68,7 +68,7 @@ bool HasSubBlock(const proto::OpDesc& op_desc) {
// the child block to help pruning
void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
int block_id, int parent_block_id,
std::set<std::string>& dependent_vars) {
std::set<std::string>* dependent_vars) {
auto& block = input.blocks(block_id);
auto& ops = block.ops();

Expand All @@ -90,11 +90,11 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
std::vector<bool> should_run;
for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) {
auto& op_desc = *op_iter;
if (IsTarget(op_desc) || HasDependentVar(op_desc, dependent_vars)) {
if (IsTarget(op_desc) || HasDependentVar(op_desc, *dependent_vars)) {
// insert its input to the dependency graph
for (auto& var : op_desc.inputs()) {
for (auto& argu : var.arguments()) {
dependent_vars.insert(argu);
dependent_vars->insert(argu);
}
}
should_run.push_back(true);
Expand Down Expand Up @@ -138,7 +138,7 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
// GetSubBlockIndex(*op) is the idx of the sub_block in the input desc
// output_block_id is the idx of the current block in the output desc
prune_impl(input, output, GetSubBlockIndex(*op), output_block_id,
sub_block_dependent_vars);
&sub_block_dependent_vars);
}
}
}
Expand Down Expand Up @@ -181,7 +181,7 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
void Prune(const proto::ProgramDesc& input, proto::ProgramDesc* output) {
std::set<std::string> dependent_vars;
output->clear_blocks();
prune_impl(input, output, 0, -1, dependent_vars);
prune_impl(input, output, 0, -1, &dependent_vars);
}

void inference_optimize_impl(proto::ProgramDesc* input, int block_id) {
Expand Down