-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Add ctc edit distance operator #5300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
db69417
b7a4e3d
f5681f1
6bc6ccd
116687a
b82049b
c16d1ca
4745a0b
36ec3e9
2c1adb0
2e49fac
0250e54
f594ca4
a1935b2
f3dcd00
fe0ef91
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,87 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include "paddle/operators/ctc_edit_distance_op.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| class CTCEditDistanceOp : public framework::OperatorWithKernel { | ||
| public: | ||
| using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
|
||
| void InferShape(framework::InferShapeContext *ctx) const override { | ||
| PADDLE_ENFORCE(ctx->HasInput("X1"), "Input(X1) shouldn't be null."); | ||
| PADDLE_ENFORCE(ctx->HasInput("X2"), "Input(X2) shouldn't be null."); | ||
| PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null."); | ||
| ctx->SetOutputDim("Out", {1}); | ||
| } | ||
|
|
||
| protected: | ||
| framework::OpKernelType GetKernelType( | ||
| const framework::ExecutionContext &ctx) const override { | ||
| return framework::OpKernelType(framework::DataType::FP32, | ||
| ctx.device_context()); | ||
| } | ||
| }; | ||
|
|
||
| class CTCEditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { | ||
| public: | ||
| CTCEditDistanceOpMaker(framework::OpProto *proto, | ||
| framework::OpAttrChecker *op_checker) | ||
| : OpProtoAndCheckerMaker(proto, op_checker) { | ||
| AddInput("X1", | ||
| "(2-D tensor with shape [M x 1]) The indices for " | ||
| "hypothesis string"); | ||
| AddInput("X2", | ||
| "(2-D tensor with shape [N x 1]) The indices " | ||
| "for reference string."); | ||
|
||
| AddAttr<bool>("normalized", | ||
| "(bool, default false) Indicated whether " | ||
| "normalize the Output(Out) by the length of reference " | ||
| "string (X2).") | ||
| .SetDefault(false); | ||
|
||
| AddOutput("Out", | ||
| "(2-D tensor with shape [1 x 1]) " | ||
| "The output distance of CTCEditDistance operator."); | ||
|
||
| AddComment(R"DOC( | ||
|
|
||
| CTCEditDistance operator computes the edit distance of two sequences, one named | ||
| hypothesis with length M and another named reference with length N. | ||
|
|
||
| Edit distance, also called Levenshtein distance, measures how dissimilar two strings | ||
| are by counting the minimum number of operations to transform one string into anthor. | ||
| Here the operations include insertion, deletion, and substitution. For example, | ||
| given hypothesis string A = "kitten" and reference B = "sitting", the edit distance | ||
| is 3 for A will be transformed into B at least after two substitutions and one | ||
| insertion: | ||
|
|
||
| "kitten" -> "sitten" -> "sittin" -> "sitting" | ||
|
|
||
| If Attr(normalized) is true, the edit distance will be divided by the length of | ||
| reference string N. | ||
| )DOC"); | ||
| } | ||
| }; | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle | ||
|
|
||
| namespace ops = paddle::operators; | ||
|
|
||
| REGISTER_OP_WITHOUT_GRADIENT(ctc_edit_distance, ops::CTCEditDistanceOp, | ||
| ops::CTCEditDistanceOpMaker); | ||
| REGISTER_OP_CPU_KERNEL( | ||
| ctc_edit_distance, | ||
| ops::CTCEditDistanceKernel<paddle::platform::CPUPlace, float>); | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,131 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include <algorithm> | ||
| #include "paddle/framework/op_registry.h" | ||
| #include "paddle/platform/cuda_helper.h" | ||
| #include "paddle/platform/gpu_info.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| using platform::PADDLE_CUDA_NUM_THREADS; | ||
|
|
||
| template <typename T> | ||
| __global__ void FillFirstRow(T* dist, const int N) { | ||
| int idx = blockDim.x * blockIdx.x + threadIdx.x; | ||
| if (idx < N + 1) { | ||
| dist[idx] = idx; | ||
| } | ||
| } | ||
|
|
||
| template <typename T> | ||
| __global__ void FillFirstColumn(T* dist, const int M, const int N) { | ||
| int idx = blockDim.x * blockIdx.x + threadIdx.x; | ||
| if (idx < M + 1) { | ||
| dist[idx * (N + 1)] = idx; | ||
| } | ||
| } | ||
|
|
||
| template <typename T> | ||
| __global__ void Levenshtein(T* dist, const int* x1, const int* x2, const int M, | ||
| const int N, const int start) { | ||
| int idx = blockDim.x * blockIdx.x + threadIdx.x; | ||
| int offset = N; | ||
| int index = start + idx * offset; | ||
| int row = index / (N + 1); | ||
| int col = index % (N + 1); | ||
| if (row > 0 && col > 0 && row < M + 1 && col < N + 1) { | ||
| int cost = x1[row - 1] == x2[col - 1] ? 0 : 1; | ||
| int dels = dist[(row - 1) * (N + 1) + col] + 1; | ||
| int ins = dist[row * (N + 1) + col - 1] + 1; | ||
| int subs = dist[(row - 1) * (N + 1) + (col - 1)] + cost; | ||
| dist[index] = min(dels, min(ins, subs)); | ||
| } | ||
| } | ||
|
|
||
| template <typename T> | ||
| __global__ void SetOutput(T* out, const T* dist, const int M, const int N, | ||
| bool normalized) { | ||
| int idx = blockDim.x * blockIdx.x + threadIdx.x; | ||
| if (idx == 0) { | ||
| out[0] = normalized ? dist[M * (N + 1) + N] / N : dist[M * (N + 1) + N]; | ||
| } | ||
| } | ||
|
|
||
| template <typename Place, typename T> | ||
| class CTCEditDistanceGPUKernel : public framework::OpKernel<T> { | ||
| public: | ||
| void Compute(const framework::ExecutionContext& ctx) const { | ||
| auto* out_t = ctx.Output<framework::Tensor>("Out"); | ||
|
|
||
| auto* x1_t = ctx.Input<framework::Tensor>("X1"); | ||
| auto* x2_t = ctx.Input<framework::Tensor>("X2"); | ||
|
|
||
| out_t->mutable_data<T>(ctx.GetPlace()); | ||
| auto out = out_t->data<T>(); | ||
|
|
||
| auto normalized = ctx.Attr<bool>("normalized"); | ||
| auto stream = reinterpret_cast<const platform::CUDADeviceContext&>( | ||
| ctx.device_context()) | ||
| .stream(); | ||
|
|
||
| auto m = x1_t->numel(); | ||
| auto n = x2_t->numel(); | ||
| T distance = 0.0; | ||
| if (m == 0 || n == 0) { | ||
| distance = std::max(m, n); | ||
| if (normalized) { | ||
| distance = distance / n; | ||
| } | ||
| memory::Copy(boost::get<Place>(ctx.GetPlace()), out, platform::CPUPlace(), | ||
| &distance, sizeof(T), stream); | ||
| } else { | ||
| framework::Tensor dist_t; | ||
| dist_t.Resize({m + 1, n + 1}); | ||
| dist_t.mutable_data<T>(ctx.GetPlace()); | ||
| auto dist = dist_t.data<T>(); | ||
| auto x1 = x1_t->data<int>(); | ||
| auto x2 = x2_t->data<int>(); | ||
|
|
||
| FillFirstColumn<T><<<1 + m / PADDLE_CUDA_NUM_THREADS, | ||
| PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, m, n); | ||
|
|
||
| FillFirstRow<T><<<1 + n / PADDLE_CUDA_NUM_THREADS, | ||
| PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, n); | ||
| // Compute the elements of distance matrix in the anti-diagonal diretion | ||
| for (int64_t slice = 2; slice < m + n + 1; ++slice) { | ||
| int z_m = slice < m + 1 ? 0 : slice - m; | ||
| int z_n = slice < n + 1 ? 0 : slice - n; | ||
| int size = slice - (z_m + z_n) + 1; // number of elments in the same | ||
| // anti-diagonal line to update | ||
| int start = slice < n + 1 ? slice : z_n * (n + 1) - 1; // start index | ||
|
|
||
| Levenshtein<T><<<1 + (size - 1) / PADDLE_CUDA_NUM_THREADS, | ||
| PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, x1, x2, m, | ||
| n, start); | ||
| } | ||
| SetOutput<T><<<1, 1, 0, stream>>>(out, dist, m, n, normalized); | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle | ||
|
|
||
| namespace ops = paddle::operators; | ||
|
|
||
| REGISTER_OP_GPU_KERNEL( | ||
| ctc_edit_distance, | ||
| ops::CTCEditDistanceGPUKernel<paddle::platform::GPUPlace, float>); |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #pragma once | ||
| #include <algorithm> | ||
| #include "paddle/framework/eigen.h" | ||
| #include "paddle/framework/op_registry.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| template <typename Place, typename T> | ||
| class CTCEditDistanceKernel : public framework::OpKernel<T> { | ||
| public: | ||
| void Compute(const framework::ExecutionContext& ctx) const { | ||
| auto* out_t = ctx.Output<framework::Tensor>("Out"); | ||
|
|
||
| auto* x1_t = ctx.Input<framework::Tensor>("X1"); | ||
| auto* x2_t = ctx.Input<framework::Tensor>("X2"); | ||
|
|
||
| out_t->mutable_data<float>(ctx.GetPlace()); | ||
|
|
||
| auto normalized = ctx.Attr<bool>("normalized"); | ||
|
|
||
| auto m = x1_t->numel(); | ||
| auto n = x2_t->numel(); | ||
| T distance = 0.0; | ||
|
||
| if (m == 0) { | ||
| distance = n; | ||
| } else if (n == 0) { | ||
| distance = m; | ||
| } else { | ||
| framework::Tensor dist_t; | ||
| dist_t.Resize({m + 1, n + 1}); | ||
| dist_t.mutable_data<T>(ctx.GetPlace()); | ||
|
||
| auto dist = dist_t.data<T>(); | ||
| auto x1 = x1_t->data<int>(); | ||
| auto x2 = x2_t->data<int>(); | ||
| for (int64_t i = 0; i < m + 1; ++i) { | ||
| dist[i * (n + 1)] = i; | ||
| } | ||
| for (int64_t j = 0; j < n + 1; ++j) { | ||
| dist[j] = j; | ||
| } | ||
| for (int64_t i = 1; i < m + 1; ++i) { | ||
| for (int64_t j = 1; j < n + 1; ++j) { | ||
| int cost = x1[i - 1] == x2[j - 1] ? 0 : 1; | ||
| int dels = dist[(i - 1) * (n + 1) + j] + 1; | ||
| int ins = dist[i * (n + 1) + (j - 1)] + 1; | ||
| int subs = dist[(i - 1) * (n + 1) + (j - 1)] + cost; | ||
| dist[i * (n + 1) + j] = std::min(dels, std::min(ins, subs)); | ||
| } | ||
| } | ||
| distance = dist[m * (n + 1) + n]; | ||
| } | ||
|
|
||
| if (normalized) { | ||
| distance = distance / n; | ||
| } | ||
| auto out = out_t->data<T>(); | ||
| out[0] = distance; | ||
| } | ||
| }; | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| import unittest | ||
| import numpy as np | ||
| from op_test import OpTest | ||
|
|
||
|
|
||
| def Levenshtein(hyp, ref): | ||
| """ Compute the Levenshtein distance between two strings. | ||
|
|
||
| :param hyp: hypothesis string in index | ||
| :type hyp: list | ||
| :param ref: reference string in index | ||
| :type ref: list | ||
| """ | ||
| m = len(hyp) | ||
| n = len(ref) | ||
| if m == 0: | ||
| return n | ||
| if n == 0: | ||
| return m | ||
|
|
||
| dist = np.zeros((m + 1, n + 1)) | ||
| for i in range(0, m + 1): | ||
| dist[i][0] = i | ||
| for j in range(0, n + 1): | ||
| dist[0][j] = j | ||
|
|
||
| for i in range(1, m + 1): | ||
| for j in range(1, n + 1): | ||
| cost = 0 if hyp[i - 1] == ref[j - 1] else 1 | ||
| deletion = dist[i - 1][j] + 1 | ||
| insertion = dist[i][j - 1] + 1 | ||
| substitution = dist[i - 1][j - 1] + cost | ||
| dist[i][j] = min(deletion, insertion, substitution) | ||
| return dist[m][n] | ||
|
|
||
|
|
||
| class TestCTCEditDistanceOp(OpTest): | ||
| def setUp(self): | ||
| self.op_type = "ctc_edit_distance" | ||
| normalized = True | ||
| x1 = np.array([0, 12, 3, 5]).astype("int32") | ||
| x2 = np.array([0, 12, 4, 7, 8]).astype("int32") | ||
|
|
||
| distance = Levenshtein(hyp=x1, ref=x2) | ||
| if normalized is True: | ||
| distance = distance / len(x2) | ||
| self.attrs = {'normalized': normalized} | ||
| self.inputs = {'X1': x1, 'X2': x2} | ||
| self.outputs = {'Out': distance} | ||
|
|
||
| def test_check_output(self): | ||
| self.check_output() | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to check the shape of inputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done