Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/framework/enforce.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <glog/logging.h>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why need this include?

#include <paddle/string/printf.h>
#include <exception>
#include <sstream>
Expand Down
6 changes: 6 additions & 0 deletions paddle/framework/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,17 +201,23 @@ class OpRegistry {
static OperatorPtr CreateOp(const OpDesc& op_desc) {
std::string op_type = op_desc.type();
OperatorPtr op(creators().at(op_type)());
const OpProto& op_proto = protos().at(op_type);
op->desc_ = op_desc;
// set op's inputs_ from desc.
op->inputs_.reserve((size_t)op_desc.inputs_size());
std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
std::back_inserter(op->inputs_));
// set op's outputs_ from desc.
op->outputs_.reserve((size_t)op_desc.outputs_size());
std::copy(op_desc.outputs().begin(), op_desc.outputs().end(),
std::back_inserter(op->outputs_));
// set op's attr;
for (auto& attr : op_desc.attrs()) {
op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr);
}
op_checkers().at(op_type).Check(op->attrs_);
// set argument offsets stored in op.
op->CreateArgumentOffsetMap(op_proto);
op->Init();
return op;
}
Expand Down
57 changes: 57 additions & 0 deletions paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,68 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <algorithm>

#include "paddle/framework/operator.h"

namespace paddle {
namespace framework {

void OperatorBase::CreateArgumentOffsetMap(const OpProto& proto) {
Copy link
Member

@jacquesqiao jacquesqiao Jul 16, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argument in caffe2 has the same meaning with Attribute in Paddle. Should we name it CreateInOutOffsetMap or something else?

Copy link
Contributor Author

@Superjomn Superjomn Jul 17, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this map stores both the inputs and outputs offests, so maybe CreateArgmentOffsetsMap ?

Copy link
Member

@jacquesqiao jacquesqiao Jul 17, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InOut instead of Input ~

for (int i = 0; i < proto.inputs_size(); i++) {
const auto& name = proto.inputs()[i].name();
arg_idxs_[name] = i;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in_out_inxs?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, index --> indices?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it doesn't matter, we all understand what arg_idxs_ means.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为了保持一致,建议还是改成 in_out_idxs 之类的。

}
for (int i = 0; i < proto.outputs_size(); i++) {
const auto& name = proto.outputs()[i].name();
arg_idxs_[name] = i;
}
}

const std::string& OperatorBase::Input(const std::string& name) const {
auto it = arg_idxs_.find(name);
PADDLE_ENFORCE(it != arg_idxs_.end(), "no key [%d] in arg_idxs_", name);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[%d] --> [%s]


if (attrs_.count("input_format") == 0) {
return inputs_[it->second];
} else {
const auto& input_format = GetAttr<std::vector<int>>("input_format");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does the input_format mean? If it is the necessary attribute for dynamic-length inputs, should check it in the op_checker ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

int idx = input_format[it->second];
return inputs_.at(idx);
}
}

std::vector<std::string> OperatorBase::Inputs(const std::string& name) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible optimization here. Return a vector_view instead of a vector to prevent memory copy. It is not critical for now. But we can still have a TODO comments here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added in header

done

auto input_format = GetAttr<std::vector<int>>("input_format");
auto offset = arg_idxs_.at(name);

return std::vector<std::string>{
inputs_.begin() + input_format.at(offset),
inputs_.begin() + input_format.at(offset + 1)};
}

const std::string& OperatorBase::Output(const std::string& name) const {
auto it = arg_idxs_.find(name);
PADDLE_ENFORCE(it != arg_idxs_.end(), "no key [%d] in arg_idxs_", name);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[%d] --> [%s]


if (attrs_.count("output_format") == 0) {
return outputs_[it->second];
} else {
const auto& output_format = GetAttr<std::vector<int>>("output_format");
int idx = output_format[it->second];
return outputs_.at(idx);
}
}

std::vector<std::string> OperatorBase::Outputs(const std::string& name) const {
auto output_format = GetAttr<std::vector<int>>("output_format");
auto offset = arg_idxs_.at(name);

return std::vector<std::string>{
outputs_.begin() + output_format.at(offset),
outputs_.begin() + output_format.at(offset + 1)};
}

std::string OperatorBase::DebugString() const {
std::stringstream ss;
ss << "=================\n";
Expand Down
97 changes: 70 additions & 27 deletions paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,20 @@ limitations under the License. */

#pragma once

#include <paddle/framework/attr_checker.h>
#include <paddle/framework/op_desc.pb.h>
#include <paddle/framework/scope.h>
#include <paddle/framework/tensor.h>
#include <paddle/platform/device_context.h>
#include <paddle/platform/place.h>
#include <paddle/utils/Error.h>
#include <boost/variant.hpp>
#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/framework/attr_checker.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
#include "paddle/utils/Error.h"

namespace paddle {
namespace framework {

Expand Down Expand Up @@ -62,6 +64,18 @@ class OperatorBase {
virtual void Run(const ScopePtr& scope,
const platform::DeviceContext& dev_ctx) const = 0;

// Get a input with argument's name described in `op_desc`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

described in op_proto or op_proto_and_checker_maker

const std::string& Input(const std::string& name) const;
// Get a input which has multiple variables.
std::vector<std::string> Inputs(const std::string& name) const;
// Get a output with argument's name described in `op_desc`
const std::string& Output(const std::string& name) const;
// Get an output which has multiple variables.
std::vector<std::string> Outputs(const std::string& name) const;

// init arg_idxs_ to accelerate argument's offset lookup.
void CreateArgumentOffsetMap(const OpProto& proto);

protected:
std::string Type() const { return desc_.type(); }

Expand All @@ -70,6 +84,53 @@ class OperatorBase {
std::vector<std::string> inputs_;
std::vector<std::string> outputs_;
AttributeMap attrs_;
// store the arguments' offset described in op_desc.
std::unordered_map<std::string, int> arg_idxs_;
};

class KernelContext {
public:
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& device_context)
: op_(*op), scope_(scope), device_context_(device_context) {}

const Variable* Input(int index) const {
return scope_->GetVariable(op_.inputs_[index]);
}

Variable* Output(int index) const {
return scope_->GetVariable(op_.outputs_[index]);
}

const Variable* Input(const std::string& name) const {
return scope_->GetVariable(op_.Input(name));
}

const Variable* Output(const std::string& name) const {
return scope_->GetVariable(op_.Output(name));
}

const std::vector<const Variable*> Inputs(const std::string& name) const {
auto names = op_.Inputs(name);
std::vector<const Variable*> res;
std::transform(
names.begin(), names.end(), res.begin(),
[this](const std::string& name) { return scope_->GetVariable(name); });
return res;
}

const std::vector<const Variable*> Outputs(const std::string& name) const {
auto names = op_.Outputs(name);
std::vector<const Variable*> res;
std::transform(
names.begin(), names.end(), res.begin(),
[this](const std::string& name) { return scope_->GetVariable(name); });
return res;
}

const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_;
const platform::DeviceContext& device_context_;
};

class OpKernel {
Expand All @@ -80,25 +141,6 @@ class OpKernel {
* device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator.
*/
class KernelContext {
public:
KernelContext(const OperatorBase* op, const ScopePtr& scope,
const platform::DeviceContext& device_context)
: op_(*op), scope_(scope), device_context_(device_context) {}

const Variable* Input(int index) const {
return scope_->GetVariable(op_.inputs_[index]);
}

Variable* Output(int index) const {
return scope_->GetVariable(op_.outputs_[index]);
}

const OperatorBase& op_;
const ScopePtr& scope_;
const platform::DeviceContext& device_context_;
};

virtual void Compute(const KernelContext& context) const = 0;

virtual ~OpKernel() {}
Expand Down Expand Up @@ -143,14 +185,15 @@ class OperatorWithKernel : public OperatorBase {
void Run(const ScopePtr& scope,
const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(Type()).at(OpKernelKey(dev_ctx));
opKernel->Compute(OpKernel::KernelContext(this, scope, dev_ctx));
opKernel->Compute(KernelContext(this, scope, dev_ctx));
}

static std::unordered_map<std::string /* op_type */, OpKernelMap>&
AllOpKernels() {
static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
return g_all_op_kernels;
}

void InferShape(const std::shared_ptr<Scope>& scope) const final {
std::vector<const Tensor*> ins;
VarNamesToTensors(scope, inputs_, &ins);
Expand Down
112 changes: 107 additions & 5 deletions paddle/framework/operator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of test op");
AddOutput("output", "output of test op");
AddInput("x", "input of test op");
AddOutput("y", "output of test op");
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
Expand All @@ -101,11 +101,68 @@ class OpWithKernelTest : public OperatorWithKernel {

class CPUKernelTest : public OpKernel {
public:
void Compute(const KernelContext& context) const {
float scale = context.op_.GetAttr<float>("scale");
void Compute(const KernelContext& ctx) const {
float scale = ctx.op_.GetAttr<float>("scale");
ASSERT_NEAR(scale, 3.14, 1e-5);
std::cout << "this is cpu kernel" << std::endl;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cpu_kernel_run_num++ should be added back

std::cout << context.op_.DebugString() << std::endl;
std::cout << ctx.op_.DebugString() << std::endl;
ASSERT_EQ(ctx.op_.Input("x"), "IN1");
ASSERT_EQ(ctx.op_.Output("y"), "OUT1");
}
};

// multiple inputs test
class OperatorMultiInputsTest : public OperatorBase {
public:
void Init() override { x = 1; }
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {
float scale = GetAttr<float>("scale");
ASSERT_NEAR(scale, 3.14, 1e-5);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can remove the unnecessary assert

ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr);
ASSERT_EQ(x, 1);
ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr);
ASSERT_EQ(Input("x"), "IN1");
ASSERT_EQ(Input("y"), "OUT1");
}

public:
float x = 0;
};

class OpKernelTestMultiInputsProtoAndCheckerMaker
: public OpProtoAndCheckerMaker {
public:
OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto,
OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInputs("xs", "input of test op");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inputs of test op

AddInput("k", "input of test op");
AddOutputs("ys", "output of test op");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

outputs of test op

AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
AddComment("This is test op");
}
};

class CPUKernalMultiInputsTest : public OpKernel {
public:
void Compute(const KernelContext& ctx) const {
auto xs = ctx.op_.Inputs("xs");
ASSERT_EQ(xs.size(), 3UL);
ASSERT_EQ(xs[0], "x0");
ASSERT_EQ(xs[1], "x1");
ASSERT_EQ(xs[2], "x2");

auto k = ctx.op_.Input("k");
ASSERT_EQ(k, "k0");

auto ys = ctx.op_.Outputs("ys");
ASSERT_EQ(ys.size(), 2UL);
ASSERT_EQ(ys[0], "y0");
ASSERT_EQ(ys[1], "y1");
}
};

Expand All @@ -116,6 +173,7 @@ REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest,
paddle::framework::OpKernelTestProtoAndCheckerMaker);
REGISTER_OP_CPU_KERNEL(op_with_kernel, paddle::framework::CPUKernelTest);

// test with single input
TEST(OpKernel, all) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("op_with_kernel");
Expand All @@ -133,3 +191,47 @@ TEST(OpKernel, all) {
paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cpu_device_context);
}

REGISTER_OP(op_multi_inputs_with_kernel, paddle::framework::OpWithKernelTest,
paddle::framework::OpKernelTestMultiInputsProtoAndCheckerMaker);
REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel,
paddle::framework::CPUKernalMultiInputsTest);

// test with multi inputs
TEST(OpKernel, multi_inputs) {
using namespace paddle::framework;

OpDesc op_desc;
op_desc.set_type("op_multi_inputs_with_kernel");
*op_desc.mutable_inputs()->Add() = "x0";
*op_desc.mutable_inputs()->Add() = "x1";
*op_desc.mutable_inputs()->Add() = "x2";
*op_desc.mutable_inputs()->Add() = "k0";
*op_desc.mutable_outputs()->Add() = "y0";
*op_desc.mutable_outputs()->Add() = "y1";
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(3.14);

auto attr0 = op_desc.mutable_attrs()->Add();
attr0->set_name("input_format");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_starts or input_offsets?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

attr0->set_type(paddle::framework::AttrType::INTS);
auto input_format = attr0->mutable_ints();
input_format->Add(0); // x0
input_format->Add(3); // k
input_format->Add(4); // end

auto attr1 = op_desc.mutable_attrs()->Add();
attr1->set_name("output_format");
attr1->set_type(paddle::framework::AttrType::INTS);
auto output_format = attr1->mutable_ints();
output_format->Add(0); // y0
output_format->Add(2); // y1

paddle::platform::CPUDeviceContext cpu_device_context;
auto scope = std::make_shared<Scope>();

OperatorPtr op(paddle::framework::OpRegistry::CreateOp(op_desc));
op->Run(scope, cpu_device_context);
}