Skip to content

Commit 488937e

Browse files
author
Yang Yang
committed
add comments
1 parent 8467777 commit 488937e

File tree

3 files changed

+13
-1
lines changed

3 files changed

+13
-1
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,15 @@ OperatorBase::OperatorBase(const std::string& type,
180180
CheckAllInputOutputSet();
181181
}
182182

183+
std::vector<std::string> OperatorBase::InputVars() const {
184+
std::vector<std::string> ret_val;
185+
for (auto& o : inputs_) {
186+
ret_val.reserve(ret_val.size() + o.second.size());
187+
ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
188+
}
189+
return ret_val;
190+
}
191+
183192
std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
184193
std::vector<std::string> ret_val;
185194
if (has_intermediate) {

paddle/fluid/framework/operator.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,15 @@ class OperatorBase {
109109
std::string Input(const std::string& name) const;
110110
//! Get a input which has multiple variables.
111111
const std::vector<std::string>& Inputs(const std::string& name) const;
112+
//! Get all inputs variable names
113+
std::vector<std::string> InputVars() const;
112114

113115
//! Get a output with argument's name described in `op_proto`
114116
std::string Output(const std::string& name) const;
115117
//! Get an output which has multiple variables.
116118
//! TODO add a vector_view to prevent memory copy.
117119
const std::vector<std::string>& Outputs(const std::string& name) const;
118-
120+
//! Get all outputs variable names
119121
virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
120122

121123
// Return a new operator instance, which is as same as this.

paddle/fluid/pybind/pybind.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,7 @@ All parameter, weight, gradient are variables in Paddle.
402402
.def("output_vars",
403403
[](const OperatorBase &op) { return op.OutputVars(true); })
404404
.def("inputs", [](const OperatorBase &op) { return op.Inputs(); })
405+
.def("input_vars", [](const OperatorBase &op) { return op.InputVars(); })
405406
.def("__str__", &OperatorBase::DebugString)
406407
.def("no_intermediate_outputs",
407408
[](const OperatorBase &op) { return op.OutputVars(false); })

0 commit comments

Comments
 (0)