From c251e0530838e0281f5cbac85bd0a9e9e622ac23 Mon Sep 17 00:00:00 2001 From: Anthony Shoumikhin Date: Wed, 26 Mar 2025 20:31:06 -0700 Subject: [PATCH] Add Module forward API with overloads. Summary: https://github.com/pytorch/executorch/issues/8363 Reviewed By: mergennachin Differential Revision: D71924237 --- .../ExecuTorch/Exported/ExecuTorchModule.h | 49 +++++++++++++++++++ .../ExecuTorch/Exported/ExecuTorchModule.mm | 31 ++++++++++++ .../ExecuTorch/__tests__/ModuleTest.swift | 2 +- 3 files changed, 81 insertions(+), 1 deletion(-) diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchModule.h b/extension/apple/ExecuTorch/Exported/ExecuTorchModule.h index 34789071ca..2e3e4e374e 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchModule.h +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchModule.h @@ -173,6 +173,55 @@ __attribute__((deprecated("This API is experimental."))) error:(NSError **)error NS_SWIFT_NAME(execute(_:_:)); +/** + * Executes the "forward" method with the provided input values. + * + * This is a convenience method that calls the executeMethod with "forward" as the method name. + * + * @param values An NSArray of ExecuTorchValue objects representing the inputs. + * @param error A pointer to an NSError pointer that is set if an error occurs. + * @return An NSArray of ExecuTorchValue objects representing the outputs, or nil in case of an error. + */ +- (nullable NSArray *)forwardWithInputs:(NSArray *)values + error:(NSError **)error + NS_SWIFT_NAME(forward(_:)); + +/** + * Executes the "forward" method with the provided single input value. + * + * This is a convenience method that calls the executeMethod with "forward" as the method name. + * + * @param value An ExecuTorchValue object representing the input. + * @param error A pointer to an NSError pointer that is set if an error occurs. + * @return An NSArray of ExecuTorchValue objects representing the outputs, or nil in case of an error. + */ +- (nullable NSArray *)forwardWithInput:(ExecuTorchValue *)value + error:(NSError **)error + NS_SWIFT_NAME(forward(_:)); + +/** + * Executes the "forward" method with no inputs. + * + * This is a convenience method that calls the executeMethod with "forward" as the method name. + * + * @param error A pointer to an NSError pointer that is set if an error occurs. + * @return An NSArray of ExecuTorchValue objects representing the outputs, or nil in case of an error. + */ +- (nullable NSArray *)forward:(NSError **)error; + +/** + * Executes the "forward" method with no inputs. + * + * This is a convenience method that calls the executeMethod with "forward" as the method name. + * + * @param tensors An NSArray of ExecuTorchTensor objects representing the inputs. + * @param error A pointer to an NSError pointer that is set if an error occurs. + * @return An NSArray of ExecuTorchValue objects representing the outputs, or nil in case of an error. + */ +- (nullable NSArray *)forwardWithTensors:(NSArray *)tensors + error:(NSError **)error + NS_SWIFT_NAME(forward(_:)); + + (instancetype)new NS_UNAVAILABLE; - (instancetype)init NS_UNAVAILABLE; diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm b/extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm index 243ab3c159..2e37cd3048 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm @@ -166,4 +166,35 @@ - (BOOL)isMethodLoaded:(NSString *)methodName { error:error]; } +- (nullable NSArray *)forwardWithInputs:(NSArray *)values + error:(NSError **)error { + return [self executeMethod:@"forward" + withInputs:values + error:error]; +} + +- (nullable NSArray *)forwardWithInput:(ExecuTorchValue *)value + error:(NSError **)error { + return [self executeMethod:@"forward" + withInputs:@[value] + error:error]; +} + +- (nullable NSArray *)forward:(NSError **)error { + return [self executeMethod:@"forward" + withInputs:@[] + error:error]; +} + +- (nullable NSArray *)forwardWithTensors:(NSArray *)tensors + error:(NSError **)error { + NSMutableArray *values = [NSMutableArray arrayWithCapacity:tensors.count]; + for (ExecuTorchTensor *tensor in tensors) { + [values addObject:[ExecuTorchValue valueWithTensor:tensor]]; + } + return [self executeMethod:@"forward" + withInputs:values + error:error]; +} + @end diff --git a/extension/apple/ExecuTorch/__tests__/ModuleTest.swift b/extension/apple/ExecuTorch/__tests__/ModuleTest.swift index 87e35d510c..51758dc221 100644 --- a/extension/apple/ExecuTorch/__tests__/ModuleTest.swift +++ b/extension/apple/ExecuTorch/__tests__/ModuleTest.swift @@ -65,7 +65,7 @@ class ModuleTest: XCTestCase { } let inputs = [inputTensor, inputTensor] var outputs: [Value]? - XCTAssertNoThrow(outputs = try module.execute("forward", inputs)) + XCTAssertNoThrow(outputs = try module.forward(inputs)) var outputData: [Float] = [2.0] let outputTensor = outputData.withUnsafeMutableBytes { Tensor(bytesNoCopy: $0.baseAddress!, shape:[1], dataType: .float, shapeDynamism: .static)