diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.java b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.java index a25c0bf6343..f71351ae6ae 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.java +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.java @@ -25,6 +25,8 @@ import java.io.InputStream; import java.net.URI; import java.net.URISyntaxException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicInteger; import java.io.IOException; import java.io.File; import java.io.FileOutputStream; @@ -42,6 +44,7 @@ public class ModuleInstrumentationTest { private static String FORWARD_METHOD = "forward"; private static String NONE_METHOD = "none"; private static int OK = 0x00; + private static int INVALID_STATE = 0x2; private static int INVALID_ARGUMENT = 0x12; private static int ACCESS_FAILED = 0x22; @@ -124,4 +127,63 @@ public void testNonPteFile() throws IOException{ int loadMethod = module.loadMethod(FORWARD_METHOD); assertEquals(loadMethod, INVALID_ARGUMENT); } + + @Test + public void testLoadOnDestroyedModule() throws IOException{ + Module module = Module.load(getTestFilePath(TEST_FILE_NAME)); + + module.destroy(); + + int loadMethod = module.loadMethod(FORWARD_METHOD); + assertEquals(loadMethod, INVALID_STATE); + } + + @Test + public void testForwardOnDestroyedModule() throws IOException{ + Module module = Module.load(getTestFilePath(TEST_FILE_NAME)); + + int loadMethod = module.loadMethod(FORWARD_METHOD); + assertEquals(loadMethod, OK); + + module.destroy(); + + EValue[] results = module.forward(); + assertEquals(0, results.length); + } + + @Test + public void testForwardFromMultipleThreads() throws InterruptedException, IOException { + Module module = Module.load(getTestFilePath(TEST_FILE_NAME)); + + int numThreads = 100; + CountDownLatch latch = new CountDownLatch(numThreads); + AtomicInteger completed = new AtomicInteger(0); + + Runnable runnable = new Runnable() { + @Override + public void run() { + try { + latch.countDown(); + latch.await(5000, java.util.concurrent.TimeUnit.MILLISECONDS); + EValue[] results = module.forward(); + assertTrue(results[0].isTensor()); + completed.incrementAndGet(); + } catch (InterruptedException e) { + + } + } + }; + + Thread[] threads = new Thread[numThreads]; + for (int i = 0; i < numThreads; i++) { + threads[i] = new Thread(runnable); + threads[i].start(); + } + + for (int i = 0; i < numThreads; i++) { + threads[i].join(); + } + + assertEquals(numThreads, completed.get()); + } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java index 879b88c5f2f..f3f543dc2a8 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java @@ -8,8 +8,11 @@ package org.pytorch.executorch; +import android.util.Log; import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.SystemDelegate; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; import org.pytorch.executorch.annotations.Experimental; /** @@ -35,6 +38,9 @@ public class Module { /** Reference to the NativePeer object of this module. */ private NativePeer mNativePeer; + /** Lock protecting the non-thread safe methods in NativePeer. */ + private Lock mLock = new ReentrantLock(); + /** * Loads a serialized ExecuTorch module from the specified path on the disk. * @@ -72,7 +78,16 @@ public static Module load(final String modelPath) { * @return return value from the 'forward' method. */ public EValue[] forward(EValue... inputs) { - return mNativePeer.forward(inputs); + try { + mLock.lock(); + if (mNativePeer == null) { + Log.e("ExecuTorch", "Attempt to use a destroyed module"); + return new EValue[0]; + } + return mNativePeer.forward(inputs); + } finally { + mLock.unlock(); + } } /** @@ -83,7 +98,16 @@ public EValue[] forward(EValue... inputs) { * @return return value from the method. */ public EValue[] execute(String methodName, EValue... inputs) { - return mNativePeer.execute(methodName, inputs); + try { + mLock.lock(); + if (mNativePeer == null) { + Log.e("ExecuTorch", "Attempt to use a destroyed module"); + return new EValue[0]; + } + return mNativePeer.execute(methodName, inputs); + } finally { + mLock.unlock(); + } } /** @@ -96,7 +120,16 @@ public EValue[] execute(String methodName, EValue... inputs) { * @return the Error code if there was an error loading the method */ public int loadMethod(String methodName) { - return mNativePeer.loadMethod(methodName); + try { + mLock.lock(); + if (mNativePeer == null) { + Log.e("ExecuTorch", "Attempt to use a destroyed module"); + return 0x2; // InvalidState + } + return mNativePeer.loadMethod(methodName); + } finally { + mLock.unlock(); + } } /** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */ @@ -111,6 +144,19 @@ public String[] readLogBuffer() { * more quickly. See {@link com.facebook.jni.HybridData#resetNative}. */ public void destroy() { - mNativePeer.resetNative(); + if (mLock.tryLock()) { + try { + mNativePeer.resetNative(); + } finally { + mNativePeer = null; + mLock.unlock(); + } + } else { + mNativePeer = null; + Log.w( + "ExecuTorch", + "Destroy was called while the module was in use. Resources will not be immediately" + + " released."); + } } }