Skip to content

Commit c98e26f

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Add Tensor equality API.
Summary: #8366 Reviewed By: mergennachin Differential Revision: D71910685
1 parent 32b9449 commit c98e26f

File tree

3 files changed

+57
-0
lines changed

3 files changed

+57
-0
lines changed

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h

+8
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,14 @@ __attribute__((deprecated("This API is experimental.")))
200200
error:(NSError **)error
201201
NS_SWIFT_NAME(resize(to:));
202202

203+
/**
204+
* Determines whether the current tensor is equal to another tensor.
205+
*
206+
* @param other Another ExecuTorchTensor instance to compare against.
207+
* @return YES if the tensors have the same type, shape, strides, and data; otherwise, NO.
208+
*/
209+
- (BOOL)isEqualToTensor:(nullable ExecuTorchTensor *)other;
210+
203211
+ (instancetype)new NS_UNAVAILABLE;
204212
- (instancetype)init NS_UNAVAILABLE;
205213

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm

+26
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,32 @@ - (BOOL)resizeToShape:(NSArray<NSNumber *> *)shape
128128
return YES;
129129
}
130130

131+
- (BOOL)isEqualToTensor:(nullable ExecuTorchTensor *)other {
132+
if (!other) {
133+
return NO;
134+
}
135+
const auto *data = _tensor->unsafeGetTensorImpl()->data();
136+
const auto *otherData = other->_tensor->unsafeGetTensorImpl()->data();
137+
const auto size = self.count * ExecuTorchSizeOfDataType(self.dataType);
138+
return self.dataType == other.dataType &&
139+
self.count == other.count &&
140+
[self.shape isEqual:other.shape] &&
141+
[self.dimensionOrder isEqual:other.dimensionOrder] &&
142+
[self.strides isEqual:other.strides] &&
143+
self.shapeDynamism == other.shapeDynamism &&
144+
(data && otherData ? std::memcmp(data, otherData, size) == 0 : data == otherData);
145+
}
146+
147+
- (BOOL)isEqual:(nullable id)other {
148+
if (self == other) {
149+
return YES;
150+
}
151+
if (![other isKindOfClass:[ExecuTorchTensor class]]) {
152+
return NO;
153+
}
154+
return [self isEqualToTensor:(ExecuTorchTensor *)other];
155+
}
156+
131157
@end
132158

133159
@implementation ExecuTorchTensor (BytesNoCopy)

extension/apple/ExecuTorch/__tests__/TensorTest.swift

+23
Original file line numberDiff line numberDiff line change
@@ -190,4 +190,27 @@ class TensorTest: XCTestCase {
190190
}
191191
XCTAssertThrowsError(try tensor.resize(to: [2, 3]))
192192
}
193+
194+
func testIsEqual() {
195+
var data: [Float] = [1.0, 2.0, 3.0, 4.0]
196+
let tensor1 = data.withUnsafeMutableBytes {
197+
Tensor(bytesNoCopy: $0.baseAddress!, shape: [2, 2], dataType: .float)
198+
}
199+
let tensor2 = Tensor(tensor1)
200+
XCTAssertTrue(tensor1.isEqual(tensor2))
201+
XCTAssertTrue(tensor2.isEqual(tensor1))
202+
203+
var dataModified: [Float] = [1.0, 2.0, 3.0, 5.0]
204+
let tensor3 = dataModified.withUnsafeMutableBytes {
205+
Tensor(bytesNoCopy: $0.baseAddress!, shape: [2, 2], dataType: .float)
206+
}
207+
XCTAssertFalse(tensor1.isEqual(tensor3))
208+
let tensor4 = data.withUnsafeMutableBytes {
209+
Tensor(bytesNoCopy: $0.baseAddress!, shape: [4, 1], dataType: .float)
210+
}
211+
XCTAssertFalse(tensor1.isEqual(tensor4))
212+
XCTAssertTrue(tensor1.isEqual(tensor1))
213+
XCTAssertFalse(tensor1.isEqual(NSString(string: "Not a tensor")))
214+
XCTAssertFalse(tensor4.isEqual(tensor2.copy()))
215+
}
193216
}

0 commit comments

Comments
 (0)