HADOOP-13547. Optimize IPC client protobuf decoding. Contributed by Daryn Sharp.

This commit is contained in:
Kihwal Lee 2016-09-02 11:12:05 -05:00
parent 1222433729
commit 28ea4122f0
5 changed files with 104 additions and 295 deletions

View File

@ -21,7 +21,6 @@ package org.apache.hadoop.ipc;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.ThreadFactoryBuilder; import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.google.protobuf.CodedOutputStream;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.classification.InterfaceAudience; import org.apache.hadoop.classification.InterfaceAudience;
@ -31,13 +30,11 @@ import org.apache.hadoop.classification.InterfaceStability.Unstable;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.CommonConfigurationKeys; import org.apache.hadoop.fs.CommonConfigurationKeys;
import org.apache.hadoop.fs.CommonConfigurationKeysPublic; import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.io.IOUtils; import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.io.Writable; import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.retry.RetryPolicies; import org.apache.hadoop.io.retry.RetryPolicies;
import org.apache.hadoop.io.retry.RetryPolicy; import org.apache.hadoop.io.retry.RetryPolicy;
import org.apache.hadoop.io.retry.RetryPolicy.RetryAction; import org.apache.hadoop.io.retry.RetryPolicy.RetryAction;
import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcRequestMessageWrapper;
import org.apache.hadoop.ipc.RPC.RpcKind; import org.apache.hadoop.ipc.RPC.RpcKind;
import org.apache.hadoop.ipc.Server.AuthProtocol; import org.apache.hadoop.ipc.Server.AuthProtocol;
import org.apache.hadoop.ipc.protobuf.IpcConnectionContextProtos.IpcConnectionContextProto; import org.apache.hadoop.ipc.protobuf.IpcConnectionContextProtos.IpcConnectionContextProto;
@ -54,7 +51,6 @@ import org.apache.hadoop.security.SaslRpcServer.AuthMethod;
import org.apache.hadoop.security.SecurityUtil; import org.apache.hadoop.security.SecurityUtil;
import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.util.ProtoUtil; import org.apache.hadoop.util.ProtoUtil;
import org.apache.hadoop.util.ReflectionUtils;
import org.apache.hadoop.util.StringUtils; import org.apache.hadoop.util.StringUtils;
import org.apache.hadoop.util.Time; import org.apache.hadoop.util.Time;
import org.apache.hadoop.util.concurrent.AsyncGet; import org.apache.hadoop.util.concurrent.AsyncGet;
@ -65,6 +61,7 @@ import javax.net.SocketFactory;
import javax.security.sasl.Sasl; import javax.security.sasl.Sasl;
import java.io.*; import java.io.*;
import java.net.*; import java.net.*;
import java.nio.ByteBuffer;
import java.security.PrivilegedExceptionAction; import java.security.PrivilegedExceptionAction;
import java.util.*; import java.util.*;
import java.util.Map.Entry; import java.util.Map.Entry;
@ -431,7 +428,7 @@ public class Client implements AutoCloseable {
private final boolean doPing; //do we need to send ping message private final boolean doPing; //do we need to send ping message
private final int pingInterval; // how often sends ping to the server private final int pingInterval; // how often sends ping to the server
private final int soTimeout; // used by ipc ping and rpc timeout private final int soTimeout; // used by ipc ping and rpc timeout
private ByteArrayOutputStream pingRequest; // ping message private ResponseBuffer pingRequest; // ping message
// currently active calls // currently active calls
private Hashtable<Integer, Call> calls = new Hashtable<Integer, Call>(); private Hashtable<Integer, Call> calls = new Hashtable<Integer, Call>();
@ -461,7 +458,7 @@ public class Client implements AutoCloseable {
this.doPing = remoteId.getDoPing(); this.doPing = remoteId.getDoPing();
if (doPing) { if (doPing) {
// construct a RPC header with the callId as the ping callId // construct a RPC header with the callId as the ping callId
pingRequest = new ByteArrayOutputStream(); pingRequest = new ResponseBuffer();
RpcRequestHeaderProto pingHeader = ProtoUtil RpcRequestHeaderProto pingHeader = ProtoUtil
.makeRpcRequestHeader(RpcKind.RPC_PROTOCOL_BUFFER, .makeRpcRequestHeader(RpcKind.RPC_PROTOCOL_BUFFER,
OperationProto.RPC_FINAL_PACKET, PING_CALL_ID, OperationProto.RPC_FINAL_PACKET, PING_CALL_ID,
@ -981,12 +978,10 @@ public class Client implements AutoCloseable {
.makeRpcRequestHeader(RpcKind.RPC_PROTOCOL_BUFFER, .makeRpcRequestHeader(RpcKind.RPC_PROTOCOL_BUFFER,
OperationProto.RPC_FINAL_PACKET, CONNECTION_CONTEXT_CALL_ID, OperationProto.RPC_FINAL_PACKET, CONNECTION_CONTEXT_CALL_ID,
RpcConstants.INVALID_RETRY_COUNT, clientId); RpcConstants.INVALID_RETRY_COUNT, clientId);
RpcRequestMessageWrapper request = final ResponseBuffer buf = new ResponseBuffer();
new RpcRequestMessageWrapper(connectionContextHeader, message); connectionContextHeader.writeDelimitedTo(buf);
message.writeDelimitedTo(buf);
// Write out the packet length buf.writeTo(out);
out.writeInt(request.getLength());
request.write(out);
} }
/* wait till someone signals us to start reading RPC response or /* wait till someone signals us to start reading RPC response or
@ -1032,7 +1027,6 @@ public class Client implements AutoCloseable {
if ( curTime - lastActivity.get() >= pingInterval) { if ( curTime - lastActivity.get() >= pingInterval) {
lastActivity.set(curTime); lastActivity.set(curTime);
synchronized (out) { synchronized (out) {
out.writeInt(pingRequest.size());
pingRequest.writeTo(out); pingRequest.writeTo(out);
out.flush(); out.flush();
} }
@ -1087,12 +1081,13 @@ public class Client implements AutoCloseable {
// 2) RpcRequest // 2) RpcRequest
// //
// Items '1' and '2' are prepared here. // Items '1' and '2' are prepared here.
final DataOutputBuffer d = new DataOutputBuffer();
RpcRequestHeaderProto header = ProtoUtil.makeRpcRequestHeader( RpcRequestHeaderProto header = ProtoUtil.makeRpcRequestHeader(
call.rpcKind, OperationProto.RPC_FINAL_PACKET, call.id, call.retry, call.rpcKind, OperationProto.RPC_FINAL_PACKET, call.id, call.retry,
clientId); clientId);
header.writeDelimitedTo(d);
call.rpcRequest.write(d); final ResponseBuffer buf = new ResponseBuffer();
header.writeDelimitedTo(buf);
RpcWritable.wrap(call.rpcRequest).writeTo(buf);
synchronized (sendRpcRequestLock) { synchronized (sendRpcRequestLock) {
Future<?> senderFuture = sendParamsExecutor.submit(new Runnable() { Future<?> senderFuture = sendParamsExecutor.submit(new Runnable() {
@ -1103,14 +1098,10 @@ public class Client implements AutoCloseable {
if (shouldCloseConnection.get()) { if (shouldCloseConnection.get()) {
return; return;
} }
if (LOG.isDebugEnabled()) {
if (LOG.isDebugEnabled())
LOG.debug(getName() + " sending #" + call.id); LOG.debug(getName() + " sending #" + call.id);
}
byte[] data = d.getData(); buf.writeTo(out); // RpcRequestHeader + RpcRequest
int totalLength = d.getLength();
out.writeInt(totalLength); // Total Length
out.write(data, 0, totalLength);// RpcRequestHeader + RpcRequest
out.flush(); out.flush();
} }
} catch (IOException e) { } catch (IOException e) {
@ -1121,7 +1112,7 @@ public class Client implements AutoCloseable {
} finally { } finally {
//the buffer is just an in-memory buffer, but it is still polite to //the buffer is just an in-memory buffer, but it is still polite to
// close early // close early
IOUtils.closeStream(d); IOUtils.closeStream(buf);
} }
} }
}); });
@ -1153,12 +1144,13 @@ public class Client implements AutoCloseable {
try { try {
int totalLen = in.readInt(); int totalLen = in.readInt();
RpcResponseHeaderProto header = ByteBuffer bb = ByteBuffer.allocate(totalLen);
RpcResponseHeaderProto.parseDelimitedFrom(in); in.readFully(bb.array());
checkResponse(header);
int headerLen = header.getSerializedSize(); RpcWritable.Buffer packet = RpcWritable.Buffer.wrap(bb);
headerLen += CodedOutputStream.computeRawVarint32Size(headerLen); RpcResponseHeaderProto header =
packet.getValue(RpcResponseHeaderProto.getDefaultInstance());
checkResponse(header);
int callId = header.getCallId(); int callId = header.getCallId();
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
@ -1166,28 +1158,15 @@ public class Client implements AutoCloseable {
RpcStatusProto status = header.getStatus(); RpcStatusProto status = header.getStatus();
if (status == RpcStatusProto.SUCCESS) { if (status == RpcStatusProto.SUCCESS) {
Writable value = ReflectionUtils.newInstance(valueClass, conf); Writable value = packet.newInstance(valueClass, conf);
value.readFields(in); // read value
final Call call = calls.remove(callId); final Call call = calls.remove(callId);
call.setRpcResponse(value); call.setRpcResponse(value);
// verify that length was correct
// only for ProtobufEngine where len can be verified easily
if (call.getRpcResponse() instanceof ProtobufRpcEngine.RpcWrapper) {
ProtobufRpcEngine.RpcWrapper resWrapper =
(ProtobufRpcEngine.RpcWrapper) call.getRpcResponse();
if (totalLen != headerLen + resWrapper.getLength()) {
throw new RpcClientException(
"RPC response length mismatch on rpc success");
} }
// verify that packet length was correct
if (packet.remaining() > 0) {
throw new RpcClientException("RPC response length mismatch");
} }
} else { // Rpc Request failed if (status != RpcStatusProto.SUCCESS) { // Rpc Request failed
// Verify that length was correct
if (totalLen != headerLen) {
throw new RpcClientException(
"RPC response length mismatch on rpc error");
}
final String exceptionClassName = header.hasExceptionClassName() ? final String exceptionClassName = header.hasExceptionClassName() ?
header.getExceptionClassName() : header.getExceptionClassName() :
"ServerDidNotSetExceptionClassName"; "ServerDidNotSetExceptionClassName";

View File

@ -27,29 +27,22 @@ import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.classification.InterfaceStability; import org.apache.hadoop.classification.InterfaceStability;
import org.apache.hadoop.classification.InterfaceStability.Unstable; import org.apache.hadoop.classification.InterfaceStability.Unstable;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.DataOutputOutputStream;
import org.apache.hadoop.io.Writable; import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.retry.RetryPolicy; import org.apache.hadoop.io.retry.RetryPolicy;
import org.apache.hadoop.ipc.Client.ConnectionId; import org.apache.hadoop.ipc.Client.ConnectionId;
import org.apache.hadoop.ipc.RPC.RpcInvoker; import org.apache.hadoop.ipc.RPC.RpcInvoker;
import org.apache.hadoop.ipc.RpcWritable; import org.apache.hadoop.ipc.RpcWritable;
import org.apache.hadoop.ipc.protobuf.ProtobufRpcEngineProtos.RequestHeaderProto; import org.apache.hadoop.ipc.protobuf.ProtobufRpcEngineProtos.RequestHeaderProto;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcRequestHeaderProto;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto;
import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.SecretManager; import org.apache.hadoop.security.token.SecretManager;
import org.apache.hadoop.security.token.TokenIdentifier; import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.util.ProtoUtil;
import org.apache.hadoop.util.Time; import org.apache.hadoop.util.Time;
import org.apache.hadoop.util.concurrent.AsyncGet; import org.apache.hadoop.util.concurrent.AsyncGet;
import org.apache.htrace.core.TraceScope; import org.apache.htrace.core.TraceScope;
import org.apache.htrace.core.Tracer; import org.apache.htrace.core.Tracer;
import javax.net.SocketFactory; import javax.net.SocketFactory;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException; import java.io.IOException;
import java.io.OutputStream;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.lang.reflect.Proxy; import java.lang.reflect.Proxy;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
@ -146,7 +139,7 @@ public class ProtobufRpcEngine implements RpcEngine {
private Invoker(Class<?> protocol, Client.ConnectionId connId, private Invoker(Class<?> protocol, Client.ConnectionId connId,
Configuration conf, SocketFactory factory) { Configuration conf, SocketFactory factory) {
this.remoteId = connId; this.remoteId = connId;
this.client = CLIENTS.getClient(conf, factory, RpcResponseWrapper.class); this.client = CLIENTS.getClient(conf, factory, RpcWritable.Buffer.class);
this.protocolName = RPC.getProtocolName(protocol); this.protocolName = RPC.getProtocolName(protocol);
this.clientProtocolVersion = RPC this.clientProtocolVersion = RPC
.getProtocolVersion(protocol); .getProtocolVersion(protocol);
@ -193,7 +186,7 @@ public class ProtobufRpcEngine implements RpcEngine {
* the server. * the server.
*/ */
@Override @Override
public Object invoke(Object proxy, final Method method, Object[] args) public Message invoke(Object proxy, final Method method, Object[] args)
throws ServiceException { throws ServiceException {
long startTime = 0; long startTime = 0;
if (LOG.isDebugEnabled()) { if (LOG.isDebugEnabled()) {
@ -228,11 +221,11 @@ public class ProtobufRpcEngine implements RpcEngine {
} }
Message theRequest = (Message) args[1]; final Message theRequest = (Message) args[1];
final RpcResponseWrapper val; final RpcWritable.Buffer val;
try { try {
val = (RpcResponseWrapper) client.call(RPC.RpcKind.RPC_PROTOCOL_BUFFER, val = (RpcWritable.Buffer) client.call(RPC.RpcKind.RPC_PROTOCOL_BUFFER,
new RpcRequestWrapper(rpcRequestHeader, theRequest), remoteId, new RpcProtobufRequest(rpcRequestHeader, theRequest), remoteId,
fallbackToSimpleAuth); fallbackToSimpleAuth);
} catch (Throwable e) { } catch (Throwable e) {
@ -256,7 +249,7 @@ public class ProtobufRpcEngine implements RpcEngine {
} }
if (Client.isAsynchronousMode()) { if (Client.isAsynchronousMode()) {
final AsyncGet<RpcResponseWrapper, IOException> arr final AsyncGet<RpcWritable.Buffer, IOException> arr
= Client.getAsyncRpcResponse(); = Client.getAsyncRpcResponse();
final AsyncGet<Message, Exception> asyncGet final AsyncGet<Message, Exception> asyncGet
= new AsyncGet<Message, Exception>() { = new AsyncGet<Message, Exception>() {
@ -278,7 +271,7 @@ public class ProtobufRpcEngine implements RpcEngine {
} }
private Message getReturnMessage(final Method method, private Message getReturnMessage(final Method method,
final RpcResponseWrapper rrw) throws ServiceException { final RpcWritable.Buffer buf) throws ServiceException {
Message prototype = null; Message prototype = null;
try { try {
prototype = getReturnProtoType(method); prototype = getReturnProtoType(method);
@ -287,8 +280,7 @@ public class ProtobufRpcEngine implements RpcEngine {
} }
Message returnMessage; Message returnMessage;
try { try {
returnMessage = prototype.newBuilderForType() returnMessage = buf.getValue(prototype.getDefaultInstanceForType());
.mergeFrom(rrw.theResponseRead).build();
if (LOG.isTraceEnabled()) { if (LOG.isTraceEnabled()) {
LOG.trace(Thread.currentThread().getId() + ": Response <- " + LOG.trace(Thread.currentThread().getId() + ": Response <- " +
@ -329,201 +321,12 @@ public class ProtobufRpcEngine implements RpcEngine {
} }
} }
interface RpcWrapper extends Writable {
int getLength();
}
/**
* Wrapper for Protocol Buffer Requests
*
* Note while this wrapper is writable, the request on the wire is in
* Protobuf. Several methods on {@link org.apache.hadoop.ipc.Server and RPC}
* use type Writable as a wrapper to work across multiple RpcEngine kinds.
*/
private static abstract class RpcMessageWithHeader<T extends GeneratedMessage>
implements RpcWrapper {
T requestHeader;
Message theRequest; // for clientSide, the request is here
byte[] theRequestRead; // for server side, the request is here
public RpcMessageWithHeader() {
}
public RpcMessageWithHeader(T requestHeader, Message theRequest) {
this.requestHeader = requestHeader;
this.theRequest = theRequest;
}
@Override
public void write(DataOutput out) throws IOException {
OutputStream os = DataOutputOutputStream.constructOutputStream(out);
((Message)requestHeader).writeDelimitedTo(os);
theRequest.writeDelimitedTo(os);
}
@Override
public void readFields(DataInput in) throws IOException {
requestHeader = parseHeaderFrom(readVarintBytes(in));
theRequestRead = readMessageRequest(in);
}
abstract T parseHeaderFrom(byte[] bytes) throws IOException;
byte[] readMessageRequest(DataInput in) throws IOException {
return readVarintBytes(in);
}
private static byte[] readVarintBytes(DataInput in) throws IOException {
final int length = ProtoUtil.readRawVarint32(in);
final byte[] bytes = new byte[length];
in.readFully(bytes);
return bytes;
}
public T getMessageHeader() {
return requestHeader;
}
public byte[] getMessageBytes() {
return theRequestRead;
}
@Override
public int getLength() {
int headerLen = requestHeader.getSerializedSize();
int reqLen;
if (theRequest != null) {
reqLen = theRequest.getSerializedSize();
} else if (theRequestRead != null ) {
reqLen = theRequestRead.length;
} else {
throw new IllegalArgumentException(
"getLength on uninitialized RpcWrapper");
}
return CodedOutputStream.computeRawVarint32Size(headerLen) + headerLen
+ CodedOutputStream.computeRawVarint32Size(reqLen) + reqLen;
}
}
private static class RpcRequestWrapper
extends RpcMessageWithHeader<RequestHeaderProto> {
@SuppressWarnings("unused")
public RpcRequestWrapper() {}
public RpcRequestWrapper(
RequestHeaderProto requestHeader, Message theRequest) {
super(requestHeader, theRequest);
}
@Override
RequestHeaderProto parseHeaderFrom(byte[] bytes) throws IOException {
return RequestHeaderProto.parseFrom(bytes);
}
@Override
public String toString() {
return requestHeader.getDeclaringClassProtocolName() + "." +
requestHeader.getMethodName();
}
}
@InterfaceAudience.LimitedPrivate({"RPC"})
public static class RpcRequestMessageWrapper
extends RpcMessageWithHeader<RpcRequestHeaderProto> {
public RpcRequestMessageWrapper() {}
public RpcRequestMessageWrapper(
RpcRequestHeaderProto requestHeader, Message theRequest) {
super(requestHeader, theRequest);
}
@Override
RpcRequestHeaderProto parseHeaderFrom(byte[] bytes) throws IOException {
return RpcRequestHeaderProto.parseFrom(bytes);
}
}
@InterfaceAudience.LimitedPrivate({"RPC"})
public static class RpcResponseMessageWrapper
extends RpcMessageWithHeader<RpcResponseHeaderProto> {
public RpcResponseMessageWrapper() {}
public RpcResponseMessageWrapper(
RpcResponseHeaderProto responseHeader, Message theRequest) {
super(responseHeader, theRequest);
}
@Override
byte[] readMessageRequest(DataInput in) throws IOException {
// error message contain no message body
switch (requestHeader.getStatus()) {
case ERROR:
case FATAL:
return null;
default:
return super.readMessageRequest(in);
}
}
@Override
RpcResponseHeaderProto parseHeaderFrom(byte[] bytes) throws IOException {
return RpcResponseHeaderProto.parseFrom(bytes);
}
}
/**
* Wrapper for Protocol Buffer Responses
*
* Note while this wrapper is writable, the request on the wire is in
* Protobuf. Several methods on {@link org.apache.hadoop.ipc.Server and RPC}
* use type Writable as a wrapper to work across multiple RpcEngine kinds.
*/
@InterfaceAudience.LimitedPrivate({"RPC"}) // temporarily exposed
public static class RpcResponseWrapper implements RpcWrapper {
Message theResponse; // for senderSide, the response is here
byte[] theResponseRead; // for receiver side, the response is here
public RpcResponseWrapper() {
}
public RpcResponseWrapper(Message message) {
this.theResponse = message;
}
@Override
public void write(DataOutput out) throws IOException {
OutputStream os = DataOutputOutputStream.constructOutputStream(out);
theResponse.writeDelimitedTo(os);
}
@Override
public void readFields(DataInput in) throws IOException {
int length = ProtoUtil.readRawVarint32(in);
theResponseRead = new byte[length];
in.readFully(theResponseRead);
}
@Override
public int getLength() {
int resLen;
if (theResponse != null) {
resLen = theResponse.getSerializedSize();
} else if (theResponseRead != null ) {
resLen = theResponseRead.length;
} else {
throw new IllegalArgumentException(
"getLength on uninitialized RpcWrapper");
}
return CodedOutputStream.computeRawVarint32Size(resLen) + resLen;
}
}
@VisibleForTesting @VisibleForTesting
@InterfaceAudience.Private @InterfaceAudience.Private
@InterfaceStability.Unstable @InterfaceStability.Unstable
static Client getClient(Configuration conf) { static Client getClient(Configuration conf) {
return CLIENTS.getClient(conf, SocketFactory.getDefault(), return CLIENTS.getClient(conf, SocketFactory.getDefault(),
RpcResponseWrapper.class); RpcWritable.Buffer.class);
} }
@ -672,16 +475,30 @@ public class ProtobufRpcEngine implements RpcEngine {
// which uses the rpc header. in the normal case we want to defer decoding // which uses the rpc header. in the normal case we want to defer decoding
// the rpc header until needed by the rpc engine. // the rpc header until needed by the rpc engine.
static class RpcProtobufRequest extends RpcWritable.Buffer { static class RpcProtobufRequest extends RpcWritable.Buffer {
private RequestHeaderProto lazyHeader; private volatile RequestHeaderProto requestHeader;
private Message payload;
public RpcProtobufRequest() { public RpcProtobufRequest() {
} }
synchronized RequestHeaderProto getRequestHeader() throws IOException { RpcProtobufRequest(RequestHeaderProto header, Message payload) {
if (lazyHeader == null) { this.requestHeader = header;
lazyHeader = getValue(RequestHeaderProto.getDefaultInstance()); this.payload = payload;
}
RequestHeaderProto getRequestHeader() throws IOException {
if (getByteBuffer() != null && requestHeader == null) {
requestHeader = getValue(RequestHeaderProto.getDefaultInstance());
}
return requestHeader;
}
@Override
public void writeTo(ResponseBuffer out) throws IOException {
requestHeader.writeDelimitedTo(out);
if (payload != null) {
payload.writeDelimitedTo(out);
} }
return lazyHeader;
} }
// this is used by htrace to name the span. // this is used by htrace to name the span.

View File

@ -27,8 +27,14 @@ import java.util.Arrays;
import org.apache.hadoop.classification.InterfaceAudience; import org.apache.hadoop.classification.InterfaceAudience;
@InterfaceAudience.Private @InterfaceAudience.Private
class ResponseBuffer extends DataOutputStream { /** generates byte-length framed buffers. */
ResponseBuffer(int capacity) { public class ResponseBuffer extends DataOutputStream {
public ResponseBuffer() {
this(1024);
}
public ResponseBuffer(int capacity) {
super(new FramedBuffer(capacity)); super(new FramedBuffer(capacity));
} }
@ -39,7 +45,7 @@ class ResponseBuffer extends DataOutputStream {
return buf; return buf;
} }
void writeTo(OutputStream out) throws IOException { public void writeTo(OutputStream out) throws IOException {
getFramedBuffer().writeTo(out); getFramedBuffer().writeTo(out);
} }

View File

@ -24,7 +24,6 @@ import java.io.DataInputStream;
import java.io.DataOutput; import java.io.DataOutput;
import java.io.IOException; import java.io.IOException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import org.apache.hadoop.classification.InterfaceAudience; import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.conf.Configurable; import org.apache.hadoop.conf.Configurable;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
@ -34,6 +33,7 @@ import com.google.protobuf.CodedInputStream;
import com.google.protobuf.CodedOutputStream; import com.google.protobuf.CodedOutputStream;
import com.google.protobuf.Message; import com.google.protobuf.Message;
// note anything marked public is solely for access by SaslRpcClient
@InterfaceAudience.Private @InterfaceAudience.Private
public abstract class RpcWritable implements Writable { public abstract class RpcWritable implements Writable {
@ -99,6 +99,10 @@ public abstract class RpcWritable implements Writable {
this.message = message; this.message = message;
} }
Message getMessage() {
return message;
}
@Override @Override
void writeTo(ResponseBuffer out) throws IOException { void writeTo(ResponseBuffer out) throws IOException {
int length = message.getSerializedSize(); int length = message.getSerializedSize();
@ -128,11 +132,13 @@ public abstract class RpcWritable implements Writable {
} }
} }
// adapter to allow decoding of writables and protobufs from a byte buffer. /**
static class Buffer extends RpcWritable { * adapter to allow decoding of writables and protobufs from a byte buffer.
*/
public static class Buffer extends RpcWritable {
private ByteBuffer bb; private ByteBuffer bb;
static Buffer wrap(ByteBuffer bb) { public static Buffer wrap(ByteBuffer bb) {
return new Buffer(bb); return new Buffer(bb);
} }
@ -142,6 +148,10 @@ public abstract class RpcWritable implements Writable {
this.bb = bb; this.bb = bb;
} }
ByteBuffer getByteBuffer() {
return bb;
}
@Override @Override
void writeTo(ResponseBuffer out) throws IOException { void writeTo(ResponseBuffer out) throws IOException {
out.ensureCapacity(bb.remaining()); out.ensureCapacity(bb.remaining());
@ -177,7 +187,7 @@ public abstract class RpcWritable implements Writable {
return RpcWritable.wrap(value).readFrom(bb); return RpcWritable.wrap(value).readFrom(bb);
} }
int remaining() { public int remaining() {
return bb.remaining(); return bb.remaining();
} }
} }

View File

@ -53,11 +53,11 @@ import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.classification.InterfaceStability; import org.apache.hadoop.classification.InterfaceStability;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.GlobPattern; import org.apache.hadoop.fs.GlobPattern;
import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcRequestMessageWrapper;
import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcResponseMessageWrapper;
import org.apache.hadoop.ipc.RPC.RpcKind; import org.apache.hadoop.ipc.RPC.RpcKind;
import org.apache.hadoop.ipc.RemoteException; import org.apache.hadoop.ipc.RemoteException;
import org.apache.hadoop.ipc.ResponseBuffer;
import org.apache.hadoop.ipc.RpcConstants; import org.apache.hadoop.ipc.RpcConstants;
import org.apache.hadoop.ipc.RpcWritable;
import org.apache.hadoop.ipc.Server.AuthProtocol; import org.apache.hadoop.ipc.Server.AuthProtocol;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcRequestHeaderProto; import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcRequestHeaderProto;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcRequestHeaderProto.OperationProto; import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcRequestHeaderProto.OperationProto;
@ -367,11 +367,13 @@ public class SaslRpcClient {
// loop until sasl is complete or a rpc error occurs // loop until sasl is complete or a rpc error occurs
boolean done = false; boolean done = false;
do { do {
int totalLen = inStream.readInt(); int rpcLen = inStream.readInt();
RpcResponseMessageWrapper responseWrapper = ByteBuffer bb = ByteBuffer.allocate(rpcLen);
new RpcResponseMessageWrapper(); inStream.readFully(bb.array());
responseWrapper.readFields(inStream);
RpcResponseHeaderProto header = responseWrapper.getMessageHeader(); RpcWritable.Buffer saslPacket = RpcWritable.Buffer.wrap(bb);
RpcResponseHeaderProto header =
saslPacket.getValue(RpcResponseHeaderProto.getDefaultInstance());
switch (header.getStatus()) { switch (header.getStatus()) {
case ERROR: // might get a RPC error during case ERROR: // might get a RPC error during
case FATAL: case FATAL:
@ -379,15 +381,14 @@ public class SaslRpcClient {
header.getErrorMsg()); header.getErrorMsg());
default: break; default: break;
} }
if (totalLen != responseWrapper.getLength()) {
throw new SaslException("Received malformed response length");
}
if (header.getCallId() != AuthProtocol.SASL.callId) { if (header.getCallId() != AuthProtocol.SASL.callId) {
throw new SaslException("Non-SASL response during negotiation"); throw new SaslException("Non-SASL response during negotiation");
} }
RpcSaslProto saslMessage = RpcSaslProto saslMessage =
RpcSaslProto.parseFrom(responseWrapper.getMessageBytes()); saslPacket.getValue(RpcSaslProto.getDefaultInstance());
if (saslPacket.remaining() > 0) {
throw new SaslException("Received malformed response length");
}
// handle sasl negotiation process // handle sasl negotiation process
RpcSaslProto.Builder response = null; RpcSaslProto.Builder response = null;
switch (saslMessage.getState()) { switch (saslMessage.getState()) {
@ -451,15 +452,15 @@ public class SaslRpcClient {
return authMethod; return authMethod;
} }
private void sendSaslMessage(DataOutputStream out, RpcSaslProto message) private void sendSaslMessage(OutputStream out, RpcSaslProto message)
throws IOException { throws IOException {
if (LOG.isDebugEnabled()) { if (LOG.isDebugEnabled()) {
LOG.debug("Sending sasl message "+message); LOG.debug("Sending sasl message "+message);
} }
RpcRequestMessageWrapper request = ResponseBuffer buf = new ResponseBuffer();
new RpcRequestMessageWrapper(saslHeader, message); saslHeader.writeDelimitedTo(buf);
out.writeInt(request.getLength()); message.writeDelimitedTo(buf);
request.write(out); buf.writeTo(out);
out.flush(); out.flush();
} }
@ -633,11 +634,7 @@ public class SaslRpcClient {
.setState(SaslState.WRAP) .setState(SaslState.WRAP)
.setToken(ByteString.copyFrom(buf, 0, buf.length)) .setToken(ByteString.copyFrom(buf, 0, buf.length))
.build(); .build();
RpcRequestMessageWrapper request = sendSaslMessage(out, saslMessage);
new RpcRequestMessageWrapper(saslHeader, saslMessage);
DataOutputStream dob = new DataOutputStream(out);
dob.writeInt(request.getLength());
request.write(dob);
} }
} }