From 7e2106c61b62c44cf4d699e8f882ba2adac26b8d Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Wed, 23 Apr 2025 14:31:48 -0700 Subject: [PATCH 1/2] [Android] Support 16bit data as raw data byte[] In java, when the returned dtype is fp16 or bf16, instead of crash, use byte[] to represent these raw data, and let user parse the byte[] --- .../java/org/pytorch/executorch/Tensor.java | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java index 6b32d90cda..9a9e1a3218 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java @@ -630,6 +630,44 @@ public String toString() { } } + /** + * A generic holder for 16b (fp16/bfloat16) tensor. Not intended for general use. User can only + * use #Tensor.getDataAsUnsignedByteArray() to get raw bytes. Users need to parse it. + */ + static class Tensor_raw_data_16b extends Tensor { + private final ByteBuffer data; + private final DType myDtype; + + private Tensor_raw_data_16b(ByteBuffer data, long[] shape, DType dtype) { + super(shape); + this.data = data; + this.myDtype = dtype; + } + + @Override + public DType dtype() { + return myDtype; + } + + @Override + Buffer getRawDataBuffer() { + return data; + } + + @Override + public byte[] getDataAsUnsignedByteArray() { + data.rewind(); + byte[] arr = new byte[data.remaining()]; + data.get(arr); + return arr; + } + + @Override + public String toString() { + return String.format("Tensor(%s, dtype=%d)", Arrays.toString(shape), this.myDtype); + } + } + // region checks private static void checkArgument(boolean expression, String errorMessage, Object... args) { if (!expression) { @@ -674,6 +712,10 @@ private static Tensor nativeNewTensor( tensor = new Tensor_uint8(data, shape); } else if (DType.INT8.jniCode == dtype) { tensor = new Tensor_int8(data, shape); + } else if (DType.HALF.jniCode == dtype) { + tensor = new Tensor_raw_data_16b(data, shape, DType.HALF); + } else if (DType.HALF.jniCode == dtype) { + tensor = new Tensor_raw_data_16b(data, shape, DType.BFLOAT16); } else { throw new IllegalArgumentException("Unknown Tensor dtype"); } From 833d38838ccfe345a6505913dba65cd95981ac93 Mon Sep 17 00:00:00 2001 From: Hansong <107070759+kirklandsign@users.noreply.github.com> Date: Wed, 23 Apr 2025 15:48:26 -0700 Subject: [PATCH 2/2] Update Tensor.java --- .../src/main/java/org/pytorch/executorch/Tensor.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java index 9a9e1a3218..5893fc5665 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java @@ -714,7 +714,7 @@ private static Tensor nativeNewTensor( tensor = new Tensor_int8(data, shape); } else if (DType.HALF.jniCode == dtype) { tensor = new Tensor_raw_data_16b(data, shape, DType.HALF); - } else if (DType.HALF.jniCode == dtype) { + } else if (DType.BFLOAT16.jniCode == dtype) { tensor = new Tensor_raw_data_16b(data, shape, DType.BFLOAT16); } else { throw new IllegalArgumentException("Unknown Tensor dtype");