From f40e3eb0590f85bb42d2471992bf5d524628fdd6 Mon Sep 17 00:00:00 2001 From: hchaverr Date: Thu, 6 May 2021 16:40:45 -0700 Subject: [PATCH] HADOOP-17680. Allow ProtobufRpcEngine to be extensible (#2905) Contributed by Hector Chaverri. --- .../apache/hadoop/ipc/ProtobufRpcEngine.java | 30 +++++++++++++++---- .../apache/hadoop/ipc/ProtobufRpcEngine2.java | 30 +++++++++++++++---- 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine.java index a1500d52a74..882cc141d89 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine.java @@ -122,7 +122,7 @@ public class ProtobufRpcEngine implements RpcEngine { factory)), false); } - private static class Invoker implements RpcInvocationHandler { + protected static class Invoker implements RpcInvocationHandler { private final Map returnTypes = new ConcurrentHashMap(); private boolean isClosed = false; @@ -133,7 +133,7 @@ public class ProtobufRpcEngine implements RpcEngine { private AtomicBoolean fallbackToSimpleAuth; private AlignmentContext alignmentContext; - private Invoker(Class protocol, InetSocketAddress addr, + protected Invoker(Class protocol, InetSocketAddress addr, UserGroupInformation ticket, Configuration conf, SocketFactory factory, int rpcTimeout, RetryPolicy connectionRetryPolicy, AtomicBoolean fallbackToSimpleAuth, AlignmentContext alignmentContext) @@ -148,7 +148,7 @@ public class ProtobufRpcEngine implements RpcEngine { /** * This constructor takes a connectionId, instead of creating a new one. */ - private Invoker(Class protocol, Client.ConnectionId connId, + protected Invoker(Class protocol, Client.ConnectionId connId, Configuration conf, SocketFactory factory) { this.remoteId = connId; this.client = CLIENTS.getClient(conf, factory, RpcWritable.Buffer.class); @@ -225,8 +225,6 @@ public class ProtobufRpcEngine implements RpcEngine { traceScope = tracer.newScope(RpcClientUtil.methodToTraceString(method)); } - RequestHeaderProto rpcRequestHeader = constructRpcRequestHeader(method); - if (LOG.isTraceEnabled()) { LOG.trace(Thread.currentThread().getId() + ": Call -> " + remoteId + ": " + method.getName() + @@ -238,7 +236,7 @@ public class ProtobufRpcEngine implements RpcEngine { final RpcWritable.Buffer val; try { val = (RpcWritable.Buffer) client.call(RPC.RpcKind.RPC_PROTOCOL_BUFFER, - new RpcProtobufRequest(rpcRequestHeader, theRequest), remoteId, + constructRpcRequest(method, theRequest), remoteId, fallbackToSimpleAuth, alignmentContext); } catch (Throwable e) { @@ -283,6 +281,11 @@ public class ProtobufRpcEngine implements RpcEngine { } } + protected Writable constructRpcRequest(Method method, Message theRequest) { + RequestHeaderProto rpcRequestHeader = constructRpcRequestHeader(method); + return new RpcProtobufRequest(rpcRequestHeader, theRequest); + } + private Message getReturnMessage(final Method method, final RpcWritable.Buffer buf) throws ServiceException { Message prototype = null; @@ -332,6 +335,14 @@ public class ProtobufRpcEngine implements RpcEngine { public ConnectionId getConnectionId() { return remoteId; } + + protected long getClientProtocolVersion() { + return clientProtocolVersion; + } + + protected String getProtocolName() { + return protocolName; + } } @VisibleForTesting @@ -518,6 +529,13 @@ public class ProtobufRpcEngine implements RpcEngine { String declaringClassProtoName = rpcRequest.getDeclaringClassProtocolName(); long clientVersion = rpcRequest.getClientProtocolVersion(); + return call(server, connectionProtocolName, request, receiveTime, + methodName, declaringClassProtoName, clientVersion); + } + + protected Writable call(RPC.Server server, String connectionProtocolName, + RpcWritable.Buffer request, long receiveTime, String methodName, + String declaringClassProtoName, long clientVersion) throws Exception { if (server.verbose) LOG.info("Call: connectionProtocolName=" + connectionProtocolName + ", method=" + methodName); diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine2.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine2.java index 310f44eebe2..2f5d56437d0 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine2.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine2.java @@ -116,7 +116,7 @@ public class ProtobufRpcEngine2 implements RpcEngine { factory)), false); } - private static final class Invoker implements RpcInvocationHandler { + protected static class Invoker implements RpcInvocationHandler { private final Map returnTypes = new ConcurrentHashMap(); private boolean isClosed = false; @@ -127,7 +127,7 @@ public class ProtobufRpcEngine2 implements RpcEngine { private AtomicBoolean fallbackToSimpleAuth; private AlignmentContext alignmentContext; - private Invoker(Class protocol, InetSocketAddress addr, + protected Invoker(Class protocol, InetSocketAddress addr, UserGroupInformation ticket, Configuration conf, SocketFactory factory, int rpcTimeout, RetryPolicy connectionRetryPolicy, AtomicBoolean fallbackToSimpleAuth, AlignmentContext alignmentContext) @@ -142,7 +142,7 @@ public class ProtobufRpcEngine2 implements RpcEngine { /** * This constructor takes a connectionId, instead of creating a new one. */ - private Invoker(Class protocol, Client.ConnectionId connId, + protected Invoker(Class protocol, Client.ConnectionId connId, Configuration conf, SocketFactory factory) { this.remoteId = connId; this.client = CLIENTS.getClient(conf, factory, RpcWritable.Buffer.class); @@ -219,8 +219,6 @@ public class ProtobufRpcEngine2 implements RpcEngine { traceScope = tracer.newScope(RpcClientUtil.methodToTraceString(method)); } - RequestHeaderProto rpcRequestHeader = constructRpcRequestHeader(method); - if (LOG.isTraceEnabled()) { LOG.trace(Thread.currentThread().getId() + ": Call -> " + remoteId + ": " + method.getName() + @@ -232,7 +230,7 @@ public class ProtobufRpcEngine2 implements RpcEngine { final RpcWritable.Buffer val; try { val = (RpcWritable.Buffer) client.call(RPC.RpcKind.RPC_PROTOCOL_BUFFER, - new RpcProtobufRequest(rpcRequestHeader, theRequest), remoteId, + constructRpcRequest(method, theRequest), remoteId, fallbackToSimpleAuth, alignmentContext); } catch (Throwable e) { @@ -279,6 +277,11 @@ public class ProtobufRpcEngine2 implements RpcEngine { } } + protected Writable constructRpcRequest(Method method, Message theRequest) { + RequestHeaderProto rpcRequestHeader = constructRpcRequestHeader(method); + return new RpcProtobufRequest(rpcRequestHeader, theRequest); + } + private Message getReturnMessage(final Method method, final RpcWritable.Buffer buf) throws ServiceException { Message prototype = null; @@ -328,6 +331,14 @@ public class ProtobufRpcEngine2 implements RpcEngine { public ConnectionId getConnectionId() { return remoteId; } + + protected long getClientProtocolVersion() { + return clientProtocolVersion; + } + + protected String getProtocolName() { + return protocolName; + } } @VisibleForTesting @@ -509,6 +520,13 @@ public class ProtobufRpcEngine2 implements RpcEngine { String declaringClassProtoName = rpcRequest.getDeclaringClassProtocolName(); long clientVersion = rpcRequest.getClientProtocolVersion(); + return call(server, connectionProtocolName, request, receiveTime, + methodName, declaringClassProtoName, clientVersion); + } + + protected Writable call(RPC.Server server, String connectionProtocolName, + RpcWritable.Buffer request, long receiveTime, String methodName, + String declaringClassProtoName, long clientVersion) throws Exception { if (server.verbose) { LOG.info("Call: connectionProtocolName=" + connectionProtocolName + ", method=" + methodName);