HADOOP-7227. Remove protocol version check at proxy creation in Hadoop RPC. Contributed by jitendra.

git-svn-id: https://svn.apache.org/repos/asf/hadoop/common/trunk@1099284 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Jitendra Nath Pandey 2011-05-03 22:16:14 +00:00
parent bef14d0918
commit 78a7b9768d
7 changed files with 172 additions and 33 deletions

View File

@ -105,6 +105,9 @@ Trunk (unreleased changes)
HADOOP-7179. Federation: Improve HDFS startup scripts. (Erik Steffl HADOOP-7179. Federation: Improve HDFS startup scripts. (Erik Steffl
and Tanping Wang via suresh) and Tanping Wang via suresh)
HADOOP-7227. Remove protocol version check at proxy creation in Hadoop
RPC. (jitendra)
OPTIMIZATIONS OPTIMIZATIONS
BUG FIXES BUG FIXES

View File

@ -61,6 +61,8 @@ public class AvroRpcEngine implements RpcEngine {
/** Tunnel an Avro RPC request and response through Hadoop's RPC. */ /** Tunnel an Avro RPC request and response through Hadoop's RPC. */
private static interface TunnelProtocol extends VersionedProtocol { private static interface TunnelProtocol extends VersionedProtocol {
//WritableRpcEngine expects a versionID in every protocol.
public static final long versionID = 0L;
/** All Avro methods and responses go through this. */ /** All Avro methods and responses go through this. */
BufferListWritable call(BufferListWritable request) throws IOException; BufferListWritable call(BufferListWritable request) throws IOException;
} }
@ -147,7 +149,7 @@ public class AvroRpcEngine implements RpcEngine {
protocol.getClassLoader(), protocol.getClassLoader(),
new Class[] { protocol }, new Class[] { protocol },
new Invoker(protocol, addr, ticket, conf, factory, rpcTimeout)), new Invoker(protocol, addr, ticket, conf, factory, rpcTimeout)),
null); false);
} }
/** Stop this proxy. */ /** Stop this proxy. */

View File

@ -19,6 +19,7 @@
package org.apache.hadoop.ipc; package org.apache.hadoop.ipc;
import java.io.IOException; import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.HashSet; import java.util.HashSet;
@ -34,24 +35,55 @@ public class ProtocolProxy<T> {
private Class<T> protocol; private Class<T> protocol;
private T proxy; private T proxy;
private HashSet<Integer> serverMethods = null; private HashSet<Integer> serverMethods = null;
final private boolean supportServerMethodCheck;
private boolean serverMethodsFetched = false;
/** /**
* Constructor * Constructor
* *
* @param protocol protocol class * @param protocol protocol class
* @param proxy its proxy * @param proxy its proxy
* @param serverMethods a list of hash codes of the methods that it supports * @param supportServerMethodCheck If false proxy will never fetch server
* @throws ClassNotFoundException * methods and isMethodSupported will always return true. If true,
* server methods will be fetched for the first call to
* isMethodSupported.
*/ */
public ProtocolProxy(Class<T> protocol, T proxy, int[] serverMethods) { public ProtocolProxy(Class<T> protocol, T proxy,
boolean supportServerMethodCheck) {
this.protocol = protocol; this.protocol = protocol;
this.proxy = proxy; this.proxy = proxy;
if (serverMethods != null) { this.supportServerMethodCheck = supportServerMethodCheck;
this.serverMethods = new HashSet<Integer>(serverMethods.length); }
for (int method : serverMethods) {
this.serverMethods.add(Integer.valueOf(method)); private void fetchServerMethods(Method method) throws IOException {
long clientVersion;
try {
Field versionField = method.getDeclaringClass().getField("versionID");
versionField.setAccessible(true);
clientVersion = versionField.getLong(method.getDeclaringClass());
} catch (NoSuchFieldException ex) {
throw new RuntimeException(ex);
} catch (IllegalAccessException ex) {
throw new RuntimeException(ex);
}
int clientMethodsHash = ProtocolSignature.getFingerprint(method
.getDeclaringClass().getMethods());
ProtocolSignature serverInfo = ((VersionedProtocol) proxy)
.getProtocolSignature(protocol.getName(), clientVersion,
clientMethodsHash);
long serverVersion = serverInfo.getVersion();
if (serverVersion != clientVersion) {
throw new RPC.VersionMismatch(protocol.getName(), clientVersion,
serverVersion);
}
int[] serverMethodsCodes = serverInfo.getMethods();
if (serverMethodsCodes != null) {
serverMethods = new HashSet<Integer>(serverMethodsCodes.length);
for (int m : serverMethodsCodes) {
this.serverMethods.add(Integer.valueOf(m));
} }
} }
serverMethodsFetched = true;
} }
/* /*
@ -68,10 +100,10 @@ public class ProtocolProxy<T> {
* @param parameterTypes a method's parameter types * @param parameterTypes a method's parameter types
* @return true if the method is supported by the server * @return true if the method is supported by the server
*/ */
public boolean isMethodSupported(String methodName, public synchronized boolean isMethodSupported(String methodName,
Class<?>... parameterTypes) Class<?>... parameterTypes)
throws IOException { throws IOException {
if (serverMethods == null) { // client & server have the same protocol if (!supportServerMethodCheck) {
return true; return true;
} }
Method method; Method method;
@ -82,6 +114,12 @@ public class ProtocolProxy<T> {
} catch (NoSuchMethodException e) { } catch (NoSuchMethodException e) {
throw new IOException(e); throw new IOException(e);
} }
if (!serverMethodsFetched) {
fetchServerMethods(method);
}
if (serverMethods == null) { // client & server have the same protocol
return true;
}
return serverMethods.contains( return serverMethods.contains(
Integer.valueOf(ProtocolSignature.getFingerprint(method))); Integer.valueOf(ProtocolSignature.getFingerprint(method)));
} }

View File

@ -18,6 +18,7 @@
package org.apache.hadoop.ipc; package org.apache.hadoop.ipc;
import java.lang.reflect.Field;
import java.lang.reflect.Proxy; import java.lang.reflect.Proxy;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.lang.reflect.Array; import java.lang.reflect.Array;
@ -46,6 +47,10 @@ import org.apache.hadoop.metrics.util.MetricsTimeVaryingRate;
@InterfaceStability.Evolving @InterfaceStability.Evolving
public class WritableRpcEngine implements RpcEngine { public class WritableRpcEngine implements RpcEngine {
private static final Log LOG = LogFactory.getLog(RPC.class); private static final Log LOG = LogFactory.getLog(RPC.class);
//writableRpcVersion should be updated if there is a change
//in format of the rpc messages.
public static long writableRpcVersion = 1L;
/** A method invocation, including the method name and its parameters.*/ /** A method invocation, including the method name and its parameters.*/
private static class Invocation implements Writable, Configurable { private static class Invocation implements Writable, Configurable {
@ -53,6 +58,12 @@ public class WritableRpcEngine implements RpcEngine {
private Class<?>[] parameterClasses; private Class<?>[] parameterClasses;
private Object[] parameters; private Object[] parameters;
private Configuration conf; private Configuration conf;
private long clientVersion;
private int clientMethodsHash;
//This could be different from static writableRpcVersion when received
//at server, if client is using a different version.
private long rpcVersion;
public Invocation() {} public Invocation() {}
@ -60,6 +71,24 @@ public class WritableRpcEngine implements RpcEngine {
this.methodName = method.getName(); this.methodName = method.getName();
this.parameterClasses = method.getParameterTypes(); this.parameterClasses = method.getParameterTypes();
this.parameters = parameters; this.parameters = parameters;
rpcVersion = writableRpcVersion;
if (method.getDeclaringClass().equals(VersionedProtocol.class)) {
//VersionedProtocol is exempted from version check.
clientVersion = 0;
clientMethodsHash = 0;
} else {
try {
Field versionField = method.getDeclaringClass().getField("versionID");
versionField.setAccessible(true);
this.clientVersion = versionField.getLong(method.getDeclaringClass());
} catch (NoSuchFieldException ex) {
throw new RuntimeException(ex);
} catch (IllegalAccessException ex) {
throw new RuntimeException(ex);
}
this.clientMethodsHash = ProtocolSignature.getFingerprint(method
.getDeclaringClass().getMethods());
}
} }
/** The name of the method invoked. */ /** The name of the method invoked. */
@ -70,9 +99,28 @@ public class WritableRpcEngine implements RpcEngine {
/** The parameter instances. */ /** The parameter instances. */
public Object[] getParameters() { return parameters; } public Object[] getParameters() { return parameters; }
private long getProtocolVersion() {
return clientVersion;
}
private int getClientMethodsHash() {
return clientMethodsHash;
}
/**
* Returns the rpc version used by the client.
* @return rpcVersion
*/
public long getRpcVersion() {
return rpcVersion;
}
public void readFields(DataInput in) throws IOException { public void readFields(DataInput in) throws IOException {
rpcVersion = in.readLong();
methodName = UTF8.readString(in); methodName = UTF8.readString(in);
clientVersion = in.readLong();
clientMethodsHash = in.readInt();
parameters = new Object[in.readInt()]; parameters = new Object[in.readInt()];
parameterClasses = new Class[parameters.length]; parameterClasses = new Class[parameters.length];
ObjectWritable objectWritable = new ObjectWritable(); ObjectWritable objectWritable = new ObjectWritable();
@ -83,7 +131,10 @@ public class WritableRpcEngine implements RpcEngine {
} }
public void write(DataOutput out) throws IOException { public void write(DataOutput out) throws IOException {
out.writeLong(rpcVersion);
UTF8.writeString(out, methodName); UTF8.writeString(out, methodName);
out.writeLong(clientVersion);
out.writeInt(clientMethodsHash);
out.writeInt(parameterClasses.length); out.writeInt(parameterClasses.length);
for (int i = 0; i < parameterClasses.length; i++) { for (int i = 0; i < parameterClasses.length; i++) {
ObjectWritable.writeObject(out, parameters[i], parameterClasses[i], ObjectWritable.writeObject(out, parameters[i], parameterClasses[i],
@ -101,6 +152,9 @@ public class WritableRpcEngine implements RpcEngine {
buffer.append(parameters[i]); buffer.append(parameters[i]);
} }
buffer.append(")"); buffer.append(")");
buffer.append(", rpc version="+rpcVersion);
buffer.append(", client version="+clientVersion);
buffer.append(", methodsFingerPrint="+clientMethodsHash);
return buffer.toString(); return buffer.toString();
} }
@ -230,22 +284,10 @@ public class WritableRpcEngine implements RpcEngine {
int rpcTimeout) int rpcTimeout)
throws IOException { throws IOException {
T proxy = (T)Proxy.newProxyInstance T proxy = (T) Proxy.newProxyInstance(protocol.getClassLoader(),
(protocol.getClassLoader(), new Class[] { protocol }, new Class[] { protocol }, new Invoker(protocol, addr, ticket, conf,
new Invoker(protocol, addr, ticket, conf, factory, rpcTimeout)); factory, rpcTimeout));
int[] serverMethods = null; return new ProtocolProxy<T>(protocol, proxy, true);
if (proxy instanceof VersionedProtocol) {
ProtocolSignature serverInfo = ((VersionedProtocol)proxy)
.getProtocolSignature(protocol.getName(), clientVersion,
ProtocolSignature.getFingerprint(protocol.getMethods()));
long serverVersion = serverInfo.getVersion();
if (serverVersion != clientVersion) {
throw new RPC.VersionMismatch(protocol.getName(), clientVersion,
serverVersion);
}
serverMethods = serverInfo.getMethods();
}
return new ProtocolProxy<T>(protocol, proxy, serverMethods);
} }
/** /**
@ -353,6 +395,31 @@ public class WritableRpcEngine implements RpcEngine {
call.getParameterClasses()); call.getParameterClasses());
method.setAccessible(true); method.setAccessible(true);
// Verify rpc version
if (call.getRpcVersion() != writableRpcVersion) {
// Client is using a different version of WritableRpc
throw new IOException(
"WritableRpc version mismatch, client side version="
+ call.getRpcVersion() + ", server side version="
+ writableRpcVersion);
}
//Verify protocol version.
//Bypass the version check for VersionedProtocol
if (!method.getDeclaringClass().equals(VersionedProtocol.class)) {
long clientVersion = call.getProtocolVersion();
ProtocolSignature serverInfo = ((VersionedProtocol) instance)
.getProtocolSignature(protocol.getCanonicalName(), call
.getProtocolVersion(), call.getClientMethodsHash());
long serverVersion = serverInfo.getVersion();
if (serverVersion != clientVersion) {
LOG.warn("Version mismatch: client version=" + clientVersion
+ ", server version=" + serverVersion);
throw new RPC.VersionMismatch(protocol.getName(), clientVersion,
serverVersion);
}
}
long startTime = System.currentTimeMillis(); long startTime = System.currentTimeMillis();
Object value = method.invoke(instance, call.getParameters()); Object value = method.invoke(instance, call.getParameters());
int processingTime = (int) (System.currentTimeMillis() - startTime); int processingTime = (int) (System.currentTimeMillis() - startTime);

View File

@ -290,8 +290,7 @@ public class TestRPC extends TestCase {
// Check rpcMetrics // Check rpcMetrics
server.rpcMetrics.doUpdates(new NullContext()); server.rpcMetrics.doUpdates(new NullContext());
// Number 4 includes getProtocolVersion() assertEquals(3, server.rpcMetrics.rpcProcessingTime.getPreviousIntervalNumOps());
assertEquals(4, server.rpcMetrics.rpcProcessingTime.getPreviousIntervalNumOps());
assertTrue(server.rpcMetrics.sentBytes.getPreviousIntervalValue() > 0); assertTrue(server.rpcMetrics.sentBytes.getPreviousIntervalValue() > 0);
assertTrue(server.rpcMetrics.receivedBytes.getPreviousIntervalValue() > 0); assertTrue(server.rpcMetrics.receivedBytes.getPreviousIntervalValue() > 0);
@ -376,8 +375,9 @@ public class TestRPC extends TestCase {
public void testStandaloneClient() throws IOException { public void testStandaloneClient() throws IOException {
try { try {
RPC.waitForProxy(TestProtocol.class, TestProtocol proxy = RPC.waitForProxy(TestProtocol.class,
TestProtocol.versionID, new InetSocketAddress(ADDRESS, 20), conf, 15000L); TestProtocol.versionID, new InetSocketAddress(ADDRESS, 20), conf, 15000L);
proxy.echo("");
fail("We should not have reached here"); fail("We should not have reached here");
} catch (ConnectException ioe) { } catch (ConnectException ioe) {
//this is what we expected //this is what we expected
@ -502,6 +502,7 @@ public class TestRPC extends TestCase {
try { try {
proxy = (TestProtocol) RPC.getProxy(TestProtocol.class, proxy = (TestProtocol) RPC.getProxy(TestProtocol.class,
TestProtocol.versionID, addr, conf); TestProtocol.versionID, addr, conf);
proxy.echo("");
} catch (RemoteException e) { } catch (RemoteException e) {
LOG.info("LOGGING MESSAGE: " + e.getLocalizedMessage()); LOG.info("LOGGING MESSAGE: " + e.getLocalizedMessage());
assertTrue(e.unwrapRemoteException() instanceof AccessControlException); assertTrue(e.unwrapRemoteException() instanceof AccessControlException);
@ -527,6 +528,7 @@ public class TestRPC extends TestCase {
try { try {
proxy = (TestProtocol) RPC.getProxy(TestProtocol.class, proxy = (TestProtocol) RPC.getProxy(TestProtocol.class,
TestProtocol.versionID, mulitServerAddr, conf); TestProtocol.versionID, mulitServerAddr, conf);
proxy.echo("");
} catch (RemoteException e) { } catch (RemoteException e) {
LOG.info("LOGGING MESSAGE: " + e.getLocalizedMessage()); LOG.info("LOGGING MESSAGE: " + e.getLocalizedMessage());
assertTrue(e.unwrapRemoteException() instanceof AccessControlException); assertTrue(e.unwrapRemoteException() instanceof AccessControlException);

View File

@ -18,19 +18,21 @@
package org.apache.hadoop.ipc; package org.apache.hadoop.ipc;
import static org.junit.Assert.*; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.fail;
import java.io.IOException; import java.io.IOException;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import org.apache.commons.logging.*; import junit.framework.Assert;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.net.NetUtils; import org.apache.hadoop.net.NetUtils;
import org.junit.After; import org.junit.After;
import org.junit.Test; import org.junit.Test;
/** Unit test for supporting method-name based compatible RPCs. */ /** Unit test for supporting method-name based compatible RPCs. */
@ -247,4 +249,26 @@ public class TestRPCCompatibility {
int hash2 = ProtocolSignature.getFingerprint(new Method[] {strMethod, intMethod}); int hash2 = ProtocolSignature.getFingerprint(new Method[] {strMethod, intMethod});
assertEquals(hash1, hash2); assertEquals(hash1, hash2);
} }
public interface TestProtocol4 extends TestProtocol2 {
public static final long versionID = 1L;
int echo(int value) throws IOException;
}
@Test
public void testVersionMismatch() throws IOException {
server = RPC.getServer(TestProtocol2.class, new TestImpl0(), ADDRESS, 0, 2,
false, conf, null);
server.start();
addr = NetUtils.getConnectAddress(server);
TestProtocol4 proxy = RPC.getProxy(TestProtocol4.class,
TestProtocol4.versionID, addr, conf);
try {
proxy.echo(21);
fail("The call must throw VersionMismatch exception");
} catch (IOException ex) {
Assert.assertTrue(ex.getMessage().contains("VersionMismatch"));
}
}
} }

View File

@ -321,17 +321,20 @@ public class TestSaslRPC {
try { try {
proxy1 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class, proxy1 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
TestSaslProtocol.versionID, addr, newConf); TestSaslProtocol.versionID, addr, newConf);
proxy1.getAuthMethod();
Client client = WritableRpcEngine.getClient(conf); Client client = WritableRpcEngine.getClient(conf);
Set<ConnectionId> conns = client.getConnectionIds(); Set<ConnectionId> conns = client.getConnectionIds();
assertEquals("number of connections in cache is wrong", 1, conns.size()); assertEquals("number of connections in cache is wrong", 1, conns.size());
// same conf, connection should be re-used // same conf, connection should be re-used
proxy2 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class, proxy2 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
TestSaslProtocol.versionID, addr, newConf); TestSaslProtocol.versionID, addr, newConf);
proxy2.getAuthMethod();
assertEquals("number of connections in cache is wrong", 1, conns.size()); assertEquals("number of connections in cache is wrong", 1, conns.size());
// different conf, new connection should be set up // different conf, new connection should be set up
newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_2); newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_2);
proxy3 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class, proxy3 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
TestSaslProtocol.versionID, addr, newConf); TestSaslProtocol.versionID, addr, newConf);
proxy3.getAuthMethod();
ConnectionId[] connsArray = conns.toArray(new ConnectionId[0]); ConnectionId[] connsArray = conns.toArray(new ConnectionId[0]);
assertEquals("number of connections in cache is wrong", 2, assertEquals("number of connections in cache is wrong", 2,
connsArray.length); connsArray.length);