HADOOP-13426. More efficiently build IPC responses. Contributed by Daryn Sharp.

This commit is contained in:
Kihwal Lee 2016-08-03 09:30:24 -05:00
parent d848184e90
commit 2d8227605f
3 changed files with 250 additions and 78 deletions

View File

@ -0,0 +1,98 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.ipc;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.util.Arrays;
import org.apache.hadoop.classification.InterfaceAudience;
@InterfaceAudience.Private
class ResponseBuffer extends DataOutputStream {
ResponseBuffer(int capacity) {
super(new FramedBuffer(capacity));
}
// update framing bytes based on bytes written to stream.
private FramedBuffer getFramedBuffer() {
FramedBuffer buf = (FramedBuffer)out;
buf.setSize(written);
return buf;
}
void writeTo(OutputStream out) throws IOException {
getFramedBuffer().writeTo(out);
}
byte[] toByteArray() {
return getFramedBuffer().toByteArray();
}
int capacity() {
return ((FramedBuffer)out).capacity();
}
void setCapacity(int capacity) {
((FramedBuffer)out).setCapacity(capacity);
}
void ensureCapacity(int capacity) {
if (((FramedBuffer)out).capacity() < capacity) {
((FramedBuffer)out).setCapacity(capacity);
}
}
ResponseBuffer reset() {
written = 0;
((FramedBuffer)out).reset();
return this;
}
private static class FramedBuffer extends ByteArrayOutputStream {
private static final int FRAMING_BYTES = 4;
FramedBuffer(int capacity) {
super(capacity + FRAMING_BYTES);
reset();
}
@Override
public int size() {
return count - FRAMING_BYTES;
}
void setSize(int size) {
buf[0] = (byte)((size >>> 24) & 0xFF);
buf[1] = (byte)((size >>> 16) & 0xFF);
buf[2] = (byte)((size >>> 8) & 0xFF);
buf[3] = (byte)((size >>> 0) & 0xFF);
}
int capacity() {
return buf.length - FRAMING_BYTES;
}
void setCapacity(int capacity) {
buf = Arrays.copyOf(buf, capacity + FRAMING_BYTES);
}
@Override
public void reset() {
count = FRAMING_BYTES;
setSize(0);
}
};
}

View File

@ -80,12 +80,11 @@ import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configuration.IntegerRanges;
import org.apache.hadoop.fs.CommonConfigurationKeys;
import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableUtils;
import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcResponseMessageWrapper;
import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcResponseWrapper;
import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcWrapper;
import org.apache.hadoop.ipc.RPC.RpcInvoker;
import org.apache.hadoop.ipc.RPC.VersionMismatch;
import org.apache.hadoop.ipc.metrics.RpcDetailedMetrics;
@ -423,6 +422,13 @@ public abstract class Server {
private int maxQueueSize;
private final int maxRespSize;
private final ThreadLocal<ResponseBuffer> responseBuffer =
new ThreadLocal<ResponseBuffer>(){
@Override
protected ResponseBuffer initialValue() {
return new ResponseBuffer(INITIAL_RESP_BUF_SIZE);
}
};
private int socketSendBufferSize;
private final int maxDataLength;
private final boolean tcpNoDelay; // if T then disable Nagle's Algorithm
@ -744,14 +750,7 @@ public abstract class Server {
public void abortResponse(Throwable t) throws IOException {
// don't send response if the call was already sent or aborted.
if (responseWaitCount.getAndSet(-1) > 0) {
// clone the call to prevent a race with the other thread stomping
// on the response while being sent. the original call is
// effectively discarded since the wait count won't hit zero
Call call = new Call(this);
setupResponse(new ByteArrayOutputStream(), call,
RpcStatusProto.FATAL, RpcErrorCodeProto.ERROR_RPC_SERVER,
null, t.getClass().getName(), StringUtils.stringifyException(t));
call.sendResponse();
connection.abortResponse(this, t);
}
}
@ -1272,9 +1271,7 @@ public abstract class Server {
// must only wrap before adding to the responseQueue to prevent
// postponed responses from being encrypted and sent out of order.
if (call.connection.useWrap) {
ByteArrayOutputStream response = new ByteArrayOutputStream();
wrapWithSasl(response, call);
call.setResponse(ByteBuffer.wrap(response.toByteArray()));
wrapWithSasl(call);
}
call.connection.responseQueue.addLast(call);
if (call.connection.responseQueue.size() == 1) {
@ -1395,7 +1392,6 @@ public abstract class Server {
// Fake 'call' for failed authorization response
private final Call authFailedCall = new Call(AUTHORIZATION_FAILED_CALL_ID,
RpcConstants.INVALID_RETRY_COUNT, null, this);
private ByteArrayOutputStream authFailedResponse = new ByteArrayOutputStream();
private boolean sentNegotiate = false;
private boolean useWrap = false;
@ -1715,15 +1711,14 @@ public abstract class Server {
private void doSaslReply(Message message) throws IOException {
final Call saslCall = new Call(AuthProtocol.SASL.callId,
RpcConstants.INVALID_RETRY_COUNT, null, this);
final ByteArrayOutputStream saslResponse = new ByteArrayOutputStream();
setupResponse(saslResponse, saslCall,
setupResponse(saslCall,
RpcStatusProto.SUCCESS, null,
new RpcResponseWrapper(message), null, null);
saslCall.sendResponse();
}
private void doSaslReply(Exception ioe) throws IOException {
setupResponse(authFailedResponse, authFailedCall,
setupResponse(authFailedCall,
RpcStatusProto.FATAL, RpcErrorCodeProto.FATAL_UNAUTHORIZED,
null, ioe.getClass().getName(), ioe.getLocalizedMessage());
authFailedCall.sendResponse();
@ -1934,7 +1929,7 @@ public abstract class Server {
// Versions >>9 understand the normal response
Call fakeCall = new Call(-1, RpcConstants.INVALID_RETRY_COUNT, null,
this);
setupResponse(buffer, fakeCall,
setupResponse(fakeCall,
RpcStatusProto.FATAL, RpcErrorCodeProto.FATAL_VERSION_MISMATCH,
null, VersionMismatch.class.getName(), errMsg);
fakeCall.sendResponse();
@ -2111,7 +2106,7 @@ public abstract class Server {
} catch (WrappedRpcServerException wrse) { // inform client of error
Throwable ioe = wrse.getCause();
final Call call = new Call(callId, retry, null, this);
setupResponse(authFailedResponse, call,
setupResponse(call,
RpcStatusProto.FATAL, wrse.getRpcErrorCodeProto(), null,
ioe.getClass().getName(), ioe.getMessage());
call.sendResponse();
@ -2341,6 +2336,17 @@ public abstract class Server {
responder.doRespond(call);
}
private void abortResponse(Call call, Throwable t) throws IOException {
// clone the call to prevent a race with the other thread stomping
// on the response while being sent. the original call is
// effectively discarded since the wait count won't hit zero
call = new Call(call);
setupResponse(call,
RpcStatusProto.FATAL, RpcErrorCodeProto.ERROR_RPC_SERVER,
null, t.getClass().getName(), StringUtils.stringifyException(t));
call.sendResponse();
}
/**
* Get service class for connection
* @return the serviceClass
@ -2384,8 +2390,6 @@ public abstract class Server {
public void run() {
LOG.debug(Thread.currentThread().getName() + ": starting");
SERVER.set(Server.this);
ByteArrayOutputStream buf =
new ByteArrayOutputStream(INITIAL_RESP_BUF_SIZE);
while (running) {
TraceScope traceScope = null;
try {
@ -2455,16 +2459,8 @@ public abstract class Server {
}
CurCall.set(null);
synchronized (call.connection.responseQueue) {
setupResponse(buf, call, returnStatus, detailedErr,
setupResponse(call, returnStatus, detailedErr,
value, errorClass, error);
// Discard the large buf and reset it back to smaller size
// to free up heap.
if (buf.size() > maxRespSize) {
LOG.warn("Large response size " + buf.size() + " for call "
+ call.toString());
buf = new ByteArrayOutputStream(INITIAL_RESP_BUF_SIZE);
}
call.sendResponse();
}
} catch (InterruptedException e) {
@ -2682,12 +2678,10 @@ public abstract class Server {
* @param error error message, if the call failed
* @throws IOException
*/
private static void setupResponse(ByteArrayOutputStream responseBuf,
Call call, RpcStatusProto status, RpcErrorCodeProto erCode,
Writable rv, String errorClass, String error)
throws IOException {
responseBuf.reset();
DataOutputStream out = new DataOutputStream(responseBuf);
private void setupResponse(
Call call, RpcStatusProto status, RpcErrorCodeProto erCode,
Writable rv, String errorClass, String error)
throws IOException {
RpcResponseHeaderProto.Builder headerBuilder =
RpcResponseHeaderProto.newBuilder();
headerBuilder.setClientId(ByteString.copyFrom(call.clientId));
@ -2698,32 +2692,14 @@ public abstract class Server {
if (status == RpcStatusProto.SUCCESS) {
RpcResponseHeaderProto header = headerBuilder.build();
final int headerLen = header.getSerializedSize();
int fullLength = CodedOutputStream.computeRawVarint32Size(headerLen) +
headerLen;
try {
if (rv instanceof ProtobufRpcEngine.RpcWrapper) {
ProtobufRpcEngine.RpcWrapper resWrapper =
(ProtobufRpcEngine.RpcWrapper) rv;
fullLength += resWrapper.getLength();
out.writeInt(fullLength);
header.writeDelimitedTo(out);
rv.write(out);
} else { // Have to serialize to buffer to get len
final DataOutputBuffer buf = new DataOutputBuffer();
rv.write(buf);
byte[] data = buf.getData();
fullLength += buf.getLength();
out.writeInt(fullLength);
header.writeDelimitedTo(out);
out.write(data, 0, buf.getLength());
}
setupResponse(call, header, rv);
} catch (Throwable t) {
LOG.warn("Error serializing call response for call " + call, t);
// Call back to same function - this is OK since the
// buffer is reset at the top, and since status is changed
// to ERROR it won't infinite loop.
setupResponse(responseBuf, call, RpcStatusProto.ERROR,
setupResponse(call, RpcStatusProto.ERROR,
RpcErrorCodeProto.ERROR_SERIALIZING_RESPONSE,
null, t.getClass().getName(),
StringUtils.stringifyException(t));
@ -2733,14 +2709,33 @@ public abstract class Server {
headerBuilder.setExceptionClassName(errorClass);
headerBuilder.setErrorMsg(error);
headerBuilder.setErrorDetail(erCode);
RpcResponseHeaderProto header = headerBuilder.build();
int headerLen = header.getSerializedSize();
final int fullLength =
CodedOutputStream.computeRawVarint32Size(headerLen) + headerLen;
out.writeInt(fullLength);
header.writeDelimitedTo(out);
setupResponse(call, headerBuilder.build(), null);
}
}
private void setupResponse(Call call,
RpcResponseHeaderProto header, Writable rv) throws IOException {
ResponseBuffer buf = responseBuffer.get().reset();
// adjust capacity on estimated length to reduce resizing copies
int estimatedLen = header.getSerializedSize();
estimatedLen += CodedOutputStream.computeRawVarint32Size(estimatedLen);
// if it's not a wrapped protobuf, just let it grow on its own
if (rv instanceof RpcWrapper) {
estimatedLen += ((RpcWrapper)rv).getLength();
}
buf.ensureCapacity(estimatedLen);
header.writeDelimitedTo(buf);
if (rv != null) { // null for exceptions
rv.write(buf);
}
call.setResponse(ByteBuffer.wrap(buf.toByteArray()));
// Discard a large buf and reset it back to smaller size
// to free up heap.
if (buf.capacity() > maxRespSize) {
LOG.warn("Large response size " + buf.size() + " for call "
+ call.toString());
buf.setCapacity(INITIAL_RESP_BUF_SIZE);
}
call.setResponse(ByteBuffer.wrap(responseBuf.toByteArray()));
}
/**
@ -2770,9 +2765,7 @@ public abstract class Server {
call.setResponse(ByteBuffer.wrap(response.toByteArray()));
}
private static void wrapWithSasl(ByteArrayOutputStream response, Call call)
throws IOException {
private void wrapWithSasl(Call call) throws IOException {
if (call.connection.saslServer != null) {
byte[] token = call.rpcResponse.array();
// synchronization may be needed since there can be multiple Handler
@ -2783,7 +2776,6 @@ public abstract class Server {
if (LOG.isDebugEnabled())
LOG.debug("Adding saslServer wrapped token of size " + token.length
+ " as call response.");
response.reset();
// rebuild with sasl header and payload
RpcResponseHeaderProto saslHeader = RpcResponseHeaderProto.newBuilder()
.setCallId(AuthProtocol.SASL.callId)
@ -2791,14 +2783,9 @@ public abstract class Server {
.build();
RpcSaslProto saslMessage = RpcSaslProto.newBuilder()
.setState(SaslState.WRAP)
.setToken(ByteString.copyFrom(token, 0, token.length))
.setToken(ByteString.copyFrom(token))
.build();
RpcResponseMessageWrapper saslResponse =
new RpcResponseMessageWrapper(saslHeader, saslMessage);
DataOutputStream out = new DataOutputStream(response);
out.writeInt(saslResponse.getLength());
saslResponse.write(out);
setupResponse(call, saslHeader, new RpcResponseWrapper(saslMessage));
}
}

View File

@ -0,0 +1,87 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.ipc;
import static org.junit.Assert.assertEquals;
import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.IOException;
import org.apache.hadoop.ipc.ResponseBuffer;
import org.junit.Test;
/** Unit tests for ResponseBuffer. */
public class TestResponseBuffer {
@Test
public void testBuffer() throws IOException {
final int startSize = 8;
final String empty = "";
ResponseBuffer buf = new ResponseBuffer(startSize);
assertEquals(startSize, buf.capacity());
// verify it's initially empty
checkBuffer(buf, empty);
// write "nothing" and re-verify it's empty
buf.writeBytes(empty);
checkBuffer(buf, empty);
// write to the buffer twice and verify it's properly encoded
String s1 = "testing123";
buf.writeBytes(s1);
checkBuffer(buf, s1);
String s2 = "456!";
buf.writeBytes(s2);
checkBuffer(buf, s1 + s2);
// reset should not change length of underlying byte array
int length = buf.capacity();
buf.reset();
assertEquals(length, buf.capacity());
checkBuffer(buf, empty);
// setCapacity will change length of underlying byte array
buf.setCapacity(startSize);
assertEquals(startSize, buf.capacity());
checkBuffer(buf, empty);
// make sure it still works
buf.writeBytes(s1);
checkBuffer(buf, s1);
buf.writeBytes(s2);
checkBuffer(buf, s1 + s2);
}
private void checkBuffer(ResponseBuffer buf, String expected)
throws IOException {
// buffer payload length matches expected length
int expectedLength = expected.getBytes().length;
assertEquals(expectedLength, buf.size());
// buffer has the framing bytes (int)
byte[] framed = buf.toByteArray();
assertEquals(expectedLength + 4, framed.length);
// verify encoding of buffer: framing (int) + payload bytes
DataInputStream dis =
new DataInputStream(new ByteArrayInputStream(framed));
assertEquals(expectedLength, dis.readInt());
assertEquals(expectedLength, dis.available());
byte[] payload = new byte[expectedLength];
dis.readFully(payload);
assertEquals(expected, new String(payload));
}
}