Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions paddle/operators/ctc_edit_distance_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/ctc_edit_distance_op.h"

namespace paddle {
namespace operators {

class CTCEditDistanceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X1"), "Input(X1) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("X2"), "Input(X2) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to check the shape of inputs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

ctx->SetOutputDim("Out", {1});
}

protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(framework::DataType::FP32,
ctx.device_context());
}
};

class CTCEditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CTCEditDistanceOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X1",
"(2-D tensor with shape [M x 1]) The indices for "
"hypothesis string");
AddInput("X2",
"(2-D tensor with shape [N x 1]) The indices "
"for reference string.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • As I know, the inputs of this evaluator are predicted result and ground truth, please use more meaningful names.
  • What is the meaning of M and N? Please give more explanation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

AddAttr<bool>("normalized",
"(bool, default false) Indicated whether "
"normalize the Output(Out) by the length of reference "
"string (X2).")
.SetDefault(false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In ctc, the blank should be removed from the predicted result. In attention, the eos and sos should be removed. Maybe here, we need an attribute named uncare of type std::vector<int>, to set the elements that need to be removed.
It is not implemented in CTCErrorEvaluator.cpp, but is a real need from ocr.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with that the operator needs to accept some uncared input tokens. But it would make the code dirty allowing for the CUDA kernel. How about implementing an independent op to remove the uncared tokens?

AddOutput("Out",
"(2-D tensor with shape [1 x 1]) "
"The output distance of CTCEditDistance operator.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch_size should be supported in this op, the output's shape should be [batch_size, 1].

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

AddComment(R"DOC(

CTCEditDistance operator computes the edit distance of two sequences, one named
hypothesis with length M and another named reference with length N.

Edit distance, also called Levenshtein distance, measures how dissimilar two strings
are by counting the minimum number of operations to transform one string into anthor.
Here the operations include insertion, deletion, and substitution. For example,
given hypothesis string A = "kitten" and reference B = "sitting", the edit distance
is 3 for A will be transformed into B at least after two substitutions and one
insertion:

"kitten" -> "sitten" -> "sittin" -> "sitting"

If Attr(normalized) is true, the edit distance will be divided by the length of
reference string N.
)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OP_WITHOUT_GRADIENT(ctc_edit_distance, ops::CTCEditDistanceOp,
ops::CTCEditDistanceOpMaker);
REGISTER_OP_CPU_KERNEL(
ctc_edit_distance,
ops::CTCEditDistanceKernel<paddle::platform::CPUPlace, float>);
131 changes: 131 additions & 0 deletions paddle/operators/ctc_edit_distance_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <algorithm>
#include "paddle/framework/op_registry.h"
#include "paddle/platform/cuda_helper.h"
#include "paddle/platform/gpu_info.h"

namespace paddle {
namespace operators {

using platform::PADDLE_CUDA_NUM_THREADS;

template <typename T>
__global__ void FillFirstRow(T* dist, const int N) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
if (idx < N + 1) {
dist[idx] = idx;
}
}

template <typename T>
__global__ void FillFirstColumn(T* dist, const int M, const int N) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
if (idx < M + 1) {
dist[idx * (N + 1)] = idx;
}
}

template <typename T>
__global__ void Levenshtein(T* dist, const int* x1, const int* x2, const int M,
const int N, const int start) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = N;
int index = start + idx * offset;
int row = index / (N + 1);
int col = index % (N + 1);
if (row > 0 && col > 0 && row < M + 1 && col < N + 1) {
int cost = x1[row - 1] == x2[col - 1] ? 0 : 1;
int dels = dist[(row - 1) * (N + 1) + col] + 1;
int ins = dist[row * (N + 1) + col - 1] + 1;
int subs = dist[(row - 1) * (N + 1) + (col - 1)] + cost;
dist[index] = min(dels, min(ins, subs));
}
}

template <typename T>
__global__ void SetOutput(T* out, const T* dist, const int M, const int N,
bool normalized) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
if (idx == 0) {
out[0] = normalized ? dist[M * (N + 1) + N] / N : dist[M * (N + 1) + N];
}
}

template <typename Place, typename T>
class CTCEditDistanceGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* out_t = ctx.Output<framework::Tensor>("Out");

auto* x1_t = ctx.Input<framework::Tensor>("X1");
auto* x2_t = ctx.Input<framework::Tensor>("X2");

out_t->mutable_data<T>(ctx.GetPlace());
auto out = out_t->data<T>();

auto normalized = ctx.Attr<bool>("normalized");
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();

auto m = x1_t->numel();
auto n = x2_t->numel();
T distance = 0.0;
if (m == 0 || n == 0) {
distance = std::max(m, n);
if (normalized) {
distance = distance / n;
}
memory::Copy(boost::get<Place>(ctx.GetPlace()), out, platform::CPUPlace(),
&distance, sizeof(T), stream);
} else {
framework::Tensor dist_t;
dist_t.Resize({m + 1, n + 1});
dist_t.mutable_data<T>(ctx.GetPlace());
auto dist = dist_t.data<T>();
auto x1 = x1_t->data<int>();
auto x2 = x2_t->data<int>();

FillFirstColumn<T><<<1 + m / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, m, n);

FillFirstRow<T><<<1 + n / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, n);
// Compute the elements of distance matrix in the anti-diagonal diretion
for (int64_t slice = 2; slice < m + n + 1; ++slice) {
int z_m = slice < m + 1 ? 0 : slice - m;
int z_n = slice < n + 1 ? 0 : slice - n;
int size = slice - (z_m + z_n) + 1; // number of elments in the same
// anti-diagonal line to update
int start = slice < n + 1 ? slice : z_n * (n + 1) - 1; // start index

Levenshtein<T><<<1 + (size - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, x1, x2, m,
n, start);
}
SetOutput<T><<<1, 1, 0, stream>>>(out, dist, m, n, normalized);
}
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OP_GPU_KERNEL(
ctc_edit_distance,
ops::CTCEditDistanceGPUKernel<paddle::platform::GPUPlace, float>);
77 changes: 77 additions & 0 deletions paddle/operators/ctc_edit_distance_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <algorithm>
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"

namespace paddle {
namespace operators {

template <typename Place, typename T>
class CTCEditDistanceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* out_t = ctx.Output<framework::Tensor>("Out");

auto* x1_t = ctx.Input<framework::Tensor>("X1");
auto* x2_t = ctx.Input<framework::Tensor>("X2");

out_t->mutable_data<float>(ctx.GetPlace());

auto normalized = ctx.Attr<bool>("normalized");

auto m = x1_t->numel();
auto n = x2_t->numel();
T distance = 0.0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old codes count the substitution, deletion, insertion error, why not implement it in the op?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually these counts seem useless.

if (m == 0) {
distance = n;
} else if (n == 0) {
distance = m;
} else {
framework::Tensor dist_t;
dist_t.Resize({m + 1, n + 1});
dist_t.mutable_data<T>(ctx.GetPlace());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not necessary because filling the first row and column of this matrix is enough.

auto dist = dist_t.data<T>();
auto x1 = x1_t->data<int>();
auto x2 = x2_t->data<int>();
for (int64_t i = 0; i < m + 1; ++i) {
dist[i * (n + 1)] = i;
}
for (int64_t j = 0; j < n + 1; ++j) {
dist[j] = j;
}
for (int64_t i = 1; i < m + 1; ++i) {
for (int64_t j = 1; j < n + 1; ++j) {
int cost = x1[i - 1] == x2[j - 1] ? 0 : 1;
int dels = dist[(i - 1) * (n + 1) + j] + 1;
int ins = dist[i * (n + 1) + (j - 1)] + 1;
int subs = dist[(i - 1) * (n + 1) + (j - 1)] + cost;
dist[i * (n + 1) + j] = std::min(dels, std::min(ins, subs));
}
}
distance = dist[m * (n + 1) + n];
}

if (normalized) {
distance = distance / n;
}
auto out = out_t->data<T>();
out[0] = distance;
}
};

} // namespace operators
} // namespace paddle
56 changes: 56 additions & 0 deletions python/paddle/v2/fluid/tests/test_ctc_edit_distance_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import unittest
import numpy as np
from op_test import OpTest


def Levenshtein(hyp, ref):
""" Compute the Levenshtein distance between two strings.

:param hyp: hypothesis string in index
:type hyp: list
:param ref: reference string in index
:type ref: list
"""
m = len(hyp)
n = len(ref)
if m == 0:
return n
if n == 0:
return m

dist = np.zeros((m + 1, n + 1))
for i in range(0, m + 1):
dist[i][0] = i
for j in range(0, n + 1):
dist[0][j] = j

for i in range(1, m + 1):
for j in range(1, n + 1):
cost = 0 if hyp[i - 1] == ref[j - 1] else 1
deletion = dist[i - 1][j] + 1
insertion = dist[i][j - 1] + 1
substitution = dist[i - 1][j - 1] + cost
dist[i][j] = min(deletion, insertion, substitution)
return dist[m][n]


class TestCTCEditDistanceOp(OpTest):
def setUp(self):
self.op_type = "ctc_edit_distance"
normalized = True
x1 = np.array([0, 12, 3, 5]).astype("int32")
x2 = np.array([0, 12, 4, 7, 8]).astype("int32")

distance = Levenshtein(hyp=x1, ref=x2)
if normalized is True:
distance = distance / len(x2)
self.attrs = {'normalized': normalized}
self.inputs = {'X1': x1, 'X2': x2}
self.outputs = {'Out': distance}

def test_check_output(self):
self.check_output()


if __name__ == '__main__':
unittest.main()