Skip to content
142 changes: 132 additions & 10 deletions paddle/gserver/layers/ConvShiftLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class ConvShiftLayer : public Layer {

void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
bool isSeqType();
void circularConvSeq();
void circularConvSeqDerivative();
};

REGISTER_LAYER(conv_shift, ConvShiftLayer);
Expand All @@ -66,42 +69,161 @@ bool ConvShiftLayer::init(const LayerMap& layerMap,
return true;
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

22 ~ 42行的注释也需要对应地修改。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经移到Matrix中

bool ConvShiftLayer::isSeqType() {
const Argument& inLayer0 = getInput(0);
if (nullptr == inLayer0.sequenceStartPositions)
return false;
else
return true;
}

void ConvShiftLayer::circularConvSeq() {
const Argument& inLayer0 = getInput(0);
MatrixPtr in0 = inLayer0.value;
MatrixPtr in1 = getInputValue(1);
MatrixPtr out = getOutputValue();
const ICpuGpuVectorPtr& sequenceStartPositions =
inLayer0.sequenceStartPositions;

size_t width0 = in0->getWidth();
size_t numSeqs = sequenceStartPositions->getSize() - 1;
size_t height0 = in0->getHeight();
size_t width1 = in1->getWidth();
size_t height1 = in1->getHeight();

CHECK_EQ(numSeqs, height1);
CHECK_EQ(width0, out->getWidth());
CHECK_EQ(height0, out->getHeight());

CHECK_EQ(width1 % 2, 1U);

real* inV0 = in0->getData();
const int* startPosIntPtr = sequenceStartPositions->getData(false);
real* inV1 = in1->getData();
real* outV = out->getData();

int leftCtxLen = (width1 - 1) / 2;
for (size_t x = 0; x < numSeqs - 1; x++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for循环最后的x++统一改成++x,下同

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

int curSeqLen = startPosIntPtr[x + 1];
size_t curSeqWidth = curSeqLen * width0;
for (size_t i = 0; i < curSeqWidth; i++) {
for (size_t j = 0; j < width1; ++j) {
int index = i + j - leftCtxLen;
index = (index + curSeqWidth) % curSeqWidth;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

111行和112行可以写在一块。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

int outVRowOffset = i / width0;
int outVColOffset = i % width0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

113和114行只和i有关,应该放在109行的for循环里面,不要放在110行的for循环里面,这样会多算很多次。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

int inV0RowOffset = index / width0;
int inV0ColOffset = index % width0;
(outV + outVRowOffset)[outVColOffset] +=
(inV0 + inV0RowOffset)[inV0ColOffset] * inV1[j];
}
}
outV += curSeqWidth;
inV0 += curSeqWidth;
inV1 += width1;
}
}

void ConvShiftLayer::circularConvSeqDerivative() {
const Argument& inLayer0 = getInput(0);
MatrixPtr in0 = inLayer0.value;
MatrixPtr in1 = getInputValue(1);
MatrixPtr inG0 = getInputGrad(0);
MatrixPtr inG1 = getInputGrad(1);
MatrixPtr outG = getOutputGrad();
const ICpuGpuVectorPtr& sequenceStartPositions =
inLayer0.sequenceStartPositions;

size_t height0 = in0->getHeight();
size_t height1 = in1->getHeight();
size_t numSeqs = sequenceStartPositions->getSize() - 1;
size_t width0 = in0->getWidth();
size_t width1 = in1->getWidth();

CHECK_EQ(height1, numSeqs);
CHECK_EQ(height0, inG0->getHeight());
CHECK_EQ(width0, inG0->getWidth());
CHECK_EQ(height1, inG1->getHeight());
CHECK_EQ(width1, inG1->getWidth());
CHECK_EQ(height0, outG->getHeight());
CHECK_EQ(width0, outG->getWidth());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cheack_eq有点多,可以只保留143,148,149三行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个check是必要的,梯度、输入等的维度必须严格一致


const int* startPosIntPtr = sequenceStartPositions->getData(false);
real* outGV = outG->getData();
real* inV0 = in0->getData();
real* inV1 = in1->getData();
real* inGV0 = inG0->getData();
real* inGV1 = inG1->getData();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

命名的时候,GV都可以改成G,下同。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

延伸了Matrix函数的命名方式,GV代表真实的矩阵,感觉比较合适


int leftCtxLen = (width1 - 1) / 2;
for (size_t x = 0; x < numSeqs - 1; x++) {
int curSeqLen = startPosIntPtr[x + 1];
size_t curSeqWidth = curSeqLen * width0;
for (size_t j = 0; j < width1; j++) {
for (size_t i = 0; i < curSeqWidth; i++) {
int index = i + j - leftCtxLen;
index = (index + curSeqWidth) % curSeqWidth;
int inGV0RowOffset = index / width0;
int inGV0ColOffset = index % width0;
int outGVRowOffset = i / width0;
int outGVColOffset = i % width0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修改意见同forward

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

(inGV0 + inGV0RowOffset)[inGV0ColOffset] +=
(outGV + outGVRowOffset)[outGVColOffset] * inV1[j];
inGV1[j] += (outGV + outGVRowOffset)[outGVColOffset] *
(inGV0 + inGV0RowOffset)[inGV0ColOffset];
}
}
outGV += curSeqWidth;
inV0 += curSeqWidth;
inV1 += width1;
inGV0 += curSeqWidth;
inGV1 += width1;
}
}

void ConvShiftLayer::forward(PassType passType) {
Layer::forward(passType);

MatrixPtr inV0 = getInputValue(0);
MatrixPtr inV1 = getInputValue(1);

size_t batchSize = inV0->getHeight();
size_t dataDim = inV0->getWidth();

CHECK_EQ(batchSize, inV1->getHeight());
CHECK_EQ(dataDim, getSize());

{
REGISTER_TIMER_INFO("FwResetTimer", getName().c_str());
resetOutput(batchSize, dataDim);
}

MatrixPtr outV = getOutputValue();

REGISTER_TIMER_INFO("FwConvShiftTimer", getName().c_str());
outV->circularConv(*inV0, *inV1);
if (!isSeqType()) {
MatrixPtr inV1 = getInputValue(1);
CHECK_EQ(batchSize, inV1->getHeight());
MatrixPtr outV = getOutputValue();
outV->circularConv(*inV0, *inV1);
} else {
circularConvSeq();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

将circularConvSeq函数挪到matrix.cpp里面,这个函数的接口应该类似circularConv。即
outV->circularConv(*inV0, *inV1, *seqstartposition);

这样200-207行,可以修改为

if (inLayer0.sequenceStartPositions !=nullptr) {
} else {
}

去掉isSeqType这个函数。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}

void ConvShiftLayer::backward(const UpdateCallback& callback) {
MatrixPtr inV0 = getInputValue(0);
MatrixPtr inV1 = getInputValue(1);
MatrixPtr outG = getOutputGrad();
MatrixPtr inG0 = getInputGrad(0);
MatrixPtr inG1 = getInputGrad(1);

REGISTER_TIMER_INFO("BwConvShiftTimer", getName().c_str());

if (inG0 && inG1) {
if (!(inG0 && inG1)) {
CHECK(!inG0 || !inG1) << "Not supported";
}

if (!isSeqType()) {
MatrixPtr inV0 = getInputValue(0);
MatrixPtr inV1 = getInputValue(1);
MatrixPtr outG = getOutputGrad();
outG->circularConvDerivative(*outG, *inV0, *inV1, *inG0, *inG1);
} else {
CHECK(!inG0 || !inG1) << "Not supported";
circularConvSeqDerivative();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里修改同forward函数

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}
}

Expand Down