diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java index d76b4eaa5d1..1c0dd039b9f 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java @@ -62,6 +62,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.classification.InterfaceAudience; import org.apache.hadoop.classification.InterfaceStability; +import org.apache.hadoop.classification.InterfaceStability.Unstable; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.CommonConfigurationKeys; import org.apache.hadoop.fs.CommonConfigurationKeysPublic; @@ -96,6 +97,7 @@ import org.apache.htrace.core.Tracer; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.common.util.concurrent.AbstractFuture; import com.google.common.util.concurrent.ThreadFactoryBuilder; import com.google.protobuf.CodedOutputStream; @@ -107,7 +109,7 @@ import com.google.protobuf.CodedOutputStream; */ @InterfaceAudience.LimitedPrivate(value = { "Common", "HDFS", "MapReduce", "Yarn" }) @InterfaceStability.Evolving -public class Client { +public class Client implements AutoCloseable { public static final Log LOG = LogFactory.getLog(Client.class); @@ -116,6 +118,20 @@ public class Client { private static final ThreadLocal callId = new ThreadLocal(); private static final ThreadLocal retryCount = new ThreadLocal(); + private static final ThreadLocal> returnValue = new ThreadLocal<>(); + private static final ThreadLocal asynchronousMode = + new ThreadLocal() { + @Override + protected Boolean initialValue() { + return false; + } + }; + + @SuppressWarnings("unchecked") + @Unstable + public static Future getReturnValue() { + return (Future) returnValue.get(); + } /** Set call id and retry count for the next call. */ public static void setCallIdAndRetryCount(int cid, int rc) { @@ -1356,8 +1372,8 @@ public class Client { ConnectionId remoteId, int serviceClass, AtomicBoolean fallbackToSimpleAuth) throws IOException { final Call call = createCall(rpcKind, rpcRequest); - Connection connection = getConnection(remoteId, call, serviceClass, - fallbackToSimpleAuth); + final Connection connection = getConnection(remoteId, call, serviceClass, + fallbackToSimpleAuth); try { connection.sendRpcRequest(call); // send the rpc request } catch (RejectedExecutionException e) { @@ -1368,6 +1384,51 @@ public class Client { throw new IOException(e); } + if (isAsynchronousMode()) { + Future returnFuture = new AbstractFuture() { + @Override + public Writable get() throws InterruptedException, ExecutionException { + try { + set(getRpcResponse(call, connection)); + } catch (IOException ie) { + setException(ie); + } + return super.get(); + } + }; + + returnValue.set(returnFuture); + return null; + } else { + return getRpcResponse(call, connection); + } + } + + /** + * Check if RPC is in asynchronous mode or not. + * + * @returns true, if RPC is in asynchronous mode, otherwise false for + * synchronous mode. + */ + @Unstable + static boolean isAsynchronousMode() { + return asynchronousMode.get(); + } + + /** + * Set RPC to asynchronous or synchronous mode. + * + * @param async + * true, RPC will be in asynchronous mode, otherwise false for + * synchronous mode + */ + @Unstable + public static void setAsynchronousMode(boolean async) { + asynchronousMode.set(async); + } + + private Writable getRpcResponse(final Call call, final Connection connection) + throws IOException { synchronized (call) { while (!call.done) { try { @@ -1642,4 +1703,10 @@ public class Client { public static int nextCallId() { return callIdCounter.getAndIncrement() & 0x7FFFFFFF; } + + @Override + @Unstable + public void close() throws Exception { + stop(); + } } diff --git a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestAsyncIPC.java b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestAsyncIPC.java new file mode 100644 index 00000000000..de4395e8f25 --- /dev/null +++ b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestAsyncIPC.java @@ -0,0 +1,346 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.ipc; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.ipc.RPC.RpcKind; +import org.apache.hadoop.ipc.TestIPC.CallInfo; +import org.apache.hadoop.ipc.TestIPC.TestServer; +import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto; +import org.apache.hadoop.net.NetUtils; +import org.apache.hadoop.util.StringUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestAsyncIPC { + + private static Configuration conf; + private static final Log LOG = LogFactory.getLog(TestAsyncIPC.class); + + @Before + public void setupConf() { + conf = new Configuration(); + Client.setPingInterval(conf, TestIPC.PING_INTERVAL); + // set asynchronous mode for main thread + Client.setAsynchronousMode(true); + } + + protected static class SerialCaller extends Thread { + private Client client; + private InetSocketAddress server; + private int count; + private boolean failed; + Map> returnFutures = + new HashMap>(); + Map expectedValues = new HashMap(); + + public SerialCaller(Client client, InetSocketAddress server, int count) { + this.client = client; + this.server = server; + this.count = count; + // set asynchronous mode, since SerialCaller extends Thread + Client.setAsynchronousMode(true); + } + + @Override + public void run() { + // in case Thread#Start is called, which will spawn new thread + Client.setAsynchronousMode(true); + for (int i = 0; i < count; i++) { + try { + final long param = TestIPC.RANDOM.nextLong(); + TestIPC.call(client, param, server, conf); + Future returnFuture = Client.getReturnValue(); + returnFutures.put(i, returnFuture); + expectedValues.put(i, param); + } catch (Exception e) { + LOG.fatal("Caught: " + StringUtils.stringifyException(e)); + failed = true; + } + } + } + + public void waitForReturnValues() throws InterruptedException, + ExecutionException { + for (int i = 0; i < count; i++) { + LongWritable value = returnFutures.get(i).get(); + if (expectedValues.get(i) != value.get()) { + LOG.fatal(String.format("Call-%d failed!", i)); + failed = true; + break; + } + } + } + } + + @Test + public void testSerial() throws IOException, InterruptedException, + ExecutionException { + internalTestSerial(3, false, 2, 5, 100); + internalTestSerial(3, true, 2, 5, 10); + } + + public void internalTestSerial(int handlerCount, boolean handlerSleep, + int clientCount, int callerCount, int callCount) throws IOException, + InterruptedException, ExecutionException { + Server server = new TestIPC.TestServer(handlerCount, handlerSleep, conf); + InetSocketAddress addr = NetUtils.getConnectAddress(server); + server.start(); + + Client[] clients = new Client[clientCount]; + for (int i = 0; i < clientCount; i++) { + clients[i] = new Client(LongWritable.class, conf); + } + + SerialCaller[] callers = new SerialCaller[callerCount]; + for (int i = 0; i < callerCount; i++) { + callers[i] = new SerialCaller(clients[i % clientCount], addr, callCount); + callers[i].start(); + } + for (int i = 0; i < callerCount; i++) { + callers[i].join(); + callers[i].waitForReturnValues(); + String msg = String.format("Expected not failed for caller-%d: %s.", i, + callers[i]); + assertFalse(msg, callers[i].failed); + } + for (int i = 0; i < clientCount; i++) { + clients[i].stop(); + } + server.stop(); + } + + /** + * Test if (1) the rpc server uses the call id/retry provided by the rpc + * client, and (2) the rpc client receives the same call id/retry from the rpc + * server. + * + * @throws ExecutionException + * @throws InterruptedException + */ + @Test(timeout = 60000) + public void testCallIdAndRetry() throws IOException, InterruptedException, + ExecutionException { + final Map infoMap = new HashMap(); + + // Override client to store the call info and check response + final Client client = new Client(LongWritable.class, conf) { + @Override + Call createCall(RpcKind rpcKind, Writable rpcRequest) { + // Set different call id and retry count for the next call + Client.setCallIdAndRetryCount(Client.nextCallId(), + TestIPC.RANDOM.nextInt(255)); + + final Call call = super.createCall(rpcKind, rpcRequest); + + CallInfo info = new CallInfo(); + info.id = call.id; + info.retry = call.retry; + infoMap.put(call.id, info); + + return call; + } + + @Override + void checkResponse(RpcResponseHeaderProto header) throws IOException { + super.checkResponse(header); + Assert.assertEquals(infoMap.get(header.getCallId()).retry, + header.getRetryCount()); + } + }; + + // Attach a listener that tracks every call received by the server. + final TestServer server = new TestIPC.TestServer(1, false, conf); + server.callListener = new Runnable() { + @Override + public void run() { + Assert.assertEquals(infoMap.get(Server.getCallId()).retry, + Server.getCallRetryCount()); + } + }; + + try { + InetSocketAddress addr = NetUtils.getConnectAddress(server); + server.start(); + final SerialCaller caller = new SerialCaller(client, addr, 4); + caller.run(); + caller.waitForReturnValues(); + String msg = String.format("Expected not failed for caller: %s.", caller); + assertFalse(msg, caller.failed); + } finally { + client.stop(); + server.stop(); + } + } + + /** + * Test if the rpc server gets the retry count from client. + * + * @throws ExecutionException + * @throws InterruptedException + */ + @Test(timeout = 60000) + public void testCallRetryCount() throws IOException, InterruptedException, + ExecutionException { + final int retryCount = 255; + // Override client to store the call id + final Client client = new Client(LongWritable.class, conf); + Client.setCallIdAndRetryCount(Client.nextCallId(), retryCount); + + // Attach a listener that tracks every call ID received by the server. + final TestServer server = new TestIPC.TestServer(1, false, conf); + server.callListener = new Runnable() { + @Override + public void run() { + // we have not set the retry count for the client, thus on the server + // side we should see retry count as 0 + Assert.assertEquals(retryCount, Server.getCallRetryCount()); + } + }; + + try { + InetSocketAddress addr = NetUtils.getConnectAddress(server); + server.start(); + final SerialCaller caller = new SerialCaller(client, addr, 10); + caller.run(); + caller.waitForReturnValues(); + String msg = String.format("Expected not failed for caller: %s.", caller); + assertFalse(msg, caller.failed); + } finally { + client.stop(); + server.stop(); + } + } + + /** + * Test if the rpc server gets the default retry count (0) from client. + * + * @throws ExecutionException + * @throws InterruptedException + */ + @Test(timeout = 60000) + public void testInitialCallRetryCount() throws IOException, + InterruptedException, ExecutionException { + // Override client to store the call id + final Client client = new Client(LongWritable.class, conf); + + // Attach a listener that tracks every call ID received by the server. + final TestServer server = new TestIPC.TestServer(1, false, conf); + server.callListener = new Runnable() { + @Override + public void run() { + // we have not set the retry count for the client, thus on the server + // side we should see retry count as 0 + Assert.assertEquals(0, Server.getCallRetryCount()); + } + }; + + try { + InetSocketAddress addr = NetUtils.getConnectAddress(server); + server.start(); + final SerialCaller caller = new SerialCaller(client, addr, 10); + caller.run(); + caller.waitForReturnValues(); + String msg = String.format("Expected not failed for caller: %s.", caller); + assertFalse(msg, caller.failed); + } finally { + client.stop(); + server.stop(); + } + } + + /** + * Tests that client generates a unique sequential call ID for each RPC call, + * even if multiple threads are using the same client. + * + * @throws InterruptedException + * @throws ExecutionException + */ + @Test(timeout = 60000) + public void testUniqueSequentialCallIds() throws IOException, + InterruptedException, ExecutionException { + int serverThreads = 10, callerCount = 100, perCallerCallCount = 100; + TestServer server = new TestIPC.TestServer(serverThreads, false, conf); + + // Attach a listener that tracks every call ID received by the server. This + // list must be synchronized, because multiple server threads will add to + // it. + final List callIds = Collections + .synchronizedList(new ArrayList()); + server.callListener = new Runnable() { + @Override + public void run() { + callIds.add(Server.getCallId()); + } + }; + + Client client = new Client(LongWritable.class, conf); + + try { + InetSocketAddress addr = NetUtils.getConnectAddress(server); + server.start(); + SerialCaller[] callers = new SerialCaller[callerCount]; + for (int i = 0; i < callerCount; ++i) { + callers[i] = new SerialCaller(client, addr, perCallerCallCount); + callers[i].start(); + } + for (int i = 0; i < callerCount; ++i) { + callers[i].join(); + callers[i].waitForReturnValues(); + String msg = String.format("Expected not failed for caller-%d: %s.", i, + callers[i]); + assertFalse(msg, callers[i].failed); + } + } finally { + client.stop(); + server.stop(); + } + + int expectedCallCount = callerCount * perCallerCallCount; + assertEquals(expectedCallCount, callIds.size()); + + // It is not guaranteed that the server executes requests in sequential + // order + // of client call ID, so we must sort the call IDs before checking that it + // contains every expected value. + Collections.sort(callIds); + final int startID = callIds.get(0).intValue(); + for (int i = 0; i < expectedCallCount; ++i) { + assertEquals(startID + i, callIds.get(i).intValue()); + } + } +} \ No newline at end of file diff --git a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestIPC.java b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestIPC.java index d65818238eb..6bfcc537da4 100644 --- a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestIPC.java +++ b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestIPC.java @@ -99,7 +99,7 @@ public class TestIPC { LogFactory.getLog(TestIPC.class); private static Configuration conf; - final static private int PING_INTERVAL = 1000; + final static int PING_INTERVAL = 1000; final static private int MIN_SLEEP_TIME = 1000; /** * Flag used to turn off the fault injection behavior @@ -114,7 +114,7 @@ public class TestIPC { Client.setPingInterval(conf, PING_INTERVAL); } - private static final Random RANDOM = new Random(); + static final Random RANDOM = new Random(); private static final String ADDRESS = "0.0.0.0"; @@ -148,22 +148,33 @@ public class TestIPC { RPC.RPC_SERVICE_CLASS_DEFAULT, null); } - private static class TestServer extends Server { + static class TestServer extends Server { // Tests can set callListener to run a piece of code each time the server // receives a call. This code executes on the server thread, so it has // visibility of that thread's thread-local storage. - private Runnable callListener; + Runnable callListener; private boolean sleep; private Class responseClass; public TestServer(int handlerCount, boolean sleep) throws IOException { this(handlerCount, sleep, LongWritable.class, null); } - + + public TestServer(int handlerCount, boolean sleep, Configuration conf) + throws IOException { + this(handlerCount, sleep, LongWritable.class, null, conf); + } + public TestServer(int handlerCount, boolean sleep, Class paramClass, - Class responseClass) - throws IOException { + Class responseClass) throws IOException { + this(handlerCount, sleep, paramClass, responseClass, conf); + } + + public TestServer(int handlerCount, boolean sleep, + Class paramClass, + Class responseClass, Configuration conf) + throws IOException { super(ADDRESS, 0, paramClass, handlerCount, conf); this.sleep = sleep; this.responseClass = responseClass; @@ -1070,7 +1081,7 @@ public class TestIPC { assertRetriesOnSocketTimeouts(conf, 4); } - private static class CallInfo { + static class CallInfo { int id = RpcConstants.INVALID_CALL_ID; int retry = RpcConstants.INVALID_RETRY_COUNT; } @@ -1125,7 +1136,7 @@ public class TestIPC { } /** A dummy protocol */ - private interface DummyProtocol { + interface DummyProtocol { @Idempotent public void dummyRun() throws IOException; }