Skip to content

Commit df0e087

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Add Module execute API.
Summary: #8363 Reviewed By: bsoyluoglu Differential Revision: D71919658
1 parent 750b1f1 commit df0e087

File tree

3 files changed

+81
-0
lines changed

3 files changed

+81
-0
lines changed

extension/apple/ExecuTorch/Exported/ExecuTorchModule.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,21 @@ __attribute__((deprecated("This API is experimental.")))
115115
*/
116116
- (nullable NSSet<NSString *> *)methodNames:(NSError **)error;
117117

118+
/**
119+
* Executes a specific method with the provided input values.
120+
*
121+
* The method is loaded on demand if not already loaded.
122+
*
123+
* @param methodName A string representing the method name.
124+
* @param values An NSArray of ExecuTorchValue objects representing the inputs.
125+
* @param error A pointer to an NSError pointer that is set if an error occurs.
126+
* @return An NSArray of ExecuTorchValue objects representing the outputs, or nil in case of an error.
127+
*/
128+
- (nullable NSArray<ExecuTorchValue *> *)executeMethod:(NSString *)methodName
129+
withInputs:(NSArray<ExecuTorchValue *> *)values
130+
error:(NSError **)error
131+
NS_SWIFT_NAME(execute(_:_:));
132+
118133
+ (instancetype)new NS_UNAVAILABLE;
119134
- (instancetype)init NS_UNAVAILABLE;
120135

extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,27 @@
1616
using namespace executorch::extension;
1717
using namespace executorch::runtime;
1818

19+
static inline EValue toEValue(ExecuTorchValue *value) {
20+
if (value.isTensor) {
21+
auto *nativeTensorPtr = value.tensorValue.nativeInstance;
22+
ET_CHECK(nativeTensorPtr);
23+
auto nativeTensor = *reinterpret_cast<TensorPtr *>(nativeTensorPtr);
24+
ET_CHECK(nativeTensor);
25+
return *nativeTensor;
26+
}
27+
ET_CHECK_MSG(false, "Unsupported ExecuTorchValue type");
28+
return EValue();
29+
}
30+
31+
static inline ExecuTorchValue *toExecuTorchValue(EValue value) {
32+
if (value.isTensor()) {
33+
auto nativeInstance = make_tensor_ptr(value.toTensor());
34+
return [ExecuTorchValue valueWithTensor:[[ExecuTorchTensor alloc] initWithNativeInstance:&nativeInstance]];
35+
}
36+
ET_CHECK_MSG(false, "Unsupported EValue type");
37+
return [ExecuTorchValue new];
38+
}
39+
1940
@implementation ExecuTorchModule {
2041
std::unique_ptr<Module> _module;
2142
}
@@ -94,4 +115,28 @@ - (BOOL)isMethodLoaded:(NSString *)methodName {
94115
return methods;
95116
}
96117

118+
- (nullable NSArray<ExecuTorchValue *> *)executeMethod:(NSString *)methodName
119+
withInputs:(NSArray<ExecuTorchValue *> *)values
120+
error:(NSError **)error {
121+
std::vector<EValue> inputs;
122+
inputs.reserve(values.count);
123+
for (ExecuTorchValue *value in values) {
124+
inputs.push_back(toEValue(value));
125+
}
126+
const auto result = _module->execute(methodName.UTF8String, inputs);
127+
if (!result.ok()) {
128+
if (error) {
129+
*error = [NSError errorWithDomain:ExecuTorchErrorDomain
130+
code:(NSInteger)result.error()
131+
userInfo:nil];
132+
}
133+
return nil;
134+
}
135+
NSMutableArray<ExecuTorchValue *> *outputs = [NSMutableArray arrayWithCapacity:result->size()];
136+
for (const auto &value : *result) {
137+
[outputs addObject:toExecuTorchValue(value)];
138+
}
139+
return outputs;
140+
}
141+
97142
@end

extension/apple/ExecuTorch/__tests__/ModuleTest.swift

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,25 @@ class ModuleTest: XCTestCase {
5151
XCTAssertNoThrow(methodNames = try module.methodNames())
5252
XCTAssertEqual(methodNames, Set(["forward"]))
5353
}
54+
55+
func testExecute() {
56+
let bundle = Bundle(for: type(of: self))
57+
guard let modelPath = bundle.path(forResource: "add", ofType: "pte") else {
58+
XCTFail("Couldn't find the model file")
59+
return
60+
}
61+
let module = Module(filePath: modelPath)
62+
var inputData: [Float] = [1.0]
63+
let inputTensor = inputData.withUnsafeMutableBytes {
64+
Tensor(bytesNoCopy: $0.baseAddress!, shape:[1], dataType: .float)
65+
}
66+
let inputs = [Value(inputTensor), Value(inputTensor)]
67+
var outputs: [Value]?
68+
XCTAssertNoThrow(outputs = try module.execute("forward", inputs))
69+
var outputData: [Float] = [2.0]
70+
let outputTensor = outputData.withUnsafeMutableBytes {
71+
Tensor(bytesNoCopy: $0.baseAddress!, shape:[1], dataType: .float, shapeDynamism: .static)
72+
}
73+
XCTAssertEqual(outputs?[0].tensor, outputTensor)
74+
}
5475
}

0 commit comments

Comments
 (0)