HADOOP-9716. Rpc retries should use the same call ID as the original call.

git-svn-id: https://svn.apache.org/repos/asf/hadoop/common/trunk@1504362 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Tsz-wo Sze 2013-07-18 04:42:56 +00:00
parent a5cd4b9bee
commit 7ec67c5118
9 changed files with 160 additions and 72 deletions

View File

@ -477,6 +477,9 @@ Release 2.1.0-beta - 2013-07-02
HADOOP-9734. Common protobuf definitions for GetUserMappingsProtocol,
RefreshAuthorizationPolicyProtocol and RefreshUserMappingsProtocol (jlowe)
HADOOP-9716. Rpc retries should use the same call ID as the original call.
(szetszwo)
OPTIMIZATIONS
HADOOP-9150. Avoid unnecessary DNS resolution attempts for logical URIs

View File

@ -18,18 +18,22 @@
package org.apache.hadoop.io.retry;
import java.io.IOException;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.Collections;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.io.retry.RetryPolicy.RetryAction;
import org.apache.hadoop.util.ThreadUtil;
import org.apache.hadoop.ipc.Client;
import org.apache.hadoop.ipc.Client.ConnectionId;
import org.apache.hadoop.ipc.RPC;
import org.apache.hadoop.ipc.RpcConstants;
import org.apache.hadoop.ipc.RpcInvocationHandler;
import org.apache.hadoop.util.ThreadUtil;
class RetryInvocationHandler implements RpcInvocationHandler {
public static final Log LOG = LogFactory.getLog(RetryInvocationHandler.class);
@ -44,13 +48,13 @@ class RetryInvocationHandler implements RpcInvocationHandler {
private final RetryPolicy defaultPolicy;
private final Map<String,RetryPolicy> methodNameToPolicyMap;
private Object currentProxy;
public RetryInvocationHandler(FailoverProxyProvider proxyProvider,
RetryInvocationHandler(FailoverProxyProvider proxyProvider,
RetryPolicy retryPolicy) {
this(proxyProvider, retryPolicy, Collections.<String, RetryPolicy>emptyMap());
}
public RetryInvocationHandler(FailoverProxyProvider proxyProvider,
RetryInvocationHandler(FailoverProxyProvider proxyProvider,
RetryPolicy defaultPolicy,
Map<String, RetryPolicy> methodNameToPolicyMap) {
this.proxyProvider = proxyProvider;
@ -69,6 +73,8 @@ public Object invoke(Object proxy, Method method, Object[] args)
// The number of times this method invocation has been failed over.
int invocationFailoverCount = 0;
final boolean isRpc = isRpcInvocation();
final int callId = isRpc? Client.nextCallId(): RpcConstants.INVALID_CALL_ID;
int retries = 0;
while (true) {
// The number of times this invocation handler has ever been failed over,
@ -78,6 +84,10 @@ public Object invoke(Object proxy, Method method, Object[] args)
synchronized (proxyProvider) {
invocationAttemptFailoverCount = proxyProviderFailoverCount;
}
if (isRpc) {
Client.setCallId(callId);
}
try {
Object ret = invokeMethod(method, args);
hasMadeASuccessfulCall = true;
@ -166,6 +176,14 @@ private Object invokeMethod(Method method, Object[] args) throws Throwable {
}
}
private boolean isRpcInvocation() {
if (!Proxy.isProxyClass(currentProxy.getClass())) {
return false;
}
final InvocationHandler ih = Proxy.getInvocationHandler(currentProxy);
return ih instanceof RpcInvocationHandler;
}
@Override
public void close() throws IOException {
proxyProvider.close();

View File

@ -91,6 +91,7 @@
import org.apache.hadoop.util.StringUtils;
import org.apache.hadoop.util.Time;
import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.google.protobuf.CodedOutputStream;
@ -106,11 +107,22 @@ public class Client {
public static final Log LOG = LogFactory.getLog(Client.class);
/** A counter for generating call IDs. */
private static final AtomicInteger callIdCounter = new AtomicInteger();
private static final ThreadLocal<Integer> callId = new ThreadLocal<Integer>();
/** Set call id for the next call. */
public static void setCallId(int cid) {
Preconditions.checkArgument(cid != RpcConstants.INVALID_CALL_ID);
Preconditions.checkState(callId.get() == null);
callId.set(cid);
}
private Hashtable<ConnectionId, Connection> connections =
new Hashtable<ConnectionId, Connection>();
private Class<? extends Writable> valueClass; // class of call values
private final AtomicInteger counter = new AtomicInteger(); // call ID sequence
private AtomicBoolean running = new AtomicBoolean(true); // if client runs
final private Configuration conf;
@ -259,11 +271,15 @@ synchronized void decCount() {
synchronized boolean isZeroReference() {
return refCount==0;
}
Call createCall(RPC.RpcKind rpcKind, Writable rpcRequest) {
return new Call(rpcKind, rpcRequest);
}
/**
* Class that represents an RPC call
*/
private class Call {
static class Call {
final int id; // call id
final Writable rpcRequest; // the serialized rpc request
Writable rpcResponse; // null if rpc has error
@ -271,10 +287,17 @@ private class Call {
final RPC.RpcKind rpcKind; // Rpc EngineKind
boolean done; // true when call is done
protected Call(RPC.RpcKind rpcKind, Writable param) {
private Call(RPC.RpcKind rpcKind, Writable param) {
this.rpcKind = rpcKind;
this.rpcRequest = param;
this.id = nextCallId();
final Integer id = callId.get();
if (id == null) {
this.id = nextCallId();
} else {
callId.set(null);
this.id = id;
}
}
/** Indicate when the call is complete and the
@ -1346,7 +1369,7 @@ public Writable call(RPC.RpcKind rpcKind, Writable rpcRequest,
public Writable call(RPC.RpcKind rpcKind, Writable rpcRequest,
ConnectionId remoteId, int serviceClass)
throws InterruptedException, IOException {
Call call = new Call(rpcKind, rpcRequest);
final Call call = createCall(rpcKind, rpcRequest);
Connection connection = getConnection(remoteId, call, serviceClass);
try {
connection.sendRpcRequest(call); // send the rpc request
@ -1633,9 +1656,9 @@ public String toString() {
* versions of the client did not mask off the sign bit, so a server may still
* see a negative call ID if it receives connections from an old client.
*
* @return int next valid call ID
* @return next call ID
*/
private int nextCallId() {
return counter.getAndIncrement() & 0x7FFFFFFF;
public static int nextCallId() {
return callIdCounter.getAndIncrement() & 0x7FFFFFFF;
}
}

View File

@ -124,7 +124,7 @@ private Invoker(Class<?> protocol, InetSocketAddress addr,
/**
* This constructor takes a connectionId, instead of creating a new one.
*/
public Invoker(Class<?> protocol, Client.ConnectionId connId,
private Invoker(Class<?> protocol, Client.ConnectionId connId,
Configuration conf, SocketFactory factory) {
this.remoteId = connId;
this.client = CLIENTS.getClient(conf, factory, RpcResponseWrapper.class);

View File

@ -278,7 +278,7 @@ public static Server get() {
*
* @return int sequential ID number of currently active RPC call
*/
public static int getCallId() {
static int getCallId() {
Call call = CurCall.get();
return call != null ? call.callId : RpcConstants.INVALID_CALL_ID;
}
@ -464,12 +464,12 @@ private static class Call {
private final RPC.RpcKind rpcKind;
private final byte[] clientId;
public Call(int id, Writable param, Connection connection) {
private Call(int id, Writable param, Connection connection) {
this(id, param, connection, RPC.RpcKind.RPC_BUILTIN,
RpcConstants.DUMMY_CLIENT_ID);
}
public Call(int id, Writable param, Connection connection,
private Call(int id, Writable param, Connection connection,
RPC.RpcKind kind, byte[] clientId) {
this.callId = id;
this.rpcRequest = param;
@ -482,7 +482,7 @@ public Call(int id, Writable param, Connection connection,
@Override
public String toString() {
return rpcRequest.toString() + " from " + connection.toString();
return rpcRequest + " from " + connection + " Call#" + callId;
}
public void setResponse(ByteBuffer response) {
@ -987,8 +987,7 @@ private boolean processResponse(LinkedList<Call> responseQueue,
call = responseQueue.removeFirst();
SocketChannel channel = call.connection.channel;
if (LOG.isDebugEnabled()) {
LOG.debug(getName() + ": responding to #" + call.callId + " from " +
call.connection);
LOG.debug(getName() + ": responding to " + call);
}
//
// Send as much data as we can in the non-blocking fashion
@ -1007,8 +1006,8 @@ private boolean processResponse(LinkedList<Call> responseQueue,
done = false; // more calls pending to be sent.
}
if (LOG.isDebugEnabled()) {
LOG.debug(getName() + ": responding to #" + call.callId + " from " +
call.connection + " Wrote " + numBytes + " bytes.");
LOG.debug(getName() + ": responding to " + call
+ " Wrote " + numBytes + " bytes.");
}
} else {
//
@ -1035,9 +1034,8 @@ private boolean processResponse(LinkedList<Call> responseQueue,
}
}
if (LOG.isDebugEnabled()) {
LOG.debug(getName() + ": responding to #" + call.callId + " from " +
call.connection + " Wrote partial " + numBytes +
" bytes.");
LOG.debug(getName() + ": responding to " + call
+ " Wrote partial " + numBytes + " bytes.");
}
}
error = false; // everything went off well
@ -2004,8 +2002,7 @@ public void run() {
try {
final Call call = callQueue.take(); // pop the queue; maybe blocked here
if (LOG.isDebugEnabled()) {
LOG.debug(getName() + ": has Call#" + call.callId +
"for RpcKind " + call.rpcKind + " from " + call.connection);
LOG.debug(getName() + ": " + call + " for RpcKind " + call.rpcKind);
}
String errorClass = null;
String error = null;

View File

@ -18,40 +18,49 @@
package org.apache.hadoop.ipc;
import org.apache.commons.logging.*;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.ipc.Server.Connection;
import org.apache.hadoop.util.StringUtils;
import org.apache.hadoop.net.ConnectTimeoutException;
import org.apache.hadoop.net.NetUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.io.ByteArrayOutputStream;
import java.io.DataInput;
import java.io.File;
import java.io.DataOutput;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketTimeoutException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import javax.net.SocketFactory;
import org.junit.Test;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.ipc.RPC.RpcKind;
import org.apache.hadoop.ipc.Server.Connection;
import org.apache.hadoop.net.ConnectTimeoutException;
import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.util.StringUtils;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.Test;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
@ -655,6 +664,48 @@ public void testConnectionRetriesOnSocketTimeoutExceptions() throws Exception {
assertRetriesOnSocketTimeouts(conf, 4);
}
private static class CallId {
int id = RpcConstants.INVALID_CALL_ID;
}
/**
* Test if the rpc server uses the call id generated by the rpc client.
*/
@Test
public void testCallIds() throws Exception {
final CallId callId = new CallId();
// Override client to store the call id
final Client client = new Client(LongWritable.class, conf) {
@Override
Call createCall(RpcKind rpcKind, Writable rpcRequest) {
final Call call = super.createCall(rpcKind, rpcRequest);
callId.id = call.id;
return call;
}
};
// Attach a listener that tracks every call ID received by the server.
final TestServer server = new TestServer(1, false);
server.callListener = new Runnable() {
@Override
public void run() {
Assert.assertEquals(callId.id, Server.getCallId());
}
};
try {
InetSocketAddress addr = NetUtils.getConnectAddress(server);
server.start();
final SerialCaller caller = new SerialCaller(client, addr, 10);
caller.run();
assertFalse(caller.failed);
} finally {
client.stop();
server.stop();
}
}
/**
* Tests that client generates a unique sequential call ID for each RPC call,
* even if multiple threads are using the same client.
@ -701,8 +752,9 @@ public void run() {
// of client call ID, so we must sort the call IDs before checking that it
// contains every expected value.
Collections.sort(callIds);
final int startID = callIds.get(0).intValue();
for (int i = 0; i < expectedCallCount; ++i) {
assertEquals(i, callIds.get(i).intValue());
assertEquals(startID + i, callIds.get(i).intValue());
}
}

View File

@ -105,10 +105,10 @@ public void run() {
byte[] bytes = new byte[byteSize];
System.arraycopy(BYTES, 0, bytes, 0, byteSize);
Writable param = new BytesWritable(bytes);
Writable value = client.call(param, address);
client.call(param, address);
Thread.sleep(RANDOM.nextInt(20));
} catch (Exception e) {
LOG.fatal("Caught: " + e);
LOG.fatal("Caught Exception", e);
failed = true;
}
}

View File

@ -151,10 +151,8 @@ public void tearDown() throws Exception {
private static TestRpcService getClient() throws IOException {
// Set RPC engine to protobuf RPC engine
RPC.setProtocolEngine(conf, TestRpcService.class,
ProtobufRpcEngine.class);
return RPC.getProxy(TestRpcService.class, 0, addr,
conf);
RPC.setProtocolEngine(conf, TestRpcService.class, ProtobufRpcEngine.class);
return RPC.getProxy(TestRpcService.class, 0, addr, conf);
}
private static TestRpcService2 getClient2() throws IOException {
@ -191,6 +189,7 @@ public static void testProtoBufRpc(TestRpcService client) throws Exception {
RemoteException re = (RemoteException)e.getCause();
RpcServerException rse = (RpcServerException) re
.unwrapRemoteException(RpcServerException.class);
Assert.assertNotNull(rse);
Assert.assertTrue(re.getErrorCode().equals(
RpcErrorCodeProto.ERROR_RPC_SERVER));
}
@ -246,6 +245,7 @@ public void testExtraLongRpc() throws Exception {
.setMessage(shortString).build();
// short message goes through
EchoResponseProto echoResponse = client.echo2(null, echoRequest);
Assert.assertEquals(shortString, echoResponse.getMessage());
final String longString = StringUtils.repeat("X", 4096);
echoRequest = EchoRequestProto.newBuilder()

View File

@ -94,7 +94,7 @@ public void setupConf() {
int datasize = 1024*100;
int numThreads = 50;
public interface TestProtocol extends VersionedProtocol {
public static final long versionID = 1L;
@ -360,8 +360,7 @@ public void testProxyAddress() throws Exception {
InetSocketAddress addr = NetUtils.getConnectAddress(server);
// create a client
proxy = (TestProtocol)RPC.getProxy(
TestProtocol.class, TestProtocol.versionID, addr, conf);
proxy = RPC.getProxy(TestProtocol.class, TestProtocol.versionID, addr, conf);
assertEquals(addr, RPC.getServerAddress(proxy));
} finally {
@ -388,8 +387,7 @@ public void testSlowRpc() throws Exception {
InetSocketAddress addr = NetUtils.getConnectAddress(server);
// create a client
proxy = (TestProtocol)RPC.getProxy(
TestProtocol.class, TestProtocol.versionID, addr, conf);
proxy = RPC.getProxy(TestProtocol.class, TestProtocol.versionID, addr, conf);
SlowRPC slowrpc = new SlowRPC(proxy);
Thread thread = new Thread(slowrpc, "SlowRPC");
@ -432,8 +430,7 @@ private void testCallsInternal(Configuration conf) throws Exception {
server.start();
InetSocketAddress addr = NetUtils.getConnectAddress(server);
proxy = (TestProtocol)RPC.getProxy(
TestProtocol.class, TestProtocol.versionID, addr, conf);
proxy = RPC.getProxy(TestProtocol.class, TestProtocol.versionID, addr, conf);
proxy.ping();
@ -557,8 +554,7 @@ private void doRPCs(Configuration conf, boolean expectFailure) throws Exception
InetSocketAddress addr = NetUtils.getConnectAddress(server);
try {
proxy = (TestProtocol)RPC.getProxy(
TestProtocol.class, TestProtocol.versionID, addr, conf);
proxy = RPC.getProxy(TestProtocol.class, TestProtocol.versionID, addr, conf);
proxy.ping();
if (expectFailure) {
@ -660,7 +656,7 @@ public void testStopMockObject() throws Exception {
@Test
public void testStopProxy() throws IOException {
StoppedProtocol proxy = (StoppedProtocol) RPC.getProxy(StoppedProtocol.class,
StoppedProtocol proxy = RPC.getProxy(StoppedProtocol.class,
StoppedProtocol.versionID, null, conf);
StoppedInvocationHandler invocationHandler = (StoppedInvocationHandler)
Proxy.getInvocationHandler(proxy);
@ -671,7 +667,7 @@ public void testStopProxy() throws IOException {
@Test
public void testWrappedStopProxy() throws IOException {
StoppedProtocol wrappedProxy = (StoppedProtocol) RPC.getProxy(StoppedProtocol.class,
StoppedProtocol wrappedProxy = RPC.getProxy(StoppedProtocol.class,
StoppedProtocol.versionID, null, conf);
StoppedInvocationHandler invocationHandler = (StoppedInvocationHandler)
Proxy.getInvocationHandler(wrappedProxy);
@ -701,8 +697,7 @@ public void testErrorMsgForInsecureClient() throws Exception {
final InetSocketAddress addr = NetUtils.getConnectAddress(server);
TestProtocol proxy = null;
try {
proxy = (TestProtocol) RPC.getProxy(TestProtocol.class,
TestProtocol.versionID, addr, conf);
proxy = RPC.getProxy(TestProtocol.class, TestProtocol.versionID, addr, conf);
proxy.echo("");
} catch (RemoteException e) {
LOG.info("LOGGING MESSAGE: " + e.getLocalizedMessage());
@ -730,7 +725,7 @@ public void testErrorMsgForInsecureClient() throws Exception {
proxy = null;
try {
UserGroupInformation.setConfiguration(conf);
proxy = (TestProtocol) RPC.getProxy(TestProtocol.class,
proxy = RPC.getProxy(TestProtocol.class,
TestProtocol.versionID, mulitServerAddr, conf);
proxy.echo("");
} catch (RemoteException e) {
@ -847,7 +842,7 @@ TestProtocol.class, new TestImpl(), ADDRESS, 0, 5, true, conf, null
try {
InetSocketAddress addr = NetUtils.getConnectAddress(server);
final TestProtocol proxy = (TestProtocol) RPC.getProxy(
final TestProtocol proxy = RPC.getProxy(
TestProtocol.class, TestProtocol.versionID, addr, conf);
// Connect to the server
proxy.ping();
@ -887,8 +882,8 @@ TestProtocol.class, new TestImpl(), ADDRESS, 0, 5, true, conf, null
for (int i = 0; i < numConcurrentRPC; i++) {
final int num = i;
final TestProtocol proxy = (TestProtocol) RPC.getProxy(
TestProtocol.class, TestProtocol.versionID, addr, conf);
final TestProtocol proxy = RPC.getProxy(TestProtocol.class,
TestProtocol.versionID, addr, conf);
Thread rpcThread = new Thread(new Runnable() {
@Override
public void run() {
@ -906,7 +901,7 @@ public void run() {
error.set(e);
}
LOG.error(e);
LOG.error("thread " + num, e);
} finally {
latch.countDown();
}