Skip to content

Commit c4f0481

Browse files
DenisVieriu97pytorchbot
authored andcommitted
Add index.Tensor and aten.logical_not (#3221)
Summary: Add missing llama ops for MPS delegate: - `index.Tensor` - `logical_not` `index.put` works correctly for generating 1 token, but gives incorrect results on 2nd token. This remains disabled. Summary of changes: - Adds missing llama2 ops - Adds support for launching Metal kernels instead of MPSGraph ops (if MPSGraph doesn't have the support) cc cccclai , shoumikhin Pull Request resolved: #3221 Reviewed By: shoumikhin Differential Revision: D56447710 Pulled By: cccclai fbshipit-source-id: 778a485df5e67d1afd006b42f07b69c8a3961223 (cherry picked from commit 02a6b66)
1 parent 66783f4 commit c4f0481

15 files changed

+446
-10
lines changed

backends/apple/mps/mps_preprocess.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from executorch.backends.apple.mps.serialization.mps_graph_schema import (
1919
MPSGraph,
2020
MPSTensor,
21+
OpType,
2122
)
2223

2324
from executorch.backends.apple.mps.serialization.mps_graph_serialize import (
@@ -65,6 +66,7 @@ def preprocess(
6566
input_ids=[],
6667
output_ids=[],
6768
constant_ids=[],
69+
graph_type=OpType.mps_graph,
6870
)
6971

7072
convert_model_to_fp16 = True
@@ -111,6 +113,16 @@ def handle_call_function(
111113
mps_graph: MPSGraph,
112114
) -> None:
113115
logging.info(f"Visiting: {node}, {node.target.__name__}")
116+
117+
if (
118+
"delegation_tag" in node.meta
119+
and "metal_kernel" in node.meta["delegation_tag"]
120+
):
121+
logging.info(
122+
f"Node '{node.target.__name__}' was marked as a Metal kernel by the MPSPartitioner!"
123+
)
124+
mps_graph.graph_type = OpType.metal_kernel
125+
114126
if node.target.__name__ in node_visitors:
115127
node_visitors[node.target.__name__].define_node(node, mps_graph)
116128
else:

backends/apple/mps/operators/indexing_ops.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Provided subject to the LICENSE file in the top level directory.
44
#
55

6-
from typing import cast
6+
from typing import cast, List
77

88
import torch
99
from executorch.backends.apple.mps.operators.node_visitor import (
@@ -13,9 +13,12 @@
1313
from executorch.backends.apple.mps.serialization.mps_graph_schema import (
1414
MPSEmbedding,
1515
MPSGraph,
16+
MPSIndexPut,
1617
MPSIndexSelect,
18+
MPSIndexTensor,
1719
)
1820
from executorch.backends.apple.mps.utils.mps_utils import get_input_node
21+
from executorch.backends.transforms import get_shape
1922
from executorch.exir.sym_util import eval_expr
2023

2124

@@ -40,6 +43,78 @@ def define_node(
4043
mps_graph.mps_nodes.append(mps_node)
4144

4245

46+
@register_node_visitor
47+
class IndexTensorVisitor(NodeVisitor):
48+
target = "aten.index.Tensor"
49+
50+
def __init__(self, *args) -> None:
51+
super().__init__(*args)
52+
53+
def define_node(
54+
self,
55+
node: torch.fx.Node,
56+
mps_graph: MPSGraph,
57+
) -> None:
58+
mps_node = self.create_unary_node(node, mps_graph, MPSIndexTensor)
59+
tensors = cast(List[torch.fx.Node], node.args[1])
60+
for tensor in tensors:
61+
mps_node.mpsnode_union.indices_id.append(
62+
self.define_tensor(tensor, mps_graph)
63+
)
64+
65+
mps_graph.mps_nodes.append(mps_node)
66+
67+
68+
# [MPS TODO]: Works on a single iteration of llama2, but subsequent tokens
69+
# are wrong when using Index put. Disabling it for now.
70+
@register_node_visitor
71+
class IndexPutVisitor(NodeVisitor):
72+
# target = "aten.index_put.default"
73+
target = "disabled"
74+
75+
def __init__(self, *args) -> None:
76+
super().__init__(*args)
77+
78+
def infer_sizes(self, a: List[int], b: List[int]):
79+
dimsA = len(a)
80+
dimsB = len(b)
81+
ndim = dimsA if dimsA > dimsB else dimsB
82+
expandedSizes = [0] * ndim
83+
for i in range(ndim - 1, -1, -1):
84+
offset = ndim - 1 - i
85+
dimA = dimsA - 1 - offset
86+
dimB = dimsB - 1 - offset
87+
sizeA = a[dimA] if dimA >= 0 else -1
88+
sizeB = b[dimB] if dimB >= 0 else -1
89+
expandedSizes[i] = sizeA if sizeB == -1 else sizeB
90+
91+
return expandedSizes
92+
93+
def define_node(
94+
self,
95+
node: torch.fx.Node,
96+
mps_graph: MPSGraph,
97+
) -> None:
98+
mps_node = self.create_unary_node(node, mps_graph, MPSIndexPut)
99+
updates_shape = get_shape(node.args[2])
100+
input_shape = get_shape(node.args[0])
101+
new_shape = []
102+
if len(updates_shape) != 1 and len(updates_shape) != len(input_shape):
103+
new_shape = self.infer_sizes(input_shape, updates_shape)
104+
mps_node.mpsnode_union.values_shape = new_shape
105+
106+
tensors = cast(List[torch.fx.Node], node.args[1])
107+
for tensor in tensors:
108+
mps_node.mpsnode_union.indices_id.append(
109+
self.define_tensor(tensor, mps_graph)
110+
)
111+
112+
mps_node.mpsnode_union.values_id = self.define_tensor(
113+
get_input_node(node, 2), mps_graph
114+
)
115+
mps_graph.mps_nodes.append(mps_node)
116+
117+
43118
@register_node_visitor
44119
class EmbeddingVisitor(NodeVisitor):
45120
target = "aten.embedding.default"

backends/apple/mps/operators/unary_ops.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
MPSLog,
3131
MPSLog10,
3232
MPSLog2,
33+
MPSLogicalNot,
3334
MPSNeg,
3435
MPSReciprocal,
3536
MPSRound,
@@ -79,6 +80,7 @@ class UnaryOpVisitor(NodeVisitor):
7980
"aten.isnan.default",
8081
"aten.isinf.default",
8182
"aten.round.default",
83+
"aten.logical_not.default",
8284
]
8385

8486
def __init__(self, *args) -> None:
@@ -115,6 +117,7 @@ def __init__(self, *args) -> None:
115117
exir_ops.edge.aten.isnan.default: MPSIsnan,
116118
exir_ops.edge.aten.isinf.default: MPSIsinf,
117119
exir_ops.edge.aten.round.default: MPSRound,
120+
exir_ops.edge.aten.logical_not.default: MPSLogicalNot,
118121
}
119122

120123
def define_node(

backends/apple/mps/partition/mps_partitioner.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
#
55

66
import logging
7-
from typing import Any, Dict, List, Union
7+
from typing import Any, cast, Dict, List, Union
88

99
import torch
1010
from executorch.backends.apple.mps.mps_preprocess import MPSBackend
1111
from executorch.backends.apple.mps.operators.node_visitor import get_node_visitors
1212
from executorch.backends.apple.mps.utils.mps_utils import is_parameter
13+
from executorch.backends.transforms import get_shape
1314
from executorch.exir.backend.backend_details import CompileSpec
1415
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
1516
generate_partitions_from_list_of_nodes,
@@ -20,6 +21,7 @@
2021
PartitionResult,
2122
)
2223
from executorch.exir.backend.utils import tag_constant_data
24+
from executorch.exir.dialects._ops import ops as exir_ops
2325
from torch.export.exported_program import ExportedProgram
2426
from torch.fx.passes.infra.partitioner import Partition
2527
from torch.fx.passes.operator_support import OperatorSupportBase
@@ -28,6 +30,13 @@
2830
logging.basicConfig(level=logging.DEBUG, format=FORMAT)
2931

3032

33+
# ops implemented as Metal kernels.
34+
METAL_KERNELS = [
35+
exir_ops.edge.aten.index.Tensor,
36+
exir_ops.edge.aten.index_put.default,
37+
]
38+
39+
3140
class MPSOperatorSupport(OperatorSupportBase):
3241
def __init__(self, edge_program: torch.export.ExportedProgram, compiler_specs):
3342
self.node_visitors = get_node_visitors(edge_program)
@@ -65,10 +74,47 @@ def generate_partitions(self, edge_program: ExportedProgram) -> List[Any]:
6574
op_support=self.supported_ops,
6675
)
6776

77+
def mps_graph_advanced_indexing_support(self, node: torch.fx.Node):
78+
num_indices = 0
79+
tensors = cast(List[torch.fx.Node], node.args[1])
80+
input = cast(torch.fx.Node, node.args[0])
81+
for t in tensors:
82+
if t is not None:
83+
num_indices += 1
84+
# Can dispatch to MPSGraph if the length of the slices is equal
85+
# to the number of dimensions of the sliced tensors, or only one
86+
# slice is present. All other cases will fallback to a Metal kernel.
87+
if num_indices == len(get_shape(input)) or num_indices == 1:
88+
return True
89+
90+
return False
91+
92+
def use_metal_kernel(self, node: torch.fx.Node):
93+
if node.target in METAL_KERNELS:
94+
if (
95+
node.target == exir_ops.edge.aten.index.Tensor
96+
or node.target == exir_ops.edge.aten.index_put.default
97+
):
98+
if not self.mps_graph_advanced_indexing_support(node):
99+
return True
100+
return False
101+
68102
def tag_nodes(self, partitions: List[Partition]) -> None:
69103
for partition in partitions:
70-
for node in partition.nodes:
104+
crt_partition_counter = 0
105+
for node in sorted(partition.nodes):
71106
delegation_tag = f"mps_{partition.id}"
107+
if self.use_metal_kernel(node):
108+
logging.warning(f"[WARNING] Using Metal kernel for op {node.name}!")
109+
# Partition the Metal kernel into a separate partition
110+
crt_partition_counter += 1
111+
delegation_tag = (
112+
f"{delegation_tag}_metal_kernel_{crt_partition_counter}"
113+
)
114+
crt_partition_counter += 1
115+
else:
116+
delegation_tag = f"{delegation_tag}_{crt_partition_counter}"
117+
72118
node.meta["delegation_tag"] = delegation_tag
73119
self.partition_tags[delegation_tag] = self.delegation_spec
74120

backends/apple/mps/runtime/MPSDevice.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,19 @@
55

66
#pragma once
77

8+
// Obj-C headers
89
#include <Foundation/Foundation.h>
910
#include <Metal/Metal.h>
11+
12+
// Runtime headers
13+
#include <executorch/runtime/backend/interface.h>
14+
15+
// MPS headers
1016
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
1117

18+
#include <unordered_map>
19+
#include <vector>
20+
1221
#define MB(x) (x * 1048576UL)
1322

1423
namespace torch {
@@ -25,6 +34,11 @@ enum class MacOSVersion : uint32_t {
2534
MACOS_VER_14_0_PLUS,
2635
};
2736

37+
enum class LibraryType : uint32_t {
38+
INDEXING_KERNELS = 0,
39+
MAX = INDEXING_KERNELS,
40+
};
41+
2842
class MPSDevice {
2943
public:
3044
/**
@@ -53,9 +67,18 @@ class MPSDevice {
5367

5468
~MPSDevice();
5569

70+
/**
71+
* Compile a PSO for a given library type.
72+
* Once compiled, the library and PSOs are cached.
73+
*/
74+
Error compilePSO(LibraryType libraryType, const char* kernelName);
75+
Error compileLibrary(LibraryType);
76+
5677
private:
5778
static MPSDevice* _device;
5879
id<MTLDevice> _mtl_device;
80+
std::unordered_map<LibraryType, id<MTLLibrary>> _m_library_cache;
81+
std::unordered_map<std::string, id<MTLComputePipelineState>> _m_pso_cache;
5982
MPSDevice();
6083
};
6184

backends/apple/mps/runtime/MPSDevice.mm

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,20 @@
1616
static std::unique_ptr<MPSDevice> mps_device;
1717
static std::once_flag mpsdev_init;
1818

19+
static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& device, bool macOS13Plus) {
20+
// MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants)
21+
// host_name attribute needs at least Metal 2.2 and ulong needs Metal 2.3 (supported on MacOS 11+)
22+
MTLLanguageVersion languageVersion = MTLLanguageVersion2_3;
23+
#if defined(__MAC_13_0)
24+
if (macOS13Plus) {
25+
languageVersion = MTLLanguageVersion3_0;
26+
}
27+
#endif
28+
29+
ET_CHECK_MSG([device supportsFamily:MTLGPUFamilyMac2], "Missing Metal support for MTLGPUFamilyMac2");
30+
return languageVersion;
31+
}
32+
1933
MPSDevice::~MPSDevice() {
2034
[_mtl_device release];
2135
_mtl_device = nil;
@@ -79,6 +93,57 @@
7993
}
8094
}
8195

96+
const char* getLibraryCString(LibraryType libraryType) {
97+
switch (libraryType) {
98+
case LibraryType::INDEXING_KERNELS:
99+
return "TODO";
100+
default:
101+
ET_CHECK_MSG(false, "Unhandled library type!");
102+
}
103+
}
104+
105+
Error
106+
MPSDevice::compileLibrary(LibraryType libraryType) {
107+
Error err = Error::Ok;
108+
NSError* error = nil;
109+
MTLCompileOptions* options = [MTLCompileOptions new];
110+
[options setLanguageVersion:getMetalLanguageVersion(_mtl_device, isMacOS13Plus(MacOSVersion::MACOS_VER_13_0_PLUS))];
111+
[options setFastMathEnabled:YES];
112+
id<MTLLibrary> lib =
113+
[_mtl_device newLibraryWithSource:[NSString stringWithCString:getLibraryCString(libraryType)
114+
encoding:NSASCIIStringEncoding]
115+
options:options
116+
error:&error];
117+
118+
ET_CHECK_OR_RETURN_ERROR(
119+
lib != nil,
120+
Internal,
121+
"Failed to create indexing library, error: %s", [[error description] UTF8String]
122+
);
123+
124+
_m_library_cache[libraryType] = lib;
125+
return err;
126+
}
127+
128+
Error
129+
MPSDevice::compilePSO(LibraryType libraryType, const char* kernelName) {
130+
Error err = Error::Ok;
131+
if (_m_library_cache.find(libraryType) == _m_library_cache.end()) {
132+
ET_LOG(Debug, "Compiling library type: %d", libraryType);
133+
err = compileLibrary(libraryType);
134+
ET_CHECK_OR_RETURN_ERROR(
135+
err == Error::Ok,
136+
Internal,
137+
"An error occured occured while compiling library %d", libraryType
138+
);
139+
}
140+
if (_m_pso_cache.find(kernelName) == _m_pso_cache.end()) {
141+
ET_LOG(Debug, "Compiling kernel: %s", kernelName);
142+
// err = compilePSO(libraryType, kernelName);
143+
}
144+
return err;
145+
}
146+
82147
bool isMacOS13OrNewer(MacOSVersion version) {
83148
return MPSDevice::getInstance()->isMacOS13Plus(version);
84149
}

0 commit comments

Comments
 (0)