diff --git a/backends/apple/mps/runtime/MPSDevice.mm b/backends/apple/mps/runtime/MPSDevice.mm index c34b571c3a9..7f4c0bde9e5 100644 --- a/backends/apple/mps/runtime/MPSDevice.mm +++ b/backends/apple/mps/runtime/MPSDevice.mm @@ -22,11 +22,11 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id& de // MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants) // host_name attribute needs at least Metal 2.2 and ulong needs Metal 2.3 (supported on MacOS 11+) MTLLanguageVersion languageVersion = MTLLanguageVersion2_3; -#if defined(__MAC_13_0) - if (macOS13Plus) { - languageVersion = MTLLanguageVersion3_0; + if (@available(iOS 16, macOS 13, *)) { + if (macOS13Plus) { + languageVersion = MTLLanguageVersion3_0; + } } -#endif ET_CHECK_MSG([device supportsFamily:MTLGPUFamilyMac2], "Missing Metal support for MTLGPUFamilyMac2"); return languageVersion; diff --git a/backends/apple/mps/runtime/operations/IndexingOps.mm b/backends/apple/mps/runtime/operations/IndexingOps.mm index d4015d10253..34a03851655 100644 --- a/backends/apple/mps/runtime/operations/IndexingOps.mm +++ b/backends/apple/mps/runtime/operations/IndexingOps.mm @@ -206,25 +206,32 @@ Error MPSGraphBuilder::mpsScatterOp(NodePtr nodePtr) { - auto graphNode = nodePtr->mpsnode_union_as_MPSScatter(); - ET_LOG( - Debug, "%s %d: %d", - __FUNCTION__, graphNode->input1_id(), graphNode->output_id() - ); + if (@available(iOS 15.4, macOS 12.3, *)) { + auto graphNode = nodePtr->mpsnode_union_as_MPSScatter(); + ET_LOG( + Debug, "%s %d: %d", + __FUNCTION__, graphNode->input1_id(), graphNode->output_id() + ); - int64_t dim = graphNode->dim(); - MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id()); - MPSGraphTensor* indicesTensor = getMPSGraphTensor(graphNode->idx_id()); - MPSGraphTensor* updatesTensor = getMPSGraphTensor(graphNode->src_id()); + int64_t dim = graphNode->dim(); + MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id()); + MPSGraphTensor* indicesTensor = getMPSGraphTensor(graphNode->idx_id()); + MPSGraphTensor* updatesTensor = getMPSGraphTensor(graphNode->src_id()); - _idToMPSGraphTensor[graphNode->output_id()] = - [_mpsGraph scatterAlongAxis:dim - withDataTensor:inputTensor - updatesTensor:updatesTensor - indicesTensor:indicesTensor - mode:MPSGraphScatterModeSet - name:nil]; - return Error::Ok; + _idToMPSGraphTensor[graphNode->output_id()] = + [_mpsGraph scatterAlongAxis:dim + withDataTensor:inputTensor + updatesTensor:updatesTensor + indicesTensor:indicesTensor + mode:MPSGraphScatterModeSet + name:nil]; + + return Error::Ok; + } else { + ET_LOG(Error, "MPS: scatter op is not supported on iOS < 15.4 and macOS < 12.3"); + + return Error::NotSupported; + } }