Skip to content

Commit 859e924

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Update OSS repo (#2033)
Summary: Pull Request resolved: #2033 Update the OSS Xtensa repo with more up to date compiler and quantizer things. Introduce a test folder and a conv1d test. Reviewed By: tarun292, cccclai Differential Revision: D54034581 fbshipit-source-id: c2bf0c43897a2ef7dff291698370d2583433a6ba
1 parent 948760a commit 859e924

20 files changed

+1434
-157
lines changed

backends/vulkan/test/op_tests/utils/codegen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def prepack_ref(self, ref: ValueRef) -> bool:
165165
else:
166166
return ref.supports_prepack and self.should_prepack
167167

168-
def create_value_for(self, ref: ValueRefList) -> str:
168+
def create_value_for(self, ref: ValueRefList) -> str: # noqa: C901
169169
if isinstance(ref, list):
170170
ret_str = ""
171171
for r in ref:

docs/source/build-run-xtensa.md

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,14 @@ examples/xtensa/
6868
├── aot
6969
├── kernels
7070
├── ops
71+
├── tests
7172
├── third-party
7273
└── utils
7374
```
7475

7576
***AoT (Ahead-of-Time) Components***:
7677

77-
The AoT folder contains all of the python scripts and functions needed to export the model to an ExecuTorch `.pte` file. In our case, [export_example.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/export_example.py) defines a model and some example inputs (set to a vector of ones), and runs it through the quantizer (from [quantizer.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/quantizer.py)). Then a few compiler passes, also defined in [quantizer.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/quantizer.py), will replace operators with custom ones that are supported and optimized on the chip. Any operator needed to compute things should be defined in [meta_registrations.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/meta_registrations.py) and have corresponding implemetations in the other folders.
78+
The AoT folder contains all of the python scripts and functions needed to export the model to an ExecuTorch `.pte` file. In our case, [export_example.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/export_example.py) is an API that takes a model (nn.Module) and representative inputs and runs it through the quantizer (from [quantizer.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/quantizer.py)). Then a few compiler passes, also defined in [quantizer.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/quantizer.py), will replace operators with custom ones that are supported and optimized on the chip. Any operator needed to compute things should be defined in [meta_registrations.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/meta_registrations.py) and have corresponding implemetations in the other folders.
7879

7980
***Operators***:
8081

@@ -97,17 +98,31 @@ cd executorch
9798
python3 -m examples.portable.scripts.export --model_name="add"
9899
```
99100

100-
***Quantized Linear***:
101+
***Quantized Operators***:
101102

102-
The second, more complex model is a quantized [linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) operation. The model is defined [here](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/export_example.py#L88). Linear is the backbone of most Automatic Speech Recognition (ASR) models.
103+
The other, more complex model are custom operators, including:
104+
- a quantized [linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) operation. The model is defined [here](https://github.com/pytorch/executorch/blob/main/examples/xtensa/tests/quantized_linear_example.py#L28). Linear is the backbone of most Automatic Speech Recognition (ASR) models.
105+
- a quantized [conv1d](https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html) operation. The model is defined [here](https://github.com/pytorch/executorch/blob/main/examples/xtensa/tests/quantized_conv1d_example.py#L36). Convolutions are important in wake word and many denoising models.
103106

104-
The generated file is called `XtensaDemoModel.pte`.
107+
In both cases the generated file is called `XtensaDemoModel.pte`.
108+
109+
```bash
110+
cd executorch
111+
python3 -m examples.xtensa.tests.quantized_<linear,conv1d>_example
112+
```
113+
114+
***Small Model: RNNT predictor***:
115+
116+
The torchaudio [RNNT-emformer](https://pytorch.org/audio/stable/tutorials/online_asr_tutorial.html) model is an Automatic Speech Recognition (ASR) model, comprised of three different submodels: an encoder, a predictor and a joiner.
117+
The predictor is a sequence of basic ops (embedding, ReLU, linear, layer norm) and can be exported using:
105118

106119
```bash
107120
cd executorch
108-
python3 -m examples.xtensa.aot.export_example
121+
python3 -m examples.xtensa.tests.rnnt_predictor_quantized_example
109122
```
110123

124+
The generated file is called `XtensaDemoModel.pte`.
125+
111126
### Runtime
112127

113128
**Building the DSP firmware image**
@@ -139,12 +154,14 @@ cmake -DBUCK2=buck2 \
139154
-DCMAKE_TOOLCHAIN_FILE=<path_to_executorch>/examples/xtensa/xtensa.cmake \
140155
-DCMAKE_INSTALL_PREFIX=cmake-out \
141156
-DCMAKE_BUILD_TYPE=Debug \
157+
-DPYTHON_EXECUTABLE=python3 \
158+
-DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \
142159
-DEXECUTORCH_BUILD_HOST_TARGETS=ON \
143160
-DEXECUTORCH_BUILD_EXECUTOR_RUNNER=OFF \
161+
-DEXECUTORCH_BUILD_PTHREADPOOL=OFF \
162+
-DEXECUTORCH_BUILD_CPUINFO=OFF \
144163
-DEXECUTORCH_BUILD_FLATC=OFF \
145164
-DFLATC_EXECUTABLE="$(which flatc)" \
146-
-DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \
147-
-DPYTHON_EXECUTABLE=python3 \
148165
-Bcmake-out .
149166

150167
cmake --build cmake-out -j8 --target install --config Debug
@@ -196,6 +213,6 @@ First 20 elements of output 0
196213

197214
In this tutorial, you have learned how to export a quantized operation, build the ExecuTorch runtime and run this model on the Xtensa HiFi4 DSP chip.
198215

199-
The model in this tutorial is a typical operation appearing in ASR models, and can be extended to a complete ASR model by creating the model in [export_example.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/export_example.py) and adding the needed operators/kernels to [operators](https://github.com/pytorch/executorch/blob/main/examples/xtensa/ops) and [kernels](https://github.com/pytorch/executorch/blob/main/examples/xtensa/kernels).
216+
The (quantized linear) model in this tutorial is a typical operation appearing in ASR models, and can be extended to a complete ASR model by creating the model as a new test and adding the needed operators/kernels to [operators](https://github.com/pytorch/executorch/blob/main/examples/xtensa/ops) and [kernels](https://github.com/pytorch/executorch/blob/main/examples/xtensa/kernels).
200217

201218
Other models can be created following the same structure, always assuming that operators and kernels are available.

examples/xtensa/aot/compiler.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
from typing import Any, Callable
9+
10+
import torch
11+
12+
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
13+
14+
from torch.export import export
15+
from torch.export.exported_program import ExportedProgram
16+
17+
18+
def export_program(
19+
model: Callable,
20+
inputs: Any,
21+
pt2_quant: bool = False,
22+
) -> ExportedProgram:
23+
# we don't support training mode. Make it eval
24+
if hasattr(model, "eval"):
25+
if pt2_quant:
26+
# pyre-fixme[6]: Incompatible parameter type.
27+
torch.ao.quantization.move_exported_model_to_eval(model)
28+
else:
29+
# pyre-fixme[16]: Anonymous callable has no attribute `eval`.
30+
model.eval()
31+
32+
# if it's already an ExportedProgram, just return it
33+
if isinstance(model, ExportedProgram):
34+
return model
35+
36+
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
37+
38+
# Prevent mkldnn decompositions
39+
torch._C._set_mkldnn_enabled(False)
40+
41+
# else: capture the model and return it.
42+
return export(model, inputs)
43+
44+
45+
# Export the model and lower it it edge IR.
46+
def export_to_edge(
47+
model: Callable,
48+
inputs: Any,
49+
pt2_quant: bool = False,
50+
dump_graphs: bool = False,
51+
) -> EdgeProgramManager:
52+
# Export the model into an ExportedProgram.
53+
expo_program = export_program(model, inputs, pt2_quant)
54+
55+
if dump_graphs:
56+
logging.info(f"Exported graph:\n{expo_program.graph_module.graph}")
57+
58+
# Call to_edge to convert the graph to edge IR.
59+
edge_prog_manager = to_edge(
60+
expo_program, compile_config=EdgeCompileConfig(_check_ir_validity=False)
61+
)
62+
63+
if dump_graphs:
64+
logging.info(
65+
f"Edge graph:\n{edge_prog_manager.exported_program().graph_module.graph}"
66+
)
67+
68+
return edge_prog_manager

examples/xtensa/aot/export_example.py

Lines changed: 15 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -10,47 +10,27 @@
1010

1111
from .meta_registrations import * # noqa
1212

13-
import torch
14-
from executorch.exir import EdgeCompileConfig
1513
from torch._export import capture_pre_autograd_graph
1614
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
1715

18-
from ...portable.utils import export_to_edge, save_pte_program
16+
from ...portable.utils import save_pte_program
1917

18+
from .compiler import export_to_edge
2019
from .quantizer import (
2120
QuantFusion,
2221
ReplacePT2DequantWithXtensaDequant,
2322
ReplacePT2QuantWithXtensaQuant,
24-
XtensaQuantizer,
23+
XtensaBaseQuantizer,
2524
)
2625

2726

2827
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
2928
logging.basicConfig(level=logging.INFO, format=FORMAT)
3029

3130

32-
if __name__ == "__main__":
33-
in_features = 32
34-
out_features = 16
35-
bias = True
36-
shape = [64, in_features]
37-
38-
class QuantizedLinear(torch.nn.Module):
39-
def __init__(self, in_features: int, out_features: int, bias: bool):
40-
super().__init__()
41-
self.output_linear = torch.nn.Linear(in_features, out_features, bias=bias)
42-
43-
def forward(self, x: torch.Tensor):
44-
output_linear_out = self.output_linear(x)
45-
return output_linear_out
46-
47-
model = QuantizedLinear(in_features, out_features, bias)
48-
model.eval()
49-
50-
example_inputs = (torch.ones(shape),)
51-
31+
def export_xtensa_model(model, example_inputs):
5232
# Quantizer
53-
quantizer = XtensaQuantizer()
33+
quantizer = XtensaBaseQuantizer()
5434

5535
# Export
5636
model_exp = capture_pre_autograd_graph(model, example_inputs)
@@ -66,29 +46,20 @@ def forward(self, x: torch.Tensor):
6646
patterns = [q.pattern for q in quantizer.quantizers]
6747
QuantFusion(patterns)(converted_model)
6848

69-
# pre-autograd export. eventually this will become torch.export
70-
converted_model_exp = capture_pre_autograd_graph(converted_model, example_inputs)
49+
# Get edge program (note: the name will change to export_to_xtensa in future PRs)
50+
edge_prog_manager = export_to_edge(converted_model, example_inputs, pt2_quant=True)
7151

72-
converted_model_exp = torch.ao.quantization.move_exported_model_to_eval(
73-
converted_model_exp
52+
# Run a couple required passes for quant/dequant ops
53+
xtensa_prog_manager = edge_prog_manager.transform(
54+
[ReplacePT2QuantWithXtensaQuant(), ReplacePT2DequantWithXtensaDequant()],
55+
check_ir_validity=False,
7456
)
7557

76-
exec_prog = (
77-
export_to_edge(
78-
converted_model_exp,
79-
example_inputs,
80-
edge_compile_config=EdgeCompileConfig(
81-
_check_ir_validity=False,
82-
),
83-
)
84-
.transform(
85-
[ReplacePT2QuantWithXtensaQuant(), ReplacePT2DequantWithXtensaDequant()],
86-
check_ir_validity=False,
87-
)
88-
.to_executorch()
89-
)
58+
exec_prog = xtensa_prog_manager.to_executorch()
9059

91-
logging.info(f"Final exported graph:\n{exec_prog.exported_program().graph}")
60+
logging.info(
61+
f"Final exported graph module:\n{exec_prog.exported_program().graph_module}"
62+
)
9263

9364
# Save the program as XtensaDemoModel.pte
9465
save_pte_program(exec_prog, "XtensaDemoModel")

examples/xtensa/aot/meta_registrations.py

Lines changed: 88 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import Optional, Tuple
8+
79
import torch
810
from executorch.exir.scalar_type import ScalarType
911
from torch.library import impl, Library
1012

13+
from .utils import get_conv1d_output_size
14+
1115
lib = Library("xtensa", "DEF")
1216

1317
lib.define(
@@ -25,10 +29,31 @@
2529
)
2630

2731
lib.define(
28-
"quantized_linear_pt2(Tensor src, Tensor weight, Tensor bias, float src_scale, int src_zero_point, float weight_scale, int weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point) -> (Tensor Z)"
32+
"quantized_layer_norm(Tensor X, Tensor X_scale, Tensor X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> (Tensor Y)"
33+
)
34+
35+
lib.define(
36+
"quantized_layer_norm.out(Tensor X, Tensor X_scale, Tensor X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
37+
)
38+
39+
lib.define(
40+
"quantized_linear(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
41+
)
42+
lib.define(
43+
"quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
44+
)
45+
46+
lib.define("quantized_relu(Tensor X, Tensor X_zero_point) -> (Tensor Y)")
47+
48+
lib.define(
49+
"quantized_relu.out(Tensor X, Tensor X_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
50+
)
51+
52+
lib.define(
53+
"quantized_conv(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False) -> (Tensor Z)"
2954
)
3055
lib.define(
31-
"quantized_linear_pt2.out(Tensor src, Tensor weight, Tensor bias, float src_scale, int src_zero_point, float weight_scale, int weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
56+
"quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
3257
)
3358

3459
m = Library("xtensa", "IMPL", "Meta")
@@ -58,18 +83,17 @@ def dequantize_per_tensor_meta(
5883
return input.new_empty(input.size(), dtype=torch.float)
5984

6085

61-
@impl(m, "quantized_linear_pt2")
62-
def quantized_linear_pt2_meta(
86+
@impl(m, "quantized_linear")
87+
def quantized_linear_meta(
6388
src: torch.Tensor,
6489
weight: torch.Tensor,
6590
bias: torch.Tensor,
66-
in_scale: float,
6791
in_zero_point: int,
68-
weight_scale: float,
69-
weight_zero_point: int,
70-
out_multiplier: int,
71-
out_shift: int,
92+
weight_zero_point: torch.Tensor,
93+
out_multiplier: torch.Tensor,
94+
out_shift: torch.Tensor,
7295
out_zero_point: int,
96+
offset: Optional[torch.Tensor],
7397
):
7498
# src comes in shape [leading_dims, in_dim]
7599
# weight comes in shape [out_dim, in_dim]
@@ -79,3 +103,58 @@ def quantized_linear_pt2_meta(
79103
assert len(weight_size) == 2
80104
out_size[-1] = weight_size[0]
81105
return src.new_empty(out_size, dtype=torch.uint8)
106+
107+
108+
@impl(m, "quantized_conv")
109+
def quantized_conv_meta(
110+
input: torch.Tensor,
111+
weight: torch.Tensor,
112+
bias: torch.Tensor,
113+
stride: Tuple[int],
114+
padding: Tuple[int],
115+
dilation: Tuple[int],
116+
groups: int,
117+
in_zero_point: int,
118+
weight_zero_point: torch.Tensor,
119+
bias_scale: torch.Tensor,
120+
output_scale: float,
121+
output_zero_point: int,
122+
out_multiplier: torch.Tensor,
123+
out_shift: torch.Tensor,
124+
channel_last: bool = False,
125+
):
126+
out_channels, _in_channels, *kernel_size = weight.shape
127+
in_size = input.shape
128+
# Assert that the input tensor has at least 3 dimensions, and at most 6
129+
assert len(in_size) > 2
130+
assert len(in_size) < 6
131+
132+
# Compute the output tensor size
133+
output_size = get_conv1d_output_size(
134+
in_size, out_channels, stride[0], padding[0], dilation[0], kernel_size[0]
135+
)
136+
137+
return input.new_empty(output_size, dtype=input.dtype)
138+
139+
140+
@impl(m, "quantized_layer_norm")
141+
def quantized_layer_norm_meta(
142+
input: torch.Tensor,
143+
X_scale: torch.Tensor,
144+
X_zero_point: torch.Tensor,
145+
normalized_shape: int,
146+
weight: torch.Tensor,
147+
bias: torch.Tensor,
148+
eps: float,
149+
output_scale: float,
150+
output_zero_point: int,
151+
):
152+
return input.new_empty(input.size(), dtype=torch.uint8)
153+
154+
155+
@impl(m, "quantized_relu")
156+
def quantized_relu_meta(
157+
X: torch.Tensor,
158+
X_zero_point: torch.Tensor,
159+
):
160+
return X.new_empty(X.size(), dtype=torch.uint8)

0 commit comments

Comments
 (0)