diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchModule.h b/extension/apple/ExecuTorch/Exported/ExecuTorchModule.h index 34789071caa..2e3e4e374ef 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchModule.h +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchModule.h @@ -173,6 +173,55 @@ __attribute__((deprecated("This API is experimental."))) error:(NSError **)error NS_SWIFT_NAME(execute(_:_:)); +/** + * Executes the "forward" method with the provided input values. + * + * This is a convenience method that calls the executeMethod with "forward" as the method name. + * + * @param values An NSArray of ExecuTorchValue objects representing the inputs. + * @param error A pointer to an NSError pointer that is set if an error occurs. + * @return An NSArray of ExecuTorchValue objects representing the outputs, or nil in case of an error. + */ +- (nullable NSArray *)forwardWithInputs:(NSArray *)values + error:(NSError **)error + NS_SWIFT_NAME(forward(_:)); + +/** + * Executes the "forward" method with the provided single input value. + * + * This is a convenience method that calls the executeMethod with "forward" as the method name. + * + * @param value An ExecuTorchValue object representing the input. + * @param error A pointer to an NSError pointer that is set if an error occurs. + * @return An NSArray of ExecuTorchValue objects representing the outputs, or nil in case of an error. + */ +- (nullable NSArray *)forwardWithInput:(ExecuTorchValue *)value + error:(NSError **)error + NS_SWIFT_NAME(forward(_:)); + +/** + * Executes the "forward" method with no inputs. + * + * This is a convenience method that calls the executeMethod with "forward" as the method name. + * + * @param error A pointer to an NSError pointer that is set if an error occurs. + * @return An NSArray of ExecuTorchValue objects representing the outputs, or nil in case of an error. + */ +- (nullable NSArray *)forward:(NSError **)error; + +/** + * Executes the "forward" method with no inputs. + * + * This is a convenience method that calls the executeMethod with "forward" as the method name. + * + * @param tensors An NSArray of ExecuTorchTensor objects representing the inputs. + * @param error A pointer to an NSError pointer that is set if an error occurs. + * @return An NSArray of ExecuTorchValue objects representing the outputs, or nil in case of an error. + */ +- (nullable NSArray *)forwardWithTensors:(NSArray *)tensors + error:(NSError **)error + NS_SWIFT_NAME(forward(_:)); + + (instancetype)new NS_UNAVAILABLE; - (instancetype)init NS_UNAVAILABLE; diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm b/extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm index 243ab3c159b..2e37cd30484 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm @@ -166,4 +166,35 @@ - (BOOL)isMethodLoaded:(NSString *)methodName { error:error]; } +- (nullable NSArray *)forwardWithInputs:(NSArray *)values + error:(NSError **)error { + return [self executeMethod:@"forward" + withInputs:values + error:error]; +} + +- (nullable NSArray *)forwardWithInput:(ExecuTorchValue *)value + error:(NSError **)error { + return [self executeMethod:@"forward" + withInputs:@[value] + error:error]; +} + +- (nullable NSArray *)forward:(NSError **)error { + return [self executeMethod:@"forward" + withInputs:@[] + error:error]; +} + +- (nullable NSArray *)forwardWithTensors:(NSArray *)tensors + error:(NSError **)error { + NSMutableArray *values = [NSMutableArray arrayWithCapacity:tensors.count]; + for (ExecuTorchTensor *tensor in tensors) { + [values addObject:[ExecuTorchValue valueWithTensor:tensor]]; + } + return [self executeMethod:@"forward" + withInputs:values + error:error]; +} + @end diff --git a/extension/apple/ExecuTorch/__tests__/ModuleTest.swift b/extension/apple/ExecuTorch/__tests__/ModuleTest.swift index 87e35d510ce..51758dc221b 100644 --- a/extension/apple/ExecuTorch/__tests__/ModuleTest.swift +++ b/extension/apple/ExecuTorch/__tests__/ModuleTest.swift @@ -65,7 +65,7 @@ class ModuleTest: XCTestCase { } let inputs = [inputTensor, inputTensor] var outputs: [Value]? - XCTAssertNoThrow(outputs = try module.execute("forward", inputs)) + XCTAssertNoThrow(outputs = try module.forward(inputs)) var outputData: [Float] = [2.0] let outputTensor = outputData.withUnsafeMutableBytes { Tensor(bytesNoCopy: $0.baseAddress!, shape:[1], dataType: .float, shapeDynamism: .static)