Skip to content

Make Android Module thread-safe and prevent destruction during inference #9833

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import java.io.InputStream;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger;
import java.io.IOException;
import java.io.File;
import java.io.FileOutputStream;
Expand All @@ -42,6 +44,7 @@ public class ModuleInstrumentationTest {
private static String FORWARD_METHOD = "forward";
private static String NONE_METHOD = "none";
private static int OK = 0x00;
private static int INVALID_STATE = 0x2;
private static int INVALID_ARGUMENT = 0x12;
private static int ACCESS_FAILED = 0x22;

Expand Down Expand Up @@ -124,4 +127,63 @@ public void testNonPteFile() throws IOException{
int loadMethod = module.loadMethod(FORWARD_METHOD);
assertEquals(loadMethod, INVALID_ARGUMENT);
}

@Test
public void testLoadOnDestroyedModule() throws IOException{
Module module = Module.load(getTestFilePath(TEST_FILE_NAME));

module.destroy();

int loadMethod = module.loadMethod(FORWARD_METHOD);
assertEquals(loadMethod, INVALID_STATE);
}

@Test
public void testForwardOnDestroyedModule() throws IOException{
Module module = Module.load(getTestFilePath(TEST_FILE_NAME));

int loadMethod = module.loadMethod(FORWARD_METHOD);
assertEquals(loadMethod, OK);

module.destroy();

EValue[] results = module.forward();
assertEquals(0, results.length);
}

@Test
public void testForwardFromMultipleThreads() throws InterruptedException, IOException {
Module module = Module.load(getTestFilePath(TEST_FILE_NAME));

int numThreads = 100;
CountDownLatch latch = new CountDownLatch(numThreads);
AtomicInteger completed = new AtomicInteger(0);

Runnable runnable = new Runnable() {
@Override
public void run() {
try {
latch.countDown();
latch.await(5000, java.util.concurrent.TimeUnit.MILLISECONDS);
EValue[] results = module.forward();
assertTrue(results[0].isTensor());
completed.incrementAndGet();
} catch (InterruptedException e) {

}
}
};

Thread[] threads = new Thread[numThreads];
for (int i = 0; i < numThreads; i++) {
threads[i] = new Thread(runnable);
threads[i].start();
}

for (int i = 0; i < numThreads; i++) {
threads[i].join();
}

assertEquals(numThreads, completed.get());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@

package org.pytorch.executorch;

import android.util.Log;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import org.pytorch.executorch.annotations.Experimental;

/**
Expand All @@ -35,6 +38,9 @@ public class Module {
/** Reference to the NativePeer object of this module. */
private NativePeer mNativePeer;

/** Lock protecting the non-thread safe methods in NativePeer. */
private Lock mLock = new ReentrantLock();

/**
* Loads a serialized ExecuTorch module from the specified path on the disk.
*
Expand Down Expand Up @@ -72,7 +78,16 @@ public static Module load(final String modelPath) {
* @return return value from the 'forward' method.
*/
public EValue[] forward(EValue... inputs) {
return mNativePeer.forward(inputs);
try {
mLock.lock();
if (mNativePeer == null) {
Log.e("ExecuTorch", "Attempt to use a destroyed module");
return new EValue[0];
}
return mNativePeer.forward(inputs);
} finally {
mLock.unlock();
}
}

/**
Expand All @@ -83,7 +98,16 @@ public EValue[] forward(EValue... inputs) {
* @return return value from the method.
*/
public EValue[] execute(String methodName, EValue... inputs) {
return mNativePeer.execute(methodName, inputs);
try {
mLock.lock();
if (mNativePeer == null) {
Log.e("ExecuTorch", "Attempt to use a destroyed module");
return new EValue[0];
}
return mNativePeer.execute(methodName, inputs);
} finally {
mLock.unlock();
}
}

/**
Expand All @@ -96,7 +120,16 @@ public EValue[] execute(String methodName, EValue... inputs) {
* @return the Error code if there was an error loading the method
*/
public int loadMethod(String methodName) {
return mNativePeer.loadMethod(methodName);
try {
mLock.lock();
if (mNativePeer == null) {
Log.e("ExecuTorch", "Attempt to use a destroyed module");
return 0x2; // InvalidState
}
return mNativePeer.loadMethod(methodName);
} finally {
mLock.unlock();
}
}

/** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */
Expand All @@ -111,6 +144,19 @@ public String[] readLogBuffer() {
* more quickly. See {@link com.facebook.jni.HybridData#resetNative}.
*/
public void destroy() {
mNativePeer.resetNative();
if (mLock.tryLock()) {
try {
mNativePeer.resetNative();
} finally {
mNativePeer = null;
mLock.unlock();
}
} else {
mNativePeer = null;
Log.w(
"ExecuTorch",
"Destroy was called while the module was in use. Resources will not be immediately"
+ " released.");
}
}
}
Loading