diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ResponseBuffer.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ResponseBuffer.java new file mode 100644 index 00000000000..ac96a24178e --- /dev/null +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ResponseBuffer.java @@ -0,0 +1,98 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.ipc; + +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.util.Arrays; + +import org.apache.hadoop.classification.InterfaceAudience; + +@InterfaceAudience.Private +class ResponseBuffer extends DataOutputStream { + ResponseBuffer(int capacity) { + super(new FramedBuffer(capacity)); + } + + // update framing bytes based on bytes written to stream. + private FramedBuffer getFramedBuffer() { + FramedBuffer buf = (FramedBuffer)out; + buf.setSize(written); + return buf; + } + + void writeTo(OutputStream out) throws IOException { + getFramedBuffer().writeTo(out); + } + + byte[] toByteArray() { + return getFramedBuffer().toByteArray(); + } + + int capacity() { + return ((FramedBuffer)out).capacity(); + } + + void setCapacity(int capacity) { + ((FramedBuffer)out).setCapacity(capacity); + } + + void ensureCapacity(int capacity) { + if (((FramedBuffer)out).capacity() < capacity) { + ((FramedBuffer)out).setCapacity(capacity); + } + } + + ResponseBuffer reset() { + written = 0; + ((FramedBuffer)out).reset(); + return this; + } + + private static class FramedBuffer extends ByteArrayOutputStream { + private static final int FRAMING_BYTES = 4; + FramedBuffer(int capacity) { + super(capacity + FRAMING_BYTES); + reset(); + } + @Override + public int size() { + return count - FRAMING_BYTES; + } + void setSize(int size) { + buf[0] = (byte)((size >>> 24) & 0xFF); + buf[1] = (byte)((size >>> 16) & 0xFF); + buf[2] = (byte)((size >>> 8) & 0xFF); + buf[3] = (byte)((size >>> 0) & 0xFF); + } + int capacity() { + return buf.length - FRAMING_BYTES; + } + void setCapacity(int capacity) { + buf = Arrays.copyOf(buf, capacity + FRAMING_BYTES); + } + @Override + public void reset() { + count = FRAMING_BYTES; + setSize(0); + } + }; +} \ No newline at end of file diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java index 6eb93fdb426..90117f96984 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java @@ -79,12 +79,11 @@ import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration.IntegerRanges; import org.apache.hadoop.fs.CommonConfigurationKeys; import org.apache.hadoop.fs.CommonConfigurationKeysPublic; -import org.apache.hadoop.io.DataOutputBuffer; import org.apache.hadoop.io.IOUtils; import org.apache.hadoop.io.Writable; import org.apache.hadoop.io.WritableUtils; -import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcResponseMessageWrapper; import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcResponseWrapper; +import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcWrapper; import org.apache.hadoop.ipc.RPC.RpcInvoker; import org.apache.hadoop.ipc.RPC.VersionMismatch; import org.apache.hadoop.ipc.metrics.RpcDetailedMetrics; @@ -422,6 +421,13 @@ public abstract class Server { private int maxQueueSize; private final int maxRespSize; + private final ThreadLocal responseBuffer = + new ThreadLocal(){ + @Override + protected ResponseBuffer initialValue() { + return new ResponseBuffer(INITIAL_RESP_BUF_SIZE); + } + }; private int socketSendBufferSize; private final int maxDataLength; private final boolean tcpNoDelay; // if T then disable Nagle's Algorithm @@ -743,14 +749,7 @@ public abstract class Server { public void abortResponse(Throwable t) throws IOException { // don't send response if the call was already sent or aborted. if (responseWaitCount.getAndSet(-1) > 0) { - // clone the call to prevent a race with the other thread stomping - // on the response while being sent. the original call is - // effectively discarded since the wait count won't hit zero - Call call = new Call(this); - setupResponse(new ByteArrayOutputStream(), call, - RpcStatusProto.FATAL, RpcErrorCodeProto.ERROR_RPC_SERVER, - null, t.getClass().getName(), StringUtils.stringifyException(t)); - call.sendResponse(); + connection.abortResponse(this, t); } } @@ -1271,9 +1270,7 @@ public abstract class Server { // must only wrap before adding to the responseQueue to prevent // postponed responses from being encrypted and sent out of order. if (call.connection.useWrap) { - ByteArrayOutputStream response = new ByteArrayOutputStream(); - wrapWithSasl(response, call); - call.setResponse(ByteBuffer.wrap(response.toByteArray())); + wrapWithSasl(call); } call.connection.responseQueue.addLast(call); if (call.connection.responseQueue.size() == 1) { @@ -1394,8 +1391,7 @@ public abstract class Server { // Fake 'call' for failed authorization response private final Call authFailedCall = new Call(AUTHORIZATION_FAILED_CALL_ID, RpcConstants.INVALID_RETRY_COUNT, null, this); - private ByteArrayOutputStream authFailedResponse = new ByteArrayOutputStream(); - + private boolean sentNegotiate = false; private boolean useWrap = false; @@ -1677,15 +1673,14 @@ public abstract class Server { private void doSaslReply(Message message) throws IOException { final Call saslCall = new Call(AuthProtocol.SASL.callId, RpcConstants.INVALID_RETRY_COUNT, null, this); - final ByteArrayOutputStream saslResponse = new ByteArrayOutputStream(); - setupResponse(saslResponse, saslCall, + setupResponse(saslCall, RpcStatusProto.SUCCESS, null, new RpcResponseWrapper(message), null, null); saslCall.sendResponse(); } private void doSaslReply(Exception ioe) throws IOException { - setupResponse(authFailedResponse, authFailedCall, + setupResponse(authFailedCall, RpcStatusProto.FATAL, RpcErrorCodeProto.FATAL_UNAUTHORIZED, null, ioe.getClass().getName(), ioe.getLocalizedMessage()); authFailedCall.sendResponse(); @@ -1865,7 +1860,7 @@ public abstract class Server { // Versions >>9 understand the normal response Call fakeCall = new Call(-1, RpcConstants.INVALID_RETRY_COUNT, null, this); - setupResponse(buffer, fakeCall, + setupResponse(fakeCall, RpcStatusProto.FATAL, RpcErrorCodeProto.FATAL_VERSION_MISMATCH, null, VersionMismatch.class.getName(), errMsg); fakeCall.sendResponse(); @@ -2031,7 +2026,7 @@ public abstract class Server { } catch (WrappedRpcServerException wrse) { // inform client of error Throwable ioe = wrse.getCause(); final Call call = new Call(callId, retry, null, this); - setupResponse(authFailedResponse, call, + setupResponse(call, RpcStatusProto.FATAL, wrse.getRpcErrorCodeProto(), null, ioe.getClass().getName(), ioe.getMessage()); call.sendResponse(); @@ -2257,6 +2252,17 @@ public abstract class Server { responder.doRespond(call); } + private void abortResponse(Call call, Throwable t) throws IOException { + // clone the call to prevent a race with the other thread stomping + // on the response while being sent. the original call is + // effectively discarded since the wait count won't hit zero + call = new Call(call); + setupResponse(call, + RpcStatusProto.FATAL, RpcErrorCodeProto.ERROR_RPC_SERVER, + null, t.getClass().getName(), StringUtils.stringifyException(t)); + call.sendResponse(); + } + /** * Get service class for connection * @return the serviceClass @@ -2300,8 +2306,6 @@ public abstract class Server { public void run() { LOG.debug(Thread.currentThread().getName() + ": starting"); SERVER.set(Server.this); - ByteArrayOutputStream buf = - new ByteArrayOutputStream(INITIAL_RESP_BUF_SIZE); while (running) { TraceScope traceScope = null; try { @@ -2371,16 +2375,8 @@ public abstract class Server { } CurCall.set(null); synchronized (call.connection.responseQueue) { - setupResponse(buf, call, returnStatus, detailedErr, + setupResponse(call, returnStatus, detailedErr, value, errorClass, error); - - // Discard the large buf and reset it back to smaller size - // to free up heap. - if (buf.size() > maxRespSize) { - LOG.warn("Large response size " + buf.size() + " for call " - + call.toString()); - buf = new ByteArrayOutputStream(INITIAL_RESP_BUF_SIZE); - } call.sendResponse(); } } catch (InterruptedException e) { @@ -2598,13 +2594,11 @@ public abstract class Server { * @param error error message, if the call failed * @throws IOException */ - private static void setupResponse(ByteArrayOutputStream responseBuf, - Call call, RpcStatusProto status, RpcErrorCodeProto erCode, - Writable rv, String errorClass, String error) - throws IOException { - responseBuf.reset(); - DataOutputStream out = new DataOutputStream(responseBuf); - RpcResponseHeaderProto.Builder headerBuilder = + private void setupResponse( + Call call, RpcStatusProto status, RpcErrorCodeProto erCode, + Writable rv, String errorClass, String error) + throws IOException { + RpcResponseHeaderProto.Builder headerBuilder = RpcResponseHeaderProto.newBuilder(); headerBuilder.setClientId(ByteString.copyFrom(call.clientId)); headerBuilder.setCallId(call.callId); @@ -2614,32 +2608,14 @@ public abstract class Server { if (status == RpcStatusProto.SUCCESS) { RpcResponseHeaderProto header = headerBuilder.build(); - final int headerLen = header.getSerializedSize(); - int fullLength = CodedOutputStream.computeRawVarint32Size(headerLen) + - headerLen; try { - if (rv instanceof ProtobufRpcEngine.RpcWrapper) { - ProtobufRpcEngine.RpcWrapper resWrapper = - (ProtobufRpcEngine.RpcWrapper) rv; - fullLength += resWrapper.getLength(); - out.writeInt(fullLength); - header.writeDelimitedTo(out); - rv.write(out); - } else { // Have to serialize to buffer to get len - final DataOutputBuffer buf = new DataOutputBuffer(); - rv.write(buf); - byte[] data = buf.getData(); - fullLength += buf.getLength(); - out.writeInt(fullLength); - header.writeDelimitedTo(out); - out.write(data, 0, buf.getLength()); - } + setupResponse(call, header, rv); } catch (Throwable t) { LOG.warn("Error serializing call response for call " + call, t); // Call back to same function - this is OK since the // buffer is reset at the top, and since status is changed // to ERROR it won't infinite loop. - setupResponse(responseBuf, call, RpcStatusProto.ERROR, + setupResponse(call, RpcStatusProto.ERROR, RpcErrorCodeProto.ERROR_SERIALIZING_RESPONSE, null, t.getClass().getName(), StringUtils.stringifyException(t)); @@ -2649,16 +2625,35 @@ public abstract class Server { headerBuilder.setExceptionClassName(errorClass); headerBuilder.setErrorMsg(error); headerBuilder.setErrorDetail(erCode); - RpcResponseHeaderProto header = headerBuilder.build(); - int headerLen = header.getSerializedSize(); - final int fullLength = - CodedOutputStream.computeRawVarint32Size(headerLen) + headerLen; - out.writeInt(fullLength); - header.writeDelimitedTo(out); + setupResponse(call, headerBuilder.build(), null); } - call.setResponse(ByteBuffer.wrap(responseBuf.toByteArray())); } - + + private void setupResponse(Call call, + RpcResponseHeaderProto header, Writable rv) throws IOException { + ResponseBuffer buf = responseBuffer.get().reset(); + // adjust capacity on estimated length to reduce resizing copies + int estimatedLen = header.getSerializedSize(); + estimatedLen += CodedOutputStream.computeRawVarint32Size(estimatedLen); + // if it's not a wrapped protobuf, just let it grow on its own + if (rv instanceof RpcWrapper) { + estimatedLen += ((RpcWrapper)rv).getLength(); + } + buf.ensureCapacity(estimatedLen); + header.writeDelimitedTo(buf); + if (rv != null) { // null for exceptions + rv.write(buf); + } + call.setResponse(ByteBuffer.wrap(buf.toByteArray())); + // Discard a large buf and reset it back to smaller size + // to free up heap. + if (buf.capacity() > maxRespSize) { + LOG.warn("Large response size " + buf.size() + " for call " + + call.toString()); + buf.setCapacity(INITIAL_RESP_BUF_SIZE); + } + } + /** * Setup response for the IPC Call on Fatal Error from a * client that is using old version of Hadoop. @@ -2685,10 +2680,8 @@ public abstract class Server { WritableUtils.writeString(out, error); call.setResponse(ByteBuffer.wrap(response.toByteArray())); } - - - private static void wrapWithSasl(ByteArrayOutputStream response, Call call) - throws IOException { + + private void wrapWithSasl(Call call) throws IOException { if (call.connection.saslServer != null) { byte[] token = call.rpcResponse.array(); // synchronization may be needed since there can be multiple Handler @@ -2699,7 +2692,6 @@ public abstract class Server { if (LOG.isDebugEnabled()) LOG.debug("Adding saslServer wrapped token of size " + token.length + " as call response."); - response.reset(); // rebuild with sasl header and payload RpcResponseHeaderProto saslHeader = RpcResponseHeaderProto.newBuilder() .setCallId(AuthProtocol.SASL.callId) @@ -2707,14 +2699,9 @@ public abstract class Server { .build(); RpcSaslProto saslMessage = RpcSaslProto.newBuilder() .setState(SaslState.WRAP) - .setToken(ByteString.copyFrom(token, 0, token.length)) + .setToken(ByteString.copyFrom(token)) .build(); - RpcResponseMessageWrapper saslResponse = - new RpcResponseMessageWrapper(saslHeader, saslMessage); - - DataOutputStream out = new DataOutputStream(response); - out.writeInt(saslResponse.getLength()); - saslResponse.write(out); + setupResponse(call, saslHeader, new RpcResponseWrapper(saslMessage)); } } diff --git a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestResponseBuffer.java b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestResponseBuffer.java new file mode 100644 index 00000000000..98743be94a4 --- /dev/null +++ b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestResponseBuffer.java @@ -0,0 +1,87 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.ipc; + +import static org.junit.Assert.assertEquals; +import java.io.ByteArrayInputStream; +import java.io.DataInputStream; +import java.io.IOException; +import org.apache.hadoop.ipc.ResponseBuffer; +import org.junit.Test; + +/** Unit tests for ResponseBuffer. */ +public class TestResponseBuffer { + @Test + public void testBuffer() throws IOException { + final int startSize = 8; + final String empty = ""; + ResponseBuffer buf = new ResponseBuffer(startSize); + assertEquals(startSize, buf.capacity()); + + // verify it's initially empty + checkBuffer(buf, empty); + // write "nothing" and re-verify it's empty + buf.writeBytes(empty); + checkBuffer(buf, empty); + + // write to the buffer twice and verify it's properly encoded + String s1 = "testing123"; + buf.writeBytes(s1); + checkBuffer(buf, s1); + String s2 = "456!"; + buf.writeBytes(s2); + checkBuffer(buf, s1 + s2); + + // reset should not change length of underlying byte array + int length = buf.capacity(); + buf.reset(); + assertEquals(length, buf.capacity()); + checkBuffer(buf, empty); + + // setCapacity will change length of underlying byte array + buf.setCapacity(startSize); + assertEquals(startSize, buf.capacity()); + checkBuffer(buf, empty); + + // make sure it still works + buf.writeBytes(s1); + checkBuffer(buf, s1); + buf.writeBytes(s2); + checkBuffer(buf, s1 + s2); + } + + private void checkBuffer(ResponseBuffer buf, String expected) + throws IOException { + // buffer payload length matches expected length + int expectedLength = expected.getBytes().length; + assertEquals(expectedLength, buf.size()); + // buffer has the framing bytes (int) + byte[] framed = buf.toByteArray(); + assertEquals(expectedLength + 4, framed.length); + + // verify encoding of buffer: framing (int) + payload bytes + DataInputStream dis = + new DataInputStream(new ByteArrayInputStream(framed)); + assertEquals(expectedLength, dis.readInt()); + assertEquals(expectedLength, dis.available()); + byte[] payload = new byte[expectedLength]; + dis.readFully(payload); + assertEquals(expected, new String(payload)); + } +}