-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Op varient inputs #2901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Op varient inputs #2901
Conversation
paddle/framework/operator.cc
Outdated
| namespace paddle { | ||
| namespace framework { | ||
|
|
||
| void OperatorBase::CreateArgumentOffsetMap(const OpProto& proto) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Argument in caffe2 has the same meaning with Attribute in Paddle. Should we name it CreateInOutOffsetMap or something else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this map stores both the inputs and outputs offests, so maybe CreateArgmentOffsetsMap ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
InOut instead of Input ~
paddle/framework/operator.h
Outdated
| virtual void Run(const ScopePtr& scope, | ||
| const platform::DeviceContext& dev_ctx) const = 0; | ||
|
|
||
| // Get a input with argument's name described in `op_desc` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
described in op_proto or op_proto_and_checker_maker
paddle/framework/operator.cc
Outdated
|
|
||
| const std::string& OperatorBase::Input(const std::string& name) const { | ||
| auto it = arg_idxs_.find(name); | ||
| PADDLE_ENFORCE(it != arg_idxs_.end(), "no key [%d] in arg_idxs_", name); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[%d] --> [%s]
paddle/framework/operator.cc
Outdated
|
|
||
| const std::string& OperatorBase::Output(const std::string& name) const { | ||
| auto it = arg_idxs_.find(name); | ||
| PADDLE_ENFORCE(it != arg_idxs_.end(), "no key [%d] in arg_idxs_", name); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[%d] --> [%s]
paddle/framework/operator_test.cc
Outdated
| void Run(const std::shared_ptr<Scope>& scope, | ||
| const platform::DeviceContext& dev_ctx) const override { | ||
| float scale = GetAttr<float>("scale"); | ||
| ASSERT_NEAR(scale, 3.14, 1e-5); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can remove the unnecessary assert
| if (attrs_.count("input_format") == 0) { | ||
| return inputs_[it->second]; | ||
| } else { | ||
| const auto& input_format = GetAttr<std::vector<int>>("input_format"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does the input_format mean? If it is the necessary attribute for dynamic-length inputs, should check it in the op_checker ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paddle/framework/operator_test.cc
Outdated
| OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto, | ||
| OpAttrChecker* op_checker) | ||
| : OpProtoAndCheckerMaker(proto, op_checker) { | ||
| AddInputs("xs", "input of test op"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inputs of test op
paddle/framework/operator_test.cc
Outdated
| : OpProtoAndCheckerMaker(proto, op_checker) { | ||
| AddInputs("xs", "input of test op"); | ||
| AddInput("k", "input of test op"); | ||
| AddOutputs("ys", "output of test op"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
outputs of test op
| attr->set_f(3.14); | ||
|
|
||
| auto attr0 = op_desc.mutable_attrs()->Add(); | ||
| attr0->set_name("input_format"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input_starts or input_offsets?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paddle/framework/operator.h
Outdated
| void CreateInOutOffsetMap(const OpProto& proto); | ||
|
|
||
| protected: | ||
| std::string Type() const { return type_; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
type_ is better.
paddle/framework/operator.cc
Outdated
| void OperatorBase::CreateInOutOffsetMap(const OpProto& proto) { | ||
| for (int i = 0; i < proto.inputs_size(); i++) { | ||
| const auto& name = proto.inputs()[i].name(); | ||
| arg_idxs_[name] = i; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in_out_inxs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, index --> indices?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But it doesn't matter, we all understand what arg_idxs_ means.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为了保持一致,建议还是改成 in_out_idxs 之类的。
| ASSERT_EQ((int)context.op_.outputs_.size(), 1); | ||
| ASSERT_NEAR(context.op_.GetAttr<float>("scale"), 3.14, 1e-5); | ||
| void Compute(const KernelContext& ctx) const { | ||
| std::cout << "this is cpu kernel" << std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cpu_kernel_run_num++ should be added back
| } | ||
| } | ||
|
|
||
| std::vector<std::string> OperatorBase::Inputs(const std::string& name) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possible optimization here. Return a vector_view instead of a vector to prevent memory copy. It is not critical for now. But we can still have a TODO comments here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added in header
done
paddle/framework/enforce.h
Outdated
| limitations under the License. */ | ||
|
|
||
| #pragma once | ||
| #include <glog/logging.h> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why need this include?
jacquesqiao
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! except for a little question: What is the purpose of move CreateInOutOffsetMap from OperatorBase to op registry
No description provided.