Support TLS/SSL renegotiation (elastic/x-pack-elasticsearch#3600)

This commit is related to elastic/x-pack-elasticsearch#3246. It adds support for receiving TLS/SSL
renegotiation requests for peers.

Original commit: elastic/x-pack-elasticsearch@c22c16b3bc
This commit is contained in:
Tim Brooks 2018-01-18 10:59:44 -07:00 committed by GitHub
parent 0ea43c1aa1
commit fb12a0e383
3 changed files with 128 additions and 28 deletions

View File

@ -72,12 +72,30 @@ public class SSLDriver implements AutoCloseable {
public void init() throws SSLException {
engine.setUseClientMode(isClientMode);
if (currentMode.isHandshake()) {
engine.beginHandshake();
((HandshakeMode) currentMode).startHandshake();
} else {
throw new AssertionError("Attempted to init outside from non-handshaking mode: " + currentMode.modeName());
}
}
/**
* Requests a TLS renegotiation. This means the we will request that the peer performs another handshake
* prior to the continued exchange of application data. This can only be requested if we are currently in
* APPLICATION mode.
*
* @throws SSLException if the handshake cannot be initiated
*/
public void renegotiate() throws SSLException {
if (currentMode.isApplication()) {
currentMode = new HandshakeMode();
engine.beginHandshake();
((HandshakeMode) currentMode).startHandshake();
} else {
throw new IllegalStateException("Attempted to renegotiate while in invalid mode: " + currentMode.modeName());
}
}
public boolean hasFlushPending() {
return networkWriteBuffer.hasRemaining();
}
@ -223,15 +241,6 @@ public class SSLDriver implements AutoCloseable {
}
}
private boolean checkRenegotiation(SSLEngineResult.HandshakeStatus newStatus) {
if (isHandshaking() == false && newStatus != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING
&& newStatus != SSLEngineResult.HandshakeStatus.FINISHED) {
// TODO: Iron out the specifics of renegotiation
throw new IllegalStateException("We do not support renegotiation");
}
return false;
}
private void closingInternal() {
// This check prevents us from attempting to send close_notify twice
if (currentMode.isClose() == false) {
@ -306,7 +315,6 @@ public class SSLDriver implements AutoCloseable {
private SSLEngineResult.HandshakeStatus handshakeStatus;
private void startHandshake() throws SSLException {
engine.beginHandshake();
handshakeStatus = engine.getHandshakeStatus();
if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP &&
handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_WRAP) {
@ -434,7 +442,7 @@ public class SSLDriver implements AutoCloseable {
networkReadBuffer.flip();
SSLEngineResult result = unwrap(buffer);
boolean renegotiationRequested = result.getStatus() != SSLEngineResult.Status.CLOSED
&& checkRenegotiation(result.getHandshakeStatus());
&& maybeRenegotiation(result.getHandshakeStatus());
continueUnwrap = result.bytesProduced() > 0 && renegotiationRequested == false;
}
}
@ -442,10 +450,19 @@ public class SSLDriver implements AutoCloseable {
@Override
public int write(ByteBuffer[] buffers) throws SSLException {
SSLEngineResult result = wrap(buffers);
checkRenegotiation(result.getHandshakeStatus());
maybeRenegotiation(result.getHandshakeStatus());
return result.bytesConsumed();
}
private boolean maybeRenegotiation(SSLEngineResult.HandshakeStatus newStatus) throws SSLException {
if (newStatus != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING && newStatus != SSLEngineResult.HandshakeStatus.FINISHED) {
renegotiate();
return true;
} else {
return false;
}
}
@Override
public boolean needsNonApplicationWrite() {
return false;

View File

@ -54,6 +54,40 @@ public class SSLDriverTests extends ESTestCase {
normalClose(clientDriver, serverDriver);
}
public void testRenegotiate() throws Exception {
SSLContext sslContext = getSSLContext();
SSLDriver clientDriver = getDriver(sslContext.createSSLEngine(), true);
SSLDriver serverDriver = getDriver(sslContext.createSSLEngine(), false);
handshake(clientDriver, serverDriver);
ByteBuffer[] buffers = {ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8))};
sendAppData(clientDriver, serverDriver, buffers);
serverDriver.read(serverBuffer);
assertEquals(ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8)), serverBuffer.sliceBuffersTo(4)[0]);
clientDriver.renegotiate();
assertTrue(clientDriver.isHandshaking());
assertFalse(clientDriver.readyForApplicationWrites());
// This tests that the client driver can still receive data based on the prior handshake
ByteBuffer[] buffers2 = {ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8))};
sendAppData(serverDriver, clientDriver, buffers2);
clientDriver.read(clientBuffer);
assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), clientBuffer.sliceBuffersTo(4)[0]);
handshake(clientDriver, serverDriver, true);
sendAppData(clientDriver, serverDriver, buffers);
serverDriver.read(serverBuffer);
assertEquals(ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8)), serverBuffer.sliceBuffersTo(4)[0]);
sendAppData(serverDriver, clientDriver, buffers2);
clientDriver.read(clientBuffer);
assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), clientBuffer.sliceBuffersTo(4)[0]);
normalClose(clientDriver, serverDriver);
}
public void testBigAppData() throws Exception {
SSLContext sslContext = getSSLContext();
@ -220,10 +254,16 @@ public class SSLDriverTests extends ESTestCase {
}
private void handshake(SSLDriver clientDriver, SSLDriver serverDriver) throws IOException {
clientDriver.init();
serverDriver.init();
handshake(clientDriver, serverDriver, false);
}
assertTrue(clientDriver.needsNonApplicationWrite());
private void handshake(SSLDriver clientDriver, SSLDriver serverDriver, boolean isRenegotiation) throws IOException {
if (isRenegotiation == false) {
clientDriver.init();
serverDriver.init();
}
assertTrue(clientDriver.needsNonApplicationWrite() || clientDriver.hasFlushPending());
assertFalse(serverDriver.needsNonApplicationWrite());
sendHandshakeMessages(clientDriver, serverDriver);
@ -305,4 +345,4 @@ public class SSLDriverTests extends ESTestCase {
private SSLDriver getDriver(SSLEngine engine, boolean isClient) {
return new SSLDriver(engine, isClient);
}
}
}

View File

@ -7,7 +7,12 @@ package org.elasticsearch.xpack.security.transport.nio;
import org.elasticsearch.Version;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.SuppressForbidden;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.OutputStreamStreamOutput;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.MockSecureSettings;
@ -18,28 +23,45 @@ import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.MockPageCacheRecycler;
import org.elasticsearch.env.TestEnvironment;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.mocksocket.MockServerSocket;
import org.elasticsearch.mocksocket.MockSocket;
import org.elasticsearch.node.Node;
import org.elasticsearch.test.transport.MockTransportService;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.AbstractSimpleTransportTestCase;
import org.elasticsearch.transport.BindTransportException;
import org.elasticsearch.transport.ConnectTransportException;
import org.elasticsearch.transport.ConnectionProfile;
import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.common.socket.SocketAccess;
import org.elasticsearch.xpack.ssl.SSLService;
import javax.net.SocketFactory;
import javax.net.ssl.HandshakeCompletedEvent;
import javax.net.ssl.HandshakeCompletedListener;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.UnknownHostException;
import java.nio.channels.SocketChannel;
import java.nio.file.Path;
import java.util.Collections;
import java.util.concurrent.CountDownLatch;
import static java.util.Collections.emptyMap;
import static java.util.Collections.emptySet;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.startsWith;
public class SimpleSecurityNioTransportTests extends AbstractSimpleTransportTestCase {
@ -146,22 +168,43 @@ public class SimpleSecurityNioTransportTests extends AbstractSimpleTransportTest
assertEquals("Failed to bind to ["+ port + "]", bindTransportException.getMessage());
}
@SuppressForbidden(reason = "Need to open socket connection")
public void testRenegotiation() throws Exception {
SSLService sslService = createSSLService();
SocketFactory factory = sslService.sslSocketFactory(Settings.EMPTY);
try (SSLSocket socket = (SSLSocket) factory.createSocket()) {
SocketAccess.doPrivileged(() -> socket.connect(serviceA.boundAddress().publishAddress().address()));
CountDownLatch handshakeLatch = new CountDownLatch(1);
HandshakeCompletedListener firstListener = event -> handshakeLatch.countDown();
socket.addHandshakeCompletedListener(firstListener);
handshakeLatch.countDown();
socket.removeHandshakeCompletedListener(firstListener);
OutputStreamStreamOutput stream = new OutputStreamStreamOutput(socket.getOutputStream());
stream.writeByte((byte) 'E');
stream.writeByte((byte)'S');
stream.writeInt(-1);
stream.flush();
socket.startHandshake();
CountDownLatch renegotiationLatch = new CountDownLatch(1);
HandshakeCompletedListener secondListener = event -> renegotiationLatch.countDown();
socket.addHandshakeCompletedListener(secondListener);
renegotiationLatch.countDown();
socket.removeHandshakeCompletedListener(secondListener);
stream.writeByte((byte) 'E');
stream.writeByte((byte)'S');
stream.writeInt(-1);
stream.flush();
}
}
// TODO: These tests currently rely on plaintext transports
@Override
@AwaitsFix(bugUrl = "")
public void testTcpHandshake() throws IOException, InterruptedException {
}
@Override
@AwaitsFix(bugUrl = "")
public void testHandshakeWithIncompatVersion() {
}
@Override
@AwaitsFix(bugUrl = "")
public void testHandshakeUpdatesVersion() throws IOException {
}
}