Move nio ip filter rule to be a channel handler (#43507)

Currently nio implements ip filtering at the channel context level. This
is kind of a hack as the application logic should be implemented at the
handler level. This commit moves the ip filtering into a channel
handler. This requires adding an indicator to the channel handler to
show when a channel should be closed.
This commit is contained in:
Tim Brooks 2019-06-24 11:35:46 -04:00
parent fac7efba9a
commit 38516a4dd5
No known key found for this signature in database
GPG Key ID: C2AA3BB91A889E77
15 changed files with 194 additions and 105 deletions

View File

@ -21,19 +21,12 @@ package org.elasticsearch.nio;
import java.io.IOException;
import java.util.function.Consumer;
import java.util.function.Predicate;
public class BytesChannelContext extends SocketChannelContext {
public BytesChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler,
ReadWriteHandler handler, InboundChannelBuffer channelBuffer) {
this(channel, selector, exceptionHandler, handler, channelBuffer, ALWAYS_ALLOW_CHANNEL);
}
public BytesChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler,
ReadWriteHandler handler, InboundChannelBuffer channelBuffer,
Predicate<NioSocketChannel> allowChannelPredicate) {
super(channel, selector, exceptionHandler, handler, channelBuffer, allowChannelPredicate);
NioChannelHandler handler, InboundChannelBuffer channelBuffer) {
super(channel, selector, exceptionHandler, handler, channelBuffer);
}
@Override

View File

@ -24,7 +24,7 @@ import java.util.Collections;
import java.util.List;
import java.util.function.BiConsumer;
public abstract class BytesWriteHandler implements ReadWriteHandler {
public abstract class BytesWriteHandler implements NioChannelHandler {
private static final List<FlushOperation> EMPTY_LIST = Collections.emptyList();
@ -48,6 +48,11 @@ public abstract class BytesWriteHandler implements ReadWriteHandler {
return EMPTY_LIST;
}
@Override
public boolean closeNow() {
return false;
}
@Override
public void close() {}
}

View File

@ -0,0 +1,68 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.nio;
import java.io.IOException;
import java.util.List;
import java.util.function.BiConsumer;
public abstract class DelegatingHandler implements NioChannelHandler {
private NioChannelHandler delegate;
public DelegatingHandler(NioChannelHandler delegate) {
this.delegate = delegate;
}
@Override
public void channelRegistered() {
this.delegate.channelRegistered();
}
@Override
public WriteOperation createWriteOperation(SocketChannelContext context, Object message, BiConsumer<Void, Exception> listener) {
return delegate.createWriteOperation(context, message, listener);
}
@Override
public List<FlushOperation> writeToBytes(WriteOperation writeOperation) {
return delegate.writeToBytes(writeOperation);
}
@Override
public List<FlushOperation> pollFlushOperations() {
return delegate.pollFlushOperations();
}
@Override
public int consumeReads(InboundChannelBuffer channelBuffer) throws IOException {
return delegate.consumeReads(channelBuffer);
}
@Override
public boolean closeNow() {
return delegate.closeNow();
}
@Override
public void close() throws IOException {
delegate.close();
}
}

View File

@ -24,9 +24,9 @@ import java.util.List;
import java.util.function.BiConsumer;
/**
* Implements the application specific logic for handling inbound and outbound messages for a channel.
* Implements the application specific logic for handling channel operations.
*/
public interface ReadWriteHandler {
public interface NioChannelHandler {
/**
* This method is called when the channel is registered with its selector.
@ -72,5 +72,12 @@ public interface ReadWriteHandler {
*/
int consumeReads(InboundChannelBuffer channelBuffer) throws IOException;
/**
* This method indicates if the underlying channel should be closed.
*
* @return if the channel should be closed
*/
boolean closeNow();
void close() throws IOException;
}

View File

@ -32,7 +32,6 @@ import java.util.LinkedList;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Predicate;
/**
* This context should implement the specific logic for a channel. When a channel receives a notification
@ -45,13 +44,10 @@ import java.util.function.Predicate;
*/
public abstract class SocketChannelContext extends ChannelContext<SocketChannel> {
protected static final Predicate<NioSocketChannel> ALWAYS_ALLOW_CHANNEL = (c) -> true;
protected final NioSocketChannel channel;
protected final InboundChannelBuffer channelBuffer;
protected final AtomicBoolean isClosing = new AtomicBoolean(false);
private final ReadWriteHandler readWriteHandler;
private final Predicate<NioSocketChannel> allowChannelPredicate;
private final NioChannelHandler readWriteHandler;
private final NioSelector selector;
private final CompletableContext<Void> connectContext = new CompletableContext<>();
private final LinkedList<FlushOperation> pendingFlushes = new LinkedList<>();
@ -59,14 +55,12 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
private Exception connectException;
protected SocketChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler,
ReadWriteHandler readWriteHandler, InboundChannelBuffer channelBuffer,
Predicate<NioSocketChannel> allowChannelPredicate) {
NioChannelHandler readWriteHandler, InboundChannelBuffer channelBuffer) {
super(channel.getRawChannel(), exceptionHandler);
this.selector = selector;
this.channel = channel;
this.readWriteHandler = readWriteHandler;
this.channelBuffer = channelBuffer;
this.allowChannelPredicate = allowChannelPredicate;
}
@Override
@ -171,9 +165,6 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
protected void register() throws IOException {
super.register();
readWriteHandler.channelRegistered();
if (allowChannelPredicate.test(channel) == false) {
closeNow = true;
}
}
@Override
@ -233,7 +224,7 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
public abstract boolean selectorShouldClose();
protected boolean closeNow() {
return closeNow;
return closeNow || readWriteHandler.closeNow();
}
protected void setCloseNow() {

View File

@ -23,7 +23,7 @@ import java.util.function.BiConsumer;
/**
* This is a basic write operation that can be queued with a channel. The only requirements of a write
* operation is that is has a listener and a reference to its channel. The actual conversion of the write
* operation implementation to bytes will be performed by the {@link ReadWriteHandler}.
* operation implementation to bytes will be performed by the {@link NioChannelHandler}.
*/
public interface WriteOperation {

View File

@ -44,7 +44,7 @@ public class EventHandlerTests extends ESTestCase {
private Consumer<Exception> channelExceptionHandler;
private Consumer<Exception> genericExceptionHandler;
private ReadWriteHandler readWriteHandler;
private NioChannelHandler readWriteHandler;
private EventHandler handler;
private DoNotRegisterSocketContext context;
private DoNotRegisterServerContext serverContext;
@ -56,7 +56,7 @@ public class EventHandlerTests extends ESTestCase {
public void setUpHandler() throws IOException {
channelExceptionHandler = mock(Consumer.class);
genericExceptionHandler = mock(Consumer.class);
readWriteHandler = mock(ReadWriteHandler.class);
readWriteHandler = mock(NioChannelHandler.class);
channelFactory = mock(ChannelFactory.class);
NioSelector selector = mock(NioSelector.class);
ArrayList<NioSelector> selectors = new ArrayList<>();
@ -260,7 +260,7 @@ public class EventHandlerTests extends ESTestCase {
DoNotRegisterSocketContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler,
ReadWriteHandler handler) {
NioChannelHandler handler) {
super(channel, selector, exceptionHandler, handler, InboundChannelBuffer.allocatingInstance());
}

View File

@ -35,7 +35,6 @@ import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.IntFunction;
import java.util.function.Predicate;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
@ -54,7 +53,7 @@ public class SocketChannelContextTests extends ESTestCase {
private NioSocketChannel channel;
private BiConsumer<Void, Exception> listener;
private NioSelector selector;
private ReadWriteHandler readWriteHandler;
private NioChannelHandler readWriteHandler;
private ByteBuffer ioBuffer = ByteBuffer.allocate(1024);
@SuppressWarnings("unchecked")
@ -68,7 +67,7 @@ public class SocketChannelContextTests extends ESTestCase {
when(channel.getRawChannel()).thenReturn(rawChannel);
exceptionHandler = mock(Consumer.class);
selector = mock(NioSelector.class);
readWriteHandler = mock(ReadWriteHandler.class);
readWriteHandler = mock(NioChannelHandler.class);
InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance();
context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, channelBuffer);
@ -102,22 +101,6 @@ public class SocketChannelContextTests extends ESTestCase {
assertTrue(context.closeNow());
}
public void testValidateInRegisterCanSucceed() throws IOException {
InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance();
context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, (c) -> true);
assertFalse(context.closeNow());
context.register();
assertFalse(context.closeNow());
}
public void testValidateInRegisterCanFail() throws IOException {
InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance();
context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, (c) -> false);
assertFalse(context.closeNow());
context.register();
assertTrue(context.closeNow());
}
public void testConnectSucceeds() throws IOException {
AtomicBoolean listenerCalled = new AtomicBoolean(false);
when(rawChannel.finishConnect()).thenReturn(false, true);
@ -394,14 +377,8 @@ public class SocketChannelContextTests extends ESTestCase {
private static class TestSocketChannelContext extends SocketChannelContext {
private TestSocketChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler,
ReadWriteHandler readWriteHandler, InboundChannelBuffer channelBuffer) {
this(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, ALWAYS_ALLOW_CHANNEL);
}
private TestSocketChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler,
ReadWriteHandler readWriteHandler, InboundChannelBuffer channelBuffer,
Predicate<NioSocketChannel> allowChannelPredicate) {
super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, allowChannelPredicate);
NioChannelHandler readWriteHandler, InboundChannelBuffer channelBuffer) {
super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer);
}
@Override

View File

@ -38,7 +38,7 @@ import org.elasticsearch.http.nio.cors.NioCorsConfig;
import org.elasticsearch.http.nio.cors.NioCorsHandler;
import org.elasticsearch.nio.FlushOperation;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.ReadWriteHandler;
import org.elasticsearch.nio.NioChannelHandler;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.nio.TaskScheduler;
import org.elasticsearch.nio.WriteOperation;
@ -50,7 +50,7 @@ import java.util.concurrent.TimeUnit;
import java.util.function.BiConsumer;
import java.util.function.LongSupplier;
public class HttpReadWriteHandler implements ReadWriteHandler {
public class HttpReadWriteHandler implements NioChannelHandler {
private final NettyAdaptor adaptor;
private final NioHttpChannel nioHttpChannel;
@ -140,6 +140,11 @@ public class HttpReadWriteHandler implements ReadWriteHandler {
return copiedOperations;
}
@Override
public boolean closeNow() {
return false;
}
@Override
public void close() throws IOException {
try {

View File

@ -49,7 +49,7 @@ import org.elasticsearch.nio.NioSelectorGroup;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioServerSocketChannel;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.ReadWriteHandler;
import org.elasticsearch.nio.NioChannelHandler;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.nio.WriteOperation;
import org.elasticsearch.tasks.Task;
@ -207,7 +207,7 @@ class NioHttpClient implements Closeable {
}
}
private static class HttpClientHandler implements ReadWriteHandler {
private static class HttpClientHandler implements NioChannelHandler {
private final NettyAdaptor adaptor;
private final CountDownLatch latch;
@ -277,6 +277,11 @@ class NioHttpClient implements Closeable {
return bytesConsumed;
}
@Override
public boolean closeNow() {
return false;
}
@Override
public void close() throws IOException {
try {

View File

@ -5,28 +5,49 @@
*/
package org.elasticsearch.xpack.security.transport.nio;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.DelegatingHandler;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioChannelHandler;
import org.elasticsearch.xpack.security.transport.filter.IPFilter;
import java.util.function.Predicate;
import java.io.IOException;
import java.net.InetSocketAddress;
public final class NioIPFilter implements Predicate<NioSocketChannel> {
public final class NioIPFilter extends DelegatingHandler {
private final InetSocketAddress remoteAddress;
private final IPFilter filter;
private final String profile;
private boolean denied = false;
NioIPFilter(@Nullable IPFilter filter, String profile) {
NioIPFilter(NioChannelHandler delegate, InetSocketAddress remoteAddress, IPFilter filter, String profile) {
super(delegate);
this.remoteAddress = remoteAddress;
this.filter = filter;
this.profile = profile;
}
@Override
public boolean test(NioSocketChannel nioChannel) {
if (filter != null) {
return filter.accept(profile, nioChannel.getRemoteAddress());
public void channelRegistered() {
if (filter.accept(profile, remoteAddress)) {
super.channelRegistered();
} else {
return true;
denied = true;
}
}
@Override
public int consumeReads(InboundChannelBuffer channelBuffer) throws IOException {
if (denied) {
// Do not consume any reads if channel is disallowed
return 0;
} else {
return super.consumeReads(channelBuffer);
}
}
@Override
public boolean closeNow() {
return denied;
}
}

View File

@ -9,9 +9,9 @@ import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.core.internal.io.IOUtils;
import org.elasticsearch.nio.FlushOperation;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioChannelHandler;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.ReadWriteHandler;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.nio.WriteOperation;
@ -23,12 +23,11 @@ import java.util.LinkedList;
import java.util.concurrent.TimeUnit;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Predicate;
/**
* Provides a TLS/SSL read/write layer over a channel. This context will use a {@link SSLDriver} to handshake
* with the peer channel. Once the handshake is complete, any data from the peer channel will be decrypted
* before being passed to the {@link ReadWriteHandler}. Outbound data will be encrypted before being flushed
* before being passed to the {@link NioChannelHandler}. Outbound data will be encrypted before being flushed
* to the channel.
*/
public final class SSLChannelContext extends SocketChannelContext {
@ -43,15 +42,14 @@ public final class SSLChannelContext extends SocketChannelContext {
private Runnable closeTimeoutCanceller = DEFAULT_TIMEOUT_CANCELLER;
SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler, SSLDriver sslDriver,
ReadWriteHandler readWriteHandler, InboundChannelBuffer applicationBuffer) {
NioChannelHandler readWriteHandler, InboundChannelBuffer applicationBuffer) {
this(channel, selector, exceptionHandler, sslDriver, readWriteHandler, InboundChannelBuffer.allocatingInstance(),
applicationBuffer, ALWAYS_ALLOW_CHANNEL);
applicationBuffer);
}
SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler, SSLDriver sslDriver,
ReadWriteHandler readWriteHandler, InboundChannelBuffer networkReadBuffer, InboundChannelBuffer channelBuffer,
Predicate<NioSocketChannel> allowChannelPredicate) {
super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, allowChannelPredicate);
NioChannelHandler readWriteHandler, InboundChannelBuffer networkReadBuffer, InboundChannelBuffer channelBuffer) {
super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer);
this.sslDriver = sslDriver;
this.networkReadBuffer = networkReadBuffer;
}

View File

@ -19,6 +19,7 @@ import org.elasticsearch.http.nio.NioHttpServerTransport;
import org.elasticsearch.nio.BytesChannelContext;
import org.elasticsearch.nio.ChannelFactory;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioChannelHandler;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.ServerChannelContext;
@ -44,7 +45,6 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport {
private final SecurityHttpExceptionHandler securityExceptionHandler;
private final IPFilter ipFilter;
private final NioIPFilter nioIpFilter;
private final SSLService sslService;
private final SSLConfiguration sslConfiguration;
private final boolean sslEnabled;
@ -56,7 +56,6 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport {
super(settings, networkService, bigArrays, pageCacheRecycler, threadPool, xContentRegistry, dispatcher, nioGroupFactory);
this.securityExceptionHandler = new SecurityHttpExceptionHandler(logger, lifecycle, (c, e) -> super.onException(c, e));
this.ipFilter = ipFilter;
this.nioIpFilter = new NioIPFilter(ipFilter, IPFilter.HTTP_PROFILE_NAME);
this.sslEnabled = HTTP_SSL_ENABLED.get(settings);
this.sslService = sslService;
if (sslEnabled) {
@ -91,6 +90,13 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport {
NioHttpChannel httpChannel = new NioHttpChannel(channel);
HttpReadWriteHandler httpHandler = new HttpReadWriteHandler(httpChannel,SecurityNioHttpServerTransport.this,
handlingSettings, corsConfig, selector.getTaskScheduler(), threadPool::relativeTimeInNanos);
final NioChannelHandler handler;
if (ipFilter != null) {
handler = new NioIPFilter(httpHandler, httpChannel.getRemoteAddress(), ipFilter, IPFilter.HTTP_PROFILE_NAME);
} else {
handler = httpHandler;
}
InboundChannelBuffer networkBuffer = new InboundChannelBuffer(pageAllocator);
Consumer<Exception> exceptionHandler = (e) -> securityExceptionHandler.accept(httpChannel, e);
@ -107,10 +113,10 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport {
}
SSLDriver sslDriver = new SSLDriver(sslEngine, pageAllocator, false);
InboundChannelBuffer applicationBuffer = new InboundChannelBuffer(pageAllocator);
context = new SSLChannelContext(httpChannel, selector, exceptionHandler, sslDriver, httpHandler, networkBuffer,
applicationBuffer, nioIpFilter);
context = new SSLChannelContext(httpChannel, selector, exceptionHandler, sslDriver, handler, networkBuffer,
applicationBuffer);
} else {
context = new BytesChannelContext(httpChannel, selector, exceptionHandler, httpHandler, networkBuffer, nioIpFilter);
context = new BytesChannelContext(httpChannel, selector, exceptionHandler, handler, networkBuffer);
}
httpChannel.setContext(context);

View File

@ -18,6 +18,7 @@ import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.nio.BytesChannelContext;
import org.elasticsearch.nio.ChannelFactory;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioChannelHandler;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.ServerChannelContext;
@ -65,19 +66,19 @@ public class SecurityNioTransport extends NioTransport {
private static final Logger logger = LogManager.getLogger(SecurityNioTransport.class);
private final SecurityTransportExceptionHandler exceptionHandler;
private final IPFilter authenticator;
private final IPFilter ipFilter;
private final SSLService sslService;
private final Map<String, SSLConfiguration> profileConfiguration;
private final boolean sslEnabled;
public SecurityNioTransport(Settings settings, Version version, ThreadPool threadPool, NetworkService networkService,
PageCacheRecycler pageCacheRecycler, NamedWriteableRegistry namedWriteableRegistry,
CircuitBreakerService circuitBreakerService, @Nullable final IPFilter authenticator,
CircuitBreakerService circuitBreakerService, @Nullable final IPFilter ipFilter,
SSLService sslService, NioGroupFactory groupFactory) {
super(settings, version, threadPool, networkService, pageCacheRecycler, namedWriteableRegistry, circuitBreakerService,
groupFactory);
this.exceptionHandler = new SecurityTransportExceptionHandler(logger, lifecycle, (c, e) -> super.onException(c, e));
this.authenticator = authenticator;
this.ipFilter = ipFilter;
this.sslService = sslService;
this.sslEnabled = XPackSettings.TRANSPORT_SSL_ENABLED.get(settings);
if (sslEnabled) {
@ -92,8 +93,8 @@ public class SecurityNioTransport extends NioTransport {
@Override
protected void doStart() {
super.doStart();
if (authenticator != null) {
authenticator.setBoundTransportAddress(boundAddress(), profileBoundAddresses());
if (ipFilter != null) {
ipFilter.setBoundTransportAddress(boundAddress(), profileBoundAddresses());
}
}
@ -132,7 +133,6 @@ public class SecurityNioTransport extends NioTransport {
private final String profileName;
private final boolean isClient;
private final NioIPFilter ipFilter;
private SecurityTcpChannelFactory(ProfileSettings profileSettings, boolean isClient) {
this(new RawChannelFactory(profileSettings.tcpNoDelay,
@ -146,13 +146,18 @@ public class SecurityNioTransport extends NioTransport {
super(rawChannelFactory);
this.profileName = profileName;
this.isClient = isClient;
this.ipFilter = new NioIPFilter(authenticator, profileName);
}
@Override
public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
NioTcpChannel nioChannel = new NioTcpChannel(isClient == false, profileName, channel);
TcpReadWriteHandler readWriteHandler = new TcpReadWriteHandler(nioChannel, SecurityNioTransport.this);
final NioChannelHandler handler;
if (ipFilter != null) {
handler = new NioIPFilter(readWriteHandler, nioChannel.getRemoteAddress(), ipFilter, profileName);
} else {
handler = readWriteHandler;
}
InboundChannelBuffer networkBuffer = new InboundChannelBuffer(pageAllocator);
Consumer<Exception> exceptionHandler = (e) -> onException(nioChannel, e);
@ -160,10 +165,10 @@ public class SecurityNioTransport extends NioTransport {
if (sslEnabled) {
SSLDriver sslDriver = new SSLDriver(createSSLEngine(channel), pageAllocator, isClient);
InboundChannelBuffer applicationBuffer = new InboundChannelBuffer(pageAllocator);
context = new SSLChannelContext(nioChannel, selector, exceptionHandler, sslDriver, readWriteHandler, networkBuffer,
applicationBuffer, ipFilter);
context = new SSLChannelContext(nioChannel, selector, exceptionHandler, sslDriver, handler, networkBuffer,
applicationBuffer);
} else {
context = new BytesChannelContext(nioChannel, selector, exceptionHandler, readWriteHandler, networkBuffer, ipFilter);
context = new BytesChannelContext(nioChannel, selector, exceptionHandler, handler, networkBuffer);
}
nioChannel.setContext(context);

View File

@ -13,7 +13,7 @@ import org.elasticsearch.common.transport.BoundTransportAddress;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.http.HttpServerTransport;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.NioChannelHandler;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.xpack.security.audit.AuditTrailService;
@ -26,13 +26,15 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class NioIPFilterTests extends ESTestCase {
private NioIPFilter nioIPFilter;
private IPFilter ipFilter;
private String profile;
@Before
public void init() throws Exception {
@ -59,7 +61,7 @@ public class NioIPFilterTests extends ESTestCase {
XPackLicenseState licenseState = mock(XPackLicenseState.class);
when(licenseState.isIpFilteringAllowed()).thenReturn(true);
AuditTrailService auditTrailService = new AuditTrailService(Collections.emptyList(), licenseState);
IPFilter ipFilter = new IPFilter(settings, auditTrailService, clusterSettings, licenseState);
ipFilter = new IPFilter(settings, auditTrailService, clusterSettings, licenseState);
ipFilter.setBoundTransportAddress(transport.boundAddress(), transport.profileBoundAddresses());
if (isHttpEnabled) {
HttpServerTransport httpTransport = mock(HttpServerTransport.class);
@ -70,21 +72,27 @@ public class NioIPFilterTests extends ESTestCase {
}
if (isHttpEnabled) {
nioIPFilter = new NioIPFilter(ipFilter, IPFilter.HTTP_PROFILE_NAME);
profile = IPFilter.HTTP_PROFILE_NAME;
} else {
nioIPFilter = new NioIPFilter(ipFilter, "default");
profile = "default";
}
}
public void testThatFilteringWorksByIp() throws Exception {
public void testThatFilterCanPass() throws Exception {
InetSocketAddress localhostAddr = new InetSocketAddress(InetAddresses.forString("127.0.0.1"), 12345);
NioSocketChannel channel1 = mock(NioSocketChannel.class);
when(channel1.getRemoteAddress()).thenReturn(localhostAddr);
assertThat(nioIPFilter.test(channel1), is(true));
NioChannelHandler delegate = mock(NioChannelHandler.class);
NioIPFilter nioIPFilter = new NioIPFilter(delegate, localhostAddr, ipFilter, profile);
nioIPFilter.channelRegistered();
verify(delegate).channelRegistered();
assertFalse(nioIPFilter.closeNow());
}
InetSocketAddress remoteAddr = new InetSocketAddress(InetAddresses.forString("10.0.0.8"), 12345);
NioSocketChannel channel2 = mock(NioSocketChannel.class);
when(channel2.getRemoteAddress()).thenReturn(remoteAddr);
assertThat(nioIPFilter.test(channel2), is(false));
public void testThatFilterCanFail() throws Exception {
InetSocketAddress localhostAddr = new InetSocketAddress(InetAddresses.forString("10.0.0.8"), 12345);
NioChannelHandler delegate = mock(NioChannelHandler.class);
NioIPFilter nioIPFilter = new NioIPFilter(delegate, localhostAddr, ipFilter, profile);
nioIPFilter.channelRegistered();
verify(delegate, times(0)).channelRegistered();
assertTrue(nioIPFilter.closeNow());
}
}