From 097db64cd648fbaebdcb5da9a7330f7ec3bbd7af Mon Sep 17 00:00:00 2001 From: Anthony Shoumikhin Date: Wed, 26 Mar 2025 20:28:18 -0700 Subject: [PATCH] Add Module loadMethod API. Summary: https://github.com/pytorch/executorch/issues/8363 Reviewed By: mergennachin Differential Revision: D71917236 --- .../ExecuTorch/Exported/ExecuTorchModule.h | 18 ++++++++++++++++++ .../ExecuTorch/Exported/ExecuTorchModule.mm | 18 ++++++++++++++++++ .../ExecuTorch/__tests__/ModuleTest.swift | 11 +++++++++++ 3 files changed, 47 insertions(+) diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchModule.h b/extension/apple/ExecuTorch/Exported/ExecuTorchModule.h index b309b55dd75..a73379ff4c7 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchModule.h +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchModule.h @@ -86,6 +86,24 @@ __attribute__((deprecated("This API is experimental."))) */ - (BOOL)isLoaded; +/** + * Loads a specific method from the program. + * + * @param methodName A string representing the name of the method to load. + * @param error A pointer to an NSError pointer that is set if an error occurs. + * @return YES if the method was successfully loaded; otherwise, NO. + */ +- (BOOL)loadMethod:(NSString *)methodName + error:(NSError **)error NS_SWIFT_NAME(load(_:)); + +/** + * Checks if a specific method is loaded. + * + * @param methodName A string representing the method name. + * @return YES if the method is loaded; otherwise, NO. + */ +- (BOOL)isMethodLoaded:(NSString *)methodName NS_SWIFT_NAME(isLoaded(_:)); + + (instancetype)new NS_UNAVAILABLE; - (instancetype)init NS_UNAVAILABLE; diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm b/extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm index 1d62b9f2a1c..246d6324de0 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm @@ -59,4 +59,22 @@ - (BOOL)isLoaded { return _module->is_loaded(); } +- (BOOL)loadMethod:(NSString *)methodName + error:(NSError **)error { + const auto errorCode = _module->load_method(methodName.UTF8String); + if (errorCode != Error::Ok) { + if (error) { + *error = [NSError errorWithDomain:ExecuTorchErrorDomain + code:(NSInteger)errorCode + userInfo:nil]; + } + return NO; + } + return YES; +} + +- (BOOL)isMethodLoaded:(NSString *)methodName { + return _module->is_method_loaded(methodName.UTF8String); +} + @end diff --git a/extension/apple/ExecuTorch/__tests__/ModuleTest.swift b/extension/apple/ExecuTorch/__tests__/ModuleTest.swift index f3b85f23ac1..e94820a43c3 100644 --- a/extension/apple/ExecuTorch/__tests__/ModuleTest.swift +++ b/extension/apple/ExecuTorch/__tests__/ModuleTest.swift @@ -28,4 +28,15 @@ class ModuleTest: XCTestCase { XCTAssertNoThrow(try module.load()) XCTAssertTrue(module.isLoaded()) } + + func testLoadMethod() { + let bundle = Bundle(for: type(of: self)) + guard let modelPath = bundle.path(forResource: "add", ofType: "pte") else { + XCTFail("Couldn't find the model file") + return + } + let module = Module(filePath: modelPath) + XCTAssertNoThrow(try module.load("forward")) + XCTAssertTrue(module.isLoaded("forward")) + } }