Support TLS/SSL renegotiation (elastic/x-pack-elasticsearch#3600)

This commit is related to elastic/x-pack-elasticsearch#3246. It adds support for receiving TLS/SSL
renegotiation requests for peers.

Original commit: elastic/x-pack-elasticsearch@c22c16b3bc
This commit is contained in:
Tim Brooks 2018-01-18 10:59:44 -07:00 committed by GitHub
parent 0ea43c1aa1
commit fb12a0e383
3 changed files with 128 additions and 28 deletions

View File

@ -72,12 +72,30 @@ public class SSLDriver implements AutoCloseable {
public void init() throws SSLException { public void init() throws SSLException {
engine.setUseClientMode(isClientMode); engine.setUseClientMode(isClientMode);
if (currentMode.isHandshake()) { if (currentMode.isHandshake()) {
engine.beginHandshake();
((HandshakeMode) currentMode).startHandshake(); ((HandshakeMode) currentMode).startHandshake();
} else { } else {
throw new AssertionError("Attempted to init outside from non-handshaking mode: " + currentMode.modeName()); throw new AssertionError("Attempted to init outside from non-handshaking mode: " + currentMode.modeName());
} }
} }
/**
* Requests a TLS renegotiation. This means the we will request that the peer performs another handshake
* prior to the continued exchange of application data. This can only be requested if we are currently in
* APPLICATION mode.
*
* @throws SSLException if the handshake cannot be initiated
*/
public void renegotiate() throws SSLException {
if (currentMode.isApplication()) {
currentMode = new HandshakeMode();
engine.beginHandshake();
((HandshakeMode) currentMode).startHandshake();
} else {
throw new IllegalStateException("Attempted to renegotiate while in invalid mode: " + currentMode.modeName());
}
}
public boolean hasFlushPending() { public boolean hasFlushPending() {
return networkWriteBuffer.hasRemaining(); return networkWriteBuffer.hasRemaining();
} }
@ -223,15 +241,6 @@ public class SSLDriver implements AutoCloseable {
} }
} }
private boolean checkRenegotiation(SSLEngineResult.HandshakeStatus newStatus) {
if (isHandshaking() == false && newStatus != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING
&& newStatus != SSLEngineResult.HandshakeStatus.FINISHED) {
// TODO: Iron out the specifics of renegotiation
throw new IllegalStateException("We do not support renegotiation");
}
return false;
}
private void closingInternal() { private void closingInternal() {
// This check prevents us from attempting to send close_notify twice // This check prevents us from attempting to send close_notify twice
if (currentMode.isClose() == false) { if (currentMode.isClose() == false) {
@ -306,7 +315,6 @@ public class SSLDriver implements AutoCloseable {
private SSLEngineResult.HandshakeStatus handshakeStatus; private SSLEngineResult.HandshakeStatus handshakeStatus;
private void startHandshake() throws SSLException { private void startHandshake() throws SSLException {
engine.beginHandshake();
handshakeStatus = engine.getHandshakeStatus(); handshakeStatus = engine.getHandshakeStatus();
if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP && if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP &&
handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_WRAP) { handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_WRAP) {
@ -434,7 +442,7 @@ public class SSLDriver implements AutoCloseable {
networkReadBuffer.flip(); networkReadBuffer.flip();
SSLEngineResult result = unwrap(buffer); SSLEngineResult result = unwrap(buffer);
boolean renegotiationRequested = result.getStatus() != SSLEngineResult.Status.CLOSED boolean renegotiationRequested = result.getStatus() != SSLEngineResult.Status.CLOSED
&& checkRenegotiation(result.getHandshakeStatus()); && maybeRenegotiation(result.getHandshakeStatus());
continueUnwrap = result.bytesProduced() > 0 && renegotiationRequested == false; continueUnwrap = result.bytesProduced() > 0 && renegotiationRequested == false;
} }
} }
@ -442,10 +450,19 @@ public class SSLDriver implements AutoCloseable {
@Override @Override
public int write(ByteBuffer[] buffers) throws SSLException { public int write(ByteBuffer[] buffers) throws SSLException {
SSLEngineResult result = wrap(buffers); SSLEngineResult result = wrap(buffers);
checkRenegotiation(result.getHandshakeStatus()); maybeRenegotiation(result.getHandshakeStatus());
return result.bytesConsumed(); return result.bytesConsumed();
} }
private boolean maybeRenegotiation(SSLEngineResult.HandshakeStatus newStatus) throws SSLException {
if (newStatus != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING && newStatus != SSLEngineResult.HandshakeStatus.FINISHED) {
renegotiate();
return true;
} else {
return false;
}
}
@Override @Override
public boolean needsNonApplicationWrite() { public boolean needsNonApplicationWrite() {
return false; return false;

View File

@ -54,6 +54,40 @@ public class SSLDriverTests extends ESTestCase {
normalClose(clientDriver, serverDriver); normalClose(clientDriver, serverDriver);
} }
public void testRenegotiate() throws Exception {
SSLContext sslContext = getSSLContext();
SSLDriver clientDriver = getDriver(sslContext.createSSLEngine(), true);
SSLDriver serverDriver = getDriver(sslContext.createSSLEngine(), false);
handshake(clientDriver, serverDriver);
ByteBuffer[] buffers = {ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8))};
sendAppData(clientDriver, serverDriver, buffers);
serverDriver.read(serverBuffer);
assertEquals(ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8)), serverBuffer.sliceBuffersTo(4)[0]);
clientDriver.renegotiate();
assertTrue(clientDriver.isHandshaking());
assertFalse(clientDriver.readyForApplicationWrites());
// This tests that the client driver can still receive data based on the prior handshake
ByteBuffer[] buffers2 = {ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8))};
sendAppData(serverDriver, clientDriver, buffers2);
clientDriver.read(clientBuffer);
assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), clientBuffer.sliceBuffersTo(4)[0]);
handshake(clientDriver, serverDriver, true);
sendAppData(clientDriver, serverDriver, buffers);
serverDriver.read(serverBuffer);
assertEquals(ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8)), serverBuffer.sliceBuffersTo(4)[0]);
sendAppData(serverDriver, clientDriver, buffers2);
clientDriver.read(clientBuffer);
assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), clientBuffer.sliceBuffersTo(4)[0]);
normalClose(clientDriver, serverDriver);
}
public void testBigAppData() throws Exception { public void testBigAppData() throws Exception {
SSLContext sslContext = getSSLContext(); SSLContext sslContext = getSSLContext();
@ -220,10 +254,16 @@ public class SSLDriverTests extends ESTestCase {
} }
private void handshake(SSLDriver clientDriver, SSLDriver serverDriver) throws IOException { private void handshake(SSLDriver clientDriver, SSLDriver serverDriver) throws IOException {
handshake(clientDriver, serverDriver, false);
}
private void handshake(SSLDriver clientDriver, SSLDriver serverDriver, boolean isRenegotiation) throws IOException {
if (isRenegotiation == false) {
clientDriver.init(); clientDriver.init();
serverDriver.init(); serverDriver.init();
}
assertTrue(clientDriver.needsNonApplicationWrite()); assertTrue(clientDriver.needsNonApplicationWrite() || clientDriver.hasFlushPending());
assertFalse(serverDriver.needsNonApplicationWrite()); assertFalse(serverDriver.needsNonApplicationWrite());
sendHandshakeMessages(clientDriver, serverDriver); sendHandshakeMessages(clientDriver, serverDriver);

View File

@ -7,7 +7,12 @@ package org.elasticsearch.xpack.security.transport.nio;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.SuppressForbidden;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.OutputStreamStreamOutput;
import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.MockSecureSettings; import org.elasticsearch.common.settings.MockSecureSettings;
@ -18,28 +23,45 @@ import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.MockPageCacheRecycler; import org.elasticsearch.common.util.MockPageCacheRecycler;
import org.elasticsearch.env.TestEnvironment; import org.elasticsearch.env.TestEnvironment;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.mocksocket.MockServerSocket;
import org.elasticsearch.mocksocket.MockSocket;
import org.elasticsearch.node.Node; import org.elasticsearch.node.Node;
import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.test.transport.MockTransportService;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.AbstractSimpleTransportTestCase; import org.elasticsearch.transport.AbstractSimpleTransportTestCase;
import org.elasticsearch.transport.BindTransportException; import org.elasticsearch.transport.BindTransportException;
import org.elasticsearch.transport.ConnectTransportException; import org.elasticsearch.transport.ConnectTransportException;
import org.elasticsearch.transport.ConnectionProfile;
import org.elasticsearch.transport.TcpChannel; import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.transport.TcpTransport; import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportService; import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.common.socket.SocketAccess;
import org.elasticsearch.xpack.ssl.SSLService; import org.elasticsearch.xpack.ssl.SSLService;
import javax.net.SocketFactory;
import javax.net.ssl.HandshakeCompletedEvent;
import javax.net.ssl.HandshakeCompletedListener;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import java.io.IOException; import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.InetAddress; import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.UnknownHostException; import java.net.UnknownHostException;
import java.nio.channels.SocketChannel;
import java.nio.file.Path; import java.nio.file.Path;
import java.util.Collections; import java.util.Collections;
import java.util.concurrent.CountDownLatch;
import static java.util.Collections.emptyMap; import static java.util.Collections.emptyMap;
import static java.util.Collections.emptySet; import static java.util.Collections.emptySet;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.startsWith;
public class SimpleSecurityNioTransportTests extends AbstractSimpleTransportTestCase { public class SimpleSecurityNioTransportTests extends AbstractSimpleTransportTestCase {
@ -146,22 +168,43 @@ public class SimpleSecurityNioTransportTests extends AbstractSimpleTransportTest
assertEquals("Failed to bind to ["+ port + "]", bindTransportException.getMessage()); assertEquals("Failed to bind to ["+ port + "]", bindTransportException.getMessage());
} }
@SuppressForbidden(reason = "Need to open socket connection")
public void testRenegotiation() throws Exception {
SSLService sslService = createSSLService();
SocketFactory factory = sslService.sslSocketFactory(Settings.EMPTY);
try (SSLSocket socket = (SSLSocket) factory.createSocket()) {
SocketAccess.doPrivileged(() -> socket.connect(serviceA.boundAddress().publishAddress().address()));
CountDownLatch handshakeLatch = new CountDownLatch(1);
HandshakeCompletedListener firstListener = event -> handshakeLatch.countDown();
socket.addHandshakeCompletedListener(firstListener);
handshakeLatch.countDown();
socket.removeHandshakeCompletedListener(firstListener);
OutputStreamStreamOutput stream = new OutputStreamStreamOutput(socket.getOutputStream());
stream.writeByte((byte) 'E');
stream.writeByte((byte)'S');
stream.writeInt(-1);
stream.flush();
socket.startHandshake();
CountDownLatch renegotiationLatch = new CountDownLatch(1);
HandshakeCompletedListener secondListener = event -> renegotiationLatch.countDown();
socket.addHandshakeCompletedListener(secondListener);
renegotiationLatch.countDown();
socket.removeHandshakeCompletedListener(secondListener);
stream.writeByte((byte) 'E');
stream.writeByte((byte)'S');
stream.writeInt(-1);
stream.flush();
}
}
// TODO: These tests currently rely on plaintext transports // TODO: These tests currently rely on plaintext transports
@Override @Override
@AwaitsFix(bugUrl = "") @AwaitsFix(bugUrl = "")
public void testTcpHandshake() throws IOException, InterruptedException { public void testTcpHandshake() throws IOException, InterruptedException {
} }
@Override
@AwaitsFix(bugUrl = "")
public void testHandshakeWithIncompatVersion() {
}
@Override
@AwaitsFix(bugUrl = "")
public void testHandshakeUpdatesVersion() throws IOException {
}
} }