Skip to content

Commit 6185e42

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
(WIP) Make Android Module thread-safe and prevent destruction during inference (#9833)
Summary: While the Android Module interface was originally not designed to be thread safe, we've seen a sizable number of issues pop up due to users not fully meeting the thread safety requirements that we impose on the caller. Empirically, this is not always obvious when writing app code and can sneak in in subtle ways. Common issues are calling forward from a different thread while one inference is already in progress and not synchronizing module cleanup with inference. Both have caused crashes that are sometimes difficult for users to debug. This PR attempts to mitigate these issues by adding explicit synchronization in the Java Module class. Both method load and execution are behind a lock, and destroy will warn and avoid immediate destruction if an inference is in progress. I'm hesitant to directly acquire the lock in destroy, since it can get called in certain cleanup paths. Instead, I'm just warning and setting the native peer to null so it should get GC'd once out of use. Differential Revision: D72273052
1 parent 1572381 commit 6185e42

File tree

2 files changed

+112
-4
lines changed

2 files changed

+112
-4
lines changed

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.java

+62
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import java.io.InputStream;
2626
import java.net.URI;
2727
import java.net.URISyntaxException;
28+
import java.util.concurrent.CountDownLatch;
29+
import java.util.concurrent.atomic.AtomicInteger;
2830
import java.io.IOException;
2931
import java.io.File;
3032
import java.io.FileOutputStream;
@@ -42,6 +44,7 @@ public class ModuleInstrumentationTest {
4244
private static String FORWARD_METHOD = "forward";
4345
private static String NONE_METHOD = "none";
4446
private static int OK = 0x00;
47+
private static int INVALID_STATE = 0x2;
4548
private static int INVALID_ARGUMENT = 0x12;
4649
private static int ACCESS_FAILED = 0x22;
4750

@@ -124,4 +127,63 @@ public void testNonPteFile() throws IOException{
124127
int loadMethod = module.loadMethod(FORWARD_METHOD);
125128
assertEquals(loadMethod, INVALID_ARGUMENT);
126129
}
130+
131+
@Test
132+
public void testLoadOnDestroyedModule() throws IOException{
133+
Module module = Module.load(getTestFilePath(TEST_FILE_NAME));
134+
135+
module.destroy();
136+
137+
int loadMethod = module.loadMethod(FORWARD_METHOD);
138+
assertEquals(loadMethod, INVALID_STATE);
139+
}
140+
141+
@Test
142+
public void testForwardOnDestroyedModule() throws IOException{
143+
Module module = Module.load(getTestFilePath(TEST_FILE_NAME));
144+
145+
int loadMethod = module.loadMethod(FORWARD_METHOD);
146+
assertEquals(loadMethod, OK);
147+
148+
module.destroy();
149+
150+
EValue[] results = module.forward();
151+
assertEquals(0, results.length);
152+
}
153+
154+
@Test
155+
public void testForwardFromMultipleThreads() throws InterruptedException, IOException {
156+
Module module = Module.load(getTestFilePath(TEST_FILE_NAME));
157+
158+
int numThreads = 100;
159+
CountDownLatch latch = new CountDownLatch(numThreads);
160+
AtomicInteger completed = new AtomicInteger(0);
161+
162+
Runnable runnable = new Runnable() {
163+
@Override
164+
public void run() {
165+
try {
166+
latch.countDown();
167+
latch.await(5000, java.util.concurrent.TimeUnit.MILLISECONDS);
168+
EValue[] results = module.forward();
169+
assertTrue(results[0].isTensor());
170+
completed.incrementAndGet();
171+
} catch (InterruptedException e) {
172+
173+
}
174+
}
175+
};
176+
177+
Thread[] threads = new Thread[numThreads];
178+
for (int i = 0; i < numThreads; i++) {
179+
threads[i] = new Thread(runnable);
180+
threads[i].start();
181+
}
182+
183+
for (int i = 0; i < numThreads; i++) {
184+
threads[i].join();
185+
}
186+
187+
assertEquals(numThreads, completed.get());
188+
}
127189
}

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java

+50-4
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,11 @@
88

99
package org.pytorch.executorch;
1010

11+
import android.util.Log;
1112
import com.facebook.soloader.nativeloader.NativeLoader;
1213
import com.facebook.soloader.nativeloader.SystemDelegate;
14+
import java.util.concurrent.locks.Lock;
15+
import java.util.concurrent.locks.ReentrantLock;
1316
import org.pytorch.executorch.annotations.Experimental;
1417

1518
/**
@@ -35,6 +38,9 @@ public class Module {
3538
/** Reference to the NativePeer object of this module. */
3639
private NativePeer mNativePeer;
3740

41+
/** Lock protecting the non-thread safe methods in NativePeer. */
42+
private Lock mLock = new ReentrantLock();
43+
3844
/**
3945
* Loads a serialized ExecuTorch module from the specified path on the disk.
4046
*
@@ -72,7 +78,16 @@ public static Module load(final String modelPath) {
7278
* @return return value from the 'forward' method.
7379
*/
7480
public EValue[] forward(EValue... inputs) {
75-
return mNativePeer.forward(inputs);
81+
try {
82+
mLock.lock();
83+
if (mNativePeer == null) {
84+
Log.e("ExecuTorch", "Attempt to use a destroyed module");
85+
return new EValue[0];
86+
}
87+
return mNativePeer.forward(inputs);
88+
} finally {
89+
mLock.unlock();
90+
}
7691
}
7792

7893
/**
@@ -83,7 +98,16 @@ public EValue[] forward(EValue... inputs) {
8398
* @return return value from the method.
8499
*/
85100
public EValue[] execute(String methodName, EValue... inputs) {
86-
return mNativePeer.execute(methodName, inputs);
101+
try {
102+
mLock.lock();
103+
if (mNativePeer == null) {
104+
Log.e("ExecuTorch", "Attempt to use a destroyed module");
105+
return new EValue[0];
106+
}
107+
return mNativePeer.execute(methodName, inputs);
108+
} finally {
109+
mLock.unlock();
110+
}
87111
}
88112

89113
/**
@@ -96,7 +120,16 @@ public EValue[] execute(String methodName, EValue... inputs) {
96120
* @return the Error code if there was an error loading the method
97121
*/
98122
public int loadMethod(String methodName) {
99-
return mNativePeer.loadMethod(methodName);
123+
try {
124+
mLock.lock();
125+
if (mNativePeer == null) {
126+
Log.e("ExecuTorch", "Attempt to use a destroyed module");
127+
return 0x2; // InvalidState
128+
}
129+
return mNativePeer.loadMethod(methodName);
130+
} finally {
131+
mLock.unlock();
132+
}
100133
}
101134

102135
/** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */
@@ -111,6 +144,19 @@ public String[] readLogBuffer() {
111144
* more quickly. See {@link com.facebook.jni.HybridData#resetNative}.
112145
*/
113146
public void destroy() {
114-
mNativePeer.resetNative();
147+
if (mLock.tryLock()) {
148+
try {
149+
mNativePeer.resetNative();
150+
} finally {
151+
mNativePeer = null;
152+
mLock.unlock();
153+
}
154+
} else {
155+
mNativePeer = null;
156+
Log.w(
157+
"ExecuTorch",
158+
"Destroy was called while the module was in use. Resources will not be immediately"
159+
+ " released.");
160+
}
115161
}
116162
}

0 commit comments

Comments
 (0)