Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions paddle/operators/sequence_concat_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/sequence_concat_op.h"

namespace paddle {
namespace operators {

class SequenceConcatOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInputs("X"),
"Inputs(X) of SequenceConcatOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequenceConcatOp should not be null.");
const size_t level = static_cast<size_t>(ctx->Attrs().Get<int>("level"));
const size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis"));
PADDLE_ENFORCE(level == 0UL || level == 1UL,
"The sequence_concat operator only accepts sequence "
"or a nested sequence as its input.");
auto ins_dims = ctx->GetInputsDim("X");
framework::DDim out_dims = ins_dims[0];
const size_t n = ins_dims.size();
for (size_t i = 1; i < n; ++i) {
out_dims[axis] += ins_dims[i][axis];
}
ctx->SetOutputDim("Out", out_dims);
Copy link
Contributor

@qingqing01 qingqing01 Oct 10, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前设计框架时讨论,InferShape里是要能够推断出完成的Shape信息,所以下面LoD的check, set_lod, concatLoD实现可能需要移到这里。 @reyoung

Copy link
Contributor Author

@Yancey0623 Yancey0623 Oct 11, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

赞同在InferShape里推断出所有Shape信息,但现在的接口貌似还没有获取LoD的接口?
https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/shape_inference.h#L25

}
};

class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceConcatOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"The input Multip LoDTensors, which are variable-length "
"sequence or nested sequence.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AddInput("X",
    "(A vector of LoDTensor), the input is a vector of LoDTensor, "
    "each of which is a variable-length sequence or nested sequence.")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

.AsDuplicable();
AddOutput("Out",
"A LoDTensor, the variable-length output of "
"sequence_concat Op.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AddOutput("Out",
     "(A LoDTensor), the variable-length output of "
     "sequence_concat Op.");

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

AddAttr<int>("axis",
"(int, default 0)"
"The axis which the inputs will be joined with."
"If axis is 0, the inputs will be joined with LoD index.")
.SetDefault(0);
AddAttr<int>("level",
"(int, default 0)"
"The level which the inputs will be joined with."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"The level at which the inputs will be joined."

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"If level is 0, the inputs will be joined with "
"nested sequences."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the level is 0, the inputs will be joined at the nested sequence level, which

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"If level is 1, the inputs will be joined with sequences.")
.SetDefault(0);
AddComment(R"DOC(
The sequence_concat operator concatenates multiple LoDTensors.
It only supports sequences ( LoD Tensor with level=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • ( LoD Tensor with level=1) --> (LoD Tensor with level=1) , an extra space

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

or nested sequences (LoD tensor with level=0) as its inputs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a nested sequence is a LoD tensor with level=2? Do I make a mistake?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, you're right, and maybe level number is 2 is more suitable? Because we also have an attribute level=x, this may make some confusion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

- Case1:
If the axis is 1, level is 1, the LoD of Inputs are the same,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Inputs --> input
  • If the axis is other than 0, (here, axis is 1 and level is 1), each input should have the same LoD information and the LoD information of the output keeps the same as the input.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

LoD(x0) = {{0,2,4},{0,1,2,3,4}}; Dims(x0) = (2,3,4)
LoD(x1) = {{0,2,4},{0,1,2,3,4}}; Dims(x1) = (2,4,4)
LoD(Out) = {{0,2,4},{0,1,2,3,4}}; Dims(Out) = (2,7,4)
- Case2:
If the axis is 0, level is 1, the LoD of inputs are different,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If axis is 0 (here, level is 1) the inputs are concatenated along time steps, the LoD information of the output need to re-compute.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (2,3,4)
LoD(x1) = {{0,3,5}, {0,1,3,4,5}}; Dims(x1) = (3,3,4)
LoD(Out) = {{0,5,9}, {0,1,2,4,5,6,7,8,9}}; Dims(Out) = (5,3,4)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LoD or Lod?
I prefer to keep consistent in the comment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add NOTE into the doc, as I understand, the level of all the inputs should be the same (if I am right).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • The examples here are all with level=1, what will happen if level=0?
  • The above examples explain the role of axis, but it is still hard to understand how the attribute level works.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a unit test with level=0, done.


NOTE: The level of all the inputs should be the same.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The level --> The levels.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

)DOC");
}
};

class SequenceConcatGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"The gradient of Out should not be null.");
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName("X")),
"The gradient of X should not be empty.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • " should not be empty." --> " should not be null.". Keep the output information the same as check in line 95, since I think these two checks are for almost for the same purpose.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP(sequence_concat, ops::SequenceConcatOp, ops::SequenceConcatOpMaker,
sequence_concat_grad, ops::SequenceConcatGradOp);
REGISTER_OP_CPU_KERNEL(
sequence_concat,
ops::SequenceConcatOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
sequence_concat_grad,
ops::SequenceConcatGradOpKernel<paddle::platform::CPUPlace, float>);
25 changes: 25 additions & 0 deletions paddle/operators/sequence_concat_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#define EIGEN_USE_GPU

#include "paddle/operators/sequence_concat_op.h"

namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
sequence_concat,
ops::SequenceConcatOpKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
sequence_concat_grad,
ops::SequenceConcatGradOpKernel<paddle::platform::GPUPlace, float>);
167 changes: 167 additions & 0 deletions paddle/operators/sequence_concat_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/strided_memcpy.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using LoD = framework::LoD;

// Concat LoD, the initialized LoD of Output is lod(x0),
// if axis is not 0, the LoD(Out) will be the same as Inputs, if axis is 0:
// Case1:
// There is one level, the Output LoD will be modified:
// LoD(x0) = {{0,2,4}}
// LoD(x1) = {{0,1,5}}
// LoD(Out) = {{0,3,9}}
// Case2:
// There is two level, and concat level is 1,
// the Output LoD will be modified as followed:
// LoD(x0) = {{0,2,4}, {0,1,2,3,4}}
// LoD(x1) = {{0,3,5}, {0,1,3,4,5}}
// LoD(Out) = {{0,5,9}, {0,1,2,4,5,6,7,8,9}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think line 26 ~ line 38 can be deleted because the examples in the operator comments explain the same logic well too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

template <typename T>
LoD concatLoD(const std::vector<const T*> ins, const size_t axis,
const size_t level) {
auto out_lod = ins[0]->lod();
const size_t n = ins.size();
if (axis == 0UL) {
if (level == 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

level == 0UL

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

for (size_t i = 1; i < n; ++i) {
for (size_t j = 0; j < ins[i]->lod()[0].size(); ++j) {
out_lod[0][j] += ins[i]->lod()[0][j];
}
}
} else if (level == 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

level == 1UL

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

PADDLE_ENFORCE_EQ(ins[0]->NumLevels(), 2UL,
"If the level is 1, all of the inputs "
"should be the the nested sequence.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • the the --> an extra the
  • nested sequence --> nested sequences.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

for (size_t i = 1; i < n; ++i) {
for (size_t j = 0; j < ins[i]->lod()[0].size(); ++j) {
out_lod[0].push_back(ins[i]->lod()[0][j]);
}
for (size_t j = 0; j < ins[i]->lod()[1].size(); ++j) {
out_lod[1][j] += ins[i]->lod()[1][j];
}
}
}
}
return out_lod;
}

template <typename Place, typename T>
class SequenceConcatOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<LoDTensor>("X");
auto* out = ctx.Output<LoDTensor>("Out");
const size_t axis = static_cast<size_t>(ctx.Attr<int>("axis"));
const size_t level = static_cast<size_t>(ctx.Attr<int>("level"));
const size_t n = ins.size();

for (size_t i = 1; i < n; ++i) {
PADDLE_ENFORCE_EQ(ins[0]->NumLevels(), ins[i]->NumLevels(),
"The level number of all the input LoDTensors "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The level number --> The levels

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"should be the same.");
PADDLE_ENFORCE_EQ(ins[0]->dims().size(), ins[i]->dims().size(),
"The dimensions size of all the input LoDTensors "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dimensions size --> The dimension size

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"should be the same.");

const size_t dims_size = ins[i]->dims().size();
for (size_t j = 0; j < dims_size; ++j) {
if (j == axis) continue;
PADDLE_ENFORCE_EQ(ins[0]->dims()[j], ins[i]->dims()[j],
"The dimensions of all the input LoDTensors "
"except for the specify axis should be "
"matched exactly.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Except for the dimension of the specified axis along which all the inputs are concatenated, dimensions of all the other axises of the input LoDTensors should be the same.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}
}

out->mutable_data<T>(ctx.GetPlace());
auto out_lod = concatLoD<LoDTensor>(ins, axis, level);
out->set_lod(out_lod);

auto out_lod_level = out_lod[level];
for (size_t i = 0; i < out_lod_level.size() - 1; ++i) {
Tensor out_t = out->Slice<T>(static_cast<int>(out_lod_level[i]),
static_cast<int>(out_lod_level[i + 1]));
auto out_stride = framework::stride(out_t.dims());
size_t offset = 0;

for (size_t j = 0; j < n; ++j) {
auto in_lod_level = ins[j]->lod()[level];
auto in_stride = framework::stride(ins[j]->dims());
Tensor in_t = ins[j]->Slice<T>(static_cast<int>(in_lod_level[i]),
static_cast<int>(in_lod_level[i + 1]));
size_t axis_dim = in_t.dims()[axis];
StridedMemcpy<T>(ctx.device_context(), in_t.data<T>(), in_stride,
in_t.dims(), out_stride, out_t.data<T>() + offset);
offset += axis_dim * in_stride[axis];
}
}
}
};

template <typename Place, typename T>
class SequenceConcatGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<framework::LoDTensor>("X");
auto* out_grad =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto x_grads =
ctx.MultiOutput<framework::LoDTensor>(framework::GradVarName("X"));
size_t axis = static_cast<size_t>(ctx.Attr<int>("axis"));
size_t level = static_cast<size_t>(ctx.Attr<int>("level"));
const size_t n = x_grads.size();

// Set Grad(X) LoD as X
for (size_t i = 0; i < n; i++) {
x_grads[i]->set_lod(ins[i]->lod());
x_grads[i]->mutable_data<T>(ctx.GetPlace());
}

auto out_lod = concatLoD<LoDTensor>(ins, axis, level);
auto out_lod_level = out_lod[level];

for (size_t i = 0; i < out_lod_level.size() - 1; ++i) {
Tensor out_grad_t =
out_grad->Slice<T>(static_cast<int>(out_lod_level[i]),
static_cast<int>(out_lod_level[i + 1]));
auto out_grad_stride = framework::stride(out_grad_t.dims());
size_t offset = 0;

for (size_t j = 0; j < n; ++j) {
auto x_grad_lod_level = x_grads[j]->lod()[level];
auto x_grad_stride = framework::stride(x_grads[j]->dims());
Tensor x_grad_t =
x_grads[j]->Slice<T>(static_cast<int>(x_grad_lod_level[i]),
static_cast<int>(x_grad_lod_level[i + 1]));
size_t axis_dim = x_grad_t.dims()[axis];
StridedMemcpy<T>(ctx.device_context(), out_grad_t.data<T>() + offset,
out_grad_stride, out_grad_t.dims(), x_grad_stride,
x_grad_t.data<T>());
offset += axis_dim * out_grad_stride[axis];
}
}
}
};

} // namespace operators
} // namespace paddle
57 changes: 57 additions & 0 deletions python/paddle/v2/framework/tests/test_seq_concat_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import unittest
import numpy as np
from op_test import OpTest


class TestConcatOp(OpTest):
def set_data(self):
# two level, batch size is 3
x0 = np.random.random((11, 6, 3)).astype('float32')
lod0 = [[0, 2, 5, 11], [0, 1, 2, 5, 7, 11]]
x1 = np.random.random((11, 8, 3)).astype('float32')
lod1 = [[0, 2, 5, 11], [0, 1, 2, 5, 7, 11]]
axis = 1
level = 1
self.inputs = {'X': [('x0', (x0, lod0)), ('x1', (x1, lod1))]}
self.attrs = {'axis': axis, 'level': level}
outs = []
for i in range(5):
sub_x0 = x0[lod0[level][i]:lod0[level][i + 1], :]
sub_x1 = x1[lod1[level][i]:lod1[level][i + 1], :]
outs.append(np.concatenate((sub_x0, sub_x1), axis=axis))

self.outputs = {'Out': np.concatenate(outs, axis=0)}

def setUp(self):
self.op_type = "sequence_concat"
self.set_data()

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['x0'], 'Out')


class TestConcatOpDiffLod(TestConcatOp):
def set_data(self):
# two level, batch size is 3
x0 = np.random.random((12, 6, 3)).astype('float32')
lod0 = [[0, 3, 9, 12], [0, 2, 3, 5, 9, 12]]
x1 = np.random.random((11, 6, 3)).astype('float32')
lod1 = [[0, 2, 5, 11], [0, 1, 2, 5, 7, 11]]
axis = 0
level = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Only my personal question, I found in the comment and the unittest, only level=1 is tested. It is still hard to make others understand how the attribute level works.
  • Is it necessary to test level=0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Add a unit test with level=0.

self.inputs = {'X': [('x0', (x0, lod0)), ('x1', (x1, lod1))]}
self.attrs = {'axis': axis, 'level': level}
outs = []
for i in range(5):
sub_x0 = x0[lod0[level][i]:lod0[level][i + 1], :]
sub_x1 = x1[lod1[level][i]:lod1[level][i + 1], :]
outs.append(np.concatenate((sub_x0, sub_x1), axis=axis))

self.outputs = {'Out': np.concatenate(outs, axis=0)}


if __name__ == '__main__':
unittest.main()