SSL/TLS: Do not allow writes before handshake is complete

SSLEngine will throw various SSLExceptions when the application initiates a write prior
to the handshake being completed. The NettySecuredTransport marks a channel as ready
for use once it is connected, even though the handshake has not completed. A handler
has been added that performs the handshake and queues writes until the handshake has
completed. Additionally, fix SslMultiPortTests to always connect to the proper client
profile port.

Closes elastic/elasticsearch#390. Closes elastic/elasticsearch#393. Closes elastic/elasticsearch#394. Closes elastic/elasticsearch#395. Closes elastic/elasticsearch#414

Original commit: elastic/x-pack-elasticsearch@1bb3218373
This commit is contained in:
jaymode 2014-12-17 18:51:23 -05:00
parent 76735579d1
commit 0a9f51f3f5
6 changed files with 373 additions and 35 deletions

View File

@ -6,11 +6,11 @@
package org.elasticsearch.shield.ssl;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.collect.Maps;
import org.elasticsearch.common.component.AbstractComponent;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.settings.ImmutableSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.shield.ShieldSettingsException;
import javax.net.ssl.*;
@ -27,7 +27,7 @@ public class SSLService extends AbstractComponent {
static final String[] DEFAULT_CIPHERS = new String[]{ "TLS_RSA_WITH_AES_128_CBC_SHA256", "TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_DHE_RSA_WITH_AES_128_CBC_SHA", "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA" };
private Map<String, SSLContext> sslContexts = Maps.newHashMapWithExpectedSize(3);
private Map<String, SSLContext> sslContexts = ConcurrentCollections.newConcurrentMap();
@Inject
public SSLService(Settings settings) {

View File

@ -0,0 +1,112 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.shield.transport.netty;
import org.elasticsearch.common.logging.ESLogger;
import org.elasticsearch.common.logging.Loggers;
import org.elasticsearch.common.netty.channel.*;
import org.elasticsearch.common.netty.handler.ssl.SslHandler;
import java.util.LinkedList;
import java.util.Queue;
/**
* Netty requires that nothing be written to the channel prior to the handshake. Writing before the handshake
* completes, results in odd SSLExceptions being thrown. Channel writes can happen from any thread that
* can access the channel and Netty does not provide a way to ensure the handshake has occurred before the
* application writes to the channel. This handler will queue up writes until the handshake has occurred and
* then will pass the writes through the pipeline. After all writes have been completed, this handler removes
* itself from the pipeline.
*
* NOTE: This class assumes that the transport will not use a closed channel again or attempt to reconnect, which
* is the way that NettyTransport currently works
*/
public class HandshakeWaitingHandler extends SimpleChannelHandler {
private static final ESLogger logger = Loggers.getLogger(HandshakeWaitingHandler.class);
private boolean handshaken = false;
private Queue<MessageEvent> pendingWrites = new LinkedList<>();
@Override
public void channelConnected(final ChannelHandlerContext ctx, final ChannelStateEvent e) throws Exception {
SslHandler sslHandler = ctx.getPipeline().get(SslHandler.class);
final ChannelFuture handshakeFuture = sslHandler.handshake();
handshakeFuture.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture channelFuture) throws Exception {
if (handshakeFuture.isSuccess()) {
logger.debug("SSL / TLS handshake completed for channel");
// We synchronize here to allow all pending writes to be processed prior to any writes coming from
// another thread
synchronized (HandshakeWaitingHandler.this) {
handshaken = true;
while (!pendingWrites.isEmpty()) {
MessageEvent event = pendingWrites.remove();
ctx.sendDownstream(event);
}
ctx.getPipeline().remove(HandshakeWaitingHandler.class);
}
ctx.sendUpstream(e);
} else {
Throwable cause = handshakeFuture.getCause();
if (logger.isDebugEnabled()) {
logger.debug("SSL / TLS handshake failed, closing channel: {}", cause, cause.getMessage());
} else {
logger.error("SSL / TLS handshake failed, closing channel: {}", cause.getMessage());
}
synchronized (HandshakeWaitingHandler.this) {
// Set failure on the futures of each message so that listeners are called
while (!pendingWrites.isEmpty()) {
DownstreamMessageEvent event = (DownstreamMessageEvent) pendingWrites.remove();
event.getFuture().setFailure(cause);
}
// Some writes may be waiting to acquire lock, if so the SetFailureOnAddQueue will set
// failure on their futures
pendingWrites = new SetFailureOnAddQueue(cause);
handshakeFuture.getChannel().close();
}
}
}
});
}
@Override
public synchronized void writeRequested(ChannelHandlerContext ctx, MessageEvent e) throws Exception {
// Writes can come from any thread so we need to ensure that we do not let any through
// until handshake has completed
if (!handshaken) {
pendingWrites.add(e);
return;
}
ctx.sendDownstream(e);
}
synchronized boolean hasPendingWrites() {
return !pendingWrites.isEmpty();
}
private static class SetFailureOnAddQueue extends LinkedList<MessageEvent> {
private final Throwable cause;
SetFailureOnAddQueue(Throwable cause) {
super();
this.cause = cause;
}
@Override
public boolean add(MessageEvent messageEvent) {
DownstreamMessageEvent event = (DownstreamMessageEvent) messageEvent;
event.getFuture().setFailure(cause);
return false;
}
}
}

View File

@ -123,6 +123,7 @@ public class NettySecuredTransport extends NettyTransport {
sslEngine.setUseClientMode(true);
ctx.getPipeline().replace(this, "ssl", new SslHandler(sslEngine));
ctx.getPipeline().addAfter("ssl", "handshake", new HandshakeWaitingHandler());
ctx.sendDownstream(e);
}

View File

@ -10,7 +10,6 @@ import org.elasticsearch.common.component.Lifecycle;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.logging.ESLogger;
import org.elasticsearch.common.netty.channel.*;
import org.elasticsearch.common.netty.handler.ssl.SslHandler;
import org.elasticsearch.common.transport.InetSocketTransportAddress;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.threadpool.ThreadPool;
@ -33,36 +32,6 @@ public class SecuredMessageChannelHandler extends MessageChannelHandler {
this.profileName = profileName;
}
@Override
public void channelConnected(final ChannelHandlerContext ctx, final ChannelStateEvent e) throws Exception {
SslHandler sslHandler = ctx.getPipeline().get(SslHandler.class);
// Make sure handler is present and we are the client
if (sslHandler == null || !sslHandler.getEngine().getUseClientMode()) {
return;
}
final ChannelFuture handshakeFuture = sslHandler.handshake();
// Get notified when SSL handshake is done.
handshakeFuture.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (future.isSuccess()) {
logger.debug("SSL / TLS handshake completed for channel");
ctx.sendUpstream(e);
} else {
if (logger.isDebugEnabled()) {
logger.debug("SSL / TLS handshake failed, closing channel: {}", future.getCause(), future.getCause().getMessage());
} else {
logger.error("SSL / TLS handshake failed, closing channel: {}", future.getCause().getMessage());
}
future.getChannel().close();
}
}
});
}
// TODO ADD PREPROCESSING
/**

View File

@ -0,0 +1,238 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.shield.transport.netty;
import org.elasticsearch.common.netty.bootstrap.ClientBootstrap;
import org.elasticsearch.common.netty.bootstrap.ServerBootstrap;
import org.elasticsearch.common.netty.buffer.ChannelBuffer;
import org.elasticsearch.common.netty.buffer.ChannelBuffers;
import org.elasticsearch.common.netty.channel.*;
import org.elasticsearch.common.netty.channel.socket.nio.NioClientSocketChannelFactory;
import org.elasticsearch.common.netty.channel.socket.nio.NioServerSocketChannelFactory;
import org.elasticsearch.common.netty.handler.ssl.SslHandler;
import org.elasticsearch.shield.ssl.SSLService;
import org.elasticsearch.test.ElasticsearchTestCase;
import org.junit.*;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import java.net.InetSocketAddress;
import java.nio.file.Paths;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
import static org.elasticsearch.common.settings.ImmutableSettings.settingsBuilder;
import static org.hamcrest.Matchers.*;
public class HandshakeWaitingHandlerTests extends ElasticsearchTestCase {
private static final int CONCURRENT_CLIENT_REQUESTS = 20;
private static ServerBootstrap serverBootstrap;
private static ClientBootstrap clientBootstrap;
private static SSLContext sslContext;
private static int iterations;
private final AtomicBoolean failed = new AtomicBoolean(false);
private volatile Throwable failureCause = null;
@BeforeClass
public static void setup() throws Exception {
SSLService sslService = new SSLService(settingsBuilder()
.put("shield.ssl.keystore.path", Paths.get(HandshakeWaitingHandlerTests.class.getResource("/org/elasticsearch/shield/transport/ssl/certs/simple/testnode.jks").toURI()))
.put("shield.ssl.keystore.password", "testnode")
.build());
sslContext = sslService.getSslContext();
ChannelFactory factory = new NioServerSocketChannelFactory(
Executors.newFixedThreadPool(1),
Executors.newFixedThreadPool(1));
serverBootstrap = new ServerBootstrap(factory);
ChannelFactory clientFactory = new NioClientSocketChannelFactory(
Executors.newCachedThreadPool(),
Executors.newCachedThreadPool());
clientBootstrap = new ClientBootstrap(clientFactory);
iterations = randomIntBetween(10, 100);
}
@After
public void reset() {
failed.set(false);
failureCause = null;
}
@Test
public void testWriteBeforeHandshakeFailsWithoutHandler() throws Exception {
serverBootstrap.setPipelineFactory(getServerFactory());
final int randomPort = randomIntBetween(49000, 65500);
serverBootstrap.bind(new InetSocketAddress("localhost", randomPort));
clientBootstrap.setPipelineFactory(new ChannelPipelineFactory() {
@Override
public ChannelPipeline getPipeline() throws Exception {
final SSLEngine engine = sslContext.createSSLEngine();
engine.setUseClientMode(true);
return Channels.pipeline(
new SslHandler(engine));
}
});
ExecutorService threadPoolExecutor = Executors.newFixedThreadPool(CONCURRENT_CLIENT_REQUESTS);
try {
List<Callable<ChannelFuture>> callables = new ArrayList<>(CONCURRENT_CLIENT_REQUESTS);
for (int i = 0; i < CONCURRENT_CLIENT_REQUESTS; i++) {
callables.add(new WriteBeforeHandshakeCompletedCallable(clientBootstrap, randomPort));
}
for (int i = 0; i < iterations; i++) {
List<Future<ChannelFuture>> futures = threadPoolExecutor.invokeAll(callables);
for (Future<ChannelFuture> future : futures) {
ChannelFuture handshakeFuture = future.get();
handshakeFuture.await();
handshakeFuture.getChannel().close();
}
if (failed.get()) {
assertThat(failureCause, anyOf(instanceOf(SSLException.class), instanceOf(AssertionError.class)));
break;
}
}
if (!failed.get()) {
fail("Expected this test to fail with an SSLException or AssertionError");
}
} finally {
threadPoolExecutor.shutdown();
}
}
@Test
public void testWriteBeforeHandshakePassesWithHandshakeWaitingHandler() throws Exception {
serverBootstrap.setPipelineFactory(getServerFactory());
final int randomPort = randomIntBetween(49000, 65500);
serverBootstrap.bind(new InetSocketAddress("localhost", randomPort));
clientBootstrap.setPipelineFactory(new ChannelPipelineFactory() {
@Override
public ChannelPipeline getPipeline() throws Exception {
final SSLEngine engine = sslContext.createSSLEngine();
engine.setUseClientMode(true);
return Channels.pipeline(
new SslHandler(engine),
new HandshakeWaitingHandler());
}
});
ExecutorService threadPoolExecutor = Executors.newFixedThreadPool(CONCURRENT_CLIENT_REQUESTS);
try {
List<Callable<ChannelFuture>> callables = new ArrayList<>(CONCURRENT_CLIENT_REQUESTS);
for (int i = 0; i < CONCURRENT_CLIENT_REQUESTS; i++) {
callables.add(new WriteBeforeHandshakeCompletedCallable(clientBootstrap, randomPort));
}
for (int i = 0; i < iterations; i++) {
List<Future<ChannelFuture>> futures = threadPoolExecutor.invokeAll(callables);
for (Future<ChannelFuture> future : futures) {
ChannelFuture handshakeFuture = future.get();
handshakeFuture.await();
// Wait for pending writes to prevent IOExceptions
Channel channel = handshakeFuture.getChannel();
HandshakeWaitingHandler handler = channel.getPipeline().get(HandshakeWaitingHandler.class);
while (handler != null && handler.hasPendingWrites()) {
Thread.sleep(10);
}
channel.close();
}
if (failed.get()) {
failureCause.printStackTrace();
fail("Expected this test to always pass with the HandshakeWaitingHandler in pipeline");
}
}
} finally {
threadPoolExecutor.shutdown();
}
}
private ChannelPipelineFactory getServerFactory() {
return new ChannelPipelineFactory() {
public ChannelPipeline getPipeline() throws Exception {
final SSLEngine sslEngine = sslContext.createSSLEngine();
sslEngine.setUseClientMode(false);
return Channels.pipeline(new SslHandler(sslEngine),
new SimpleChannelHandler() {
@Override
public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) {
// Sink the message
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) {
Throwable cause = e.getCause();
// Only save first cause
if (failed.compareAndSet(false, true)) {
failureCause = cause;
}
ctx.getChannel().close();
}
});
}
};
}
@AfterClass
public static void cleanUp() {
clientBootstrap.shutdown();
serverBootstrap.shutdown();
clientBootstrap.releaseExternalResources();
serverBootstrap.releaseExternalResources();
clientBootstrap = null;
serverBootstrap = null;
sslContext = null;
}
private static class WriteBeforeHandshakeCompletedCallable implements Callable<ChannelFuture> {
private final ClientBootstrap bootstrap;
private final int port;
WriteBeforeHandshakeCompletedCallable(ClientBootstrap bootstrap, int port) {
this.bootstrap = bootstrap;
this.port = port;
}
@Override
public ChannelFuture call() throws Exception {
ChannelBuffer buffer = ChannelBuffers.buffer(8);
buffer.writeLong(SecureRandom.getInstanceStrong().nextLong());
// Connect and wait, then immediately start writing
ChannelFuture future = bootstrap.connect(new InetSocketAddress("localhost", port));
future.awaitUninterruptibly();
Channel channel = future.getChannel();
// Do not call handshake before writing as it will most likely succeed before a write begins
// in the test
ChannelFuture handshakeFuture = null;
for (int i = 0; i < 100; i++) {
channel.write(buffer);
handshakeFuture = channel.getPipeline().get(SslHandler.class).handshake();
}
return handshakeFuture;
}
}
}

View File

@ -7,6 +7,7 @@ package org.elasticsearch.shield.transport.ssl;
import org.elasticsearch.client.transport.NoNodeAvailableException;
import org.elasticsearch.client.transport.TransportClient;
import org.elasticsearch.common.netty.channel.Channel;
import org.elasticsearch.common.settings.ImmutableSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.InetSocketTransportAddress;
@ -14,10 +15,14 @@ import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.test.ShieldIntegrationTest;
import org.elasticsearch.test.ShieldSettingsSource;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.netty.NettyTransport;
import org.junit.BeforeClass;
import org.junit.Test;
import java.io.File;
import java.lang.reflect.Field;
import java.net.InetSocketAddress;
import java.util.Map;
import static org.elasticsearch.common.settings.ImmutableSettings.settingsBuilder;
import static org.elasticsearch.test.ElasticsearchIntegrationTest.ClusterScope;
@ -78,7 +83,7 @@ public class SslMultiPortTests extends ShieldIntegrationTest {
@Test(expected = NoNodeAvailableException.class)
public void testThatStandardTransportClientCannotConnectToClientProfile() throws Exception {
try(TransportClient transportClient = createTransportClient(ImmutableSettings.EMPTY)) {
transportClient.addTransportAddress(new InetSocketTransportAddress("localhost", randomClientPort));
transportClient.addTransportAddress(new InetSocketTransportAddress("localhost", getClientProfilePort()));
transportClient.admin().cluster().prepareHealth().get();
}
}
@ -87,7 +92,7 @@ public class SslMultiPortTests extends ShieldIntegrationTest {
public void testThatProfileTransportClientCanConnectToClientProfile() throws Exception {
Settings settings = ShieldSettingsSource.getSSLSettingsForStore("/org/elasticsearch/shield/transport/ssl/certs/simple/testclient-client-profile.jks", "testclient-client-profile");
try (TransportClient transportClient = createTransportClient(settings)) {
transportClient.addTransportAddress(new InetSocketTransportAddress("localhost", randomClientPort));
transportClient.addTransportAddress(new InetSocketTransportAddress("localhost", getClientProfilePort()));
assertGreenClusterState(transportClient);
}
}
@ -101,4 +106,17 @@ public class SslMultiPortTests extends ShieldIntegrationTest {
transportClient.admin().cluster().prepareHealth().get();
}
}
/*
* Gets the actual port that the client profile in this test environment is listening on as the randomClientPort
* may actually be bound by some other node
*/
private int getClientProfilePort() throws Exception {
NettyTransport transport = (NettyTransport) internalCluster().getInstance(Transport.class);
Field channels = NettyTransport.class.getDeclaredField("serverChannels");
channels.setAccessible(true);
Map<String, Channel> serverChannels = (Map<String, Channel>) channels.get(transport);
Channel clientProfileChannel = serverChannels.get("client");
return ((InetSocketAddress) clientProfileChannel.getLocalAddress()).getPort();
}
}