Skip to content

Commit 4076f93

Browse files
Tensor constructor to create with an array of scalars. (#9692)
Summary: #8366 Reviewed By: bsoyluoglu Differential Revision: D71929714 Co-authored-by: Anthony Shoumikhin <[email protected]>
1 parent dd3122d commit 4076f93

File tree

3 files changed

+71
-0
lines changed

3 files changed

+71
-0
lines changed

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,4 +411,29 @@ __attribute__((deprecated("This API is experimental.")))
411411

412412
@end
413413

414+
#pragma mark - Scalars Category
415+
416+
@interface ExecuTorchTensor (Scalars)
417+
418+
/**
419+
* Initializes a tensor with an array of scalar values and full tensor properties.
420+
*
421+
* @param scalars An NSArray of NSNumber objects representing the scalar values.
422+
* @param shape An NSArray of NSNumber objects representing the desired tensor shape.
423+
* @param strides An NSArray of NSNumber objects representing the tensor strides.
424+
* @param dimensionOrder An NSArray of NSNumber objects indicating the order of dimensions.
425+
* @param dataType An ExecuTorchDataType value specifying the element type.
426+
* @param shapeDynamism An ExecuTorchShapeDynamism value indicating the shape dynamism.
427+
* @return An initialized ExecuTorchTensor instance containing the provided scalar values.
428+
*/
429+
- (instancetype)initWithScalars:(NSArray<NSNumber *> *)scalars
430+
shape:(NSArray<NSNumber *> *)shape
431+
strides:(NSArray<NSNumber *> *)strides
432+
dimensionOrder:(NSArray<NSNumber *> *)dimensionOrder
433+
dataType:(ExecuTorchDataType)dataType
434+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism
435+
NS_SWIFT_NAME(init(_:shape:strides:dimensionOrder:dataType:shapeDynamism:));
436+
437+
@end
438+
414439
NS_ASSUME_NONNULL_END

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,3 +335,36 @@ - (instancetype)initWithData:(NSData *)data
335335
}
336336

337337
@end
338+
339+
@implementation ExecuTorchTensor (Scalars)
340+
341+
- (instancetype)initWithScalars:(NSArray<NSNumber *> *)scalars
342+
shape:(NSArray<NSNumber *> *)shape
343+
strides:(NSArray<NSNumber *> *)strides
344+
dimensionOrder:(NSArray<NSNumber *> *)dimensionOrder
345+
dataType:(ExecuTorchDataType)dataType
346+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism {
347+
const NSInteger count = scalars.count;
348+
ET_CHECK_MSG(count == ExecuTorchElementCountOfShape(shape),
349+
"Number of scalars does not match the shape");
350+
std::vector<uint8_t> data;
351+
data.resize(count * ExecuTorchSizeOfDataType(dataType));
352+
for (NSUInteger index = 0; index < count; ++index) {
353+
ET_SWITCH_REALHBBF16_AND_UINT_TYPES(
354+
static_cast<ScalarType>(dataType), nil, "initWithScalars", CTYPE, [&] {
355+
reinterpret_cast<CTYPE *>(data.data())[index] = utils::extractValue<CTYPE>(scalars[index]);
356+
}
357+
);
358+
}
359+
auto tensor = make_tensor_ptr(
360+
utils::toVector<SizesType>(shape),
361+
std::move(data),
362+
utils::toVector<DimOrderType>(dimensionOrder),
363+
utils::toVector<StridesType>(strides),
364+
static_cast<ScalarType>(dataType),
365+
static_cast<TensorShapeDynamism>(shapeDynamism)
366+
);
367+
return [self initWithNativeInstance:&tensor];
368+
}
369+
370+
@end

extension/apple/ExecuTorch/__tests__/TensorTest.swift

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,4 +223,17 @@ class TensorTest: XCTestCase {
223223
XCTAssertFalse(tensor1.isEqual(NSString(string: "Not a tensor")))
224224
XCTAssertFalse(tensor4.isEqual(tensor2.copy()))
225225
}
226+
227+
func testInitScalarsFloat() {
228+
let data: [Float] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
229+
let tensor = Tensor(data.map(NSNumber.init), shape: [2, 3], strides: [3, 1], dimensionOrder: [0, 1], dataType: .float, shapeDynamism: .dynamicBound)
230+
XCTAssertEqual(tensor.dataType, .float)
231+
XCTAssertEqual(tensor.shape, [2, 3])
232+
XCTAssertEqual(tensor.strides, [3, 1])
233+
XCTAssertEqual(tensor.dimensionOrder, [0, 1])
234+
XCTAssertEqual(tensor.count, 6)
235+
tensor.bytes { pointer, count, dataType in
236+
XCTAssertEqual(Array(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Float.self), count: count)), data)
237+
}
238+
}
226239
}

0 commit comments

Comments
 (0)