-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Enrich ConvShift to support sequence data input #2133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
6adf4ac
8cd2222
25cdee6
167ceb5
cec53b9
aa4ac87
a4e5e66
490e35d
eeccac1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -52,6 +52,9 @@ class ConvShiftLayer : public Layer { | |
|
|
||
| void forward(PassType passType) override; | ||
| void backward(const UpdateCallback& callback = nullptr) override; | ||
| bool isSeqType(); | ||
| void circularConvSeq(); | ||
| void circularConvSeqDerivative(); | ||
| }; | ||
|
|
||
| REGISTER_LAYER(conv_shift, ConvShiftLayer); | ||
|
|
@@ -66,42 +69,161 @@ bool ConvShiftLayer::init(const LayerMap& layerMap, | |
| return true; | ||
| } | ||
|
|
||
| bool ConvShiftLayer::isSeqType() { | ||
| const Argument& inLayer0 = getInput(0); | ||
| if (nullptr == inLayer0.sequenceStartPositions) | ||
| return false; | ||
| else | ||
| return true; | ||
| } | ||
|
|
||
| void ConvShiftLayer::circularConvSeq() { | ||
| const Argument& inLayer0 = getInput(0); | ||
| MatrixPtr in0 = inLayer0.value; | ||
| MatrixPtr in1 = getInputValue(1); | ||
| MatrixPtr out = getOutputValue(); | ||
| const ICpuGpuVectorPtr& sequenceStartPositions = | ||
| inLayer0.sequenceStartPositions; | ||
|
|
||
| size_t width0 = in0->getWidth(); | ||
| size_t numSeqs = sequenceStartPositions->getSize() - 1; | ||
| size_t height0 = in0->getHeight(); | ||
| size_t width1 = in1->getWidth(); | ||
| size_t height1 = in1->getHeight(); | ||
|
|
||
| CHECK_EQ(numSeqs, height1); | ||
| CHECK_EQ(width0, out->getWidth()); | ||
| CHECK_EQ(height0, out->getHeight()); | ||
|
|
||
| CHECK_EQ(width1 % 2, 1U); | ||
|
|
||
| real* inV0 = in0->getData(); | ||
| const int* startPosIntPtr = sequenceStartPositions->getData(false); | ||
| real* inV1 = in1->getData(); | ||
| real* outV = out->getData(); | ||
|
|
||
| int leftCtxLen = (width1 - 1) / 2; | ||
| for (size_t x = 0; x < numSeqs - 1; x++) { | ||
|
||
| int curSeqLen = startPosIntPtr[x + 1]; | ||
| size_t curSeqWidth = curSeqLen * width0; | ||
| for (size_t i = 0; i < curSeqWidth; i++) { | ||
| for (size_t j = 0; j < width1; ++j) { | ||
| int index = i + j - leftCtxLen; | ||
| index = (index + curSeqWidth) % curSeqWidth; | ||
|
||
| int outVRowOffset = i / width0; | ||
| int outVColOffset = i % width0; | ||
|
||
| int inV0RowOffset = index / width0; | ||
| int inV0ColOffset = index % width0; | ||
| (outV + outVRowOffset)[outVColOffset] += | ||
| (inV0 + inV0RowOffset)[inV0ColOffset] * inV1[j]; | ||
| } | ||
| } | ||
| outV += curSeqWidth; | ||
| inV0 += curSeqWidth; | ||
| inV1 += width1; | ||
| } | ||
| } | ||
|
|
||
| void ConvShiftLayer::circularConvSeqDerivative() { | ||
| const Argument& inLayer0 = getInput(0); | ||
| MatrixPtr in0 = inLayer0.value; | ||
| MatrixPtr in1 = getInputValue(1); | ||
| MatrixPtr inG0 = getInputGrad(0); | ||
| MatrixPtr inG1 = getInputGrad(1); | ||
| MatrixPtr outG = getOutputGrad(); | ||
| const ICpuGpuVectorPtr& sequenceStartPositions = | ||
| inLayer0.sequenceStartPositions; | ||
|
|
||
| size_t height0 = in0->getHeight(); | ||
| size_t height1 = in1->getHeight(); | ||
| size_t numSeqs = sequenceStartPositions->getSize() - 1; | ||
| size_t width0 = in0->getWidth(); | ||
| size_t width1 = in1->getWidth(); | ||
|
|
||
| CHECK_EQ(height1, numSeqs); | ||
| CHECK_EQ(height0, inG0->getHeight()); | ||
| CHECK_EQ(width0, inG0->getWidth()); | ||
| CHECK_EQ(height1, inG1->getHeight()); | ||
| CHECK_EQ(width1, inG1->getWidth()); | ||
| CHECK_EQ(height0, outG->getHeight()); | ||
| CHECK_EQ(width0, outG->getWidth()); | ||
|
||
|
|
||
| const int* startPosIntPtr = sequenceStartPositions->getData(false); | ||
| real* outGV = outG->getData(); | ||
| real* inV0 = in0->getData(); | ||
| real* inV1 = in1->getData(); | ||
| real* inGV0 = inG0->getData(); | ||
| real* inGV1 = inG1->getData(); | ||
|
||
|
|
||
| int leftCtxLen = (width1 - 1) / 2; | ||
| for (size_t x = 0; x < numSeqs - 1; x++) { | ||
| int curSeqLen = startPosIntPtr[x + 1]; | ||
| size_t curSeqWidth = curSeqLen * width0; | ||
| for (size_t j = 0; j < width1; j++) { | ||
| for (size_t i = 0; i < curSeqWidth; i++) { | ||
| int index = i + j - leftCtxLen; | ||
| index = (index + curSeqWidth) % curSeqWidth; | ||
| int inGV0RowOffset = index / width0; | ||
| int inGV0ColOffset = index % width0; | ||
| int outGVRowOffset = i / width0; | ||
| int outGVColOffset = i % width0; | ||
|
||
| (inGV0 + inGV0RowOffset)[inGV0ColOffset] += | ||
| (outGV + outGVRowOffset)[outGVColOffset] * inV1[j]; | ||
| inGV1[j] += (outGV + outGVRowOffset)[outGVColOffset] * | ||
| (inGV0 + inGV0RowOffset)[inGV0ColOffset]; | ||
| } | ||
| } | ||
| outGV += curSeqWidth; | ||
| inV0 += curSeqWidth; | ||
| inV1 += width1; | ||
| inGV0 += curSeqWidth; | ||
| inGV1 += width1; | ||
| } | ||
| } | ||
|
|
||
| void ConvShiftLayer::forward(PassType passType) { | ||
| Layer::forward(passType); | ||
|
|
||
| MatrixPtr inV0 = getInputValue(0); | ||
| MatrixPtr inV1 = getInputValue(1); | ||
|
|
||
| size_t batchSize = inV0->getHeight(); | ||
| size_t dataDim = inV0->getWidth(); | ||
|
|
||
| CHECK_EQ(batchSize, inV1->getHeight()); | ||
| CHECK_EQ(dataDim, getSize()); | ||
|
|
||
| { | ||
| REGISTER_TIMER_INFO("FwResetTimer", getName().c_str()); | ||
| resetOutput(batchSize, dataDim); | ||
| } | ||
|
|
||
| MatrixPtr outV = getOutputValue(); | ||
|
|
||
| REGISTER_TIMER_INFO("FwConvShiftTimer", getName().c_str()); | ||
| outV->circularConv(*inV0, *inV1); | ||
| if (!isSeqType()) { | ||
| MatrixPtr inV1 = getInputValue(1); | ||
| CHECK_EQ(batchSize, inV1->getHeight()); | ||
| MatrixPtr outV = getOutputValue(); | ||
| outV->circularConv(*inV0, *inV1); | ||
| } else { | ||
| circularConvSeq(); | ||
| } | ||
|
||
| } | ||
|
|
||
| void ConvShiftLayer::backward(const UpdateCallback& callback) { | ||
| MatrixPtr inV0 = getInputValue(0); | ||
| MatrixPtr inV1 = getInputValue(1); | ||
| MatrixPtr outG = getOutputGrad(); | ||
| MatrixPtr inG0 = getInputGrad(0); | ||
| MatrixPtr inG1 = getInputGrad(1); | ||
|
|
||
| REGISTER_TIMER_INFO("BwConvShiftTimer", getName().c_str()); | ||
|
|
||
| if (inG0 && inG1) { | ||
| if (!(inG0 && inG1)) { | ||
| CHECK(!inG0 || !inG1) << "Not supported"; | ||
| } | ||
|
|
||
| if (!isSeqType()) { | ||
| MatrixPtr inV0 = getInputValue(0); | ||
| MatrixPtr inV1 = getInputValue(1); | ||
| MatrixPtr outG = getOutputGrad(); | ||
| outG->circularConvDerivative(*outG, *inV0, *inV1, *inG0, *inG1); | ||
| } else { | ||
| CHECK(!inG0 || !inG1) << "Not supported"; | ||
| circularConvSeqDerivative(); | ||
|
||
| } | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
22 ~ 42行的注释也需要对应地修改。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经移到Matrix中