-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Op varient inputs #2901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Op varient inputs #2901
Changes from 5 commits
937b099
bec798b
5a0233c
e58ddea
f0f029d
2c3473a
771db04
edd0d01
84eef6f
d57d9bc
0113c4c
a8e3c8a
010a1d8
5459382
39be0ba
b2d77a2
0327d6c
d7f5de9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,11 +12,68 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include <algorithm> | ||
|
|
||
| #include "paddle/framework/operator.h" | ||
|
|
||
| namespace paddle { | ||
| namespace framework { | ||
|
|
||
| void OperatorBase::CreateArgumentOffsetMap(const OpProto& proto) { | ||
|
||
| for (int i = 0; i < proto.inputs_size(); i++) { | ||
| const auto& name = proto.inputs()[i].name(); | ||
| arg_idxs_[name] = i; | ||
|
||
| } | ||
| for (int i = 0; i < proto.outputs_size(); i++) { | ||
| const auto& name = proto.outputs()[i].name(); | ||
| arg_idxs_[name] = i; | ||
| } | ||
| } | ||
|
|
||
| const std::string& OperatorBase::Input(const std::string& name) const { | ||
| auto it = arg_idxs_.find(name); | ||
| PADDLE_ENFORCE(it != arg_idxs_.end(), "no key [%d] in arg_idxs_", name); | ||
|
||
|
|
||
| if (attrs_.count("input_format") == 0) { | ||
| return inputs_[it->second]; | ||
| } else { | ||
| const auto& input_format = GetAttr<std::vector<int>>("input_format"); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| int idx = input_format[it->second]; | ||
| return inputs_.at(idx); | ||
| } | ||
| } | ||
|
|
||
| std::vector<std::string> OperatorBase::Inputs(const std::string& name) const { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Possible optimization here. Return a vector_view instead of a vector to prevent memory copy. It is not critical for now. But we can still have a TODO comments here.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added in header done |
||
| auto input_format = GetAttr<std::vector<int>>("input_format"); | ||
| auto offset = arg_idxs_.at(name); | ||
|
|
||
| return std::vector<std::string>{ | ||
| inputs_.begin() + input_format.at(offset), | ||
| inputs_.begin() + input_format.at(offset + 1)}; | ||
| } | ||
|
|
||
| const std::string& OperatorBase::Output(const std::string& name) const { | ||
| auto it = arg_idxs_.find(name); | ||
| PADDLE_ENFORCE(it != arg_idxs_.end(), "no key [%d] in arg_idxs_", name); | ||
|
||
|
|
||
| if (attrs_.count("output_format") == 0) { | ||
| return outputs_[it->second]; | ||
| } else { | ||
| const auto& output_format = GetAttr<std::vector<int>>("output_format"); | ||
| int idx = output_format[it->second]; | ||
| return outputs_.at(idx); | ||
| } | ||
| } | ||
|
|
||
| std::vector<std::string> OperatorBase::Outputs(const std::string& name) const { | ||
| auto output_format = GetAttr<std::vector<int>>("output_format"); | ||
| auto offset = arg_idxs_.at(name); | ||
|
|
||
| return std::vector<std::string>{ | ||
| outputs_.begin() + output_format.at(offset), | ||
| outputs_.begin() + output_format.at(offset + 1)}; | ||
| } | ||
|
|
||
| std::string OperatorBase::DebugString() const { | ||
| std::stringstream ss; | ||
| ss << "=================\n"; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,18 +14,20 @@ limitations under the License. */ | |
|
|
||
| #pragma once | ||
|
|
||
| #include <paddle/framework/attr_checker.h> | ||
| #include <paddle/framework/op_desc.pb.h> | ||
| #include <paddle/framework/scope.h> | ||
| #include <paddle/framework/tensor.h> | ||
| #include <paddle/platform/device_context.h> | ||
| #include <paddle/platform/place.h> | ||
| #include <paddle/utils/Error.h> | ||
| #include <boost/variant.hpp> | ||
| #include <string> | ||
| #include <unordered_map> | ||
| #include <vector> | ||
|
|
||
| #include "paddle/framework/attr_checker.h" | ||
| #include "paddle/framework/op_desc.pb.h" | ||
| #include "paddle/framework/op_proto.pb.h" | ||
| #include "paddle/framework/scope.h" | ||
| #include "paddle/framework/tensor.h" | ||
| #include "paddle/platform/device_context.h" | ||
| #include "paddle/platform/place.h" | ||
| #include "paddle/utils/Error.h" | ||
|
|
||
| namespace paddle { | ||
| namespace framework { | ||
|
|
||
|
|
@@ -62,6 +64,18 @@ class OperatorBase { | |
| virtual void Run(const ScopePtr& scope, | ||
| const platform::DeviceContext& dev_ctx) const = 0; | ||
|
|
||
| // Get a input with argument's name described in `op_desc` | ||
|
||
| const std::string& Input(const std::string& name) const; | ||
| // Get a input which has multiple variables. | ||
| std::vector<std::string> Inputs(const std::string& name) const; | ||
| // Get a output with argument's name described in `op_desc` | ||
| const std::string& Output(const std::string& name) const; | ||
| // Get an output which has multiple variables. | ||
| std::vector<std::string> Outputs(const std::string& name) const; | ||
|
|
||
| // init arg_idxs_ to accelerate argument's offset lookup. | ||
| void CreateArgumentOffsetMap(const OpProto& proto); | ||
|
|
||
| protected: | ||
| std::string Type() const { return desc_.type(); } | ||
|
|
||
|
|
@@ -70,6 +84,53 @@ class OperatorBase { | |
| std::vector<std::string> inputs_; | ||
| std::vector<std::string> outputs_; | ||
| AttributeMap attrs_; | ||
| // store the arguments' offset described in op_desc. | ||
| std::unordered_map<std::string, int> arg_idxs_; | ||
| }; | ||
|
|
||
| class KernelContext { | ||
| public: | ||
| KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope, | ||
| const platform::DeviceContext& device_context) | ||
| : op_(*op), scope_(scope), device_context_(device_context) {} | ||
|
|
||
| const Variable* Input(int index) const { | ||
| return scope_->GetVariable(op_.inputs_[index]); | ||
| } | ||
|
|
||
| Variable* Output(int index) const { | ||
| return scope_->GetVariable(op_.outputs_[index]); | ||
| } | ||
|
|
||
| const Variable* Input(const std::string& name) const { | ||
| return scope_->GetVariable(op_.Input(name)); | ||
| } | ||
|
|
||
| const Variable* Output(const std::string& name) const { | ||
| return scope_->GetVariable(op_.Output(name)); | ||
| } | ||
|
|
||
| const std::vector<const Variable*> Inputs(const std::string& name) const { | ||
| auto names = op_.Inputs(name); | ||
| std::vector<const Variable*> res; | ||
| std::transform( | ||
| names.begin(), names.end(), res.begin(), | ||
| [this](const std::string& name) { return scope_->GetVariable(name); }); | ||
| return res; | ||
| } | ||
|
|
||
| const std::vector<const Variable*> Outputs(const std::string& name) const { | ||
| auto names = op_.Outputs(name); | ||
| std::vector<const Variable*> res; | ||
| std::transform( | ||
| names.begin(), names.end(), res.begin(), | ||
| [this](const std::string& name) { return scope_->GetVariable(name); }); | ||
| return res; | ||
| } | ||
|
|
||
| const OperatorBase& op_; | ||
| const std::shared_ptr<Scope>& scope_; | ||
| const platform::DeviceContext& device_context_; | ||
| }; | ||
|
|
||
| class OpKernel { | ||
|
|
@@ -80,25 +141,6 @@ class OpKernel { | |
| * device resource such as CUDA stream, cublas handle, etc. from | ||
| * KernelContext. User should construct it before run the Operator. | ||
| */ | ||
| class KernelContext { | ||
| public: | ||
| KernelContext(const OperatorBase* op, const ScopePtr& scope, | ||
| const platform::DeviceContext& device_context) | ||
| : op_(*op), scope_(scope), device_context_(device_context) {} | ||
|
|
||
| const Variable* Input(int index) const { | ||
| return scope_->GetVariable(op_.inputs_[index]); | ||
| } | ||
|
|
||
| Variable* Output(int index) const { | ||
| return scope_->GetVariable(op_.outputs_[index]); | ||
| } | ||
|
|
||
| const OperatorBase& op_; | ||
| const ScopePtr& scope_; | ||
| const platform::DeviceContext& device_context_; | ||
| }; | ||
|
|
||
| virtual void Compute(const KernelContext& context) const = 0; | ||
|
|
||
| virtual ~OpKernel() {} | ||
|
|
@@ -143,14 +185,15 @@ class OperatorWithKernel : public OperatorBase { | |
| void Run(const ScopePtr& scope, | ||
| const platform::DeviceContext& dev_ctx) const final { | ||
| auto& opKernel = AllOpKernels().at(Type()).at(OpKernelKey(dev_ctx)); | ||
| opKernel->Compute(OpKernel::KernelContext(this, scope, dev_ctx)); | ||
| opKernel->Compute(KernelContext(this, scope, dev_ctx)); | ||
| } | ||
|
|
||
| static std::unordered_map<std::string /* op_type */, OpKernelMap>& | ||
| AllOpKernels() { | ||
| static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels; | ||
| return g_all_op_kernels; | ||
| } | ||
|
|
||
| void InferShape(const std::shared_ptr<Scope>& scope) const final { | ||
| std::vector<const Tensor*> ins; | ||
| VarNamesToTensors(scope, inputs_, &ins); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -84,8 +84,8 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { | |
| public: | ||
| OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) | ||
| : OpProtoAndCheckerMaker(proto, op_checker) { | ||
| AddInput("input", "input of test op"); | ||
| AddOutput("output", "output of test op"); | ||
| AddInput("x", "input of test op"); | ||
| AddOutput("y", "output of test op"); | ||
| AddAttr<float>("scale", "scale of cosine op") | ||
| .SetDefault(1.0) | ||
| .LargerThan(0.0); | ||
|
|
@@ -101,11 +101,68 @@ class OpWithKernelTest : public OperatorWithKernel { | |
|
|
||
| class CPUKernelTest : public OpKernel { | ||
| public: | ||
| void Compute(const KernelContext& context) const { | ||
| float scale = context.op_.GetAttr<float>("scale"); | ||
| void Compute(const KernelContext& ctx) const { | ||
| float scale = ctx.op_.GetAttr<float>("scale"); | ||
| ASSERT_NEAR(scale, 3.14, 1e-5); | ||
| std::cout << "this is cpu kernel" << std::endl; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cpu_kernel_run_num++ should be added back |
||
| std::cout << context.op_.DebugString() << std::endl; | ||
| std::cout << ctx.op_.DebugString() << std::endl; | ||
| ASSERT_EQ(ctx.op_.Input("x"), "IN1"); | ||
| ASSERT_EQ(ctx.op_.Output("y"), "OUT1"); | ||
| } | ||
| }; | ||
|
|
||
| // multiple inputs test | ||
| class OperatorMultiInputsTest : public OperatorBase { | ||
| public: | ||
| void Init() override { x = 1; } | ||
| void InferShape(const std::shared_ptr<Scope>& scope) const override {} | ||
| void Run(const std::shared_ptr<Scope>& scope, | ||
| const platform::DeviceContext& dev_ctx) const override { | ||
| float scale = GetAttr<float>("scale"); | ||
| ASSERT_NEAR(scale, 3.14, 1e-5); | ||
|
||
| ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr); | ||
| ASSERT_EQ(x, 1); | ||
| ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr); | ||
| ASSERT_EQ(Input("x"), "IN1"); | ||
| ASSERT_EQ(Input("y"), "OUT1"); | ||
| } | ||
|
|
||
| public: | ||
| float x = 0; | ||
| }; | ||
|
|
||
| class OpKernelTestMultiInputsProtoAndCheckerMaker | ||
| : public OpProtoAndCheckerMaker { | ||
| public: | ||
| OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto, | ||
| OpAttrChecker* op_checker) | ||
| : OpProtoAndCheckerMaker(proto, op_checker) { | ||
| AddInputs("xs", "input of test op"); | ||
|
||
| AddInput("k", "input of test op"); | ||
| AddOutputs("ys", "output of test op"); | ||
|
||
| AddAttr<float>("scale", "scale of cosine op") | ||
| .SetDefault(1.0) | ||
| .LargerThan(0.0); | ||
| AddComment("This is test op"); | ||
| } | ||
| }; | ||
|
|
||
| class CPUKernalMultiInputsTest : public OpKernel { | ||
| public: | ||
| void Compute(const KernelContext& ctx) const { | ||
| auto xs = ctx.op_.Inputs("xs"); | ||
| ASSERT_EQ(xs.size(), 3UL); | ||
| ASSERT_EQ(xs[0], "x0"); | ||
| ASSERT_EQ(xs[1], "x1"); | ||
| ASSERT_EQ(xs[2], "x2"); | ||
|
|
||
| auto k = ctx.op_.Input("k"); | ||
| ASSERT_EQ(k, "k0"); | ||
|
|
||
| auto ys = ctx.op_.Outputs("ys"); | ||
| ASSERT_EQ(ys.size(), 2UL); | ||
| ASSERT_EQ(ys[0], "y0"); | ||
| ASSERT_EQ(ys[1], "y1"); | ||
| } | ||
| }; | ||
|
|
||
|
|
@@ -116,6 +173,7 @@ REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest, | |
| paddle::framework::OpKernelTestProtoAndCheckerMaker); | ||
| REGISTER_OP_CPU_KERNEL(op_with_kernel, paddle::framework::CPUKernelTest); | ||
|
|
||
| // test with single input | ||
| TEST(OpKernel, all) { | ||
| paddle::framework::OpDesc op_desc; | ||
| op_desc.set_type("op_with_kernel"); | ||
|
|
@@ -133,3 +191,47 @@ TEST(OpKernel, all) { | |
| paddle::framework::OpRegistry::CreateOp(op_desc); | ||
| op->Run(scope, cpu_device_context); | ||
| } | ||
|
|
||
| REGISTER_OP(op_multi_inputs_with_kernel, paddle::framework::OpWithKernelTest, | ||
| paddle::framework::OpKernelTestMultiInputsProtoAndCheckerMaker); | ||
| REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel, | ||
| paddle::framework::CPUKernalMultiInputsTest); | ||
|
|
||
| // test with multi inputs | ||
| TEST(OpKernel, multi_inputs) { | ||
| using namespace paddle::framework; | ||
|
|
||
| OpDesc op_desc; | ||
| op_desc.set_type("op_multi_inputs_with_kernel"); | ||
| *op_desc.mutable_inputs()->Add() = "x0"; | ||
| *op_desc.mutable_inputs()->Add() = "x1"; | ||
| *op_desc.mutable_inputs()->Add() = "x2"; | ||
| *op_desc.mutable_inputs()->Add() = "k0"; | ||
| *op_desc.mutable_outputs()->Add() = "y0"; | ||
| *op_desc.mutable_outputs()->Add() = "y1"; | ||
| auto attr = op_desc.mutable_attrs()->Add(); | ||
| attr->set_name("scale"); | ||
| attr->set_type(paddle::framework::AttrType::FLOAT); | ||
| attr->set_f(3.14); | ||
|
|
||
| auto attr0 = op_desc.mutable_attrs()->Add(); | ||
| attr0->set_name("input_format"); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. input_starts or input_offsets?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| attr0->set_type(paddle::framework::AttrType::INTS); | ||
| auto input_format = attr0->mutable_ints(); | ||
| input_format->Add(0); // x0 | ||
| input_format->Add(3); // k | ||
| input_format->Add(4); // end | ||
|
|
||
| auto attr1 = op_desc.mutable_attrs()->Add(); | ||
| attr1->set_name("output_format"); | ||
| attr1->set_type(paddle::framework::AttrType::INTS); | ||
| auto output_format = attr1->mutable_ints(); | ||
| output_format->Add(0); // y0 | ||
| output_format->Add(2); // y1 | ||
|
|
||
| paddle::platform::CPUDeviceContext cpu_device_context; | ||
| auto scope = std::make_shared<Scope>(); | ||
|
|
||
| OperatorPtr op(paddle::framework::OpRegistry::CreateOp(op_desc)); | ||
| op->Run(scope, cpu_device_context); | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why need this include?