diff --git a/CHANGES.txt b/CHANGES.txt index 727ee6e7bc3..ce44d182dae 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -105,6 +105,9 @@ Trunk (unreleased changes) HADOOP-7179. Federation: Improve HDFS startup scripts. (Erik Steffl and Tanping Wang via suresh) + HADOOP-7227. Remove protocol version check at proxy creation in Hadoop + RPC. (jitendra) + OPTIMIZATIONS BUG FIXES diff --git a/src/java/org/apache/hadoop/ipc/AvroRpcEngine.java b/src/java/org/apache/hadoop/ipc/AvroRpcEngine.java index adef1eac6a3..180e7811b1c 100644 --- a/src/java/org/apache/hadoop/ipc/AvroRpcEngine.java +++ b/src/java/org/apache/hadoop/ipc/AvroRpcEngine.java @@ -61,6 +61,8 @@ public class AvroRpcEngine implements RpcEngine { /** Tunnel an Avro RPC request and response through Hadoop's RPC. */ private static interface TunnelProtocol extends VersionedProtocol { + //WritableRpcEngine expects a versionID in every protocol. + public static final long versionID = 0L; /** All Avro methods and responses go through this. */ BufferListWritable call(BufferListWritable request) throws IOException; } @@ -147,7 +149,7 @@ public class AvroRpcEngine implements RpcEngine { protocol.getClassLoader(), new Class[] { protocol }, new Invoker(protocol, addr, ticket, conf, factory, rpcTimeout)), - null); + false); } /** Stop this proxy. */ diff --git a/src/java/org/apache/hadoop/ipc/ProtocolProxy.java b/src/java/org/apache/hadoop/ipc/ProtocolProxy.java index 52005afd125..937031c6748 100644 --- a/src/java/org/apache/hadoop/ipc/ProtocolProxy.java +++ b/src/java/org/apache/hadoop/ipc/ProtocolProxy.java @@ -19,6 +19,7 @@ package org.apache.hadoop.ipc; import java.io.IOException; +import java.lang.reflect.Field; import java.lang.reflect.Method; import java.util.HashSet; @@ -34,24 +35,55 @@ public class ProtocolProxy { private Class protocol; private T proxy; private HashSet serverMethods = null; + final private boolean supportServerMethodCheck; + private boolean serverMethodsFetched = false; /** * Constructor * * @param protocol protocol class * @param proxy its proxy - * @param serverMethods a list of hash codes of the methods that it supports - * @throws ClassNotFoundException + * @param supportServerMethodCheck If false proxy will never fetch server + * methods and isMethodSupported will always return true. If true, + * server methods will be fetched for the first call to + * isMethodSupported. */ - public ProtocolProxy(Class protocol, T proxy, int[] serverMethods) { + public ProtocolProxy(Class protocol, T proxy, + boolean supportServerMethodCheck) { this.protocol = protocol; this.proxy = proxy; - if (serverMethods != null) { - this.serverMethods = new HashSet(serverMethods.length); - for (int method : serverMethods) { - this.serverMethods.add(Integer.valueOf(method)); + this.supportServerMethodCheck = supportServerMethodCheck; + } + + private void fetchServerMethods(Method method) throws IOException { + long clientVersion; + try { + Field versionField = method.getDeclaringClass().getField("versionID"); + versionField.setAccessible(true); + clientVersion = versionField.getLong(method.getDeclaringClass()); + } catch (NoSuchFieldException ex) { + throw new RuntimeException(ex); + } catch (IllegalAccessException ex) { + throw new RuntimeException(ex); + } + int clientMethodsHash = ProtocolSignature.getFingerprint(method + .getDeclaringClass().getMethods()); + ProtocolSignature serverInfo = ((VersionedProtocol) proxy) + .getProtocolSignature(protocol.getName(), clientVersion, + clientMethodsHash); + long serverVersion = serverInfo.getVersion(); + if (serverVersion != clientVersion) { + throw new RPC.VersionMismatch(protocol.getName(), clientVersion, + serverVersion); + } + int[] serverMethodsCodes = serverInfo.getMethods(); + if (serverMethodsCodes != null) { + serverMethods = new HashSet(serverMethodsCodes.length); + for (int m : serverMethodsCodes) { + this.serverMethods.add(Integer.valueOf(m)); } } + serverMethodsFetched = true; } /* @@ -68,10 +100,10 @@ public class ProtocolProxy { * @param parameterTypes a method's parameter types * @return true if the method is supported by the server */ - public boolean isMethodSupported(String methodName, + public synchronized boolean isMethodSupported(String methodName, Class... parameterTypes) throws IOException { - if (serverMethods == null) { // client & server have the same protocol + if (!supportServerMethodCheck) { return true; } Method method; @@ -82,6 +114,12 @@ public class ProtocolProxy { } catch (NoSuchMethodException e) { throw new IOException(e); } + if (!serverMethodsFetched) { + fetchServerMethods(method); + } + if (serverMethods == null) { // client & server have the same protocol + return true; + } return serverMethods.contains( Integer.valueOf(ProtocolSignature.getFingerprint(method))); } diff --git a/src/java/org/apache/hadoop/ipc/WritableRpcEngine.java b/src/java/org/apache/hadoop/ipc/WritableRpcEngine.java index 6077720d440..49feaf6d9e5 100644 --- a/src/java/org/apache/hadoop/ipc/WritableRpcEngine.java +++ b/src/java/org/apache/hadoop/ipc/WritableRpcEngine.java @@ -18,6 +18,7 @@ package org.apache.hadoop.ipc; +import java.lang.reflect.Field; import java.lang.reflect.Proxy; import java.lang.reflect.Method; import java.lang.reflect.Array; @@ -46,6 +47,10 @@ import org.apache.hadoop.metrics.util.MetricsTimeVaryingRate; @InterfaceStability.Evolving public class WritableRpcEngine implements RpcEngine { private static final Log LOG = LogFactory.getLog(RPC.class); + + //writableRpcVersion should be updated if there is a change + //in format of the rpc messages. + public static long writableRpcVersion = 1L; /** A method invocation, including the method name and its parameters.*/ private static class Invocation implements Writable, Configurable { @@ -53,6 +58,12 @@ public class WritableRpcEngine implements RpcEngine { private Class[] parameterClasses; private Object[] parameters; private Configuration conf; + private long clientVersion; + private int clientMethodsHash; + + //This could be different from static writableRpcVersion when received + //at server, if client is using a different version. + private long rpcVersion; public Invocation() {} @@ -60,6 +71,24 @@ public class WritableRpcEngine implements RpcEngine { this.methodName = method.getName(); this.parameterClasses = method.getParameterTypes(); this.parameters = parameters; + rpcVersion = writableRpcVersion; + if (method.getDeclaringClass().equals(VersionedProtocol.class)) { + //VersionedProtocol is exempted from version check. + clientVersion = 0; + clientMethodsHash = 0; + } else { + try { + Field versionField = method.getDeclaringClass().getField("versionID"); + versionField.setAccessible(true); + this.clientVersion = versionField.getLong(method.getDeclaringClass()); + } catch (NoSuchFieldException ex) { + throw new RuntimeException(ex); + } catch (IllegalAccessException ex) { + throw new RuntimeException(ex); + } + this.clientMethodsHash = ProtocolSignature.getFingerprint(method + .getDeclaringClass().getMethods()); + } } /** The name of the method invoked. */ @@ -70,9 +99,28 @@ public class WritableRpcEngine implements RpcEngine { /** The parameter instances. */ public Object[] getParameters() { return parameters; } + + private long getProtocolVersion() { + return clientVersion; + } + + private int getClientMethodsHash() { + return clientMethodsHash; + } + + /** + * Returns the rpc version used by the client. + * @return rpcVersion + */ + public long getRpcVersion() { + return rpcVersion; + } public void readFields(DataInput in) throws IOException { + rpcVersion = in.readLong(); methodName = UTF8.readString(in); + clientVersion = in.readLong(); + clientMethodsHash = in.readInt(); parameters = new Object[in.readInt()]; parameterClasses = new Class[parameters.length]; ObjectWritable objectWritable = new ObjectWritable(); @@ -83,7 +131,10 @@ public class WritableRpcEngine implements RpcEngine { } public void write(DataOutput out) throws IOException { + out.writeLong(rpcVersion); UTF8.writeString(out, methodName); + out.writeLong(clientVersion); + out.writeInt(clientMethodsHash); out.writeInt(parameterClasses.length); for (int i = 0; i < parameterClasses.length; i++) { ObjectWritable.writeObject(out, parameters[i], parameterClasses[i], @@ -101,6 +152,9 @@ public class WritableRpcEngine implements RpcEngine { buffer.append(parameters[i]); } buffer.append(")"); + buffer.append(", rpc version="+rpcVersion); + buffer.append(", client version="+clientVersion); + buffer.append(", methodsFingerPrint="+clientMethodsHash); return buffer.toString(); } @@ -230,22 +284,10 @@ public class WritableRpcEngine implements RpcEngine { int rpcTimeout) throws IOException { - T proxy = (T)Proxy.newProxyInstance - (protocol.getClassLoader(), new Class[] { protocol }, - new Invoker(protocol, addr, ticket, conf, factory, rpcTimeout)); - int[] serverMethods = null; - if (proxy instanceof VersionedProtocol) { - ProtocolSignature serverInfo = ((VersionedProtocol)proxy) - .getProtocolSignature(protocol.getName(), clientVersion, - ProtocolSignature.getFingerprint(protocol.getMethods())); - long serverVersion = serverInfo.getVersion(); - if (serverVersion != clientVersion) { - throw new RPC.VersionMismatch(protocol.getName(), clientVersion, - serverVersion); - } - serverMethods = serverInfo.getMethods(); - } - return new ProtocolProxy(protocol, proxy, serverMethods); + T proxy = (T) Proxy.newProxyInstance(protocol.getClassLoader(), + new Class[] { protocol }, new Invoker(protocol, addr, ticket, conf, + factory, rpcTimeout)); + return new ProtocolProxy(protocol, proxy, true); } /** @@ -353,6 +395,31 @@ public class WritableRpcEngine implements RpcEngine { call.getParameterClasses()); method.setAccessible(true); + // Verify rpc version + if (call.getRpcVersion() != writableRpcVersion) { + // Client is using a different version of WritableRpc + throw new IOException( + "WritableRpc version mismatch, client side version=" + + call.getRpcVersion() + ", server side version=" + + writableRpcVersion); + } + + //Verify protocol version. + //Bypass the version check for VersionedProtocol + if (!method.getDeclaringClass().equals(VersionedProtocol.class)) { + long clientVersion = call.getProtocolVersion(); + ProtocolSignature serverInfo = ((VersionedProtocol) instance) + .getProtocolSignature(protocol.getCanonicalName(), call + .getProtocolVersion(), call.getClientMethodsHash()); + long serverVersion = serverInfo.getVersion(); + if (serverVersion != clientVersion) { + LOG.warn("Version mismatch: client version=" + clientVersion + + ", server version=" + serverVersion); + throw new RPC.VersionMismatch(protocol.getName(), clientVersion, + serverVersion); + } + } + long startTime = System.currentTimeMillis(); Object value = method.invoke(instance, call.getParameters()); int processingTime = (int) (System.currentTimeMillis() - startTime); diff --git a/src/test/core/org/apache/hadoop/ipc/TestRPC.java b/src/test/core/org/apache/hadoop/ipc/TestRPC.java index 5f284805843..a867a50bfdb 100644 --- a/src/test/core/org/apache/hadoop/ipc/TestRPC.java +++ b/src/test/core/org/apache/hadoop/ipc/TestRPC.java @@ -290,8 +290,7 @@ public class TestRPC extends TestCase { // Check rpcMetrics server.rpcMetrics.doUpdates(new NullContext()); - // Number 4 includes getProtocolVersion() - assertEquals(4, server.rpcMetrics.rpcProcessingTime.getPreviousIntervalNumOps()); + assertEquals(3, server.rpcMetrics.rpcProcessingTime.getPreviousIntervalNumOps()); assertTrue(server.rpcMetrics.sentBytes.getPreviousIntervalValue() > 0); assertTrue(server.rpcMetrics.receivedBytes.getPreviousIntervalValue() > 0); @@ -376,8 +375,9 @@ public class TestRPC extends TestCase { public void testStandaloneClient() throws IOException { try { - RPC.waitForProxy(TestProtocol.class, + TestProtocol proxy = RPC.waitForProxy(TestProtocol.class, TestProtocol.versionID, new InetSocketAddress(ADDRESS, 20), conf, 15000L); + proxy.echo(""); fail("We should not have reached here"); } catch (ConnectException ioe) { //this is what we expected @@ -502,6 +502,7 @@ public class TestRPC extends TestCase { try { proxy = (TestProtocol) RPC.getProxy(TestProtocol.class, TestProtocol.versionID, addr, conf); + proxy.echo(""); } catch (RemoteException e) { LOG.info("LOGGING MESSAGE: " + e.getLocalizedMessage()); assertTrue(e.unwrapRemoteException() instanceof AccessControlException); @@ -527,6 +528,7 @@ public class TestRPC extends TestCase { try { proxy = (TestProtocol) RPC.getProxy(TestProtocol.class, TestProtocol.versionID, mulitServerAddr, conf); + proxy.echo(""); } catch (RemoteException e) { LOG.info("LOGGING MESSAGE: " + e.getLocalizedMessage()); assertTrue(e.unwrapRemoteException() instanceof AccessControlException); diff --git a/src/test/core/org/apache/hadoop/ipc/TestRPCCompatibility.java b/src/test/core/org/apache/hadoop/ipc/TestRPCCompatibility.java index 18636581107..02ca2afe42a 100644 --- a/src/test/core/org/apache/hadoop/ipc/TestRPCCompatibility.java +++ b/src/test/core/org/apache/hadoop/ipc/TestRPCCompatibility.java @@ -18,19 +18,21 @@ package org.apache.hadoop.ipc; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.fail; import java.io.IOException; import java.lang.reflect.Method; import java.net.InetSocketAddress; -import org.apache.commons.logging.*; +import junit.framework.Assert; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.net.NetUtils; - import org.junit.After; - import org.junit.Test; /** Unit test for supporting method-name based compatible RPCs. */ @@ -247,4 +249,26 @@ public class TestRPCCompatibility { int hash2 = ProtocolSignature.getFingerprint(new Method[] {strMethod, intMethod}); assertEquals(hash1, hash2); } + + public interface TestProtocol4 extends TestProtocol2 { + public static final long versionID = 1L; + int echo(int value) throws IOException; + } + + @Test + public void testVersionMismatch() throws IOException { + server = RPC.getServer(TestProtocol2.class, new TestImpl0(), ADDRESS, 0, 2, + false, conf, null); + server.start(); + addr = NetUtils.getConnectAddress(server); + + TestProtocol4 proxy = RPC.getProxy(TestProtocol4.class, + TestProtocol4.versionID, addr, conf); + try { + proxy.echo(21); + fail("The call must throw VersionMismatch exception"); + } catch (IOException ex) { + Assert.assertTrue(ex.getMessage().contains("VersionMismatch")); + } + } } \ No newline at end of file diff --git a/src/test/core/org/apache/hadoop/ipc/TestSaslRPC.java b/src/test/core/org/apache/hadoop/ipc/TestSaslRPC.java index 5cddffc3451..b89e3a74431 100644 --- a/src/test/core/org/apache/hadoop/ipc/TestSaslRPC.java +++ b/src/test/core/org/apache/hadoop/ipc/TestSaslRPC.java @@ -321,17 +321,20 @@ public class TestSaslRPC { try { proxy1 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class, TestSaslProtocol.versionID, addr, newConf); + proxy1.getAuthMethod(); Client client = WritableRpcEngine.getClient(conf); Set conns = client.getConnectionIds(); assertEquals("number of connections in cache is wrong", 1, conns.size()); // same conf, connection should be re-used proxy2 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class, TestSaslProtocol.versionID, addr, newConf); + proxy2.getAuthMethod(); assertEquals("number of connections in cache is wrong", 1, conns.size()); // different conf, new connection should be set up newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_2); proxy3 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class, TestSaslProtocol.versionID, addr, newConf); + proxy3.getAuthMethod(); ConnectionId[] connsArray = conns.toArray(new ConnectionId[0]); assertEquals("number of connections in cache is wrong", 2, connsArray.length);