Create nio-transport plugin for NioTransport (#27949)

This is related to #27260. This commit moves the NioTransport from
:test:framework to a new nio-transport plugin. Additionally, supporting
tcp decoding classes are moved to this plugin. Generic byte reading and
writing contexts are moved to the nio library.

Additionally, this commit adds a basic MockNioTransport to
:test:framework that is a TcpTransport implementation for testing that
is driven by nio.
This commit is contained in:
Tim Brooks 2018-01-05 09:41:29 -07:00 committed by GitHub
parent fdb9b50747
commit 38701fb6ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
41 changed files with 1393 additions and 741 deletions

View File

@ -27,6 +27,8 @@ import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.function.ToIntBiFunction;
/**
@ -148,6 +150,37 @@ public abstract class BytesReference implements Accountable, Comparable<BytesRef
return BytesRef.deepCopyOf(bytesRef).bytes;
}
/**
* Returns an array of byte buffers from the given BytesReference.
*/
public static ByteBuffer[] toByteBuffers(BytesReference reference) {
BytesRefIterator byteRefIterator = reference.iterator();
BytesRef r;
try {
ArrayList<ByteBuffer> buffers = new ArrayList<>();
while ((r = byteRefIterator.next()) != null) {
buffers.add(ByteBuffer.wrap(r.bytes, r.offset, r.length));
}
return buffers.toArray(new ByteBuffer[buffers.size()]);
} catch (IOException e) {
// this is really an error since we don't do IO in our bytesreferences
throw new AssertionError("won't happen", e);
}
}
/**
* Returns BytesReference composed of the provided ByteBuffers.
*/
public static BytesReference fromByteBuffers(ByteBuffer[] buffers) {
ByteBufferReference[] references = new ByteBufferReference[buffers.length];
for (int i = 0; i < references.length; ++i) {
references[i] = new ByteBufferReference(buffers[i]);
}
return new CompositeBytesReference(references);
}
@Override
public int compareTo(final BytesReference other) {
return compareIterators(this, other, (a, b) -> a.compareTo(b));

View File

@ -46,12 +46,17 @@ import java.util.concurrent.TimeoutException;
public interface TcpChannel extends Releasable {
/**
* Closes the channel. This might be an asynchronous process. There is notguarantee that the channel
* Closes the channel. This might be an asynchronous process. There is no guarantee that the channel
* will be closed when this method returns. Use the {@link #addCloseListener(ActionListener)} method
* to implement logic that depends on knowing when the channel is closed.
*/
void close();
/**
* This returns the profile for this channel.
*/
String getProfile();
/**
* Adds a listener that will be executed when the channel is closed. If the channel is still open when
* this listener is added, the listener will be executed by the thread that eventually closes the
@ -86,6 +91,13 @@ public interface TcpChannel extends Releasable {
*/
InetSocketAddress getLocalAddress();
/**
* Returns the remote address for this channel. Can be null if channel does not have a remote address.
*
* @return the remote address of this channel.
*/
InetSocketAddress getRemoteAddress();
/**
* Sends a tcp message to the channel. The listener will be executed once the send process has been
* completed.

View File

@ -184,8 +184,12 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
public static final Setting.AffixSetting<Integer> PUBLISH_PORT_PROFILE = affixKeySetting("transport.profiles.", "publish_port",
key -> intSetting(key, -1, -1, Setting.Property.NodeScope));
private static final long NINETY_PER_HEAP_SIZE = (long) (JvmInfo.jvmInfo().getMem().getHeapMax().getBytes() * 0.9);
// This is the number of bytes necessary to read the message size
public static final int BYTES_NEEDED_FOR_MESSAGE_SIZE = TcpHeader.MARKER_BYTES_SIZE + TcpHeader.MESSAGE_LENGTH_SIZE;
public static final int PING_DATA_SIZE = -1;
private static final long NINETY_PER_HEAP_SIZE = (long) (JvmInfo.jvmInfo().getMem().getHeapMax().getBytes() * 0.9);
private static final BytesReference EMPTY_BYTES_REFERENCE = new BytesArray(new byte[0]);
private final CircuitBreakerService circuitBreakerService;
// package visibility for tests
protected final ScheduledPing scheduledPing;
@ -317,8 +321,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
public class ScheduledPing extends AbstractLifecycleRunnable {
/**
* The magic number (must be lower than 0) for a ping message. This is handled
* specifically in {@link TcpTransport#validateMessageHeader}.
* The magic number (must be lower than 0) for a ping message.
*/
private final BytesReference pingHeader;
final CounterMetric successfulPings = new CounterMetric();
@ -1210,7 +1213,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
* @param length the payload length in bytes
* @see TcpHeader
*/
final BytesReference buildHeader(long requestId, byte status, Version protocolVersion, int length) throws IOException {
private BytesReference buildHeader(long requestId, byte status, Version protocolVersion, int length) throws IOException {
try (BytesStreamOutput headerOutput = new BytesStreamOutput(TcpHeader.HEADER_SIZE)) {
headerOutput.setVersion(protocolVersion);
TcpHeader.writeHeader(headerOutput, requestId, status, protocolVersion, length);
@ -1247,76 +1250,135 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
}
/**
* Validates the first N bytes of the message header and returns <code>false</code> if the message is
* a ping message and has no payload ie. isn't a real user level message.
* Consumes bytes that are available from network reads. This method returns the number of bytes consumed
* in this call.
*
* @throws IllegalStateException if the message is too short, less than the header or less that the header plus the message size
* @throws HttpOnTransportException if the message has no valid header and appears to be a HTTP message
* @throws IllegalArgumentException if the message is greater that the maximum allowed frame size. This is dependent on the available
* memory.
* @param channel the channel read from
* @param bytesReference the bytes available to consume
* @return the number of bytes consumed
* @throws StreamCorruptedException if the message header format is not recognized
* @throws TcpTransport.HttpOnTransportException if the message header appears to be a HTTP message
* @throws IllegalArgumentException if the message length is greater that the maximum allowed frame size.
* This is dependent on the available memory.
*/
public static boolean validateMessageHeader(BytesReference buffer) throws IOException {
final int sizeHeaderLength = TcpHeader.MARKER_BYTES_SIZE + TcpHeader.MESSAGE_LENGTH_SIZE;
if (buffer.length() < sizeHeaderLength) {
throw new IllegalStateException("message size must be >= to the header size");
}
int offset = 0;
if (buffer.get(offset) != 'E' || buffer.get(offset + 1) != 'S') {
// special handling for what is probably HTTP
if (bufferStartsWith(buffer, offset, "GET ") ||
bufferStartsWith(buffer, offset, "POST ") ||
bufferStartsWith(buffer, offset, "PUT ") ||
bufferStartsWith(buffer, offset, "HEAD ") ||
bufferStartsWith(buffer, offset, "DELETE ") ||
bufferStartsWith(buffer, offset, "OPTIONS ") ||
bufferStartsWith(buffer, offset, "PATCH ") ||
bufferStartsWith(buffer, offset, "TRACE ")) {
public int consumeNetworkReads(TcpChannel channel, BytesReference bytesReference) throws IOException {
BytesReference message = decodeFrame(bytesReference);
throw new HttpOnTransportException("This is not a HTTP port");
if (message == null) {
return 0;
} else if (message.length() == 0) {
// This is a ping and should not be handled.
return BYTES_NEEDED_FOR_MESSAGE_SIZE;
} else {
try {
messageReceived(message, channel);
} catch (Exception e) {
onException(channel, e);
}
return message.length() + BYTES_NEEDED_FOR_MESSAGE_SIZE;
}
}
/**
* Attempts to a decode a message from the provided bytes. If a full message is not available, null is
* returned. If the message is a ping, an empty {@link BytesReference} will be returned.
*
* @param networkBytes the will be read
* @return the message decoded
* @throws StreamCorruptedException if the message header format is not recognized
* @throws TcpTransport.HttpOnTransportException if the message header appears to be a HTTP message
* @throws IllegalArgumentException if the message length is greater that the maximum allowed frame size.
* This is dependent on the available memory.
*/
public static BytesReference decodeFrame(BytesReference networkBytes) throws IOException {
int messageLength = readMessageLength(networkBytes);
if (messageLength == -1) {
return null;
} else {
int totalLength = messageLength + BYTES_NEEDED_FOR_MESSAGE_SIZE;
if (totalLength > networkBytes.length()) {
return null;
} else if (totalLength == 6) {
return EMPTY_BYTES_REFERENCE;
} else {
return networkBytes.slice(BYTES_NEEDED_FOR_MESSAGE_SIZE, messageLength);
}
}
}
/**
* Validates the first 6 bytes of the message header and returns the length of the message. If 6 bytes
* are not available, it returns -1.
*
* @param networkBytes the will be read
* @return the length of the message
* @throws StreamCorruptedException if the message header format is not recognized
* @throws TcpTransport.HttpOnTransportException if the message header appears to be a HTTP message
* @throws IllegalArgumentException if the message length is greater that the maximum allowed frame size.
* This is dependent on the available memory.
*/
public static int readMessageLength(BytesReference networkBytes) throws IOException {
if (networkBytes.length() < BYTES_NEEDED_FOR_MESSAGE_SIZE) {
return -1;
} else {
return readHeaderBuffer(networkBytes);
}
}
private static int readHeaderBuffer(BytesReference headerBuffer) throws IOException {
if (headerBuffer.get(0) != 'E' || headerBuffer.get(1) != 'S') {
if (appearsToBeHTTP(headerBuffer)) {
throw new TcpTransport.HttpOnTransportException("This is not a HTTP port");
}
// we have 6 readable bytes, show 4 (should be enough)
throw new StreamCorruptedException("invalid internal transport message format, got ("
+ Integer.toHexString(buffer.get(offset) & 0xFF) + ","
+ Integer.toHexString(buffer.get(offset + 1) & 0xFF) + ","
+ Integer.toHexString(buffer.get(offset + 2) & 0xFF) + ","
+ Integer.toHexString(buffer.get(offset + 3) & 0xFF) + ")");
+ Integer.toHexString(headerBuffer.get(0) & 0xFF) + ","
+ Integer.toHexString(headerBuffer.get(1) & 0xFF) + ","
+ Integer.toHexString(headerBuffer.get(2) & 0xFF) + ","
+ Integer.toHexString(headerBuffer.get(3) & 0xFF) + ")");
}
final int dataLen;
try (StreamInput input = buffer.streamInput()) {
final int messageLength;
try (StreamInput input = headerBuffer.streamInput()) {
input.skip(TcpHeader.MARKER_BYTES_SIZE);
dataLen = input.readInt();
if (dataLen == PING_DATA_SIZE) {
// discard the messages we read and continue, this is achieved by skipping the bytes
// and returning null
return false;
}
messageLength = input.readInt();
}
if (dataLen <= 0) {
throw new StreamCorruptedException("invalid data length: " + dataLen);
if (messageLength == TcpTransport.PING_DATA_SIZE) {
// This is a ping
return 0;
}
// safety against too large frames being sent
if (dataLen > NINETY_PER_HEAP_SIZE) {
throw new IllegalArgumentException("transport content length received [" + new ByteSizeValue(dataLen) + "] exceeded ["
if (messageLength <= 0) {
throw new StreamCorruptedException("invalid data length: " + messageLength);
}
if (messageLength > NINETY_PER_HEAP_SIZE) {
throw new IllegalArgumentException("transport content length received [" + new ByteSizeValue(messageLength) + "] exceeded ["
+ new ByteSizeValue(NINETY_PER_HEAP_SIZE) + "]");
}
if (buffer.length() < dataLen + sizeHeaderLength) {
throw new IllegalStateException("buffer must be >= to the message size but wasn't");
}
return true;
return messageLength;
}
private static boolean bufferStartsWith(BytesReference buffer, int offset, String method) {
private static boolean appearsToBeHTTP(BytesReference headerBuffer) {
return bufferStartsWith(headerBuffer, "GET") ||
bufferStartsWith(headerBuffer, "POST") ||
bufferStartsWith(headerBuffer, "PUT") ||
bufferStartsWith(headerBuffer, "HEAD") ||
bufferStartsWith(headerBuffer, "DELETE") ||
// Actually 'OPTIONS'. But we are only guaranteed to have read six bytes at this point.
bufferStartsWith(headerBuffer, "OPTION") ||
bufferStartsWith(headerBuffer, "PATCH") ||
bufferStartsWith(headerBuffer, "TRACE");
}
private static boolean bufferStartsWith(BytesReference buffer, String method) {
char[] chars = method.toCharArray();
for (int i = 0; i < chars.length; i++) {
if (buffer.get(offset + i) != chars[i]) {
if (buffer.get(i) != chars[i]) {
return false;
}
}
return true;
}
@ -1343,8 +1405,10 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
/**
* This method handles the message receive part for both request and responses
*/
public final void messageReceived(BytesReference reference, TcpChannel channel, String profileName,
InetSocketAddress remoteAddress, int messageLengthBytes) throws IOException {
public final void messageReceived(BytesReference reference, TcpChannel channel) throws IOException {
String profileName = channel.getProfile();
InetSocketAddress remoteAddress = channel.getRemoteAddress();
int messageLengthBytes = reference.length();
final int totalMessageSize = messageLengthBytes + TcpHeader.MARKER_BYTES_SIZE + TcpHeader.MESSAGE_LENGTH_SIZE;
readBytesMetric.inc(totalMessageSize);
// we have additional bytes to read, outside of the header

View File

@ -22,8 +22,10 @@ package org.elasticsearch.transport;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.compress.CompressorFactory;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.Settings;
@ -37,12 +39,17 @@ import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import java.io.IOException;
import java.io.StreamCorruptedException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.core.IsInstanceOf.instanceOf;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
/** Unit tests for {@link TcpTransport} */
public class TcpTransportTests extends ESTestCase {
@ -246,6 +253,11 @@ public class TcpTransportTests extends ESTestCase {
public void close() {
}
@Override
public String getProfile() {
return null;
}
@Override
public void addCloseListener(ActionListener<Void> listener) {
}
@ -264,6 +276,11 @@ public class TcpTransportTests extends ESTestCase {
return null;
}
@Override
public InetSocketAddress getRemoteAddress() {
return null;
}
@Override
public void sendMessage(BytesReference reference, ActionListener<Void> listener) {
messageCaptor.set(reference);
@ -354,4 +371,126 @@ public class TcpTransportTests extends ESTestCase {
assertEquals(3, profile.getNumConnectionsPerType(TransportRequestOptions.Type.BULK));
}
public void testDecodeWithIncompleteHeader() throws IOException {
BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14);
streamOutput.write('E');
streamOutput.write('S');
streamOutput.write(1);
streamOutput.write(1);
assertNull(TcpTransport.decodeFrame(streamOutput.bytes()));
}
public void testDecodePing() throws IOException {
BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14);
streamOutput.write('E');
streamOutput.write('S');
streamOutput.writeInt(-1);
BytesReference message = TcpTransport.decodeFrame(streamOutput.bytes());
assertEquals(0, message.length());
}
public void testDecodePingWithStartOfSecondMessage() throws IOException {
BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14);
streamOutput.write('E');
streamOutput.write('S');
streamOutput.writeInt(-1);
streamOutput.write('E');
streamOutput.write('S');
BytesReference message = TcpTransport.decodeFrame(streamOutput.bytes());
assertEquals(0, message.length());
}
public void testDecodeMessage() throws IOException {
BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14);
streamOutput.write('E');
streamOutput.write('S');
streamOutput.writeInt(2);
streamOutput.write('M');
streamOutput.write('A');
BytesReference message = TcpTransport.decodeFrame(streamOutput.bytes());
assertEquals(streamOutput.bytes().slice(6, 2), message);
}
public void testDecodeIncompleteMessage() throws IOException {
BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14);
streamOutput.write('E');
streamOutput.write('S');
streamOutput.writeInt(3);
streamOutput.write('M');
streamOutput.write('A');
BytesReference message = TcpTransport.decodeFrame(streamOutput.bytes());
assertNull(message);
}
public void testInvalidLength() throws IOException {
BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14);
streamOutput.write('E');
streamOutput.write('S');
streamOutput.writeInt(-2);
streamOutput.write('M');
streamOutput.write('A');
try {
TcpTransport.decodeFrame(streamOutput.bytes());
fail("Expected exception");
} catch (Exception ex) {
assertThat(ex, instanceOf(StreamCorruptedException.class));
assertEquals("invalid data length: -2", ex.getMessage());
}
}
public void testInvalidHeader() throws IOException {
BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14);
streamOutput.write('E');
streamOutput.write('C');
byte byte1 = randomByte();
byte byte2 = randomByte();
streamOutput.write(byte1);
streamOutput.write(byte2);
streamOutput.write(randomByte());
streamOutput.write(randomByte());
streamOutput.write(randomByte());
try {
TcpTransport.decodeFrame(streamOutput.bytes());
fail("Expected exception");
} catch (Exception ex) {
assertThat(ex, instanceOf(StreamCorruptedException.class));
String expected = "invalid internal transport message format, got (45,43,"
+ Integer.toHexString(byte1 & 0xFF) + ","
+ Integer.toHexString(byte2 & 0xFF) + ")";
assertEquals(expected, ex.getMessage());
}
}
public void testHTTPHeader() throws IOException {
String[] httpHeaders = {"GET", "POST", "PUT", "HEAD", "DELETE", "OPTIONS", "PATCH", "TRACE"};
for (String httpHeader : httpHeaders) {
BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14);
for (char c : httpHeader.toCharArray()) {
streamOutput.write((byte) c);
}
streamOutput.write(new byte[6]);
try {
BytesReference bytes = streamOutput.bytes();
TcpTransport.decodeFrame(bytes);
fail("Expected exception");
} catch (Exception ex) {
assertThat(ex, instanceOf(TcpTransport.HttpOnTransportException.class));
assertEquals("This is not a HTTP port", ex.getMessage());
}
}
}
}

View File

@ -0,0 +1,64 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.nio;
import java.io.IOException;
public class BytesReadContext implements ReadContext {
private final NioSocketChannel channel;
private final ReadConsumer readConsumer;
private final InboundChannelBuffer channelBuffer;
public BytesReadContext(NioSocketChannel channel, ReadConsumer readConsumer, InboundChannelBuffer channelBuffer) {
this.channel = channel;
this.channelBuffer = channelBuffer;
this.readConsumer = readConsumer;
}
@Override
public int read() throws IOException {
if (channelBuffer.getRemaining() == 0) {
// Requiring one additional byte will ensure that a new page is allocated.
channelBuffer.ensureCapacity(channelBuffer.getCapacity() + 1);
}
int bytesRead = channel.read(channelBuffer.sliceBuffersFrom(channelBuffer.getIndex()));
if (bytesRead == -1) {
return bytesRead;
}
channelBuffer.incrementIndex(bytesRead);
int bytesConsumed = Integer.MAX_VALUE;
while (bytesConsumed > 0) {
bytesConsumed = readConsumer.consumeReads(channelBuffer);
channelBuffer.release(bytesConsumed);
}
return bytesRead;
}
@Override
public void close() {
channelBuffer.close();
}
}

View File

@ -17,41 +17,32 @@
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefIterator;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.SocketSelector;
import org.elasticsearch.nio.WriteContext;
import org.elasticsearch.nio.WriteOperation;
package org.elasticsearch.nio;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.function.BiConsumer;
public class TcpWriteContext implements WriteContext {
public class BytesWriteContext implements WriteContext {
private final NioSocketChannel channel;
private final LinkedList<WriteOperation> queued = new LinkedList<>();
public TcpWriteContext(NioSocketChannel channel) {
public BytesWriteContext(NioSocketChannel channel) {
this.channel = channel;
}
@Override
public void sendMessage(Object message, BiConsumer<Void, Throwable> listener) {
BytesReference reference = (BytesReference) message;
ByteBuffer[] buffers = (ByteBuffer[]) message;
if (channel.isWritable() == false) {
listener.accept(null, new ClosedChannelException());
return;
}
WriteOperation writeOperation = new WriteOperation(channel, toByteBuffers(reference), listener);
WriteOperation writeOperation = new WriteOperation(channel, buffers, listener);
SocketSelector selector = channel.getSelector();
if (selector.isOnCurrentThread() == false) {
selector.queueWrite(writeOperation);
@ -117,21 +108,4 @@ public class TcpWriteContext implements WriteContext {
lastOpCompleted = op.isFullyFlushed();
}
}
private static ByteBuffer[] toByteBuffers(BytesReference bytesReference) {
BytesRefIterator byteRefIterator = bytesReference.iterator();
BytesRef r;
try {
// Most network messages are composed of three buffers.
ArrayList<ByteBuffer> buffers = new ArrayList<>(3);
while ((r = byteRefIterator.next()) != null) {
buffers.add(ByteBuffer.wrap(r.bytes, r.offset, r.length));
}
return buffers.toArray(new ByteBuffer[buffers.size()]);
} catch (IOException e) {
// this is really an error since we don't do IO in our bytesreferences
throw new AssertionError("won't happen", e);
}
}
}

View File

@ -70,6 +70,18 @@ public class NioSocketChannel extends AbstractNioChannel<SocketChannel> {
}
}
public int read(ByteBuffer buffer) throws IOException {
return socketChannel.read(buffer);
}
public int read(ByteBuffer[] buffers) throws IOException {
if (buffers.length == 1) {
return socketChannel.read(buffers[0]);
} else {
return (int) socketChannel.read(buffers);
}
}
public int read(InboundChannelBuffer buffer) throws IOException {
int bytesRead = (int) socketChannel.read(buffer.sliceBuffersFrom(buffer.getIndex()));

View File

@ -28,4 +28,8 @@ public interface ReadContext extends AutoCloseable {
@Override
void close();
@FunctionalInterface
interface ReadConsumer {
int consumeReads(InboundChannelBuffer channelBuffer) throws IOException;
}
}

View File

@ -0,0 +1,142 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.nio;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.test.ESTestCase;
import org.junit.Before;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.function.Supplier;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class BytesReadContextTests extends ESTestCase {
private ReadContext.ReadConsumer readConsumer;
private NioSocketChannel channel;
private BytesReadContext readContext;
private InboundChannelBuffer channelBuffer;
private int messageLength;
@Before
public void init() {
readConsumer = mock(ReadContext.ReadConsumer.class);
messageLength = randomInt(96) + 20;
channel = mock(NioSocketChannel.class);
Supplier<InboundChannelBuffer.Page> pageSupplier = () ->
new InboundChannelBuffer.Page(ByteBuffer.allocate(BigArrays.BYTE_PAGE_SIZE), () -> {});
channelBuffer = new InboundChannelBuffer(pageSupplier);
readContext = new BytesReadContext(channel, readConsumer, channelBuffer);
}
public void testSuccessfulRead() throws IOException {
byte[] bytes = createMessage(messageLength);
when(channel.read(any(ByteBuffer[].class))).thenAnswer(invocationOnMock -> {
ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0];
buffers[0].put(bytes);
return bytes.length;
});
when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength, 0);
assertEquals(messageLength, readContext.read());
assertEquals(0, channelBuffer.getIndex());
assertEquals(BigArrays.BYTE_PAGE_SIZE - bytes.length, channelBuffer.getCapacity());
verify(readConsumer, times(2)).consumeReads(channelBuffer);
}
public void testMultipleReadsConsumed() throws IOException {
byte[] bytes = createMessage(messageLength * 2);
when(channel.read(any(ByteBuffer[].class))).thenAnswer(invocationOnMock -> {
ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0];
buffers[0].put(bytes);
return bytes.length;
});
when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength, messageLength, 0);
assertEquals(bytes.length, readContext.read());
assertEquals(0, channelBuffer.getIndex());
assertEquals(BigArrays.BYTE_PAGE_SIZE - bytes.length, channelBuffer.getCapacity());
verify(readConsumer, times(3)).consumeReads(channelBuffer);
}
public void testPartialRead() throws IOException {
byte[] bytes = createMessage(messageLength);
when(channel.read(any(ByteBuffer[].class))).thenAnswer(invocationOnMock -> {
ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0];
buffers[0].put(bytes);
return bytes.length;
});
when(readConsumer.consumeReads(channelBuffer)).thenReturn(0, messageLength);
assertEquals(messageLength, readContext.read());
assertEquals(bytes.length, channelBuffer.getIndex());
verify(readConsumer, times(1)).consumeReads(channelBuffer);
when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength * 2, 0);
assertEquals(messageLength, readContext.read());
assertEquals(0, channelBuffer.getIndex());
assertEquals(BigArrays.BYTE_PAGE_SIZE - (bytes.length * 2), channelBuffer.getCapacity());
verify(readConsumer, times(3)).consumeReads(channelBuffer);
}
public void testReadThrowsIOException() throws IOException {
IOException ioException = new IOException();
when(channel.read(any(ByteBuffer[].class))).thenThrow(ioException);
IOException ex = expectThrows(IOException.class, () -> readContext.read());
assertSame(ioException, ex);
}
public void closeClosesChannelBuffer() {
InboundChannelBuffer buffer = mock(InboundChannelBuffer.class);
BytesReadContext readContext = new BytesReadContext(channel, readConsumer, buffer);
readContext.close();
verify(buffer).close();
}
private static byte[] createMessage(int length) {
byte[] bytes = new byte[length];
for (int i = 0; i < length; ++i) {
bytes[i] = randomByte();
}
return bytes;
}
}

View File

@ -17,12 +17,8 @@
* under the License.
*/
package org.elasticsearch.transport.nio;
package org.elasticsearch.nio;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.SocketSelector;
import org.elasticsearch.nio.WriteOperation;
import org.elasticsearch.test.ESTestCase;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
@ -39,11 +35,11 @@ import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class TcpWriteContextTests extends ESTestCase {
public class BytesWriteContextTests extends ESTestCase {
private SocketSelector selector;
private BiConsumer<Void, Throwable> listener;
private TcpWriteContext writeContext;
private BytesWriteContext writeContext;
private NioSocketChannel channel;
@Before
@ -53,7 +49,7 @@ public class TcpWriteContextTests extends ESTestCase {
selector = mock(SocketSelector.class);
listener = mock(BiConsumer.class);
channel = mock(NioSocketChannel.class);
writeContext = new TcpWriteContext(channel);
writeContext = new BytesWriteContext(channel);
when(channel.getSelector()).thenReturn(selector);
when(selector.isOnCurrentThread()).thenReturn(true);
@ -62,44 +58,43 @@ public class TcpWriteContextTests extends ESTestCase {
public void testWriteFailsIfChannelNotWritable() throws Exception {
when(channel.isWritable()).thenReturn(false);
writeContext.sendMessage(new BytesArray(generateBytes(10)), listener);
ByteBuffer[] buffers = {ByteBuffer.wrap(generateBytes(10))};
writeContext.sendMessage(buffers, listener);
verify(listener).accept(isNull(Void.class), any(ClosedChannelException.class));
}
public void testSendMessageFromDifferentThreadIsQueuedWithSelector() throws Exception {
byte[] bytes = generateBytes(10);
BytesArray bytesArray = new BytesArray(bytes);
ArgumentCaptor<WriteOperation> writeOpCaptor = ArgumentCaptor.forClass(WriteOperation.class);
when(selector.isOnCurrentThread()).thenReturn(false);
when(channel.isWritable()).thenReturn(true);
writeContext.sendMessage(bytesArray, listener);
ByteBuffer[] buffers = {ByteBuffer.wrap(generateBytes(10))};
writeContext.sendMessage(buffers, listener);
verify(selector).queueWrite(writeOpCaptor.capture());
WriteOperation writeOp = writeOpCaptor.getValue();
assertSame(listener, writeOp.getListener());
assertSame(channel, writeOp.getChannel());
assertEquals(ByteBuffer.wrap(bytes), writeOp.getByteBuffers()[0]);
assertEquals(buffers[0], writeOp.getByteBuffers()[0]);
}
public void testSendMessageFromSameThreadIsQueuedInChannel() throws Exception {
byte[] bytes = generateBytes(10);
BytesArray bytesArray = new BytesArray(bytes);
ArgumentCaptor<WriteOperation> writeOpCaptor = ArgumentCaptor.forClass(WriteOperation.class);
when(channel.isWritable()).thenReturn(true);
writeContext.sendMessage(bytesArray, listener);
ByteBuffer[] buffers = {ByteBuffer.wrap(generateBytes(10))};
writeContext.sendMessage(buffers, listener);
verify(selector).queueWriteInChannelBuffer(writeOpCaptor.capture());
WriteOperation writeOp = writeOpCaptor.getValue();
assertSame(listener, writeOp.getListener());
assertSame(channel, writeOp.getChannel());
assertEquals(ByteBuffer.wrap(bytes), writeOp.getByteBuffers()[0]);
assertEquals(buffers[0], writeOp.getByteBuffers()[0]);
}
public void testWriteIsQueuedInChannel() throws Exception {
@ -163,7 +158,7 @@ public class TcpWriteContextTests extends ESTestCase {
public void testMultipleWritesPartialFlushes() throws IOException {
assertFalse(writeContext.hasQueuedWriteOps());
BiConsumer listener2 = mock(BiConsumer.class);
BiConsumer<Void, Throwable> listener2 = mock(BiConsumer.class);
WriteOperation writeOperation1 = mock(WriteOperation.class);
WriteOperation writeOperation2 = mock(WriteOperation.class);
when(writeOperation1.getListener()).thenReturn(listener);

View File

@ -20,7 +20,6 @@
package org.elasticsearch.nio;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.transport.nio.TcpWriteContext;
import org.junit.Before;
import java.io.IOException;
@ -54,7 +53,7 @@ public class SocketEventHandlerTests extends ESTestCase {
readContext = mock(ReadContext.class);
when(rawChannel.finishConnect()).thenReturn(true);
channel.setContexts(readContext, new TcpWriteContext(channel), exceptionHandler);
channel.setContexts(readContext, new BytesWriteContext(channel), exceptionHandler);
channel.register();
channel.finishConnect();

View File

@ -56,12 +56,11 @@ final class Netty4MessageChannelHandler extends ChannelDuplexHandler {
final int expectedReaderIndex = buffer.readerIndex() + remainingMessageSize;
try {
Channel channel = ctx.channel();
InetSocketAddress remoteAddress = (InetSocketAddress) channel.remoteAddress();
// netty always copies a buffer, either in NioWorker in its read handler, where it copies to a fresh
// buffer, or in the cumulative buffer, which is cleaned each time so it could be bigger than the actual size
BytesReference reference = Netty4Utils.toBytesReference(buffer, remainingMessageSize);
Attribute<NettyTcpChannel> channelAttribute = channel.attr(Netty4Transport.CHANNEL_KEY);
transport.messageReceived(reference, channelAttribute.get(), profileName, remoteAddress, remainingMessageSize);
transport.messageReceived(reference, channelAttribute.get());
} finally {
// Set the expected position of the buffer, no matter what happened
buffer.readerIndex(expectedReaderIndex);

View File

@ -23,6 +23,7 @@ import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.TooLongFrameException;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.transport.TcpHeader;
import org.elasticsearch.transport.TcpTransport;
@ -30,20 +31,25 @@ import java.util.List;
final class Netty4SizeHeaderFrameDecoder extends ByteToMessageDecoder {
private static final int HEADER_SIZE = TcpHeader.MARKER_BYTES_SIZE + TcpHeader.MESSAGE_LENGTH_SIZE;
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
try {
boolean continueProcessing = TcpTransport.validateMessageHeader(Netty4Utils.toBytesReference(in));
final ByteBuf message = in.skipBytes(TcpHeader.MARKER_BYTES_SIZE + TcpHeader.MESSAGE_LENGTH_SIZE);
if (!continueProcessing) return;
out.add(message);
BytesReference networkBytes = Netty4Utils.toBytesReference(in);
int messageLength = TcpTransport.readMessageLength(networkBytes) + HEADER_SIZE;
// If the message length is -1, we have not read a complete header. If the message length is
// greater than the network bytes available, we have not read a complete frame.
if (messageLength != -1 && messageLength <= networkBytes.length()) {
final ByteBuf message = in.skipBytes(HEADER_SIZE);
// 6 bytes would mean it is a ping. And we should ignore.
if (messageLength != 6) {
out.add(message);
}
}
} catch (IllegalArgumentException ex) {
throw new TooLongFrameException(ex);
} catch (IllegalStateException ex) {
/* decode will be called until the ByteBuf is fully consumed; when it is fully
* consumed, transport#validateMessageHeader will throw an IllegalStateException which
* is okay, it means we have finished consuming the ByteBuf and we can get out
*/
}
}

View File

@ -249,7 +249,7 @@ public class Netty4Transport extends TcpTransport {
}
addClosedExceptionLogger(channel);
NettyTcpChannel nettyChannel = new NettyTcpChannel(channel);
NettyTcpChannel nettyChannel = new NettyTcpChannel(channel, "default");
channel.attr(CHANNEL_KEY).set(nettyChannel);
channelFuture.addListener(f -> {
@ -272,7 +272,7 @@ public class Netty4Transport extends TcpTransport {
@Override
protected NettyTcpChannel bind(String name, InetSocketAddress address) {
Channel channel = serverBootstraps.get(name).bind(address).syncUninterruptibly().channel();
NettyTcpChannel esChannel = new NettyTcpChannel(channel);
NettyTcpChannel esChannel = new NettyTcpChannel(channel, name);
channel.attr(CHANNEL_KEY).set(esChannel);
return esChannel;
}
@ -335,7 +335,7 @@ public class Netty4Transport extends TcpTransport {
@Override
protected void initChannel(Channel ch) throws Exception {
addClosedExceptionLogger(ch);
NettyTcpChannel nettyTcpChannel = new NettyTcpChannel(ch);
NettyTcpChannel nettyTcpChannel = new NettyTcpChannel(ch, name);
ch.attr(CHANNEL_KEY).set(nettyTcpChannel);
serverAcceptedChannel(nettyTcpChannel);
ch.pipeline().addLast("logging", new ESLoggingHandler());

View File

@ -38,10 +38,12 @@ import java.util.concurrent.CompletableFuture;
public class NettyTcpChannel implements TcpChannel {
private final Channel channel;
private final String profile;
private final CompletableFuture<Void> closeContext = new CompletableFuture<>();
NettyTcpChannel(Channel channel) {
NettyTcpChannel(Channel channel, String profile) {
this.channel = channel;
this.profile = profile;
this.channel.closeFuture().addListener(f -> {
if (f.isSuccess()) {
closeContext.complete(null);
@ -62,6 +64,11 @@ public class NettyTcpChannel implements TcpChannel {
channel.close();
}
@Override
public String getProfile() {
return profile;
}
@Override
public void addCloseListener(ActionListener<Void> listener) {
closeContext.whenComplete(ActionListener.toBiConsumer(listener));
@ -82,6 +89,11 @@ public class NettyTcpChannel implements TcpChannel {
return (InetSocketAddress) channel.localAddress();
}
@Override
public InetSocketAddress getRemoteAddress() {
return (InetSocketAddress) channel.remoteAddress();
}
@Override
public void sendMessage(BytesReference reference, ActionListener<Void> listener) {
ChannelPromise writePromise = channel.newPromise();

View File

@ -0,0 +1,32 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
esplugin {
description 'The nio transport.'
classname 'org.elasticsearch.transport.nio.NioTransportPlugin'
}
dependencyLicenses.enabled = false
compileJava.options.compilerArgs << "-Xlint:-try"
compileTestJava.options.compilerArgs << "-Xlint:-rawtypes,-unchecked"
dependencies {
compile "org.elasticsearch:elasticsearch-nio:${version}"
}

View File

@ -0,0 +1,32 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.elasticsearch.bootstrap.BootstrapCheck;
import org.elasticsearch.bootstrap.BootstrapContext;
public class NioNotEnabledBootstrapCheck implements BootstrapCheck {
@Override
public BootstrapCheckResult check(BootstrapContext context) {
return BootstrapCheckResult.failure("The transport-nio plugin is experimental and not ready for production usage. It should " +
"not be enabled in production.");
}
}

View File

@ -19,10 +19,10 @@
package org.elasticsearch.transport.nio;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.recycler.Recycler;
@ -34,12 +34,15 @@ import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.nio.AcceptorEventHandler;
import org.elasticsearch.nio.BytesReadContext;
import org.elasticsearch.nio.BytesWriteContext;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioGroup;
import org.elasticsearch.nio.NioServerSocketChannel;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.ReadContext;
import org.elasticsearch.nio.SocketEventHandler;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.Transports;
@ -56,8 +59,8 @@ import static org.elasticsearch.common.util.concurrent.EsExecutors.daemonThreadF
public class NioTransport extends TcpTransport {
public static final String TRANSPORT_WORKER_THREAD_NAME_PREFIX = Transports.NIO_TRANSPORT_WORKER_THREAD_NAME_PREFIX;
public static final String TRANSPORT_ACCEPTOR_THREAD_NAME_PREFIX = Transports.NIO_TRANSPORT_ACCEPTOR_THREAD_NAME_PREFIX;
private static final String TRANSPORT_WORKER_THREAD_NAME_PREFIX = Transports.NIO_TRANSPORT_WORKER_THREAD_NAME_PREFIX;
private static final String TRANSPORT_ACCEPTOR_THREAD_NAME_PREFIX = Transports.NIO_TRANSPORT_ACCEPTOR_THREAD_NAME_PREFIX;
public static final Setting<Integer> NIO_WORKER_COUNT =
new Setting<>("transport.nio.worker_count",
@ -72,9 +75,9 @@ public class NioTransport extends TcpTransport {
private volatile NioGroup nioGroup;
private volatile TcpChannelFactory clientChannelFactory;
public NioTransport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays,
PageCacheRecycler pageCacheRecycler, NamedWriteableRegistry namedWriteableRegistry,
CircuitBreakerService circuitBreakerService) {
NioTransport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays,
PageCacheRecycler pageCacheRecycler, NamedWriteableRegistry namedWriteableRegistry,
CircuitBreakerService circuitBreakerService) {
super("nio", settings, threadPool, bigArrays, circuitBreakerService, namedWriteableRegistry, networkService);
this.pageCacheRecycler = pageCacheRecycler;
}
@ -104,17 +107,16 @@ public class NioTransport extends TcpTransport {
}
nioGroup = new NioGroup(logger, daemonThreadFactory(this.settings, TRANSPORT_ACCEPTOR_THREAD_NAME_PREFIX), acceptorCount,
AcceptorEventHandler::new, daemonThreadFactory(this.settings, TRANSPORT_WORKER_THREAD_NAME_PREFIX),
NioTransport.NIO_WORKER_COUNT.get(settings), this::getSocketEventHandler);
NioTransport.NIO_WORKER_COUNT.get(settings), SocketEventHandler::new);
ProfileSettings clientProfileSettings = new ProfileSettings(settings, "default");
clientChannelFactory = new TcpChannelFactory(clientProfileSettings, getContextSetter("client"), getServerContextSetter());
clientChannelFactory = new TcpChannelFactory(clientProfileSettings, getContextSetter(), getServerContextSetter());
if (useNetworkServer) {
// loop through all profiles and start them up, special handling for default one
for (ProfileSettings profileSettings : profileSettings) {
String profileName = profileSettings.profileName;
Consumer<NioSocketChannel> contextSetter = getContextSetter(profileName);
TcpChannelFactory factory = new TcpChannelFactory(profileSettings, contextSetter, getServerContextSetter());
TcpChannelFactory factory = new TcpChannelFactory(profileSettings, getContextSetter(), getServerContextSetter());
profileToChannelFactory.putIfAbsent(profileName, factory);
bindServer(profileSettings);
}
@ -141,22 +143,20 @@ public class NioTransport extends TcpTransport {
profileToChannelFactory.clear();
}
protected SocketEventHandler getSocketEventHandler(Logger logger) {
return new SocketEventHandler(logger);
}
final void exceptionCaught(NioSocketChannel channel, Exception exception) {
onException((TcpNioSocketChannel) channel, exception);
onException((TcpChannel) channel, exception);
}
private Consumer<NioSocketChannel> getContextSetter(String profileName) {
private Consumer<TcpNioSocketChannel> getContextSetter() {
return (c) -> {
Supplier<InboundChannelBuffer.Page> pageSupplier = () -> {
Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close);
};
c.setContexts(new TcpReadContext(c, new TcpReadHandler(profileName, this), new InboundChannelBuffer(pageSupplier)),
new TcpWriteContext(c), this::exceptionCaught);
ReadContext.ReadConsumer nioReadConsumer = channelBuffer ->
consumeNetworkReads(c, BytesReference.fromByteBuffers(channelBuffer.sliceBuffersTo(channelBuffer.getIndex())));
BytesReadContext readContext = new BytesReadContext(c, nioReadConsumer, new InboundChannelBuffer(pageSupplier));
c.setContexts(readContext, new BytesWriteContext(c), this::exceptionCaught);
};
}
@ -165,7 +165,7 @@ public class NioTransport extends TcpTransport {
}
private Consumer<NioServerSocketChannel> getServerContextSetter() {
private Consumer<TcpNioServerSocketChannel> getServerContextSetter() {
return (c) -> c.setAcceptContext(this::acceptChannel);
}
}

View File

@ -16,10 +16,14 @@
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.elasticsearch.bootstrap.BootstrapCheck;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.network.NetworkModule;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.PageCacheRecycler;
@ -29,7 +33,9 @@ import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.Transport;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
@ -37,21 +43,27 @@ public class NioTransportPlugin extends Plugin implements NetworkPlugin {
public static final String NIO_TRANSPORT_NAME = "nio-transport";
@Override
public List<Setting<?>> getSettings() {
return Arrays.asList(
NioTransport.NIO_WORKER_COUNT,
NioTransport.NIO_ACCEPTOR_COUNT
);
}
@Override
public Map<String, Supplier<Transport>> getTransports(Settings settings, ThreadPool threadPool, BigArrays bigArrays,
PageCacheRecycler pageCacheRecycler,
CircuitBreakerService circuitBreakerService,
NamedWriteableRegistry namedWriteableRegistry,
NetworkService networkService) {
Settings settings1;
if (NioTransport.NIO_WORKER_COUNT.exists(settings) == false) {
// As this is only used for tests right now, limit the number of worker threads.
settings1 = Settings.builder().put(settings).put(NioTransport.NIO_WORKER_COUNT.getKey(), 2).build();
} else {
settings1 = settings;
}
return Collections.singletonMap(NIO_TRANSPORT_NAME,
() -> new NioTransport(settings1, threadPool, networkService, bigArrays, pageCacheRecycler, namedWriteableRegistry,
() -> new NioTransport(settings, threadPool, networkService, bigArrays, pageCacheRecycler, namedWriteableRegistry,
circuitBreakerService));
}
@Override
public List<BootstrapCheck> getBootstrapChecks() {
return Collections.singletonList(new NioNotEnabledBootstrapCheck());
}
}

View File

@ -19,12 +19,10 @@
package org.elasticsearch.transport.nio;
import org.elasticsearch.nio.ChannelFactory;
import org.elasticsearch.nio.NioServerSocketChannel;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.nio.AcceptingSelector;
import org.elasticsearch.nio.ChannelFactory;
import org.elasticsearch.nio.SocketSelector;
import org.elasticsearch.transport.TcpTransport;
import java.io.IOException;
import java.nio.channels.ServerSocketChannel;
@ -39,30 +37,32 @@ import java.util.function.Consumer;
*/
public class TcpChannelFactory extends ChannelFactory<TcpNioServerSocketChannel, TcpNioSocketChannel> {
private final Consumer<NioSocketChannel> contextSetter;
private final Consumer<NioServerSocketChannel> serverContextSetter;
private final Consumer<TcpNioSocketChannel> contextSetter;
private final Consumer<TcpNioServerSocketChannel> serverContextSetter;
private final String profileName;
TcpChannelFactory(TcpTransport.ProfileSettings profileSettings, Consumer<NioSocketChannel> contextSetter,
Consumer<NioServerSocketChannel> serverContextSetter) {
TcpChannelFactory(TcpTransport.ProfileSettings profileSettings, Consumer<TcpNioSocketChannel> contextSetter,
Consumer<TcpNioServerSocketChannel> serverContextSetter) {
super(new RawChannelFactory(profileSettings.tcpNoDelay,
profileSettings.tcpKeepAlive,
profileSettings.reuseAddress,
Math.toIntExact(profileSettings.sendBufferSize.getBytes()),
Math.toIntExact(profileSettings.receiveBufferSize.getBytes())));
this.profileName = profileSettings.profileName;
this.contextSetter = contextSetter;
this.serverContextSetter = serverContextSetter;
}
@Override
public TcpNioSocketChannel createChannel(SocketSelector selector, SocketChannel channel) throws IOException {
TcpNioSocketChannel nioChannel = new TcpNioSocketChannel(channel, selector);
TcpNioSocketChannel nioChannel = new TcpNioSocketChannel(profileName, channel, selector);
contextSetter.accept(nioChannel);
return nioChannel;
}
@Override
public TcpNioServerSocketChannel createServerChannel(AcceptingSelector selector, ServerSocketChannel channel) throws IOException {
TcpNioServerSocketChannel nioServerChannel = new TcpNioServerSocketChannel(channel, this, selector);
TcpNioServerSocketChannel nioServerChannel = new TcpNioServerSocketChannel(profileName, channel, this, selector);
serverContextSetter.accept(nioServerChannel);
return nioServerChannel;
}

View File

@ -21,11 +21,12 @@ package org.elasticsearch.transport.nio;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.nio.AcceptingSelector;
import org.elasticsearch.nio.NioServerSocketChannel;
import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.nio.AcceptingSelector;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.channels.ServerSocketChannel;
/**
@ -34,9 +35,12 @@ import java.nio.channels.ServerSocketChannel;
*/
public class TcpNioServerSocketChannel extends NioServerSocketChannel implements TcpChannel {
TcpNioServerSocketChannel(ServerSocketChannel socketChannel, TcpChannelFactory channelFactory, AcceptingSelector selector)
throws IOException {
private final String profile;
TcpNioServerSocketChannel(String profile, ServerSocketChannel socketChannel, TcpChannelFactory channelFactory,
AcceptingSelector selector) throws IOException {
super(socketChannel, channelFactory, selector);
this.profile = profile;
}
@Override
@ -49,6 +53,16 @@ public class TcpNioServerSocketChannel extends NioServerSocketChannel implements
throw new UnsupportedOperationException("Cannot set SO_LINGER on a server channel.");
}
@Override
public InetSocketAddress getRemoteAddress() {
return null;
}
@Override
public String getProfile() {
return profile;
}
@Override
public void addCloseListener(ActionListener<Void> listener) {
addCloseListener(ActionListener.toBiConsumer(listener));

View File

@ -22,8 +22,8 @@ package org.elasticsearch.transport.nio;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.nio.SocketSelector;
import org.elasticsearch.transport.TcpChannel;
import java.io.IOException;
import java.net.StandardSocketOptions;
@ -31,12 +31,15 @@ import java.nio.channels.SocketChannel;
public class TcpNioSocketChannel extends NioSocketChannel implements TcpChannel {
public TcpNioSocketChannel(SocketChannel socketChannel, SocketSelector selector) throws IOException {
private final String profile;
TcpNioSocketChannel(String profile, SocketChannel socketChannel, SocketSelector selector) throws IOException {
super(socketChannel, selector);
this.profile = profile;
}
public void sendMessage(BytesReference reference, ActionListener<Void> listener) {
getWriteContext().sendMessage(reference, ActionListener.toBiConsumer(listener));
getWriteContext().sendMessage(BytesReference.toByteBuffers(reference), ActionListener.toBiConsumer(listener));
}
@Override
@ -46,6 +49,11 @@ public class TcpNioSocketChannel extends NioSocketChannel implements TcpChannel
}
}
@Override
public String getProfile() {
return profile;
}
@Override
public void addCloseListener(ActionListener<Void> listener) {
addCloseListener(ActionListener.toBiConsumer(listener));

View File

@ -29,14 +29,14 @@ public class TcpReadHandler {
private final String profile;
private final NioTransport transport;
public TcpReadHandler(String profile, NioTransport transport) {
TcpReadHandler(String profile, NioTransport transport) {
this.profile = profile;
this.transport = transport;
}
public void handleMessage(BytesReference reference, TcpNioSocketChannel channel, int messageBytesLength) {
try {
transport.messageReceived(reference, channel, profile, channel.getRemoteAddress(), messageBytesLength);
transport.messageReceived(reference, channel);
} catch (IOException e) {
handleException(channel, e);
}

View File

@ -0,0 +1,71 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch;
import org.elasticsearch.common.network.NetworkModule;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.transport.nio.NioTransport;
import org.elasticsearch.transport.nio.NioTransportPlugin;
import java.util.Collection;
import java.util.Collections;
public abstract class NioIntegTestCase extends ESIntegTestCase {
@Override
protected boolean ignoreExternalCluster() {
return true;
}
@Override
protected boolean addMockTransportService() {
return false;
}
@Override
protected Settings nodeSettings(int nodeOrdinal) {
Settings.Builder builder = Settings.builder().put(super.nodeSettings(nodeOrdinal));
// randomize netty settings
if (randomBoolean()) {
builder.put(NioTransport.NIO_WORKER_COUNT.getKey(), random().nextInt(3) + 1);
}
builder.put(NetworkModule.TRANSPORT_TYPE_KEY, NioTransportPlugin.NIO_TRANSPORT_NAME);
return builder.build();
}
@Override
protected Settings transportClientSettings() {
Settings.Builder builder = Settings.builder().put(super.transportClientSettings());
builder.put(NetworkModule.TRANSPORT_TYPE_KEY, NioTransportPlugin.NIO_TRANSPORT_NAME);
return builder.build();
}
@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Collections.singletonList(NioTransportPlugin.class);
}
@Override
protected Collection<Class<? extends Plugin>> transportClientPlugins() {
return Collections.singletonList(NioTransportPlugin.class);
}
}

View File

@ -0,0 +1,132 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.NioIntegTestCase;
import org.elasticsearch.Version;
import org.elasticsearch.action.admin.cluster.health.ClusterHealthResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.health.ClusterHealthStatus;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.network.NetworkModule;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.plugins.NetworkPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.test.ESIntegTestCase.ClusterScope;
import org.elasticsearch.test.ESIntegTestCase.Scope;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.Transport;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
@ClusterScope(scope = Scope.TEST, supportsDedicatedMasters = false, numDataNodes = 1)
public class NioTransportIT extends NioIntegTestCase {
// static so we can use it in anonymous classes
private static String channelProfileName = null;
@Override
protected Settings nodeSettings(int nodeOrdinal) {
return Settings.builder().put(super.nodeSettings(nodeOrdinal))
.put(NetworkModule.TRANSPORT_TYPE_KEY, "exception-throwing").build();
}
@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
List<Class<? extends Plugin>> list = new ArrayList<>();
list.add(ExceptionThrowingNioTransport.TestPlugin.class);
list.addAll(super.nodePlugins());
return Collections.unmodifiableCollection(list);
}
public void testThatConnectionFailsAsIntended() throws Exception {
Client transportClient = internalCluster().transportClient();
ClusterHealthResponse clusterIndexHealths = transportClient.admin().cluster().prepareHealth().get();
assertThat(clusterIndexHealths.getStatus(), is(ClusterHealthStatus.GREEN));
try {
transportClient.filterWithHeader(Collections.singletonMap("ERROR", "MY MESSAGE")).admin().cluster().prepareHealth().get();
fail("Expected exception, but didn't happen");
} catch (ElasticsearchException e) {
assertThat(e.getMessage(), containsString("MY MESSAGE"));
assertThat(channelProfileName, is(TcpTransport.DEFAULT_PROFILE));
}
}
public static final class ExceptionThrowingNioTransport extends NioTransport {
public static class TestPlugin extends Plugin implements NetworkPlugin {
@Override
public Map<String, Supplier<Transport>> getTransports(Settings settings, ThreadPool threadPool, BigArrays bigArrays,
PageCacheRecycler pageCacheRecycler,
CircuitBreakerService circuitBreakerService,
NamedWriteableRegistry namedWriteableRegistry,
NetworkService networkService) {
return Collections.singletonMap("exception-throwing",
() -> new ExceptionThrowingNioTransport(settings, threadPool, networkService, bigArrays, pageCacheRecycler,
namedWriteableRegistry, circuitBreakerService));
}
}
ExceptionThrowingNioTransport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays,
PageCacheRecycler pageCacheRecycler, NamedWriteableRegistry namedWriteableRegistry,
CircuitBreakerService circuitBreakerService) {
super(settings, threadPool, networkService, bigArrays, pageCacheRecycler, namedWriteableRegistry, circuitBreakerService);
}
@Override
protected String handleRequest(TcpChannel channel, String profileName,
StreamInput stream, long requestId, int messageLengthBytes, Version version,
InetSocketAddress remoteAddress, byte status) throws IOException {
String action = super.handleRequest(channel, profileName, stream, requestId, messageLengthBytes, version,
remoteAddress, status);
channelProfileName = TcpTransport.DEFAULT_PROFILE;
return action;
}
@Override
protected void validateRequest(StreamInput buffer, long requestId, String action)
throws IOException {
super.validateRequest(buffer, requestId, action);
String error = threadPool.getThreadContext().getHeader("ERROR");
if (error != null) {
throw new ElasticsearchException(error);
}
}
}
}

View File

@ -19,7 +19,6 @@
package org.elasticsearch.transport.nio;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.Version;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
@ -31,7 +30,6 @@ import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.MockPageCacheRecycler;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.nio.SocketEventHandler;
import org.elasticsearch.node.Node;
import org.elasticsearch.test.transport.MockTransportService;
import org.elasticsearch.threadpool.ThreadPool;
@ -77,11 +75,6 @@ public class SimpleNioTransportTests extends AbstractSimpleTransportTestCase {
protected Version getCurrentVersion() {
return version;
}
@Override
protected SocketEventHandler getSocketEventHandler(Logger logger) {
return new TestingSocketEventHandler(logger);
}
};
MockTransportService mockTransportService =
MockTransportService.createNewService(Settings.EMPTY, transport, version, threadPool, clusterSettings);

View File

@ -32,7 +32,7 @@ import org.elasticsearch.env.Environment;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.transport.MockTcpTransportPlugin;
import org.elasticsearch.transport.client.PreBuiltTransportClient;
import org.elasticsearch.transport.nio.NioTransportPlugin;
import org.elasticsearch.transport.nio.MockNioTransportPlugin;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
@ -86,8 +86,8 @@ public abstract class ESSmokeClientTestCase extends LuceneTestCase {
String transportKey;
Class<? extends Plugin> transportPlugin;
if (usNio) {
transportKey = NioTransportPlugin.NIO_TRANSPORT_NAME;
transportPlugin = NioTransportPlugin.class;
transportKey = MockNioTransportPlugin.MOCK_NIO_TRANSPORT_NAME;
transportPlugin = MockNioTransportPlugin.class;
} else {
transportKey = MockTcpTransportPlugin.MOCK_TCP_TRANSPORT_NAME;
transportPlugin = MockTcpTransportPlugin.class;

View File

@ -24,7 +24,7 @@ import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.transport.MockTcpTransportPlugin;
import org.elasticsearch.transport.Netty4Plugin;
import org.elasticsearch.transport.nio.NioTransportPlugin;
import org.elasticsearch.transport.nio.MockNioTransportPlugin;
import org.junit.BeforeClass;
import java.util.Arrays;
@ -47,8 +47,8 @@ public abstract class HttpSmokeTestCase extends ESIntegTestCase {
private static String getTypeKey(Class<? extends Plugin> clazz) {
if (clazz.equals(MockTcpTransportPlugin.class)) {
return MockTcpTransportPlugin.MOCK_TCP_TRANSPORT_NAME;
} else if (clazz.equals(NioTransportPlugin.class)) {
return NioTransportPlugin.NIO_TRANSPORT_NAME;
} else if (clazz.equals(MockNioTransportPlugin.class)) {
return MockNioTransportPlugin.MOCK_NIO_TRANSPORT_NAME;
} else {
assert clazz.equals(Netty4Plugin.class);
return Netty4Plugin.NETTY_TRANSPORT_NAME;

View File

@ -64,6 +64,7 @@ List projects = [
'plugins:repository-s3',
'plugins:jvm-example',
'plugins:store-smb',
'plugins:transport-nio',
'qa:auto-create-index',
'qa:ccs-unavailable-clusters',
'qa:evil-tests',

View File

@ -105,7 +105,7 @@ import org.elasticsearch.test.junit.listeners.LoggingListener;
import org.elasticsearch.test.junit.listeners.ReproduceInfoPrinter;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.MockTcpTransportPlugin;
import org.elasticsearch.transport.nio.NioTransportPlugin;
import org.elasticsearch.transport.nio.MockNioTransportPlugin;
import org.joda.time.DateTimeZone;
import org.junit.After;
import org.junit.AfterClass;
@ -900,11 +900,11 @@ public abstract class ESTestCase extends LuceneTestCase {
}
public static String getTestTransportType() {
return useNio ? NioTransportPlugin.NIO_TRANSPORT_NAME : MockTcpTransportPlugin.MOCK_TCP_TRANSPORT_NAME;
return useNio ? MockNioTransportPlugin.MOCK_NIO_TRANSPORT_NAME : MockTcpTransportPlugin.MOCK_TCP_TRANSPORT_NAME;
}
public static Class<? extends Plugin> getTestTransportPlugin() {
return useNio ? NioTransportPlugin.class : MockTcpTransportPlugin.class;
return useNio ? MockNioTransportPlugin.class : MockTcpTransportPlugin.class;
}
private static final GeohashGenerator geohashGenerator = new GeohashGenerator();

View File

@ -25,7 +25,6 @@ import org.elasticsearch.action.admin.cluster.node.info.NodesInfoResponse;
import org.elasticsearch.action.admin.cluster.node.stats.NodeStats;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.transport.TransportClient;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
@ -37,7 +36,7 @@ import org.elasticsearch.env.Environment;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.transport.MockTcpTransportPlugin;
import org.elasticsearch.transport.MockTransportClient;
import org.elasticsearch.transport.nio.NioTransportPlugin;
import org.elasticsearch.transport.nio.MockNioTransportPlugin;
import java.io.IOException;
import java.net.InetSocketAddress;
@ -86,10 +85,10 @@ public final class ExternalTestCluster extends TestCluster {
String transport = getTestTransportType();
clientSettingsBuilder.put(NetworkModule.TRANSPORT_TYPE_KEY, transport);
if (pluginClasses.contains(MockTcpTransportPlugin.class) == false &&
pluginClasses.contains(NioTransportPlugin.class) == false) {
pluginClasses.contains(MockNioTransportPlugin.class) == false) {
pluginClasses = new ArrayList<>(pluginClasses);
if (transport.equals(NioTransportPlugin.NIO_TRANSPORT_NAME)) {
pluginClasses.add(NioTransportPlugin.class);
if (transport.equals(MockNioTransportPlugin.MOCK_NIO_TRANSPORT_NAME)) {
pluginClasses.add(MockNioTransportPlugin.class);
} else {
pluginClasses.add(MockTcpTransportPlugin.class);
}

View File

@ -159,14 +159,7 @@ public class MockTcpTransport extends TcpTransport {
output.write(minimalHeader);
output.writeInt(msgSize);
output.write(buffer);
final BytesReference bytes = output.bytes();
if (TcpTransport.validateMessageHeader(bytes)) {
InetSocketAddress remoteAddress = (InetSocketAddress) socket.getRemoteSocketAddress();
messageReceived(bytes.slice(TcpHeader.MARKER_BYTES_SIZE + TcpHeader.MESSAGE_LENGTH_SIZE, msgSize),
mockChannel, mockChannel.profile, remoteAddress, msgSize);
} else {
// ping message - we just drop all stuff
}
consumeNetworkReads(mockChannel, output.bytes());
}
}
@ -357,6 +350,11 @@ public class MockTcpTransport extends TcpTransport {
}
}
@Override
public String getProfile() {
return profile;
}
@Override
public void addCloseListener(ActionListener<Void> listener) {
closeFuture.whenComplete(ActionListener.toBiConsumer(listener));
@ -380,6 +378,11 @@ public class MockTcpTransport extends TcpTransport {
return localAddress;
}
@Override
public InetSocketAddress getRemoteAddress() {
return (InetSocketAddress) activeChannel.getRemoteSocketAddress();
}
@Override
public void sendMessage(BytesReference reference, ActionListener<Void> listener) {
try {

View File

@ -23,7 +23,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.network.NetworkModule;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.transport.nio.NioTransportPlugin;
import org.elasticsearch.transport.nio.MockNioTransportPlugin;
import java.util.ArrayList;
import java.util.Arrays;
@ -59,12 +59,12 @@ public class MockTransportClient extends TransportClient {
plugins.add(MockTcpTransportPlugin.class);
return plugins;
}
} else if (NioTransportPlugin.NIO_TRANSPORT_NAME.equals(transportType)) {
if (plugins.contains(NioTransportPlugin.class)) {
} else if (MockNioTransportPlugin.MOCK_NIO_TRANSPORT_NAME.equals(transportType)) {
if (plugins.contains(MockNioTransportPlugin.class)) {
return plugins;
} else {
plugins = new ArrayList<>(plugins);
plugins.add(NioTransportPlugin.class);
plugins.add(MockNioTransportPlugin.class);
return plugins;
}
}

View File

@ -0,0 +1,252 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.recycler.Recycler;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.nio.AcceptingSelector;
import org.elasticsearch.nio.AcceptorEventHandler;
import org.elasticsearch.nio.BytesReadContext;
import org.elasticsearch.nio.BytesWriteContext;
import org.elasticsearch.nio.ChannelFactory;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioGroup;
import org.elasticsearch.nio.NioServerSocketChannel;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.ReadContext;
import org.elasticsearch.nio.SocketSelector;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.Transports;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.StandardSocketOptions;
import java.nio.ByteBuffer;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.concurrent.ConcurrentMap;
import java.util.function.Supplier;
import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap;
import static org.elasticsearch.common.util.concurrent.EsExecutors.daemonThreadFactory;
public class MockNioTransport extends TcpTransport {
private static final String TRANSPORT_WORKER_THREAD_NAME_PREFIX = Transports.NIO_TRANSPORT_WORKER_THREAD_NAME_PREFIX;
private static final String TRANSPORT_ACCEPTOR_THREAD_NAME_PREFIX = Transports.NIO_TRANSPORT_ACCEPTOR_THREAD_NAME_PREFIX;
private final PageCacheRecycler pageCacheRecycler;
private final ConcurrentMap<String, MockTcpChannelFactory> profileToChannelFactory = newConcurrentMap();
private volatile NioGroup nioGroup;
private volatile MockTcpChannelFactory clientChannelFactory;
MockNioTransport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays,
PageCacheRecycler pageCacheRecycler, NamedWriteableRegistry namedWriteableRegistry,
CircuitBreakerService circuitBreakerService) {
super("mock-nio", settings, threadPool, bigArrays, circuitBreakerService, namedWriteableRegistry, networkService);
this.pageCacheRecycler = pageCacheRecycler;
}
@Override
protected MockServerChannel bind(String name, InetSocketAddress address) throws IOException {
MockTcpChannelFactory channelFactory = this.profileToChannelFactory.get(name);
return nioGroup.bindServerChannel(address, channelFactory);
}
@Override
protected MockSocketChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> connectListener)
throws IOException {
MockSocketChannel channel = nioGroup.openChannel(node.getAddress().address(), clientChannelFactory);
channel.addConnectListener(ActionListener.toBiConsumer(connectListener));
return channel;
}
@Override
protected void doStart() {
boolean success = false;
try {
int acceptorCount = 0;
boolean useNetworkServer = NetworkService.NETWORK_SERVER.get(settings);
if (useNetworkServer) {
acceptorCount = 1;
}
nioGroup = new NioGroup(logger, daemonThreadFactory(this.settings, TRANSPORT_ACCEPTOR_THREAD_NAME_PREFIX), acceptorCount,
AcceptorEventHandler::new, daemonThreadFactory(this.settings, TRANSPORT_WORKER_THREAD_NAME_PREFIX),
2, TestingSocketEventHandler::new);
ProfileSettings clientProfileSettings = new ProfileSettings(settings, "default");
clientChannelFactory = new MockTcpChannelFactory(clientProfileSettings, "client");
if (useNetworkServer) {
// loop through all profiles and start them up, special handling for default one
for (ProfileSettings profileSettings : profileSettings) {
String profileName = profileSettings.profileName;
MockTcpChannelFactory factory = new MockTcpChannelFactory(profileSettings, profileName);
profileToChannelFactory.putIfAbsent(profileName, factory);
bindServer(profileSettings);
}
}
super.doStart();
success = true;
} catch (IOException e) {
throw new ElasticsearchException(e);
} finally {
if (success == false) {
doStop();
}
}
}
@Override
protected void stopInternal() {
try {
nioGroup.close();
} catch (Exception e) {
logger.warn("unexpected exception while stopping nio group", e);
}
profileToChannelFactory.clear();
}
private void exceptionCaught(NioSocketChannel channel, Exception exception) {
onException((TcpChannel) channel, exception);
}
private void acceptChannel(NioSocketChannel channel) {
serverAcceptedChannel((TcpChannel) channel);
}
private class MockTcpChannelFactory extends ChannelFactory<MockServerChannel, MockSocketChannel> {
private final String profileName;
private MockTcpChannelFactory(ProfileSettings profileSettings, String profileName) {
super(new RawChannelFactory(profileSettings.tcpNoDelay,
profileSettings.tcpKeepAlive,
profileSettings.reuseAddress,
Math.toIntExact(profileSettings.sendBufferSize.getBytes()),
Math.toIntExact(profileSettings.receiveBufferSize.getBytes())));
this.profileName = profileName;
}
@Override
public MockSocketChannel createChannel(SocketSelector selector, SocketChannel channel) throws IOException {
MockSocketChannel nioChannel = new MockSocketChannel(profileName, channel, selector);
Supplier<InboundChannelBuffer.Page> pageSupplier = () -> {
Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close);
};
ReadContext.ReadConsumer nioReadConsumer = channelBuffer ->
consumeNetworkReads(nioChannel, BytesReference.fromByteBuffers(channelBuffer.sliceBuffersTo(channelBuffer.getIndex())));
BytesReadContext readContext = new BytesReadContext(nioChannel, nioReadConsumer, new InboundChannelBuffer(pageSupplier));
BytesWriteContext writeContext = new BytesWriteContext(nioChannel);
nioChannel.setContexts(readContext, writeContext, MockNioTransport.this::exceptionCaught);
return nioChannel;
}
@Override
public MockServerChannel createServerChannel(AcceptingSelector selector, ServerSocketChannel channel) throws IOException {
MockServerChannel nioServerChannel = new MockServerChannel(profileName, channel, this, selector);
nioServerChannel.setAcceptContext(MockNioTransport.this::acceptChannel);
return nioServerChannel;
}
}
private static class MockServerChannel extends NioServerSocketChannel implements TcpChannel {
private final String profile;
MockServerChannel(String profile, ServerSocketChannel channel, ChannelFactory<?, ?> channelFactory, AcceptingSelector selector)
throws IOException {
super(channel, channelFactory, selector);
this.profile = profile;
}
@Override
public String getProfile() {
return profile;
}
@Override
public void addCloseListener(ActionListener<Void> listener) {
addCloseListener(ActionListener.toBiConsumer(listener));
}
@Override
public void setSoLinger(int value) throws IOException {
throw new UnsupportedOperationException("Cannot set SO_LINGER on a server channel.");
}
@Override
public InetSocketAddress getRemoteAddress() {
return null;
}
@Override
public void sendMessage(BytesReference reference, ActionListener<Void> listener) {
throw new UnsupportedOperationException("Cannot send a message to a server channel.");
}
}
private static class MockSocketChannel extends NioSocketChannel implements TcpChannel {
private final String profile;
private MockSocketChannel(String profile, java.nio.channels.SocketChannel socketChannel, SocketSelector selector)
throws IOException {
super(socketChannel, selector);
this.profile = profile;
}
@Override
public String getProfile() {
return profile;
}
@Override
public void addCloseListener(ActionListener<Void> listener) {
addCloseListener(ActionListener.toBiConsumer(listener));
}
@Override
public void setSoLinger(int value) throws IOException {
if (isOpen()) {
getRawChannel().setOption(StandardSocketOptions.SO_LINGER, value);
}
}
@Override
public void sendMessage(BytesReference reference, ActionListener<Void> listener) {
getWriteContext().sendMessage(BytesReference.toByteBuffers(reference), ActionListener.toBiConsumer(listener));
}
}
}

View File

@ -0,0 +1,50 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.plugins.NetworkPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.Transport;
import java.util.Collections;
import java.util.Map;
import java.util.function.Supplier;
public class MockNioTransportPlugin extends Plugin implements NetworkPlugin {
public static final String MOCK_NIO_TRANSPORT_NAME = "mock-nio";
@Override
public Map<String, Supplier<Transport>> getTransports(Settings settings, ThreadPool threadPool, BigArrays bigArrays,
PageCacheRecycler pageCacheRecycler,
CircuitBreakerService circuitBreakerService,
NamedWriteableRegistry namedWriteableRegistry,
NetworkService networkService) {
return Collections.singletonMap(MOCK_NIO_TRANSPORT_NAME,
() -> new MockNioTransport(settings, threadPool, networkService, bigArrays, pageCacheRecycler, namedWriteableRegistry,
circuitBreakerService));
}
}

View File

@ -1,118 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.monitor.jvm.JvmInfo;
import org.elasticsearch.transport.TcpHeader;
import org.elasticsearch.transport.TcpTransport;
import java.io.IOException;
import java.io.StreamCorruptedException;
public class TcpFrameDecoder {
private static final long NINETY_PER_HEAP_SIZE = (long) (JvmInfo.jvmInfo().getMem().getHeapMax().getBytes() * 0.9);
private static final int HEADER_SIZE = TcpHeader.MARKER_BYTES_SIZE + TcpHeader.MESSAGE_LENGTH_SIZE;
private int expectedMessageLength = -1;
public BytesReference decode(BytesReference bytesReference) throws IOException {
if (bytesReference.length() >= 6) {
int messageLength = readHeaderBuffer(bytesReference);
int totalLength = messageLength + HEADER_SIZE;
if (totalLength > bytesReference.length()) {
expectedMessageLength = totalLength;
return null;
} else if (totalLength == bytesReference.length()) {
expectedMessageLength = -1;
return bytesReference;
} else {
expectedMessageLength = -1;
return bytesReference.slice(0, totalLength);
}
} else {
return null;
}
}
public int expectedMessageLength() {
return expectedMessageLength;
}
private int readHeaderBuffer(BytesReference headerBuffer) throws IOException {
if (headerBuffer.get(0) != 'E' || headerBuffer.get(1) != 'S') {
if (appearsToBeHTTP(headerBuffer)) {
throw new TcpTransport.HttpOnTransportException("This is not a HTTP port");
}
throw new StreamCorruptedException("invalid internal transport message format, got ("
+ Integer.toHexString(headerBuffer.get(0) & 0xFF) + ","
+ Integer.toHexString(headerBuffer.get(1) & 0xFF) + ","
+ Integer.toHexString(headerBuffer.get(2) & 0xFF) + ","
+ Integer.toHexString(headerBuffer.get(3) & 0xFF) + ")");
}
final int messageLength;
try (StreamInput input = headerBuffer.streamInput()) {
input.skip(TcpHeader.MARKER_BYTES_SIZE);
messageLength = input.readInt();
}
if (messageLength == -1) {
// This is a ping
return 0;
}
if (messageLength <= 0) {
throw new StreamCorruptedException("invalid data length: " + messageLength);
}
if (messageLength > NINETY_PER_HEAP_SIZE) {
throw new IllegalArgumentException("transport content length received [" + new ByteSizeValue(messageLength) + "] exceeded ["
+ new ByteSizeValue(NINETY_PER_HEAP_SIZE) + "]");
}
return messageLength;
}
private static boolean appearsToBeHTTP(BytesReference headerBuffer) {
return bufferStartsWith(headerBuffer, "GET") ||
bufferStartsWith(headerBuffer, "POST") ||
bufferStartsWith(headerBuffer, "PUT") ||
bufferStartsWith(headerBuffer, "HEAD") ||
bufferStartsWith(headerBuffer, "DELETE") ||
// TODO: Actually 'OPTIONS'. But that does not currently fit in 6 bytes
bufferStartsWith(headerBuffer, "OPTION") ||
bufferStartsWith(headerBuffer, "PATCH") ||
bufferStartsWith(headerBuffer, "TRACE");
}
private static boolean bufferStartsWith(BytesReference buffer, String method) {
char[] chars = method.toCharArray();
for (int i = 0; i < chars.length; i++) {
if (buffer.get(i) != chars[i]) {
return false;
}
}
return true;
}
}

View File

@ -1,96 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.elasticsearch.common.bytes.ByteBufferReference;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.bytes.CompositeBytesReference;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.ReadContext;
import java.io.IOException;
import java.nio.ByteBuffer;
public class TcpReadContext implements ReadContext {
private final TcpReadHandler handler;
private final TcpNioSocketChannel channel;
private final InboundChannelBuffer channelBuffer;
private final TcpFrameDecoder frameDecoder = new TcpFrameDecoder();
public TcpReadContext(NioSocketChannel channel, TcpReadHandler handler, InboundChannelBuffer channelBuffer) {
this.handler = handler;
this.channel = (TcpNioSocketChannel) channel;
this.channelBuffer = channelBuffer;
}
@Override
public int read() throws IOException {
if (channelBuffer.getRemaining() == 0) {
// Requiring one additional byte will ensure that a new page is allocated.
channelBuffer.ensureCapacity(channelBuffer.getCapacity() + 1);
}
int bytesRead = channel.read(channelBuffer);
if (bytesRead == -1) {
return bytesRead;
}
BytesReference message;
// Frame decoder will throw an exception if the message is improperly formatted, the header is incorrect,
// or the message is corrupted
while ((message = frameDecoder.decode(toBytesReference(channelBuffer))) != null) {
int messageLengthWithHeader = message.length();
try {
BytesReference messageWithoutHeader = message.slice(6, message.length() - 6);
// A message length of 6 bytes it is just a ping. Ignore for now.
if (messageLengthWithHeader != 6) {
handler.handleMessage(messageWithoutHeader, channel, messageWithoutHeader.length());
}
} catch (Exception e) {
handler.handleException(channel, e);
} finally {
channelBuffer.release(messageLengthWithHeader);
}
}
return bytesRead;
}
@Override
public void close() {
channelBuffer.close();
}
private static BytesReference toBytesReference(InboundChannelBuffer channelBuffer) {
ByteBuffer[] writtenToBuffers = channelBuffer.sliceBuffersTo(channelBuffer.getIndex());
ByteBufferReference[] references = new ByteBufferReference[writtenToBuffers.length];
for (int i = 0; i < references.length; ++i) {
references[i] = new ByteBufferReference(writtenToBuffers[i]);
}
return new CompositeBytesReference(references);
}
}

View File

@ -20,8 +20,8 @@
package org.elasticsearch.transport.nio;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.nio.SocketEventHandler;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.SocketEventHandler;
import java.io.IOException;
import java.util.Collections;

View File

@ -0,0 +1,137 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.elasticsearch.Version;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.MockPageCacheRecycler;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.node.Node;
import org.elasticsearch.test.transport.MockTransportService;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.AbstractSimpleTransportTestCase;
import org.elasticsearch.transport.BindTransportException;
import org.elasticsearch.transport.ConnectTransportException;
import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportService;
import java.io.IOException;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.Collections;
import static java.util.Collections.emptyMap;
import static java.util.Collections.emptySet;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.instanceOf;
public class SimpleMockNioTransportTests extends AbstractSimpleTransportTestCase {
public static MockTransportService nioFromThreadPool(Settings settings, ThreadPool threadPool, final Version version,
ClusterSettings clusterSettings, boolean doHandshake) {
NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList());
NetworkService networkService = new NetworkService(Collections.emptyList());
Transport transport = new MockNioTransport(settings, threadPool,
networkService, BigArrays.NON_RECYCLING_INSTANCE, new MockPageCacheRecycler(settings), namedWriteableRegistry,
new NoneCircuitBreakerService()) {
@Override
protected Version executeHandshake(DiscoveryNode node, TcpChannel channel, TimeValue timeout) throws IOException,
InterruptedException {
if (doHandshake) {
return super.executeHandshake(node, channel, timeout);
} else {
return version.minimumCompatibilityVersion();
}
}
@Override
protected Version getCurrentVersion() {
return version;
}
};
MockTransportService mockTransportService =
MockTransportService.createNewService(Settings.EMPTY, transport, version, threadPool, clusterSettings);
mockTransportService.start();
return mockTransportService;
}
@Override
protected MockTransportService build(Settings settings, Version version, ClusterSettings clusterSettings, boolean doHandshake) {
settings = Settings.builder().put(settings)
.put(TcpTransport.PORT.getKey(), "0")
.build();
MockTransportService transportService = nioFromThreadPool(settings, threadPool, version, clusterSettings, doHandshake);
transportService.start();
return transportService;
}
@Override
protected void closeConnectionChannel(Transport transport, Transport.Connection connection) throws IOException {
@SuppressWarnings("unchecked")
TcpTransport.NodeChannels channels = (TcpTransport.NodeChannels) connection;
TcpChannel.closeChannels(channels.getChannels().subList(0, randomIntBetween(1, channels.getChannels().size())), true);
}
public void testConnectException() throws UnknownHostException {
try {
serviceA.connectToNode(new DiscoveryNode("C", new TransportAddress(InetAddress.getByName("localhost"), 9876),
emptyMap(), emptySet(),Version.CURRENT));
fail("Expected ConnectTransportException");
} catch (ConnectTransportException e) {
assertThat(e.getMessage(), containsString("connect_exception"));
assertThat(e.getMessage(), containsString("[127.0.0.1:9876]"));
Throwable cause = e.getCause();
assertThat(cause, instanceOf(IOException.class));
}
}
public void testBindUnavailableAddress() {
// this is on a lower level since it needs access to the TransportService before it's started
int port = serviceA.boundAddress().publishAddress().getPort();
Settings settings = Settings.builder()
.put(Node.NODE_NAME_SETTING.getKey(), "foobar")
.put(TransportService.TRACE_LOG_INCLUDE_SETTING.getKey(), "")
.put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), "NOTHING")
.put("transport.tcp.port", port)
.build();
ClusterSettings clusterSettings = new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
BindTransportException bindTransportException = expectThrows(BindTransportException.class, () -> {
MockTransportService transportService = nioFromThreadPool(settings, threadPool, Version.CURRENT, clusterSettings, true);
try {
transportService.start();
} finally {
transportService.stop();
transportService.close();
}
});
assertEquals("Failed to bind to ["+ port + "]", bindTransportException.getMessage());
}
}

View File

@ -1,167 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.transport.TcpTransport;
import java.io.IOException;
import java.io.StreamCorruptedException;
import static org.hamcrest.Matchers.instanceOf;
public class TcpFrameDecoderTests extends ESTestCase {
private TcpFrameDecoder frameDecoder = new TcpFrameDecoder();
public void testDefaultExceptedMessageLengthIsNegative1() {
assertEquals(-1, frameDecoder.expectedMessageLength());
}
public void testDecodeWithIncompleteHeader() throws IOException {
BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14);
streamOutput.write('E');
streamOutput.write('S');
streamOutput.write(1);
streamOutput.write(1);
assertNull(frameDecoder.decode(streamOutput.bytes()));
assertEquals(-1, frameDecoder.expectedMessageLength());
}
public void testDecodePing() throws IOException {
BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14);
streamOutput.write('E');
streamOutput.write('S');
streamOutput.writeInt(-1);
BytesReference message = frameDecoder.decode(streamOutput.bytes());
assertEquals(-1, frameDecoder.expectedMessageLength());
assertEquals(streamOutput.bytes(), message);
}
public void testDecodePingWithStartOfSecondMessage() throws IOException {
BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14);
streamOutput.write('E');
streamOutput.write('S');
streamOutput.writeInt(-1);
streamOutput.write('E');
streamOutput.write('S');
BytesReference message = frameDecoder.decode(streamOutput.bytes());
assertEquals(6, message.length());
assertEquals(streamOutput.bytes().slice(0, 6), message);
}
public void testDecodeMessage() throws IOException {
BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14);
streamOutput.write('E');
streamOutput.write('S');
streamOutput.writeInt(2);
streamOutput.write('M');
streamOutput.write('A');
BytesReference message = frameDecoder.decode(streamOutput.bytes());
assertEquals(-1, frameDecoder.expectedMessageLength());
assertEquals(streamOutput.bytes(), message);
}
public void testDecodeIncompleteMessage() throws IOException {
BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14);
streamOutput.write('E');
streamOutput.write('S');
streamOutput.writeInt(3);
streamOutput.write('M');
streamOutput.write('A');
BytesReference message = frameDecoder.decode(streamOutput.bytes());
assertEquals(9, frameDecoder.expectedMessageLength());
assertNull(message);
}
public void testInvalidLength() throws IOException {
BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14);
streamOutput.write('E');
streamOutput.write('S');
streamOutput.writeInt(-2);
streamOutput.write('M');
streamOutput.write('A');
try {
frameDecoder.decode(streamOutput.bytes());
fail("Expected exception");
} catch (Exception ex) {
assertThat(ex, instanceOf(StreamCorruptedException.class));
assertEquals("invalid data length: -2", ex.getMessage());
}
}
public void testInvalidHeader() throws IOException {
BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14);
streamOutput.write('E');
streamOutput.write('C');
byte byte1 = randomByte();
byte byte2 = randomByte();
streamOutput.write(byte1);
streamOutput.write(byte2);
streamOutput.write(randomByte());
streamOutput.write(randomByte());
streamOutput.write(randomByte());
try {
frameDecoder.decode(streamOutput.bytes());
fail("Expected exception");
} catch (Exception ex) {
assertThat(ex, instanceOf(StreamCorruptedException.class));
String expected = "invalid internal transport message format, got (45,43,"
+ Integer.toHexString(byte1 & 0xFF) + ","
+ Integer.toHexString(byte2 & 0xFF) + ")";
assertEquals(expected, ex.getMessage());
}
}
public void testHTTPHeader() throws IOException {
String[] httpHeaders = {"GET", "POST", "PUT", "HEAD", "DELETE", "OPTIONS", "PATCH", "TRACE"};
for (String httpHeader : httpHeaders) {
BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14);
for (char c : httpHeader.toCharArray()) {
streamOutput.write((byte) c);
}
streamOutput.write(new byte[6]);
try {
BytesReference bytes = streamOutput.bytes();
frameDecoder.decode(bytes);
fail("Expected exception");
} catch (Exception ex) {
assertThat(ex, instanceOf(TcpTransport.HttpOnTransportException.class));
assertEquals("This is not a HTTP port", ex.getMessage());
}
}
}
}

View File

@ -1,158 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.CompositeBytesReference;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.test.ESTestCase;
import org.junit.Before;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
public class TcpReadContextTests extends ESTestCase {
private TcpReadHandler handler;
private int messageLength;
private TcpNioSocketChannel channel;
private TcpReadContext readContext;
@Before
public void init() {
handler = mock(TcpReadHandler.class);
messageLength = randomInt(96) + 4;
channel = mock(TcpNioSocketChannel.class);
Supplier<InboundChannelBuffer.Page> pageSupplier = () ->
new InboundChannelBuffer.Page(ByteBuffer.allocate(BigArrays.BYTE_PAGE_SIZE), () -> {});
readContext = new TcpReadContext(channel, handler, new InboundChannelBuffer(pageSupplier));
}
public void testSuccessfulRead() throws IOException {
byte[] bytes = createMessage(messageLength);
byte[] fullMessage = combineMessageAndHeader(bytes);
final AtomicLong bufferCapacity = new AtomicLong();
when(channel.read(any(InboundChannelBuffer.class))).thenAnswer(invocationOnMock -> {
InboundChannelBuffer buffer = (InboundChannelBuffer) invocationOnMock.getArguments()[0];
ByteBuffer byteBuffer = buffer.sliceBuffersFrom(buffer.getIndex())[0];
bufferCapacity.set(buffer.getCapacity() - buffer.getIndex());
byteBuffer.put(fullMessage);
buffer.incrementIndex(fullMessage.length);
return fullMessage.length;
});
readContext.read();
verify(handler).handleMessage(new BytesArray(bytes), channel, messageLength);
assertEquals(1024 * 16, bufferCapacity.get());
BytesArray bytesArray = new BytesArray(new byte[10]);
bytesArray.slice(5, 5);
bytesArray.slice(5, 0);
}
public void testPartialRead() throws IOException {
byte[] part1 = createMessage(messageLength);
byte[] fullPart1 = combineMessageAndHeader(part1, messageLength + messageLength);
byte[] part2 = createMessage(messageLength);
final AtomicLong bufferCapacity = new AtomicLong();
final AtomicReference<byte[]> bytes = new AtomicReference<>();
when(channel.read(any(InboundChannelBuffer.class))).thenAnswer(invocationOnMock -> {
InboundChannelBuffer buffer = (InboundChannelBuffer) invocationOnMock.getArguments()[0];
ByteBuffer byteBuffer = buffer.sliceBuffersFrom(buffer.getIndex())[0];
bufferCapacity.set(buffer.getCapacity() - buffer.getIndex());
byteBuffer.put(bytes.get());
buffer.incrementIndex(bytes.get().length);
return bytes.get().length;
});
bytes.set(fullPart1);
readContext.read();
assertEquals(1024 * 16, bufferCapacity.get());
verifyZeroInteractions(handler);
bytes.set(part2);
readContext.read();
assertEquals(1024 * 16 - fullPart1.length, bufferCapacity.get());
CompositeBytesReference reference = new CompositeBytesReference(new BytesArray(part1), new BytesArray(part2));
verify(handler).handleMessage(reference, channel, messageLength + messageLength);
}
public void testReadThrowsIOException() throws IOException {
IOException ioException = new IOException();
when(channel.read(any())).thenThrow(ioException);
try {
readContext.read();
fail("Expected exception");
} catch (Exception ex) {
assertSame(ioException, ex);
}
}
public void closeClosesChannelBuffer() {
InboundChannelBuffer buffer = mock(InboundChannelBuffer.class);
TcpReadContext readContext = new TcpReadContext(channel, handler, buffer);
readContext.close();
verify(buffer).close();
}
private static byte[] combineMessageAndHeader(byte[] bytes) {
return combineMessageAndHeader(bytes, bytes.length);
}
private static byte[] combineMessageAndHeader(byte[] bytes, int messageLength) {
byte[] fullMessage = new byte[bytes.length + 6];
ByteBuffer wrapped = ByteBuffer.wrap(fullMessage);
wrapped.put((byte) 'E');
wrapped.put((byte) 'S');
wrapped.putInt(messageLength);
wrapped.put(bytes);
return fullMessage;
}
private static byte[] createMessage(int length) {
byte[] bytes = new byte[length];
for (int i = 0; i < length; ++i) {
bytes[i] = randomByte();
}
return bytes;
}
}