Skip to content

Commit 0031ea8

Browse files
Add Module forward API with overloads. (#9689)
Summary: #8363 Reviewed By: mergennachin Differential Revision: D71924237 Co-authored-by: Anthony Shoumikhin <[email protected]>
1 parent 5f8eaf7 commit 0031ea8

File tree

3 files changed

+81
-1
lines changed

3 files changed

+81
-1
lines changed

extension/apple/ExecuTorch/Exported/ExecuTorchModule.h

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,55 @@ __attribute__((deprecated("This API is experimental.")))
173173
error:(NSError **)error
174174
NS_SWIFT_NAME(execute(_:_:));
175175

176+
/**
177+
* Executes the "forward" method with the provided input values.
178+
*
179+
* This is a convenience method that calls the executeMethod with "forward" as the method name.
180+
*
181+
* @param values An NSArray of ExecuTorchValue objects representing the inputs.
182+
* @param error A pointer to an NSError pointer that is set if an error occurs.
183+
* @return An NSArray of ExecuTorchValue objects representing the outputs, or nil in case of an error.
184+
*/
185+
- (nullable NSArray<ExecuTorchValue *> *)forwardWithInputs:(NSArray<ExecuTorchValue *> *)values
186+
error:(NSError **)error
187+
NS_SWIFT_NAME(forward(_:));
188+
189+
/**
190+
* Executes the "forward" method with the provided single input value.
191+
*
192+
* This is a convenience method that calls the executeMethod with "forward" as the method name.
193+
*
194+
* @param value An ExecuTorchValue object representing the input.
195+
* @param error A pointer to an NSError pointer that is set if an error occurs.
196+
* @return An NSArray of ExecuTorchValue objects representing the outputs, or nil in case of an error.
197+
*/
198+
- (nullable NSArray<ExecuTorchValue *> *)forwardWithInput:(ExecuTorchValue *)value
199+
error:(NSError **)error
200+
NS_SWIFT_NAME(forward(_:));
201+
202+
/**
203+
* Executes the "forward" method with no inputs.
204+
*
205+
* This is a convenience method that calls the executeMethod with "forward" as the method name.
206+
*
207+
* @param error A pointer to an NSError pointer that is set if an error occurs.
208+
* @return An NSArray of ExecuTorchValue objects representing the outputs, or nil in case of an error.
209+
*/
210+
- (nullable NSArray<ExecuTorchValue *> *)forward:(NSError **)error;
211+
212+
/**
213+
* Executes the "forward" method with no inputs.
214+
*
215+
* This is a convenience method that calls the executeMethod with "forward" as the method name.
216+
*
217+
* @param tensors An NSArray of ExecuTorchTensor objects representing the inputs.
218+
* @param error A pointer to an NSError pointer that is set if an error occurs.
219+
* @return An NSArray of ExecuTorchValue objects representing the outputs, or nil in case of an error.
220+
*/
221+
- (nullable NSArray<ExecuTorchValue *> *)forwardWithTensors:(NSArray<ExecuTorchTensor *> *)tensors
222+
error:(NSError **)error
223+
NS_SWIFT_NAME(forward(_:));
224+
176225
+ (instancetype)new NS_UNAVAILABLE;
177226
- (instancetype)init NS_UNAVAILABLE;
178227

extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,4 +166,35 @@ - (BOOL)isMethodLoaded:(NSString *)methodName {
166166
error:error];
167167
}
168168

169+
- (nullable NSArray<ExecuTorchValue *> *)forwardWithInputs:(NSArray<ExecuTorchValue *> *)values
170+
error:(NSError **)error {
171+
return [self executeMethod:@"forward"
172+
withInputs:values
173+
error:error];
174+
}
175+
176+
- (nullable NSArray<ExecuTorchValue *> *)forwardWithInput:(ExecuTorchValue *)value
177+
error:(NSError **)error {
178+
return [self executeMethod:@"forward"
179+
withInputs:@[value]
180+
error:error];
181+
}
182+
183+
- (nullable NSArray<ExecuTorchValue *> *)forward:(NSError **)error {
184+
return [self executeMethod:@"forward"
185+
withInputs:@[]
186+
error:error];
187+
}
188+
189+
- (nullable NSArray<ExecuTorchValue *> *)forwardWithTensors:(NSArray<ExecuTorchTensor *> *)tensors
190+
error:(NSError **)error {
191+
NSMutableArray<ExecuTorchValue *> *values = [NSMutableArray arrayWithCapacity:tensors.count];
192+
for (ExecuTorchTensor *tensor in tensors) {
193+
[values addObject:[ExecuTorchValue valueWithTensor:tensor]];
194+
}
195+
return [self executeMethod:@"forward"
196+
withInputs:values
197+
error:error];
198+
}
199+
169200
@end

extension/apple/ExecuTorch/__tests__/ModuleTest.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class ModuleTest: XCTestCase {
6565
}
6666
let inputs = [inputTensor, inputTensor]
6767
var outputs: [Value]?
68-
XCTAssertNoThrow(outputs = try module.execute("forward", inputs))
68+
XCTAssertNoThrow(outputs = try module.forward(inputs))
6969
var outputData: [Float] = [2.0]
7070
let outputTensor = outputData.withUnsafeMutableBytes {
7171
Tensor(bytesNoCopy: $0.baseAddress!, shape:[1], dataType: .float, shapeDynamism: .static)

0 commit comments

Comments
 (0)