Add sni name to SSLEngine in nio transport (#35920)

This commit is related to #32517. It allows an "sni_server_name"
attribute on a DiscoveryNode to be propagated to the server using
the TLS SNI extentsion. Prior to this commit, this functionality
was only support for the netty transport. This commit adds this
functionality to the security nio transport.
This commit is contained in:
Tim Brooks 2018-11-27 09:06:52 -07:00 committed by GitHub
parent adc0b560c0
commit b6ed6ef189
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 225 additions and 197 deletions

View File

@ -146,7 +146,7 @@ public abstract class ChannelFactory<ServerSocket extends NioServerSocketChannel
} }
} }
protected static class RawChannelFactory { public static class RawChannelFactory {
private final boolean tcpNoDelay; private final boolean tcpNoDelay;
private final boolean tcpKeepAlive; private final boolean tcpKeepAlive;

View File

@ -51,6 +51,7 @@ import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel; import java.nio.channels.SocketChannel;
import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ConcurrentMap;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier; import java.util.function.Supplier;
import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap; import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap;
@ -67,7 +68,7 @@ public class NioTransport extends TcpTransport {
protected final PageCacheRecycler pageCacheRecycler; protected final PageCacheRecycler pageCacheRecycler;
private final ConcurrentMap<String, TcpChannelFactory> profileToChannelFactory = newConcurrentMap(); private final ConcurrentMap<String, TcpChannelFactory> profileToChannelFactory = newConcurrentMap();
private volatile NioGroup nioGroup; private volatile NioGroup nioGroup;
private volatile TcpChannelFactory clientChannelFactory; private volatile Function<DiscoveryNode, TcpChannelFactory> clientChannelFactory;
protected NioTransport(Settings settings, Version version, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays, protected NioTransport(Settings settings, Version version, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays,
PageCacheRecycler pageCacheRecycler, NamedWriteableRegistry namedWriteableRegistry, PageCacheRecycler pageCacheRecycler, NamedWriteableRegistry namedWriteableRegistry,
@ -85,8 +86,7 @@ public class NioTransport extends TcpTransport {
@Override @Override
protected NioTcpChannel initiateChannel(DiscoveryNode node) throws IOException { protected NioTcpChannel initiateChannel(DiscoveryNode node) throws IOException {
InetSocketAddress address = node.getAddress().address(); InetSocketAddress address = node.getAddress().address();
NioTcpChannel channel = nioGroup.openChannel(address, clientChannelFactory); return nioGroup.openChannel(address, clientChannelFactory.apply(node));
return channel;
} }
@Override @Override
@ -97,13 +97,13 @@ public class NioTransport extends TcpTransport {
NioTransport.NIO_WORKER_COUNT.get(settings), (s) -> new EventHandler(this::onNonChannelException, s)); NioTransport.NIO_WORKER_COUNT.get(settings), (s) -> new EventHandler(this::onNonChannelException, s));
ProfileSettings clientProfileSettings = new ProfileSettings(settings, "default"); ProfileSettings clientProfileSettings = new ProfileSettings(settings, "default");
clientChannelFactory = channelFactory(clientProfileSettings, true); clientChannelFactory = clientChannelFactoryFunction(clientProfileSettings);
if (NetworkService.NETWORK_SERVER.get(settings)) { if (NetworkService.NETWORK_SERVER.get(settings)) {
// loop through all profiles and start them up, special handling for default one // loop through all profiles and start them up, special handling for default one
for (ProfileSettings profileSettings : profileSettings) { for (ProfileSettings profileSettings : profileSettings) {
String profileName = profileSettings.profileName; String profileName = profileSettings.profileName;
TcpChannelFactory factory = channelFactory(profileSettings, false); TcpChannelFactory factory = serverChannelFactory(profileSettings);
profileToChannelFactory.putIfAbsent(profileName, factory); profileToChannelFactory.putIfAbsent(profileName, factory);
bindServer(profileSettings); bindServer(profileSettings);
} }
@ -134,8 +134,12 @@ public class NioTransport extends TcpTransport {
serverAcceptedChannel((NioTcpChannel) channel); serverAcceptedChannel((NioTcpChannel) channel);
} }
protected TcpChannelFactory channelFactory(ProfileSettings settings, boolean isClient) { protected TcpChannelFactory serverChannelFactory(ProfileSettings profileSettings) {
return new TcpChannelFactoryImpl(settings); return new TcpChannelFactoryImpl(profileSettings);
}
protected Function<DiscoveryNode, TcpChannelFactory> clientChannelFactoryFunction(ProfileSettings profileSettings) {
return (n) -> new TcpChannelFactoryImpl(profileSettings);
} }
protected abstract class TcpChannelFactory extends ChannelFactory<NioTcpServerChannel, NioTcpChannel> { protected abstract class TcpChannelFactory extends ChannelFactory<NioTcpServerChannel, NioTcpChannel> {

View File

@ -262,14 +262,14 @@ public class SSLService {
/** /**
* Returns the {@link SSLContext} for the global configuration. Mainly used for testing * Returns the {@link SSLContext} for the global configuration. Mainly used for testing
*/ */
SSLContext sslContext() { public SSLContext sslContext() {
return sslContextHolder(globalSSLConfiguration).sslContext(); return sslContextHolder(globalSSLConfiguration).sslContext();
} }
/** /**
* Returns the {@link SSLContext} for the configuration * Returns the {@link SSLContext} for the configuration. Mainly used for testing
*/ */
SSLContext sslContext(SSLConfiguration configuration) { public SSLContext sslContext(SSLConfiguration configuration) {
return sslContextHolder(configuration).sslContext(); return sslContextHolder(configuration).sslContext();
} }

View File

@ -9,6 +9,7 @@ import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.network.CloseableChannel; import org.elasticsearch.common.network.CloseableChannel;
@ -19,12 +20,14 @@ import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.nio.BytesChannelContext; import org.elasticsearch.nio.BytesChannelContext;
import org.elasticsearch.nio.ChannelFactory;
import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.ServerChannelContext; import org.elasticsearch.nio.ServerChannelContext;
import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.ConnectTransportException;
import org.elasticsearch.transport.TcpChannel; import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.transport.TcpTransport; import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.nio.NioTcpChannel; import org.elasticsearch.transport.nio.NioTcpChannel;
@ -38,7 +41,9 @@ import org.elasticsearch.xpack.core.ssl.SSLConfiguration;
import org.elasticsearch.xpack.core.ssl.SSLService; import org.elasticsearch.xpack.core.ssl.SSLService;
import org.elasticsearch.xpack.security.transport.filter.IPFilter; import org.elasticsearch.xpack.security.transport.filter.IPFilter;
import javax.net.ssl.SNIHostName;
import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLParameters;
import java.io.IOException; import java.io.IOException;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
@ -47,6 +52,7 @@ import java.nio.channels.SocketChannel;
import java.util.Collections; import java.util.Collections;
import java.util.Map; import java.util.Map;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier; import java.util.function.Supplier;
import static org.elasticsearch.xpack.core.security.SecurityField.setting; import static org.elasticsearch.xpack.core.security.SecurityField.setting;
@ -128,8 +134,29 @@ public class SecurityNioTransport extends NioTransport {
} }
@Override @Override
protected TcpChannelFactory channelFactory(ProfileSettings profileSettings, boolean isClient) { protected TcpChannelFactory serverChannelFactory(ProfileSettings profileSettings) {
return new SecurityTcpChannelFactory(profileSettings, isClient); return new SecurityTcpChannelFactory(profileSettings, false);
}
@Override
protected Function<DiscoveryNode, TcpChannelFactory> clientChannelFactoryFunction(ProfileSettings profileSettings) {
return (node) -> {
final ChannelFactory.RawChannelFactory rawChannelFactory = new ChannelFactory.RawChannelFactory(profileSettings.tcpNoDelay,
profileSettings.tcpKeepAlive, profileSettings.reuseAddress, Math.toIntExact(profileSettings.sendBufferSize.getBytes()),
Math.toIntExact(profileSettings.receiveBufferSize.getBytes()));
SNIHostName serverName;
String configuredServerName = node.getAttributes().get("server_name");
if (configuredServerName != null) {
try {
serverName = new SNIHostName(configuredServerName);
} catch (IllegalArgumentException e) {
throw new ConnectTransportException(node, "invalid DiscoveryNode server_name [" + configuredServerName + "]", e);
}
} else {
serverName = null;
}
return new SecurityClientTcpChannelFactory(rawChannelFactory, serverName);
};
} }
private class SecurityTcpChannelFactory extends TcpChannelFactory { private class SecurityTcpChannelFactory extends TcpChannelFactory {
@ -139,12 +166,16 @@ public class SecurityNioTransport extends NioTransport {
private final NioIPFilter ipFilter; private final NioIPFilter ipFilter;
private SecurityTcpChannelFactory(ProfileSettings profileSettings, boolean isClient) { private SecurityTcpChannelFactory(ProfileSettings profileSettings, boolean isClient) {
super(new RawChannelFactory(profileSettings.tcpNoDelay, this(new RawChannelFactory(profileSettings.tcpNoDelay,
profileSettings.tcpKeepAlive, profileSettings.tcpKeepAlive,
profileSettings.reuseAddress, profileSettings.reuseAddress,
Math.toIntExact(profileSettings.sendBufferSize.getBytes()), Math.toIntExact(profileSettings.sendBufferSize.getBytes()),
Math.toIntExact(profileSettings.receiveBufferSize.getBytes()))); Math.toIntExact(profileSettings.receiveBufferSize.getBytes())), profileSettings.profileName, isClient);
this.profileName = profileSettings.profileName; }
private SecurityTcpChannelFactory(RawChannelFactory rawChannelFactory, String profileName, boolean isClient) {
super(rawChannelFactory);
this.profileName = profileName;
this.isClient = isClient; this.isClient = isClient;
this.ipFilter = new NioIPFilter(authenticator, profileName); this.ipFilter = new NioIPFilter(authenticator, profileName);
} }
@ -162,18 +193,7 @@ public class SecurityNioTransport extends NioTransport {
SocketChannelContext context; SocketChannelContext context;
if (sslEnabled) { if (sslEnabled) {
SSLEngine sslEngine; SSLDriver sslDriver = new SSLDriver(createSSLEngine(channel), isClient);
SSLConfiguration defaultConfig = profileConfiguration.get(TcpTransport.DEFAULT_PROFILE);
SSLConfiguration sslConfig = profileConfiguration.getOrDefault(profileName, defaultConfig);
boolean hostnameVerificationEnabled = sslConfig.verificationMode().isHostnameVerificationEnabled();
if (hostnameVerificationEnabled) {
InetSocketAddress inetSocketAddress = (InetSocketAddress) channel.getRemoteAddress();
// we create the socket based on the name given. don't reverse DNS
sslEngine = sslService.createSSLEngine(sslConfig, inetSocketAddress.getHostString(), inetSocketAddress.getPort());
} else {
sslEngine = sslService.createSSLEngine(sslConfig, null, -1);
}
SSLDriver sslDriver = new SSLDriver(sslEngine, isClient);
context = new SSLChannelContext(nioChannel, selector, exceptionHandler, sslDriver, readWriteHandler, buffer, ipFilter); context = new SSLChannelContext(nioChannel, selector, exceptionHandler, sslDriver, readWriteHandler, buffer, ipFilter);
} else { } else {
context = new BytesChannelContext(nioChannel, selector, exceptionHandler, readWriteHandler, buffer, ipFilter); context = new BytesChannelContext(nioChannel, selector, exceptionHandler, readWriteHandler, buffer, ipFilter);
@ -192,5 +212,46 @@ public class SecurityNioTransport extends NioTransport {
nioChannel.setContext(context); nioChannel.setContext(context);
return nioChannel; return nioChannel;
} }
protected SSLEngine createSSLEngine(SocketChannel channel) throws IOException {
SSLEngine sslEngine;
SSLConfiguration defaultConfig = profileConfiguration.get(TcpTransport.DEFAULT_PROFILE);
SSLConfiguration sslConfig = profileConfiguration.getOrDefault(profileName, defaultConfig);
boolean hostnameVerificationEnabled = sslConfig.verificationMode().isHostnameVerificationEnabled();
if (hostnameVerificationEnabled) {
InetSocketAddress inetSocketAddress = (InetSocketAddress) channel.getRemoteAddress();
// we create the socket based on the name given. don't reverse DNS
sslEngine = sslService.createSSLEngine(sslConfig, inetSocketAddress.getHostString(), inetSocketAddress.getPort());
} else {
sslEngine = sslService.createSSLEngine(sslConfig, null, -1);
}
return sslEngine;
}
}
private class SecurityClientTcpChannelFactory extends SecurityTcpChannelFactory {
private final SNIHostName serverName;
private SecurityClientTcpChannelFactory(RawChannelFactory rawChannelFactory, SNIHostName serverName) {
super(rawChannelFactory, TcpTransport.DEFAULT_PROFILE, true);
this.serverName = serverName;
}
@Override
public NioTcpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) {
throw new AssertionError("Cannot create TcpServerChannel with client factory");
}
@Override
protected SSLEngine createSSLEngine(SocketChannel channel) throws IOException {
SSLEngine sslEngine = super.createSSLEngine(channel);
if (serverName != null) {
SSLParameters sslParameters = sslEngine.getSSLParameters();
sslParameters.setServerNames(Collections.singletonList(serverName));
sslEngine.setSSLParameters(sslParameters);
}
return sslEngine;
}
} }
} }

View File

@ -13,6 +13,7 @@ import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.MockSecureSettings; import org.elasticsearch.common.settings.MockSecureSettings;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.core.internal.io.IOUtils;
import org.elasticsearch.env.TestEnvironment; import org.elasticsearch.env.TestEnvironment;
import org.elasticsearch.node.Node; import org.elasticsearch.node.Node;
import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.test.transport.MockTransportService;
@ -21,6 +22,7 @@ import org.elasticsearch.transport.BindTransportException;
import org.elasticsearch.transport.ConnectTransportException; import org.elasticsearch.transport.ConnectTransportException;
import org.elasticsearch.transport.ConnectionProfile; import org.elasticsearch.transport.ConnectionProfile;
import org.elasticsearch.transport.TcpTransport; import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportService; import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.common.socket.SocketAccess; import org.elasticsearch.xpack.core.common.socket.SocketAccess;
import org.elasticsearch.xpack.core.ssl.SSLConfiguration; import org.elasticsearch.xpack.core.ssl.SSLConfiguration;
@ -28,12 +30,24 @@ import org.elasticsearch.xpack.core.ssl.SSLService;
import javax.net.SocketFactory; import javax.net.SocketFactory;
import javax.net.ssl.HandshakeCompletedListener; import javax.net.ssl.HandshakeCompletedListener;
import javax.net.ssl.SNIHostName;
import javax.net.ssl.SNIMatcher;
import javax.net.ssl.SNIServerName;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLServerSocket;
import javax.net.ssl.SSLServerSocketFactory;
import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocket;
import java.io.IOException; import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.InetAddress; import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketTimeoutException; import java.net.SocketTimeoutException;
import java.net.UnknownHostException; import java.net.UnknownHostException;
import java.nio.file.Path; import java.nio.file.Path;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
@ -44,6 +58,19 @@ import static org.hamcrest.Matchers.instanceOf;
public abstract class AbstractSimpleSecurityTransportTestCase extends AbstractSimpleTransportTestCase { public abstract class AbstractSimpleSecurityTransportTestCase extends AbstractSimpleTransportTestCase {
private static final ConnectionProfile SINGLE_CHANNEL_PROFILE;
static {
ConnectionProfile.Builder builder = new ConnectionProfile.Builder();
builder.addConnections(1,
TransportRequestOptions.Type.BULK,
TransportRequestOptions.Type.PING,
TransportRequestOptions.Type.RECOVERY,
TransportRequestOptions.Type.REG,
TransportRequestOptions.Type.STATE);
SINGLE_CHANNEL_PROFILE = builder.build();
}
protected SSLService createSSLService() { protected SSLService createSSLService() {
return createSSLService(Settings.EMPTY); return createSSLService(Settings.EMPTY);
} }
@ -54,11 +81,11 @@ public abstract class AbstractSimpleSecurityTransportTestCase extends AbstractSi
MockSecureSettings secureSettings = new MockSecureSettings(); MockSecureSettings secureSettings = new MockSecureSettings();
secureSettings.setString("xpack.ssl.secure_key_passphrase", "testnode"); secureSettings.setString("xpack.ssl.secure_key_passphrase", "testnode");
Settings settings1 = Settings.builder() Settings settings1 = Settings.builder()
.put(settings)
.put("xpack.security.transport.ssl.enabled", true) .put("xpack.security.transport.ssl.enabled", true)
.put("xpack.ssl.key", testnodeKey) .put("xpack.ssl.key", testnodeKey)
.put("xpack.ssl.certificate", testnodeCert) .put("xpack.ssl.certificate", testnodeCert)
.put("path.home", createTempDir()) .put("path.home", createTempDir())
.put(settings)
.setSecureSettings(secureSettings) .setSecureSettings(secureSettings)
.build(); .build();
try { try {
@ -167,4 +194,108 @@ public abstract class AbstractSimpleSecurityTransportTestCase extends AbstractSi
stream.flush(); stream.flush();
} }
} }
public void testSNIServerNameIsPropagated() throws Exception {
SSLService sslService = createSSLService();
final SSLConfiguration sslConfiguration = sslService.getSSLConfiguration("xpack.ssl");
SSLContext sslContext = sslService.sslContext(sslConfiguration);
final SSLServerSocketFactory serverSocketFactory = sslContext.getServerSocketFactory();
final String sniIp = "sni-hostname";
final SNIHostName sniHostName = new SNIHostName(sniIp);
final CountDownLatch latch = new CountDownLatch(2);
try (SSLServerSocket sslServerSocket = (SSLServerSocket) serverSocketFactory.createServerSocket()) {
SSLParameters sslParameters = sslServerSocket.getSSLParameters();
sslParameters.setSNIMatchers(Collections.singletonList(new SNIMatcher(0) {
@Override
public boolean matches(SNIServerName sniServerName) {
if (sniHostName.equals(sniServerName)) {
latch.countDown();
return true;
} else {
return false;
}
}
}));
sslServerSocket.setSSLParameters(sslParameters);
SocketAccess.doPrivileged(() -> sslServerSocket.bind(getLocalEphemeral()));
new Thread(() -> {
try {
SSLSocket acceptedSocket = (SSLSocket) SocketAccess.doPrivileged(sslServerSocket::accept);
// A read call will execute the handshake
int byteRead = acceptedSocket.getInputStream().read();
assertEquals('E', byteRead);
latch.countDown();
IOUtils.closeWhileHandlingException(acceptedSocket);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}).start();
InetSocketAddress serverAddress = (InetSocketAddress) SocketAccess.doPrivileged(sslServerSocket::getLocalSocketAddress);
Settings settings = Settings.builder().put("name", "TS_TEST").put("xpack.ssl.verification_mode", "none").build();
try (MockTransportService serviceC = build(settings, version0, null, true)) {
serviceC.acceptIncomingRequests();
HashMap<String, String> attributes = new HashMap<>();
attributes.put("server_name", sniIp);
DiscoveryNode node = new DiscoveryNode("server_node_id", new TransportAddress(serverAddress), attributes,
EnumSet.allOf(DiscoveryNode.Role.class), Version.CURRENT);
new Thread(() -> {
try {
serviceC.connectToNode(node, SINGLE_CHANNEL_PROFILE);
} catch (ConnectTransportException ex) {
// Ignore. The other side is not setup to do the ES handshake. So this will fail.
}
}).start();
latch.await();
}
}
}
public void testInvalidSNIServerName() throws Exception {
SSLService sslService = createSSLService();
final SSLConfiguration sslConfiguration = sslService.getSSLConfiguration("xpack.ssl");
SSLContext sslContext = sslService.sslContext(sslConfiguration);
final SSLServerSocketFactory serverSocketFactory = sslContext.getServerSocketFactory();
final String sniIp = "invalid_hostname";
try (SSLServerSocket sslServerSocket = (SSLServerSocket) serverSocketFactory.createServerSocket()) {
SocketAccess.doPrivileged(() -> sslServerSocket.bind(getLocalEphemeral()));
new Thread(() -> {
try {
SocketAccess.doPrivileged(sslServerSocket::accept);
} catch (IOException e) {
// We except an IOException from the `accept` call because the server socket will be
// closed before the call returns.
}
}).start();
InetSocketAddress serverAddress = (InetSocketAddress) SocketAccess.doPrivileged(sslServerSocket::getLocalSocketAddress);
Settings settings = Settings.builder().put("name", "TS_TEST").put("xpack.ssl.verification_mode", "none").build();
try (MockTransportService serviceC = build(settings, version0, null, true)) {
serviceC.acceptIncomingRequests();
HashMap<String, String> attributes = new HashMap<>();
attributes.put("server_name", sniIp);
DiscoveryNode node = new DiscoveryNode("server_node_id", new TransportAddress(serverAddress), attributes,
EnumSet.allOf(DiscoveryNode.Role.class), Version.CURRENT);
ConnectTransportException connectException = expectThrows(ConnectTransportException.class,
() -> serviceC.connectToNode(node, SINGLE_CHANNEL_PROFILE));
assertThat(connectException.getMessage(), containsString("invalid DiscoveryNode server_name [invalid_hostname]"));
}
}
}
} }

View File

@ -5,13 +5,6 @@
*/ */
package org.elasticsearch.xpack.security.transport.netty4; package org.elasticsearch.xpack.security.transport.netty4;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.ssl.SslHandler;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNode;
@ -19,51 +12,20 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.test.transport.MockTransportService;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.ConnectTransportException;
import org.elasticsearch.transport.ConnectionProfile; import org.elasticsearch.transport.ConnectionProfile;
import org.elasticsearch.transport.TcpChannel; import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.transport.TcpTransport; import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ssl.SSLService;
import org.elasticsearch.xpack.security.transport.AbstractSimpleSecurityTransportTestCase; import org.elasticsearch.xpack.security.transport.AbstractSimpleSecurityTransportTestCase;
import javax.net.ssl.SNIHostName;
import javax.net.ssl.SNIMatcher;
import javax.net.ssl.SNIServerName;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLParameters;
import java.net.InetSocketAddress;
import java.util.Collections; import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.security.SecurityField.setting;
import static org.hamcrest.Matchers.containsString;
public class SimpleSecurityNetty4ServerTransportTests extends AbstractSimpleSecurityTransportTestCase { public class SimpleSecurityNetty4ServerTransportTests extends AbstractSimpleSecurityTransportTestCase {
private static final ConnectionProfile SINGLE_CHANNEL_PROFILE;
static {
ConnectionProfile.Builder builder = new ConnectionProfile.Builder();
builder.addConnections(1,
TransportRequestOptions.Type.BULK,
TransportRequestOptions.Type.PING,
TransportRequestOptions.Type.RECOVERY,
TransportRequestOptions.Type.REG,
TransportRequestOptions.Type.STATE);
SINGLE_CHANNEL_PROFILE = builder.build();
}
public MockTransportService nettyFromThreadPool(Settings settings, ThreadPool threadPool, final Version version, public MockTransportService nettyFromThreadPool(Settings settings, ThreadPool threadPool, final Version version,
ClusterSettings clusterSettings, boolean doHandshake) { ClusterSettings clusterSettings, boolean doHandshake) {
NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList());
@ -103,134 +65,4 @@ public class SimpleSecurityNetty4ServerTransportTests extends AbstractSimpleSecu
transportService.start(); transportService.start();
return transportService; return transportService;
} }
public void testSNIServerNameIsPropagated() throws Exception {
SSLService sslService = createSSLService();
final ServerBootstrap serverBootstrap = new ServerBootstrap();
boolean success = false;
try {
serverBootstrap.group(new NioEventLoopGroup(1));
serverBootstrap.channel(NioServerSocketChannel.class);
final String sniIp = "sni-hostname";
final SNIHostName sniHostName = new SNIHostName(sniIp);
final CountDownLatch latch = new CountDownLatch(2);
serverBootstrap.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) {
SSLEngine serverEngine = sslService.createSSLEngine(sslService.getSSLConfiguration(setting("transport.ssl.")),
null, -1);
serverEngine.setUseClientMode(false);
SSLParameters sslParameters = serverEngine.getSSLParameters();
sslParameters.setSNIMatchers(Collections.singletonList(new SNIMatcher(0) {
@Override
public boolean matches(SNIServerName sniServerName) {
if (sniHostName.equals(sniServerName)) {
latch.countDown();
return true;
} else {
return false;
}
}
}));
serverEngine.setSSLParameters(sslParameters);
final SslHandler sslHandler = new SslHandler(serverEngine);
sslHandler.handshakeFuture().addListener(future -> latch.countDown());
ch.pipeline().addFirst("sslhandler", sslHandler);
}
});
serverBootstrap.validate();
ChannelFuture serverFuture = serverBootstrap.bind(getLocalEphemeral());
serverFuture.await();
InetSocketAddress serverAddress = (InetSocketAddress) serverFuture.channel().localAddress();
try (MockTransportService serviceC = build(
Settings.builder()
.put("name", "TS_TEST")
.put(TransportService.TRACE_LOG_INCLUDE_SETTING.getKey(), "")
.put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), "NOTHING")
.build(),
version0,
null, true)) {
serviceC.acceptIncomingRequests();
HashMap<String, String> attributes = new HashMap<>();
attributes.put("server_name", sniIp);
DiscoveryNode node = new DiscoveryNode("server_node_id", new TransportAddress(serverAddress), attributes,
EnumSet.allOf(DiscoveryNode.Role.class), Version.CURRENT);
new Thread(() -> {
try {
serviceC.connectToNode(node, SINGLE_CHANNEL_PROFILE);
} catch (ConnectTransportException ex) {
// Ignore. The other side is not setup to do the ES handshake. So this will fail.
}
}).start();
latch.await();
serverBootstrap.config().group().shutdownGracefully(0, 5, TimeUnit.SECONDS);
success = true;
}
} finally {
if (success == false) {
serverBootstrap.config().group().shutdownGracefully(0, 5, TimeUnit.SECONDS);
}
}
}
public void testInvalidSNIServerName() throws Exception {
SSLService sslService = createSSLService();
final ServerBootstrap serverBootstrap = new ServerBootstrap();
boolean success = false;
try {
serverBootstrap.group(new NioEventLoopGroup(1));
serverBootstrap.channel(NioServerSocketChannel.class);
final String sniIp = "invalid_hostname";
serverBootstrap.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) {
SSLEngine serverEngine = sslService.createSSLEngine(sslService.getSSLConfiguration(setting("transport.ssl.")),
null, -1);
serverEngine.setUseClientMode(false);
final SslHandler sslHandler = new SslHandler(serverEngine);
ch.pipeline().addFirst("sslhandler", sslHandler);
}
});
serverBootstrap.validate();
ChannelFuture serverFuture = serverBootstrap.bind(getLocalEphemeral());
serverFuture.await();
InetSocketAddress serverAddress = (InetSocketAddress) serverFuture.channel().localAddress();
try (MockTransportService serviceC = build(
Settings.builder()
.put("name", "TS_TEST")
.put(TransportService.TRACE_LOG_INCLUDE_SETTING.getKey(), "")
.put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), "NOTHING")
.build(),
version0,
null, true)) {
serviceC.acceptIncomingRequests();
HashMap<String, String> attributes = new HashMap<>();
attributes.put("server_name", sniIp);
DiscoveryNode node = new DiscoveryNode("server_node_id", new TransportAddress(serverAddress), attributes,
EnumSet.allOf(DiscoveryNode.Role.class), Version.CURRENT);
ConnectTransportException connectException = expectThrows(ConnectTransportException.class,
() -> serviceC.connectToNode(node, SINGLE_CHANNEL_PROFILE));
assertThat(connectException.getMessage(), containsString("invalid DiscoveryNode server_name [invalid_hostname]"));
serverBootstrap.config().group().shutdownGracefully(0, 5, TimeUnit.SECONDS);
success = true;
}
} finally {
if (success == false) {
serverBootstrap.config().group().shutdownGracefully(0, 5, TimeUnit.SECONDS);
}
}
}
} }