Simplify TcpTransport interface by reducing send code to a single send method (#19223)

Due to some optimization on the netty layer we had quite some code / cruft
added to the TcpTransport to allow for those optimizations. After cleaning
up BytesReference we can now move this optimization into TcpTransport and
have a simple send method on the implementation layer instead. This commit
adds a CompositeBytesReference that also allows message headers to be written
separately which simplify the header code as well since no skips are needed
anymore.
This commit is contained in:
Simon Willnauer 2016-07-05 08:33:19 +02:00 committed by GitHub
parent a00a54ebda
commit 44ccf67e33
12 changed files with 614 additions and 383 deletions

View File

@ -35,10 +35,7 @@ public final class BytesArray extends BytesReference {
private final int length;
public BytesArray(String bytes) {
BytesRef bytesRef = new BytesRef(bytes);
this.bytes = bytesRef.bytes;
this.offset = bytesRef.offset;
this.length = bytesRef.length;
this(new BytesRef(bytes));
}
public BytesArray(BytesRef bytesRef) {
@ -47,21 +44,15 @@ public final class BytesArray extends BytesReference {
public BytesArray(BytesRef bytesRef, boolean deepCopy) {
if (deepCopy) {
BytesRef copy = BytesRef.deepCopyOf(bytesRef);
bytes = copy.bytes;
offset = copy.offset;
length = copy.length;
} else {
bytes = bytesRef.bytes;
offset = bytesRef.offset;
length = bytesRef.length;
bytesRef = BytesRef.deepCopyOf(bytesRef);
}
bytes = bytesRef.bytes;
offset = bytesRef.offset;
length = bytesRef.length;
}
public BytesArray(byte[] bytes) {
this.bytes = bytes;
this.offset = 0;
this.length = bytes.length;
this(bytes, 0, bytes.length);
}
public BytesArray(byte[] bytes, int offset, int length) {
@ -105,4 +96,5 @@ public final class BytesArray extends BytesReference {
public long ramBytesUsed() {
return bytes.length;
}
}

View File

@ -24,6 +24,7 @@ import org.apache.lucene.util.BytesRefIterator;
import org.elasticsearch.common.io.stream.StreamInput;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.function.ToIntBiFunction;
@ -52,9 +53,8 @@ public abstract class BytesReference implements Accountable, Comparable<BytesRef
/**
* A stream input of the bytes.
*/
public StreamInput streamInput() {
BytesRef ref = toBytesRef();
return StreamInput.wrap(ref.bytes, ref.offset, ref.length);
public StreamInput streamInput() throws IOException {
return new MarkSupportingStreamInputWrapper(this);
}
/**
@ -208,4 +208,73 @@ public abstract class BytesReference implements Accountable, Comparable<BytesRef
ref.length -= length;
ref.offset += length;
}
/**
* Instead of adding the complexity of {@link InputStream#reset()} etc to the actual impl
* this wrapper builds it on top of the BytesReferenceStreamInput which is much simpler
* that way.
*/
private static final class MarkSupportingStreamInputWrapper extends StreamInput {
private final BytesReference reference;
private BytesReferenceStreamInput input;
private int mark = 0;
private MarkSupportingStreamInputWrapper(BytesReference reference) throws IOException {
this.reference = reference;
this.input = new BytesReferenceStreamInput(reference.iterator(), reference.length());
}
@Override
public byte readByte() throws IOException {
return input.readByte();
}
@Override
public void readBytes(byte[] b, int offset, int len) throws IOException {
input.readBytes(b, offset, len);
}
@Override
public int read(byte[] b, int off, int len) throws IOException {
return input.read(b, off, len);
}
@Override
public void close() throws IOException {
input.close();
}
@Override
public int read() throws IOException {
return input.read();
}
@Override
public int available() throws IOException {
return input.available();
}
@Override
public void reset() throws IOException {
input = new BytesReferenceStreamInput(reference.iterator(), reference.length());
input.skip(mark);
}
@Override
public boolean markSupported() {
return true;
}
@Override
public void mark(int readLimit) {
// readLimit is optional it only guarantees that the stream remembers data upto this limit but it can remember more
// which we do in our case
this.mark = input.getOffset();
}
@Override
public long skip(long n) throws IOException {
return input.skip(n);
}
}
}

View File

@ -0,0 +1,136 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.common.bytes;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefIterator;
import org.elasticsearch.common.io.stream.StreamInput;
import java.io.EOFException;
import java.io.IOException;
/**
* A StreamInput that reads off a {@link BytesRefIterator}. This is used to provide
* generic stream access to {@link BytesReference} instances without materializing the
* underlying bytes reference.
*/
final class BytesReferenceStreamInput extends StreamInput {
private final BytesRefIterator iterator;
private int sliceOffset;
private BytesRef slice;
private final int length; // the total size of the stream
private int offset; // the current position of the stream
public BytesReferenceStreamInput(BytesRefIterator iterator, final int length) throws IOException {
this.iterator = iterator;
this.slice = iterator.next();
this.length = length;
this.offset = 0;
this.sliceOffset = 0;
}
@Override
public byte readByte() throws IOException {
if (offset >= length) {
throw new EOFException();
}
maybeNextSlice();
byte b = slice.bytes[slice.offset + (sliceOffset++)];
offset++;
return b;
}
private void maybeNextSlice() throws IOException {
while (sliceOffset == slice.length) {
slice = iterator.next();
sliceOffset = 0;
if (slice == null) {
throw new EOFException();
}
}
}
@Override
public void readBytes(byte[] b, int bOffset, int len) throws IOException {
if (offset + len > length) {
throw new IndexOutOfBoundsException("Cannot read " + len + " bytes from stream with length " + length + " at offset " + offset);
}
read(b, bOffset, len);
}
@Override
public int read() throws IOException {
if (offset >= length) {
return -1;
}
return Byte.toUnsignedInt(readByte());
}
@Override
public int read(final byte[] b, final int bOffset, final int len) throws IOException {
if (offset >= length) {
return -1;
}
final int numBytesToCopy = Math.min(len, length - offset);
int remaining = numBytesToCopy; // copy the full length or the remaining part
int destOffset = bOffset;
while (remaining > 0) {
maybeNextSlice();
final int currentLen = Math.min(remaining, slice.length - sliceOffset);
assert currentLen > 0 : "length has to be > 0 to make progress but was: " + currentLen;
System.arraycopy(slice.bytes, slice.offset + sliceOffset, b, destOffset, currentLen);
destOffset += currentLen;
remaining -= currentLen;
sliceOffset += currentLen;
offset += currentLen;
assert remaining >= 0 : "remaining: " + remaining;
}
return numBytesToCopy;
}
@Override
public void close() throws IOException {
// do nothing
}
@Override
public int available() throws IOException {
return length - offset;
}
@Override
public long skip(long n) throws IOException {
final int skip = (int) Math.min(Integer.MAX_VALUE, n);
final int numBytesSkipped = Math.min(skip, length - offset);
int remaining = numBytesSkipped;
while (remaining > 0) {
maybeNextSlice();
int currentLen = Math.min(remaining, slice.length - (slice.offset + sliceOffset));
remaining -= currentLen;
sliceOffset += currentLen;
offset += currentLen;
assert remaining >= 0 : "remaining: " + remaining;
}
return numBytesSkipped;
}
int getOffset() {
return offset;
}
}

View File

@ -0,0 +1,151 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.common.bytes;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.BytesRefIterator;
import org.apache.lucene.util.RamUsageEstimator;
import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
/**
* A composite {@link BytesReference} that allows joining multiple bytes references
* into one without copying.
*
* Note, {@link #toBytesRef()} will materialize all pages in this BytesReference.
*/
public final class CompositeBytesReference extends BytesReference {
private final BytesReference[] references;
private final int[] offsets;
private final int length;
private final long ramBytesUsed;
public CompositeBytesReference(BytesReference... references) {
this.references = Objects.requireNonNull(references, "references must not be null");
this.offsets = new int[references.length];
long ramBytesUsed = 0;
int offset = 0;
for (int i = 0; i < references.length; i++) {
BytesReference reference = references[i];
if (reference == null) {
throw new IllegalArgumentException("references must not be null");
}
offsets[i] = offset; // we use the offsets to seek into the right BytesReference for random access and slicing
offset += reference.length();
ramBytesUsed += reference.ramBytesUsed();
}
this.ramBytesUsed = ramBytesUsed
+ (Integer.BYTES * offsets.length + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER) // offsets
+ (references.length * RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER) // references
+ Integer.BYTES // length
+ Long.BYTES; // ramBytesUsed
length = offset;
}
@Override
public byte get(int index) {
final int i = getOffsetIndex(index);
return references[i].get(index - offsets[i]);
}
@Override
public int length() {
return length;
}
@Override
public BytesReference slice(int from, int length) {
// for slices we only need to find the start and the end reference
// adjust them and pass on the references in between as they are fully contained
final int to = from + length;
final int limit = getOffsetIndex(from + length);
final int start = getOffsetIndex(from);
final BytesReference[] inSlice = new BytesReference[1 + (limit - start)];
for (int i = 0, j = start; i < inSlice.length; i++) {
inSlice[i] = references[j++];
}
int inSliceOffset = from - offsets[start];
if (inSlice.length == 1) {
return inSlice[0].slice(inSliceOffset, length);
}
// now adjust slices in front and at the end
inSlice[0] = inSlice[0].slice(inSliceOffset, inSlice[0].length() - inSliceOffset);
inSlice[inSlice.length-1] = inSlice[inSlice.length-1].slice(0, to - offsets[limit]);
return new CompositeBytesReference(inSlice);
}
private final int getOffsetIndex(int offset) {
final int i = Arrays.binarySearch(offsets, offset);
return i < 0 ? (-(i + 1)) - 1 : i;
}
@Override
public BytesRef toBytesRef() {
BytesRefBuilder builder = new BytesRefBuilder();
builder.grow(length());
BytesRef spare;
BytesRefIterator iterator = iterator();
try {
while ((spare = iterator.next()) != null) {
builder.append(spare);
}
} catch (IOException ex) {
throw new AssertionError("won't happen", ex); // this is really an error since we don't do IO in our bytesreferences
}
return builder.toBytesRef();
}
@Override
public BytesRefIterator iterator() {
if (references.length > 0) {
return new BytesRefIterator() {
int index = 0;
private BytesRefIterator current = references[index++].iterator();
@Override
public BytesRef next() throws IOException {
BytesRef next = current.next();
if (next == null) {
while (index < references.length) {
current = references[index++].iterator();
next = current.next();
if (next != null) {
break;
}
}
}
return next;
}
};
} else {
return () -> null;
}
}
@Override
public long ramBytesUsed() {
return ramBytesUsed;
}
}

View File

@ -25,10 +25,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.ByteArray;
import java.io.EOFException;
import java.io.IOException;
import java.io.OutputStream;
import java.util.Arrays;
/**
* A page based bytes reference, internally holding the bytes in a paged
@ -42,7 +39,6 @@ public class PagedBytesReference extends BytesReference {
protected final ByteArray bytearray;
private final int offset;
private final int length;
private int hash = 0;
public PagedBytesReference(BigArrays bigarrays, ByteArray bytearray, int length) {
this(bigarrays, bytearray, 0, length);
@ -70,15 +66,9 @@ public class PagedBytesReference extends BytesReference {
if (from < 0 || (from + length) > length()) {
throw new IllegalArgumentException("can't slice a buffer with length [" + length() + "], with slice parameters from [" + from + "], length [" + length + "]");
}
return new PagedBytesReference(bigarrays, bytearray, offset + from, length);
}
@Override
public StreamInput streamInput() {
return new PagedBytesReferenceStreamInput(bytearray, offset, length);
}
@Override
public BytesRef toBytesRef() {
BytesRef bref = new BytesRef();
@ -87,109 +77,6 @@ public class PagedBytesReference extends BytesReference {
return bref;
}
private static class PagedBytesReferenceStreamInput extends StreamInput {
private final ByteArray bytearray;
private final BytesRef ref;
private final int offset;
private final int length;
private int pos;
private int mark;
public PagedBytesReferenceStreamInput(ByteArray bytearray, int offset, int length) {
this.bytearray = bytearray;
this.ref = new BytesRef();
this.offset = offset;
this.length = length;
this.pos = 0;
if (offset + length > bytearray.size()) {
throw new IndexOutOfBoundsException("offset+length >= bytearray.size()");
}
}
@Override
public byte readByte() throws IOException {
if (pos >= length) {
throw new EOFException();
}
return bytearray.get(offset + pos++);
}
@Override
public void readBytes(byte[] b, int bOffset, int len) throws IOException {
if (len > offset + length) {
throw new IndexOutOfBoundsException("Cannot read " + len + " bytes from stream with length " + length + " at pos " + pos);
}
read(b, bOffset, len);
}
@Override
public int read() throws IOException {
return (pos < length) ? Byte.toUnsignedInt(bytearray.get(offset + pos++)) : -1;
}
@Override
public int read(final byte[] b, final int bOffset, final int len) throws IOException {
if (len == 0) {
return 0;
}
if (pos >= offset + length) {
return -1;
}
final int numBytesToCopy = Math.min(len, length - pos); // copy the full length or the remaining part
// current offset into the underlying ByteArray
long byteArrayOffset = offset + pos;
// bytes already copied
int copiedBytes = 0;
while (copiedBytes < numBytesToCopy) {
long pageFragment = PAGE_SIZE - (byteArrayOffset % PAGE_SIZE); // how much can we read until hitting N*PAGE_SIZE?
int bulkSize = (int) Math.min(pageFragment, numBytesToCopy - copiedBytes); // we cannot copy more than a page fragment
boolean copied = bytearray.get(byteArrayOffset, bulkSize, ref); // get the fragment
assert (copied == false); // we should never ever get back a materialized byte[]
System.arraycopy(ref.bytes, ref.offset, b, bOffset + copiedBytes, bulkSize); // copy fragment contents
copiedBytes += bulkSize; // count how much we copied
byteArrayOffset += bulkSize; // advance ByteArray index
}
pos += copiedBytes; // finally advance our stream position
return copiedBytes;
}
@Override
public boolean markSupported() {
return true;
}
@Override
public void mark(int readlimit) {
this.mark = pos;
}
@Override
public void reset() throws IOException {
pos = mark;
}
@Override
public void close() throws IOException {
// do nothing
}
@Override
public int available() throws IOException {
return length - pos;
}
}
@Override
public final BytesRefIterator iterator() {
final int offset = this.offset;

View File

@ -380,12 +380,6 @@ public abstract class StreamInput extends InputStream {
return false;
}
/**
* Resets the stream.
*/
@Override
public abstract void reset() throws IOException;
/**
* Closes the stream to further operations.
*/

View File

@ -41,7 +41,7 @@ public class TcpHeader {
output.writeByte((byte)'E');
output.writeByte((byte)'S');
// write the size, the size indicates the remaining message size, not including the size int
output.writeInt(messageSize - TcpHeader.MARKER_BYTES_SIZE - TcpHeader.MESSAGE_LENGTH_SIZE);
output.writeInt(messageSize + REQUEST_ID_SIZE + STATUS_SIZE + VERSION_ID_SIZE);
output.writeLong(requestId);
output.writeByte(status);
output.writeInt(version.id);

View File

@ -29,6 +29,7 @@ import org.elasticsearch.common.Strings;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.bytes.CompositeBytesReference;
import org.elasticsearch.common.component.AbstractLifecycleComponent;
import org.elasticsearch.common.component.Lifecycle;
import org.elasticsearch.common.compress.Compressor;
@ -347,6 +348,10 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
}
}
public List<Channel[]> getChannelArrays() {
return Arrays.asList(recovery, bulk, reg, state, ping);
}
public synchronized void close() {
closeChannels(allChannels);
}
@ -869,7 +874,7 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
protected void stopInternal() {}
public boolean canCompress(TransportRequest request) {
return compress;
return compress && (!(request instanceof BytesTransportRequest));
}
@Override
@ -885,9 +890,8 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
status = TransportStatus.setRequest(status);
ReleasableBytesStreamOutput bStream = new ReleasableBytesStreamOutput(bigArrays);
boolean addedReleaseListener = false;
StreamOutput stream = bStream;
try {
bStream.skip(TcpHeader.HEADER_SIZE);
StreamOutput stream = bStream;
// only compress if asked, and, the request is not bytes, since then only
// the header part is compressed, and the "body" can't be extracted as compressed
if (options.compress() && canCompress(request)) {
@ -903,12 +907,7 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
stream.setVersion(version);
threadPool.getThreadContext().writeTo(stream);
stream.writeString(action);
Message<Channel> writeable = prepareSend(node.getVersion(), request, stream, bStream);
try (StreamOutput headerOutput = writeable.getHeaderOutput()) {
TcpHeader.writeHeader(headerOutput, requestId, status, version,
writeable.size());
}
BytesReference message = buildMessage(requestId, status, node.getVersion(), request, stream, bStream);
final TransportRequestOptions finalOptions = options;
Runnable onRequestSent = () -> {
try {
@ -917,10 +916,10 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
transportServiceAdapter.onRequestSent(node, requestId, action, request, finalOptions);
}
};
writeable.send(targetChannel, onRequestSent);
sendMessage(targetChannel, message, onRequestSent, false);
addedReleaseListener = true;
} finally {
IOUtils.close(stream);
if (!addedReleaseListener) {
Releasables.close(bStream.bytes());
}
@ -937,26 +936,19 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
*/
public void sendErrorResponse(Version nodeVersion, Channel channel, final Exception error, final long requestId,
final String action) throws IOException {
BytesStreamOutput stream = new BytesStreamOutput();
stream.setVersion(nodeVersion);
stream.skip(TcpHeader.HEADER_SIZE);
RemoteTransportException tx = new RemoteTransportException(
nodeName(), new InetSocketTransportAddress(getLocalAddress(channel)), action, error);
stream.writeThrowable(tx);
byte status = 0;
status = TransportStatus.setResponse(status);
status = TransportStatus.setError(status);
final BytesReference bytes = stream.bytes();
Message<Channel> writeable = prepareSend(nodeVersion, bytes);
try (StreamOutput headerOutput = writeable.getHeaderOutput()) {
TcpHeader.writeHeader(headerOutput, requestId, status, nodeVersion,
writeable.size());
try(BytesStreamOutput stream = new BytesStreamOutput()) {
stream.setVersion(nodeVersion);
RemoteTransportException tx = new RemoteTransportException(
nodeName(), new InetSocketTransportAddress(getLocalAddress(channel)), action, error);
stream.writeThrowable(tx);
byte status = 0;
status = TransportStatus.setResponse(status);
status = TransportStatus.setError(status);
final BytesReference bytes = stream.bytes();
final BytesReference header = buildHeader(requestId, status, nodeVersion, bytes.length());
Runnable onRequestSent = () -> transportServiceAdapter.onResponseSent(requestId, action, error);
sendMessage(channel, new CompositeBytesReference(header, bytes), onRequestSent, false);
}
Runnable onRequestSent = () -> {
transportServiceAdapter.onResponseSent(requestId, action, error);
};
writeable.send(channel, onRequestSent);
}
/**
@ -974,19 +966,15 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
status = TransportStatus.setResponse(status); // TODO share some code with sendRequest
ReleasableBytesStreamOutput bStream = new ReleasableBytesStreamOutput(bigArrays);
boolean addedReleaseListener = false;
StreamOutput stream = bStream;
try {
bStream.skip(TcpHeader.HEADER_SIZE);
StreamOutput stream = bStream;
if (options.compress()) {
status = TransportStatus.setCompress(status);
stream = CompressorFactory.COMPRESSOR.streamOutput(stream);
}
stream.setVersion(nodeVersion);
Message<Channel> writeable = prepareSend(nodeVersion, response, stream, bStream);
try (StreamOutput headerOutput = writeable.getHeaderOutput()) {
TcpHeader.writeHeader(headerOutput, requestId, status, nodeVersion,
writeable.size());
}
BytesReference reference = buildMessage(requestId, status,nodeVersion, response, stream, bStream);
final TransportResponseOptions finalOptions = options;
Runnable onRequestSent = () -> {
try {
@ -995,10 +983,11 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
transportServiceAdapter.onResponseSent(requestId, action, response, finalOptions);
}
};
writeable.send(channel, onRequestSent);
sendMessage(channel, reference, onRequestSent, false);
addedReleaseListener = true;
} finally {
IOUtils.close(stream);
if (!addedReleaseListener) {
Releasables.close(bStream.bytes());
}
@ -1006,44 +995,51 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
}
/**
* Serializes the given message into a bytes representation and forwards to {@link #prepareSend(Version, TransportMessage,
* StreamOutput, ReleasableBytesStream)}
* Writes the Tcp message header into a bytes reference.
*
* @param requestId the request ID
* @param status the request status
* @param protocolVersion the protocol version used to serialize the data in the message
* @param length the payload length in bytes
* @see TcpHeader
*/
protected Message<Channel> prepareSend(Version nodeVersion, TransportMessage message, StreamOutput stream,
ReleasableBytesStream writtenBytes) throws IOException {
message.writeTo(stream);
private BytesReference buildHeader(long requestId, byte status, Version protocolVersion, int length) throws IOException {
try (BytesStreamOutput headerOutput = new BytesStreamOutput(TcpHeader.HEADER_SIZE)) {
headerOutput.setVersion(protocolVersion);
TcpHeader.writeHeader(headerOutput, requestId, status, protocolVersion,
length);
final BytesReference bytes = headerOutput.bytes();
assert bytes.length() == TcpHeader.HEADER_SIZE : "header size mismatch expected: " + TcpHeader.HEADER_SIZE + " but was: "
+ bytes.length();
return bytes;
}
}
/**
* Serializes the given message into a bytes representation
*/
private BytesReference buildMessage(long requestId, byte status, Version nodeVersion, TransportMessage message, StreamOutput stream,
ReleasableBytesStream writtenBytes) throws IOException {
final BytesReference zeroCopyBuffer;
if (message instanceof BytesTransportRequest) { // what a shitty optimization - we should use a direct send method instead
BytesTransportRequest bRequest = (BytesTransportRequest) message;
assert nodeVersion.equals(bRequest.version());
bRequest.writeThin(stream);
zeroCopyBuffer = bRequest.bytes;
} else {
message.writeTo(stream);
zeroCopyBuffer = BytesArray.EMPTY;
}
// we have to close the stream here - flush is not enough since we might be compressing the content
// and if we do that the close method will write some marker bytes (EOS marker) and otherwise
// we barf on the decompressing end when we read past EOF on purpose in the #validateRequest method.
// this might be a problem in deflate after all but it's important to close it for now.
stream.close();
return prepareSend(nodeVersion, writtenBytes.bytes());
final BytesReference messageBody = writtenBytes.bytes();
final BytesReference header = buildHeader(requestId, status, stream.getVersion(), messageBody.length() + zeroCopyBuffer.length());
return new CompositeBytesReference(header, messageBody, zeroCopyBuffer);
}
/**
* prepares a implementation specific message to send across the network
*/
protected abstract Message<Channel> prepareSend(Version nodeVersion, BytesReference bytesReference) throws IOException;
/**
* Allows implementations to transform TransportMessages into implementation specific messages
*/
protected interface Message<Channel> {
/**
* Creates an output to write the message header to.
*/
StreamOutput getHeaderOutput();
/**
* Returns the size of the message in bytes
*/
int size();
/**
* sends the message to the channel
* @param channel the channe to send the message to
* @param onRequestSent a callback executed once the message has been fully send
*/
void send(Channel channel, Runnable onRequestSent);
}
/**
* Validates the first N bytes of the message header and returns <code>true</code> if the message is
* a ping message and has no payload ie. isn't a real user level message.

View File

@ -20,15 +20,11 @@
package org.elasticsearch.transport.netty;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.Booleans;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.bytes.ReleasablePagedBytesReference;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.io.ReleasableBytesStream;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.network.NetworkService.TcpSettings;
@ -43,17 +39,12 @@ import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.monitor.jvm.JvmInfo;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.BytesTransportRequest;
import org.elasticsearch.transport.ConnectTransportException;
import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.TransportMessage;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportServiceAdapter;
import org.elasticsearch.transport.TransportSettings;
import org.jboss.netty.bootstrap.ClientBootstrap;
import org.jboss.netty.bootstrap.ServerBootstrap;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.buffer.ChannelBuffers;
import org.jboss.netty.channel.AdaptiveReceiveBufferSizePredictorFactory;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelFuture;
@ -72,11 +63,10 @@ import org.jboss.netty.channel.socket.oio.OioClientSocketChannelFactory;
import org.jboss.netty.channel.socket.oio.OioServerSocketChannelFactory;
import org.jboss.netty.util.HashedWheelTimer;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executors;
@ -346,81 +336,33 @@ public class NettyTransport extends TcpTransport<Channel> {
channels[0].getCloseFuture().addListener(new ChannelCloseListener(node));
return new NodeChannels(channels, channels, channels, channels, channels);
}
protected NodeChannels connectToChannels(DiscoveryNode node) {
final NodeChannels nodeChannels = new NodeChannels(new Channel[connectionsPerNodeRecovery], new Channel[connectionsPerNodeBulk],
new Channel[connectionsPerNodeReg], new Channel[connectionsPerNodeState],
new Channel[connectionsPerNodePing]);
boolean success = false;
try {
ChannelFuture[] connectRecovery = new ChannelFuture[nodeChannels.recovery.length];
ChannelFuture[] connectBulk = new ChannelFuture[nodeChannels.bulk.length];
ChannelFuture[] connectReg = new ChannelFuture[nodeChannels.reg.length];
ChannelFuture[] connectState = new ChannelFuture[nodeChannels.state.length];
ChannelFuture[] connectPing = new ChannelFuture[nodeChannels.ping.length];
int numConnections = connectionsPerNodeBulk + connectionsPerNodePing + connectionsPerNodeRecovery + connectionsPerNodeReg
+ connectionsPerNodeState;
ArrayList<ChannelFuture> connections = new ArrayList<>();
InetSocketAddress address = ((InetSocketTransportAddress) node.getAddress()).address();
for (int i = 0; i < connectRecovery.length; i++) {
connectRecovery[i] = clientBootstrap.connect(address);
for (int i = 0; i < numConnections; i++) {
connections.add(clientBootstrap.connect(address));
}
for (int i = 0; i < connectBulk.length; i++) {
connectBulk[i] = clientBootstrap.connect(address);
}
for (int i = 0; i < connectReg.length; i++) {
connectReg[i] = clientBootstrap.connect(address);
}
for (int i = 0; i < connectState.length; i++) {
connectState[i] = clientBootstrap.connect(address);
}
for (int i = 0; i < connectPing.length; i++) {
connectPing[i] = clientBootstrap.connect(address);
}
final Iterator<ChannelFuture> iterator = connections.iterator();
try {
for (int i = 0; i < connectRecovery.length; i++) {
connectRecovery[i].awaitUninterruptibly((long) (connectTimeout.millis() * 1.5));
if (!connectRecovery[i].isSuccess()) {
throw new ConnectTransportException(node, "connect_timeout[" + connectTimeout + "]", connectRecovery[i].getCause());
for (Channel[] channels : nodeChannels.getChannelArrays()) {
for (int i = 0; i < channels.length; i++) {
assert iterator.hasNext();
ChannelFuture future = iterator.next();
future.awaitUninterruptibly((long) (connectTimeout.millis() * 1.5));
if (!future.isSuccess()) {
throw new ConnectTransportException(node, "connect_timeout[" + connectTimeout + "]", future.getCause());
}
channels[i] = future.getChannel();
channels[i].getCloseFuture().addListener(new ChannelCloseListener(node));
}
nodeChannels.recovery[i] = connectRecovery[i].getChannel();
nodeChannels.recovery[i].getCloseFuture().addListener(new ChannelCloseListener(node));
}
for (int i = 0; i < connectBulk.length; i++) {
connectBulk[i].awaitUninterruptibly((long) (connectTimeout.millis() * 1.5));
if (!connectBulk[i].isSuccess()) {
throw new ConnectTransportException(node, "connect_timeout[" + connectTimeout + "]", connectBulk[i].getCause());
}
nodeChannels.bulk[i] = connectBulk[i].getChannel();
nodeChannels.bulk[i].getCloseFuture().addListener(new ChannelCloseListener(node));
}
for (int i = 0; i < connectReg.length; i++) {
connectReg[i].awaitUninterruptibly((long) (connectTimeout.millis() * 1.5));
if (!connectReg[i].isSuccess()) {
throw new ConnectTransportException(node, "connect_timeout[" + connectTimeout + "]", connectReg[i].getCause());
}
nodeChannels.reg[i] = connectReg[i].getChannel();
nodeChannels.reg[i].getCloseFuture().addListener(new ChannelCloseListener(node));
}
for (int i = 0; i < connectState.length; i++) {
connectState[i].awaitUninterruptibly((long) (connectTimeout.millis() * 1.5));
if (!connectState[i].isSuccess()) {
throw new ConnectTransportException(node, "connect_timeout[" + connectTimeout + "]", connectState[i].getCause());
}
nodeChannels.state[i] = connectState[i].getChannel();
nodeChannels.state[i].getCloseFuture().addListener(new ChannelCloseListener(node));
}
for (int i = 0; i < connectPing.length; i++) {
connectPing[i].awaitUninterruptibly((long) (connectTimeout.millis() * 1.5));
if (!connectPing[i].isSuccess()) {
throw new ConnectTransportException(node, "connect_timeout[" + connectTimeout + "]", connectPing[i].getCause());
}
nodeChannels.ping[i] = connectPing[i].getChannel();
nodeChannels.ping[i].getCloseFuture().addListener(new ChannelCloseListener(node));
}
if (nodeChannels.recovery.length == 0) {
if (nodeChannels.bulk.length > 0) {
nodeChannels.recovery = nodeChannels.bulk;
@ -432,14 +374,7 @@ public class NettyTransport extends TcpTransport<Channel> {
nodeChannels.bulk = nodeChannels.reg;
}
} catch (RuntimeException e) {
// clean the futures
List<ChannelFuture> futures = new ArrayList<>();
futures.addAll(Arrays.asList(connectRecovery));
futures.addAll(Arrays.asList(connectBulk));
futures.addAll(Arrays.asList(connectReg));
futures.addAll(Arrays.asList(connectState));
futures.addAll(Arrays.asList(connectPing));
for (ChannelFuture future : Collections.unmodifiableList(futures)) {
for (ChannelFuture future : Collections.unmodifiableList(connections)) {
future.cancel();
if (future.getChannel() != null && future.getChannel().isOpen()) {
try {
@ -546,6 +481,7 @@ public class NettyTransport extends TcpTransport<Channel> {
}
}
@Override
protected void sendMessage(Channel channel, BytesReference reference, Runnable sendListener, boolean close) {
final ChannelFuture future = channel.write(NettyUtils.toChannelBuffer(reference));
if (close) {
@ -617,93 +553,4 @@ public class NettyTransport extends TcpTransport<Channel> {
}
});
}
@Override
public Message<Channel> prepareSend(Version nodeVersion, TransportMessage message, StreamOutput stream,
ReleasableBytesStream writtenBytes) throws IOException {
// it might be nice to somehow generalize this optimization, maybe a smart "paged" bytes output
// that create paged channel buffers, but its tricky to know when to do it (where this option is
// more explicit).
if (message instanceof BytesTransportRequest) {
BytesTransportRequest bRequest = (BytesTransportRequest) message;
assert nodeVersion.equals(bRequest.version());
bRequest.writeThin(stream);
stream.close();
ReleasablePagedBytesReference bytes = writtenBytes.bytes();
ChannelBuffer headerBuffer = NettyUtils.toChannelBuffer(bytes);
ChannelBuffer contentBuffer = NettyUtils.toChannelBuffer(bRequest.bytes());
ChannelBuffer buffer = ChannelBuffers.wrappedBuffer(NettyUtils.DEFAULT_GATHERING, headerBuffer, contentBuffer);
return new NettyMessage(buffer);
} else {
return super.prepareSend(nodeVersion, message, stream, writtenBytes);
}
}
@Override
public Message<Channel> prepareSend(Version nodeVersion, BytesReference bytesReference) {
return new NettyMessage(NettyUtils.toChannelBuffer(bytesReference));
}
@Override
public boolean canCompress(TransportRequest request) {
return super.canCompress(request) && (!(request instanceof BytesTransportRequest));
}
private class NettyMessage implements Message<Channel> {
private final ChannelBuffer buffer;
public NettyMessage(ChannelBuffer buffer) {
this.buffer = buffer;
}
public StreamOutput getHeaderOutput() {
return new ChannelBufferStreamOutput(buffer);
}
public int size() {
return buffer.readableBytes();
}
@Override
public void send(Channel channel, Runnable onRequestSent) {
ChannelFuture future = channel.write(buffer);
ChannelFutureListener channelFutureListener = f -> onRequestSent.run();
future.addListener(channelFutureListener);
}
}
private static final class ChannelBufferStreamOutput extends StreamOutput {
private final ChannelBuffer buffer;
private int offset;
public ChannelBufferStreamOutput(ChannelBuffer buffer) {
this.buffer = buffer;
this.offset = buffer.readerIndex();
}
@Override
public void writeByte(byte b) throws IOException {
buffer.setByte(offset++, b);
}
@Override
public void writeBytes(byte[] b, int offset, int length) throws IOException {
buffer.setBytes(this.offset, b, offset, length);
this.offset += length;
}
@Override
public void flush() throws IOException {
}
@Override
public void close() throws IOException {
}
@Override
public void reset() throws IOException {
throw new UnsupportedOperationException();
}
}
}

View File

@ -42,11 +42,20 @@ public abstract class AbstractBytesReferenceTestCase extends ESTestCase {
public void testGet() throws IOException {
int length = randomIntBetween(1, PAGE_SIZE * 3);
BytesReference pbr = newBytesReference(length);
int sliceOffset = randomIntBetween(0, length / 2);
int sliceLength = Math.max(1, length - sliceOffset - 1);
BytesReference slice = pbr.slice(sliceOffset, sliceLength);
assertEquals(pbr.get(sliceOffset), slice.get(0));
assertEquals(pbr.get(sliceOffset + sliceLength - 1), slice.get(sliceLength - 1));
final int probes = randomIntBetween(20, 100);
BytesReference copy = new BytesArray(pbr.toBytesRef(), true);
for (int i = 0; i < probes; i++) {
int index = randomIntBetween(0, copy.length() - 1);
assertEquals(pbr.get(index), copy.get(index));
index = randomIntBetween(sliceOffset, sliceOffset + sliceLength);
assertEquals(pbr.get(index), slice.get(index - sliceOffset));
}
}
public void testLength() throws IOException {
@ -121,6 +130,26 @@ public abstract class AbstractBytesReferenceTestCase extends ESTestCase {
si.readBytes(targetBuf, 0, length * 2));
}
public void testStreamInputMarkAndReset() throws IOException {
int length = randomIntBetween(10, scaledRandomIntBetween(PAGE_SIZE * 2, PAGE_SIZE * 20));
BytesReference pbr = newBytesReference(length);
StreamInput si = pbr.streamInput();
assertNotNull(si);
StreamInput wrap = StreamInput.wrap(BytesReference.toBytes(pbr));
while(wrap.available() > 0) {
if (rarely()) {
wrap.mark(Integer.MAX_VALUE);
si.mark(Integer.MAX_VALUE);
} else if (rarely()) {
wrap.reset();
si.reset();
}
assertEquals(si.readByte(), wrap.readByte());
assertEquals(si.available(), wrap.available());
}
}
public void testStreamInputBulkReadWithOffset() throws IOException {
final int length = randomIntBetween(10, scaledRandomIntBetween(PAGE_SIZE * 2, PAGE_SIZE * 20));
BytesReference pbr = newBytesReference(length);
@ -233,6 +262,24 @@ public abstract class AbstractBytesReferenceTestCase extends ESTestCase {
out.close();
}
public void testInputStreamSkip() throws IOException {
int length = randomIntBetween(10, scaledRandomIntBetween(PAGE_SIZE * 2, PAGE_SIZE * 20));
BytesReference pbr = newBytesReference(length);
final int iters = randomIntBetween(5, 50);
for (int i = 0; i < iters; i++) {
try (StreamInput input = pbr.streamInput()) {
final int offset = randomIntBetween(0, length-1);
assertEquals(offset, input.skip(offset));
assertEquals(pbr.get(offset), input.readByte());
final int nextOffset = randomIntBetween(offset, length-2);
assertEquals(nextOffset - offset, input.skip(nextOffset - offset));
assertEquals(pbr.get(nextOffset+1), input.readByte()); // +1 for the one byte we read above
assertEquals(length - (nextOffset+2), input.skip(Long.MAX_VALUE));
assertEquals(0, input.skip(randomIntBetween(0, Integer.MAX_VALUE)));
}
}
}
public void testSliceWriteToOutputStream() throws IOException {
int length = randomIntBetween(10, PAGE_SIZE * randomIntBetween(2, 5));
BytesReference pbr = newBytesReference(length);
@ -252,6 +299,9 @@ public abstract class AbstractBytesReferenceTestCase extends ESTestCase {
BytesReference pbr = newBytesReference(sizes[i]);
byte[] bytes = BytesReference.toBytes(pbr);
assertEquals(sizes[i], bytes.length);
for (int j = 0; j < bytes.length; j++) {
assertEquals(bytes[j], pbr.get(j));
}
}
}

View File

@ -0,0 +1,110 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.common.bytes;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.BytesRefIterator;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
public class CompositeBytesReferenceTests extends AbstractBytesReferenceTestCase {
@Override
protected BytesReference newBytesReference(int length) throws IOException {
// we know bytes stream output always creates a paged bytes reference, we use it to create randomized content
List<BytesReference> referenceList = newRefList(length);
BytesReference ref = new CompositeBytesReference(referenceList.toArray(new BytesReference[0]));
assertEquals(length, ref.length());
return ref;
}
private List<BytesReference> newRefList(int length) throws IOException {
List<BytesReference> referenceList = new ArrayList<>();
for (int i = 0; i < length;) {
int remaining = length-i;
int sliceLength = randomIntBetween(1, remaining);
ReleasableBytesStreamOutput out = new ReleasableBytesStreamOutput(sliceLength, bigarrays);
for (int j = 0; j < sliceLength; j++) {
out.writeByte((byte) random().nextInt(1 << 8));
}
assertEquals(sliceLength, out.size());
referenceList.add(out.bytes());
i+=sliceLength;
}
return referenceList;
}
public void testCompositeBuffer() throws IOException {
List<BytesReference> referenceList = newRefList(randomIntBetween(1, PAGE_SIZE * 2));
BytesReference ref = new CompositeBytesReference(referenceList.toArray(new BytesReference[0]));
BytesRefIterator iterator = ref.iterator();
BytesRefBuilder builder = new BytesRefBuilder();
for (BytesReference reference : referenceList) {
BytesRefIterator innerIter = reference.iterator(); // sometimes we have a paged ref - pull an iter and walk all pages!
BytesRef scratch;
while ((scratch = innerIter.next()) != null) {
BytesRef next = iterator.next();
assertNotNull(next);
assertEquals(next, scratch);
builder.append(next);
}
}
assertNull(iterator.next());
int offset = 0;
for (BytesReference reference : referenceList) {
assertEquals(reference, ref.slice(offset, reference.length()));
int probes = randomIntBetween(Math.min(10, reference.length()), reference.length());
for (int i = 0; i < probes; i++) {
int index = randomIntBetween(0, reference.length()-1);
assertEquals(ref.get(offset + index), reference.get(index));
}
offset += reference.length();
}
BytesArray array = new BytesArray(builder.toBytesRef());
assertEquals(array, ref);
assertEquals(array.hashCode(), ref.hashCode());
BytesStreamOutput output = new BytesStreamOutput();
ref.writeTo(output);
assertEquals(array, output.bytes());
}
@Override
public void testToBytesRefSharedPage() throws IOException {
// CompositeBytesReference doesn't share pages
}
@Override
public void testSliceArrayOffset() throws IOException {
// the assertions in this test only work on no-composite buffers
}
@Override
public void testSliceToBytesRef() throws IOException {
// CompositeBytesReference shifts offsets
}
}

View File

@ -20,7 +20,6 @@
package org.elasticsearch.discovery.zen.ping.unicast;
import org.elasticsearch.Version;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;