HADOOP-7227. Remove protocol version check at proxy creation in Hadoop RPC. Contributed by jitendra.

git-svn-id: https://svn.apache.org/repos/asf/hadoop/common/trunk@1098792 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Jitendra Nath Pandey 2011-05-02 20:57:48 +00:00
parent 76a7219ced
commit accf84fd10
7 changed files with 168 additions and 33 deletions

View File

@ -102,6 +102,9 @@ Trunk (unreleased changes)
HADOOP-7235. Refactor the tail command to conform to new FsCommand class.
(Daryn Sharp via szetszwo)
HADOOP-7227. Remove protocol version check at proxy creation in Hadoop
RPC. (jitendra)
OPTIMIZATIONS
BUG FIXES

View File

@ -61,6 +61,8 @@ public class AvroRpcEngine implements RpcEngine {
/** Tunnel an Avro RPC request and response through Hadoop's RPC. */
private static interface TunnelProtocol extends VersionedProtocol {
//WritableRpcEngine expects a versionID in every protocol.
public static final long versionID = 0L;
/** All Avro methods and responses go through this. */
BufferListWritable call(BufferListWritable request) throws IOException;
}
@ -147,7 +149,7 @@ public class AvroRpcEngine implements RpcEngine {
protocol.getClassLoader(),
new Class[] { protocol },
new Invoker(protocol, addr, ticket, conf, factory, rpcTimeout)),
null);
false);
}
/** Stop this proxy. */

View File

@ -34,24 +34,54 @@ public class ProtocolProxy<T> {
private Class<T> protocol;
private T proxy;
private HashSet<Integer> serverMethods = null;
final private boolean supportServerMethodCheck;
private boolean serverMethodsFetched = false;
/**
* Constructor
*
* @param protocol protocol class
* @param proxy its proxy
* @param serverMethods a list of hash codes of the methods that it supports
* @throws ClassNotFoundException
* @param supportServerMethodCheck If false proxy will never fetch server
* methods and isMethodSupported will always return true. If true,
* server methods will be fetched for the first call to
* isMethodSupported.
*/
public ProtocolProxy(Class<T> protocol, T proxy, int[] serverMethods) {
public ProtocolProxy(Class<T> protocol, T proxy,
boolean supportServerMethodCheck) {
this.protocol = protocol;
this.proxy = proxy;
if (serverMethods != null) {
this.serverMethods = new HashSet<Integer>(serverMethods.length);
for (int method : serverMethods) {
this.serverMethods.add(Integer.valueOf(method));
this.supportServerMethodCheck = supportServerMethodCheck;
}
private void fetchServerMethods(Method method) throws IOException {
long clientVersion;
try {
clientVersion = method.getDeclaringClass().getField("versionID").getLong(
method.getDeclaringClass());
} catch (NoSuchFieldException ex) {
throw new RuntimeException(ex);
} catch (IllegalAccessException ex) {
throw new RuntimeException(ex);
}
int clientMethodsHash = ProtocolSignature.getFingerprint(method
.getDeclaringClass().getMethods());
ProtocolSignature serverInfo = ((VersionedProtocol) proxy)
.getProtocolSignature(protocol.getName(), clientVersion,
clientMethodsHash);
long serverVersion = serverInfo.getVersion();
if (serverVersion != clientVersion) {
throw new RPC.VersionMismatch(protocol.getName(), clientVersion,
serverVersion);
}
int[] serverMethodsCodes = serverInfo.getMethods();
if (serverMethodsCodes != null) {
serverMethods = new HashSet<Integer>(serverMethodsCodes.length);
for (int m : serverMethodsCodes) {
this.serverMethods.add(Integer.valueOf(m));
}
}
serverMethodsFetched = true;
}
/*
@ -68,10 +98,10 @@ public class ProtocolProxy<T> {
* @param parameterTypes a method's parameter types
* @return true if the method is supported by the server
*/
public boolean isMethodSupported(String methodName,
public synchronized boolean isMethodSupported(String methodName,
Class<?>... parameterTypes)
throws IOException {
if (serverMethods == null) { // client & server have the same protocol
if (!supportServerMethodCheck) {
return true;
}
Method method;
@ -82,6 +112,12 @@ public class ProtocolProxy<T> {
} catch (NoSuchMethodException e) {
throw new IOException(e);
}
if (!serverMethodsFetched) {
fetchServerMethods(method);
}
if (serverMethods == null) { // client & server have the same protocol
return true;
}
return serverMethods.contains(
Integer.valueOf(ProtocolSignature.getFingerprint(method)));
}

View File

@ -47,12 +47,22 @@ import org.apache.hadoop.metrics.util.MetricsTimeVaryingRate;
public class WritableRpcEngine implements RpcEngine {
private static final Log LOG = LogFactory.getLog(RPC.class);
//writableRpcVersion should be updated if there is a change
//in format of the rpc messages.
public static long writableRpcVersion = 1L;
/** A method invocation, including the method name and its parameters.*/
private static class Invocation implements Writable, Configurable {
private String methodName;
private Class<?>[] parameterClasses;
private Object[] parameters;
private Configuration conf;
private long clientVersion;
private int clientMethodsHash;
//This could be different from static writableRpcVersion when received
//at server, if client is using a different version.
private long rpcVersion;
public Invocation() {}
@ -60,6 +70,23 @@ public class WritableRpcEngine implements RpcEngine {
this.methodName = method.getName();
this.parameterClasses = method.getParameterTypes();
this.parameters = parameters;
rpcVersion = writableRpcVersion;
if (method.getDeclaringClass().equals(VersionedProtocol.class)) {
//VersionedProtocol is exempted from version check.
clientVersion = 0;
clientMethodsHash = 0;
} else {
try {
this.clientVersion = method.getDeclaringClass().getField("versionID")
.getLong(method.getDeclaringClass());
} catch (NoSuchFieldException ex) {
throw new RuntimeException(ex);
} catch (IllegalAccessException ex) {
throw new RuntimeException(ex);
}
this.clientMethodsHash = ProtocolSignature.getFingerprint(method
.getDeclaringClass().getMethods());
}
}
/** The name of the method invoked. */
@ -71,8 +98,27 @@ public class WritableRpcEngine implements RpcEngine {
/** The parameter instances. */
public Object[] getParameters() { return parameters; }
private long getProtocolVersion() {
return clientVersion;
}
private int getClientMethodsHash() {
return clientMethodsHash;
}
/**
* Returns the rpc version used by the client.
* @return rpcVersion
*/
public long getRpcVersion() {
return rpcVersion;
}
public void readFields(DataInput in) throws IOException {
rpcVersion = in.readLong();
methodName = UTF8.readString(in);
clientVersion = in.readLong();
clientMethodsHash = in.readInt();
parameters = new Object[in.readInt()];
parameterClasses = new Class[parameters.length];
ObjectWritable objectWritable = new ObjectWritable();
@ -83,7 +129,10 @@ public class WritableRpcEngine implements RpcEngine {
}
public void write(DataOutput out) throws IOException {
out.writeLong(rpcVersion);
UTF8.writeString(out, methodName);
out.writeLong(clientVersion);
out.writeInt(clientMethodsHash);
out.writeInt(parameterClasses.length);
for (int i = 0; i < parameterClasses.length; i++) {
ObjectWritable.writeObject(out, parameters[i], parameterClasses[i],
@ -101,6 +150,9 @@ public class WritableRpcEngine implements RpcEngine {
buffer.append(parameters[i]);
}
buffer.append(")");
buffer.append(", rpc version="+rpcVersion);
buffer.append(", client version="+clientVersion);
buffer.append(", methodsFingerPrint="+clientMethodsHash);
return buffer.toString();
}
@ -230,22 +282,10 @@ public class WritableRpcEngine implements RpcEngine {
int rpcTimeout)
throws IOException {
T proxy = (T)Proxy.newProxyInstance
(protocol.getClassLoader(), new Class[] { protocol },
new Invoker(protocol, addr, ticket, conf, factory, rpcTimeout));
int[] serverMethods = null;
if (proxy instanceof VersionedProtocol) {
ProtocolSignature serverInfo = ((VersionedProtocol)proxy)
.getProtocolSignature(protocol.getName(), clientVersion,
ProtocolSignature.getFingerprint(protocol.getMethods()));
long serverVersion = serverInfo.getVersion();
if (serverVersion != clientVersion) {
throw new RPC.VersionMismatch(protocol.getName(), clientVersion,
serverVersion);
}
serverMethods = serverInfo.getMethods();
}
return new ProtocolProxy<T>(protocol, proxy, serverMethods);
T proxy = (T) Proxy.newProxyInstance(protocol.getClassLoader(),
new Class[] { protocol }, new Invoker(protocol, addr, ticket, conf,
factory, rpcTimeout));
return new ProtocolProxy<T>(protocol, proxy, true);
}
/**
@ -353,6 +393,31 @@ public class WritableRpcEngine implements RpcEngine {
call.getParameterClasses());
method.setAccessible(true);
// Verify rpc version
if (call.getRpcVersion() != writableRpcVersion) {
// Client is using a different version of WritableRpc
throw new IOException(
"WritableRpc version mismatch, client side version="
+ call.getRpcVersion() + ", server side version="
+ writableRpcVersion);
}
//Verify protocol version.
//Bypass the version check for VersionedProtocol
if (!method.getDeclaringClass().equals(VersionedProtocol.class)) {
long clientVersion = call.getProtocolVersion();
ProtocolSignature serverInfo = ((VersionedProtocol) instance)
.getProtocolSignature(protocol.getCanonicalName(), call
.getProtocolVersion(), call.getClientMethodsHash());
long serverVersion = serverInfo.getVersion();
if (serverVersion != clientVersion) {
LOG.warn("Version mismatch: client version=" + clientVersion
+ ", server version=" + serverVersion);
throw new RPC.VersionMismatch(protocol.getName(), clientVersion,
serverVersion);
}
}
long startTime = System.currentTimeMillis();
Object value = method.invoke(instance, call.getParameters());
int processingTime = (int) (System.currentTimeMillis() - startTime);

View File

@ -290,8 +290,7 @@ public class TestRPC extends TestCase {
// Check rpcMetrics
server.rpcMetrics.doUpdates(new NullContext());
// Number 4 includes getProtocolVersion()
assertEquals(4, server.rpcMetrics.rpcProcessingTime.getPreviousIntervalNumOps());
assertEquals(3, server.rpcMetrics.rpcProcessingTime.getPreviousIntervalNumOps());
assertTrue(server.rpcMetrics.sentBytes.getPreviousIntervalValue() > 0);
assertTrue(server.rpcMetrics.receivedBytes.getPreviousIntervalValue() > 0);
@ -376,8 +375,9 @@ public class TestRPC extends TestCase {
public void testStandaloneClient() throws IOException {
try {
RPC.waitForProxy(TestProtocol.class,
TestProtocol proxy = RPC.waitForProxy(TestProtocol.class,
TestProtocol.versionID, new InetSocketAddress(ADDRESS, 20), conf, 15000L);
proxy.echo("");
fail("We should not have reached here");
} catch (ConnectException ioe) {
//this is what we expected
@ -502,6 +502,7 @@ public class TestRPC extends TestCase {
try {
proxy = (TestProtocol) RPC.getProxy(TestProtocol.class,
TestProtocol.versionID, addr, conf);
proxy.echo("");
} catch (RemoteException e) {
LOG.info("LOGGING MESSAGE: " + e.getLocalizedMessage());
assertTrue(e.unwrapRemoteException() instanceof AccessControlException);
@ -527,6 +528,7 @@ public class TestRPC extends TestCase {
try {
proxy = (TestProtocol) RPC.getProxy(TestProtocol.class,
TestProtocol.versionID, mulitServerAddr, conf);
proxy.echo("");
} catch (RemoteException e) {
LOG.info("LOGGING MESSAGE: " + e.getLocalizedMessage());
assertTrue(e.unwrapRemoteException() instanceof AccessControlException);

View File

@ -18,19 +18,21 @@
package org.apache.hadoop.ipc;
import static org.junit.Assert.*;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.fail;
import java.io.IOException;
import java.lang.reflect.Method;
import java.net.InetSocketAddress;
import org.apache.commons.logging.*;
import junit.framework.Assert;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.net.NetUtils;
import org.junit.After;
import org.junit.Test;
/** Unit test for supporting method-name based compatible RPCs. */
@ -247,4 +249,26 @@ public class TestRPCCompatibility {
int hash2 = ProtocolSignature.getFingerprint(new Method[] {strMethod, intMethod});
assertEquals(hash1, hash2);
}
public interface TestProtocol4 extends TestProtocol2 {
public static final long versionID = 1L;
int echo(int value) throws IOException;
}
@Test
public void testVersionMismatch() throws IOException {
server = RPC.getServer(TestProtocol2.class, new TestImpl0(), ADDRESS, 0, 2,
false, conf, null);
server.start();
addr = NetUtils.getConnectAddress(server);
TestProtocol4 proxy = RPC.getProxy(TestProtocol4.class,
TestProtocol4.versionID, addr, conf);
try {
proxy.echo(21);
fail("The call must throw VersionMismatch exception");
} catch (IOException ex) {
Assert.assertTrue(ex.getMessage().contains("VersionMismatch"));
}
}
}

View File

@ -321,17 +321,20 @@ public class TestSaslRPC {
try {
proxy1 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
TestSaslProtocol.versionID, addr, newConf);
proxy1.getAuthMethod();
Client client = WritableRpcEngine.getClient(conf);
Set<ConnectionId> conns = client.getConnectionIds();
assertEquals("number of connections in cache is wrong", 1, conns.size());
// same conf, connection should be re-used
proxy2 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
TestSaslProtocol.versionID, addr, newConf);
proxy2.getAuthMethod();
assertEquals("number of connections in cache is wrong", 1, conns.size());
// different conf, new connection should be set up
newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_2);
proxy3 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
TestSaslProtocol.versionID, addr, newConf);
proxy3.getAuthMethod();
ConnectionId[] connsArray = conns.toArray(new ConnectionId[0]);
assertEquals("number of connections in cache is wrong", 2,
connsArray.length);