Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mmdeploy/codebase/mmocr/models/text_recognition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from . import crnn_decoder # noqa: F401,F403
from . import encoder_decoder_recognizer # noqa: F401,F403
from . import lstm_layer # noqa: F401,F403
from . import nrtr_decoder # noqa: F401,F403
from . import sar_decoder # noqa: F401,F403
from . import sar_encoder # noqa: F401,F403
from . import satrn_encoder # noqa: F401,F403
from . import transformer_module # noqa: F401,F403
36 changes: 36 additions & 0 deletions mmdeploy/codebase/mmocr/models/text_recognition/nrtr_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Sequence

import torch

from mmdeploy.core import FUNCTION_REWRITER


@FUNCTION_REWRITER.register_rewriter(
func_name='mmocr.models.textrecog.NRTRDecoder._get_source_mask')
def nrtr_decoder___get_source_mask(
self, src_seq: torch.Tensor,
valid_ratios: Sequence[float]) -> torch.Tensor:
"""Generate mask for source sequence.

Args:
src_seq (torch.Tensor): Image sequence. Shape :math:`(N, T, C)`.
valid_ratios (list[float]): The valid ratio of input image. For
example, if the width of the original image is w1 and the width
after padding is w2, then valid_ratio = w1/w2. Source mask is
used to cover the area of the padding region.

Returns:
Tensor or None: Source mask. Shape :math:`(N, T)`. The region of
padding area are False, and the rest are True.
"""

N, T, _ = src_seq.size()
mask = None
if len(valid_ratios) > 0:
mask = src_seq.new_zeros((N, T), device=src_seq.device)
valid_width = min(T, math.ceil(T * valid_ratios[0]))
mask[:, :valid_width] = 1

return mask
42 changes: 42 additions & 0 deletions mmdeploy/codebase/mmocr/models/text_recognition/satrn_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import List

from mmocr.structures import TextRecogDataSample
from torch import Tensor

from mmdeploy.core import FUNCTION_REWRITER


@FUNCTION_REWRITER.register_rewriter(
func_name='mmocr.models.textrecog.SATRNEncoder.forward')
def satrn_encoder__forward(
self,
feat: Tensor,
data_samples: List[TextRecogDataSample] = None) -> Tensor:
"""Forward propagation of encoder.

Args:
feat (Tensor): Feature tensor of shape :math:`(N, D_m, H, W)`.
data_samples (list[TextRecogDataSample]): Batch of
TextRecogDataSample, containing `valid_ratio` information.
Defaults to None.

Returns:
Tensor: A tensor of shape :math:`(N, T, D_m)`.
"""
valid_ratio = 1.0
feat = self.position_enc(feat)
n, c, h, w = feat.size()
mask = feat.new_zeros((n, h, w))
valid_width = min(w, math.ceil(w * valid_ratio))
mask[:, :, :valid_width] = 1
mask = mask.view(n, h * w)
feat = feat.view(n, c, h * w)

output = feat.permute(0, 2, 1).contiguous()
for enc_layer in self.layer_stack:
output = enc_layer(output, h, w, mask)
output = self.layer_norm(output)

return output
82 changes: 81 additions & 1 deletion tests/test_codebase/test_mmocr/test_mmocr_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from mmdeploy.core import RewriterContext, patch_model
from mmdeploy.utils import Backend, Codebase
from mmdeploy.utils.config_utils import load_config
from mmdeploy.utils.test import (WrapModel, check_backend, get_model_outputs,
from mmdeploy.utils.test import (WrapModel, check_backend, get_backend_outputs,
get_model_outputs, get_onnx_model,
get_rewrite_outputs)

try:
Expand Down Expand Up @@ -155,6 +156,85 @@ def test_bidirectionallstm(backend: Backend):
assert rewrite_outputs is not None


@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
def test_nrtr_decoder__get_source_mask(backend: Backend):
from mmocr.models.textrecog import NRTRDecoder
deploy_cfg = mmengine.Config(
dict(
onnx_config=dict(
input_names=['input'],
output_names=['output'],
input_shape=None,
dynamic_axes={
'input': {
0: 'batch',
},
'output': {
0: 'batch',
}
}),
backend_config=dict(type=backend.value, model_inputs=None),
codebase_config=dict(type='mmocr', task='TextRecognition')))
src_seq = torch.rand(1, 200, 256)
batch_src_seq = src_seq.expand(3, 200, 256)
decoder = NRTRDecoder(
dictionary=dict(
type='Dictionary',
dict_file='tests/test_codebase/test_mmocr/'
'data/lower_english_digits.txt',
with_start=True,
with_end=True,
same_start_end=True,
with_padding=True,
with_unknown=True))

wrapped_model = WrapModel(decoder, '_get_source_mask')
model_inputs = {'src_seq': src_seq, 'valid_ratios': torch.Tensor([1.0])}
batch_model_inputs = {'input': batch_src_seq}
ir_file_path = get_onnx_model(wrapped_model, model_inputs, deploy_cfg)
backend_outputs = get_backend_outputs(ir_file_path, batch_model_inputs,
deploy_cfg)[0].numpy()
num_elements = np.prod(backend_outputs.shape[1:])
# batch results should be same
assert np.sum(backend_outputs[0] == backend_outputs[1]) == num_elements \
and np.sum(backend_outputs[1] == backend_outputs[2]) == num_elements


@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
def test_satrn_encoder__get_source_mask(backend: Backend):
from mmocr.models.textrecog import SATRNEncoder

deploy_cfg = mmengine.Config(
dict(
onnx_config=dict(
input_names=['input'],
output_names=['output'],
input_shape=None,
dynamic_axes={
'input': {
0: 'batch',
},
'output': {
0: 'batch',
}
}),
backend_config=dict(type=backend.value, model_inputs=None),
codebase_config=dict(type='mmocr', task='TextRecognition')))
encoder = SATRNEncoder(d_k=4, d_v=4, d_model=32, d_inner=32 * 4)
feat = torch.randn(1, 32, 32, 32)
batch_feat = feat.expand(3, 32, 32, 32)
wrapped_model = WrapModel(encoder, 'forward')
model_inputs = {'feat': feat}
batch_model_inputs = {'input': batch_feat}
ir_file_path = get_onnx_model(wrapped_model, model_inputs, deploy_cfg)
backend_outputs = get_backend_outputs(ir_file_path, batch_model_inputs,
deploy_cfg)[0].numpy()
num_elements = np.prod(backend_outputs.shape[1:])
# batch results should be same
assert np.sum(backend_outputs[0] == backend_outputs[1]) == num_elements \
and np.sum(backend_outputs[1] == backend_outputs[2]) == num_elements


@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
def test_simple_test_of_single_stage_text_detector(backend: Backend):
"""Test simple_test single_stage_text_detector."""
Expand Down