Skip to content

Commit a6f628a

Browse files
guangy10facebook-github-bot
authored andcommitted
Add wav2letter model to examples (#71)
Summary: Pull Request resolved: #71 Bring in `Wav2Letter` model to `executorch/examples`. - General info about `Wav2Letter` model: https://ai.meta.com/tools/wav2letter/ - Info about `Wav2Letter` model being used in this example: https://pytorch.org/audio/stable/_modules/torchaudio/models/wav2letter.html Reviewed By: JacobSzwejbka Differential Revision: D48403704 fbshipit-source-id: 3e5e56861dc8a4a3b93744a35c3b6dc8f330c1ce
1 parent d85b2ad commit a6f628a

File tree

6 files changed

+72
-0
lines changed

6 files changed

+72
-0
lines changed

examples/export/test/test_export.py

+8
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,11 @@ def test_vit_export_to_executorch(self):
8181
self._assert_eager_lowered_same_result(
8282
eager_model, example_inputs, self.validate_tensor_allclose
8383
)
84+
85+
def test_w2l_export_to_executorch(self):
86+
eager_model, example_inputs = MODEL_NAME_TO_MODEL["w2l"]()
87+
eager_model = eager_model.eval()
88+
89+
self._assert_eager_lowered_same_result(
90+
eager_model, example_inputs, self.validate_tensor_allclose
91+
)

examples/models/TARGETS

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ python_library(
1111
"//executorch/examples/models/mobilenet_v2:mv2_export",
1212
"//executorch/examples/models/mobilenet_v3:mv3_export",
1313
"//executorch/examples/models/torchvision_vit:vit_export",
14+
"//executorch/examples/models/wav2letter:w2l_export",
1415
"//executorch/exir/backend:compile_spec_schema",
1516
],
1617
)

examples/models/models.py

+8
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,13 @@ def gen_torchvision_vit_model_and_inputs() -> Tuple[torch.nn.Module, Any]:
9595
return TorchVisionViTModel.get_model(), TorchVisionViTModel.get_example_inputs()
9696

9797

98+
def gen_wav2letter_model_and_inputs() -> Tuple[torch.nn.Module, Any]:
99+
from ..models.wav2letter import Wav2LetterModel
100+
101+
model = Wav2LetterModel()
102+
return model.get_model(), model.get_example_inputs()
103+
104+
98105
MODEL_NAME_TO_MODEL = {
99106
"mul": lambda: (MulModule(), MulModule.get_example_inputs()),
100107
"linear": lambda: (LinearModule(), LinearModule.get_example_inputs()),
@@ -103,4 +110,5 @@ def gen_torchvision_vit_model_and_inputs() -> Tuple[torch.nn.Module, Any]:
103110
"mv2": gen_mobilenet_v2_model_inputs,
104111
"mv3": gen_mobilenet_v3_model_inputs,
105112
"vit": gen_torchvision_vit_model_and_inputs,
113+
"w2l": gen_wav2letter_model_and_inputs,
106114
}

examples/models/wav2letter/TARGETS

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+
python_library(
4+
name = "w2l_export",
5+
srcs = [
6+
"__init__.py",
7+
"export.py",
8+
],
9+
base_module = "executorch.examples.models.wav2letter",
10+
deps = [
11+
"//caffe2:torch",
12+
"//pytorch/audio:torchaudio",
13+
],
14+
)
+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .export import Wav2LetterModel
8+
9+
__all__ = [
10+
Wav2LetterModel,
11+
]

examples/models/wav2letter/export.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
9+
import torch
10+
from torchaudio import models
11+
12+
FORMAT = "[%(filename)s:%(lineno)s] %(message)s"
13+
logging.basicConfig(format=FORMAT)
14+
15+
16+
class Wav2LetterModel:
17+
def __init__(self):
18+
self.batch_size = 10
19+
self.input_frames = 700
20+
self.vocab_size = 4096
21+
22+
def get_model(self):
23+
logging.info("loading wav2letter model")
24+
wav2letter = models.Wav2Letter(num_classes=self.vocab_size)
25+
logging.info("loaded wav2letter model")
26+
return wav2letter
27+
28+
def get_example_inputs(self):
29+
input_shape = (self.batch_size, 1, self.input_frames)
30+
return (torch.randn(input_shape),)

0 commit comments

Comments
 (0)