Skip to content

Commit 925b5a8

Browse files
committed
Pull request pytorch#41: [EIEX-86] Include header with tensor format into byte payload produced by backend
Merge in AITEC/executorch from feature/nxf93343/EIEX-86-tensor-format-in-payload to main-nxp * commit '8070682be46492a4cea03fa1d8960cc17f0dd586': [NO-UPSTREAM] Add tests for tensor format in payload Include header with tensor format into byte payload produced by backend
2 parents 195b0f5 + 8070682 commit 925b5a8

File tree

4 files changed

+157
-66
lines changed

4 files changed

+157
-66
lines changed

backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,22 +81,25 @@ def append_operators(self, ops_to_add: list[tflite_model.Operator]):
8181

8282
self.check_and_append_operator(op)
8383

84-
def assign_model_io_to_subgraph_and_get_io_formats(self, graph_signature) -> dict[str, TensorFormat]:
84+
def assign_model_io_to_subgraph_and_get_io_formats(self, graph_signature) -> dict[str, dict]:
8585
"""
8686
Assign model's inputs/outputs to SubGraph.
8787
8888
:param graph_signature: Instance of GraphSignature.
8989
:returns: Mapping between IO tensors' names and their formats.
9090
"""
91-
io_formats = {}
91+
io_formats = {
92+
"inputs": {},
93+
"outputs": {},
94+
}
9295

9396
self.get_sub_graph().inputs = tflite_model.SubGraphInputs()
9497
for input_name in graph_signature.user_inputs:
9598
tensor = self.tensor_for_name(input_name)
9699
assert input_name == tensor.name, ("Program's input name doesn't match with tensor name in TFLite. "
97100
"Input was probably redirected.")
98101
self.get_sub_graph().inputs.tmp_inputs.append(tensor)
99-
io_formats[tensor.name] = tensor.tensor_format
102+
io_formats["inputs"][tensor.name] = tensor.tensor_format
100103

101104
self.get_sub_graph().outputs = tflite_model.SubGraphOutputs()
102105
for output_name in graph_signature.user_outputs:
@@ -105,6 +108,6 @@ def assign_model_io_to_subgraph_and_get_io_formats(self, graph_signature) -> dic
105108
"Output was probably redirected.")
106109
self.get_sub_graph().outputs.tmp_outputs.append(tensor)
107110

108-
io_formats[tensor.name] = tensor.tensor_format
111+
io_formats["outputs"][tensor.name] = tensor.tensor_format
109112

110113
return io_formats

backends/nxp/neutron_node_extraction.py

Lines changed: 10 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,24 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import logging
8-
import struct
8+
from dataclasses import dataclass
99

1010
import numpy as np
1111

1212
from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOperator import BuiltinOperator
1313
from executorch.backends.nxp.backend.ir.lib.tflite.Model import Model
14-
from executorch.exir.backend.backend_details import PreprocessResult
1514

1615

17-
def extract_artifacts_from_neutron_node(tflite_flatbuffer_or_path: bytes | str) -> PreprocessResult:
18-
""" Extract the payload (microcode, weights, kernels) from the Neutron Node in the given TFLite model.
19-
The model can be provided as a binary flatbuffer, or a path to a `.tflite` model.
16+
@dataclass
17+
class NeutronNodeArtifacts:
18+
microcode: np.ndarray
19+
weights: np.ndarray
20+
kernels: np.ndarray
2021

21-
The return format is a `PreprocessResult` object, and its `processed_bytes` attribute contains the serialized
22-
binary data of the following C struct:
23-
struct NeutronBinary {
24-
uint8[] microcode;
25-
uint8[] weights;
26-
uint8[] kernels;
27-
}
2822

29-
The individual components must be aligned to 16 bytes.
23+
def extract_artifacts_from_neutron_node(tflite_flatbuffer_or_path: bytes | str) -> NeutronNodeArtifacts:
24+
""" Extract the payload (microcode, weights, kernels) from the Neutron Node in the given TFLite model.
25+
The model can be provided as a binary flatbuffer, or a path to a `.tflite` model.
3026
"""
3127

3228
if isinstance(tflite_flatbuffer_or_path, str):
@@ -77,35 +73,4 @@ def extract_artifacts_from_neutron_node(tflite_flatbuffer_or_path: bytes | str)
7773
assert microcode.dtype == weights.dtype == kernels.dtype == np.dtype('uint8'), \
7874
'The Neutron Node uses unexpected data types.'
7975

80-
# Align to 16B (according to commit 008bdc17670).
81-
alignment = 16
82-
83-
def padding_format_string_for_array(array: np.ndarray) -> str:
84-
""" Create a padding format string for the given array, which will add 0s at the end for correct alignment.
85-
E.g. the string '10x' represents adding 10 bytes of '0' padding.
86-
"""
87-
assert array.dtype == np.dtype('uint8')
88-
89-
overflow = array.size % alignment
90-
if overflow == 0:
91-
return ''
92-
93-
# Overflow 1 means padding 15, so use `alignment - overflow` padding.
94-
return f'{alignment - overflow}x'
95-
96-
def format_string_for_array(array: np.ndarray) -> str:
97-
""" Create a format string which will represent the provided array. It also handles the necessary alignment.
98-
E.g. for array [1,2,3] we get '3s13x', because '3s' means string of 3 bytes, and `13x` means adding 13 bytes
99-
of '0' padding at the end (for 16B alignment).
100-
"""
101-
assert array.dtype == np.dtype('uint8')
102-
103-
return f'{array.size}s{padding_format_string_for_array(array)}'
104-
105-
# The resulting payload should be structured as a binary in the format defined in the function header.
106-
payload = struct.pack(
107-
format_string_for_array(microcode) + format_string_for_array(weights) + format_string_for_array(kernels),
108-
microcode.tobytes(), weights.tobytes(), kernels.tobytes()
109-
)
110-
111-
return PreprocessResult(processed_bytes=payload)
76+
return NeutronNodeArtifacts(microcode, weights, kernels)

backends/nxp/nxp_backend.py

Lines changed: 117 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,20 @@
99
#
1010

1111
import logging
12+
import struct
1213
from typing import final, List, Optional
1314

14-
from torch.export.exported_program import ExportedProgram
15-
15+
import numpy as np
1616
import torch
17+
from torch.export.exported_program import ExportedProgram
1718

1819
from executorch.backends.nxp.backend.edge_program_converter import EdgeProgramToIRConverter
1920
from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat
2021
from executorch.backends.nxp.backend.neutron_converter_manager import NeutronConverterManager
21-
from executorch.backends.nxp.neutron_node_extraction import extract_artifacts_from_neutron_node
22+
from executorch.backends.nxp.neutron_node_extraction import extract_artifacts_from_neutron_node, NeutronNodeArtifacts
2223
from executorch.backends.xnnpack.passes import RemoveGetItemPass, XNNPACKPassManager
2324
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
2425
from executorch.exir.backend.compile_spec_schema import CompileSpec
25-
26-
27-
from torch.export.exported_program import ExportedProgram
2826
from executorch.exir.verification.verifier import EXIREdgeDialectVerifier
2927

3028

@@ -132,13 +130,6 @@ def preprocess(
132130

133131
# Convert the edge program to TFLite.
134132
tflite_model, io_formats = EdgeProgramToIRConverter().convert_program(edge_program)
135-
for tensor, tensor_format in io_formats.items():
136-
if tensor_format == TensorFormat.CHANNELS_LAST:
137-
channel_last_format = b'1'
138-
else:
139-
channel_last_format = b'0'
140-
141-
compile_spec.append(CompileSpec(tensor, channel_last_format))
142133

143134
# Call the neutron converter with the TFLite model.
144135
neutron_model = NeutronConverterManager().convert(tflite_model)
@@ -153,11 +144,121 @@ def preprocess(
153144
f.write(bytes(neutron_model))
154145
NeutronBackend.counter = NeutronBackend.counter + 1
155146

156-
# Extract the Neutron microcode, weights and kernels from the Neutron Node in the `neutron_model`.
157-
payload = extract_artifacts_from_neutron_node(neutron_model)
158-
binary = payload.processed_bytes
147+
binary = PayloadComposer().get_binary_payload(io_formats, neutron_model)
159148

160149
else:
161150
raise RuntimeError(f"Unknown format {output_format}")
162151

163152
return PreprocessResult(processed_bytes=binary)
153+
154+
155+
class PayloadComposer:
156+
ALIGNMENT = 16
157+
158+
def _padding_format_string_for_array(self, array: np.ndarray) -> str:
159+
""" Create a padding format string for the given array, which will add 0s at the end for correct alignment.
160+
E.g. the string '10x' represents adding 10 bytes of '0' padding.
161+
"""
162+
assert array.dtype == np.dtype('uint8')
163+
164+
overflow = array.size % self.ALIGNMENT
165+
if overflow == 0:
166+
return ''
167+
168+
# Overflow 1 means padding 15, so use `alignment - overflow` padding.
169+
return f'{self.ALIGNMENT - overflow}x'
170+
171+
def _format_string_for_array(self, array: np.ndarray) -> str:
172+
""" Create a format string which will represent the provided array. It also handles the necessary alignment.
173+
E.g. for array [1,2,3] we get '3s13x', because '3s' means string of 3 bytes, and `13x` means adding 13 bytes
174+
of '0' padding at the end (for 16B alignment).
175+
"""
176+
assert array.dtype == np.dtype('uint8')
177+
178+
return f'{array.size}s{self._padding_format_string_for_array(array)}'
179+
180+
def _create_payload_header(self, io_formats) -> np.ndarray:
181+
"""
182+
Create bytes header for returned payload. It contains information about
183+
input and output tensor formats. Tensors are ordered based on graph signature
184+
of ExportedProgram. Header schema:
185+
186+
+----------------------------------+------------------------+---------------------------+
187+
| Input TensorFormats length (1B) | 1st tensor format (1B) | [nth* tensor format (1B)] |
188+
+----------------------------------+------------------------+---------------------------+
189+
| Output TensorFormats length (1B) | 1st tensor format (1B) | [nth* tensor format (1B)] |
190+
+----------------------------------+------------------------+---------------------------+
191+
192+
:param io_formats: IO tensors formats.
193+
:return: Bytes representation of payload header.
194+
"""
195+
inputs = io_formats["inputs"]
196+
outputs = io_formats["outputs"]
197+
198+
assert len(inputs) < 256, "Models with more than 255 inputs are not supported."
199+
assert len(outputs) < 256, "Models with more than 255 outputs are not supported."
200+
201+
header_data = [len(inputs)]
202+
for tensor, tensor_format in inputs.items():
203+
header_data.append(1 if tensor_format == TensorFormat.CHANNELS_LAST else 0)
204+
205+
header_data.append(len(outputs))
206+
for tensor, tensor_format in outputs.items():
207+
header_data.append(1 if tensor_format == TensorFormat.CHANNELS_LAST else 0)
208+
209+
# noinspection PyTypeChecker
210+
return np.array(header_data, dtype=np.uint8)
211+
212+
def _pack_with_alignment(self, header: np.ndarray, neutron_artifacts: NeutronNodeArtifacts) -> bytes:
213+
"""
214+
Packs provided data into serialized binary data of the following C struct:
215+
struct NeutronBinary {
216+
uint8[] header;
217+
uint8[] microcode;
218+
uint8[] weights;
219+
uint8[] kernels;
220+
}
221+
The individual components must be aligned to 16 bytes.
222+
"""
223+
224+
return struct.pack(
225+
self._format_string_for_array(header) +
226+
self._format_string_for_array(neutron_artifacts.microcode) +
227+
self._format_string_for_array(neutron_artifacts.weights) +
228+
self._format_string_for_array(neutron_artifacts.kernels),
229+
header.tobytes(),
230+
neutron_artifacts.microcode.tobytes(),
231+
neutron_artifacts.weights.tobytes(),
232+
neutron_artifacts.kernels.tobytes()
233+
)
234+
235+
def get_binary_payload(self, io_formats, neutron_model) -> bytes:
236+
"""
237+
Get binary payload for provided input/output tensor formats and neutron_model. Returned data have
238+
following structure:
239+
240+
+----------------------------------------------------------------------------------------------------------------+
241+
| 16 bytes aligned blocks |
242+
+===========================+===========================+============================+===========================+
243+
| Input formats length (1B) | [nth* tensor format (1B)] | Output formats length (1B) | [nth* tensor format (1B)] |
244+
+---------------------------+---------------------------+----------------------------+---------------------------+
245+
| Neutron microcode |
246+
+----------------------------------------------------------------------------------------------------------------+
247+
| Neutron weights |
248+
+----------------------------------------------------------------------------------------------------------------+
249+
| Neutron kernels |
250+
+----------------------------------------------------------------------------------------------------------------+
251+
252+
Tensor format definition: '0x1' == CHANNELS_LAST, '0x0' == FORMATLESS (no format).
253+
254+
:param io_formats: Dictionary with keys 'inputs' and 'outputs' that contains dictionaries
255+
mapping tensor name to TensorFormat.
256+
:param neutron_model: Neutron model with single NeutronGraph node.
257+
:return: 16 bytes aligned binary payload.
258+
"""
259+
header = self._create_payload_header(io_formats)
260+
261+
# Extract the Neutron microcode, weights and kernels from the Neutron Node in the `neutron_model`.
262+
neutron_artifacts = extract_artifacts_from_neutron_node(neutron_model)
263+
264+
return self._pack_with_alignment(header, neutron_artifacts)

backends/nxp/tests/test_neutron_backend.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
1414
from executorch.backends.nxp.tests.executors import TFLiteExecutor, EdgeProgramExecutor, convert_run_compare, \
1515
ToNHWCPreprocess
16-
from executorch.backends.nxp.tests.models import Conv2dModule
16+
from executorch.backends.nxp.tests.models import Conv2dModule, SoftmaxModule
1717
from executorch.backends.nxp.tests.models import ConvFCSoftmaxModule
1818

1919

@@ -22,6 +22,28 @@ def test_neutron_backend__single_conv_model():
2222
lowered_module = edge_program_manager.exported_program().graph_module.lowered_module_0
2323
assert len(lowered_module.processed_bytes) != 0 # The Neutron microcode, weights and kernels have been written here
2424

25+
def test_neutron_backend__single_conv_model__payload_header():
26+
edge_program_manager = to_quantized_edge_program(Conv2dModule(bias=False), (1, 4, 32, 32))
27+
payload = edge_program_manager.exported_program().graph_module.lowered_module_0.processed_bytes
28+
29+
assert payload[0] == 0x1 # Single input
30+
assert payload[1] == 0x1 # Channels last
31+
assert payload[2] == 0x1 # Single output
32+
assert payload[3] == 0x1 # Channels last
33+
assert all(byte == 0x0 for byte in payload[4:16]) # Aligned to 16 bytes
34+
assert payload[17] != 0x0 # Followed by non-zero content
35+
36+
def test_neutron_backend__single_softmax_model__payload_header():
37+
edge_program_manager = to_quantized_edge_program(SoftmaxModule(1), (1, 64))
38+
payload = edge_program_manager.exported_program().graph_module.lowered_module_0.processed_bytes
39+
40+
assert payload[0] == 0x1 # Single input
41+
assert payload[1] == 0x0 # Formatless
42+
assert payload[2] == 0x1 # Single output
43+
assert payload[3] == 0x0 # Formatless
44+
assert all(byte == 0x0 for byte in payload[4:16]) # Aligned to 16 bytes
45+
assert payload[17] != 0x0 # Followed by non-zero content
46+
2547

2648
def test_lowered_program_and_tflite_output_match__conv2d__no_bias(mocker):
2749
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")

0 commit comments

Comments
 (0)