Skip to content

Commit c1b1109

Browse files
Franco Melonifacebook-github-bot
Franco Meloni
authored andcommitted
Add MPS Backend (#9095)
Summary: Reviewed By: cccclai Differential Revision: D70795041
1 parent 4c54bab commit c1b1109

File tree

2 files changed

+28
-21
lines changed

2 files changed

+28
-21
lines changed

backends/apple/mps/runtime/MPSDevice.mm

+4-4
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
2222
// MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants)
2323
// host_name attribute needs at least Metal 2.2 and ulong needs Metal 2.3 (supported on MacOS 11+)
2424
MTLLanguageVersion languageVersion = MTLLanguageVersion2_3;
25-
#if defined(__MAC_13_0)
26-
if (macOS13Plus) {
27-
languageVersion = MTLLanguageVersion3_0;
25+
if (@available(iOS 16, macOS 13, *)) {
26+
if (macOS13Plus) {
27+
languageVersion = MTLLanguageVersion3_0;
28+
}
2829
}
29-
#endif
3030

3131
ET_CHECK_MSG([device supportsFamily:MTLGPUFamilyMac2], "Missing Metal support for MTLGPUFamilyMac2");
3232
return languageVersion;

backends/apple/mps/runtime/operations/IndexingOps.mm

+24-17
Original file line numberDiff line numberDiff line change
@@ -206,25 +206,32 @@
206206

207207
Error
208208
MPSGraphBuilder::mpsScatterOp(NodePtr nodePtr) {
209-
auto graphNode = nodePtr->mpsnode_union_as_MPSScatter();
210-
ET_LOG(
211-
Debug, "%s %d: %d",
212-
__FUNCTION__, graphNode->input1_id(), graphNode->output_id()
213-
);
209+
if (@available(iOS 15.4, macOS 12.3, *)) {
210+
auto graphNode = nodePtr->mpsnode_union_as_MPSScatter();
211+
ET_LOG(
212+
Debug, "%s %d: %d",
213+
__FUNCTION__, graphNode->input1_id(), graphNode->output_id()
214+
);
214215

215-
int64_t dim = graphNode->dim();
216-
MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id());
217-
MPSGraphTensor* indicesTensor = getMPSGraphTensor(graphNode->idx_id());
218-
MPSGraphTensor* updatesTensor = getMPSGraphTensor(graphNode->src_id());
216+
int64_t dim = graphNode->dim();
217+
MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id());
218+
MPSGraphTensor* indicesTensor = getMPSGraphTensor(graphNode->idx_id());
219+
MPSGraphTensor* updatesTensor = getMPSGraphTensor(graphNode->src_id());
219220

220-
_idToMPSGraphTensor[graphNode->output_id()] =
221-
[_mpsGraph scatterAlongAxis:dim
222-
withDataTensor:inputTensor
223-
updatesTensor:updatesTensor
224-
indicesTensor:indicesTensor
225-
mode:MPSGraphScatterModeSet
226-
name:nil];
227-
return Error::Ok;
221+
_idToMPSGraphTensor[graphNode->output_id()] =
222+
[_mpsGraph scatterAlongAxis:dim
223+
withDataTensor:inputTensor
224+
updatesTensor:updatesTensor
225+
indicesTensor:indicesTensor
226+
mode:MPSGraphScatterModeSet
227+
name:nil];
228+
229+
return Error::Ok;
230+
} else {
231+
ET_LOG(Error, "MPS: scatter op is not supported on iOS < 15.4 and macOS < 12.3");
232+
233+
return Error::NotSupported;
234+
}
228235
}
229236

230237

0 commit comments

Comments
 (0)