-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Op varient inputs #2901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Op varient inputs #2901
Changes from all commits
937b099
bec798b
5a0233c
e58ddea
f0f029d
2c3473a
771db04
edd0d01
84eef6f
d57d9bc
0113c4c
a8e3c8a
010a1d8
5459382
39be0ba
b2d77a2
0327d6c
d7f5de9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,11 +12,69 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include <algorithm> | ||
|
|
||
| #include "paddle/framework/operator.h" | ||
|
|
||
| namespace paddle { | ||
| namespace framework { | ||
|
|
||
| void OperatorBase::CreateInOutOffsetMap(const OpProto& proto) { | ||
| PADDLE_ENFORCE(in_out_idxs_.empty(), "duplicate call CreateInOutOffsetMap"); | ||
| for (int i = 0; i < proto.inputs_size(); i++) { | ||
| const auto& name = proto.inputs()[i].name(); | ||
| in_out_idxs_[name] = i; | ||
| } | ||
| for (int i = 0; i < proto.outputs_size(); i++) { | ||
| const auto& name = proto.outputs()[i].name(); | ||
| in_out_idxs_[name] = i; | ||
| } | ||
| } | ||
|
|
||
| const std::string& OperatorBase::Input(const std::string& name) const { | ||
| auto it = in_out_idxs_.find(name); | ||
| PADDLE_ENFORCE(it != in_out_idxs_.end(), "no key [%s] in in_out_idxs_", name); | ||
|
|
||
| if (attrs_.count("input_format") == 0) { | ||
| return inputs_[it->second]; | ||
| } else { | ||
| const auto& input_format = GetAttr<std::vector<int>>("input_format"); | ||
| int idx = input_format[it->second]; | ||
| return inputs_.at(idx); | ||
| } | ||
| } | ||
|
|
||
| std::vector<std::string> OperatorBase::Inputs(const std::string& name) const { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Possible optimization here. Return a vector_view instead of a vector to prevent memory copy. It is not critical for now. But we can still have a TODO comments here.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added in header done |
||
| auto input_format = GetAttr<std::vector<int>>("input_format"); | ||
| auto offset = in_out_idxs_.at(name); | ||
|
|
||
| return std::vector<std::string>{ | ||
| inputs_.begin() + input_format.at(offset), | ||
| inputs_.begin() + input_format.at(offset + 1)}; | ||
| } | ||
|
|
||
| const std::string& OperatorBase::Output(const std::string& name) const { | ||
| auto it = in_out_idxs_.find(name); | ||
| PADDLE_ENFORCE(it != in_out_idxs_.end(), "no key [%s] in in_out_idxs_", name); | ||
|
|
||
| if (attrs_.count("output_format") == 0) { | ||
| return outputs_[it->second]; | ||
| } else { | ||
| const auto& output_format = GetAttr<std::vector<int>>("output_format"); | ||
| int idx = output_format[it->second]; | ||
| return outputs_.at(idx); | ||
| } | ||
| } | ||
|
|
||
| std::vector<std::string> OperatorBase::Outputs(const std::string& name) const { | ||
| auto output_format = GetAttr<std::vector<int>>("output_format"); | ||
| auto offset = in_out_idxs_.at(name); | ||
|
|
||
| return std::vector<std::string>{ | ||
| outputs_.begin() + output_format.at(offset), | ||
| outputs_.begin() + output_format.at(offset + 1)}; | ||
| } | ||
|
|
||
| std::string OperatorBase::DebugString() const { | ||
| std::stringstream ss; | ||
| ss << "=================\n"; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,7 +30,6 @@ class OpWithoutKernelTest : public OperatorBase { | |
| op_run_num++; | ||
| ASSERT_EQ((int)inputs_.size(), 1); | ||
| ASSERT_EQ((int)outputs_.size(), 1); | ||
| ASSERT_NEAR(GetAttr<float>("scale"), 3.14, 1e-5); | ||
| ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr); | ||
| ASSERT_EQ(x, 1); | ||
| ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr); | ||
|
|
@@ -86,9 +85,11 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { | |
| public: | ||
| OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) | ||
| : OpProtoAndCheckerMaker(proto, op_checker) { | ||
| AddInput("input", "input of test op"); | ||
| AddOutput("output", "output of test op"); | ||
| AddAttr<float>("scale", "scale of cosine op"); | ||
| AddInput("x", "input of test op"); | ||
| AddOutput("y", "output of test op"); | ||
| AddAttr<float>("scale", "scale of cosine op") | ||
| .SetDefault(1.0) | ||
| .LargerThan(0.0); | ||
| AddComment("This is test op"); | ||
| } | ||
| }; | ||
|
|
@@ -103,11 +104,65 @@ class OpWithKernelTest : public OperatorWithKernel { | |
|
|
||
| class CPUKernelTest : public OpKernel { | ||
| public: | ||
| void Compute(const KernelContext& context) const { | ||
| void Compute(const KernelContext& ctx) const { | ||
| std::cout << "this is cpu kernel" << std::endl; | ||
| std::cout << ctx.op_.DebugString() << std::endl; | ||
| cpu_kernel_run_num++; | ||
| ASSERT_EQ((int)context.op_.inputs_.size(), 1); | ||
| ASSERT_EQ((int)context.op_.outputs_.size(), 1); | ||
| ASSERT_NEAR(context.op_.GetAttr<float>("scale"), 3.14, 1e-5); | ||
| ASSERT_EQ(ctx.op_.Input("x"), "IN1"); | ||
| ASSERT_EQ(ctx.op_.Output("y"), "OUT1"); | ||
| } | ||
| }; | ||
|
|
||
| // multiple inputs test | ||
| class OperatorMultiInputsTest : public OperatorBase { | ||
| public: | ||
| void Init() override { x = 1; } | ||
| void InferShape(const std::shared_ptr<Scope>& scope) const override {} | ||
| void Run(const std::shared_ptr<Scope>& scope, | ||
| const platform::DeviceContext& dev_ctx) const override { | ||
| ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr); | ||
| ASSERT_EQ(x, 1); | ||
| ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr); | ||
| ASSERT_EQ(Input("x"), "IN1"); | ||
| ASSERT_EQ(Input("y"), "OUT1"); | ||
| } | ||
|
|
||
| public: | ||
| float x = 0; | ||
| }; | ||
|
|
||
| class OpKernelTestMultiInputsProtoAndCheckerMaker | ||
| : public OpProtoAndCheckerMaker { | ||
| public: | ||
| OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto, | ||
| OpAttrChecker* op_checker) | ||
| : OpProtoAndCheckerMaker(proto, op_checker) { | ||
| AddInputs("xs", "inputs of test op"); | ||
| AddInput("k", "input of test op"); | ||
| AddOutputs("ys", "outputs of test op"); | ||
| AddAttr<float>("scale", "scale of cosine op") | ||
| .SetDefault(1.0) | ||
| .LargerThan(0.0); | ||
| AddComment("This is test op"); | ||
| } | ||
| }; | ||
|
|
||
| class CPUKernalMultiInputsTest : public OpKernel { | ||
| public: | ||
| void Compute(const KernelContext& ctx) const { | ||
| auto xs = ctx.op_.Inputs("xs"); | ||
| ASSERT_EQ(xs.size(), 3UL); | ||
| ASSERT_EQ(xs[0], "x0"); | ||
| ASSERT_EQ(xs[1], "x1"); | ||
| ASSERT_EQ(xs[2], "x2"); | ||
|
|
||
| auto k = ctx.op_.Input("k"); | ||
| ASSERT_EQ(k, "k0"); | ||
|
|
||
| auto ys = ctx.op_.Outputs("ys"); | ||
| ASSERT_EQ(ys.size(), 2UL); | ||
| ASSERT_EQ(ys[0], "y0"); | ||
| ASSERT_EQ(ys[1], "y1"); | ||
| } | ||
| }; | ||
|
|
||
|
|
@@ -118,6 +173,7 @@ REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest, | |
| paddle::framework::OpKernelTestProtoAndCheckerMaker); | ||
| REGISTER_OP_CPU_KERNEL(op_with_kernel, paddle::framework::CPUKernelTest); | ||
|
|
||
| // test with single input | ||
| TEST(OpKernel, all) { | ||
| paddle::framework::OpDesc op_desc; | ||
| op_desc.set_type("op_with_kernel"); | ||
|
|
@@ -137,3 +193,47 @@ TEST(OpKernel, all) { | |
| op->Run(scope, cpu_device_context); | ||
| ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); | ||
| } | ||
|
|
||
| REGISTER_OP(op_multi_inputs_with_kernel, paddle::framework::OpWithKernelTest, | ||
| paddle::framework::OpKernelTestMultiInputsProtoAndCheckerMaker); | ||
| REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel, | ||
| paddle::framework::CPUKernalMultiInputsTest); | ||
|
|
||
| // test with multi inputs | ||
| TEST(OpKernel, multi_inputs) { | ||
| using namespace paddle::framework; | ||
|
|
||
| OpDesc op_desc; | ||
| op_desc.set_type("op_multi_inputs_with_kernel"); | ||
| *op_desc.mutable_inputs()->Add() = "x0"; | ||
| *op_desc.mutable_inputs()->Add() = "x1"; | ||
| *op_desc.mutable_inputs()->Add() = "x2"; | ||
| *op_desc.mutable_inputs()->Add() = "k0"; | ||
| *op_desc.mutable_outputs()->Add() = "y0"; | ||
| *op_desc.mutable_outputs()->Add() = "y1"; | ||
| auto attr = op_desc.mutable_attrs()->Add(); | ||
| attr->set_name("scale"); | ||
| attr->set_type(paddle::framework::AttrType::FLOAT); | ||
| attr->set_f(3.14); | ||
|
|
||
| auto attr0 = op_desc.mutable_attrs()->Add(); | ||
| attr0->set_name("input_format"); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. input_starts or input_offsets?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| attr0->set_type(paddle::framework::AttrType::INTS); | ||
| auto input_format = attr0->mutable_ints(); | ||
| input_format->Add(0); // x0 | ||
| input_format->Add(3); // k | ||
| input_format->Add(4); // end | ||
|
|
||
| auto attr1 = op_desc.mutable_attrs()->Add(); | ||
| attr1->set_name("output_format"); | ||
| attr1->set_type(paddle::framework::AttrType::INTS); | ||
| auto output_format = attr1->mutable_ints(); | ||
| output_format->Add(0); // y0 | ||
| output_format->Add(2); // y1 | ||
|
|
||
| paddle::platform::CPUDeviceContext cpu_device_context; | ||
| auto scope = std::make_shared<Scope>(); | ||
|
|
||
| OperatorPtr op(paddle::framework::OpRegistry::CreateOp(op_desc)); | ||
| op->Run(scope, cpu_device_context); | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does the
input_formatmean? If it is the necessary attribute for dynamic-length inputs, should check it in theop_checker?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done