diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchModule.h b/extension/apple/ExecuTorch/Exported/ExecuTorchModule.h index b309b55dd75..a73379ff4c7 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchModule.h +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchModule.h @@ -86,6 +86,24 @@ __attribute__((deprecated("This API is experimental."))) */ - (BOOL)isLoaded; +/** + * Loads a specific method from the program. + * + * @param methodName A string representing the name of the method to load. + * @param error A pointer to an NSError pointer that is set if an error occurs. + * @return YES if the method was successfully loaded; otherwise, NO. + */ +- (BOOL)loadMethod:(NSString *)methodName + error:(NSError **)error NS_SWIFT_NAME(load(_:)); + +/** + * Checks if a specific method is loaded. + * + * @param methodName A string representing the method name. + * @return YES if the method is loaded; otherwise, NO. + */ +- (BOOL)isMethodLoaded:(NSString *)methodName NS_SWIFT_NAME(isLoaded(_:)); + + (instancetype)new NS_UNAVAILABLE; - (instancetype)init NS_UNAVAILABLE; diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm b/extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm index 1d62b9f2a1c..246d6324de0 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm @@ -59,4 +59,22 @@ - (BOOL)isLoaded { return _module->is_loaded(); } +- (BOOL)loadMethod:(NSString *)methodName + error:(NSError **)error { + const auto errorCode = _module->load_method(methodName.UTF8String); + if (errorCode != Error::Ok) { + if (error) { + *error = [NSError errorWithDomain:ExecuTorchErrorDomain + code:(NSInteger)errorCode + userInfo:nil]; + } + return NO; + } + return YES; +} + +- (BOOL)isMethodLoaded:(NSString *)methodName { + return _module->is_method_loaded(methodName.UTF8String); +} + @end diff --git a/extension/apple/ExecuTorch/__tests__/ModuleTest.swift b/extension/apple/ExecuTorch/__tests__/ModuleTest.swift index f3b85f23ac1..e94820a43c3 100644 --- a/extension/apple/ExecuTorch/__tests__/ModuleTest.swift +++ b/extension/apple/ExecuTorch/__tests__/ModuleTest.swift @@ -28,4 +28,15 @@ class ModuleTest: XCTestCase { XCTAssertNoThrow(try module.load()) XCTAssertTrue(module.isLoaded()) } + + func testLoadMethod() { + let bundle = Bundle(for: type(of: self)) + guard let modelPath = bundle.path(forResource: "add", ofType: "pte") else { + XCTFail("Couldn't find the model file") + return + } + let module = Module(filePath: modelPath) + XCTAssertNoThrow(try module.load("forward")) + XCTAssertTrue(module.isLoaded("forward")) + } }