Introduce NioTransport into framework for testing (#24262)

This commit introduces a nio based tcp transport into framework for
testing.

Currently Elasticsearch uses a simple blocking tcp transport for
testing purposes (MockTcpTransport). This diverges from production
where our current transport (netty) is non-blocking.

The point of this commit is to introduce a testing variant that more
closely matches the behavior of production instances.
This commit is contained in:
Tim Brooks 2017-06-28 10:51:20 -05:00 committed by GitHub
parent 9ce9c21b83
commit 5f8be0e090
48 changed files with 5494 additions and 2 deletions

View File

@ -25,7 +25,7 @@ commonscodec = 1.10
hamcrest = 1.3
securemock = 1.2
# When updating mocksocket, please also update core/src/main/resources/org/elasticsearch/bootstrap/test-framework.policy
mocksocket = 1.1
mocksocket = 1.2
# benchmark dependencies
jmh = 1.17.3

View File

@ -58,7 +58,7 @@ grant codeBase "${codebase.junit-4.12.jar}" {
permission java.lang.RuntimePermission "accessDeclaredMembers";
};
grant codeBase "${codebase.mocksocket-1.1.jar}" {
grant codeBase "${codebase.mocksocket-1.2.jar}" {
// mocksocket makes and accepts socket connections
permission java.net.SocketPermission "*", "accept,connect";
};

View File

@ -0,0 +1,115 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.elasticsearch.transport.nio.channel.NioServerSocketChannel;
import java.io.IOException;
import java.nio.channels.CancelledKeyException;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.ClosedSelectorException;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.util.Iterator;
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedQueue;
/**
* Selector implementation that handles {@link NioServerSocketChannel}. It's main piece of functionality is
* accepting new channels.
*/
public class AcceptingSelector extends ESSelector {
private final AcceptorEventHandler eventHandler;
private final ConcurrentLinkedQueue<NioServerSocketChannel> newChannels = new ConcurrentLinkedQueue<>();
public AcceptingSelector(AcceptorEventHandler eventHandler) throws IOException {
super(eventHandler);
this.eventHandler = eventHandler;
}
public AcceptingSelector(AcceptorEventHandler eventHandler, Selector selector) throws IOException {
super(eventHandler, selector);
this.eventHandler = eventHandler;
}
@Override
void doSelect(int timeout) throws IOException, ClosedSelectorException {
setUpNewServerChannels();
int ready = selector.select(timeout);
if (ready > 0) {
Set<SelectionKey> selectionKeys = selector.selectedKeys();
Iterator<SelectionKey> keyIterator = selectionKeys.iterator();
while (keyIterator.hasNext()) {
SelectionKey sk = keyIterator.next();
keyIterator.remove();
acceptChannel(sk);
}
}
}
@Override
void cleanup() {
channelsToClose.addAll(registeredChannels);
closePendingChannels();
}
/**
* Registers a NioServerSocketChannel to be handled by this selector. The channel will by queued and
* eventually registered next time through the event loop.
* @param serverSocketChannel the channel to register
*/
public void registerServerChannel(NioServerSocketChannel serverSocketChannel) {
newChannels.add(serverSocketChannel);
wakeup();
}
private void setUpNewServerChannels() throws ClosedChannelException {
NioServerSocketChannel newChannel;
while ((newChannel = this.newChannels.poll()) != null) {
if (newChannel.register(this)) {
SelectionKey selectionKey = newChannel.getSelectionKey();
selectionKey.attach(newChannel);
registeredChannels.add(newChannel);
eventHandler.serverChannelRegistered(newChannel);
}
}
}
private void acceptChannel(SelectionKey sk) {
NioServerSocketChannel serverChannel = (NioServerSocketChannel) sk.attachment();
if (sk.isValid()) {
try {
if (sk.isAcceptable()) {
try {
eventHandler.acceptChannel(serverChannel);
} catch (IOException e) {
eventHandler.acceptException(serverChannel, e);
}
}
} catch (CancelledKeyException ex) {
eventHandler.genericServerChannelException(serverChannel, ex);
}
} else {
eventHandler.genericServerChannelException(serverChannel, new CancelledKeyException());
}
}
}

View File

@ -0,0 +1,91 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.transport.nio.channel.ChannelFactory;
import org.elasticsearch.transport.nio.channel.NioServerSocketChannel;
import org.elasticsearch.transport.nio.channel.NioSocketChannel;
import org.elasticsearch.transport.nio.channel.SelectionKeyUtils;
import java.io.IOException;
import java.util.function.Supplier;
/**
* Event handler designed to handle events from server sockets
*/
public class AcceptorEventHandler extends EventHandler {
private final Supplier<SocketSelector> selectorSupplier;
private final OpenChannels openChannels;
public AcceptorEventHandler(Logger logger, OpenChannels openChannels, Supplier<SocketSelector> selectorSupplier) {
super(logger);
this.openChannels = openChannels;
this.selectorSupplier = selectorSupplier;
}
/**
* This method is called when a NioServerSocketChannel is successfully registered. It should only be
* called once per channel.
*
* @param nioServerSocketChannel that was registered
*/
public void serverChannelRegistered(NioServerSocketChannel nioServerSocketChannel) {
SelectionKeyUtils.setAcceptInterested(nioServerSocketChannel);
openChannels.serverChannelOpened(nioServerSocketChannel);
}
/**
* This method is called when a server channel signals it is ready to accept a connection. All of the
* accept logic should occur in this call.
*
* @param nioServerChannel that can accept a connection
*/
public void acceptChannel(NioServerSocketChannel nioServerChannel) throws IOException {
ChannelFactory channelFactory = nioServerChannel.getChannelFactory();
NioSocketChannel nioSocketChannel = channelFactory.acceptNioChannel(nioServerChannel);
openChannels.acceptedChannelOpened(nioSocketChannel);
nioSocketChannel.getCloseFuture().setListener(openChannels::channelClosed);
selectorSupplier.get().registerSocketChannel(nioSocketChannel);
}
/**
* This method is called when an attempt to accept a connection throws an exception.
*
* @param nioServerChannel that accepting a connection
* @param exception that occurred
*/
public void acceptException(NioServerSocketChannel nioServerChannel, Exception exception) {
logger.debug("exception while accepting new channel", exception);
}
/**
* This method is called when handling an event from a channel fails due to an unexpected exception.
* An example would be if checking ready ops on a {@link java.nio.channels.SelectionKey} threw
* {@link java.nio.channels.CancelledKeyException}.
*
* @param channel that caused the exception
* @param exception that was thrown
*/
public void genericServerChannelException(NioServerSocketChannel channel, Exception exception) {
logger.debug("event handling exception", exception);
}
}

View File

@ -0,0 +1,196 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.transport.nio.channel.NioChannel;
import java.io.Closeable;
import java.io.IOException;
import java.nio.channels.ClosedSelectorException;
import java.nio.channels.Selector;
import java.util.Collections;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.ReentrantLock;
/**
* This is a basic selector abstraction used by {@link org.elasticsearch.transport.nio.NioTransport}. This
* selector wraps a raw nio {@link Selector}. When you call {@link #runLoop()}, the selector will run until
* {@link #close()} is called. This instance handles closing of channels. Users should call
* {@link #queueChannelClose(NioChannel)} to schedule a channel for close by this selector.
* <p>
* Children of this class should implement the specific {@link #doSelect(int)} and {@link #cleanup()}
* functionality.
*/
public abstract class ESSelector implements Closeable {
final Selector selector;
final ConcurrentLinkedQueue<NioChannel> channelsToClose = new ConcurrentLinkedQueue<>();
final Set<NioChannel> registeredChannels = Collections.newSetFromMap(new ConcurrentHashMap<NioChannel, Boolean>());
private final EventHandler eventHandler;
private final ReentrantLock runLock = new ReentrantLock();
private final AtomicBoolean isClosed = new AtomicBoolean(false);
private final PlainActionFuture<Boolean> isRunningFuture = PlainActionFuture.newFuture();
private volatile Thread thread;
ESSelector(EventHandler eventHandler) throws IOException {
this(eventHandler, Selector.open());
}
ESSelector(EventHandler eventHandler, Selector selector) throws IOException {
this.eventHandler = eventHandler;
this.selector = selector;
}
/**
* Starts this selector. The selector will run until {@link #close()} or {@link #close(boolean)} is
* called.
*/
public void runLoop() {
if (runLock.tryLock()) {
isRunningFuture.onResponse(true);
try {
setThread();
while (isOpen()) {
singleLoop();
}
} finally {
try {
cleanup();
} finally {
runLock.unlock();
}
}
} else {
throw new IllegalStateException("selector is already running");
}
}
void singleLoop() {
try {
closePendingChannels();
doSelect(300);
} catch (ClosedSelectorException e) {
if (isOpen()) {
throw e;
}
} catch (IOException e) {
eventHandler.selectException(e);
} catch (Exception e) {
eventHandler.uncaughtException(e);
}
}
/**
* Should implement the specific select logic. This will be called once per {@link #singleLoop()}
*
* @param timeout to pass to the raw select operation
* @throws IOException thrown by the raw select operation
* @throws ClosedSelectorException thrown if the raw selector is closed
*/
abstract void doSelect(int timeout) throws IOException, ClosedSelectorException;
void setThread() {
thread = Thread.currentThread();
}
public boolean isOnCurrentThread() {
return Thread.currentThread() == thread;
}
public void wakeup() {
// TODO: Do I need the wakeup optimizations that some other libraries use?
selector.wakeup();
}
public Set<NioChannel> getRegisteredChannels() {
return registeredChannels;
}
@Override
public void close() throws IOException {
close(false);
}
public void close(boolean shouldInterrupt) throws IOException {
if (isClosed.compareAndSet(false, true)) {
selector.close();
if (shouldInterrupt && thread != null) {
thread.interrupt();
} else {
wakeup();
}
runLock.lock(); // wait for the shutdown to complete
}
}
public void queueChannelClose(NioChannel channel) {
ensureOpen();
channelsToClose.offer(channel);
wakeup();
}
void closePendingChannels() {
NioChannel channel;
while ((channel = channelsToClose.poll()) != null) {
closeChannel(channel);
}
}
/**
* Called once as the selector is being closed.
*/
abstract void cleanup();
public Selector rawSelector() {
return selector;
}
public boolean isOpen() {
return isClosed.get() == false;
}
public boolean isRunning() {
return runLock.isLocked();
}
public PlainActionFuture<Boolean> isRunningFuture() {
return isRunningFuture;
}
private void closeChannel(NioChannel channel) {
try {
eventHandler.handleClose(channel);
} finally {
registeredChannels.remove(channel);
}
}
private void ensureOpen() {
if (isClosed.get()) {
throw new IllegalStateException("selector is already closed");
}
}
}

View File

@ -0,0 +1,71 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.transport.nio.channel.CloseFuture;
import org.elasticsearch.transport.nio.channel.NioChannel;
import org.elasticsearch.transport.nio.channel.NioSocketChannel;
import java.io.IOException;
import java.nio.channels.Selector;
public abstract class EventHandler {
protected final Logger logger;
public EventHandler(Logger logger) {
this.logger = logger;
}
/**
* This method handles an IOException that was thrown during a call to {@link Selector#select(long)}.
*
* @param exception that was uncaught
*/
public void selectException(IOException exception) {
logger.warn("io exception during select", exception);
}
/**
* This method handles an exception that was uncaught during a select loop.
*
* @param exception that was uncaught
*/
public void uncaughtException(Exception exception) {
Thread thread = Thread.currentThread();
thread.getUncaughtExceptionHandler().uncaughtException(thread, exception);
}
/**
* This method handles the closing of an NioChannel
*
* @param channel that should be closed
*/
public void handleClose(NioChannel channel) {
channel.closeFromSelector();
CloseFuture closeFuture = channel.getCloseFuture();
assert closeFuture.isDone() : "Should always be done as we are on the selector thread";
IOException closeException = closeFuture.getCloseException();
if (closeException != null) {
logger.trace("exception while closing channel", closeException);
}
}
}

View File

@ -0,0 +1,157 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import java.nio.ByteBuffer;
import java.util.Iterator;
public class NetworkBytesReference extends BytesReference {
private final BytesArray bytesArray;
private final ByteBuffer writeBuffer;
private final ByteBuffer readBuffer;
private int writeIndex;
private int readIndex;
public NetworkBytesReference(BytesArray bytesArray, int writeIndex, int readIndex) {
this.bytesArray = bytesArray;
this.writeIndex = writeIndex;
this.readIndex = readIndex;
this.writeBuffer = ByteBuffer.wrap(bytesArray.array());
this.readBuffer = ByteBuffer.wrap(bytesArray.array());
}
public static NetworkBytesReference wrap(BytesArray bytesArray) {
return wrap(bytesArray, 0, 0);
}
public static NetworkBytesReference wrap(BytesArray bytesArray, int writeIndex, int readIndex) {
if (readIndex > writeIndex) {
throw new IndexOutOfBoundsException("Read index [" + readIndex + "] was greater than write index [" + writeIndex + "]");
}
return new NetworkBytesReference(bytesArray, writeIndex, readIndex);
}
@Override
public byte get(int index) {
return bytesArray.get(index);
}
@Override
public int length() {
return bytesArray.length();
}
@Override
public NetworkBytesReference slice(int from, int length) {
BytesReference ref = bytesArray.slice(from, length);
BytesArray newBytesArray;
if (ref instanceof BytesArray) {
newBytesArray = (BytesArray) ref;
} else {
newBytesArray = new BytesArray(ref.toBytesRef());
}
int newReadIndex = Math.min(Math.max(readIndex - from, 0), length);
int newWriteIndex = Math.min(Math.max(writeIndex - from, 0), length);
return wrap(newBytesArray, newWriteIndex, newReadIndex);
}
@Override
public BytesRef toBytesRef() {
return bytesArray.toBytesRef();
}
@Override
public long ramBytesUsed() {
return bytesArray.ramBytesUsed();
}
public int getWriteIndex() {
return writeIndex;
}
public void incrementWrite(int delta) {
int newWriteIndex = writeIndex + delta;
if (newWriteIndex > bytesArray.length()) {
throw new IndexOutOfBoundsException("New write index [" + newWriteIndex + "] would be greater than length" +
" [" + bytesArray.length() + "]");
}
writeIndex = newWriteIndex;
}
public int getWriteRemaining() {
return bytesArray.length() - writeIndex;
}
public boolean hasWriteRemaining() {
return getWriteRemaining() > 0;
}
public int getReadIndex() {
return readIndex;
}
public void incrementRead(int delta) {
int newReadIndex = readIndex + delta;
if (newReadIndex > writeIndex) {
throw new IndexOutOfBoundsException("New read index [" + newReadIndex + "] would be greater than write" +
" index [" + writeIndex + "]");
}
readIndex = newReadIndex;
}
public int getReadRemaining() {
return writeIndex - readIndex;
}
public boolean hasReadRemaining() {
return getReadRemaining() > 0;
}
public ByteBuffer getWriteByteBuffer() {
writeBuffer.position(bytesArray.offset() + writeIndex);
writeBuffer.limit(bytesArray.offset() + bytesArray.length());
return writeBuffer;
}
public ByteBuffer getReadByteBuffer() {
readBuffer.position(bytesArray.offset() + readIndex);
readBuffer.limit(bytesArray.offset() + writeIndex);
return readBuffer;
}
public static void vectorizedIncrementReadIndexes(Iterable<NetworkBytesReference> references, int delta) {
Iterator<NetworkBytesReference> refs = references.iterator();
while (delta != 0) {
NetworkBytesReference ref = refs.next();
int amountToInc = Math.min(ref.getReadRemaining(), delta);
ref.incrementRead(amountToInc);
delta -= amountToInc;
}
}
}

View File

@ -0,0 +1,155 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.transport.ConnectTransportException;
import org.elasticsearch.transport.nio.channel.ChannelFactory;
import org.elasticsearch.transport.nio.channel.ConnectFuture;
import org.elasticsearch.transport.nio.channel.NioChannel;
import org.elasticsearch.transport.nio.channel.NioSocketChannel;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.LockSupport;
import java.util.function.Consumer;
import java.util.function.Supplier;
public class NioClient {
private static final int CLOSED = -1;
private final Logger logger;
private final OpenChannels openChannels;
private final Supplier<SocketSelector> selectorSupplier;
private final TimeValue defaultConnectTimeout;
private final ChannelFactory channelFactory;
private final Semaphore semaphore = new Semaphore(Integer.MAX_VALUE);
public NioClient(Logger logger, OpenChannels openChannels, Supplier<SocketSelector> selectorSupplier, TimeValue connectTimeout,
ChannelFactory channelFactory) {
this.logger = logger;
this.openChannels = openChannels;
this.selectorSupplier = selectorSupplier;
this.defaultConnectTimeout = connectTimeout;
this.channelFactory = channelFactory;
}
public boolean connectToChannels(DiscoveryNode node, NioSocketChannel[] channels, TimeValue connectTimeout,
Consumer<NioChannel> closeListener) throws IOException {
boolean allowedToConnect = semaphore.tryAcquire();
if (allowedToConnect == false) {
return false;
}
final ArrayList<NioSocketChannel> connections = new ArrayList<>(channels.length);
connectTimeout = getConnectTimeout(connectTimeout);
final InetSocketAddress address = node.getAddress().address();
try {
for (int i = 0; i < channels.length; i++) {
SocketSelector socketSelector = selectorSupplier.get();
NioSocketChannel nioSocketChannel = channelFactory.openNioChannel(address);
openChannels.clientChannelOpened(nioSocketChannel);
nioSocketChannel.getCloseFuture().setListener(closeListener);
connections.add(nioSocketChannel);
socketSelector.registerSocketChannel(nioSocketChannel);
}
Exception ex = null;
boolean allConnected = true;
for (NioSocketChannel socketChannel : connections) {
ConnectFuture connectFuture = socketChannel.getConnectFuture();
boolean success = connectFuture.awaitConnectionComplete(connectTimeout.getMillis(), TimeUnit.MILLISECONDS);
if (success == false) {
allConnected = false;
Exception exception = connectFuture.getException();
if (exception != null) {
ex = exception;
break;
}
}
}
if (allConnected == false) {
if (ex == null) {
throw new ConnectTransportException(node, "connect_timeout[" + connectTimeout + "]");
} else {
throw new ConnectTransportException(node, "connect_exception", ex);
}
}
addConnectionsToList(channels, connections);
return true;
} catch (IOException | RuntimeException e) {
closeChannels(connections, e);
throw e;
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
closeChannels(connections, e);
throw new ElasticsearchException(e);
} finally {
semaphore.release();
}
}
public void close() {
semaphore.acquireUninterruptibly(Integer.MAX_VALUE);
}
private TimeValue getConnectTimeout(TimeValue connectTimeout) {
if (connectTimeout != null && connectTimeout.equals(defaultConnectTimeout) == false) {
return connectTimeout;
} else {
return defaultConnectTimeout;
}
}
private static void addConnectionsToList(NioSocketChannel[] channels, ArrayList<NioSocketChannel> connections) {
final Iterator<NioSocketChannel> iterator = connections.iterator();
for (int i = 0; i < channels.length; i++) {
assert iterator.hasNext();
channels[i] = iterator.next();
}
assert iterator.hasNext() == false : "not all created connection have been consumed";
}
private void closeChannels(ArrayList<NioSocketChannel> connections, Exception e) {
for (final NioSocketChannel socketChannel : connections) {
try {
socketChannel.closeAsync().awaitClose();
} catch (InterruptedException inner) {
logger.trace("exception while closing channel", e);
e.addSuppressed(inner);
Thread.currentThread().interrupt();
} catch (Exception inner) {
logger.trace("exception while closing channel", e);
e.addSuppressed(inner);
}
}
}
}

View File

@ -0,0 +1,66 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.concurrent.CountDownLatch;
public class NioShutdown {
private final Logger logger;
public NioShutdown(Logger logger) {
this.logger = logger;
}
void orderlyShutdown(OpenChannels openChannels, NioClient client, ArrayList<AcceptingSelector> acceptors,
ArrayList<SocketSelector> socketSelectors) {
// Close the client. This ensures that no new send connections will be opened. Client could be null if exception was
// throw on start up
if (client != null) {
client.close();
}
// Start by closing the server channels. Once these are closed, we are guaranteed to no accept new connections
openChannels.closeServerChannels();
for (AcceptingSelector acceptor : acceptors) {
shutdownSelector(acceptor);
}
openChannels.close();
for (SocketSelector selector : socketSelectors) {
shutdownSelector(selector);
}
}
private void shutdownSelector(ESSelector selector) {
try {
selector.close();
} catch (IOException | ElasticsearchException e) {
logger.warn("unexpected exception while stopping selector", e);
}
}
}

View File

@ -0,0 +1,289 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.ConnectionProfile;
import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.TransportSettings;
import org.elasticsearch.transport.nio.channel.ChannelFactory;
import org.elasticsearch.transport.nio.channel.NioChannel;
import org.elasticsearch.transport.nio.channel.NioServerSocketChannel;
import org.elasticsearch.transport.nio.channel.NioSocketChannel;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ThreadFactory;
import java.util.function.Consumer;
import java.util.function.Supplier;
import static org.elasticsearch.common.settings.Setting.intSetting;
import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap;
import static org.elasticsearch.common.util.concurrent.EsExecutors.daemonThreadFactory;
public class NioTransport extends TcpTransport<NioChannel> {
// TODO: Need to add to places where we check if transport thread
public static final String TRANSPORT_WORKER_THREAD_NAME_PREFIX = "transport_worker";
public static final String TRANSPORT_ACCEPTOR_THREAD_NAME_PREFIX = "transport_acceptor";
public static final Setting<Integer> NIO_WORKER_COUNT =
new Setting<>("transport.nio.worker_count",
(s) -> Integer.toString(EsExecutors.numberOfProcessors(s) * 2),
(s) -> Setting.parseInt(s, 1, "transport.nio.worker_count"), Setting.Property.NodeScope);
public static final Setting<Integer> NIO_ACCEPTOR_COUNT =
intSetting("transport.nio.acceptor_count", 1, 1, Setting.Property.NodeScope);
private final TcpReadHandler tcpReadHandler = new TcpReadHandler(this);
private final BigArrays bigArrays;
private final ConcurrentMap<String, ChannelFactory> profileToChannelFactory = newConcurrentMap();
private final OpenChannels openChannels = new OpenChannels(logger);
private final ArrayList<AcceptingSelector> acceptors = new ArrayList<>();
private final ArrayList<SocketSelector> socketSelectors = new ArrayList<>();
private NioClient client;
private int acceptorNumber;
public NioTransport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays,
NamedWriteableRegistry namedWriteableRegistry, CircuitBreakerService circuitBreakerService) {
super("nio", settings, threadPool, bigArrays, circuitBreakerService, namedWriteableRegistry, networkService);
this.bigArrays = bigArrays;
}
@Override
public long getNumOpenServerConnections() {
return openChannels.serverChannelsCount();
}
@Override
protected InetSocketAddress getLocalAddress(NioChannel channel) {
return channel.getLocalAddress();
}
@Override
protected NioServerSocketChannel bind(String name, InetSocketAddress address) throws IOException {
ChannelFactory channelFactory = this.profileToChannelFactory.get(name);
NioServerSocketChannel serverSocketChannel = channelFactory.openNioServerSocketChannel(name, address);
acceptors.get(++acceptorNumber % NioTransport.NIO_ACCEPTOR_COUNT.get(settings)).registerServerChannel(serverSocketChannel);
return serverSocketChannel;
}
@Override
protected void closeChannels(List<NioChannel> channels) throws IOException {
IOException closingExceptions = null;
for (final NioChannel channel : channels) {
if (channel != null && channel.isOpen()) {
try {
channel.closeAsync().awaitClose();
} catch (Exception e) {
if (closingExceptions == null) {
closingExceptions = new IOException("failed to close channels");
}
closingExceptions.addSuppressed(e.getCause());
}
}
}
if (closingExceptions != null) {
throw closingExceptions;
}
}
@Override
protected void sendMessage(NioChannel channel, BytesReference reference, ActionListener<NioChannel> listener) {
if (channel instanceof NioSocketChannel) {
NioSocketChannel nioSocketChannel = (NioSocketChannel) channel;
nioSocketChannel.getWriteContext().sendMessage(reference, listener);
} else {
logger.error("cannot send message to channel of this type [{}]", channel.getClass());
}
}
@Override
protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile, Consumer<NioChannel> onChannelClose)
throws IOException {
NioSocketChannel[] channels = new NioSocketChannel[profile.getNumConnections()];
ClientChannelCloseListener closeListener = new ClientChannelCloseListener(onChannelClose);
boolean connected = client.connectToChannels(node, channels, profile.getConnectTimeout(), closeListener);
if (connected == false) {
throw new ElasticsearchException("client is shutdown");
}
return new NodeChannels(node, channels, profile);
}
@Override
protected boolean isOpen(NioChannel channel) {
return channel.isOpen();
}
@Override
protected void doStart() {
boolean success = false;
try {
if (NetworkService.NETWORK_SERVER.get(settings)) {
int workerCount = NioTransport.NIO_WORKER_COUNT.get(settings);
for (int i = 0; i < workerCount; ++i) {
SocketSelector selector = new SocketSelector(getSocketEventHandler());
socketSelectors.add(selector);
}
int acceptorCount = NioTransport.NIO_ACCEPTOR_COUNT.get(settings);
for (int i = 0; i < acceptorCount; ++i) {
Supplier<SocketSelector> selectorSupplier = new RoundRobinSelectorSupplier(socketSelectors);
AcceptorEventHandler eventHandler = new AcceptorEventHandler(logger, openChannels, selectorSupplier);
AcceptingSelector acceptor = new AcceptingSelector(eventHandler);
acceptors.add(acceptor);
}
// loop through all profiles and start them up, special handling for default one
for (Map.Entry<String, Settings> entry : buildProfileSettings().entrySet()) {
// merge fallback settings with default settings with profile settings so we have complete settings with default values
final Settings settings = Settings.builder()
.put(createFallbackSettings())
.put(entry.getValue()).build();
profileToChannelFactory.putIfAbsent(entry.getKey(), new ChannelFactory(settings, tcpReadHandler));
bindServer(entry.getKey(), settings);
}
}
client = createClient();
for (SocketSelector selector : socketSelectors) {
if (selector.isRunning() == false) {
ThreadFactory threadFactory = daemonThreadFactory(this.settings, TRANSPORT_WORKER_THREAD_NAME_PREFIX);
threadFactory.newThread(selector::runLoop).start();
selector.isRunningFuture().actionGet();
}
}
for (AcceptingSelector acceptor : acceptors) {
if (acceptor.isRunning() == false) {
ThreadFactory threadFactory = daemonThreadFactory(this.settings, TRANSPORT_ACCEPTOR_THREAD_NAME_PREFIX);
threadFactory.newThread(acceptor::runLoop).start();
acceptor.isRunningFuture().actionGet();
}
}
super.doStart();
success = true;
} catch (IOException e) {
throw new ElasticsearchException(e);
} finally {
if (success == false) {
doStop();
}
}
}
@Override
protected void stopInternal() {
NioShutdown nioShutdown = new NioShutdown(logger);
nioShutdown.orderlyShutdown(openChannels, client, acceptors, socketSelectors);
profileToChannelFactory.clear();
socketSelectors.clear();
}
protected SocketEventHandler getSocketEventHandler() {
return new SocketEventHandler(logger, this::exceptionCaught);
}
final void exceptionCaught(NioSocketChannel channel, Throwable cause) {
final Throwable unwrapped = ExceptionsHelper.unwrap(cause, ElasticsearchException.class);
final Throwable t = unwrapped != null ? unwrapped : cause;
onException(channel, t instanceof Exception ? (Exception) t : new ElasticsearchException(t));
}
private Settings createFallbackSettings() {
Settings.Builder fallbackSettingsBuilder = Settings.builder();
List<String> fallbackBindHost = TransportSettings.BIND_HOST.get(settings);
if (fallbackBindHost.isEmpty() == false) {
fallbackSettingsBuilder.putArray("bind_host", fallbackBindHost);
}
List<String> fallbackPublishHost = TransportSettings.PUBLISH_HOST.get(settings);
if (fallbackPublishHost.isEmpty() == false) {
fallbackSettingsBuilder.putArray("publish_host", fallbackPublishHost);
}
boolean fallbackTcpNoDelay = settings.getAsBoolean("transport.nio.tcp_no_delay",
NetworkService.TcpSettings.TCP_NO_DELAY.get(settings));
fallbackSettingsBuilder.put("tcp_no_delay", fallbackTcpNoDelay);
boolean fallbackTcpKeepAlive = settings.getAsBoolean("transport.nio.tcp_keep_alive",
NetworkService.TcpSettings.TCP_KEEP_ALIVE.get(settings));
fallbackSettingsBuilder.put("tcp_keep_alive", fallbackTcpKeepAlive);
boolean fallbackReuseAddress = settings.getAsBoolean("transport.nio.reuse_address",
NetworkService.TcpSettings.TCP_REUSE_ADDRESS.get(settings));
fallbackSettingsBuilder.put("reuse_address", fallbackReuseAddress);
ByteSizeValue fallbackTcpSendBufferSize = settings.getAsBytesSize("transport.nio.tcp_send_buffer_size",
TCP_SEND_BUFFER_SIZE.get(settings));
if (fallbackTcpSendBufferSize.getBytes() >= 0) {
fallbackSettingsBuilder.put("tcp_send_buffer_size", fallbackTcpSendBufferSize);
}
ByteSizeValue fallbackTcpBufferSize = settings.getAsBytesSize("transport.nio.tcp_receive_buffer_size",
TCP_RECEIVE_BUFFER_SIZE.get(settings));
if (fallbackTcpBufferSize.getBytes() >= 0) {
fallbackSettingsBuilder.put("tcp_receive_buffer_size", fallbackTcpBufferSize);
}
return fallbackSettingsBuilder.build();
}
private NioClient createClient() {
Supplier<SocketSelector> selectorSupplier = new RoundRobinSelectorSupplier(socketSelectors);
ChannelFactory channelFactory = new ChannelFactory(settings, tcpReadHandler);
return new NioClient(logger, openChannels, selectorSupplier, defaultConnectionProfile.getConnectTimeout(), channelFactory);
}
class ClientChannelCloseListener implements Consumer<NioChannel> {
private final Consumer<NioChannel> consumer;
private ClientChannelCloseListener(Consumer<NioChannel> consumer) {
this.consumer = consumer;
}
@Override
public void accept(final NioChannel channel) {
consumer.accept(channel);
openChannels.channelClosed(channel);
}
}
}

View File

@ -0,0 +1,120 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.transport.nio.channel.NioChannel;
import org.elasticsearch.transport.nio.channel.NioServerSocketChannel;
import org.elasticsearch.transport.nio.channel.NioSocketChannel;
import java.util.HashSet;
import java.util.Map;
import java.util.concurrent.ConcurrentMap;
import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap;
public class OpenChannels implements Releasable {
// TODO: Maybe set concurrency levels?
private final ConcurrentMap<NioSocketChannel, Long> openClientChannels = newConcurrentMap();
private final ConcurrentMap<NioSocketChannel, Long> openAcceptedChannels = newConcurrentMap();
private final ConcurrentMap<NioServerSocketChannel, Long> openServerChannels = newConcurrentMap();
private final Logger logger;
public OpenChannels(Logger logger) {
this.logger = logger;
}
public void serverChannelOpened(NioServerSocketChannel channel) {
boolean added = openServerChannels.putIfAbsent(channel, System.nanoTime()) == null;
if (added && logger.isTraceEnabled()) {
logger.trace("server channel opened: {}", channel);
}
}
public long serverChannelsCount() {
return openServerChannels.size();
}
public void acceptedChannelOpened(NioSocketChannel channel) {
boolean added = openAcceptedChannels.putIfAbsent(channel, System.nanoTime()) == null;
if (added && logger.isTraceEnabled()) {
logger.trace("accepted channel opened: {}", channel);
}
}
public HashSet<NioSocketChannel> getAcceptedChannels() {
return new HashSet<>(openAcceptedChannels.keySet());
}
public void clientChannelOpened(NioSocketChannel channel) {
boolean added = openClientChannels.putIfAbsent(channel, System.nanoTime()) == null;
if (added && logger.isTraceEnabled()) {
logger.trace("client channel opened: {}", channel);
}
}
public void channelClosed(NioChannel channel) {
boolean removed;
if (channel instanceof NioServerSocketChannel) {
removed = openServerChannels.remove(channel) != null;
} else {
NioSocketChannel socketChannel = (NioSocketChannel) channel;
removed = openClientChannels.remove(socketChannel) != null;
if (removed == false) {
removed = openAcceptedChannels.remove(socketChannel) != null;
}
}
if (removed && logger.isTraceEnabled()) {
logger.trace("channel closed: {}", channel);
}
}
public void closeServerChannels() {
for (NioServerSocketChannel channel : openServerChannels.keySet()) {
ensureClosedInternal(channel);
}
openServerChannels.clear();
}
@Override
public void close() {
for (NioSocketChannel channel : openClientChannels.keySet()) {
ensureClosedInternal(channel);
}
for (NioSocketChannel channel : openAcceptedChannels.keySet()) {
ensureClosedInternal(channel);
}
openClientChannels.clear();
openAcceptedChannels.clear();
}
private void ensureClosedInternal(NioChannel channel) {
try {
channel.closeAsync().get();
} catch (Exception e) {
logger.trace("exception while closing channels", e);
}
}
}

View File

@ -0,0 +1,40 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import java.util.ArrayList;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
public class RoundRobinSelectorSupplier implements Supplier<SocketSelector> {
private final ArrayList<SocketSelector> selectors;
private final int count;
private AtomicInteger counter = new AtomicInteger(0);
public RoundRobinSelectorSupplier(ArrayList<SocketSelector> selectors) {
this.count = selectors.size();
this.selectors = selectors;
}
public SocketSelector get() {
return selectors.get(counter.getAndIncrement() % count);
}
}

View File

@ -0,0 +1,154 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.transport.nio.channel.NioSocketChannel;
import org.elasticsearch.transport.nio.channel.SelectionKeyUtils;
import org.elasticsearch.transport.nio.channel.WriteContext;
import java.io.IOException;
import java.util.function.BiConsumer;
/**
* Event handler designed to handle events from non-server sockets
*/
public class SocketEventHandler extends EventHandler {
private final BiConsumer<NioSocketChannel, Throwable> exceptionHandler;
private final Logger logger;
public SocketEventHandler(Logger logger, BiConsumer<NioSocketChannel, Throwable> exceptionHandler) {
super(logger);
this.exceptionHandler = exceptionHandler;
this.logger = logger;
}
/**
* This method is called when a NioSocketChannel is successfully registered. It should only be called
* once per channel.
*
* @param channel that was registered
*/
public void handleRegistration(NioSocketChannel channel) {
SelectionKeyUtils.setConnectAndReadInterested(channel);
}
/**
* This method is called when an attempt to register a channel throws an exception.
*
* @param channel that was registered
* @param exception that occurred
*/
public void registrationException(NioSocketChannel channel, Exception exception) {
logger.trace("failed to register channel", exception);
exceptionCaught(channel, exception);
}
/**
* This method is called when a NioSocketChannel is successfully connected. It should only be called
* once per channel.
*
* @param channel that was registered
*/
public void handleConnect(NioSocketChannel channel) {
SelectionKeyUtils.removeConnectInterested(channel);
}
/**
* This method is called when an attempt to connect a channel throws an exception.
*
* @param channel that was connecting
* @param exception that occurred
*/
public void connectException(NioSocketChannel channel, Exception exception) {
logger.trace("failed to connect to channel", exception);
exceptionCaught(channel, exception);
}
/**
* This method is called when a channel signals it is ready for be read. All of the read logic should
* occur in this call.
*
* @param channel that can be read
*/
public void handleRead(NioSocketChannel channel) throws IOException {
int bytesRead = channel.getReadContext().read();
if (bytesRead == -1) {
handleClose(channel);
}
}
/**
* This method is called when an attempt to read from a channel throws an exception.
*
* @param channel that was being read
* @param exception that occurred
*/
public void readException(NioSocketChannel channel, Exception exception) {
logger.trace("failed to read from channel", exception);
exceptionCaught(channel, exception);
}
/**
* This method is called when a channel signals it is ready to receive writes. All of the write logic
* should occur in this call.
*
* @param channel that can be read
*/
public void handleWrite(NioSocketChannel channel) throws IOException {
WriteContext channelContext = channel.getWriteContext();
channelContext.flushChannel();
if (channelContext.hasQueuedWriteOps()) {
SelectionKeyUtils.setWriteInterested(channel);
} else {
SelectionKeyUtils.removeWriteInterested(channel);
}
}
/**
* This method is called when an attempt to write to a channel throws an exception.
*
* @param channel that was being written to
* @param exception that occurred
*/
public void writeException(NioSocketChannel channel, Exception exception) {
logger.trace("failed to write to channel", exception);
exceptionCaught(channel, exception);
}
/**
* This method is called when handling an event from a channel fails due to an unexpected exception.
* An example would be if checking ready ops on a {@link java.nio.channels.SelectionKey} threw
* {@link java.nio.channels.CancelledKeyException}.
*
* @param channel that caused the exception
* @param exception that was thrown
*/
public void genericChannelException(NioSocketChannel channel, Exception exception) {
logger.trace("event handling failed", exception);
exceptionCaught(channel, exception);
}
private void exceptionCaught(NioSocketChannel channel, Exception e) {
exceptionHandler.accept(channel, e);
}
}

View File

@ -0,0 +1,216 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.elasticsearch.transport.nio.channel.NioSocketChannel;
import org.elasticsearch.transport.nio.channel.SelectionKeyUtils;
import org.elasticsearch.transport.nio.channel.WriteContext;
import java.io.IOException;
import java.nio.channels.CancelledKeyException;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.ClosedSelectorException;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.util.Iterator;
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedQueue;
/**
* Selector implementation that handles {@link NioSocketChannel}. It's main piece of functionality is
* handling connect, read, and write events.
*/
public class SocketSelector extends ESSelector {
private final ConcurrentLinkedQueue<NioSocketChannel> newChannels = new ConcurrentLinkedQueue<>();
private final ConcurrentLinkedQueue<WriteOperation> queuedWrites = new ConcurrentLinkedQueue<>();
private final SocketEventHandler eventHandler;
public SocketSelector(SocketEventHandler eventHandler) throws IOException {
super(eventHandler);
this.eventHandler = eventHandler;
}
public SocketSelector(SocketEventHandler eventHandler, Selector selector) throws IOException {
super(eventHandler, selector);
this.eventHandler = eventHandler;
}
@Override
void doSelect(int timeout) throws IOException, ClosedSelectorException {
setUpNewChannels();
handleQueuedWrites();
int ready = selector.select(timeout);
if (ready > 0) {
Set<SelectionKey> selectionKeys = selector.selectedKeys();
processKeys(selectionKeys);
}
}
@Override
void cleanup() {
WriteOperation op;
while ((op = queuedWrites.poll()) != null) {
op.getListener().onFailure(new ClosedSelectorException());
}
channelsToClose.addAll(newChannels);
channelsToClose.addAll(registeredChannels);
closePendingChannels();
}
/**
* Registers a NioSocketChannel to be handled by this selector. The channel will by queued and eventually
* registered next time through the event loop.
* @param nioSocketChannel the channel to register
*/
public void registerSocketChannel(NioSocketChannel nioSocketChannel) {
newChannels.offer(nioSocketChannel);
wakeup();
}
/**
* Queues a write operation to be handled by the event loop. This can be called by any thread and is the
* api available for non-selector threads to schedule writes.
*
* @param writeOperation to be queued
*/
public void queueWrite(WriteOperation writeOperation) {
queuedWrites.offer(writeOperation);
if (isOpen() == false) {
boolean wasRemoved = queuedWrites.remove(writeOperation);
if (wasRemoved) {
writeOperation.getListener().onFailure(new ClosedSelectorException());
}
} else {
wakeup();
}
}
/**
* Queues a write operation directly in a channel's buffer. Channel buffers are only safe to be accessed
* by the selector thread. As a result, this method should only be called by the selector thread.
*
* @param writeOperation to be queued in a channel's buffer
*/
public void queueWriteInChannelBuffer(WriteOperation writeOperation) {
assert isOnCurrentThread() : "Must be on selector thread";
NioSocketChannel channel = writeOperation.getChannel();
WriteContext context = channel.getWriteContext();
try {
SelectionKeyUtils.setWriteInterested(channel);
context.queueWriteOperations(writeOperation);
} catch (Exception e) {
writeOperation.getListener().onFailure(e);
}
}
private void processKeys(Set<SelectionKey> selectionKeys) {
Iterator<SelectionKey> keyIterator = selectionKeys.iterator();
while (keyIterator.hasNext()) {
SelectionKey sk = keyIterator.next();
keyIterator.remove();
NioSocketChannel nioSocketChannel = (NioSocketChannel) sk.attachment();
if (sk.isValid()) {
try {
int ops = sk.readyOps();
if ((ops & SelectionKey.OP_CONNECT) != 0) {
attemptConnect(nioSocketChannel);
}
if (nioSocketChannel.isConnectComplete()) {
if ((ops & SelectionKey.OP_WRITE) != 0) {
handleWrite(nioSocketChannel);
}
if ((ops & SelectionKey.OP_READ) != 0) {
handleRead(nioSocketChannel);
}
}
} catch (CancelledKeyException e) {
eventHandler.genericChannelException(nioSocketChannel, e);
}
} else {
eventHandler.genericChannelException(nioSocketChannel, new CancelledKeyException());
}
}
}
private void handleWrite(NioSocketChannel nioSocketChannel) {
try {
eventHandler.handleWrite(nioSocketChannel);
} catch (Exception e) {
eventHandler.writeException(nioSocketChannel, e);
}
}
private void handleRead(NioSocketChannel nioSocketChannel) {
try {
eventHandler.handleRead(nioSocketChannel);
} catch (Exception e) {
eventHandler.readException(nioSocketChannel, e);
}
}
private void handleQueuedWrites() {
WriteOperation writeOperation;
while ((writeOperation = queuedWrites.poll()) != null) {
if (writeOperation.getChannel().isWritable()) {
queueWriteInChannelBuffer(writeOperation);
} else {
writeOperation.getListener().onFailure(new ClosedChannelException());
}
}
}
private void setUpNewChannels() {
NioSocketChannel newChannel;
while ((newChannel = this.newChannels.poll()) != null) {
setupChannel(newChannel);
}
}
private void setupChannel(NioSocketChannel newChannel) {
try {
if (newChannel.register(this)) {
registeredChannels.add(newChannel);
SelectionKey key = newChannel.getSelectionKey();
key.attach(newChannel);
eventHandler.handleRegistration(newChannel);
attemptConnect(newChannel);
}
} catch (Exception e) {
eventHandler.registrationException(newChannel, e);
}
}
private void attemptConnect(NioSocketChannel newChannel) {
try {
if (newChannel.finishConnect()) {
eventHandler.handleConnect(newChannel);
}
} catch (Exception e) {
eventHandler.connectException(newChannel, e);
}
}
}

View File

@ -0,0 +1,47 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.transport.nio.channel.NioSocketChannel;
import java.io.IOException;
public class TcpReadHandler {
private final NioTransport transport;
public TcpReadHandler(NioTransport transport) {
this.transport = transport;
}
public void handleMessage(BytesReference reference, NioSocketChannel channel, String profileName,
int messageBytesLength) {
try {
transport.messageReceived(reference, channel, profileName, channel.getRemoteAddress(), messageBytesLength);
} catch (IOException e) {
handleException(channel, e);
}
}
public void handleException(NioSocketChannel channel, Exception e) {
transport.exceptionCaught(channel, e);
}
}

View File

@ -0,0 +1,81 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefIterator;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.transport.nio.channel.NioChannel;
import org.elasticsearch.transport.nio.channel.NioSocketChannel;
import java.io.IOException;
import java.util.ArrayList;
public class WriteOperation {
private final NioSocketChannel channel;
private final ActionListener<NioChannel> listener;
private final NetworkBytesReference[] references;
public WriteOperation(NioSocketChannel channel, BytesReference bytesReference, ActionListener<NioChannel> listener) {
this.channel = channel;
this.listener = listener;
this.references = toArray(bytesReference);
}
public NetworkBytesReference[] getByteReferences() {
return references;
}
public ActionListener<NioChannel> getListener() {
return listener;
}
public NioSocketChannel getChannel() {
return channel;
}
public boolean isFullyFlushed() {
return references[references.length - 1].hasReadRemaining() == false;
}
public int flush() throws IOException {
return channel.write(references);
}
private static NetworkBytesReference[] toArray(BytesReference reference) {
BytesRefIterator byteRefIterator = reference.iterator();
BytesRef r;
try {
// Most network messages are composed of three buffers
ArrayList<NetworkBytesReference> references = new ArrayList<>(3);
while ((r = byteRefIterator.next()) != null) {
references.add(NetworkBytesReference.wrap(new BytesArray(r), r.length, 0));
}
return references.toArray(new NetworkBytesReference[references.size()]);
} catch (IOException e) {
// this is really an error since we don't do IO in our bytesreferences
throw new AssertionError("won't happen", e);
}
}
}

View File

@ -0,0 +1,205 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio.channel;
import org.elasticsearch.transport.nio.ESSelector;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.NetworkChannel;
import java.nio.channels.SelectableChannel;
import java.nio.channels.SelectionKey;
import java.util.concurrent.atomic.AtomicInteger;
/**
* This is a basic channel abstraction used by the {@link org.elasticsearch.transport.nio.NioTransport}.
* <p>
* A channel is open once it is constructed. The channel remains open and {@link #isOpen()} will return
* true until the channel is explicitly closed.
* <p>
* A channel lifecycle has four stages:
* <ol>
* <li>UNREGISTERED - When a channel is created and prior to it being registered with a selector.
* <li>REGISTERED - When a channel has been registered with a selector. This is the state of a channel that
* can perform normal operations.
* <li>CLOSING - When a channel has been marked for closed, but is not yet closed. {@link #isOpen()} will
* still return true. Normal operations should be rejected. The most common scenario for a channel to be
* CLOSING is when channel that was REGISTERED has {@link #closeAsync()} called, but the selector thread
* has not yet closed the channel.
* <li>CLOSED - The channel has been closed.
* </ol>
*
* @param <S> the type of raw channel this AbstractNioChannel uses
*/
public abstract class AbstractNioChannel<S extends SelectableChannel & NetworkChannel> implements NioChannel {
static final int UNREGISTERED = 0;
static final int REGISTERED = 1;
static final int CLOSING = 2;
static final int CLOSED = 3;
final S socketChannel;
final AtomicInteger state = new AtomicInteger(UNREGISTERED);
private final InetSocketAddress localAddress;
private final String profile;
private final CloseFuture closeFuture = new CloseFuture();
private volatile ESSelector selector;
private SelectionKey selectionKey;
public AbstractNioChannel(String profile, S socketChannel) throws IOException {
this.profile = profile;
this.socketChannel = socketChannel;
this.localAddress = (InetSocketAddress) socketChannel.getLocalAddress();
}
@Override
public boolean isOpen() {
return closeFuture.isClosed() == false;
}
@Override
public InetSocketAddress getLocalAddress() {
return localAddress;
}
@Override
public String getProfile() {
return profile;
}
/**
* Schedules a channel to be closed by the selector event loop with which it is registered.
* <p>
* If the current state is UNREGISTERED, the call will attempt to transition the state from UNREGISTERED
* to CLOSING. If this transition is successful, the channel can no longer be registered with an event
* loop and the channel will be synchronously closed in this method call.
* <p>
* If the channel is REGISTERED and the state can be transitioned to CLOSING, the close operation will
* be scheduled with the event loop.
* <p>
* If the channel is CLOSING or CLOSED, nothing will be done.
*
* @return future that will be complete when the channel is closed
*/
@Override
public CloseFuture closeAsync() {
if (selector != null && selector.isOnCurrentThread()) {
closeFromSelector();
return closeFuture;
}
for (; ; ) {
int state = this.state.get();
if (state == UNREGISTERED && this.state.compareAndSet(UNREGISTERED, CLOSING)) {
close0();
break;
} else if (state == REGISTERED && this.state.compareAndSet(REGISTERED, CLOSING)) {
selector.queueChannelClose(this);
break;
} else if (state == CLOSING || state == CLOSED) {
break;
}
}
return closeFuture;
}
/**
* Closes the channel synchronously. This method should only be called from the selector thread.
* <p>
* Once this method returns, the channel will be closed.
*/
@Override
public void closeFromSelector() {
// This will not exit the loop until this thread or someone else has set the state to CLOSED.
// Whichever thread succeeds in setting the state to CLOSED will close the raw channel.
for (; ; ) {
int state = this.state.get();
if (state < CLOSING && this.state.compareAndSet(state, CLOSING)) {
close0();
} else if (state == CLOSING) {
close0();
} else if (state == CLOSED) {
break;
}
}
}
/**
* This method attempts to registered a channel with a selector. If method returns true the channel was
* successfully registered. If it returns false, the registration failed. The reason a registered might
* fail is if something else closed this channel.
*
* @param selector to register the channel
* @return if the channel was successfully registered
* @throws ClosedChannelException if the raw channel was closed
*/
@Override
public boolean register(ESSelector selector) throws ClosedChannelException {
if (markRegistered(selector)) {
setSelectionKey(socketChannel.register(selector.rawSelector(), 0));
return true;
} else {
return false;
}
}
@Override
public ESSelector getSelector() {
return selector;
}
@Override
public SelectionKey getSelectionKey() {
return selectionKey;
}
@Override
public CloseFuture getCloseFuture() {
return closeFuture;
}
@Override
public S getRawChannel() {
return socketChannel;
}
// Package visibility for testing
void setSelectionKey(SelectionKey selectionKey) {
this.selectionKey = selectionKey;
}
boolean markRegistered(ESSelector selector) {
this.selector = selector;
return state.compareAndSet(UNREGISTERED, REGISTERED);
}
private void close0() {
if (this.state.compareAndSet(CLOSING, CLOSED)) {
try {
socketChannel.close();
closeFuture.channelClosed(this);
} catch (IOException e) {
closeFuture.channelCloseThrewException(this, e);
}
}
}
}

View File

@ -0,0 +1,105 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio.channel;
import org.elasticsearch.common.CheckedSupplier;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.mocksocket.PrivilegedSocketAccess;
import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.nio.TcpReadHandler;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
public class ChannelFactory {
private final boolean tcpNoDelay;
private final boolean tcpKeepAlive;
private final boolean tcpReusedAddress;
private final int tcpSendBufferSize;
private final int tcpReceiveBufferSize;
private final TcpReadHandler handler;
public ChannelFactory(Settings settings, TcpReadHandler handler) {
tcpNoDelay = TcpTransport.TCP_NO_DELAY.get(settings);
tcpKeepAlive = TcpTransport.TCP_KEEP_ALIVE.get(settings);
tcpReusedAddress = TcpTransport.TCP_REUSE_ADDRESS.get(settings);
tcpSendBufferSize = Math.toIntExact(TcpTransport.TCP_SEND_BUFFER_SIZE.get(settings).getBytes());
tcpReceiveBufferSize = Math.toIntExact(TcpTransport.TCP_RECEIVE_BUFFER_SIZE.get(settings).getBytes());
this.handler = handler;
}
public NioSocketChannel openNioChannel(InetSocketAddress remoteAddress) throws IOException {
SocketChannel rawChannel = SocketChannel.open();
configureSocketChannel(rawChannel);
PrivilegedSocketAccess.connect(rawChannel, remoteAddress);
NioSocketChannel channel = new NioSocketChannel(NioChannel.CLIENT, rawChannel);
channel.setContexts(new TcpReadContext(channel, handler), new TcpWriteContext(channel));
return channel;
}
public NioSocketChannel acceptNioChannel(NioServerSocketChannel serverChannel) throws IOException {
ServerSocketChannel serverSocketChannel = serverChannel.getRawChannel();
SocketChannel rawChannel = PrivilegedSocketAccess.accept(serverSocketChannel);
configureSocketChannel(rawChannel);
NioSocketChannel channel = new NioSocketChannel(serverChannel.getProfile(), rawChannel);
channel.setContexts(new TcpReadContext(channel, handler), new TcpWriteContext(channel));
return channel;
}
public NioServerSocketChannel openNioServerSocketChannel(String profileName, InetSocketAddress address)
throws IOException {
ServerSocketChannel socketChannel = ServerSocketChannel.open();
socketChannel.configureBlocking(false);
ServerSocket socket = socketChannel.socket();
socket.setReuseAddress(tcpReusedAddress);
socketChannel.bind(address);
return new NioServerSocketChannel(profileName, socketChannel, this);
}
private void configureSocketChannel(SocketChannel channel) throws IOException {
channel.configureBlocking(false);
Socket socket = channel.socket();
socket.setTcpNoDelay(tcpNoDelay);
socket.setKeepAlive(tcpKeepAlive);
socket.setReuseAddress(tcpReusedAddress);
if (tcpSendBufferSize > 0) {
socket.setSendBufferSize(tcpSendBufferSize);
}
if (tcpReceiveBufferSize > 0) {
socket.setSendBufferSize(tcpReceiveBufferSize);
}
}
private static <T> T getSocketChannel(CheckedSupplier<T, IOException> supplier) throws IOException {
try {
return AccessController.doPrivileged((PrivilegedExceptionAction<T>) supplier::get);
} catch (PrivilegedActionException e) {
throw (IOException) e.getCause();
}
}
}

View File

@ -0,0 +1,104 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio.channel;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.common.util.concurrent.BaseFuture;
import java.io.IOException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Consumer;
public class CloseFuture extends BaseFuture<NioChannel> {
private final SetOnce<Consumer<NioChannel>> listener = new SetOnce<>();
@Override
public boolean cancel(boolean mayInterruptIfRunning) {
throw new UnsupportedOperationException("Cannot cancel close future");
}
public void awaitClose() throws InterruptedException, IOException {
try {
super.get();
} catch (ExecutionException e) {
throw (IOException) e.getCause();
}
}
public void awaitClose(long timeout, TimeUnit unit) throws InterruptedException, TimeoutException, IOException {
try {
super.get(timeout, unit);
} catch (ExecutionException e) {
throw (IOException) e.getCause();
}
}
public IOException getCloseException() {
if (isDone()) {
try {
super.get(0, TimeUnit.NANOSECONDS);
return null;
} catch (ExecutionException e) {
// We only make a setter for IOException
return (IOException) e.getCause();
} catch (TimeoutException e) {
return null;
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
return null;
}
} else {
return null;
}
}
public boolean isClosed() {
return super.isDone();
}
public void setListener(Consumer<NioChannel> listener) {
this.listener.set(listener);
}
void channelClosed(NioChannel channel) {
boolean set = set(channel);
if (set) {
Consumer<NioChannel> listener = this.listener.get();
if (listener != null) {
listener.accept(channel);
}
}
}
void channelCloseThrewException(NioChannel channel, IOException ex) {
boolean set = setException(ex);
if (set) {
Consumer<NioChannel> listener = this.listener.get();
if (listener != null) {
listener.accept(channel);
}
}
}
}

View File

@ -0,0 +1,94 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio.channel;
import org.elasticsearch.common.util.concurrent.BaseFuture;
import java.io.IOException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
public class ConnectFuture extends BaseFuture<NioSocketChannel> {
public boolean awaitConnectionComplete(long timeout, TimeUnit unit) throws InterruptedException {
try {
super.get(timeout, unit);
return true;
} catch (ExecutionException | TimeoutException e) {
return false;
}
}
public Exception getException() {
if (isDone()) {
try {
// Get should always return without blocking as we already checked 'isDone'
// We are calling 'get' here in order to throw the ExecutionException
super.get();
return null;
} catch (ExecutionException e) {
// We only make a public setters for IOException or RuntimeException
return (Exception) e.getCause();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
return null;
}
} else {
return null;
}
}
public boolean isConnectComplete() {
return getChannel() != null;
}
public boolean connectFailed() {
return getException() != null;
}
void setConnectionComplete(NioSocketChannel channel) {
set(channel);
}
void setConnectionFailed(IOException e) {
setException(e);
}
void setConnectionFailed(RuntimeException e) {
setException(e);
}
private NioSocketChannel getChannel() {
if (isDone()) {
try {
// Get should always return without blocking as we already checked 'isDone'
return super.get();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
return null;
} catch (ExecutionException e) {
return null;
}
} else {
return null;
}
}
}

View File

@ -0,0 +1,52 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio.channel;
import org.elasticsearch.transport.nio.ESSelector;
import java.net.InetSocketAddress;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.NetworkChannel;
import java.nio.channels.SelectionKey;
public interface NioChannel {
String CLIENT = "client-socket";
boolean isOpen();
InetSocketAddress getLocalAddress();
String getProfile();
CloseFuture closeAsync();
void closeFromSelector();
boolean register(ESSelector selector) throws ClosedChannelException;
ESSelector getSelector();
SelectionKey getSelectionKey();
CloseFuture getCloseFuture();
NetworkChannel getRawChannel();
}

View File

@ -0,0 +1,37 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio.channel;
import java.io.IOException;
import java.nio.channels.ServerSocketChannel;
public class NioServerSocketChannel extends AbstractNioChannel<ServerSocketChannel> {
private final ChannelFactory channelFactory;
public NioServerSocketChannel(String profile, ServerSocketChannel socketChannel, ChannelFactory channelFactory) throws IOException {
super(profile, socketChannel);
this.channelFactory = channelFactory;
}
public ChannelFactory getChannelFactory() {
return channelFactory;
}
}

View File

@ -0,0 +1,189 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio.channel;
import org.elasticsearch.transport.nio.NetworkBytesReference;
import org.elasticsearch.transport.nio.ESSelector;
import org.elasticsearch.transport.nio.SocketSelector;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.SocketChannel;
import java.util.Arrays;
public class NioSocketChannel extends AbstractNioChannel<SocketChannel> {
private final InetSocketAddress remoteAddress;
private final ConnectFuture connectFuture = new ConnectFuture();
private volatile SocketSelector socketSelector;
private WriteContext writeContext;
private ReadContext readContext;
public NioSocketChannel(String profile, SocketChannel socketChannel) throws IOException {
super(profile, socketChannel);
this.remoteAddress = (InetSocketAddress) socketChannel.getRemoteAddress();
}
@Override
public CloseFuture closeAsync() {
clearQueuedWrites();
return super.closeAsync();
}
@Override
public void closeFromSelector() {
// Even if the channel has already been closed we will clear any pending write operations just in case
clearQueuedWrites();
super.closeFromSelector();
}
@Override
public SocketSelector getSelector() {
return socketSelector;
}
@Override
boolean markRegistered(ESSelector selector) {
this.socketSelector = (SocketSelector) selector;
return super.markRegistered(selector);
}
public int write(NetworkBytesReference[] references) throws IOException {
int written;
if (references.length == 1) {
written = socketChannel.write(references[0].getReadByteBuffer());
} else {
ByteBuffer[] buffers = new ByteBuffer[references.length];
for (int i = 0; i < references.length; ++i) {
buffers[i] = references[i].getReadByteBuffer();
}
written = (int) socketChannel.write(buffers);
}
if (written <= 0) {
return written;
}
NetworkBytesReference.vectorizedIncrementReadIndexes(Arrays.asList(references), written);
return written;
}
public int read(NetworkBytesReference reference) throws IOException {
int bytesRead = socketChannel.read(reference.getWriteByteBuffer());
if (bytesRead == -1) {
return bytesRead;
}
reference.incrementWrite(bytesRead);
return bytesRead;
}
public void setContexts(ReadContext readContext, WriteContext writeContext) {
this.readContext = readContext;
this.writeContext = writeContext;
}
public WriteContext getWriteContext() {
return writeContext;
}
public ReadContext getReadContext() {
return readContext;
}
public InetSocketAddress getRemoteAddress() {
return remoteAddress;
}
public boolean isConnectComplete() {
return connectFuture.isConnectComplete();
}
public boolean isWritable() {
return state.get() == REGISTERED;
}
public boolean isReadable() {
return state.get() == REGISTERED;
}
/**
* This method will attempt to complete the connection process for this channel. It should be called for
* new channels or for a channel that has produced a OP_CONNECT event. If this method returns true then
* the connection is complete and the channel is ready for reads and writes. If it returns false, the
* channel is not yet connected and this method should be called again when a OP_CONNECT event is
* received.
*
* @return true if the connection process is complete
* @throws IOException if an I/O error occurs
*/
public boolean finishConnect() throws IOException {
if (connectFuture.isConnectComplete()) {
return true;
} else if (connectFuture.connectFailed()) {
Exception exception = connectFuture.getException();
if (exception instanceof IOException) {
throw (IOException) exception;
} else {
throw (RuntimeException) exception;
}
}
boolean isConnected = socketChannel.isConnected();
if (isConnected == false) {
isConnected = internalFinish();
}
if (isConnected) {
connectFuture.setConnectionComplete(this);
}
return isConnected;
}
public ConnectFuture getConnectFuture() {
return connectFuture;
}
private boolean internalFinish() throws IOException {
try {
return socketChannel.finishConnect();
} catch (IOException e) {
connectFuture.setConnectionFailed(e);
throw e;
} catch (RuntimeException e) {
connectFuture.setConnectionFailed(e);
throw e;
}
}
private void clearQueuedWrites() {
// Even if the channel has already been closed we will clear any pending write operations just in case
if (state.get() > UNREGISTERED) {
SocketSelector selector = getSelector();
if (selector != null && selector.isOnCurrentThread() && writeContext.hasQueuedWriteOps()) {
writeContext.clearQueuedWriteOps(new ClosedChannelException());
}
}
}
}

View File

@ -0,0 +1,28 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio.channel;
import java.io.IOException;
public interface ReadContext {
int read() throws IOException;
}

View File

@ -0,0 +1,53 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio.channel;
import java.nio.channels.CancelledKeyException;
import java.nio.channels.SelectionKey;
public final class SelectionKeyUtils {
private SelectionKeyUtils() {}
public static void setWriteInterested(NioChannel channel) throws CancelledKeyException {
SelectionKey selectionKey = channel.getSelectionKey();
selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_WRITE);
}
public static void removeWriteInterested(NioChannel channel) throws CancelledKeyException {
SelectionKey selectionKey = channel.getSelectionKey();
selectionKey.interestOps(selectionKey.interestOps() & ~SelectionKey.OP_WRITE);
}
public static void setConnectAndReadInterested(NioChannel channel) throws CancelledKeyException {
SelectionKey selectionKey = channel.getSelectionKey();
selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_CONNECT | SelectionKey.OP_READ);
}
public static void removeConnectInterested(NioChannel channel) throws CancelledKeyException {
SelectionKey selectionKey = channel.getSelectionKey();
selectionKey.interestOps(selectionKey.interestOps() & ~SelectionKey.OP_CONNECT);
}
public static void setAcceptInterested(NioServerSocketChannel channel) {
SelectionKey selectionKey = channel.getSelectionKey();
selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_ACCEPT);
}
}

View File

@ -0,0 +1,118 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio.channel;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.monitor.jvm.JvmInfo;
import org.elasticsearch.transport.TcpHeader;
import org.elasticsearch.transport.TcpTransport;
import java.io.IOException;
import java.io.StreamCorruptedException;
public class TcpFrameDecoder {
private static final long NINETY_PER_HEAP_SIZE = (long) (JvmInfo.jvmInfo().getMem().getHeapMax().getBytes() * 0.9);
private static final int HEADER_SIZE = TcpHeader.MARKER_BYTES_SIZE + TcpHeader.MESSAGE_LENGTH_SIZE;
private int expectedMessageLength = -1;
public BytesReference decode(BytesReference bytesReference, int currentBufferSize) throws IOException {
if (currentBufferSize >= 6) {
int messageLength = readHeaderBuffer(bytesReference);
int totalLength = messageLength + HEADER_SIZE;
if (totalLength > currentBufferSize) {
expectedMessageLength = totalLength;
return null;
} else if (totalLength == bytesReference.length()) {
expectedMessageLength = -1;
return bytesReference;
} else {
expectedMessageLength = -1;
return bytesReference.slice(0, totalLength);
}
} else {
return null;
}
}
public int expectedMessageLength() {
return expectedMessageLength;
}
private int readHeaderBuffer(BytesReference headerBuffer) throws IOException {
if (headerBuffer.get(0) != 'E' || headerBuffer.get(1) != 'S') {
if (appearsToBeHTTP(headerBuffer)) {
throw new TcpTransport.HttpOnTransportException("This is not a HTTP port");
}
throw new StreamCorruptedException("invalid internal transport message format, got ("
+ Integer.toHexString(headerBuffer.get(0) & 0xFF) + ","
+ Integer.toHexString(headerBuffer.get(1) & 0xFF) + ","
+ Integer.toHexString(headerBuffer.get(2) & 0xFF) + ","
+ Integer.toHexString(headerBuffer.get(3) & 0xFF) + ")");
}
final int messageLength;
try (StreamInput input = headerBuffer.streamInput()) {
input.skip(TcpHeader.MARKER_BYTES_SIZE);
messageLength = input.readInt();
}
if (messageLength == -1) {
// This is a ping
return 0;
}
if (messageLength <= 0) {
throw new StreamCorruptedException("invalid data length: " + messageLength);
}
if (messageLength > NINETY_PER_HEAP_SIZE) {
throw new IllegalArgumentException("transport content length received [" + new ByteSizeValue(messageLength) + "] exceeded ["
+ new ByteSizeValue(NINETY_PER_HEAP_SIZE) + "]");
}
return messageLength;
}
private static boolean appearsToBeHTTP(BytesReference headerBuffer) {
return bufferStartsWith(headerBuffer, "GET") ||
bufferStartsWith(headerBuffer, "POST") ||
bufferStartsWith(headerBuffer, "PUT") ||
bufferStartsWith(headerBuffer, "HEAD") ||
bufferStartsWith(headerBuffer, "DELETE") ||
// TODO: Actually 'OPTIONS'. But that does not currently fit in 6 bytes
bufferStartsWith(headerBuffer, "OPTION") ||
bufferStartsWith(headerBuffer, "PATCH") ||
bufferStartsWith(headerBuffer, "TRACE");
}
private static boolean bufferStartsWith(BytesReference buffer, String method) {
char[] chars = method.toCharArray();
for (int i = 0; i < chars.length; i++) {
if (buffer.get(i) != chars[i]) {
return false;
}
}
return true;
}
}

View File

@ -0,0 +1,109 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio.channel;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.bytes.CompositeBytesReference;
import org.elasticsearch.transport.nio.NetworkBytesReference;
import org.elasticsearch.transport.nio.TcpReadHandler;
import java.io.IOException;
import java.util.Iterator;
import java.util.LinkedList;
public class TcpReadContext implements ReadContext {
private static final int DEFAULT_READ_LENGTH = 1 << 14;
private final TcpReadHandler handler;
private final NioSocketChannel channel;
private final TcpFrameDecoder frameDecoder;
private final LinkedList<NetworkBytesReference> references = new LinkedList<>();
private int rawBytesCount = 0;
public TcpReadContext(NioSocketChannel channel, TcpReadHandler handler) {
this(channel, handler, new TcpFrameDecoder());
}
public TcpReadContext(NioSocketChannel channel, TcpReadHandler handler, TcpFrameDecoder frameDecoder) {
this.handler = handler;
this.channel = channel;
this.frameDecoder = frameDecoder;
this.references.add(NetworkBytesReference.wrap(new BytesArray(new byte[DEFAULT_READ_LENGTH])));
}
@Override
public int read() throws IOException {
NetworkBytesReference last = references.peekLast();
if (last == null || last.hasWriteRemaining() == false) {
this.references.add(NetworkBytesReference.wrap(new BytesArray(new byte[DEFAULT_READ_LENGTH])));
}
int bytesRead = channel.read(references.getLast());
if (bytesRead == -1) {
return bytesRead;
}
rawBytesCount += bytesRead;
BytesReference message;
while ((message = frameDecoder.decode(createCompositeBuffer(), rawBytesCount)) != null) {
int messageLengthWithHeader = message.length();
NetworkBytesReference.vectorizedIncrementReadIndexes(references, messageLengthWithHeader);
trimDecodedMessages(messageLengthWithHeader);
rawBytesCount -= messageLengthWithHeader;
try {
BytesReference messageWithoutHeader = message.slice(6, message.length() - 6);
handler.handleMessage(messageWithoutHeader, channel, channel.getProfile(), messageWithoutHeader.length());
} catch (Exception e) {
handler.handleException(channel, e);
}
}
return bytesRead;
}
private CompositeBytesReference createCompositeBuffer() {
return new CompositeBytesReference(references.toArray(new BytesReference[references.size()]));
}
private void trimDecodedMessages(int bytesToTrim) {
while (bytesToTrim != 0) {
NetworkBytesReference ref = references.getFirst();
int readIndex = ref.getReadIndex();
bytesToTrim -= readIndex;
if (readIndex == ref.length()) {
references.removeFirst();
} else {
assert bytesToTrim == 0;
if (readIndex != 0) {
references.removeFirst();
NetworkBytesReference slicedRef = ref.slice(readIndex, ref.length() - readIndex);
references.addFirst(slicedRef);
}
}
}
}
}

View File

@ -0,0 +1,108 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio.channel;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.transport.nio.SocketSelector;
import org.elasticsearch.transport.nio.WriteOperation;
import java.io.IOException;
import java.nio.channels.ClosedChannelException;
import java.util.LinkedList;
public class TcpWriteContext implements WriteContext {
private final NioSocketChannel channel;
private final LinkedList<WriteOperation> queued = new LinkedList<>();
public TcpWriteContext(NioSocketChannel channel) {
this.channel = channel;
}
@Override
public void sendMessage(BytesReference reference, ActionListener<NioChannel> listener) {
if (channel.isWritable() == false) {
listener.onFailure(new ClosedChannelException());
return;
}
WriteOperation writeOperation = new WriteOperation(channel, reference, listener);
SocketSelector selector = channel.getSelector();
if (selector.isOnCurrentThread() == false) {
selector.queueWrite(writeOperation);
return;
}
// TODO: Eval if we will allow writes from sendMessage
selector.queueWriteInChannelBuffer(writeOperation);
}
@Override
public void queueWriteOperations(WriteOperation writeOperation) {
assert channel.getSelector().isOnCurrentThread() : "Must be on selector thread to queue writes";
queued.add(writeOperation);
}
@Override
public void flushChannel() throws IOException {
assert channel.getSelector().isOnCurrentThread() : "Must be on selector thread to flush writes";
int ops = queued.size();
if (ops == 1) {
singleFlush(queued.pop());
} else if (ops > 1) {
multiFlush();
}
}
@Override
public boolean hasQueuedWriteOps() {
assert channel.getSelector().isOnCurrentThread() : "Must be on selector thread to access queued writes";
return queued.isEmpty() == false;
}
@Override
public void clearQueuedWriteOps(Exception e) {
assert channel.getSelector().isOnCurrentThread() : "Must be on selector thread to clear queued writes";
for (WriteOperation op : queued) {
op.getListener().onFailure(e);
}
queued.clear();
}
private void singleFlush(WriteOperation headOp) throws IOException {
headOp.flush();
if (headOp.isFullyFlushed()) {
headOp.getListener().onResponse(channel);
} else {
queued.push(headOp);
}
}
private void multiFlush() throws IOException {
boolean lastOpCompleted = true;
while (lastOpCompleted && queued.isEmpty() == false) {
WriteOperation op = queued.pop();
singleFlush(op);
lastOpCompleted = op.isFullyFlushed();
}
}
}

View File

@ -0,0 +1,40 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio.channel;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.transport.nio.WriteOperation;
import java.io.IOException;
public interface WriteContext {
void sendMessage(BytesReference reference, ActionListener<NioChannel> listener);
void queueWriteOperations(WriteOperation writeOperation);
void flushChannel() throws IOException;
boolean hasQueuedWriteOps();
void clearQueuedWriteOps(Exception e);
}

View File

@ -0,0 +1,113 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.transport.nio.channel.NioChannel;
import org.elasticsearch.transport.nio.channel.NioServerSocketChannel;
import org.elasticsearch.transport.nio.utils.TestSelectionKey;
import org.junit.Before;
import java.io.IOException;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.security.PrivilegedActionException;
import java.util.HashSet;
import java.util.Set;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class AcceptingSelectorTests extends ESTestCase {
private AcceptingSelector selector;
private NioServerSocketChannel serverChannel;
private AcceptorEventHandler eventHandler;
private TestSelectionKey selectionKey;
private HashSet<SelectionKey> keySet = new HashSet<>();
@Before
public void setUp() throws Exception {
super.setUp();
eventHandler = mock(AcceptorEventHandler.class);
serverChannel = mock(NioServerSocketChannel.class);
Selector rawSelector = mock(Selector.class);
selector = new AcceptingSelector(eventHandler, rawSelector);
this.selector.setThread();
selectionKey = new TestSelectionKey(0);
selectionKey.attach(serverChannel);
when(serverChannel.getSelectionKey()).thenReturn(selectionKey);
when(rawSelector.selectedKeys()).thenReturn(keySet);
when(rawSelector.select(0)).thenReturn(1);
}
public void testRegisteredChannel() throws IOException, PrivilegedActionException {
selector.registerServerChannel(serverChannel);
when(serverChannel.register(selector)).thenReturn(true);
selector.doSelect(0);
verify(eventHandler).serverChannelRegistered(serverChannel);
Set<NioChannel> registeredChannels = selector.getRegisteredChannels();
assertEquals(1, registeredChannels.size());
assertTrue(registeredChannels.contains(serverChannel));
}
public void testAcceptEvent() throws IOException {
selectionKey.setReadyOps(SelectionKey.OP_ACCEPT);
keySet.add(selectionKey);
selector.doSelect(0);
verify(eventHandler).acceptChannel(serverChannel);
}
public void testAcceptException() throws IOException {
selectionKey.setReadyOps(SelectionKey.OP_ACCEPT);
keySet.add(selectionKey);
IOException ioException = new IOException();
doThrow(ioException).when(eventHandler).acceptChannel(serverChannel);
selector.doSelect(0);
verify(eventHandler).acceptException(serverChannel, ioException);
}
public void testCleanup() throws IOException {
selector.registerServerChannel(serverChannel);
when(serverChannel.register(selector)).thenReturn(true);
selector.doSelect(0);
assertEquals(1, selector.getRegisteredChannels().size());
selector.cleanup();
verify(eventHandler).handleClose(serverChannel);
}
}

View File

@ -0,0 +1,99 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.transport.nio.channel.ChannelFactory;
import org.elasticsearch.transport.nio.channel.DoNotRegisterServerChannel;
import org.elasticsearch.transport.nio.channel.NioServerSocketChannel;
import org.elasticsearch.transport.nio.channel.NioSocketChannel;
import org.junit.Before;
import java.io.IOException;
import java.nio.channels.SelectionKey;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class AcceptorEventHandlerTests extends ESTestCase {
private AcceptorEventHandler handler;
private SocketSelector socketSelector;
private ChannelFactory channelFactory;
private OpenChannels openChannels;
private NioServerSocketChannel channel;
@Before
public void setUpHandler() throws IOException {
channelFactory = mock(ChannelFactory.class);
socketSelector = mock(SocketSelector.class);
openChannels = new OpenChannels(logger);
ArrayList<SocketSelector> selectors = new ArrayList<>();
selectors.add(socketSelector);
handler = new AcceptorEventHandler(logger, openChannels, new RoundRobinSelectorSupplier(selectors));
channel = new DoNotRegisterServerChannel("", mock(ServerSocketChannel.class), channelFactory);
channel.register(mock(ESSelector.class));
}
public void testHandleRegisterAdjustsOpenChannels() {
assertEquals(0, openChannels.serverChannelsCount());
handler.serverChannelRegistered(channel);
assertEquals(1, openChannels.serverChannelsCount());
}
public void testHandleRegisterSetsOP_ACCEPTInterest() {
assertEquals(0, channel.getSelectionKey().interestOps());
handler.serverChannelRegistered(channel);
assertEquals(SelectionKey.OP_ACCEPT, channel.getSelectionKey().interestOps());
}
public void testHandleAcceptRegistersWithSelector() throws IOException {
NioSocketChannel childChannel = new NioSocketChannel("", mock(SocketChannel.class));
when(channelFactory.acceptNioChannel(channel)).thenReturn(childChannel);
handler.acceptChannel(channel);
verify(socketSelector).registerSocketChannel(childChannel);
}
public void testHandleAcceptAddsToOpenChannelsAndAddsCloseListenerToRemove() throws IOException {
NioSocketChannel childChannel = new NioSocketChannel("", SocketChannel.open());
when(channelFactory.acceptNioChannel(channel)).thenReturn(childChannel);
handler.acceptChannel(channel);
assertEquals(new HashSet<>(Arrays.asList(childChannel)), openChannels.getAcceptedChannels());
childChannel.closeAsync();
assertEquals(new HashSet<>(), openChannels.getAcceptedChannels());
}
}

View File

@ -0,0 +1,155 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.test.ESTestCase;
import java.nio.ByteBuffer;
public class ByteBufferReferenceTests extends ESTestCase {
private NetworkBytesReference buffer;
public void testBasicGetByte() {
byte[] bytes = new byte[10];
initializeBytes(bytes);
buffer = NetworkBytesReference.wrap(new BytesArray(bytes));
assertEquals(10, buffer.length());
for (int i = 0 ; i < bytes.length; ++i) {
assertEquals(i, buffer.get(i));
}
}
public void testBasicGetByteWithOffset() {
byte[] bytes = new byte[10];
initializeBytes(bytes);
buffer = NetworkBytesReference.wrap(new BytesArray(bytes, 2, 8));
assertEquals(8, buffer.length());
for (int i = 2 ; i < bytes.length; ++i) {
assertEquals(i, buffer.get(i - 2));
}
}
public void testBasicGetByteWithOffsetAndLimit() {
byte[] bytes = new byte[10];
initializeBytes(bytes);
buffer = NetworkBytesReference.wrap(new BytesArray(bytes, 2, 6));
assertEquals(6, buffer.length());
for (int i = 2 ; i < bytes.length - 2; ++i) {
assertEquals(i, buffer.get(i - 2));
}
}
public void testGetWriteBufferRespectsWriteIndex() {
byte[] bytes = new byte[10];
buffer = NetworkBytesReference.wrap(new BytesArray(bytes, 2, 8));
ByteBuffer writeByteBuffer = buffer.getWriteByteBuffer();
assertEquals(2, writeByteBuffer.position());
assertEquals(10, writeByteBuffer.limit());
buffer.incrementWrite(2);
writeByteBuffer = buffer.getWriteByteBuffer();
assertEquals(4, writeByteBuffer.position());
assertEquals(10, writeByteBuffer.limit());
}
public void testGetReadBufferRespectsReadIndex() {
byte[] bytes = new byte[10];
buffer = NetworkBytesReference.wrap(new BytesArray(bytes, 3, 6), 6, 0);
ByteBuffer readByteBuffer = buffer.getReadByteBuffer();
assertEquals(3, readByteBuffer.position());
assertEquals(9, readByteBuffer.limit());
buffer.incrementRead(2);
readByteBuffer = buffer.getReadByteBuffer();
assertEquals(5, readByteBuffer.position());
assertEquals(9, readByteBuffer.limit());
}
public void testWriteAndReadRemaining() {
byte[] bytes = new byte[10];
buffer = NetworkBytesReference.wrap(new BytesArray(bytes, 2, 8));
assertEquals(0, buffer.getReadRemaining());
assertEquals(8, buffer.getWriteRemaining());
buffer.incrementWrite(3);
buffer.incrementRead(2);
assertEquals(1, buffer.getReadRemaining());
assertEquals(5, buffer.getWriteRemaining());
}
public void testBasicSlice() {
byte[] bytes = new byte[20];
initializeBytes(bytes);
buffer = NetworkBytesReference.wrap(new BytesArray(bytes, 2, 18));
NetworkBytesReference slice = buffer.slice(4, 14);
assertEquals(14, slice.length());
assertEquals(0, slice.getReadIndex());
assertEquals(0, slice.getWriteIndex());
for (int i = 6; i < 20; ++i) {
assertEquals(i, slice.get(i - 6));
}
}
public void testSliceWithReadAndWriteIndexes() {
byte[] bytes = new byte[20];
initializeBytes(bytes);
buffer = NetworkBytesReference.wrap(new BytesArray(bytes, 2, 18));
buffer.incrementWrite(9);
buffer.incrementRead(5);
NetworkBytesReference slice = buffer.slice(6, 12);
assertEquals(12, slice.length());
assertEquals(0, slice.getReadIndex());
assertEquals(3, slice.getWriteIndex());
for (int i = 8; i < 20; ++i) {
assertEquals(i, slice.get(i - 8));
}
}
private void initializeBytes(byte[] bytes) {
for (int i = 0 ; i < bytes.length; ++i) {
bytes[i] = (byte) i;
}
}
}

View File

@ -0,0 +1,114 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.transport.nio.channel.NioChannel;
import org.junit.Before;
import java.io.IOException;
import java.nio.channels.ClosedSelectorException;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
public class ESSelectorTests extends ESTestCase {
private ESSelector selector;
private EventHandler handler;
@Before
public void setUp() throws Exception {
super.setUp();
handler = mock(EventHandler.class);
selector = new TestSelector(handler);
}
public void testQueueChannelForClosed() throws IOException {
NioChannel channel = mock(NioChannel.class);
selector.registeredChannels.add(channel);
selector.queueChannelClose(channel);
assertEquals(1, selector.getRegisteredChannels().size());
selector.singleLoop();
verify(handler).handleClose(channel);
assertEquals(0, selector.getRegisteredChannels().size());
}
public void testSelectorClosedExceptionIsNotCaughtWhileRunning() throws IOException {
((TestSelector) this.selector).setClosedSelectorException(new ClosedSelectorException());
boolean closedSelectorExceptionCaught = false;
try {
this.selector.singleLoop();
} catch (ClosedSelectorException e) {
closedSelectorExceptionCaught = true;
}
assertTrue(closedSelectorExceptionCaught);
}
public void testIOExceptionWhileSelect() throws IOException {
IOException ioException = new IOException();
((TestSelector) this.selector).setIOException(ioException);
this.selector.singleLoop();
verify(handler).selectException(ioException);
}
private static class TestSelector extends ESSelector {
private ClosedSelectorException closedSelectorException;
private IOException ioException;
protected TestSelector(EventHandler eventHandler) throws IOException {
super(eventHandler);
}
@Override
void doSelect(int timeout) throws IOException, ClosedSelectorException {
if (closedSelectorException != null) {
throw closedSelectorException;
}
if (ioException != null) {
throw ioException;
}
}
@Override
void cleanup() {
}
public void setClosedSelectorException(ClosedSelectorException exception) {
this.closedSelectorException = exception;
}
public void setIOException(IOException ioException) {
this.ioException = ioException;
}
}
}

View File

@ -0,0 +1,193 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.elasticsearch.Version;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.transport.ConnectTransportException;
import org.elasticsearch.transport.nio.channel.ChannelFactory;
import org.elasticsearch.transport.nio.channel.CloseFuture;
import org.elasticsearch.transport.nio.channel.ConnectFuture;
import org.elasticsearch.transport.nio.channel.NioChannel;
import org.elasticsearch.transport.nio.channel.NioSocketChannel;
import org.junit.Before;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.function.Supplier;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class NioClientTests extends ESTestCase {
private NioClient client;
private SocketSelector selector;
private ChannelFactory channelFactory;
private OpenChannels openChannels = new OpenChannels(logger);
private NioSocketChannel[] channels;
private DiscoveryNode node;
private Consumer<NioChannel> listener;
private TransportAddress address;
@Before
@SuppressWarnings("unchecked")
public void setUpClient() {
channelFactory = mock(ChannelFactory.class);
selector = mock(SocketSelector.class);
listener = mock(Consumer.class);
ArrayList<SocketSelector> selectors = new ArrayList<>();
selectors.add(selector);
Supplier<SocketSelector> selectorSupplier = new RoundRobinSelectorSupplier(selectors);
client = new NioClient(logger, openChannels, selectorSupplier, TimeValue.timeValueMillis(5), channelFactory);
channels = new NioSocketChannel[2];
address = new TransportAddress(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0));
node = new DiscoveryNode("node-id", address, Version.CURRENT);
}
public void testCreateConnections() throws IOException, InterruptedException {
NioSocketChannel channel1 = mock(NioSocketChannel.class);
ConnectFuture connectFuture1 = mock(ConnectFuture.class);
CloseFuture closeFuture1 = mock(CloseFuture.class);
NioSocketChannel channel2 = mock(NioSocketChannel.class);
ConnectFuture connectFuture2 = mock(ConnectFuture.class);
CloseFuture closeFuture2 = mock(CloseFuture.class);
when(channelFactory.openNioChannel(address.address())).thenReturn(channel1, channel2);
when(channel1.getCloseFuture()).thenReturn(closeFuture1);
when(channel1.getConnectFuture()).thenReturn(connectFuture1);
when(channel2.getCloseFuture()).thenReturn(closeFuture2);
when(channel2.getConnectFuture()).thenReturn(connectFuture2);
when(connectFuture1.awaitConnectionComplete(5, TimeUnit.MILLISECONDS)).thenReturn(true);
when(connectFuture2.awaitConnectionComplete(5, TimeUnit.MILLISECONDS)).thenReturn(true);
client.connectToChannels(node, channels, TimeValue.timeValueMillis(5), listener);
verify(closeFuture1).setListener(listener);
verify(closeFuture2).setListener(listener);
verify(selector).registerSocketChannel(channel1);
verify(selector).registerSocketChannel(channel2);
assertEquals(channel1, channels[0]);
assertEquals(channel2, channels[1]);
}
public void testWithADifferentConnectTimeout() throws IOException, InterruptedException {
NioSocketChannel channel1 = mock(NioSocketChannel.class);
ConnectFuture connectFuture1 = mock(ConnectFuture.class);
CloseFuture closeFuture1 = mock(CloseFuture.class);
when(channelFactory.openNioChannel(address.address())).thenReturn(channel1);
when(channel1.getCloseFuture()).thenReturn(closeFuture1);
when(channel1.getConnectFuture()).thenReturn(connectFuture1);
when(connectFuture1.awaitConnectionComplete(3, TimeUnit.MILLISECONDS)).thenReturn(true);
channels = new NioSocketChannel[1];
client.connectToChannels(node, channels, TimeValue.timeValueMillis(3), listener);
verify(closeFuture1).setListener(listener);
verify(selector).registerSocketChannel(channel1);
assertEquals(channel1, channels[0]);
}
public void testConnectionTimeout() throws IOException, InterruptedException {
NioSocketChannel channel1 = mock(NioSocketChannel.class);
ConnectFuture connectFuture1 = mock(ConnectFuture.class);
CloseFuture closeFuture1 = mock(CloseFuture.class);
NioSocketChannel channel2 = mock(NioSocketChannel.class);
ConnectFuture connectFuture2 = mock(ConnectFuture.class);
CloseFuture closeFuture2 = mock(CloseFuture.class);
when(channelFactory.openNioChannel(address.address())).thenReturn(channel1, channel2);
when(channel1.getCloseFuture()).thenReturn(closeFuture1);
when(channel1.getConnectFuture()).thenReturn(connectFuture1);
when(channel2.getCloseFuture()).thenReturn(closeFuture2);
when(channel2.getConnectFuture()).thenReturn(connectFuture2);
when(connectFuture1.awaitConnectionComplete(5, TimeUnit.MILLISECONDS)).thenReturn(true);
when(connectFuture2.awaitConnectionComplete(5, TimeUnit.MILLISECONDS)).thenReturn(false);
try {
client.connectToChannels(node, channels, TimeValue.timeValueMillis(5), listener);
fail("Should have thrown ConnectTransportException");
} catch (ConnectTransportException e) {
assertTrue(e.getMessage().contains("connect_timeout[5ms]"));
}
verify(channel1).closeAsync();
verify(channel2).closeAsync();
assertNull(channels[0]);
assertNull(channels[1]);
}
public void testConnectionException() throws IOException, InterruptedException {
NioSocketChannel channel1 = mock(NioSocketChannel.class);
ConnectFuture connectFuture1 = mock(ConnectFuture.class);
CloseFuture closeFuture1 = mock(CloseFuture.class);
NioSocketChannel channel2 = mock(NioSocketChannel.class);
ConnectFuture connectFuture2 = mock(ConnectFuture.class);
CloseFuture closeFuture2 = mock(CloseFuture.class);
IOException ioException = new IOException();
when(channelFactory.openNioChannel(address.address())).thenReturn(channel1, channel2);
when(channel1.getCloseFuture()).thenReturn(closeFuture1);
when(channel1.getConnectFuture()).thenReturn(connectFuture1);
when(channel2.getCloseFuture()).thenReturn(closeFuture2);
when(channel2.getConnectFuture()).thenReturn(connectFuture2);
when(connectFuture1.awaitConnectionComplete(5, TimeUnit.MILLISECONDS)).thenReturn(true);
when(connectFuture2.awaitConnectionComplete(5, TimeUnit.MILLISECONDS)).thenReturn(false);
when(connectFuture2.getException()).thenReturn(ioException);
try {
client.connectToChannels(node, channels, TimeValue.timeValueMillis(5), listener);
fail("Should have thrown ConnectTransportException");
} catch (ConnectTransportException e) {
assertTrue(e.getMessage().contains("connect_exception"));
assertSame(ioException, e.getCause());
}
verify(channel1).closeAsync();
verify(channel2).closeAsync();
assertNull(channels[0]);
assertNull(channels[1]);
}
public void testCloseDoesNotAllowConnections() throws IOException {
client.close();
assertFalse(client.connectToChannels(node, channels, TimeValue.timeValueMillis(5), listener));
for (NioSocketChannel channel : channels) {
assertNull(channel);
}
}
}

View File

@ -0,0 +1,132 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.elasticsearch.Version;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.node.Node;
import org.elasticsearch.test.transport.MockTransportService;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.AbstractSimpleTransportTestCase;
import org.elasticsearch.transport.BindTransportException;
import org.elasticsearch.transport.ConnectTransportException;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.transport.TransportSettings;
import org.elasticsearch.transport.nio.channel.NioChannel;
import java.io.IOException;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.Collections;
import static java.util.Collections.emptyMap;
import static java.util.Collections.emptySet;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.instanceOf;
public class SimpleNioTransportTests extends AbstractSimpleTransportTestCase {
public static MockTransportService nioFromThreadPool(Settings settings, ThreadPool threadPool, final Version version,
ClusterSettings clusterSettings, boolean doHandshake) {
NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList());
NetworkService networkService = new NetworkService(settings, Collections.emptyList());
Transport transport = new NioTransport(settings, threadPool,
networkService,
BigArrays.NON_RECYCLING_INSTANCE, namedWriteableRegistry, new NoneCircuitBreakerService()) {
@Override
protected Version executeHandshake(DiscoveryNode node, NioChannel channel, TimeValue timeout) throws IOException,
InterruptedException {
if (doHandshake) {
return super.executeHandshake(node, channel, timeout);
} else {
return version.minimumCompatibilityVersion();
}
}
@Override
protected Version getCurrentVersion() {
return version;
}
@Override
protected SocketEventHandler getSocketEventHandler() {
return new TestingSocketEventHandler(logger, this::exceptionCaught);
}
};
MockTransportService mockTransportService =
MockTransportService.createNewService(Settings.EMPTY, transport, version, threadPool, clusterSettings);
mockTransportService.start();
return mockTransportService;
}
@Override
protected MockTransportService build(Settings settings, Version version, ClusterSettings clusterSettings, boolean doHandshake) {
settings = Settings.builder().put(settings).put(TransportSettings.PORT.getKey(), "0").build();
MockTransportService transportService = nioFromThreadPool(settings, threadPool, version, clusterSettings, doHandshake);
transportService.start();
return transportService;
}
public void testConnectException() throws UnknownHostException {
try {
serviceA.connectToNode(new DiscoveryNode("C", new TransportAddress(InetAddress.getByName("localhost"), 9876),
emptyMap(), emptySet(),Version.CURRENT));
fail("Expected ConnectTransportException");
} catch (ConnectTransportException e) {
assertThat(e.getMessage(), containsString("connect_exception"));
assertThat(e.getMessage(), containsString("[127.0.0.1:9876]"));
Throwable cause = e.getCause();
assertThat(cause, instanceOf(IOException.class));
assertEquals("Connection refused", cause.getMessage());
}
}
public void testBindUnavailableAddress() {
// this is on a lower level since it needs access to the TransportService before it's started
int port = serviceA.boundAddress().publishAddress().getPort();
Settings settings = Settings.builder()
.put(Node.NODE_NAME_SETTING.getKey(), "foobar")
.put(TransportService.TRACE_LOG_INCLUDE_SETTING.getKey(), "")
.put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), "NOTHING")
.put("transport.tcp.port", port)
.build();
ClusterSettings clusterSettings = new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
BindTransportException bindTransportException = expectThrows(BindTransportException.class, () -> {
MockTransportService transportService = nioFromThreadPool(settings, threadPool, Version.CURRENT, clusterSettings, true);
try {
transportService.start();
} finally {
transportService.stop();
transportService.close();
}
});
assertEquals("Failed to bind to ["+ port + "]", bindTransportException.getMessage());
}
}

View File

@ -0,0 +1,175 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.transport.nio.channel.CloseFuture;
import org.elasticsearch.transport.nio.channel.DoNotRegisterChannel;
import org.elasticsearch.transport.nio.channel.NioChannel;
import org.elasticsearch.transport.nio.channel.NioSocketChannel;
import org.elasticsearch.transport.nio.channel.ReadContext;
import org.elasticsearch.transport.nio.channel.SelectionKeyUtils;
import org.elasticsearch.transport.nio.channel.TcpWriteContext;
import org.junit.Before;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.CancelledKeyException;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.util.function.BiConsumer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class SocketEventHandlerTests extends ESTestCase {
private BiConsumer<NioSocketChannel, Throwable> exceptionHandler;
private SocketEventHandler handler;
private NioSocketChannel channel;
private ReadContext readContext;
private SocketChannel rawChannel;
@Before
@SuppressWarnings("unchecked")
public void setUpHandler() throws IOException {
exceptionHandler = mock(BiConsumer.class);
SocketSelector socketSelector = mock(SocketSelector.class);
handler = new SocketEventHandler(logger, exceptionHandler);
rawChannel = mock(SocketChannel.class);
channel = new DoNotRegisterChannel("", rawChannel);
readContext = mock(ReadContext.class);
when(rawChannel.finishConnect()).thenReturn(true);
channel.setContexts(readContext, new TcpWriteContext(channel));
channel.register(socketSelector);
channel.finishConnect();
when(socketSelector.isOnCurrentThread()).thenReturn(true);
}
public void testRegisterAddsOP_CONNECTAndOP_READInterest() throws IOException {
handler.handleRegistration(channel);
assertEquals(SelectionKey.OP_READ | SelectionKey.OP_CONNECT, channel.getSelectionKey().interestOps());
}
public void testRegistrationExceptionCallsExceptionHandler() throws IOException {
CancelledKeyException exception = new CancelledKeyException();
handler.registrationException(channel, exception);
verify(exceptionHandler).accept(channel, exception);
}
public void testConnectRemovesOP_CONNECTInterest() throws IOException {
SelectionKeyUtils.setConnectAndReadInterested(channel);
handler.handleConnect(channel);
assertEquals(SelectionKey.OP_READ, channel.getSelectionKey().interestOps());
}
public void testConnectExceptionCallsExceptionHandler() throws IOException {
IOException exception = new IOException();
handler.connectException(channel, exception);
verify(exceptionHandler).accept(channel, exception);
}
public void testHandleReadDelegatesToReadContext() throws IOException {
when(readContext.read()).thenReturn(1);
handler.handleRead(channel);
verify(readContext).read();
}
public void testHandleReadMarksChannelForCloseIfPeerClosed() throws IOException {
NioSocketChannel nioSocketChannel = mock(NioSocketChannel.class);
CloseFuture closeFuture = mock(CloseFuture.class);
when(nioSocketChannel.getReadContext()).thenReturn(readContext);
when(readContext.read()).thenReturn(-1);
when(nioSocketChannel.getCloseFuture()).thenReturn(closeFuture);
when(closeFuture.isDone()).thenReturn(true);
handler.handleRead(nioSocketChannel);
verify(nioSocketChannel).closeFromSelector();
}
public void testReadExceptionCallsExceptionHandler() throws IOException {
IOException exception = new IOException();
handler.readException(channel, exception);
verify(exceptionHandler).accept(channel, exception);
}
@SuppressWarnings("unchecked")
public void testHandleWriteWithCompleteFlushRemovesOP_WRITEInterest() throws IOException {
SelectionKey selectionKey = channel.getSelectionKey();
setWriteAndRead(channel);
assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, selectionKey.interestOps());
BytesArray bytesArray = new BytesArray(new byte[1]);
NetworkBytesReference networkBuffer = NetworkBytesReference.wrap(bytesArray);
channel.getWriteContext().queueWriteOperations(new WriteOperation(channel, networkBuffer, mock(ActionListener.class)));
when(rawChannel.write(ByteBuffer.wrap(bytesArray.array()))).thenReturn(1);
handler.handleWrite(channel);
assertEquals(SelectionKey.OP_READ, selectionKey.interestOps());
}
@SuppressWarnings("unchecked")
public void testHandleWriteWithInCompleteFlushLeavesOP_WRITEInterest() throws IOException {
SelectionKey selectionKey = channel.getSelectionKey();
setWriteAndRead(channel);
assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, selectionKey.interestOps());
BytesArray bytesArray = new BytesArray(new byte[1]);
NetworkBytesReference networkBuffer = NetworkBytesReference.wrap(bytesArray, 1, 0);
channel.getWriteContext().queueWriteOperations(new WriteOperation(channel, networkBuffer, mock(ActionListener.class)));
when(rawChannel.write(ByteBuffer.wrap(bytesArray.array()))).thenReturn(0);
handler.handleWrite(channel);
assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, selectionKey.interestOps());
}
public void testHandleWriteWithNoOpsRemovesOP_WRITEInterest() throws IOException {
SelectionKey selectionKey = channel.getSelectionKey();
setWriteAndRead(channel);
assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, channel.getSelectionKey().interestOps());
handler.handleWrite(channel);
assertEquals(SelectionKey.OP_READ, selectionKey.interestOps());
}
private void setWriteAndRead(NioChannel channel) {
SelectionKeyUtils.setConnectAndReadInterested(channel);
SelectionKeyUtils.removeConnectInterested(channel);
SelectionKeyUtils.setWriteInterested(channel);
}
public void testWriteExceptionCallsExceptionHandler() throws IOException {
IOException exception = new IOException();
handler.writeException(channel, exception);
verify(exceptionHandler).accept(channel, exception);
}
}

View File

@ -0,0 +1,336 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.transport.nio.channel.NioChannel;
import org.elasticsearch.transport.nio.channel.NioSocketChannel;
import org.elasticsearch.transport.nio.channel.WriteContext;
import org.elasticsearch.transport.nio.utils.TestSelectionKey;
import org.junit.Before;
import java.io.IOException;
import java.nio.channels.CancelledKeyException;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.ClosedSelectorException;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.util.HashSet;
import java.util.Set;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class SocketSelectorTests extends ESTestCase {
private SocketSelector socketSelector;
private SocketEventHandler eventHandler;
private NioSocketChannel channel;
private TestSelectionKey selectionKey;
private WriteContext writeContext;
private HashSet<SelectionKey> keySet = new HashSet<>();
private ActionListener<NioChannel> listener;
private NetworkBytesReference bufferReference = NetworkBytesReference.wrap(new BytesArray(new byte[1]));
@Before
@SuppressWarnings("unchecked")
public void setUp() throws Exception {
super.setUp();
eventHandler = mock(SocketEventHandler.class);
channel = mock(NioSocketChannel.class);
writeContext = mock(WriteContext.class);
listener = mock(ActionListener.class);
selectionKey = new TestSelectionKey(0);
selectionKey.attach(channel);
Selector rawSelector = mock(Selector.class);
this.socketSelector = new SocketSelector(eventHandler, rawSelector);
this.socketSelector.setThread();
when(rawSelector.selectedKeys()).thenReturn(keySet);
when(rawSelector.select(0)).thenReturn(1);
when(channel.getSelectionKey()).thenReturn(selectionKey);
when(channel.getWriteContext()).thenReturn(writeContext);
when(channel.isConnectComplete()).thenReturn(true);
}
public void testRegisterChannel() throws Exception {
socketSelector.registerSocketChannel(channel);
when(channel.register(socketSelector)).thenReturn(true);
socketSelector.doSelect(0);
verify(eventHandler).handleRegistration(channel);
Set<NioChannel> registeredChannels = socketSelector.getRegisteredChannels();
assertEquals(1, registeredChannels.size());
assertTrue(registeredChannels.contains(channel));
}
public void testRegisterChannelFails() throws Exception {
socketSelector.registerSocketChannel(channel);
when(channel.register(socketSelector)).thenReturn(false);
socketSelector.doSelect(0);
verify(channel, times(0)).finishConnect();
Set<NioChannel> registeredChannels = socketSelector.getRegisteredChannels();
assertEquals(0, registeredChannels.size());
assertFalse(registeredChannels.contains(channel));
}
public void testRegisterChannelFailsDueToException() throws Exception {
socketSelector.registerSocketChannel(channel);
ClosedChannelException closedChannelException = new ClosedChannelException();
when(channel.register(socketSelector)).thenThrow(closedChannelException);
socketSelector.doSelect(0);
verify(eventHandler).registrationException(channel, closedChannelException);
verify(channel, times(0)).finishConnect();
Set<NioChannel> registeredChannels = socketSelector.getRegisteredChannels();
assertEquals(0, registeredChannels.size());
assertFalse(registeredChannels.contains(channel));
}
public void testSuccessfullyRegisterChannelWillConnect() throws Exception {
socketSelector.registerSocketChannel(channel);
when(channel.register(socketSelector)).thenReturn(true);
when(channel.finishConnect()).thenReturn(true);
socketSelector.doSelect(0);
verify(eventHandler).handleConnect(channel);
}
public void testConnectIncompleteWillNotNotify() throws Exception {
socketSelector.registerSocketChannel(channel);
when(channel.register(socketSelector)).thenReturn(true);
when(channel.finishConnect()).thenReturn(false);
socketSelector.doSelect(0);
verify(eventHandler, times(0)).handleConnect(channel);
}
public void testQueueWriteWhenNotRunning() throws Exception {
socketSelector.close(false);
socketSelector.queueWrite(new WriteOperation(channel, bufferReference, listener));
verify(listener).onFailure(any(ClosedSelectorException.class));
}
public void testQueueWriteChannelIsNoLongerWritable() throws Exception {
WriteOperation writeOperation = new WriteOperation(channel, bufferReference, listener);
socketSelector.queueWrite(writeOperation);
when(channel.isWritable()).thenReturn(false);
socketSelector.doSelect(0);
verify(writeContext, times(0)).queueWriteOperations(writeOperation);
verify(listener).onFailure(any(ClosedChannelException.class));
}
public void testQueueWriteSelectionKeyThrowsException() throws Exception {
SelectionKey selectionKey = mock(SelectionKey.class);
WriteOperation writeOperation = new WriteOperation(channel, bufferReference, listener);
CancelledKeyException cancelledKeyException = new CancelledKeyException();
socketSelector.queueWrite(writeOperation);
when(channel.isWritable()).thenReturn(true);
when(channel.getSelectionKey()).thenReturn(selectionKey);
when(selectionKey.interestOps(anyInt())).thenThrow(cancelledKeyException);
socketSelector.doSelect(0);
verify(writeContext, times(0)).queueWriteOperations(writeOperation);
verify(listener).onFailure(cancelledKeyException);
}
public void testQueueWriteSuccessful() throws Exception {
WriteOperation writeOperation = new WriteOperation(channel, bufferReference, listener);
socketSelector.queueWrite(writeOperation);
assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) == 0);
when(channel.isWritable()).thenReturn(true);
socketSelector.doSelect(0);
verify(writeContext).queueWriteOperations(writeOperation);
assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) != 0);
}
public void testQueueDirectlyInChannelBufferSuccessful() throws Exception {
WriteOperation writeOperation = new WriteOperation(channel, bufferReference, listener);
assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) == 0);
when(channel.isWritable()).thenReturn(true);
socketSelector.queueWriteInChannelBuffer(writeOperation);
verify(writeContext).queueWriteOperations(writeOperation);
assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) != 0);
}
public void testQueueDirectlyInChannelBufferSelectionKeyThrowsException() throws Exception {
SelectionKey selectionKey = mock(SelectionKey.class);
WriteOperation writeOperation = new WriteOperation(channel, bufferReference, listener);
CancelledKeyException cancelledKeyException = new CancelledKeyException();
when(channel.isWritable()).thenReturn(true);
when(channel.getSelectionKey()).thenReturn(selectionKey);
when(selectionKey.interestOps(anyInt())).thenThrow(cancelledKeyException);
socketSelector.queueWriteInChannelBuffer(writeOperation);
verify(writeContext, times(0)).queueWriteOperations(writeOperation);
verify(listener).onFailure(cancelledKeyException);
}
public void testConnectEvent() throws Exception {
keySet.add(selectionKey);
selectionKey.setReadyOps(SelectionKey.OP_CONNECT);
when(channel.finishConnect()).thenReturn(true);
socketSelector.doSelect(0);
verify(eventHandler).handleConnect(channel);
}
public void testConnectEventFinishUnsuccessful() throws Exception {
keySet.add(selectionKey);
selectionKey.setReadyOps(SelectionKey.OP_CONNECT);
when(channel.finishConnect()).thenReturn(false);
socketSelector.doSelect(0);
verify(eventHandler, times(0)).handleConnect(channel);
}
public void testConnectEventFinishThrowException() throws Exception {
keySet.add(selectionKey);
IOException ioException = new IOException();
selectionKey.setReadyOps(SelectionKey.OP_CONNECT);
when(channel.finishConnect()).thenThrow(ioException);
socketSelector.doSelect(0);
verify(eventHandler, times(0)).handleConnect(channel);
verify(eventHandler).connectException(channel, ioException);
}
public void testWillNotConsiderWriteOrReadUntilConnectionComplete() throws Exception {
keySet.add(selectionKey);
IOException ioException = new IOException();
selectionKey.setReadyOps(SelectionKey.OP_WRITE | SelectionKey.OP_READ);
doThrow(ioException).when(eventHandler).handleWrite(channel);
when(channel.isConnectComplete()).thenReturn(false);
socketSelector.doSelect(0);
verify(eventHandler, times(0)).handleWrite(channel);
verify(eventHandler, times(0)).handleRead(channel);
}
public void testSuccessfulWriteEvent() throws Exception {
keySet.add(selectionKey);
selectionKey.setReadyOps(SelectionKey.OP_WRITE);
socketSelector.doSelect(0);
verify(eventHandler).handleWrite(channel);
}
public void testWriteEventWithException() throws Exception {
keySet.add(selectionKey);
IOException ioException = new IOException();
selectionKey.setReadyOps(SelectionKey.OP_WRITE);
doThrow(ioException).when(eventHandler).handleWrite(channel);
socketSelector.doSelect(0);
verify(eventHandler).writeException(channel, ioException);
}
public void testSuccessfulReadEvent() throws Exception {
keySet.add(selectionKey);
selectionKey.setReadyOps(SelectionKey.OP_READ);
socketSelector.doSelect(0);
verify(eventHandler).handleRead(channel);
}
public void testReadEventWithException() throws Exception {
keySet.add(selectionKey);
IOException ioException = new IOException();
selectionKey.setReadyOps(SelectionKey.OP_READ);
doThrow(ioException).when(eventHandler).handleRead(channel);
socketSelector.doSelect(0);
verify(eventHandler).readException(channel, ioException);
}
public void testCleanup() throws Exception {
NioSocketChannel unRegisteredChannel = mock(NioSocketChannel.class);
when(channel.register(socketSelector)).thenReturn(true);
socketSelector.registerSocketChannel(channel);
socketSelector.doSelect(0);
NetworkBytesReference networkBuffer = NetworkBytesReference.wrap(new BytesArray(new byte[1]));
socketSelector.queueWrite(new WriteOperation(mock(NioSocketChannel.class), networkBuffer, listener));
socketSelector.registerSocketChannel(unRegisteredChannel);
socketSelector.cleanup();
verify(listener).onFailure(any(ClosedSelectorException.class));
verify(eventHandler).handleClose(channel);
verify(eventHandler).handleClose(unRegisteredChannel);
}
}

View File

@ -0,0 +1,72 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.transport.nio.channel.NioSocketChannel;
import java.io.IOException;
import java.util.Collections;
import java.util.Set;
import java.util.WeakHashMap;
import java.util.function.BiConsumer;
public class TestingSocketEventHandler extends SocketEventHandler {
private final Logger logger;
public TestingSocketEventHandler(Logger logger, BiConsumer<NioSocketChannel, Throwable> exceptionHandler) {
super(logger, exceptionHandler);
this.logger = logger;
}
private Set<NioSocketChannel> hasConnectedMap = Collections.newSetFromMap(new WeakHashMap<>());
public void handleConnect(NioSocketChannel channel) {
assert hasConnectedMap.contains(channel) == false : "handleConnect should only be called once per channel";
hasConnectedMap.add(channel);
super.handleConnect(channel);
}
private Set<NioSocketChannel> hasConnectExceptionMap = Collections.newSetFromMap(new WeakHashMap<>());
public void connectException(NioSocketChannel channel, Exception e) {
assert hasConnectExceptionMap.contains(channel) == false : "connectException should only called at maximum once per channel";
hasConnectExceptionMap.add(channel);
super.connectException(channel, e);
}
public void handleRead(NioSocketChannel channel) throws IOException {
super.handleRead(channel);
}
public void readException(NioSocketChannel channel, Exception e) {
super.readException(channel, e);
}
public void handleWrite(NioSocketChannel channel) throws IOException {
super.handleWrite(channel);
}
public void writeException(NioSocketChannel channel, Exception e) {
super.writeException(channel, e);
}
}

View File

@ -0,0 +1,78 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.transport.nio.channel.NioChannel;
import org.elasticsearch.transport.nio.channel.NioSocketChannel;
import org.junit.Before;
import java.io.IOException;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class WriteOperationTests extends ESTestCase {
private NioSocketChannel channel;
private ActionListener<NioChannel> listener;
@Before
@SuppressWarnings("unchecked")
public void setFields() {
channel = mock(NioSocketChannel.class);
listener = mock(ActionListener.class);
}
public void testFlush() throws IOException {
WriteOperation writeOp = new WriteOperation(channel, new BytesArray(new byte[10]), listener);
when(channel.write(any())).thenAnswer(invocationOnMock -> {
NetworkBytesReference[] refs = (NetworkBytesReference[]) invocationOnMock.getArguments()[0];
refs[0].incrementRead(10);
return 10;
});
writeOp.flush();
assertTrue(writeOp.isFullyFlushed());
}
public void testPartialFlush() throws IOException {
WriteOperation writeOp = new WriteOperation(channel, new BytesArray(new byte[10]), listener);
when(channel.write(any())).thenAnswer(invocationOnMock -> {
NetworkBytesReference[] refs = (NetworkBytesReference[]) invocationOnMock.getArguments()[0];
refs[0].incrementRead(5);
return 5;
});
writeOp.flush();
assertFalse(writeOp.isFullyFlushed());
assertEquals(5, writeOp.getByteReferences()[0].getReadRemaining());
}
}

View File

@ -0,0 +1,99 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio.channel;
import org.elasticsearch.common.CheckedRunnable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.mocksocket.MockServerSocket;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.transport.nio.TcpReadHandler;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.io.InputStream;
import java.net.Socket;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;
import static org.mockito.Mockito.mock;
public abstract class AbstractNioChannelTestCase extends ESTestCase {
ChannelFactory channelFactory = new ChannelFactory(Settings.EMPTY, mock(TcpReadHandler.class));
MockServerSocket mockServerSocket;
private Thread serverThread;
@Before
public void serverSocketSetup() throws IOException {
mockServerSocket = new MockServerSocket(0);
serverThread = new Thread(() -> {
while (!mockServerSocket.isClosed()) {
try {
Socket socket = mockServerSocket.accept();
InputStream inputStream = socket.getInputStream();
socket.close();
} catch (IOException e) {
}
}
});
serverThread.start();
}
@After
public void serverSocketTearDown() throws IOException {
serverThread.interrupt();
mockServerSocket.close();
}
public abstract NioChannel channelToClose() throws IOException;
public void testClose() throws IOException, TimeoutException, InterruptedException {
AtomicReference<NioChannel> ref = new AtomicReference<>();
CountDownLatch latch = new CountDownLatch(1);
NioChannel socketChannel = channelToClose();
CloseFuture closeFuture = socketChannel.getCloseFuture();
closeFuture.setListener((c) -> {ref.set(c); latch.countDown();});
assertFalse(closeFuture.isClosed());
assertTrue(socketChannel.getRawChannel().isOpen());
socketChannel.closeAsync();
closeFuture.awaitClose(100, TimeUnit.SECONDS);
assertFalse(socketChannel.getRawChannel().isOpen());
assertTrue(closeFuture.isClosed());
latch.await();
assertSame(socketChannel, ref.get());
}
protected Runnable wrappedRunnable(CheckedRunnable<Exception> runnable) {
return () -> {
try {
runnable.run();
} catch (Exception e) {
}
};
}
}

View File

@ -0,0 +1,44 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio.channel;
import org.elasticsearch.transport.nio.ESSelector;
import org.elasticsearch.transport.nio.utils.TestSelectionKey;
import java.io.IOException;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.SocketChannel;
public class DoNotRegisterChannel extends NioSocketChannel {
public DoNotRegisterChannel(String profile, SocketChannel socketChannel) throws IOException {
super(profile, socketChannel);
}
@Override
public boolean register(ESSelector selector) throws ClosedChannelException {
if (markRegistered(selector)) {
setSelectionKey(new TestSelectionKey(0));
return true;
} else {
return false;
}
}
}

View File

@ -0,0 +1,44 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio.channel;
import org.elasticsearch.transport.nio.ESSelector;
import org.elasticsearch.transport.nio.utils.TestSelectionKey;
import java.io.IOException;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.ServerSocketChannel;
public class DoNotRegisterServerChannel extends NioServerSocketChannel {
public DoNotRegisterServerChannel(String profile, ServerSocketChannel channel, ChannelFactory channelFactory) throws IOException {
super(profile, channel, channelFactory);
}
@Override
public boolean register(ESSelector selector) throws ClosedChannelException {
if (markRegistered(selector)) {
setSelectionKey(new TestSelectionKey(0));
return true;
} else {
return false;
}
}
}

View File

@ -0,0 +1,33 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio.channel;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
public class NioServerSocketChannelTests extends AbstractNioChannelTestCase {
@Override
public NioChannel channelToClose() throws IOException {
return channelFactory.openNioServerSocketChannel("nio", new InetSocketAddress(InetAddress.getLoopbackAddress(),0));
}
}

View File

@ -0,0 +1,85 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio.channel;
import java.io.IOException;
import java.net.ConnectException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.LockSupport;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.instanceOf;
public class NioSocketChannelTests extends AbstractNioChannelTestCase {
private InetAddress loopbackAddress = InetAddress.getLoopbackAddress();
@Override
public NioChannel channelToClose() throws IOException {
return channelFactory.openNioChannel(new InetSocketAddress(loopbackAddress, mockServerSocket.getLocalPort()));
}
public void testConnectSucceeds() throws IOException, InterruptedException {
InetSocketAddress remoteAddress = new InetSocketAddress(loopbackAddress, mockServerSocket.getLocalPort());
NioSocketChannel socketChannel = channelFactory.openNioChannel(remoteAddress);
Thread thread = new Thread(wrappedRunnable(() -> ensureConnect(socketChannel)));
thread.start();
ConnectFuture connectFuture = socketChannel.getConnectFuture();
connectFuture.awaitConnectionComplete(100, TimeUnit.SECONDS);
assertTrue(socketChannel.isConnectComplete());
assertTrue(socketChannel.isOpen());
assertFalse(connectFuture.connectFailed());
assertNull(connectFuture.getException());
thread.join();
}
public void testConnectFails() throws IOException, InterruptedException {
mockServerSocket.close();
InetSocketAddress remoteAddress = new InetSocketAddress(loopbackAddress, mockServerSocket.getLocalPort());
NioSocketChannel socketChannel = channelFactory.openNioChannel(remoteAddress);
Thread thread = new Thread(wrappedRunnable(() -> ensureConnect(socketChannel)));
thread.start();
ConnectFuture connectFuture = socketChannel.getConnectFuture();
connectFuture.awaitConnectionComplete(100, TimeUnit.SECONDS);
assertFalse(socketChannel.isConnectComplete());
// Even if connection fails the channel is 'open' until close() is called
assertTrue(socketChannel.isOpen());
assertTrue(connectFuture.connectFailed());
assertThat(connectFuture.getException(), instanceOf(ConnectException.class));
assertThat(connectFuture.getException().getMessage(), containsString("Connection refused"));
thread.join();
}
private void ensureConnect(NioSocketChannel nioSocketChannel) throws IOException {
for (;;) {
boolean isConnected = nioSocketChannel.finishConnect();
if (isConnected) {
return;
}
LockSupport.parkNanos(TimeUnit.MILLISECONDS.toNanos(1));
}
}
}

View File

@ -0,0 +1,169 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio.channel;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.transport.TcpTransport;
import java.io.IOException;
import java.io.StreamCorruptedException;
import static org.hamcrest.Matchers.instanceOf;
public class TcpFrameDecoderTests extends ESTestCase {
private TcpFrameDecoder frameDecoder = new TcpFrameDecoder();
public void testDefaultExceptedMessageLengthIsNegative1() {
assertEquals(-1, frameDecoder.expectedMessageLength());
}
public void testDecodeWithIncompleteHeader() throws IOException {
BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14);
streamOutput.write('E');
streamOutput.write('S');
streamOutput.write(1);
streamOutput.write(1);
streamOutput.write(0);
streamOutput.write(0);
assertNull(frameDecoder.decode(streamOutput.bytes(), 4));
assertEquals(-1, frameDecoder.expectedMessageLength());
}
public void testDecodePing() throws IOException {
BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14);
streamOutput.write('E');
streamOutput.write('S');
streamOutput.writeInt(-1);
BytesReference message = frameDecoder.decode(streamOutput.bytes(), 6);
assertEquals(-1, frameDecoder.expectedMessageLength());
assertEquals(streamOutput.bytes(), message);
}
public void testDecodePingWithStartOfSecondMessage() throws IOException {
BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14);
streamOutput.write('E');
streamOutput.write('S');
streamOutput.writeInt(-1);
streamOutput.write('E');
streamOutput.write('S');
BytesReference message = frameDecoder.decode(streamOutput.bytes(), 8);
assertEquals(6, message.length());
assertEquals(streamOutput.bytes().slice(0, 6), message);
}
public void testDecodeMessage() throws IOException {
BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14);
streamOutput.write('E');
streamOutput.write('S');
streamOutput.writeInt(2);
streamOutput.write('M');
streamOutput.write('A');
BytesReference message = frameDecoder.decode(streamOutput.bytes(), 8);
assertEquals(-1, frameDecoder.expectedMessageLength());
assertEquals(streamOutput.bytes(), message);
}
public void testDecodeIncompleteMessage() throws IOException {
BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14);
streamOutput.write('E');
streamOutput.write('S');
streamOutput.writeInt(3);
streamOutput.write('M');
streamOutput.write('A');
BytesReference message = frameDecoder.decode(streamOutput.bytes(), 8);
assertEquals(9, frameDecoder.expectedMessageLength());
assertNull(message);
}
public void testInvalidLength() throws IOException {
BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14);
streamOutput.write('E');
streamOutput.write('S');
streamOutput.writeInt(-2);
streamOutput.write('M');
streamOutput.write('A');
try {
frameDecoder.decode(streamOutput.bytes(), 8);
fail("Expected exception");
} catch (Exception ex) {
assertThat(ex, instanceOf(StreamCorruptedException.class));
assertEquals("invalid data length: -2", ex.getMessage());
}
}
public void testInvalidHeader() throws IOException {
BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14);
streamOutput.write('E');
streamOutput.write('C');
byte byte1 = randomByte();
byte byte2 = randomByte();
streamOutput.write(byte1);
streamOutput.write(byte2);
streamOutput.write(randomByte());
streamOutput.write(randomByte());
streamOutput.write(randomByte());
try {
frameDecoder.decode(streamOutput.bytes(), 7);
fail("Expected exception");
} catch (Exception ex) {
assertThat(ex, instanceOf(StreamCorruptedException.class));
String expected = "invalid internal transport message format, got (45,43,"
+ Integer.toHexString(byte1 & 0xFF) + ","
+ Integer.toHexString(byte2 & 0xFF) + ")";
assertEquals(expected, ex.getMessage());
}
}
public void testHTTPHeader() throws IOException {
String[] httpHeaders = {"GET", "POST", "PUT", "HEAD", "DELETE", "OPTIONS", "PATCH", "TRACE"};
for (String httpHeader : httpHeaders) {
BytesStreamOutput streamOutput = new BytesStreamOutput(1 << 14);
for (char c : httpHeader.toCharArray()) {
streamOutput.write((byte) c);
}
streamOutput.write(new byte[6]);
try {
BytesReference bytes = streamOutput.bytes();
frameDecoder.decode(bytes, bytes.length());
fail("Expected exception");
} catch (Exception ex) {
assertThat(ex, instanceOf(TcpTransport.HttpOnTransportException.class));
assertEquals("This is not a HTTP port", ex.getMessage());
}
}
}
}

View File

@ -0,0 +1,150 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio.channel;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.CompositeBytesReference;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.transport.nio.NetworkBytesReference;
import org.elasticsearch.transport.nio.TcpReadHandler;
import org.junit.Before;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
public class TcpReadContextTests extends ESTestCase {
private static String PROFILE = "profile";
private TcpReadHandler handler;
private int messageLength;
private NioSocketChannel channel;
private TcpReadContext readContext;
@Before
public void init() throws IOException {
handler = mock(TcpReadHandler.class);
messageLength = randomInt(96) + 4;
channel = mock(NioSocketChannel.class);
readContext = new TcpReadContext(channel, handler);
when(channel.getProfile()).thenReturn(PROFILE);
}
public void testSuccessfulRead() throws IOException {
byte[] bytes = createMessage(messageLength);
byte[] fullMessage = combineMessageAndHeader(bytes);
final AtomicInteger bufferCapacity = new AtomicInteger();
when(channel.read(any(NetworkBytesReference.class))).thenAnswer(invocationOnMock -> {
NetworkBytesReference reference = (NetworkBytesReference) invocationOnMock.getArguments()[0];
ByteBuffer buffer = reference.getWriteByteBuffer();
bufferCapacity.set(reference.getWriteRemaining());
buffer.put(fullMessage);
reference.incrementWrite(fullMessage.length);
return fullMessage.length;
});
readContext.read();
verify(handler).handleMessage(new BytesArray(bytes), channel, PROFILE, messageLength);
assertEquals(1024 * 16, bufferCapacity.get());
BytesArray bytesArray = new BytesArray(new byte[10]);
bytesArray.slice(5, 5);
bytesArray.slice(5, 0);
}
public void testPartialRead() throws IOException {
byte[] part1 = createMessage(messageLength);
byte[] fullPart1 = combineMessageAndHeader(part1, messageLength + messageLength);
byte[] part2 = createMessage(messageLength);
final AtomicInteger bufferCapacity = new AtomicInteger();
final AtomicReference<byte[]> bytes = new AtomicReference<>();
when(channel.read(any(NetworkBytesReference.class))).thenAnswer(invocationOnMock -> {
NetworkBytesReference reference = (NetworkBytesReference) invocationOnMock.getArguments()[0];
ByteBuffer buffer = reference.getWriteByteBuffer();
bufferCapacity.set(reference.getWriteRemaining());
buffer.put(bytes.get());
reference.incrementWrite(bytes.get().length);
return bytes.get().length;
});
bytes.set(fullPart1);
readContext.read();
assertEquals(1024 * 16, bufferCapacity.get());
verifyZeroInteractions(handler);
bytes.set(part2);
readContext.read();
assertEquals(1024 * 16 - fullPart1.length, bufferCapacity.get());
CompositeBytesReference reference = new CompositeBytesReference(new BytesArray(part1), new BytesArray(part2));
verify(handler).handleMessage(reference, channel, PROFILE, messageLength + messageLength);
}
public void testReadThrowsIOException() throws IOException {
IOException ioException = new IOException();
when(channel.read(any())).thenThrow(ioException);
try {
readContext.read();
fail("Expected exception");
} catch (Exception ex) {
assertSame(ioException, ex);
}
}
private static byte[] combineMessageAndHeader(byte[] bytes) {
return combineMessageAndHeader(bytes, bytes.length);
}
private static byte[] combineMessageAndHeader(byte[] bytes, int messageLength) {
byte[] fullMessage = new byte[bytes.length + 6];
ByteBuffer wrapped = ByteBuffer.wrap(fullMessage);
wrapped.put((byte) 'E');
wrapped.put((byte) 'S');
wrapped.putInt(messageLength);
wrapped.put(bytes);
return fullMessage;
}
private static byte[] createMessage(int length) {
byte[] bytes = new byte[length];
for (int i = 0; i < length; ++i) {
bytes[i] = randomByte();
}
return bytes;
}
}

View File

@ -0,0 +1,296 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio.channel;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.transport.nio.SocketSelector;
import org.elasticsearch.transport.nio.WriteOperation;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.SocketChannel;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class TcpWriteContextTests extends ESTestCase {
private SocketSelector selector;
private ActionListener<NioChannel> listener;
private TcpWriteContext writeContext;
private NioSocketChannel channel;
@Before
@SuppressWarnings("unchecked")
public void setUp() throws Exception {
super.setUp();
selector = mock(SocketSelector.class);
listener = mock(ActionListener.class);
channel = mock(NioSocketChannel.class);
writeContext = new TcpWriteContext(channel);
when(channel.getSelector()).thenReturn(selector);
when(selector.isOnCurrentThread()).thenReturn(true);
}
public void testWriteFailsIfChannelNotWritable() throws Exception {
when(channel.isWritable()).thenReturn(false);
writeContext.sendMessage(new BytesArray(generateBytes(10)), listener);
verify(listener).onFailure(any(ClosedChannelException.class));
}
public void testSendMessageFromDifferentThreadIsQueuedWithSelector() throws Exception {
byte[] bytes = generateBytes(10);
BytesArray bytesArray = new BytesArray(bytes);
ArgumentCaptor<WriteOperation> writeOpCaptor = ArgumentCaptor.forClass(WriteOperation.class);
when(selector.isOnCurrentThread()).thenReturn(false);
when(channel.isWritable()).thenReturn(true);
writeContext.sendMessage(bytesArray, listener);
verify(selector).queueWrite(writeOpCaptor.capture());
WriteOperation writeOp = writeOpCaptor.getValue();
assertSame(listener, writeOp.getListener());
assertSame(channel, writeOp.getChannel());
assertEquals(ByteBuffer.wrap(bytes), writeOp.getByteReferences()[0].getReadByteBuffer());
}
public void testSendMessageFromSameThreadIsQueuedInChannel() throws Exception {
byte[] bytes = generateBytes(10);
BytesArray bytesArray = new BytesArray(bytes);
ArgumentCaptor<WriteOperation> writeOpCaptor = ArgumentCaptor.forClass(WriteOperation.class);
when(channel.isWritable()).thenReturn(true);
writeContext.sendMessage(bytesArray, listener);
verify(selector).queueWriteInChannelBuffer(writeOpCaptor.capture());
WriteOperation writeOp = writeOpCaptor.getValue();
assertSame(listener, writeOp.getListener());
assertSame(channel, writeOp.getChannel());
assertEquals(ByteBuffer.wrap(bytes), writeOp.getByteReferences()[0].getReadByteBuffer());
}
public void testWriteIsQueuedInChannel() throws Exception {
assertFalse(writeContext.hasQueuedWriteOps());
writeContext.queueWriteOperations(new WriteOperation(channel, new BytesArray(generateBytes(10)), listener));
assertTrue(writeContext.hasQueuedWriteOps());
}
public void testWriteOpsCanBeCleared() throws Exception {
assertFalse(writeContext.hasQueuedWriteOps());
writeContext.queueWriteOperations(new WriteOperation(channel, new BytesArray(generateBytes(10)), listener));
assertTrue(writeContext.hasQueuedWriteOps());
ClosedChannelException e = new ClosedChannelException();
writeContext.clearQueuedWriteOps(e);
verify(listener).onFailure(e);
assertFalse(writeContext.hasQueuedWriteOps());
}
public void testQueuedWriteIsFlushedInFlushCall() throws Exception {
assertFalse(writeContext.hasQueuedWriteOps());
WriteOperation writeOperation = mock(WriteOperation.class);
writeContext.queueWriteOperations(writeOperation);
assertTrue(writeContext.hasQueuedWriteOps());
when(writeOperation.isFullyFlushed()).thenReturn(true);
when(writeOperation.getListener()).thenReturn(listener);
writeContext.flushChannel();
verify(writeOperation).flush();
verify(listener).onResponse(channel);
assertFalse(writeContext.hasQueuedWriteOps());
}
public void testPartialFlush() throws IOException {
assertFalse(writeContext.hasQueuedWriteOps());
WriteOperation writeOperation = mock(WriteOperation.class);
writeContext.queueWriteOperations(writeOperation);
assertTrue(writeContext.hasQueuedWriteOps());
when(writeOperation.isFullyFlushed()).thenReturn(false);
writeContext.flushChannel();
verify(listener, times(0)).onResponse(channel);
assertTrue(writeContext.hasQueuedWriteOps());
}
@SuppressWarnings("unchecked")
public void testMultipleWritesPartialFlushes() throws IOException {
assertFalse(writeContext.hasQueuedWriteOps());
ActionListener listener2 = mock(ActionListener.class);
WriteOperation writeOperation1 = mock(WriteOperation.class);
WriteOperation writeOperation2 = mock(WriteOperation.class);
when(writeOperation1.getListener()).thenReturn(listener);
when(writeOperation2.getListener()).thenReturn(listener2);
writeContext.queueWriteOperations(writeOperation1);
writeContext.queueWriteOperations(writeOperation2);
assertTrue(writeContext.hasQueuedWriteOps());
when(writeOperation1.isFullyFlushed()).thenReturn(true);
when(writeOperation2.isFullyFlushed()).thenReturn(false);
writeContext.flushChannel();
verify(listener).onResponse(channel);
verify(listener2, times(0)).onResponse(channel);
assertTrue(writeContext.hasQueuedWriteOps());
when(writeOperation2.isFullyFlushed()).thenReturn(true);
writeContext.flushChannel();
verify(listener2).onResponse(channel);
assertFalse(writeContext.hasQueuedWriteOps());
}
private class ConsumeAllChannel extends NioSocketChannel {
private byte[] bytes;
private byte[] bytes2;
ConsumeAllChannel() throws IOException {
super("", mock(SocketChannel.class));
}
public int write(ByteBuffer buffer) throws IOException {
bytes = new byte[buffer.remaining()];
buffer.get(bytes);
return bytes.length;
}
public long vectorizedWrite(ByteBuffer[] buffer) throws IOException {
if (buffer.length != 2) {
throw new IOException("Only allows 2 buffers");
}
bytes = new byte[buffer[0].remaining()];
buffer[0].get(bytes);
bytes2 = new byte[buffer[1].remaining()];
buffer[1].get(bytes2);
return bytes.length + bytes2.length;
}
}
private class HalfConsumeChannel extends NioSocketChannel {
private byte[] bytes;
private byte[] bytes2;
HalfConsumeChannel() throws IOException {
super("", mock(SocketChannel.class));
}
public int write(ByteBuffer buffer) throws IOException {
bytes = new byte[buffer.limit() / 2];
buffer.get(bytes);
return bytes.length;
}
public long vectorizedWrite(ByteBuffer[] buffers) throws IOException {
if (buffers.length != 2) {
throw new IOException("Only allows 2 buffers");
}
if (bytes == null) {
bytes = new byte[buffers[0].remaining()];
bytes2 = new byte[buffers[1].remaining()];
}
if (buffers[0].remaining() != 0) {
buffers[0].get(bytes);
return bytes.length;
} else {
buffers[1].get(bytes2);
return bytes2.length;
}
}
}
private class MultiWriteChannel extends NioSocketChannel {
private byte[] write1Bytes;
private byte[] write1Bytes2;
private byte[] write2Bytes1;
private byte[] write2Bytes2;
MultiWriteChannel() throws IOException {
super("", mock(SocketChannel.class));
}
public long vectorizedWrite(ByteBuffer[] buffers) throws IOException {
if (buffers.length != 4 && write1Bytes == null) {
throw new IOException("Only allows 4 buffers");
} else if (buffers.length != 2 && write1Bytes != null) {
throw new IOException("Only allows 2 buffers on second write");
}
if (write1Bytes == null) {
write1Bytes = new byte[buffers[0].remaining()];
write1Bytes2 = new byte[buffers[1].remaining()];
write2Bytes1 = new byte[buffers[2].remaining()];
write2Bytes2 = new byte[buffers[3].remaining()];
}
if (buffers[0].remaining() != 0) {
buffers[0].get(write1Bytes);
buffers[1].get(write1Bytes2);
buffers[2].get(write2Bytes1);
return write1Bytes.length + write1Bytes2.length + write2Bytes1.length;
} else {
buffers[1].get(write2Bytes2);
return write2Bytes2.length;
}
}
}
private byte[] generateBytes(int n) {
n += 10;
byte[] bytes = new byte[n];
for (int i = 0; i < n; ++i) {
bytes[i] = randomByte();
}
return bytes;
}
}

View File

@ -0,0 +1,65 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.nio.utils;
import java.nio.channels.SelectableChannel;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.spi.AbstractSelectionKey;
public class TestSelectionKey extends AbstractSelectionKey {
private int ops = 0;
private int readyOps;
public TestSelectionKey(int ops) {
this.ops = ops;
}
@Override
public SelectableChannel channel() {
return null;
}
@Override
public Selector selector() {
return null;
}
@Override
public int interestOps() {
return ops;
}
@Override
public SelectionKey interestOps(int ops) {
this.ops = ops;
return this;
}
@Override
public int readyOps() {
return readyOps;
}
public void setReadyOps(int readyOps) {
this.readyOps = readyOps;
}
}