Skip to content

Commit 6de630b

Browse files
authored
HADOOP-19235. IPC client uses CompletableFuture to support asynchronous operations. (#6888)
1 parent 78204d9 commit 6de630b

File tree

2 files changed

+173
-74
lines changed

2 files changed

+173
-74
lines changed

hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java

Lines changed: 82 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ public class Client implements AutoCloseable {
9696
private static final ThreadLocal<Integer> retryCount = new ThreadLocal<Integer>();
9797
private static final ThreadLocal<Object> EXTERNAL_CALL_HANDLER
9898
= new ThreadLocal<>();
99-
private static final ThreadLocal<AsyncGet<? extends Writable, IOException>>
100-
ASYNC_RPC_RESPONSE = new ThreadLocal<>();
99+
private static final ThreadLocal<CompletableFuture<Writable>> ASYNC_RPC_RESPONSE
100+
= new ThreadLocal<>();
101101
private static final ThreadLocal<Boolean> asynchronousMode =
102102
new ThreadLocal<Boolean>() {
103103
@Override
@@ -110,7 +110,46 @@ protected Boolean initialValue() {
110110
@Unstable
111111
public static <T extends Writable> AsyncGet<T, IOException>
112112
getAsyncRpcResponse() {
113-
return (AsyncGet<T, IOException>) ASYNC_RPC_RESPONSE.get();
113+
CompletableFuture<Writable> responseFuture = ASYNC_RPC_RESPONSE.get();
114+
return new AsyncGet<T, IOException>() {
115+
@Override
116+
public T get(long timeout, TimeUnit unit)
117+
throws IOException, TimeoutException, InterruptedException {
118+
try {
119+
if (unit == null || timeout < 0) {
120+
return (T) responseFuture.get();
121+
}
122+
return (T) responseFuture.get(timeout, unit);
123+
} catch (ExecutionException e) {
124+
Throwable cause = e.getCause();
125+
if (cause instanceof IOException) {
126+
throw (IOException) cause;
127+
}
128+
throw new IllegalStateException(e);
129+
}
130+
}
131+
132+
@Override
133+
public boolean isDone() {
134+
return responseFuture.isDone();
135+
}
136+
};
137+
}
138+
139+
/**
140+
* Retrieves the current response future from the thread-local storage.
141+
*
142+
* @return A {@link CompletableFuture} of type T that represents the
143+
* asynchronous operation. If no response future is present in
144+
* the thread-local storage, this method returns {@code null}.
145+
* @param <T> The type of the value completed by the returned
146+
* {@link CompletableFuture}. It must be a subclass of
147+
* {@link Writable}.
148+
* @see CompletableFuture
149+
* @see Writable
150+
*/
151+
public static <T extends Writable> CompletableFuture<T> getResponseFuture() {
152+
return (CompletableFuture<T>) ASYNC_RPC_RESPONSE.get();
114153
}
115154

116155
/**
@@ -277,10 +316,8 @@ static class Call {
277316
final int id; // call id
278317
final int retry; // retry count
279318
final Writable rpcRequest; // the serialized rpc request
280-
Writable rpcResponse; // null if rpc has error
281-
IOException error; // exception, null if success
319+
private final CompletableFuture<Writable> rpcResponseFuture;
282320
final RPC.RpcKind rpcKind; // Rpc EngineKind
283-
boolean done; // true when call is done
284321
private final Object externalHandler;
285322
private AlignmentContext alignmentContext;
286323

@@ -304,6 +341,7 @@ private Call(RPC.RpcKind rpcKind, Writable param) {
304341
}
305342

306343
this.externalHandler = EXTERNAL_CALL_HANDLER.get();
344+
this.rpcResponseFuture = new CompletableFuture<>();
307345
}
308346

309347
@Override
@@ -314,9 +352,6 @@ public String toString() {
314352
/** Indicate when the call is complete and the
315353
* value or error are available. Notifies by default. */
316354
protected synchronized void callComplete() {
317-
this.done = true;
318-
notify(); // notify caller
319-
320355
if (externalHandler != null) {
321356
synchronized (externalHandler) {
322357
externalHandler.notify();
@@ -339,7 +374,7 @@ public synchronized void setAlignmentContext(AlignmentContext ac) {
339374
* @param error exception thrown by the call; either local or remote
340375
*/
341376
public synchronized void setException(IOException error) {
342-
this.error = error;
377+
rpcResponseFuture.completeExceptionally(error);
343378
callComplete();
344379
}
345380

@@ -349,13 +384,9 @@ public synchronized void setException(IOException error) {
349384
* @param rpcResponse return value of the rpc call.
350385
*/
351386
public synchronized void setRpcResponse(Writable rpcResponse) {
352-
this.rpcResponse = rpcResponse;
387+
rpcResponseFuture.complete(rpcResponse);
353388
callComplete();
354389
}
355-
356-
public synchronized Writable getRpcResponse() {
357-
return rpcResponse;
358-
}
359390
}
360391

361392
/** Thread that reads responses and notifies callers. Each connection owns a
@@ -1495,39 +1526,19 @@ Writable call(RPC.RpcKind rpcKind, Writable rpcRequest,
14951526
}
14961527

14971528
if (isAsynchronousMode()) {
1498-
final AsyncGet<Writable, IOException> asyncGet
1499-
= new AsyncGet<Writable, IOException>() {
1500-
@Override
1501-
public Writable get(long timeout, TimeUnit unit)
1502-
throws IOException, TimeoutException{
1503-
boolean done = true;
1504-
try {
1505-
final Writable w = getRpcResponse(call, connection, timeout, unit);
1506-
if (w == null) {
1507-
done = false;
1508-
throw new TimeoutException(call + " timed out "
1509-
+ timeout + " " + unit);
1510-
}
1511-
return w;
1512-
} finally {
1513-
if (done) {
1514-
releaseAsyncCall();
1529+
CompletableFuture<Writable> result = call.rpcResponseFuture.handle(
1530+
(rpcResponse, e) -> {
1531+
releaseAsyncCall();
1532+
if (e != null) {
1533+
IOException ioe = (IOException) e;
1534+
throw new CompletionException(warpIOException(ioe, connection));
15151535
}
1516-
}
1517-
}
1518-
1519-
@Override
1520-
public boolean isDone() {
1521-
synchronized (call) {
1522-
return call.done;
1523-
}
1524-
}
1525-
};
1526-
1527-
ASYNC_RPC_RESPONSE.set(asyncGet);
1536+
return rpcResponse;
1537+
});
1538+
ASYNC_RPC_RESPONSE.set(result);
15281539
return null;
15291540
} else {
1530-
return getRpcResponse(call, connection, -1, null);
1541+
return getRpcResponse(call, connection);
15311542
}
15321543
}
15331544

@@ -1564,37 +1575,34 @@ int getAsyncCallCount() {
15641575
}
15651576

15661577
/** @return the rpc response or, in case of timeout, null. */
1567-
private Writable getRpcResponse(final Call call, final Connection connection,
1568-
final long timeout, final TimeUnit unit) throws IOException {
1569-
synchronized (call) {
1570-
while (!call.done) {
1571-
try {
1572-
AsyncGet.Util.wait(call, timeout, unit);
1573-
if (timeout >= 0 && !call.done) {
1574-
return null;
1575-
}
1576-
} catch (InterruptedException ie) {
1577-
Thread.currentThread().interrupt();
1578-
throw new InterruptedIOException("Call interrupted");
1579-
}
1578+
private Writable getRpcResponse(final Call call, final Connection connection)
1579+
throws IOException {
1580+
try {
1581+
return call.rpcResponseFuture.get();
1582+
} catch (InterruptedException ie) {
1583+
Thread.currentThread().interrupt();
1584+
throw new InterruptedIOException("Call interrupted");
1585+
} catch (ExecutionException e) {
1586+
Throwable cause = e.getCause();
1587+
if (cause instanceof IOException) {
1588+
throw warpIOException((IOException) cause, connection);
15801589
}
1590+
throw new IllegalStateException(e);
1591+
}
1592+
}
15811593

1582-
if (call.error != null) {
1583-
if (call.error instanceof RemoteException ||
1584-
call.error instanceof SaslException) {
1585-
call.error.fillInStackTrace();
1586-
throw call.error;
1587-
} else { // local exception
1588-
InetSocketAddress address = connection.getRemoteAddress();
1589-
throw NetUtils.wrapException(address.getHostName(),
1590-
address.getPort(),
1591-
NetUtils.getHostname(),
1592-
0,
1593-
call.error);
1594-
}
1595-
} else {
1596-
return call.getRpcResponse();
1597-
}
1594+
private IOException warpIOException(IOException ioe, Connection connection) {
1595+
if (ioe instanceof RemoteException ||
1596+
ioe instanceof SaslException) {
1597+
ioe.fillInStackTrace();
1598+
return ioe;
1599+
} else { // local exception
1600+
InetSocketAddress address = connection.getRemoteAddress();
1601+
return NetUtils.wrapException(address.getHostName(),
1602+
address.getPort(),
1603+
NetUtils.getHostname(),
1604+
0,
1605+
ioe);
15981606
}
15991607
}
16001608

hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestAsyncIPC.java

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto;
2929
import org.apache.hadoop.net.NetUtils;
3030
import org.apache.hadoop.util.StringUtils;
31+
import org.apache.hadoop.util.Time;
3132
import org.apache.hadoop.util.concurrent.AsyncGetFuture;
3233
import org.junit.Assert;
3334
import org.junit.Before;
@@ -38,13 +39,16 @@
3839
import java.io.IOException;
3940
import java.net.InetSocketAddress;
4041
import java.util.*;
42+
import java.util.concurrent.CompletableFuture;
4143
import java.util.concurrent.ExecutionException;
4244
import java.util.concurrent.Future;
4345
import java.util.concurrent.TimeUnit;
4446
import java.util.concurrent.TimeoutException;
4547

4648
import static org.junit.Assert.assertEquals;
4749
import static org.junit.Assert.assertFalse;
50+
import static org.junit.Assert.assertTrue;
51+
import static org.junit.Assert.fail;
4852

4953
public class TestAsyncIPC {
5054

@@ -137,6 +141,60 @@ void assertReturnValues(long timeout, TimeUnit unit)
137141
}
138142
}
139143

144+
/**
145+
* For testing the asynchronous calls of the RPC client
146+
* implemented with CompletableFuture.
147+
*/
148+
static class AsyncCompletableFutureCaller extends Thread {
149+
private final Client client;
150+
private final InetSocketAddress server;
151+
private final int count;
152+
private final List<CompletableFuture<Writable>> completableFutures;
153+
private final List<Long> expectedValues;
154+
155+
AsyncCompletableFutureCaller(Client client, InetSocketAddress server, int count) {
156+
this.client = client;
157+
this.server = server;
158+
this.count = count;
159+
this.completableFutures = new ArrayList<>(count);
160+
this.expectedValues = new ArrayList<>(count);
161+
setName("Async CompletableFuture Caller");
162+
}
163+
164+
@Override
165+
public void run() {
166+
// Set the RPC client to use asynchronous mode.
167+
Client.setAsynchronousMode(true);
168+
long startTime = Time.monotonicNow();
169+
try {
170+
for (int i = 0; i < count; i++) {
171+
final long param = TestIPC.RANDOM.nextLong();
172+
TestIPC.call(client, param, server, conf);
173+
expectedValues.add(param);
174+
completableFutures.add(Client.getResponseFuture());
175+
}
176+
// Since the run method is asynchronous,
177+
// it does not need to wait for a response after sending a request,
178+
// so the time taken by the run method is less than count * 100
179+
// (where 100 is the time taken by the server to process a request).
180+
long cost = Time.monotonicNow() - startTime;
181+
assertTrue(cost < count * 100L);
182+
LOG.info("[{}] run cost {}ms", Thread.currentThread().getName(), cost);
183+
} catch (Exception e) {
184+
fail();
185+
}
186+
}
187+
188+
public void assertReturnValues()
189+
throws InterruptedException, ExecutionException {
190+
for (int i = 0; i < count; i++) {
191+
LongWritable value = (LongWritable) completableFutures.get(i).get();
192+
Assert.assertEquals("call" + i + " failed.",
193+
expectedValues.get(i).longValue(), value.get());
194+
}
195+
}
196+
}
197+
140198
static class AsyncLimitlCaller extends Thread {
141199
private Client client;
142200
private InetSocketAddress server;
@@ -538,4 +596,37 @@ public void run() {
538596
assertEquals(startID + i, callIds.get(i).intValue());
539597
}
540598
}
599+
600+
@Test(timeout = 60000)
601+
public void testAsyncCallWithCompletableFuture() throws IOException,
602+
InterruptedException, ExecutionException {
603+
// Override client to store the call id
604+
final Client client = new Client(LongWritable.class, conf);
605+
606+
// Construct an RPC server, which includes a handler thread.
607+
final TestServer server = new TestIPC.TestServer(1, false, conf);
608+
server.callListener = () -> {
609+
try {
610+
// The server requires at least 100 milliseconds to process a request.
611+
Thread.sleep(100);
612+
} catch (InterruptedException e) {
613+
throw new RuntimeException(e);
614+
}
615+
};
616+
617+
try {
618+
InetSocketAddress addr = NetUtils.getConnectAddress(server);
619+
server.start();
620+
// Send 10 asynchronous requests.
621+
final AsyncCompletableFutureCaller caller =
622+
new AsyncCompletableFutureCaller(client, addr, 10);
623+
caller.start();
624+
caller.join();
625+
// Check if the values returned by the asynchronous call meet the expected values.
626+
caller.assertReturnValues();
627+
} finally {
628+
client.stop();
629+
server.stop();
630+
}
631+
}
541632
}

0 commit comments

Comments
 (0)