Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,17 +171,6 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
return ss.str();
}

void OperatorBase::Rename(const std::string& old_name,
const std::string& new_name) {
for (auto& input : inputs_) {
std::replace(input.second.begin(), input.second.end(), old_name, new_name);
}
for (auto& output : outputs_) {
std::replace(output.second.begin(), output.second.end(), old_name,
new_name);
}
}

OperatorBase::OperatorBase(const std::string& type,
const VariableNameMap& inputs,
const VariableNameMap& outputs,
Expand Down Expand Up @@ -327,7 +316,6 @@ bool OpSupportGPU(const std::string& op_type) {
auto it = all_kernels.find(op_type);
if (it == all_kernels.end()) {
// All control operator must support GPU

return true;
}
for (auto& kern_pair : it->second) {
Expand Down
51 changes: 15 additions & 36 deletions paddle/fluid/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,31 +79,28 @@ class OperatorBase {

virtual ~OperatorBase() {}

template <typename T>
inline const T& Attr(const std::string& name) const {
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
name);
return boost::get<T>(attrs_.at(name));
}

/// if scope is not null, also show dimensions of arguments
virtual std::string DebugStringEx(const Scope* scope) const;

std::string DebugString() const { return DebugStringEx(nullptr); }

/// Net will call this interface function to Run an op.
/// Executor will call this interface function to Run an op.
// The implementation should be written at RunImpl
void Run(const Scope& scope, const platform::Place& place);

// FIXME(typhoonzero): this is only used for recv_op to stop event_loop.
virtual void Stop() {}

virtual bool IsNetOp() const { return false; }
/// if scope is not null, also show dimensions of arguments
virtual std::string DebugStringEx(const Scope* scope) const;
std::string DebugString() const { return DebugStringEx(nullptr); }

virtual bool SupportGPU() const { return false; }

/// rename inputs outputs name
void Rename(const std::string& old_name, const std::string& new_name);
const std::string& Type() const { return type_; }

template <typename T>
inline const T& Attr(const std::string& name) const {
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
name);
return boost::get<T>(attrs_.at(name));
}
const AttributeMap& Attrs() const { return attrs_; }

const VariableNameMap& Inputs() const { return inputs_; }
const VariableNameMap& Outputs() const { return outputs_; }
Expand All @@ -112,21 +109,17 @@ class OperatorBase {
std::string Input(const std::string& name) const;
//! Get a input which has multiple variables.
const std::vector<std::string>& Inputs(const std::string& name) const;

//! Get all inputs variable names
std::vector<std::string> InputVars() const;

//! Get a output with argument's name described in `op_proto`
std::string Output(const std::string& name) const;
//! Get an output which has multiple variables.
//! TODO add a vector_view to prevent memory copy.
const std::vector<std::string>& Outputs(const std::string& name) const;

//! Get all outputs variable names
virtual std::vector<std::string> OutputVars(bool has_intermediate) const;

const std::string& Type() const { return type_; }
void SetType(const std::string& type) { type_ = type; }
const AttributeMap& Attrs() const { return attrs_; }

// Return a new operator instance, which is as same as this.
// Use unique_ptr to prevent caller forget to delete this pointer.
virtual std::unique_ptr<OperatorBase> Clone() const = 0;
Expand Down Expand Up @@ -278,20 +271,6 @@ class ExecutionContext {
return res;
}

void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const {
PADDLE_ENFORCE_LT(i, InputSize(in));
PADDLE_ENFORCE_LT(j, OutputSize(out));
auto* in_var = MultiInputVar(in)[i];
auto* out_var = MultiOutputVar(out)[j];
if (!in_var->IsType<LoDTensor>()) return;
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
"The %d-th output of Output(%s) must be LoDTensor.", j, out);
auto in_tensor = in_var->Get<LoDTensor>();
auto* out_tensor = out_var->GetMutable<LoDTensor>();
out_tensor->set_lod(in_tensor.lod());
}

platform::Place GetPlace() const { return device_context_.GetPlace(); }

template <typename DeviceContextType>
Expand Down