Skip to content

Commit 15b8833

Browse files
committed
Pull request pytorch#42: Feature/EIEX-84 improve nodeformat inference to handle format change for op aten view copy
Merge in AITEC/executorch from feature/EIEX-84-improve-nodeformat-inference-to-handle-format-change-for-op-aten-view_copy to main-nxp * commit 'db884b4b15f4dafd930769305cc9e0d05b5856bf': [NO-UPSTREAM] Add tests for proper handling of node formats of `aten.view_copy` (Reshape) in TFLite. Update node format inference to properly support `aten.view_copy`.
2 parents 925b5a8 + db884b4 commit 15b8833

File tree

3 files changed

+185
-5
lines changed

3 files changed

+185
-5
lines changed

backends/nxp/backend/node_format_inference.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,10 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import logging
8-
from collections import abc
98
from enum import Enum
109

1110
from torch import Node
1211
from torch.export import ExportedProgram
13-
from typing_extensions import Tuple
1412

1513
# (TODO Lukas) Can we found ops somewhere else?
1614
from executorch.exir.dialects._ops import ops as exir_ops
@@ -45,7 +43,7 @@ class NodeFormatInference:
4543
# A set of Edge Aten ops, which have the ability to change the format (for example - input nodes
4644
# are channels first but output is formatless).
4745
ops_that_can_change_tensor_format = {
48-
# TODO ("transpose", "reshape", etc.)
46+
exir_ops.edge.aten.view_copy.default
4947
}
5048

5149
_node_format_mapping: dict[Node, NodeFormat]
@@ -86,13 +84,55 @@ def _infer_format_of_nodes(self, node: Node):
8684
if op_type in self.ops_with_channels_first_nodes:
8785
self._handle_node_which_uses_channels_first_format(node)
8886
elif op_type in self.ops_that_can_change_tensor_format:
89-
if op_type in ["transpose"]:
90-
self._assign_format_to_node(node, NodeFormat.FORMATLESS)
87+
if op_type == exir_ops.edge.aten.view_copy.default: # view_copy
88+
self._assign_format_to_node(self._node_outputs[node][0], NodeFormat.FORMATLESS)
9189
else:
9290
logger.error(f"Node format inference for node type: {op_type} not found!")
9391
else:
9492
self._handle_node_which_can_use_any_node_format(node)
9593

94+
def _infer_format_based_on_io_ranks(self, node: Node):
95+
""" Determine the format of the output tensor of given "reshape style operator" based on the ranks of its input
96+
and output.
97+
"""
98+
# noinspection PyBroadException
99+
try:
100+
main_input_rank = len(node.all_input_nodes[0].meta['val'].shape)
101+
main_output_rank = len(node.meta['val'].shape)
102+
103+
if main_output_rank == main_input_rank:
104+
# Operator maintains the number of dimensions -> try to propagate the format.
105+
self._match_formats_of_nodes(node, node.prev)
106+
107+
else:
108+
# Either the op 'flattens' the tensor, so output is formatless, or it scales it up, in which case the
109+
# format is assumed to be 'FORMATLESS', and may be back propagated as channels first later.
110+
self._assign_format_to_node(node, NodeFormat.FORMATLESS)
111+
112+
except:
113+
# Some shape data is not known, so we cannot be extra clever. Just set the output to `FORMATLESS` and
114+
# everything will be alright.
115+
self._assign_format_to_node(node, NodeFormat.FORMATLESS)
116+
117+
def _match_formats_of_nodes(self, node_1, node_2):
118+
""" If one of 'node_1' or 'node_2' is channels first, make the other channels first as well.
119+
If neither is channels first, make them both formatless.
120+
"""
121+
122+
format_1 = self._get_node_format(node_1)
123+
format_2 = self._get_node_format(node_2)
124+
125+
if format_1.is_channels_first() or format_2.is_channels_first():
126+
# At least 1 is channels first
127+
if not format_1.is_channels_first():
128+
self._assign_format_to_node(node_1, NodeFormat.CHANNELS_FIRST)
129+
elif not format_2.is_channels_first():
130+
self._assign_format_to_node(node_2, NodeFormat.CHANNELS_FIRST)
131+
132+
else:
133+
self._assign_format_to_node(node_1, NodeFormat.FORMATLESS)
134+
self._assign_format_to_node(node_2, NodeFormat.FORMATLESS)
135+
96136
def _assign_format_to_node(self, node: Node, node_format: NodeFormat):
97137
"""
98138
Assign format to node, but only if it's not channels first.

backends/nxp/tests/executorch_pipeline.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import torch
2+
from torch import nn
23
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
34

5+
from executorch import exir
46
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
57
from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec
68
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
@@ -41,3 +43,9 @@ def to_quantized_executorch_program(model: torch.nn.Module, input_shape: tuple)
4143
extract_delegate_segments=False, extract_constant_segment=False
4244
)
4345
)
46+
47+
48+
def to_edge_program(model: nn.Module, input_shape) -> EdgeProgramManager:
49+
example_input = (torch.ones(input_shape),)
50+
exir_program = torch.export.export(model, example_input)
51+
return exir.to_edge(exir_program)
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from typing import Sequence
2+
3+
import numpy as np
4+
import torch
5+
from torch import nn
6+
7+
from executorch.backends.nxp.backend.ir.converter.builder.model_builder import ModelBuilder
8+
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options.conv_2d_options import Conv2D
9+
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options.reshape_options import Reshape
10+
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options.transpose_options import Transpose
11+
from executorch.backends.nxp.tests.executorch_pipeline import to_edge_program
12+
from executorch.backends.nxp.tests.executors import convert_run_compare, ToNHWCPreprocess, ToNCHWPreprocess
13+
14+
15+
class FormatlessToChannelsFirstModule(nn.Module):
16+
def __init__(self, channels: int, new_shape: Sequence[int]):
17+
super().__init__()
18+
self.conv = nn.Conv2d(channels, channels, 2, bias=True)
19+
self.new_shape = new_shape
20+
21+
def forward(self, x):
22+
x = torch.reshape(x, self.new_shape)
23+
x = self.conv(x)
24+
return x
25+
26+
27+
class FormatlessToFormatlessModule(nn.Module):
28+
def __init__(self, new_shape: Sequence[int]):
29+
super().__init__()
30+
self.new_shape = new_shape
31+
32+
def forward(self, x):
33+
x = torch.reshape(x, self.new_shape)
34+
return x
35+
36+
37+
class ConvReshapeModule(nn.Module):
38+
def __init__(self, channels: int, new_shape: Sequence[int]):
39+
super().__init__()
40+
self.conv = nn.Conv2d(channels, channels, 2, bias=True)
41+
self.new_shape = new_shape
42+
43+
def forward(self, x):
44+
x = self.conv(x)
45+
x = torch.reshape(x, self.new_shape)
46+
return x
47+
48+
49+
def test__channels_first_to_2d(mocker):
50+
input_shape = [2, 4, 7, 9]
51+
new_shape = [12, 32] # Mix up the dimensions for a thorough test.
52+
53+
torch_model = ConvReshapeModule(channels=input_shape[1], new_shape=new_shape)
54+
edge_program = to_edge_program(torch_model, input_shape).exported_program()
55+
56+
torch.manual_seed(23)
57+
input_data = np.random.random(input_shape).astype('float32')
58+
59+
converter_spy = mocker.spy(ModelBuilder, "finish")
60+
61+
convert_run_compare(edge_program, input_data, tflite_input_preprocess=ToNHWCPreprocess())
62+
63+
tflite_model = converter_spy.spy_return
64+
ops = tflite_model.sub_graphs[0].operators.vector
65+
assert len(ops) == 3
66+
assert isinstance(ops[0].builtin_options, Conv2D)
67+
assert isinstance(ops[1].builtin_options, Transpose)
68+
assert isinstance(ops[2].builtin_options, Reshape)
69+
70+
71+
def test__channels_first_to_4d(mocker):
72+
input_shape = [2, 4, 6, 8]
73+
new_shape = [7, 4, 2, 5]
74+
75+
torch_model = ConvReshapeModule(channels=input_shape[1], new_shape=new_shape)
76+
edge_program = to_edge_program(torch_model, input_shape).exported_program()
77+
78+
torch.manual_seed(23)
79+
input_data = np.random.random(input_shape).astype('float32')
80+
81+
converter_spy = mocker.spy(ModelBuilder, "finish")
82+
83+
convert_run_compare(edge_program, input_data, tflite_input_preprocess=ToNHWCPreprocess())
84+
85+
tflite_model = converter_spy.spy_return
86+
ops = tflite_model.sub_graphs[0].operators.vector
87+
assert len(ops) == 3
88+
assert isinstance(ops[0].builtin_options, Conv2D)
89+
assert isinstance(ops[1].builtin_options, Transpose)
90+
assert isinstance(ops[2].builtin_options, Reshape)
91+
92+
93+
def test__formatless_to_channels_first(mocker):
94+
input_shape = [12, 32]
95+
new_shape = [2, 4, 6, 8] # Mix up the dimensions for a thorough test.
96+
97+
torch_model = FormatlessToChannelsFirstModule(channels=new_shape[1], new_shape=new_shape)
98+
edge_program = to_edge_program(torch_model, input_shape).exported_program()
99+
100+
torch.manual_seed(23)
101+
input_data = np.random.random(input_shape).astype('float32')
102+
103+
converter_spy = mocker.spy(ModelBuilder, "finish")
104+
105+
convert_run_compare(edge_program, input_data, tflite_output_preprocess=ToNCHWPreprocess())
106+
107+
tflite_model = converter_spy.spy_return
108+
ops = tflite_model.sub_graphs[0].operators.vector
109+
assert len(ops) == 3
110+
assert isinstance(ops[0].builtin_options, Reshape)
111+
assert isinstance(ops[1].builtin_options, Transpose)
112+
assert isinstance(ops[2].builtin_options, Conv2D)
113+
114+
115+
def test__formatless_to_formatless(mocker):
116+
input_shape = [12, 32]
117+
new_shape = [2, 4, 6, 8]
118+
119+
torch_model = FormatlessToFormatlessModule(new_shape=new_shape)
120+
edge_program = to_edge_program(torch_model, input_shape).exported_program()
121+
122+
torch.manual_seed(23)
123+
input_data = np.random.random(input_shape).astype('float32')
124+
125+
converter_spy = mocker.spy(ModelBuilder, "finish")
126+
127+
convert_run_compare(edge_program, input_data)
128+
129+
tflite_model = converter_spy.spy_return
130+
ops = tflite_model.sub_graphs[0].operators.vector
131+
assert len(ops) == 1 # No extra Transpose ops.
132+
assert isinstance(ops[0].builtin_options, Reshape)

0 commit comments

Comments
 (0)