Skip to content

Commit 1d93152

Browse files
Add Module loadMethod API. (#9685)
Summary: #8363 Reviewed By: mergennachin Differential Revision: D71917236 Co-authored-by: Anthony Shoumikhin <[email protected]>
1 parent 77fd564 commit 1d93152

File tree

3 files changed

+47
-0
lines changed

3 files changed

+47
-0
lines changed

extension/apple/ExecuTorch/Exported/ExecuTorchModule.h

+18
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,24 @@ __attribute__((deprecated("This API is experimental.")))
8686
*/
8787
- (BOOL)isLoaded;
8888

89+
/**
90+
* Loads a specific method from the program.
91+
*
92+
* @param methodName A string representing the name of the method to load.
93+
* @param error A pointer to an NSError pointer that is set if an error occurs.
94+
* @return YES if the method was successfully loaded; otherwise, NO.
95+
*/
96+
- (BOOL)loadMethod:(NSString *)methodName
97+
error:(NSError **)error NS_SWIFT_NAME(load(_:));
98+
99+
/**
100+
* Checks if a specific method is loaded.
101+
*
102+
* @param methodName A string representing the method name.
103+
* @return YES if the method is loaded; otherwise, NO.
104+
*/
105+
- (BOOL)isMethodLoaded:(NSString *)methodName NS_SWIFT_NAME(isLoaded(_:));
106+
89107
+ (instancetype)new NS_UNAVAILABLE;
90108
- (instancetype)init NS_UNAVAILABLE;
91109

extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm

+18
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,22 @@ - (BOOL)isLoaded {
5959
return _module->is_loaded();
6060
}
6161

62+
- (BOOL)loadMethod:(NSString *)methodName
63+
error:(NSError **)error {
64+
const auto errorCode = _module->load_method(methodName.UTF8String);
65+
if (errorCode != Error::Ok) {
66+
if (error) {
67+
*error = [NSError errorWithDomain:ExecuTorchErrorDomain
68+
code:(NSInteger)errorCode
69+
userInfo:nil];
70+
}
71+
return NO;
72+
}
73+
return YES;
74+
}
75+
76+
- (BOOL)isMethodLoaded:(NSString *)methodName {
77+
return _module->is_method_loaded(methodName.UTF8String);
78+
}
79+
6280
@end

extension/apple/ExecuTorch/__tests__/ModuleTest.swift

+11
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,15 @@ class ModuleTest: XCTestCase {
2828
XCTAssertNoThrow(try module.load())
2929
XCTAssertTrue(module.isLoaded())
3030
}
31+
32+
func testLoadMethod() {
33+
let bundle = Bundle(for: type(of: self))
34+
guard let modelPath = bundle.path(forResource: "add", ofType: "pte") else {
35+
XCTFail("Couldn't find the model file")
36+
return
37+
}
38+
let module = Module(filePath: modelPath)
39+
XCTAssertNoThrow(try module.load("forward"))
40+
XCTAssertTrue(module.isLoaded("forward"))
41+
}
3142
}

0 commit comments

Comments
 (0)