Skip to content

Commit 4bf6700

Browse files
Add Module load API. (#9684)
Summary: #8363 Reviewed By: mergennachin Differential Revision: D71916550 Co-authored-by: Anthony Shoumikhin <[email protected]>
1 parent da52a92 commit 4bf6700

File tree

3 files changed

+66
-2
lines changed

3 files changed

+66
-2
lines changed

extension/apple/ExecuTorch/Exported/ExecuTorchModule.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@ typedef NS_ENUM(NSInteger, ExecuTorchModuleLoadMode) {
2222
ExecuTorchModuleLoadModeMmapUseMlockIgnoreErrors,
2323
} NS_SWIFT_NAME(ModuleLoadMode);
2424

25+
/**
26+
* Enum to define the verification level used when loading a module.
27+
* Values can be a subset, but must numerically match exactly those defined in
28+
* runtime/executor/program.h
29+
*/
30+
typedef NS_ENUM(uint8_t, ExecuTorchVerification) {
31+
ExecuTorchVerificationMinimal,
32+
ExecuTorchVerificationInternalConsistency,
33+
} NS_SWIFT_NAME(ModuleVerification);
34+
2535
/**
2636
* Represents a module that encapsulates an ExecuTorch program.
2737
* This class is a facade for loading programs and executing methods within them.
@@ -49,6 +59,33 @@ __attribute__((deprecated("This API is experimental.")))
4959
*/
5060
- (instancetype)initWithFilePath:(NSString *)filePath;
5161

62+
/**
63+
* Loads the module’s program using the specified verification level.
64+
*
65+
* @param verification The verification level to apply when loading the program.
66+
* @param error A pointer to an NSError pointer that will be set if an error occurs.
67+
* @return YES if the program was successfully loaded; otherwise, NO.
68+
*/
69+
- (BOOL)loadWithVerification:(ExecuTorchVerification)verification
70+
error:(NSError **)error;
71+
72+
/**
73+
* Loads the module’s program using minimal verification.
74+
*
75+
* This is a convenience overload that defaults the verification level to Minimal.
76+
*
77+
* @param error A pointer to an NSError pointer that will be set if an error occurs.
78+
* @return YES if the program was successfully loaded; otherwise, NO.
79+
*/
80+
- (BOOL)load:(NSError **)error;
81+
82+
/**
83+
* Checks if the module is loaded.
84+
*
85+
* @return YES if the module's program is loaded; otherwise, NO.
86+
*/
87+
- (BOOL)isLoaded;
88+
5289
+ (instancetype)new NS_UNAVAILABLE;
5390
- (instancetype)init NS_UNAVAILABLE;
5491

extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#import <executorch/extension/tensor/tensor.h>
1515

1616
using namespace executorch::extension;
17+
using namespace executorch::runtime;
1718

1819
@implementation ExecuTorchModule {
1920
std::unique_ptr<Module> _module;
@@ -35,4 +36,27 @@ - (instancetype)initWithFilePath:(NSString *)filePath {
3536
return [self initWithFilePath:filePath loadMode:ExecuTorchModuleLoadModeFile];
3637
}
3738

39+
- (BOOL)loadWithVerification:(ExecuTorchVerification)verification
40+
error:(NSError **)error {
41+
const auto errorCode = _module->load(static_cast<Program::Verification>(verification));
42+
if (errorCode != Error::Ok) {
43+
if (error) {
44+
*error = [NSError errorWithDomain:ExecuTorchErrorDomain
45+
code:(NSInteger)errorCode
46+
userInfo:nil];
47+
}
48+
return NO;
49+
}
50+
return YES;
51+
}
52+
53+
- (BOOL)load:(NSError **)error {
54+
return [self loadWithVerification:ExecuTorchVerificationMinimal
55+
error:error];
56+
}
57+
58+
- (BOOL)isLoaded {
59+
return _module->is_loaded();
60+
}
61+
3862
@end

extension/apple/ExecuTorch/__tests__/ModuleTest.swift

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,14 @@ class ModuleTest: XCTestCase {
1818
return Bundle(for: type(of: self))
1919
#endif
2020
}
21-
22-
func test() throws {
21+
22+
func testLoad() {
2323
guard let modelPath = resourceBundle.path(forResource: "add", ofType: "pte") else {
2424
XCTFail("Couldn't find the model file")
2525
return
2626
}
27+
let module = Module(filePath: modelPath)
28+
XCTAssertNoThrow(try module.load())
29+
XCTAssertTrue(module.isLoaded())
2730
}
2831
}

0 commit comments

Comments
 (0)