HADOOP-12909. Change ipc.Client to support asynchronous calls. Contributed by Xiaobing Zhou

This commit is contained in:
Tsz-Wo Nicholas Sze 2016-04-07 14:01:33 +08:00
parent 3c18a53cbd
commit a62637a413
3 changed files with 436 additions and 12 deletions

View File

@ -62,6 +62,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.classification.InterfaceAudience; import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.classification.InterfaceStability; import org.apache.hadoop.classification.InterfaceStability;
import org.apache.hadoop.classification.InterfaceStability.Unstable;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.CommonConfigurationKeys; import org.apache.hadoop.fs.CommonConfigurationKeys;
import org.apache.hadoop.fs.CommonConfigurationKeysPublic; import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
@ -96,6 +97,7 @@ import org.apache.htrace.core.Tracer;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.AbstractFuture;
import com.google.common.util.concurrent.ThreadFactoryBuilder; import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.google.protobuf.CodedOutputStream; import com.google.protobuf.CodedOutputStream;
@ -107,7 +109,7 @@ import com.google.protobuf.CodedOutputStream;
*/ */
@InterfaceAudience.LimitedPrivate(value = { "Common", "HDFS", "MapReduce", "Yarn" }) @InterfaceAudience.LimitedPrivate(value = { "Common", "HDFS", "MapReduce", "Yarn" })
@InterfaceStability.Evolving @InterfaceStability.Evolving
public class Client { public class Client implements AutoCloseable {
public static final Log LOG = LogFactory.getLog(Client.class); public static final Log LOG = LogFactory.getLog(Client.class);
@ -116,6 +118,20 @@ public class Client {
private static final ThreadLocal<Integer> callId = new ThreadLocal<Integer>(); private static final ThreadLocal<Integer> callId = new ThreadLocal<Integer>();
private static final ThreadLocal<Integer> retryCount = new ThreadLocal<Integer>(); private static final ThreadLocal<Integer> retryCount = new ThreadLocal<Integer>();
private static final ThreadLocal<Future<?>> returnValue = new ThreadLocal<>();
private static final ThreadLocal<Boolean> asynchronousMode =
new ThreadLocal<Boolean>() {
@Override
protected Boolean initialValue() {
return false;
}
};
@SuppressWarnings("unchecked")
@Unstable
public static <T> Future<T> getReturnValue() {
return (Future<T>) returnValue.get();
}
/** Set call id and retry count for the next call. */ /** Set call id and retry count for the next call. */
public static void setCallIdAndRetryCount(int cid, int rc) { public static void setCallIdAndRetryCount(int cid, int rc) {
@ -1354,7 +1370,7 @@ public class Client {
ConnectionId remoteId, int serviceClass, ConnectionId remoteId, int serviceClass,
AtomicBoolean fallbackToSimpleAuth) throws IOException { AtomicBoolean fallbackToSimpleAuth) throws IOException {
final Call call = createCall(rpcKind, rpcRequest); final Call call = createCall(rpcKind, rpcRequest);
Connection connection = getConnection(remoteId, call, serviceClass, final Connection connection = getConnection(remoteId, call, serviceClass,
fallbackToSimpleAuth); fallbackToSimpleAuth);
try { try {
connection.sendRpcRequest(call); // send the rpc request connection.sendRpcRequest(call); // send the rpc request
@ -1366,6 +1382,51 @@ public class Client {
throw new IOException(e); throw new IOException(e);
} }
if (isAsynchronousMode()) {
Future<Writable> returnFuture = new AbstractFuture<Writable>() {
@Override
public Writable get() throws InterruptedException, ExecutionException {
try {
set(getRpcResponse(call, connection));
} catch (IOException ie) {
setException(ie);
}
return super.get();
}
};
returnValue.set(returnFuture);
return null;
} else {
return getRpcResponse(call, connection);
}
}
/**
* Check if RPC is in asynchronous mode or not.
*
* @returns true, if RPC is in asynchronous mode, otherwise false for
* synchronous mode.
*/
@Unstable
static boolean isAsynchronousMode() {
return asynchronousMode.get();
}
/**
* Set RPC to asynchronous or synchronous mode.
*
* @param async
* true, RPC will be in asynchronous mode, otherwise false for
* synchronous mode
*/
@Unstable
public static void setAsynchronousMode(boolean async) {
asynchronousMode.set(async);
}
private Writable getRpcResponse(final Call call, final Connection connection)
throws IOException {
synchronized (call) { synchronized (call) {
while (!call.done) { while (!call.done) {
try { try {
@ -1640,4 +1701,10 @@ public class Client {
public static int nextCallId() { public static int nextCallId() {
return callIdCounter.getAndIncrement() & 0x7FFFFFFF; return callIdCounter.getAndIncrement() & 0x7FFFFFFF;
} }
@Override
@Unstable
public void close() throws Exception {
stop();
}
} }

View File

@ -0,0 +1,346 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.ipc;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.ipc.RPC.RpcKind;
import org.apache.hadoop.ipc.TestIPC.CallInfo;
import org.apache.hadoop.ipc.TestIPC.TestServer;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto;
import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.util.StringUtils;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
public class TestAsyncIPC {
private static Configuration conf;
private static final Log LOG = LogFactory.getLog(TestAsyncIPC.class);
@Before
public void setupConf() {
conf = new Configuration();
Client.setPingInterval(conf, TestIPC.PING_INTERVAL);
// set asynchronous mode for main thread
Client.setAsynchronousMode(true);
}
protected static class SerialCaller extends Thread {
private Client client;
private InetSocketAddress server;
private int count;
private boolean failed;
Map<Integer, Future<LongWritable>> returnFutures =
new HashMap<Integer, Future<LongWritable>>();
Map<Integer, Long> expectedValues = new HashMap<Integer, Long>();
public SerialCaller(Client client, InetSocketAddress server, int count) {
this.client = client;
this.server = server;
this.count = count;
// set asynchronous mode, since SerialCaller extends Thread
Client.setAsynchronousMode(true);
}
@Override
public void run() {
// in case Thread#Start is called, which will spawn new thread
Client.setAsynchronousMode(true);
for (int i = 0; i < count; i++) {
try {
final long param = TestIPC.RANDOM.nextLong();
TestIPC.call(client, param, server, conf);
Future<LongWritable> returnFuture = Client.getReturnValue();
returnFutures.put(i, returnFuture);
expectedValues.put(i, param);
} catch (Exception e) {
LOG.fatal("Caught: " + StringUtils.stringifyException(e));
failed = true;
}
}
}
public void waitForReturnValues() throws InterruptedException,
ExecutionException {
for (int i = 0; i < count; i++) {
LongWritable value = returnFutures.get(i).get();
if (expectedValues.get(i) != value.get()) {
LOG.fatal(String.format("Call-%d failed!", i));
failed = true;
break;
}
}
}
}
@Test
public void testSerial() throws IOException, InterruptedException,
ExecutionException {
internalTestSerial(3, false, 2, 5, 100);
internalTestSerial(3, true, 2, 5, 10);
}
public void internalTestSerial(int handlerCount, boolean handlerSleep,
int clientCount, int callerCount, int callCount) throws IOException,
InterruptedException, ExecutionException {
Server server = new TestIPC.TestServer(handlerCount, handlerSleep, conf);
InetSocketAddress addr = NetUtils.getConnectAddress(server);
server.start();
Client[] clients = new Client[clientCount];
for (int i = 0; i < clientCount; i++) {
clients[i] = new Client(LongWritable.class, conf);
}
SerialCaller[] callers = new SerialCaller[callerCount];
for (int i = 0; i < callerCount; i++) {
callers[i] = new SerialCaller(clients[i % clientCount], addr, callCount);
callers[i].start();
}
for (int i = 0; i < callerCount; i++) {
callers[i].join();
callers[i].waitForReturnValues();
String msg = String.format("Expected not failed for caller-%d: %s.", i,
callers[i]);
assertFalse(msg, callers[i].failed);
}
for (int i = 0; i < clientCount; i++) {
clients[i].stop();
}
server.stop();
}
/**
* Test if (1) the rpc server uses the call id/retry provided by the rpc
* client, and (2) the rpc client receives the same call id/retry from the rpc
* server.
*
* @throws ExecutionException
* @throws InterruptedException
*/
@Test(timeout = 60000)
public void testCallIdAndRetry() throws IOException, InterruptedException,
ExecutionException {
final Map<Integer, CallInfo> infoMap = new HashMap<Integer, CallInfo>();
// Override client to store the call info and check response
final Client client = new Client(LongWritable.class, conf) {
@Override
Call createCall(RpcKind rpcKind, Writable rpcRequest) {
// Set different call id and retry count for the next call
Client.setCallIdAndRetryCount(Client.nextCallId(),
TestIPC.RANDOM.nextInt(255));
final Call call = super.createCall(rpcKind, rpcRequest);
CallInfo info = new CallInfo();
info.id = call.id;
info.retry = call.retry;
infoMap.put(call.id, info);
return call;
}
@Override
void checkResponse(RpcResponseHeaderProto header) throws IOException {
super.checkResponse(header);
Assert.assertEquals(infoMap.get(header.getCallId()).retry,
header.getRetryCount());
}
};
// Attach a listener that tracks every call received by the server.
final TestServer server = new TestIPC.TestServer(1, false, conf);
server.callListener = new Runnable() {
@Override
public void run() {
Assert.assertEquals(infoMap.get(Server.getCallId()).retry,
Server.getCallRetryCount());
}
};
try {
InetSocketAddress addr = NetUtils.getConnectAddress(server);
server.start();
final SerialCaller caller = new SerialCaller(client, addr, 4);
caller.run();
caller.waitForReturnValues();
String msg = String.format("Expected not failed for caller: %s.", caller);
assertFalse(msg, caller.failed);
} finally {
client.stop();
server.stop();
}
}
/**
* Test if the rpc server gets the retry count from client.
*
* @throws ExecutionException
* @throws InterruptedException
*/
@Test(timeout = 60000)
public void testCallRetryCount() throws IOException, InterruptedException,
ExecutionException {
final int retryCount = 255;
// Override client to store the call id
final Client client = new Client(LongWritable.class, conf);
Client.setCallIdAndRetryCount(Client.nextCallId(), retryCount);
// Attach a listener that tracks every call ID received by the server.
final TestServer server = new TestIPC.TestServer(1, false, conf);
server.callListener = new Runnable() {
@Override
public void run() {
// we have not set the retry count for the client, thus on the server
// side we should see retry count as 0
Assert.assertEquals(retryCount, Server.getCallRetryCount());
}
};
try {
InetSocketAddress addr = NetUtils.getConnectAddress(server);
server.start();
final SerialCaller caller = new SerialCaller(client, addr, 10);
caller.run();
caller.waitForReturnValues();
String msg = String.format("Expected not failed for caller: %s.", caller);
assertFalse(msg, caller.failed);
} finally {
client.stop();
server.stop();
}
}
/**
* Test if the rpc server gets the default retry count (0) from client.
*
* @throws ExecutionException
* @throws InterruptedException
*/
@Test(timeout = 60000)
public void testInitialCallRetryCount() throws IOException,
InterruptedException, ExecutionException {
// Override client to store the call id
final Client client = new Client(LongWritable.class, conf);
// Attach a listener that tracks every call ID received by the server.
final TestServer server = new TestIPC.TestServer(1, false, conf);
server.callListener = new Runnable() {
@Override
public void run() {
// we have not set the retry count for the client, thus on the server
// side we should see retry count as 0
Assert.assertEquals(0, Server.getCallRetryCount());
}
};
try {
InetSocketAddress addr = NetUtils.getConnectAddress(server);
server.start();
final SerialCaller caller = new SerialCaller(client, addr, 10);
caller.run();
caller.waitForReturnValues();
String msg = String.format("Expected not failed for caller: %s.", caller);
assertFalse(msg, caller.failed);
} finally {
client.stop();
server.stop();
}
}
/**
* Tests that client generates a unique sequential call ID for each RPC call,
* even if multiple threads are using the same client.
*
* @throws InterruptedException
* @throws ExecutionException
*/
@Test(timeout = 60000)
public void testUniqueSequentialCallIds() throws IOException,
InterruptedException, ExecutionException {
int serverThreads = 10, callerCount = 100, perCallerCallCount = 100;
TestServer server = new TestIPC.TestServer(serverThreads, false, conf);
// Attach a listener that tracks every call ID received by the server. This
// list must be synchronized, because multiple server threads will add to
// it.
final List<Integer> callIds = Collections
.synchronizedList(new ArrayList<Integer>());
server.callListener = new Runnable() {
@Override
public void run() {
callIds.add(Server.getCallId());
}
};
Client client = new Client(LongWritable.class, conf);
try {
InetSocketAddress addr = NetUtils.getConnectAddress(server);
server.start();
SerialCaller[] callers = new SerialCaller[callerCount];
for (int i = 0; i < callerCount; ++i) {
callers[i] = new SerialCaller(client, addr, perCallerCallCount);
callers[i].start();
}
for (int i = 0; i < callerCount; ++i) {
callers[i].join();
callers[i].waitForReturnValues();
String msg = String.format("Expected not failed for caller-%d: %s.", i,
callers[i]);
assertFalse(msg, callers[i].failed);
}
} finally {
client.stop();
server.stop();
}
int expectedCallCount = callerCount * perCallerCallCount;
assertEquals(expectedCallCount, callIds.size());
// It is not guaranteed that the server executes requests in sequential
// order
// of client call ID, so we must sort the call IDs before checking that it
// contains every expected value.
Collections.sort(callIds);
final int startID = callIds.get(0).intValue();
for (int i = 0; i < expectedCallCount; ++i) {
assertEquals(startID + i, callIds.get(i).intValue());
}
}
}

View File

@ -99,7 +99,7 @@ public class TestIPC {
LogFactory.getLog(TestIPC.class); LogFactory.getLog(TestIPC.class);
private static Configuration conf; private static Configuration conf;
final static private int PING_INTERVAL = 1000; final static int PING_INTERVAL = 1000;
final static private int MIN_SLEEP_TIME = 1000; final static private int MIN_SLEEP_TIME = 1000;
/** /**
* Flag used to turn off the fault injection behavior * Flag used to turn off the fault injection behavior
@ -114,7 +114,7 @@ public class TestIPC {
Client.setPingInterval(conf, PING_INTERVAL); Client.setPingInterval(conf, PING_INTERVAL);
} }
private static final Random RANDOM = new Random(); static final Random RANDOM = new Random();
private static final String ADDRESS = "0.0.0.0"; private static final String ADDRESS = "0.0.0.0";
@ -148,11 +148,11 @@ public class TestIPC {
RPC.RPC_SERVICE_CLASS_DEFAULT, null); RPC.RPC_SERVICE_CLASS_DEFAULT, null);
} }
private static class TestServer extends Server { static class TestServer extends Server {
// Tests can set callListener to run a piece of code each time the server // Tests can set callListener to run a piece of code each time the server
// receives a call. This code executes on the server thread, so it has // receives a call. This code executes on the server thread, so it has
// visibility of that thread's thread-local storage. // visibility of that thread's thread-local storage.
private Runnable callListener; Runnable callListener;
private boolean sleep; private boolean sleep;
private Class<? extends Writable> responseClass; private Class<? extends Writable> responseClass;
@ -160,9 +160,20 @@ public class TestIPC {
this(handlerCount, sleep, LongWritable.class, null); this(handlerCount, sleep, LongWritable.class, null);
} }
public TestServer(int handlerCount, boolean sleep, Configuration conf)
throws IOException {
this(handlerCount, sleep, LongWritable.class, null, conf);
}
public TestServer(int handlerCount, boolean sleep, public TestServer(int handlerCount, boolean sleep,
Class<? extends Writable> paramClass, Class<? extends Writable> paramClass,
Class<? extends Writable> responseClass) Class<? extends Writable> responseClass) throws IOException {
this(handlerCount, sleep, paramClass, responseClass, conf);
}
public TestServer(int handlerCount, boolean sleep,
Class<? extends Writable> paramClass,
Class<? extends Writable> responseClass, Configuration conf)
throws IOException { throws IOException {
super(ADDRESS, 0, paramClass, handlerCount, conf); super(ADDRESS, 0, paramClass, handlerCount, conf);
this.sleep = sleep; this.sleep = sleep;
@ -1070,7 +1081,7 @@ public class TestIPC {
assertRetriesOnSocketTimeouts(conf, 4); assertRetriesOnSocketTimeouts(conf, 4);
} }
private static class CallInfo { static class CallInfo {
int id = RpcConstants.INVALID_CALL_ID; int id = RpcConstants.INVALID_CALL_ID;
int retry = RpcConstants.INVALID_RETRY_COUNT; int retry = RpcConstants.INVALID_RETRY_COUNT;
} }
@ -1125,7 +1136,7 @@ public class TestIPC {
} }
/** A dummy protocol */ /** A dummy protocol */
private interface DummyProtocol { interface DummyProtocol {
@Idempotent @Idempotent
public void dummyRun() throws IOException; public void dummyRun() throws IOException;
} }