HADOOP-13426. More efficiently build IPC responses. Contributed by Daryn Sharp.

(cherry picked from commit 2d8227605f)
This commit is contained in:
Kihwal Lee 2016-08-03 09:33:04 -05:00
parent 22e182f295
commit afc8da0d86
3 changed files with 250 additions and 78 deletions

View File

@ -0,0 +1,98 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.ipc;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.util.Arrays;
import org.apache.hadoop.classification.InterfaceAudience;
@InterfaceAudience.Private
class ResponseBuffer extends DataOutputStream {
ResponseBuffer(int capacity) {
super(new FramedBuffer(capacity));
}
// update framing bytes based on bytes written to stream.
private FramedBuffer getFramedBuffer() {
FramedBuffer buf = (FramedBuffer)out;
buf.setSize(written);
return buf;
}
void writeTo(OutputStream out) throws IOException {
getFramedBuffer().writeTo(out);
}
byte[] toByteArray() {
return getFramedBuffer().toByteArray();
}
int capacity() {
return ((FramedBuffer)out).capacity();
}
void setCapacity(int capacity) {
((FramedBuffer)out).setCapacity(capacity);
}
void ensureCapacity(int capacity) {
if (((FramedBuffer)out).capacity() < capacity) {
((FramedBuffer)out).setCapacity(capacity);
}
}
ResponseBuffer reset() {
written = 0;
((FramedBuffer)out).reset();
return this;
}
private static class FramedBuffer extends ByteArrayOutputStream {
private static final int FRAMING_BYTES = 4;
FramedBuffer(int capacity) {
super(capacity + FRAMING_BYTES);
reset();
}
@Override
public int size() {
return count - FRAMING_BYTES;
}
void setSize(int size) {
buf[0] = (byte)((size >>> 24) & 0xFF);
buf[1] = (byte)((size >>> 16) & 0xFF);
buf[2] = (byte)((size >>> 8) & 0xFF);
buf[3] = (byte)((size >>> 0) & 0xFF);
}
int capacity() {
return buf.length - FRAMING_BYTES;
}
void setCapacity(int capacity) {
buf = Arrays.copyOf(buf, capacity + FRAMING_BYTES);
}
@Override
public void reset() {
count = FRAMING_BYTES;
setSize(0);
}
};
}

View File

@ -79,12 +79,11 @@ import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configuration.IntegerRanges; import org.apache.hadoop.conf.Configuration.IntegerRanges;
import org.apache.hadoop.fs.CommonConfigurationKeys; import org.apache.hadoop.fs.CommonConfigurationKeys;
import org.apache.hadoop.fs.CommonConfigurationKeysPublic; import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.io.IOUtils; import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.io.Writable; import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableUtils; import org.apache.hadoop.io.WritableUtils;
import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcResponseMessageWrapper;
import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcResponseWrapper; import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcResponseWrapper;
import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcWrapper;
import org.apache.hadoop.ipc.RPC.RpcInvoker; import org.apache.hadoop.ipc.RPC.RpcInvoker;
import org.apache.hadoop.ipc.RPC.VersionMismatch; import org.apache.hadoop.ipc.RPC.VersionMismatch;
import org.apache.hadoop.ipc.metrics.RpcDetailedMetrics; import org.apache.hadoop.ipc.metrics.RpcDetailedMetrics;
@ -422,6 +421,13 @@ public abstract class Server {
private int maxQueueSize; private int maxQueueSize;
private final int maxRespSize; private final int maxRespSize;
private final ThreadLocal<ResponseBuffer> responseBuffer =
new ThreadLocal<ResponseBuffer>(){
@Override
protected ResponseBuffer initialValue() {
return new ResponseBuffer(INITIAL_RESP_BUF_SIZE);
}
};
private int socketSendBufferSize; private int socketSendBufferSize;
private final int maxDataLength; private final int maxDataLength;
private final boolean tcpNoDelay; // if T then disable Nagle's Algorithm private final boolean tcpNoDelay; // if T then disable Nagle's Algorithm
@ -743,14 +749,7 @@ public abstract class Server {
public void abortResponse(Throwable t) throws IOException { public void abortResponse(Throwable t) throws IOException {
// don't send response if the call was already sent or aborted. // don't send response if the call was already sent or aborted.
if (responseWaitCount.getAndSet(-1) > 0) { if (responseWaitCount.getAndSet(-1) > 0) {
// clone the call to prevent a race with the other thread stomping connection.abortResponse(this, t);
// on the response while being sent. the original call is
// effectively discarded since the wait count won't hit zero
Call call = new Call(this);
setupResponse(new ByteArrayOutputStream(), call,
RpcStatusProto.FATAL, RpcErrorCodeProto.ERROR_RPC_SERVER,
null, t.getClass().getName(), StringUtils.stringifyException(t));
call.sendResponse();
} }
} }
@ -1271,9 +1270,7 @@ public abstract class Server {
// must only wrap before adding to the responseQueue to prevent // must only wrap before adding to the responseQueue to prevent
// postponed responses from being encrypted and sent out of order. // postponed responses from being encrypted and sent out of order.
if (call.connection.useWrap) { if (call.connection.useWrap) {
ByteArrayOutputStream response = new ByteArrayOutputStream(); wrapWithSasl(call);
wrapWithSasl(response, call);
call.setResponse(ByteBuffer.wrap(response.toByteArray()));
} }
call.connection.responseQueue.addLast(call); call.connection.responseQueue.addLast(call);
if (call.connection.responseQueue.size() == 1) { if (call.connection.responseQueue.size() == 1) {
@ -1394,7 +1391,6 @@ public abstract class Server {
// Fake 'call' for failed authorization response // Fake 'call' for failed authorization response
private final Call authFailedCall = new Call(AUTHORIZATION_FAILED_CALL_ID, private final Call authFailedCall = new Call(AUTHORIZATION_FAILED_CALL_ID,
RpcConstants.INVALID_RETRY_COUNT, null, this); RpcConstants.INVALID_RETRY_COUNT, null, this);
private ByteArrayOutputStream authFailedResponse = new ByteArrayOutputStream();
private boolean sentNegotiate = false; private boolean sentNegotiate = false;
private boolean useWrap = false; private boolean useWrap = false;
@ -1674,15 +1670,14 @@ public abstract class Server {
private void doSaslReply(Message message) throws IOException { private void doSaslReply(Message message) throws IOException {
final Call saslCall = new Call(AuthProtocol.SASL.callId, final Call saslCall = new Call(AuthProtocol.SASL.callId,
RpcConstants.INVALID_RETRY_COUNT, null, this); RpcConstants.INVALID_RETRY_COUNT, null, this);
final ByteArrayOutputStream saslResponse = new ByteArrayOutputStream(); setupResponse(saslCall,
setupResponse(saslResponse, saslCall,
RpcStatusProto.SUCCESS, null, RpcStatusProto.SUCCESS, null,
new RpcResponseWrapper(message), null, null); new RpcResponseWrapper(message), null, null);
saslCall.sendResponse(); saslCall.sendResponse();
} }
private void doSaslReply(Exception ioe) throws IOException { private void doSaslReply(Exception ioe) throws IOException {
setupResponse(authFailedResponse, authFailedCall, setupResponse(authFailedCall,
RpcStatusProto.FATAL, RpcErrorCodeProto.FATAL_UNAUTHORIZED, RpcStatusProto.FATAL, RpcErrorCodeProto.FATAL_UNAUTHORIZED,
null, ioe.getClass().getName(), ioe.getLocalizedMessage()); null, ioe.getClass().getName(), ioe.getLocalizedMessage());
authFailedCall.sendResponse(); authFailedCall.sendResponse();
@ -1860,7 +1855,7 @@ public abstract class Server {
// Versions >>9 understand the normal response // Versions >>9 understand the normal response
Call fakeCall = new Call(-1, RpcConstants.INVALID_RETRY_COUNT, null, Call fakeCall = new Call(-1, RpcConstants.INVALID_RETRY_COUNT, null,
this); this);
setupResponse(buffer, fakeCall, setupResponse(fakeCall,
RpcStatusProto.FATAL, RpcErrorCodeProto.FATAL_VERSION_MISMATCH, RpcStatusProto.FATAL, RpcErrorCodeProto.FATAL_VERSION_MISMATCH,
null, VersionMismatch.class.getName(), errMsg); null, VersionMismatch.class.getName(), errMsg);
fakeCall.sendResponse(); fakeCall.sendResponse();
@ -2026,7 +2021,7 @@ public abstract class Server {
} catch (WrappedRpcServerException wrse) { // inform client of error } catch (WrappedRpcServerException wrse) { // inform client of error
Throwable ioe = wrse.getCause(); Throwable ioe = wrse.getCause();
final Call call = new Call(callId, retry, null, this); final Call call = new Call(callId, retry, null, this);
setupResponse(authFailedResponse, call, setupResponse(call,
RpcStatusProto.FATAL, wrse.getRpcErrorCodeProto(), null, RpcStatusProto.FATAL, wrse.getRpcErrorCodeProto(), null,
ioe.getClass().getName(), ioe.getMessage()); ioe.getClass().getName(), ioe.getMessage());
call.sendResponse(); call.sendResponse();
@ -2252,6 +2247,17 @@ public abstract class Server {
responder.doRespond(call); responder.doRespond(call);
} }
private void abortResponse(Call call, Throwable t) throws IOException {
// clone the call to prevent a race with the other thread stomping
// on the response while being sent. the original call is
// effectively discarded since the wait count won't hit zero
call = new Call(call);
setupResponse(call,
RpcStatusProto.FATAL, RpcErrorCodeProto.ERROR_RPC_SERVER,
null, t.getClass().getName(), StringUtils.stringifyException(t));
call.sendResponse();
}
/** /**
* Get service class for connection * Get service class for connection
* @return the serviceClass * @return the serviceClass
@ -2295,8 +2301,6 @@ public abstract class Server {
public void run() { public void run() {
LOG.debug(Thread.currentThread().getName() + ": starting"); LOG.debug(Thread.currentThread().getName() + ": starting");
SERVER.set(Server.this); SERVER.set(Server.this);
ByteArrayOutputStream buf =
new ByteArrayOutputStream(INITIAL_RESP_BUF_SIZE);
while (running) { while (running) {
TraceScope traceScope = null; TraceScope traceScope = null;
try { try {
@ -2366,16 +2370,8 @@ public abstract class Server {
} }
CurCall.set(null); CurCall.set(null);
synchronized (call.connection.responseQueue) { synchronized (call.connection.responseQueue) {
setupResponse(buf, call, returnStatus, detailedErr, setupResponse(call, returnStatus, detailedErr,
value, errorClass, error); value, errorClass, error);
// Discard the large buf and reset it back to smaller size
// to free up heap.
if (buf.size() > maxRespSize) {
LOG.warn("Large response size " + buf.size() + " for call "
+ call.toString());
buf = new ByteArrayOutputStream(INITIAL_RESP_BUF_SIZE);
}
call.sendResponse(); call.sendResponse();
} }
} catch (InterruptedException e) { } catch (InterruptedException e) {
@ -2593,12 +2589,10 @@ public abstract class Server {
* @param error error message, if the call failed * @param error error message, if the call failed
* @throws IOException * @throws IOException
*/ */
private static void setupResponse(ByteArrayOutputStream responseBuf, private void setupResponse(
Call call, RpcStatusProto status, RpcErrorCodeProto erCode, Call call, RpcStatusProto status, RpcErrorCodeProto erCode,
Writable rv, String errorClass, String error) Writable rv, String errorClass, String error)
throws IOException { throws IOException {
responseBuf.reset();
DataOutputStream out = new DataOutputStream(responseBuf);
RpcResponseHeaderProto.Builder headerBuilder = RpcResponseHeaderProto.Builder headerBuilder =
RpcResponseHeaderProto.newBuilder(); RpcResponseHeaderProto.newBuilder();
headerBuilder.setClientId(ByteString.copyFrom(call.clientId)); headerBuilder.setClientId(ByteString.copyFrom(call.clientId));
@ -2609,32 +2603,14 @@ public abstract class Server {
if (status == RpcStatusProto.SUCCESS) { if (status == RpcStatusProto.SUCCESS) {
RpcResponseHeaderProto header = headerBuilder.build(); RpcResponseHeaderProto header = headerBuilder.build();
final int headerLen = header.getSerializedSize();
int fullLength = CodedOutputStream.computeRawVarint32Size(headerLen) +
headerLen;
try { try {
if (rv instanceof ProtobufRpcEngine.RpcWrapper) { setupResponse(call, header, rv);
ProtobufRpcEngine.RpcWrapper resWrapper =
(ProtobufRpcEngine.RpcWrapper) rv;
fullLength += resWrapper.getLength();
out.writeInt(fullLength);
header.writeDelimitedTo(out);
rv.write(out);
} else { // Have to serialize to buffer to get len
final DataOutputBuffer buf = new DataOutputBuffer();
rv.write(buf);
byte[] data = buf.getData();
fullLength += buf.getLength();
out.writeInt(fullLength);
header.writeDelimitedTo(out);
out.write(data, 0, buf.getLength());
}
} catch (Throwable t) { } catch (Throwable t) {
LOG.warn("Error serializing call response for call " + call, t); LOG.warn("Error serializing call response for call " + call, t);
// Call back to same function - this is OK since the // Call back to same function - this is OK since the
// buffer is reset at the top, and since status is changed // buffer is reset at the top, and since status is changed
// to ERROR it won't infinite loop. // to ERROR it won't infinite loop.
setupResponse(responseBuf, call, RpcStatusProto.ERROR, setupResponse(call, RpcStatusProto.ERROR,
RpcErrorCodeProto.ERROR_SERIALIZING_RESPONSE, RpcErrorCodeProto.ERROR_SERIALIZING_RESPONSE,
null, t.getClass().getName(), null, t.getClass().getName(),
StringUtils.stringifyException(t)); StringUtils.stringifyException(t));
@ -2644,14 +2620,33 @@ public abstract class Server {
headerBuilder.setExceptionClassName(errorClass); headerBuilder.setExceptionClassName(errorClass);
headerBuilder.setErrorMsg(error); headerBuilder.setErrorMsg(error);
headerBuilder.setErrorDetail(erCode); headerBuilder.setErrorDetail(erCode);
RpcResponseHeaderProto header = headerBuilder.build(); setupResponse(call, headerBuilder.build(), null);
int headerLen = header.getSerializedSize(); }
final int fullLength = }
CodedOutputStream.computeRawVarint32Size(headerLen) + headerLen;
out.writeInt(fullLength); private void setupResponse(Call call,
header.writeDelimitedTo(out); RpcResponseHeaderProto header, Writable rv) throws IOException {
ResponseBuffer buf = responseBuffer.get().reset();
// adjust capacity on estimated length to reduce resizing copies
int estimatedLen = header.getSerializedSize();
estimatedLen += CodedOutputStream.computeRawVarint32Size(estimatedLen);
// if it's not a wrapped protobuf, just let it grow on its own
if (rv instanceof RpcWrapper) {
estimatedLen += ((RpcWrapper)rv).getLength();
}
buf.ensureCapacity(estimatedLen);
header.writeDelimitedTo(buf);
if (rv != null) { // null for exceptions
rv.write(buf);
}
call.setResponse(ByteBuffer.wrap(buf.toByteArray()));
// Discard a large buf and reset it back to smaller size
// to free up heap.
if (buf.capacity() > maxRespSize) {
LOG.warn("Large response size " + buf.size() + " for call "
+ call.toString());
buf.setCapacity(INITIAL_RESP_BUF_SIZE);
} }
call.setResponse(ByteBuffer.wrap(responseBuf.toByteArray()));
} }
/** /**
@ -2681,9 +2676,7 @@ public abstract class Server {
call.setResponse(ByteBuffer.wrap(response.toByteArray())); call.setResponse(ByteBuffer.wrap(response.toByteArray()));
} }
private void wrapWithSasl(Call call) throws IOException {
private static void wrapWithSasl(ByteArrayOutputStream response, Call call)
throws IOException {
if (call.connection.saslServer != null) { if (call.connection.saslServer != null) {
byte[] token = call.rpcResponse.array(); byte[] token = call.rpcResponse.array();
// synchronization may be needed since there can be multiple Handler // synchronization may be needed since there can be multiple Handler
@ -2694,7 +2687,6 @@ public abstract class Server {
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("Adding saslServer wrapped token of size " + token.length LOG.debug("Adding saslServer wrapped token of size " + token.length
+ " as call response."); + " as call response.");
response.reset();
// rebuild with sasl header and payload // rebuild with sasl header and payload
RpcResponseHeaderProto saslHeader = RpcResponseHeaderProto.newBuilder() RpcResponseHeaderProto saslHeader = RpcResponseHeaderProto.newBuilder()
.setCallId(AuthProtocol.SASL.callId) .setCallId(AuthProtocol.SASL.callId)
@ -2702,14 +2694,9 @@ public abstract class Server {
.build(); .build();
RpcSaslProto saslMessage = RpcSaslProto.newBuilder() RpcSaslProto saslMessage = RpcSaslProto.newBuilder()
.setState(SaslState.WRAP) .setState(SaslState.WRAP)
.setToken(ByteString.copyFrom(token, 0, token.length)) .setToken(ByteString.copyFrom(token))
.build(); .build();
RpcResponseMessageWrapper saslResponse = setupResponse(call, saslHeader, new RpcResponseWrapper(saslMessage));
new RpcResponseMessageWrapper(saslHeader, saslMessage);
DataOutputStream out = new DataOutputStream(response);
out.writeInt(saslResponse.getLength());
saslResponse.write(out);
} }
} }

View File

@ -0,0 +1,87 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.ipc;
import static org.junit.Assert.assertEquals;
import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.IOException;
import org.apache.hadoop.ipc.ResponseBuffer;
import org.junit.Test;
/** Unit tests for ResponseBuffer. */
public class TestResponseBuffer {
@Test
public void testBuffer() throws IOException {
final int startSize = 8;
final String empty = "";
ResponseBuffer buf = new ResponseBuffer(startSize);
assertEquals(startSize, buf.capacity());
// verify it's initially empty
checkBuffer(buf, empty);
// write "nothing" and re-verify it's empty
buf.writeBytes(empty);
checkBuffer(buf, empty);
// write to the buffer twice and verify it's properly encoded
String s1 = "testing123";
buf.writeBytes(s1);
checkBuffer(buf, s1);
String s2 = "456!";
buf.writeBytes(s2);
checkBuffer(buf, s1 + s2);
// reset should not change length of underlying byte array
int length = buf.capacity();
buf.reset();
assertEquals(length, buf.capacity());
checkBuffer(buf, empty);
// setCapacity will change length of underlying byte array
buf.setCapacity(startSize);
assertEquals(startSize, buf.capacity());
checkBuffer(buf, empty);
// make sure it still works
buf.writeBytes(s1);
checkBuffer(buf, s1);
buf.writeBytes(s2);
checkBuffer(buf, s1 + s2);
}
private void checkBuffer(ResponseBuffer buf, String expected)
throws IOException {
// buffer payload length matches expected length
int expectedLength = expected.getBytes().length;
assertEquals(expectedLength, buf.size());
// buffer has the framing bytes (int)
byte[] framed = buf.toByteArray();
assertEquals(expectedLength + 4, framed.length);
// verify encoding of buffer: framing (int) + payload bytes
DataInputStream dis =
new DataInputStream(new ByteArrayInputStream(framed));
assertEquals(expectedLength, dis.readInt());
assertEquals(expectedLength, dis.available());
byte[] payload = new byte[expectedLength];
dis.readFully(payload);
assertEquals(expected, new String(payload));
}
}