diff --git a/backends/arm/operator_support/to_copy_support.py b/backends/arm/operator_support/to_copy_support.py index aa0be8cfcd0..343d949c244 100644 --- a/backends/arm/operator_support/to_copy_support.py +++ b/backends/arm/operator_support/to_copy_support.py @@ -77,7 +77,6 @@ def is_node_tosa_supported( ) -> bool: assert node.target in self.targets - assert tosa_spec.support_integer() supported_dtypes = ( self.ALL_SUPPORTED_TYPES if tosa_spec.support_float() diff --git a/backends/arm/operators/op_amax.py b/backends/arm/operators/op_amax.py index 059f6c1e553..6bb9d563ca6 100644 --- a/backends/arm/operators/op_amax.py +++ b/backends/arm/operators/op_amax.py @@ -2,9 +2,8 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List +from typing import Any, List -import tosa_tools.v0_80.serializer.tosa_serializer as ts from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -15,19 +14,22 @@ @register_node_visitor -class MaxVisitor(NodeVisitor): +class MaxVisitor_0_80(NodeVisitor): target = "aten.amax.default" + tosa_specs = NodeVisitor.tosa_specs_0_80 + def __init__(self, *args): super().__init__(*args) def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts input = inputs[0] dim = inputs[1].number @@ -49,3 +51,42 @@ def define_node( tosa_graph.addOperator( ts.TosaOp.Op().REDUCE_MAX, [input.name], [output.name], attr ) + + +@register_node_visitor +class MaxVisitor(NodeVisitor): + target = "aten.amax.default" + + tosa_specs = NodeVisitor.tosa_specs_1_00 + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import serializer.tosa_serializer as ts + + input = inputs[0] + dim = inputs[1].number + + if dim < 0: + tensor = get_first_fake_tensor(node) + rank = len(tensor.size()) + dim = rank + dim + + keep_dims = inputs[2].number + if not keep_dims: + raise RuntimeError( + "TOSA only supports keepdims == True; Did you run the convert_minmax pass?" + ) + + attr = ts.TosaSerializerAttribute() + attr.ReduceMaxAttribute(axis=input.dim_order.index(dim), nan_mode=1) + tosa_graph.addOperator( + ts.TosaOp.Op().REDUCE_MAX, [input.name], [output.name], attr + ) diff --git a/backends/arm/operators/op_amin.py b/backends/arm/operators/op_amin.py index 85e43b76c4c..5c0fee5cfaf 100644 --- a/backends/arm/operators/op_amin.py +++ b/backends/arm/operators/op_amin.py @@ -2,9 +2,8 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List +from typing import Any, List -import tosa_tools.v0_80.serializer.tosa_serializer as ts from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -15,19 +14,22 @@ @register_node_visitor -class MinVisitor(NodeVisitor): +class MinVisitor_0_80(NodeVisitor): target = "aten.amin.default" + tosa_specs = NodeVisitor.tosa_specs_0_80 + def __init__(self, *args): super().__init__(*args) def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts input = inputs[0] dim = inputs[1].number @@ -49,3 +51,42 @@ def define_node( tosa_graph.addOperator( ts.TosaOp.Op().REDUCE_MIN, [input.name], [output.name], attr ) + + +@register_node_visitor +class MinVisitor(NodeVisitor): + target = "aten.amin.default" + + tosa_specs = NodeVisitor.tosa_specs_1_00 + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import serializer.tosa_serializer as ts + + input = inputs[0] + dim = inputs[1].number + + if dim < 0: + tensor = get_first_fake_tensor(node) + rank = len(tensor.size()) + dim = rank + dim + + keep_dims = inputs[2].number + if not keep_dims: + raise RuntimeError( + "TOSA only supports keepdims == True; Did you run the convert_minmax pass?" + ) + + attr = ts.TosaSerializerAttribute() + attr.ReduceMinAttribute(axis=input.dim_order.index(dim), nan_mode=1) + tosa_graph.addOperator( + ts.TosaOp.Op().REDUCE_MIN, [input.name], [output.name], attr + ) diff --git a/backends/arm/operators/op_clamp.py b/backends/arm/operators/op_clamp.py index b18ed640b5f..aedcc643e5d 100644 --- a/backends/arm/operators/op_clamp.py +++ b/backends/arm/operators/op_clamp.py @@ -8,9 +8,9 @@ from typing import Any, List, Tuple +import numpy as np import torch -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -34,7 +34,7 @@ def __init__(self, *args): def _create_clamp_node( self, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, input_name: str, output_name: str, min_int: int, @@ -42,6 +42,8 @@ def _create_clamp_node( min_fp32: float, max_fp32: float, ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + attr = ts.TosaSerializerAttribute() attr.ClampAttribute( tosa_graph.builder, @@ -81,7 +83,7 @@ def cast_type(value: Any) -> int | float: def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: @@ -122,10 +124,12 @@ def __init__(self, *args): def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + if len(node.all_input_nodes) != 1: raise ValueError( f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" @@ -150,3 +154,118 @@ def define_node( min_fp32, max_fp32, ) + + +@register_node_visitor +class ClampVisitor_INT(NodeVisitor): + target = "aten.clamp.default" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def _get_min_max_arguments( + self, node: Node, dtype_min: int | float, dtype_max: int | float + ) -> Tuple[int | float, int | float]: + + def cast_type(value: Any) -> int | float: + if isinstance(value, int): + return value + else: + # Attempt to cast to float + return float(value) + + if len(node.args) != 2 and len(node.args) != 3: + raise ValueError(f"Expected len(node.args) to be 2 or 3, got {node.args}") + + min_arg = dtype_min + max_arg = dtype_max + + if node.args[1] is not None: + min_arg = cast_type(node.args[1]) + + if len(node.args) > 2: + if node.args[2] is not None: + max_arg = cast_type(node.args[2]) + + return min_arg, max_arg + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import serializer.tosa_serializer as ts # type: ignore + + if len(node.all_input_nodes) != 1: + raise ValueError( + f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" + ) + + # NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments + min_int8, max_int8 = self._get_min_max_arguments( + node, + torch.iinfo(torch.int8).min, + torch.iinfo(torch.int8).max, + ) + + attr = ts.TosaSerializerAttribute() + attr.ClampAttribute( + tosa_graph.builder, + np.int8(min_int8).tobytes(), + np.int8(max_int8).tobytes(), + nan_mode=1, + ) + + tosa_graph.addOperator( + ts.TosaOp.Op().CLAMP, [inputs[0].name], [output.name], attr + ) + + +@register_node_visitor +class ClampVisitor_FP(ClampVisitor_INT): + # inheriting 'target' from INT class + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import serializer.tosa_serializer as ts # type: ignore + + if len(node.all_input_nodes) != 1: + raise ValueError( + f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" + ) + + min_fp32, max_fp32 = self._get_min_max_arguments( + node, + torch.finfo(torch.float32).min, + torch.finfo(torch.float32).max, + ) + + attr = ts.TosaSerializerAttribute() + attr.ClampAttribute( + tosa_graph.builder, + np.float32(min_fp32).tobytes(), + np.float32(max_fp32).tobytes(), + nan_mode=1, + ) + + tosa_graph.addOperator( + ts.TosaOp.Op().CLAMP, [inputs[0].name], [output.name], attr + ) diff --git a/backends/arm/operators/op_repeat.py b/backends/arm/operators/op_repeat.py index 142ccb1d25a..979a10ecff1 100644 --- a/backends/arm/operators/op_repeat.py +++ b/backends/arm/operators/op_repeat.py @@ -5,8 +5,9 @@ # pyre-unsafe +from typing import Any + import torch -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -16,19 +17,22 @@ @register_node_visitor -class RepeatVisitor(NodeVisitor): +class RepeatVisitor_0_80(NodeVisitor): target = "aten.repeat.default" + tosa_specs = NodeVisitor.tosa_specs_0_80 + def __init__(self, *args): super().__init__(*args) def define_node( self, node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: list[TosaArg], output: TosaArg, ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore multiples = inputs[1].special @@ -37,3 +41,41 @@ def define_node( tosa_graph.addOperator( ts.TosaOp.Op().TILE, [inputs[0].name], [output.name], attr ) + + +@register_node_visitor +class RepeatVisitor(NodeVisitor): + target = "aten.repeat.default" + + tosa_specs = NodeVisitor.tosa_specs_1_00 + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: list[TosaArg], + output: TosaArg, + ) -> None: + import serializer.tosa_serializer as ts # type: ignore + + multiples = inputs[1].special + + if len(multiples) == 0: + raise ValueError(f"Length of multiples argument is 0: {inputs[1]}!") + + multiple_shapes = tosa_graph.addConst( + (len(multiples),), + ts.DType.SHAPE, + list(tosa_shape(multiples, output.dim_order)), + name=node.name + "_multiples", + ) + + tosa_graph.addOperator( + ts.TosaOp.Op().TILE, + [inputs[0].name, multiple_shapes.name], + [output.name], + None, + ) diff --git a/backends/arm/operators/op_slice.py b/backends/arm/operators/op_slice.py index 27ae977a5bc..a8d326cfa9b 100644 --- a/backends/arm/operators/op_slice.py +++ b/backends/arm/operators/op_slice.py @@ -5,9 +5,8 @@ # pyre-unsafe -from typing import List +from typing import Any, List -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -16,32 +15,37 @@ from torch.fx import Node +def _fixup_start(start, shape, dim): + if start.number < 0: + return start.number % shape[dim] + else: + return start.number + + +def _fixup_end(end, shape, dim): + if end.number < 0: + return end.number % shape[dim] + else: + return min(end.number, shape[dim]) + + @register_node_visitor -class SliceVisitor(NodeVisitor): +class SliceVisitor_080(NodeVisitor): target = "aten.slice_copy.Tensor" + tosa_specs = NodeVisitor.tosa_specs_0_80 + def __init__(self, *args): super().__init__(*args) - def _fixup_start(self, start, shape, dim): - if start.number < 0: - return start.number % shape[dim] - else: - return start.number - - def _fixup_end(self, end, shape, dim): - if end.number < 0: - return end.number % shape[dim] - else: - return min(end.number, shape[dim]) - def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore # See slice_copy_support.py if not (len(inputs) == 4 or (len(inputs) == 5 and inputs[4].number == 1)): @@ -55,8 +59,8 @@ def define_node( shape = input_node.shape dim = dim.number - start_index = self._fixup_start(start, shape, dim) - end_index = self._fixup_end(end, shape, dim) + start_index = _fixup_start(start, shape, dim) + end_index = _fixup_end(end, shape, dim) size = end_index - start_index assert size > 0 @@ -66,7 +70,7 @@ def define_node( attr = ts.TosaSerializerAttribute() start_attr = [ - self._fixup_start(start, shape, dim) if i == dim else 0 + _fixup_start(start, shape, dim) if i == dim else 0 for i in input_node.dim_order ] size_attr = [size if i == dim else shape[i] for i in input_node.dim_order] @@ -75,3 +79,77 @@ def define_node( tosa_graph.addOperator( ts.TosaOp.Op().SLICE, [input_node.name], [output.name], attr ) + + +@register_node_visitor +class SliceVisitor(NodeVisitor): + target = "aten.slice_copy.Tensor" + + tosa_specs = NodeVisitor.tosa_specs_1_00 + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import serializer.tosa_serializer as ts # type: ignore + + # See slice_copy_support.py + if not (len(inputs) == 4 or (len(inputs) == 5 and inputs[4].number == 1)): + raise ValueError("Unsupported combination of inputs") + + # aten.slice_copy supports slicing in 1d at a time. + # The arguments are the actual input, dimension of slicing, start index, end index and optinal step or stride. + input_node, dim, start, end = inputs + + # Translate and check parameters in Pytorch dim order. + shape = input_node.shape + dim = dim.number + + start_index = _fixup_start(start, shape, dim) + end_index = _fixup_end(end, shape, dim) + size = end_index - start_index + + assert size > 0 + assert size <= shape[dim] + + # Convert aten args to Tosa's start and size shape_t tensors and in TOSA dim order. + starts = [ + _fixup_start(start, shape, dim) if i == dim else 0 + for i in input_node.dim_order + ] + + if len(starts) != 0: + starts_len = len(starts) + else: + starts_len = 1 + starts = [0] + + start_tensor = tosa_graph.addConst( + (starts_len,), + ts.DType.SHAPE, + starts, + node.name + "_start_shape", + ) + + sizes = [size if i == dim else shape[i] for i in input_node.dim_order] + if len(sizes) != 0: + sizes_len = len(starts) + else: + sizes_len = 1 + sizes = [0] + sizes_tensor = tosa_graph.addConst( + (sizes_len,), ts.DType.SHAPE, sizes, node.name + "_sizes_shape" + ) + + tosa_graph.addOperator( + ts.TosaOp.Op().SLICE, + [input_node.name, start_tensor.name, sizes_tensor.name], + [output.name], + None, + ) diff --git a/backends/arm/operators/op_table.py b/backends/arm/operators/op_table.py index 6a2053bea0d..454aebecd5e 100644 --- a/backends/arm/operators/op_table.py +++ b/backends/arm/operators/op_table.py @@ -5,30 +5,34 @@ # pyre-unsafe -from typing import List +from typing import Any, List import numpy as np import torch - -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification + @register_node_visitor -class TableVisitor(NodeVisitor): +class TableVisitor_0_80(NodeVisitor): target = "_table.default" + tosa_specs = NodeVisitor.tosa_specs_0_80 + def define_node( self, node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + if node.name not in self._exported_program.state_dict.keys(): # type: ignore[union-attr] raise RuntimeError( f"Did not find key {node.name} in state_dict {self._exported_program.state_dict.keys()}." @@ -50,3 +54,50 @@ def define_node( tosa_graph.addOperator( ts.TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr ) + + +@register_node_visitor +class TableVisitor(NodeVisitor): + target = "_table.default" + + tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+INT")] + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import serializer.tosa_serializer as ts # type: ignore + + if node.name not in self._exported_program.state_dict.keys(): # type: ignore[union-attr] + raise RuntimeError( + f"Did not find key {node.name} in state_dict {self._exported_program.state_dict.keys()}." + ) + if inputs[0].dtype == ts.DType.INT8 and output.dtype != ts.DType.INT8: + raise ValueError(f"Int8 tables need int8 output, got {output.dtype=}.") + if inputs[0].dtype == ts.DType.INT16 and output.dtype != ts.DType.INT32: + raise ValueError(f"Int16 tables need int32 output, got {output.dtype=}.") + + if inputs[0].dtype not in (ts.DType.INT8, ts.DType.INT16): + raise ValueError( + f"TOSA.TABLE only supports int8 or int16 inputs, got {ts.DTypeNames[inputs[0].dtype]}" + ) + + table = self._exported_program.state_dict[node.name] + + table_tensor_name = node.name + "_table" + tosa_graph.addConst( + table.shape, + ts.DType.INT8 if inputs[0].dtype == ts.DType.INT8 else ts.DType.INT16, + table.detach().numpy(), + name=table_tensor_name, + ) + + tosa_graph.addOperator( + ts.TosaOp.Op().TABLE, + [inputs[0].name, table_tensor_name], + [output.name], + None, + ) diff --git a/backends/arm/operators/op_to_copy.py b/backends/arm/operators/op_to_copy.py index 90485b71d50..210bfd2f61f 100644 --- a/backends/arm/operators/op_to_copy.py +++ b/backends/arm/operators/op_to_copy.py @@ -4,13 +4,10 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import List +from typing import Any, List import torch -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore -import tosa_tools.v0_80.tosa.Op as TosaOp # type: ignore - from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -18,6 +15,33 @@ from executorch.backends.arm.tosa_mapping import TosaArg +@register_node_visitor +class ToCopyVisitor_0_80(NodeVisitor): + """ + Implement the type cast functionality of _to_copy. + + Other features like setting of the memory_format or moving a tensor to a + different device are not supported. + + Also note that the node should not be quantized. + """ + + target = "aten._to_copy.default" + + tosa_specs = NodeVisitor.tosa_specs_0_80 + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + + tosa_graph.addOperator(ts.TosaOp.Op().CAST, [inputs[0].name], [output.name]) + + @register_node_visitor class ToCopyVisitor(NodeVisitor): """ @@ -31,11 +55,15 @@ class ToCopyVisitor(NodeVisitor): target = "aten._to_copy.default" + tosa_specs = NodeVisitor.tosa_specs_1_00 + def define_node( self, node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: - tosa_graph.addOperator(TosaOp.Op().CAST, [inputs[0].name], [output.name]) + import serializer.tosa_serializer as ts # type: ignore + + tosa_graph.addOperator(ts.TosaOp.Op().CAST, [inputs[0].name], [output.name]) diff --git a/backends/arm/operators/op_to_dim_order_copy.py b/backends/arm/operators/op_to_dim_order_copy.py index f144beba29f..740576f2736 100644 --- a/backends/arm/operators/op_to_dim_order_copy.py +++ b/backends/arm/operators/op_to_dim_order_copy.py @@ -4,13 +4,10 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import List +from typing import Any, List import torch -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore -import tosa_tools.v0_80.tosa.Op as TosaOp # type: ignore - from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -18,6 +15,33 @@ from executorch.backends.arm.tosa_mapping import TosaArg +@register_node_visitor +class ToDimOrderCopyVisitor_0_80(NodeVisitor): + """ + Implement the type cast functionality of _to_dim_order_copy. + + Other features like setting of the dim_order or moving a tensor to a + different device are not supported. + + Also note that the node should not be quantized. + """ + + target = "dim_order_ops._to_dim_order_copy.default" + + tosa_specs = NodeVisitor.tosa_specs_0_80 + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + + tosa_graph.addOperator(ts.TosaOp.Op().CAST, [inputs[0].name], [output.name]) + + @register_node_visitor class ToDimOrderCopyVisitor(NodeVisitor): """ @@ -31,11 +55,15 @@ class ToDimOrderCopyVisitor(NodeVisitor): target = "dim_order_ops._to_dim_order_copy.default" + tosa_specs = NodeVisitor.tosa_specs_1_00 + def define_node( self, node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: - tosa_graph.addOperator(TosaOp.Op().CAST, [inputs[0].name], [output.name]) + import serializer.tosa_serializer as ts # type: ignore + + tosa_graph.addOperator(ts.TosaOp.Op().CAST, [inputs[0].name], [output.name]) diff --git a/backends/arm/operators/op_transpose.py b/backends/arm/operators/op_transpose.py index b909aef2ac9..ac98979c234 100644 --- a/backends/arm/operators/op_transpose.py +++ b/backends/arm/operators/op_transpose.py @@ -5,11 +5,10 @@ # pyre-unsafe -from typing import List +from typing import Any, List import torch -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -17,6 +16,36 @@ from executorch.backends.arm.tosa_mapping import TosaArg +@register_node_visitor +class TransposeVisitor_0_80(NodeVisitor): + """ + This node visitor targets the _transpose op defined in the + passthrough_to_tosa library. Used when switching between tosa_dim_orders. + Inserts a TOSA TRANSPOSE. + """ + + target = "_transpose.default" + + tosa_specs = NodeVisitor.tosa_specs_0_80 + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + + output_rank = len(output.shape) + perms = [dim % output_rank for dim in inputs[1].special] + attr = ts.TosaSerializerAttribute() + attr.TransposeAttribute(perms) + tosa_graph.addOperator( + ts.TosaOp.Op().TRANSPOSE, [inputs[0].name], [output.name], attr + ) + + @register_node_visitor class TransposeVisitor(NodeVisitor): """ @@ -27,13 +56,17 @@ class TransposeVisitor(NodeVisitor): target = "_transpose.default" + tosa_specs = NodeVisitor.tosa_specs_1_00 + def define_node( self, node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import serializer.tosa_serializer as ts # type: ignore + output_rank = len(output.shape) perms = [dim % output_rank for dim in inputs[1].special] attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_view.py b/backends/arm/operators/op_view.py index e063b8e39ec..3e965d8f6ff 100644 --- a/backends/arm/operators/op_view.py +++ b/backends/arm/operators/op_view.py @@ -4,13 +4,10 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import List +from typing import Any, cast, List import torch -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore -import tosa_tools.v0_80.tosa.Op as TosaOp # type: ignore - from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -20,22 +17,67 @@ @register_node_visitor -class ViewVisitor(NodeVisitor): +class ViewVisitor_0_80(NodeVisitor): target = "aten.view_copy.default" + tosa_specs = NodeVisitor.tosa_specs_0_80 + def __init__(self, *args): super().__init__(*args) def define_node( self, node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts + attr = ts.TosaSerializerAttribute() new_shape = tosa_shape(inputs[1].special, output.dim_order) attr.ReshapeAttribute(new_shape) + tosa_graph = cast(ts.TosaSerializer, tosa_graph) + tosa_graph.addOperator( + ts.TosaOp.Op().RESHAPE, [inputs[0].name], [output.name], attr + ) + + +@register_node_visitor +class ViewVisitor(NodeVisitor): + target = "aten.view_copy.default" + + tosa_specs = NodeVisitor.tosa_specs_1_00 + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import serializer.tosa_serializer as ts + + tosa_graph = cast(ts.TosaSerializer, tosa_graph) + + if len(output.shape) == 0: + raise ValueError(f"No output shape for {output}") + + shape_len = len(output.shape) + shape_data = list(tosa_shape(output.shape, output.dim_order)) + + shape = tosa_graph.addConst( + [shape_len], + ts.DType.SHAPE, + shape_data, + name=node.name + "_shape", + ) + + attr = ts.TosaSerializerAttribute() + attr.ReshapeAttribute() tosa_graph.addOperator( - TosaOp.Op().RESHAPE, [inputs[0].name], [output.name], attr + ts.TosaOp.Op().RESHAPE, [inputs[0].name, shape.name], [output.name], attr ) diff --git a/backends/arm/operators/ops_identity.py b/backends/arm/operators/ops_identity.py index 0c6527cf336..0c41e13d445 100644 --- a/backends/arm/operators/ops_identity.py +++ b/backends/arm/operators/ops_identity.py @@ -5,13 +5,11 @@ # pyre-unsafe -from typing import List +from typing import Any, List import torch import torch.fx -import tosa_tools.v0_80.serializer.tosa_serializer as ts - from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -19,6 +17,38 @@ from executorch.backends.arm.tosa_mapping import TosaArg +def identity_operator_factory_v0_80(identity_target: str): + """ + Creates and registers NodeVisitors for operators that map directly + to a TOSA IDENTITY op. + """ + + class IdentityOperatorVisitor(NodeVisitor): + target = identity_target + + tosa_specs = NodeVisitor.tosa_specs_0_80 + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts + + # Simply add an identityOp + tosa_graph.addOperator( + ts.TosaOp.Op().IDENTITY, [inputs[0].name], [output.name] + ) + + register_node_visitor(IdentityOperatorVisitor) + + +identity_operator_factory_v0_80("getitem") +identity_operator_factory_v0_80("aten.alias_copy.default") + + def identity_operator_factory(identity_target: str): """ Creates and registers NodeVisitors for operators that map directly @@ -28,13 +58,17 @@ def identity_operator_factory(identity_target: str): class IdentityOperatorVisitor(NodeVisitor): target = identity_target + tosa_specs = NodeVisitor.tosa_specs_1_00 + def define_node( self, node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import serializer.tosa_serializer as ts + # Simply add an identityOp tosa_graph.addOperator( ts.TosaOp.Op().IDENTITY, [inputs[0].name], [output.name]