diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine.java index 5db2cc13257..25a4fd4bba7 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine.java @@ -115,7 +115,7 @@ public class ProtobufRpcEngine implements RpcEngine { factory)), false); } - private static class Invoker implements RpcInvocationHandler { + protected static class Invoker implements RpcInvocationHandler { private final Map returnTypes = new ConcurrentHashMap(); private boolean isClosed = false; @@ -126,7 +126,7 @@ public class ProtobufRpcEngine implements RpcEngine { private AtomicBoolean fallbackToSimpleAuth; private AlignmentContext alignmentContext; - private Invoker(Class protocol, InetSocketAddress addr, + protected Invoker(Class protocol, InetSocketAddress addr, UserGroupInformation ticket, Configuration conf, SocketFactory factory, int rpcTimeout, RetryPolicy connectionRetryPolicy, AtomicBoolean fallbackToSimpleAuth, AlignmentContext alignmentContext) @@ -141,7 +141,7 @@ public class ProtobufRpcEngine implements RpcEngine { /** * This constructor takes a connectionId, instead of creating a new one. */ - private Invoker(Class protocol, Client.ConnectionId connId, + protected Invoker(Class protocol, Client.ConnectionId connId, Configuration conf, SocketFactory factory) { this.remoteId = connId; this.client = CLIENTS.getClient(conf, factory, RpcWritable.Buffer.class); @@ -218,8 +218,6 @@ public class ProtobufRpcEngine implements RpcEngine { traceScope = tracer.newScope(RpcClientUtil.methodToTraceString(method)); } - RequestHeaderProto rpcRequestHeader = constructRpcRequestHeader(method); - if (LOG.isTraceEnabled()) { LOG.trace(Thread.currentThread().getId() + ": Call -> " + remoteId + ": " + method.getName() + @@ -231,7 +229,7 @@ public class ProtobufRpcEngine implements RpcEngine { final RpcWritable.Buffer val; try { val = (RpcWritable.Buffer) client.call(RPC.RpcKind.RPC_PROTOCOL_BUFFER, - new RpcProtobufRequest(rpcRequestHeader, theRequest), remoteId, + constructRpcRequest(method, theRequest), remoteId, fallbackToSimpleAuth, alignmentContext); } catch (Throwable e) { @@ -276,6 +274,11 @@ public class ProtobufRpcEngine implements RpcEngine { } } + protected Writable constructRpcRequest(Method method, Message theRequest) { + RequestHeaderProto rpcRequestHeader = constructRpcRequestHeader(method); + return new RpcProtobufRequest(rpcRequestHeader, theRequest); + } + private Message getReturnMessage(final Method method, final RpcWritable.Buffer buf) throws ServiceException { Message prototype = null; @@ -325,6 +328,14 @@ public class ProtobufRpcEngine implements RpcEngine { public ConnectionId getConnectionId() { return remoteId; } + + protected long getClientProtocolVersion() { + return clientProtocolVersion; + } + + protected String getProtocolName() { + return protocolName; + } } @VisibleForTesting @@ -504,6 +515,13 @@ public class ProtobufRpcEngine implements RpcEngine { String declaringClassProtoName = rpcRequest.getDeclaringClassProtocolName(); long clientVersion = rpcRequest.getClientProtocolVersion(); + return call(server, connectionProtocolName, request, receiveTime, + methodName, declaringClassProtoName, clientVersion); + } + + protected Writable call(RPC.Server server, String connectionProtocolName, + RpcWritable.Buffer request, long receiveTime, String methodName, + String declaringClassProtoName, long clientVersion) throws Exception { if (server.verbose) LOG.info("Call: connectionProtocolName=" + connectionProtocolName + ", method=" + methodName);