diff --git a/jetty-client/src/test/java/org/eclipse/jetty/client/SslBytesClientTest.java b/jetty-client/src/test/java/org/eclipse/jetty/client/SslBytesClientTest.java new file mode 100644 index 00000000000..fa917512bfa --- /dev/null +++ b/jetty-client/src/test/java/org/eclipse/jetty/client/SslBytesClientTest.java @@ -0,0 +1,145 @@ +package org.eclipse.jetty.client; + +import java.io.BufferedReader; +import java.io.File; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLServerSocket; +import javax.net.ssl.SSLSocket; + +import org.eclipse.jetty.http.HttpMethods; +import org.eclipse.jetty.http.HttpStatus; +import org.eclipse.jetty.toolchain.test.MavenTestingUtils; +import org.eclipse.jetty.util.ssl.SslContextFactory; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class SslBytesClientTest extends SslBytesTest +{ + private ExecutorService threadPool; + private HttpClient client; + private SimpleProxy proxy; + private SSLServerSocket acceptor; + + @Before + public void init() throws Exception + { + threadPool = Executors.newCachedThreadPool(); + + client = new HttpClient(); + client.setConnectorType(HttpClient.CONNECTOR_SELECT_CHANNEL); + File keyStore = MavenTestingUtils.getTestResourceFile("keystore"); + SslContextFactory cf = client.getSslContextFactory(); + cf.setKeyStorePath(keyStore.getAbsolutePath()); + cf.setKeyStorePassword("storepwd"); + cf.setKeyManagerPassword("keypwd"); + client.start(); + + SSLContext sslContext = cf.getSslContext(); + acceptor = (SSLServerSocket)sslContext.getServerSocketFactory().createServerSocket(5870); + + int serverPort = acceptor.getLocalPort(); + + proxy = new SimpleProxy(threadPool, "localhost", serverPort); + proxy.start(); + logger.debug(":{} <==> :{}", proxy.getPort(), serverPort); + } + + @After + public void destroy() throws Exception + { + if (acceptor != null) + acceptor.close(); + if (proxy != null) + proxy.stop(); + if (client != null) + client.stop(); + if (threadPool != null) + threadPool.shutdownNow(); + } + + @Test + public void testHandshake() throws Exception + { + ContentExchange exchange = new ContentExchange(true); + exchange.setURL("https://localhost:" + proxy.getPort()); + String method = HttpMethods.GET; + exchange.setMethod(method); + client.send(exchange); + + final SSLSocket server = (SSLSocket)acceptor.accept(); + server.setUseClientMode(false); + + Future handshake = threadPool.submit(new Callable() + { + public Object call() throws Exception + { + server.startHandshake(); + return null; + } + }); + + // Client Hello + TLSRecord record = proxy.readFromClient(); + Assert.assertEquals(TLSRecord.Type.HANDSHAKE, record.getType()); + proxy.flushToServer(record); + + // Server Hello + Certificate + Server Done + record = proxy.readFromServer(); + Assert.assertEquals(TLSRecord.Type.HANDSHAKE, record.getType()); + proxy.flushToClient(record); + + // Client Key Exchange + record = proxy.readFromClient(); + Assert.assertEquals(TLSRecord.Type.HANDSHAKE, record.getType()); + proxy.flushToServer(record); + + // Change Cipher Spec + record = proxy.readFromClient(); + Assert.assertEquals(TLSRecord.Type.CHANGE_CIPHER_SPEC, record.getType()); + proxy.flushToServer(record); + + // Client Done + record = proxy.readFromClient(); + Assert.assertEquals(TLSRecord.Type.HANDSHAKE, record.getType()); + proxy.flushToServer(record); + + // Change Cipher Spec + record = proxy.readFromServer(); + Assert.assertEquals(TLSRecord.Type.CHANGE_CIPHER_SPEC, record.getType()); + proxy.flushToClient(record); + + // Server Done + record = proxy.readFromServer(); + Assert.assertEquals(TLSRecord.Type.HANDSHAKE, record.getType()); + proxy.flushToClient(record); + + Assert.assertNull(handshake.get(5, TimeUnit.SECONDS)); + + SimpleProxy.AutomaticFlow automaticProxyFlow = proxy.startAutomaticFlow(); + // Read request + BufferedReader reader = new BufferedReader(new InputStreamReader(server.getInputStream(), "UTF-8")); + String line = reader.readLine(); + Assert.assertTrue(line.startsWith(method)); + while (line.length() > 0) + line = reader.readLine(); + // Write response + OutputStream output = server.getOutputStream(); + output.write(("HTTP/1.1 200 OK\r\n" + + "Content-Length: 0\r\n" + + "\r\n").getBytes("UTF-8")); + output.flush(); + Assert.assertTrue(automaticProxyFlow.stop(5, TimeUnit.SECONDS)); + + Assert.assertEquals(HttpExchange.STATUS_COMPLETED, exchange.waitForDone()); + Assert.assertEquals(HttpStatus.OK_200, exchange.getResponseStatus()); + } +} diff --git a/jetty-client/src/test/java/org/eclipse/jetty/client/SslBytesServerTest.java b/jetty-client/src/test/java/org/eclipse/jetty/client/SslBytesServerTest.java index af8ea78f1c5..4deda7fdf73 100644 --- a/jetty-client/src/test/java/org/eclipse/jetty/client/SslBytesServerTest.java +++ b/jetty-client/src/test/java/org/eclipse/jetty/client/SslBytesServerTest.java @@ -1,22 +1,14 @@ package org.eclipse.jetty.client; import java.io.BufferedReader; -import java.io.EOFException; import java.io.File; import java.io.IOException; -import java.io.InputStream; import java.io.InputStreamReader; -import java.io.InterruptedIOException; import java.io.OutputStream; -import java.net.ServerSocket; -import java.net.Socket; import java.net.SocketTimeoutException; import java.nio.channels.SocketChannel; import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; import java.util.concurrent.Callable; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; @@ -44,8 +36,6 @@ import org.eclipse.jetty.server.Server; import org.eclipse.jetty.server.handler.AbstractHandler; import org.eclipse.jetty.server.ssl.SslSelectChannelConnector; import org.eclipse.jetty.toolchain.test.MavenTestingUtils; -import org.eclipse.jetty.util.log.Log; -import org.eclipse.jetty.util.log.Logger; import org.eclipse.jetty.util.ssl.SslContextFactory; import org.junit.After; import org.junit.Assert; @@ -57,9 +47,8 @@ import org.junit.Test; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.lessThan; -public class SslBytesServerTest +public class SslBytesServerTest extends SslBytesTest { - private final Logger logger = Log.getLogger(getClass()); private final AtomicInteger sslHandles = new AtomicInteger(); private final AtomicInteger httpParses = new AtomicInteger(); private ExecutorService threadPool; @@ -68,7 +57,7 @@ public class SslBytesServerTest private SimpleProxy proxy; @Before - public void startServer() throws Exception + public void init() throws Exception { threadPool = Executors.newCachedThreadPool(); server = new Server(); @@ -136,16 +125,17 @@ public class SslBytesServerTest } }); server.start(); + int serverPort = connector.getLocalPort(); sslContext = cf.getSslContext(); - proxy = new SimpleProxy(threadPool, "localhost", connector.getLocalPort()); + proxy = new SimpleProxy(threadPool, "localhost", serverPort); proxy.start(); - logger.debug(":{} <==> :{}", proxy.getPort(), connector.getLocalPort()); + logger.debug(":{} <==> :{}", proxy.getPort(), serverPort); } @After - public void stopServer() throws Exception + public void destroy() throws Exception { if (proxy != null) proxy.stop(); @@ -1233,315 +1223,4 @@ public class SslBytesServerTest proxy.flushToClient(record); } - public class SimpleProxy implements Runnable - { - private final CountDownLatch latch = new CountDownLatch(1); - private final ExecutorService threadPool; - private final String serverHost; - private final int serverPort; - private volatile ServerSocket serverSocket; - private volatile Socket server; - private volatile Socket client; - - public SimpleProxy(ExecutorService threadPool, String serverHost, int serverPort) - { - this.threadPool = threadPool; - this.serverHost = serverHost; - this.serverPort = serverPort; - } - - public void start() throws Exception - { -// serverSocket = new ServerSocket(5871); - serverSocket = new ServerSocket(0); - Thread acceptor = new Thread(this); - acceptor.start(); - server = new Socket(serverHost, serverPort); - } - - public void stop() throws Exception - { - serverSocket.close(); - } - - public void run() - { - try - { - client = serverSocket.accept(); - latch.countDown(); - } - catch (IOException x) - { - x.printStackTrace(); - } - } - - public int getPort() - { - return serverSocket.getLocalPort(); - } - - public TLSRecord readFromClient() throws IOException - { - TLSRecord record = read(client); - logger.debug("C --> P {}", record); - return record; - } - - private TLSRecord read(Socket socket) throws IOException - { - InputStream input = socket.getInputStream(); - int first = -2; - while (true) - { - try - { - socket.setSoTimeout(500); - first = input.read(); - break; - } - catch (SocketTimeoutException x) - { - if (Thread.currentThread().isInterrupted()) - break; - } - } - if (first == -2) - throw new InterruptedIOException(); - else if (first == -1) - return null; - - if (first >= 0x80) - { - // SSLv2 Record - int hiLength = first & 0x3F; - int loLength = input.read(); - int length = (hiLength << 8) + loLength; - byte[] bytes = new byte[2 + length]; - bytes[0] = (byte)first; - bytes[1] = (byte)loLength; - return read(TLSRecord.Type.HANDSHAKE, input, bytes, 2, length); - } - else - { - // TLS Record - int major = input.read(); - int minor = input.read(); - int hiLength = input.read(); - int loLength = input.read(); - int length = (hiLength << 8) + loLength; - byte[] bytes = new byte[1 + 2 + 2 + length]; - bytes[0] = (byte)first; - bytes[1] = (byte)major; - bytes[2] = (byte)minor; - bytes[3] = (byte)hiLength; - bytes[4] = (byte)loLength; - return read(TLSRecord.Type.from(first), input, bytes, 5, length); - } - } - - private TLSRecord read(TLSRecord.Type type, InputStream input, byte[] bytes, int offset, int length) throws IOException - { - while (length > 0) - { - int read = input.read(bytes, offset, length); - if (read < 0) - throw new EOFException(); - offset += read; - length -= read; - } - return new TLSRecord(type, bytes); - } - - public void flushToServer(TLSRecord record) throws IOException - { - if (record == null) - { - server.shutdownOutput(); - if (client.isOutputShutdown()) - { - client.close(); - server.close(); - } - } - else - { - flush(server, record.getBytes()); - } - } - - public void flushToServer(byte... bytes) throws IOException - { - flush(server, bytes); - } - - private void flush(Socket socket, byte... bytes) throws IOException - { - OutputStream output = socket.getOutputStream(); - output.write(bytes); - output.flush(); - } - - public TLSRecord readFromServer() throws IOException - { - TLSRecord record = read(server); - logger.debug("P <-- S {}", record); - return record; - } - - public void flushToClient(TLSRecord record) throws IOException - { - if (record == null) - { - client.shutdownOutput(); - if (server.isOutputShutdown()) - { - server.close(); - client.close(); - } - } - else - { - flush(client, record.getBytes()); - } - } - - public AutomaticFlow startAutomaticFlow() throws InterruptedException - { - final CountDownLatch startLatch = new CountDownLatch(2); - final CountDownLatch stopLatch = new CountDownLatch(2); - Future clientToServer = threadPool.submit(new Callable() - { - public Object call() throws Exception - { - startLatch.countDown(); - logger.debug("Automatic flow C --> S started"); - try - { - while (true) - { - flushToServer(readFromClient()); - } - } - catch (InterruptedIOException x) - { - return null; - } - finally - { - stopLatch.countDown(); - logger.debug("Automatic flow C --> S finished"); - } - } - }); - Future serverToClient = threadPool.submit(new Callable() - { - public Object call() throws Exception - { - startLatch.countDown(); - logger.debug("Automatic flow C <-- S started"); - try - { - while (true) - { - flushToClient(readFromServer()); - } - } - catch (InterruptedIOException x) - { - return null; - } - finally - { - stopLatch.countDown(); - logger.debug("Automatic flow C <-- S finished"); - } - } - }); - Assert.assertTrue(startLatch.await(5, TimeUnit.SECONDS)); - return new AutomaticFlow(stopLatch, clientToServer, serverToClient); - } - - public boolean awaitClient(int time, TimeUnit unit) throws InterruptedException - { - return latch.await(time, unit); - } - - public class AutomaticFlow - { - private final CountDownLatch stopLatch; - private final Future clientToServer; - private final Future serverToClient; - - public AutomaticFlow(CountDownLatch stopLatch, Future clientToServer, Future serverToClient) - { - this.stopLatch = stopLatch; - this.clientToServer = clientToServer; - this.serverToClient = serverToClient; - } - - public boolean stop(long time, TimeUnit unit) throws InterruptedException - { - clientToServer.cancel(true); - serverToClient.cancel(true); - return stopLatch.await(time, unit); - } - } - } - - public static class TLSRecord - { - private final Type type; - private final byte[] bytes; - - public TLSRecord(Type type, byte[] bytes) - { - this.type = type; - this.bytes = bytes; - } - - public Type getType() - { - return type; - } - - public byte[] getBytes() - { - return bytes; - } - - @Override - public String toString() - { - return "TLSRecord [" + type + "] " + bytes.length + " bytes"; - } - - public enum Type - { - CHANGE_CIPHER_SPEC(20), ALERT(21), HANDSHAKE(22), APPLICATION(23); - - private int code; - - private Type(int code) - { - this.code = code; - Mapper.codes.put(this.code, this); - } - - public static Type from(int code) - { - Type result = Mapper.codes.get(code); - if (result == null) - throw new IllegalArgumentException("Invalid TLSRecord.Type " + code); - return result; - } - - private static class Mapper - { - private static final Map codes = new HashMap(); - } - } - } - } diff --git a/jetty-client/src/test/java/org/eclipse/jetty/client/SslBytesTest.java b/jetty-client/src/test/java/org/eclipse/jetty/client/SslBytesTest.java new file mode 100644 index 00000000000..f66c464297b --- /dev/null +++ b/jetty-client/src/test/java/org/eclipse/jetty/client/SslBytesTest.java @@ -0,0 +1,337 @@ +package org.eclipse.jetty.client; + +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.io.InterruptedIOException; +import java.io.OutputStream; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.SocketTimeoutException; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.Callable; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +import org.eclipse.jetty.util.log.Log; +import org.eclipse.jetty.util.log.Logger; +import org.junit.Assert; + +public abstract class SslBytesTest +{ + protected final Logger logger = Log.getLogger(getClass()); + + public static class TLSRecord + { + private final SslBytesServerTest.TLSRecord.Type type; + private final byte[] bytes; + + public TLSRecord(SslBytesServerTest.TLSRecord.Type type, byte[] bytes) + { + this.type = type; + this.bytes = bytes; + } + + public SslBytesServerTest.TLSRecord.Type getType() + { + return type; + } + + public byte[] getBytes() + { + return bytes; + } + + @Override + public String toString() + { + return "TLSRecord [" + type + "] " + bytes.length + " bytes"; + } + + public enum Type + { + CHANGE_CIPHER_SPEC(20), ALERT(21), HANDSHAKE(22), APPLICATION(23); + + private int code; + + private Type(int code) + { + this.code = code; + SslBytesServerTest.TLSRecord.Type.Mapper.codes.put(this.code, this); + } + + public static SslBytesServerTest.TLSRecord.Type from(int code) + { + SslBytesServerTest.TLSRecord.Type result = SslBytesServerTest.TLSRecord.Type.Mapper.codes.get(code); + if (result == null) + throw new IllegalArgumentException("Invalid TLSRecord.Type " + code); + return result; + } + + private static class Mapper + { + private static final Map codes = new HashMap(); + } + } + } + + public class SimpleProxy implements Runnable + { + private final CountDownLatch latch = new CountDownLatch(1); + private final ExecutorService threadPool; + private final String serverHost; + private final int serverPort; + private volatile ServerSocket serverSocket; + private volatile Socket server; + private volatile Socket client; + + public SimpleProxy(ExecutorService threadPool, String serverHost, int serverPort) + { + this.threadPool = threadPool; + this.serverHost = serverHost; + this.serverPort = serverPort; + } + + public void start() throws Exception + { + serverSocket = new ServerSocket(5871); +// serverSocket = new ServerSocket(0); + Thread acceptor = new Thread(this); + acceptor.start(); + server = new Socket(serverHost, serverPort); + } + + public void stop() throws Exception + { + serverSocket.close(); + } + + public void run() + { + try + { + client = serverSocket.accept(); + latch.countDown(); + } + catch (IOException x) + { + x.printStackTrace(); + } + } + + public int getPort() + { + return serverSocket.getLocalPort(); + } + + public TLSRecord readFromClient() throws IOException + { + TLSRecord record = read(client); + logger.debug("C --> P {}", record); + return record; + } + + private TLSRecord read(Socket socket) throws IOException + { + InputStream input = socket.getInputStream(); + int first = -2; + while (true) + { + try + { + socket.setSoTimeout(500); + first = input.read(); + break; + } + catch (SocketTimeoutException x) + { + if (Thread.currentThread().isInterrupted()) + break; + } + } + if (first == -2) + throw new InterruptedIOException(); + else if (first == -1) + return null; + + if (first >= 0x80) + { + // SSLv2 Record + int hiLength = first & 0x3F; + int loLength = input.read(); + int length = (hiLength << 8) + loLength; + byte[] bytes = new byte[2 + length]; + bytes[0] = (byte)first; + bytes[1] = (byte)loLength; + return read(TLSRecord.Type.HANDSHAKE, input, bytes, 2, length); + } + else + { + // TLS Record + int major = input.read(); + int minor = input.read(); + int hiLength = input.read(); + int loLength = input.read(); + int length = (hiLength << 8) + loLength; + byte[] bytes = new byte[1 + 2 + 2 + length]; + bytes[0] = (byte)first; + bytes[1] = (byte)major; + bytes[2] = (byte)minor; + bytes[3] = (byte)hiLength; + bytes[4] = (byte)loLength; + return read(TLSRecord.Type.from(first), input, bytes, 5, length); + } + } + + private TLSRecord read(SslBytesServerTest.TLSRecord.Type type, InputStream input, byte[] bytes, int offset, int length) throws IOException + { + while (length > 0) + { + int read = input.read(bytes, offset, length); + if (read < 0) + throw new EOFException(); + offset += read; + length -= read; + } + return new TLSRecord(type, bytes); + } + + public void flushToServer(TLSRecord record) throws IOException + { + if (record == null) + { + server.shutdownOutput(); + if (client.isOutputShutdown()) + { + client.close(); + server.close(); + } + } + else + { + flush(server, record.getBytes()); + } + } + + public void flushToServer(byte... bytes) throws IOException + { + flush(server, bytes); + } + + private void flush(Socket socket, byte... bytes) throws IOException + { + OutputStream output = socket.getOutputStream(); + output.write(bytes); + output.flush(); + } + + public TLSRecord readFromServer() throws IOException + { + TLSRecord record = read(server); + logger.debug("P <-- S {}", record); + return record; + } + + public void flushToClient(TLSRecord record) throws IOException + { + if (record == null) + { + client.shutdownOutput(); + if (server.isOutputShutdown()) + { + server.close(); + client.close(); + } + } + else + { + flush(client, record.getBytes()); + } + } + + public SslBytesServerTest.SimpleProxy.AutomaticFlow startAutomaticFlow() throws InterruptedException + { + final CountDownLatch startLatch = new CountDownLatch(2); + final CountDownLatch stopLatch = new CountDownLatch(2); + Future clientToServer = threadPool.submit(new Callable() + { + public Object call() throws Exception + { + startLatch.countDown(); + logger.debug("Automatic flow C --> S started"); + try + { + while (true) + { + flushToServer(readFromClient()); + } + } + catch (InterruptedIOException x) + { + return null; + } + finally + { + stopLatch.countDown(); + logger.debug("Automatic flow C --> S finished"); + } + } + }); + Future serverToClient = threadPool.submit(new Callable() + { + public Object call() throws Exception + { + startLatch.countDown(); + logger.debug("Automatic flow C <-- S started"); + try + { + while (true) + { + flushToClient(readFromServer()); + } + } + catch (InterruptedIOException x) + { + return null; + } + finally + { + stopLatch.countDown(); + logger.debug("Automatic flow C <-- S finished"); + } + } + }); + Assert.assertTrue(startLatch.await(5, TimeUnit.SECONDS)); + return new SslBytesServerTest.SimpleProxy.AutomaticFlow(stopLatch, clientToServer, serverToClient); + } + + public boolean awaitClient(int time, TimeUnit unit) throws InterruptedException + { + return latch.await(time, unit); + } + + public class AutomaticFlow + { + private final CountDownLatch stopLatch; + private final Future clientToServer; + private final Future serverToClient; + + public AutomaticFlow(CountDownLatch stopLatch, Future clientToServer, Future serverToClient) + { + this.stopLatch = stopLatch; + this.clientToServer = clientToServer; + this.serverToClient = serverToClient; + } + + public boolean stop(long time, TimeUnit unit) throws InterruptedException + { + clientToServer.cancel(true); + serverToClient.cancel(true); + return stopLatch.await(time, unit); + } + } + } +}