diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/DType.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/DType.java index f5d33d0b71..3aca4871d6 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/DType.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/DType.java @@ -73,4 +73,13 @@ public enum DType { DType(int jniCode) { this.jniCode = jniCode; } + + public static DType fromJniCode(int jniCode) { + for (DType dtype : values()) { + if (dtype.jniCode == jniCode) { + return dtype; + } + } + throw new IllegalArgumentException("No DType found for jniCode " + jniCode); + } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java index 6b32d90cda..1a30baba2f 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java @@ -8,6 +8,7 @@ package org.pytorch.executorch; +import android.util.Log; import com.facebook.jni.HybridData; import com.facebook.jni.annotations.DoNotStrip; import java.nio.Buffer; @@ -630,6 +631,30 @@ public String toString() { } } + static class Tensor_unsupported extends Tensor { + private final ByteBuffer data; + private final DType mDtype; + + private Tensor_unsupported(ByteBuffer data, long[] shape, DType dtype) { + super(shape); + this.data = data; + this.mDtype = dtype; + Log.e( + "ExecuTorch", + toString() + " in Java. Please consider re-export the model with proper return type"); + } + + @Override + public DType dtype() { + return mDtype; + } + + @Override + public String toString() { + return String.format("Unsupported tensor(%s, dtype=%d)", Arrays.toString(shape), this.mDtype); + } + } + // region checks private static void checkArgument(boolean expression, String errorMessage, Object... args) { if (!expression) { @@ -675,7 +700,7 @@ private static Tensor nativeNewTensor( } else if (DType.INT8.jniCode == dtype) { tensor = new Tensor_int8(data, shape); } else { - throw new IllegalArgumentException("Unknown Tensor dtype"); + tensor = new Tensor_unsupported(data, shape, DType.fromJniCode(dtype)); } tensor.mHybridData = hybridData; return tensor;