HADOOP-7227. Remove protocol version check at proxy creation in Hadoop RPC. Contributed by jitendra.
git-svn-id: https://svn.apache.org/repos/asf/hadoop/common/trunk@1098792 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
76a7219ced
commit
accf84fd10
|
@ -102,6 +102,9 @@ Trunk (unreleased changes)
|
|||
HADOOP-7235. Refactor the tail command to conform to new FsCommand class.
|
||||
(Daryn Sharp via szetszwo)
|
||||
|
||||
HADOOP-7227. Remove protocol version check at proxy creation in Hadoop
|
||||
RPC. (jitendra)
|
||||
|
||||
OPTIMIZATIONS
|
||||
|
||||
BUG FIXES
|
||||
|
|
|
@ -61,6 +61,8 @@ public class AvroRpcEngine implements RpcEngine {
|
|||
|
||||
/** Tunnel an Avro RPC request and response through Hadoop's RPC. */
|
||||
private static interface TunnelProtocol extends VersionedProtocol {
|
||||
//WritableRpcEngine expects a versionID in every protocol.
|
||||
public static final long versionID = 0L;
|
||||
/** All Avro methods and responses go through this. */
|
||||
BufferListWritable call(BufferListWritable request) throws IOException;
|
||||
}
|
||||
|
@ -147,7 +149,7 @@ public class AvroRpcEngine implements RpcEngine {
|
|||
protocol.getClassLoader(),
|
||||
new Class[] { protocol },
|
||||
new Invoker(protocol, addr, ticket, conf, factory, rpcTimeout)),
|
||||
null);
|
||||
false);
|
||||
}
|
||||
|
||||
/** Stop this proxy. */
|
||||
|
|
|
@ -34,24 +34,54 @@ public class ProtocolProxy<T> {
|
|||
private Class<T> protocol;
|
||||
private T proxy;
|
||||
private HashSet<Integer> serverMethods = null;
|
||||
final private boolean supportServerMethodCheck;
|
||||
private boolean serverMethodsFetched = false;
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*
|
||||
* @param protocol protocol class
|
||||
* @param proxy its proxy
|
||||
* @param serverMethods a list of hash codes of the methods that it supports
|
||||
* @throws ClassNotFoundException
|
||||
* @param supportServerMethodCheck If false proxy will never fetch server
|
||||
* methods and isMethodSupported will always return true. If true,
|
||||
* server methods will be fetched for the first call to
|
||||
* isMethodSupported.
|
||||
*/
|
||||
public ProtocolProxy(Class<T> protocol, T proxy, int[] serverMethods) {
|
||||
public ProtocolProxy(Class<T> protocol, T proxy,
|
||||
boolean supportServerMethodCheck) {
|
||||
this.protocol = protocol;
|
||||
this.proxy = proxy;
|
||||
if (serverMethods != null) {
|
||||
this.serverMethods = new HashSet<Integer>(serverMethods.length);
|
||||
for (int method : serverMethods) {
|
||||
this.serverMethods.add(Integer.valueOf(method));
|
||||
this.supportServerMethodCheck = supportServerMethodCheck;
|
||||
}
|
||||
|
||||
private void fetchServerMethods(Method method) throws IOException {
|
||||
long clientVersion;
|
||||
try {
|
||||
clientVersion = method.getDeclaringClass().getField("versionID").getLong(
|
||||
method.getDeclaringClass());
|
||||
} catch (NoSuchFieldException ex) {
|
||||
throw new RuntimeException(ex);
|
||||
} catch (IllegalAccessException ex) {
|
||||
throw new RuntimeException(ex);
|
||||
}
|
||||
int clientMethodsHash = ProtocolSignature.getFingerprint(method
|
||||
.getDeclaringClass().getMethods());
|
||||
ProtocolSignature serverInfo = ((VersionedProtocol) proxy)
|
||||
.getProtocolSignature(protocol.getName(), clientVersion,
|
||||
clientMethodsHash);
|
||||
long serverVersion = serverInfo.getVersion();
|
||||
if (serverVersion != clientVersion) {
|
||||
throw new RPC.VersionMismatch(protocol.getName(), clientVersion,
|
||||
serverVersion);
|
||||
}
|
||||
int[] serverMethodsCodes = serverInfo.getMethods();
|
||||
if (serverMethodsCodes != null) {
|
||||
serverMethods = new HashSet<Integer>(serverMethodsCodes.length);
|
||||
for (int m : serverMethodsCodes) {
|
||||
this.serverMethods.add(Integer.valueOf(m));
|
||||
}
|
||||
}
|
||||
serverMethodsFetched = true;
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -68,10 +98,10 @@ public class ProtocolProxy<T> {
|
|||
* @param parameterTypes a method's parameter types
|
||||
* @return true if the method is supported by the server
|
||||
*/
|
||||
public boolean isMethodSupported(String methodName,
|
||||
public synchronized boolean isMethodSupported(String methodName,
|
||||
Class<?>... parameterTypes)
|
||||
throws IOException {
|
||||
if (serverMethods == null) { // client & server have the same protocol
|
||||
if (!supportServerMethodCheck) {
|
||||
return true;
|
||||
}
|
||||
Method method;
|
||||
|
@ -82,6 +112,12 @@ public class ProtocolProxy<T> {
|
|||
} catch (NoSuchMethodException e) {
|
||||
throw new IOException(e);
|
||||
}
|
||||
if (!serverMethodsFetched) {
|
||||
fetchServerMethods(method);
|
||||
}
|
||||
if (serverMethods == null) { // client & server have the same protocol
|
||||
return true;
|
||||
}
|
||||
return serverMethods.contains(
|
||||
Integer.valueOf(ProtocolSignature.getFingerprint(method)));
|
||||
}
|
||||
|
|
|
@ -47,12 +47,22 @@ import org.apache.hadoop.metrics.util.MetricsTimeVaryingRate;
|
|||
public class WritableRpcEngine implements RpcEngine {
|
||||
private static final Log LOG = LogFactory.getLog(RPC.class);
|
||||
|
||||
//writableRpcVersion should be updated if there is a change
|
||||
//in format of the rpc messages.
|
||||
public static long writableRpcVersion = 1L;
|
||||
|
||||
/** A method invocation, including the method name and its parameters.*/
|
||||
private static class Invocation implements Writable, Configurable {
|
||||
private String methodName;
|
||||
private Class<?>[] parameterClasses;
|
||||
private Object[] parameters;
|
||||
private Configuration conf;
|
||||
private long clientVersion;
|
||||
private int clientMethodsHash;
|
||||
|
||||
//This could be different from static writableRpcVersion when received
|
||||
//at server, if client is using a different version.
|
||||
private long rpcVersion;
|
||||
|
||||
public Invocation() {}
|
||||
|
||||
|
@ -60,6 +70,23 @@ public class WritableRpcEngine implements RpcEngine {
|
|||
this.methodName = method.getName();
|
||||
this.parameterClasses = method.getParameterTypes();
|
||||
this.parameters = parameters;
|
||||
rpcVersion = writableRpcVersion;
|
||||
if (method.getDeclaringClass().equals(VersionedProtocol.class)) {
|
||||
//VersionedProtocol is exempted from version check.
|
||||
clientVersion = 0;
|
||||
clientMethodsHash = 0;
|
||||
} else {
|
||||
try {
|
||||
this.clientVersion = method.getDeclaringClass().getField("versionID")
|
||||
.getLong(method.getDeclaringClass());
|
||||
} catch (NoSuchFieldException ex) {
|
||||
throw new RuntimeException(ex);
|
||||
} catch (IllegalAccessException ex) {
|
||||
throw new RuntimeException(ex);
|
||||
}
|
||||
this.clientMethodsHash = ProtocolSignature.getFingerprint(method
|
||||
.getDeclaringClass().getMethods());
|
||||
}
|
||||
}
|
||||
|
||||
/** The name of the method invoked. */
|
||||
|
@ -71,8 +98,27 @@ public class WritableRpcEngine implements RpcEngine {
|
|||
/** The parameter instances. */
|
||||
public Object[] getParameters() { return parameters; }
|
||||
|
||||
private long getProtocolVersion() {
|
||||
return clientVersion;
|
||||
}
|
||||
|
||||
private int getClientMethodsHash() {
|
||||
return clientMethodsHash;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the rpc version used by the client.
|
||||
* @return rpcVersion
|
||||
*/
|
||||
public long getRpcVersion() {
|
||||
return rpcVersion;
|
||||
}
|
||||
|
||||
public void readFields(DataInput in) throws IOException {
|
||||
rpcVersion = in.readLong();
|
||||
methodName = UTF8.readString(in);
|
||||
clientVersion = in.readLong();
|
||||
clientMethodsHash = in.readInt();
|
||||
parameters = new Object[in.readInt()];
|
||||
parameterClasses = new Class[parameters.length];
|
||||
ObjectWritable objectWritable = new ObjectWritable();
|
||||
|
@ -83,7 +129,10 @@ public class WritableRpcEngine implements RpcEngine {
|
|||
}
|
||||
|
||||
public void write(DataOutput out) throws IOException {
|
||||
out.writeLong(rpcVersion);
|
||||
UTF8.writeString(out, methodName);
|
||||
out.writeLong(clientVersion);
|
||||
out.writeInt(clientMethodsHash);
|
||||
out.writeInt(parameterClasses.length);
|
||||
for (int i = 0; i < parameterClasses.length; i++) {
|
||||
ObjectWritable.writeObject(out, parameters[i], parameterClasses[i],
|
||||
|
@ -101,6 +150,9 @@ public class WritableRpcEngine implements RpcEngine {
|
|||
buffer.append(parameters[i]);
|
||||
}
|
||||
buffer.append(")");
|
||||
buffer.append(", rpc version="+rpcVersion);
|
||||
buffer.append(", client version="+clientVersion);
|
||||
buffer.append(", methodsFingerPrint="+clientMethodsHash);
|
||||
return buffer.toString();
|
||||
}
|
||||
|
||||
|
@ -230,22 +282,10 @@ public class WritableRpcEngine implements RpcEngine {
|
|||
int rpcTimeout)
|
||||
throws IOException {
|
||||
|
||||
T proxy = (T)Proxy.newProxyInstance
|
||||
(protocol.getClassLoader(), new Class[] { protocol },
|
||||
new Invoker(protocol, addr, ticket, conf, factory, rpcTimeout));
|
||||
int[] serverMethods = null;
|
||||
if (proxy instanceof VersionedProtocol) {
|
||||
ProtocolSignature serverInfo = ((VersionedProtocol)proxy)
|
||||
.getProtocolSignature(protocol.getName(), clientVersion,
|
||||
ProtocolSignature.getFingerprint(protocol.getMethods()));
|
||||
long serverVersion = serverInfo.getVersion();
|
||||
if (serverVersion != clientVersion) {
|
||||
throw new RPC.VersionMismatch(protocol.getName(), clientVersion,
|
||||
serverVersion);
|
||||
}
|
||||
serverMethods = serverInfo.getMethods();
|
||||
}
|
||||
return new ProtocolProxy<T>(protocol, proxy, serverMethods);
|
||||
T proxy = (T) Proxy.newProxyInstance(protocol.getClassLoader(),
|
||||
new Class[] { protocol }, new Invoker(protocol, addr, ticket, conf,
|
||||
factory, rpcTimeout));
|
||||
return new ProtocolProxy<T>(protocol, proxy, true);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -353,6 +393,31 @@ public class WritableRpcEngine implements RpcEngine {
|
|||
call.getParameterClasses());
|
||||
method.setAccessible(true);
|
||||
|
||||
// Verify rpc version
|
||||
if (call.getRpcVersion() != writableRpcVersion) {
|
||||
// Client is using a different version of WritableRpc
|
||||
throw new IOException(
|
||||
"WritableRpc version mismatch, client side version="
|
||||
+ call.getRpcVersion() + ", server side version="
|
||||
+ writableRpcVersion);
|
||||
}
|
||||
|
||||
//Verify protocol version.
|
||||
//Bypass the version check for VersionedProtocol
|
||||
if (!method.getDeclaringClass().equals(VersionedProtocol.class)) {
|
||||
long clientVersion = call.getProtocolVersion();
|
||||
ProtocolSignature serverInfo = ((VersionedProtocol) instance)
|
||||
.getProtocolSignature(protocol.getCanonicalName(), call
|
||||
.getProtocolVersion(), call.getClientMethodsHash());
|
||||
long serverVersion = serverInfo.getVersion();
|
||||
if (serverVersion != clientVersion) {
|
||||
LOG.warn("Version mismatch: client version=" + clientVersion
|
||||
+ ", server version=" + serverVersion);
|
||||
throw new RPC.VersionMismatch(protocol.getName(), clientVersion,
|
||||
serverVersion);
|
||||
}
|
||||
}
|
||||
|
||||
long startTime = System.currentTimeMillis();
|
||||
Object value = method.invoke(instance, call.getParameters());
|
||||
int processingTime = (int) (System.currentTimeMillis() - startTime);
|
||||
|
|
|
@ -290,8 +290,7 @@ public class TestRPC extends TestCase {
|
|||
// Check rpcMetrics
|
||||
server.rpcMetrics.doUpdates(new NullContext());
|
||||
|
||||
// Number 4 includes getProtocolVersion()
|
||||
assertEquals(4, server.rpcMetrics.rpcProcessingTime.getPreviousIntervalNumOps());
|
||||
assertEquals(3, server.rpcMetrics.rpcProcessingTime.getPreviousIntervalNumOps());
|
||||
assertTrue(server.rpcMetrics.sentBytes.getPreviousIntervalValue() > 0);
|
||||
assertTrue(server.rpcMetrics.receivedBytes.getPreviousIntervalValue() > 0);
|
||||
|
||||
|
@ -376,8 +375,9 @@ public class TestRPC extends TestCase {
|
|||
|
||||
public void testStandaloneClient() throws IOException {
|
||||
try {
|
||||
RPC.waitForProxy(TestProtocol.class,
|
||||
TestProtocol proxy = RPC.waitForProxy(TestProtocol.class,
|
||||
TestProtocol.versionID, new InetSocketAddress(ADDRESS, 20), conf, 15000L);
|
||||
proxy.echo("");
|
||||
fail("We should not have reached here");
|
||||
} catch (ConnectException ioe) {
|
||||
//this is what we expected
|
||||
|
@ -502,6 +502,7 @@ public class TestRPC extends TestCase {
|
|||
try {
|
||||
proxy = (TestProtocol) RPC.getProxy(TestProtocol.class,
|
||||
TestProtocol.versionID, addr, conf);
|
||||
proxy.echo("");
|
||||
} catch (RemoteException e) {
|
||||
LOG.info("LOGGING MESSAGE: " + e.getLocalizedMessage());
|
||||
assertTrue(e.unwrapRemoteException() instanceof AccessControlException);
|
||||
|
@ -527,6 +528,7 @@ public class TestRPC extends TestCase {
|
|||
try {
|
||||
proxy = (TestProtocol) RPC.getProxy(TestProtocol.class,
|
||||
TestProtocol.versionID, mulitServerAddr, conf);
|
||||
proxy.echo("");
|
||||
} catch (RemoteException e) {
|
||||
LOG.info("LOGGING MESSAGE: " + e.getLocalizedMessage());
|
||||
assertTrue(e.unwrapRemoteException() instanceof AccessControlException);
|
||||
|
|
|
@ -18,19 +18,21 @@
|
|||
|
||||
package org.apache.hadoop.ipc;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.fail;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.lang.reflect.Method;
|
||||
import java.net.InetSocketAddress;
|
||||
|
||||
import org.apache.commons.logging.*;
|
||||
import junit.framework.Assert;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.net.NetUtils;
|
||||
|
||||
import org.junit.After;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
/** Unit test for supporting method-name based compatible RPCs. */
|
||||
|
@ -247,4 +249,26 @@ public class TestRPCCompatibility {
|
|||
int hash2 = ProtocolSignature.getFingerprint(new Method[] {strMethod, intMethod});
|
||||
assertEquals(hash1, hash2);
|
||||
}
|
||||
|
||||
public interface TestProtocol4 extends TestProtocol2 {
|
||||
public static final long versionID = 1L;
|
||||
int echo(int value) throws IOException;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVersionMismatch() throws IOException {
|
||||
server = RPC.getServer(TestProtocol2.class, new TestImpl0(), ADDRESS, 0, 2,
|
||||
false, conf, null);
|
||||
server.start();
|
||||
addr = NetUtils.getConnectAddress(server);
|
||||
|
||||
TestProtocol4 proxy = RPC.getProxy(TestProtocol4.class,
|
||||
TestProtocol4.versionID, addr, conf);
|
||||
try {
|
||||
proxy.echo(21);
|
||||
fail("The call must throw VersionMismatch exception");
|
||||
} catch (IOException ex) {
|
||||
Assert.assertTrue(ex.getMessage().contains("VersionMismatch"));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -321,17 +321,20 @@ public class TestSaslRPC {
|
|||
try {
|
||||
proxy1 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
|
||||
TestSaslProtocol.versionID, addr, newConf);
|
||||
proxy1.getAuthMethod();
|
||||
Client client = WritableRpcEngine.getClient(conf);
|
||||
Set<ConnectionId> conns = client.getConnectionIds();
|
||||
assertEquals("number of connections in cache is wrong", 1, conns.size());
|
||||
// same conf, connection should be re-used
|
||||
proxy2 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
|
||||
TestSaslProtocol.versionID, addr, newConf);
|
||||
proxy2.getAuthMethod();
|
||||
assertEquals("number of connections in cache is wrong", 1, conns.size());
|
||||
// different conf, new connection should be set up
|
||||
newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_2);
|
||||
proxy3 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
|
||||
TestSaslProtocol.versionID, addr, newConf);
|
||||
proxy3.getAuthMethod();
|
||||
ConnectionId[] connsArray = conns.toArray(new ConnectionId[0]);
|
||||
assertEquals("number of connections in cache is wrong", 2,
|
||||
connsArray.length);
|
||||
|
|
Loading…
Reference in New Issue