From a4eb736f0454f9735b29953bd0435d0f20cea57a Mon Sep 17 00:00:00 2001 From: Anthony Shoumikhin Date: Wed, 26 Mar 2025 20:11:54 -0700 Subject: [PATCH] Add Tensor resizing API. Summary: https://github.com/pytorch/executorch/issues/8366 Reviewed By: mergennachin Differential Revision: D71909752 --- .../ExecuTorch/Exported/ExecuTorchTensor.h | 11 ++++++++ .../ExecuTorch/Exported/ExecuTorchTensor.mm | 20 +++++++++++++++ .../ExecuTorch/__tests__/TensorTest.swift | 25 +++++++++++++++++++ 3 files changed, 56 insertions(+) diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h index 152a74b7cb2..dba0448cc1e 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h @@ -189,6 +189,17 @@ __attribute__((deprecated("This API is experimental."))) - (void)mutableBytesWithHandler:(void (^)(void *pointer, NSInteger count, ExecuTorchDataType dataType))handler NS_SWIFT_NAME(mutableBytes(_:)); +/** + * Resizes the tensor to a new shape. + * + * @param shape An NSArray of NSNumber objects representing the desired new shape. + * @param error A pointer to an NSError pointer that is set if an error occurs. + * @return YES if the tensor was successfully resized; otherwise, NO. + */ +- (BOOL)resizeToShape:(NSArray *)shape + error:(NSError **)error + NS_SWIFT_NAME(resize(to:)); + + (instancetype)new NS_UNAVAILABLE; - (instancetype)init NS_UNAVAILABLE; diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm index 912bc4f59d2..b302f1614eb 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm @@ -15,6 +15,7 @@ using namespace executorch::aten; using namespace executorch::extension; +using namespace executorch::runtime; NSInteger ExecuTorchSizeOfDataType(ExecuTorchDataType dataType) { return elementSize(static_cast(dataType)); @@ -108,6 +109,25 @@ - (void)mutableBytesWithHandler:(void (^)(void *pointer, NSInteger count, ExecuT handler(_tensor->unsafeGetTensorImpl()->mutable_data(), self.count, self.dataType); } +- (BOOL)resizeToShape:(NSArray *)shape + error:(NSError **)error { + const auto resizeError = resize_tensor_ptr( + _tensor, utils::toVector(shape) + ); + if (resizeError != Error::Ok) { + if (error) { + *error = [NSError errorWithDomain:ExecuTorchErrorDomain + code:(NSInteger)resizeError + userInfo:nil]; + } + return NO; + } + _shape = nil; + _strides = nil; + _dimensionOrder = nil; + return YES; +} + @end @implementation ExecuTorchTensor (BytesNoCopy) diff --git a/extension/apple/ExecuTorch/__tests__/TensorTest.swift b/extension/apple/ExecuTorch/__tests__/TensorTest.swift index fef9da87906..33148d98e33 100644 --- a/extension/apple/ExecuTorch/__tests__/TensorTest.swift +++ b/extension/apple/ExecuTorch/__tests__/TensorTest.swift @@ -165,4 +165,29 @@ class TensorTest: XCTestCase { XCTAssertEqual(tensor1.dimensionOrder, tensor2.dimensionOrder) XCTAssertEqual(tensor1.count, tensor2.count) } + + func testResize() { + var data: [Int] = [1, 2, 3, 4] + let tensor = data.withUnsafeMutableBytes { + Tensor(bytesNoCopy: $0.baseAddress!, shape: [4, 1], dataType: .int) + } + XCTAssertNoThrow(try tensor.resize(to: [2, 2])) + XCTAssertEqual(tensor.dataType, .int) + XCTAssertEqual(tensor.shape, [2, 2]) + XCTAssertEqual(tensor.strides, [2, 1]) + XCTAssertEqual(tensor.dimensionOrder, [0, 1]) + XCTAssertEqual(tensor.count, 4) + + tensor.bytes { pointer, count, dataType in + XCTAssertEqual(Array(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Int.self), count: count)), data) + } + } + + func testResizeError() { + var data: [Int] = [1, 2, 3, 4] + let tensor = data.withUnsafeMutableBytes { + Tensor(bytesNoCopy: $0.baseAddress!, shape: [4, 1], dataType: .int) + } + XCTAssertThrowsError(try tensor.resize(to: [2, 3])) + } }