Merge branch 'master' into release-9.3.3

This commit is contained in:
Jesse McConnell 2015-08-27 08:31:22 -05:00
commit f5d1fb1058
12 changed files with 745 additions and 341 deletions

View File

@ -18,11 +18,6 @@
package org.eclipse.jetty.io; package org.eclipse.jetty.io;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.File; import java.io.File;
@ -43,7 +38,6 @@ import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Future; import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import org.eclipse.jetty.io.ClientConnectionFactory.Helper;
import org.eclipse.jetty.toolchain.test.MavenTestingUtils; import org.eclipse.jetty.toolchain.test.MavenTestingUtils;
import org.eclipse.jetty.toolchain.test.OS; import org.eclipse.jetty.toolchain.test.OS;
import org.eclipse.jetty.util.BufferUtil; import org.eclipse.jetty.util.BufferUtil;
@ -51,6 +45,11 @@ import org.eclipse.jetty.util.IO;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
public class IOTest public class IOTest
{ {
@Test @Test
@ -488,14 +487,4 @@ public class IOTest
for (int i=0;i<buffers.length;i++) for (int i=0;i<buffers.length;i++)
assertEquals(0,buffers[i].remaining()); assertEquals(0,buffers[i].remaining());
} }
@Test
public void testDomain()
{
assertTrue(IO.isInDomain("foo.com","foo.com"));
assertTrue(IO.isInDomain("www.foo.com","foo.com"));
assertFalse(IO.isInDomain("foo.com","bar.com"));
assertFalse(IO.isInDomain("www.foo.com","bar.com"));
}
} }

View File

@ -29,17 +29,17 @@ import org.eclipse.jetty.http.BadMessageException;
import org.eclipse.jetty.http.HttpScheme; import org.eclipse.jetty.http.HttpScheme;
import org.eclipse.jetty.io.ssl.SslConnection; import org.eclipse.jetty.io.ssl.SslConnection;
import org.eclipse.jetty.io.ssl.SslConnection.DecryptedEndPoint; import org.eclipse.jetty.io.ssl.SslConnection.DecryptedEndPoint;
import org.eclipse.jetty.util.IO;
import org.eclipse.jetty.util.TypeUtil; import org.eclipse.jetty.util.TypeUtil;
import org.eclipse.jetty.util.log.Log; import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger; import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.util.ssl.SniX509ExtendedKeyManager;
import org.eclipse.jetty.util.ssl.SslContextFactory; import org.eclipse.jetty.util.ssl.SslContextFactory;
import org.eclipse.jetty.util.ssl.X509;
/**
/* ------------------------------------------------------------ */ * <p>Customizer that extracts the attribute from an {@link SSLContext}
/** Customizer that extracts the attribute from an {@link SSLContext}
* and sets them on the request with {@link ServletRequest#setAttribute(String, Object)} * and sets them on the request with {@link ServletRequest#setAttribute(String, Object)}
* according to Servlet Specification Requirements. * according to Servlet Specification Requirements.</p>
*/ */
public class SecureRequestCustomizer implements HttpConfiguration.Customizer public class SecureRequestCustomizer implements HttpConfiguration.Customizer
{ {
@ -77,10 +77,9 @@ public class SecureRequestCustomizer implements HttpConfiguration.Customizer
} }
} }
/* ------------------------------------------------------------ */ /**
/* * <p>Customizes the request attributes to be set for SSL requests.</p>
* Customise the request attributes to be set for SSL requests. <br> * <p>The requirements of the Servlet specs are:</p>
* The requirements of the Servlet specs are:
* <ul> * <ul>
* <li> an attribute named "javax.servlet.request.ssl_session_id" of type * <li> an attribute named "javax.servlet.request.ssl_session_id" of type
* String (since Servlet Spec 3.0).</li> * String (since Servlet Spec 3.0).</li>
@ -95,8 +94,7 @@ public class SecureRequestCustomizer implements HttpConfiguration.Customizer
* </li> * </li>
* </ul> * </ul>
* *
* @param request * @param request HttpRequest to be customized.
* HttpRequest to be customised.
*/ */
public void customize(SSLEngine sslEngine, Request request) public void customize(SSLEngine sslEngine, Request request)
{ {
@ -105,17 +103,18 @@ public class SecureRequestCustomizer implements HttpConfiguration.Customizer
if (_sniHostCheck) if (_sniHostCheck)
{ {
String sniName = (String)sslSession.getValue("org.eclipse.jetty.util.ssl.sniname");
if (sniName!=null && !sniName.equalsIgnoreCase(request.getServerName()))
{
String wild=(String)sslSession.getValue("org.eclipse.jetty.util.ssl.sniwild");
String name = request.getServerName(); String name = request.getServerName();
if (wild==null || !IO.isInDomain(name,wild)) @SuppressWarnings("unchecked")
X509 x509 = (X509)sslSession.getValue(SniX509ExtendedKeyManager.SNI_X509);
if (x509!=null && !x509.matches(name))
{ {
LOG.warn("Host does not match SNI Name: {}/{}!={}",sniName,wild,request.getServerName()); LOG.warn("Host {} does not match SNI {}",name,x509);
throw new BadMessageException(400,"Host does not match SNI"); throw new BadMessageException(400,"Host does not match SNI");
} }
}
if (LOG.isDebugEnabled())
LOG.debug("Host {} matched SNI {}",name,x509);
} }
try try
@ -134,7 +133,7 @@ public class SecureRequestCustomizer implements HttpConfiguration.Customizer
} }
else else
{ {
keySize=new Integer(SslContextFactory.deduceKeyLength(cipherSuite)); keySize=SslContextFactory.deduceKeyLength(cipherSuite);
certs=SslContextFactory.getCertChain(sslSession); certs=SslContextFactory.getCertChain(sslSession);
byte[] bytes = sslSession.getId(); byte[] bytes = sslSession.getId();
idStr = TypeUtil.toHexString(bytes); idStr = TypeUtil.toHexString(bytes);
@ -161,9 +160,6 @@ public class SecureRequestCustomizer implements HttpConfiguration.Customizer
return String.format("%s@%x",this.getClass().getSimpleName(),hashCode()); return String.format("%s@%x",this.getClass().getSimpleName(),hashCode());
} }
/* ------------------------------------------------------------ */
/* ------------------------------------------------------------ */
/* ------------------------------------------------------------ */
/** /**
* Simple bundle of information that is cached in the SSLSession. Stores the * Simple bundle of information that is cached in the SSLSession. Stores the
* effective keySize and the client certificate chain. * effective keySize and the client certificate chain.
@ -196,7 +192,4 @@ public class SecureRequestCustomizer implements HttpConfiguration.Customizer
return _idStr; return _idStr;
} }
} }
} }

View File

@ -34,7 +34,6 @@ import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.server.Server; import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.util.ByteArrayISO8859Writer; import org.eclipse.jetty.util.ByteArrayISO8859Writer;
import org.eclipse.jetty.util.IO; import org.eclipse.jetty.util.IO;
import org.eclipse.jetty.util.Jetty;
import org.eclipse.jetty.util.log.Log; import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger; import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.util.resource.Resource; import org.eclipse.jetty.util.resource.Resource;
@ -134,7 +133,7 @@ public class DefaultHandler extends AbstractHandler
{ {
writer.write("<li><a href=\""); writer.write("<li><a href=\"");
if (context.getVirtualHosts()!=null && context.getVirtualHosts().length>0) if (context.getVirtualHosts()!=null && context.getVirtualHosts().length>0)
writer.write("http://"+context.getVirtualHosts()[0]+":"+request.getLocalPort()); writer.write(request.getScheme()+"://"+context.getVirtualHosts()[0]+":"+request.getLocalPort());
writer.write(context.getContextPath()); writer.write(context.getContextPath());
if (context.getContextPath().length()>1 && context.getContextPath().endsWith("/")) if (context.getContextPath().length()>1 && context.getContextPath().endsWith("/"))
writer.write("/"); writer.write("/");

View File

@ -21,10 +21,13 @@ package org.eclipse.jetty.server.ssl;
import java.io.File; import java.io.File;
import java.io.FileNotFoundException; import java.io.FileNotFoundException;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket; import java.net.Socket;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Queue; import java.util.Queue;
@ -50,6 +53,7 @@ import org.eclipse.jetty.server.SslConnectionFactory;
import org.eclipse.jetty.server.handler.AbstractHandler; import org.eclipse.jetty.server.handler.AbstractHandler;
import org.eclipse.jetty.util.ConcurrentArrayQueue; import org.eclipse.jetty.util.ConcurrentArrayQueue;
import org.eclipse.jetty.util.IO; import org.eclipse.jetty.util.IO;
import org.eclipse.jetty.util.Utf8StringBuilder;
import org.eclipse.jetty.util.ssl.SslContextFactory; import org.eclipse.jetty.util.ssl.SslContextFactory;
import org.hamcrest.Matchers; import org.hamcrest.Matchers;
import org.junit.After; import org.junit.After;
@ -96,9 +100,10 @@ public class SniSslConnectionFactoryTest
@Override @Override
public void handle(String target, Request baseRequest, HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException public void handle(String target, Request baseRequest, HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException
{ {
baseRequest.setHandled(true);
response.setStatus(200); response.setStatus(200);
response.getWriter().write("url=" + request.getRequestURI() + "\nhost=" + request.getServerName()); response.setHeader("X-URL", request.getRequestURI());
response.flushBuffer(); response.setHeader("X-HOST", request.getServerName());
} }
}); });
@ -117,7 +122,7 @@ public class SniSslConnectionFactoryTest
public void testConnect() throws Exception public void testConnect() throws Exception
{ {
String response = getResponse("127.0.0.1", null); String response = getResponse("127.0.0.1", null);
Assert.assertThat(response, Matchers.containsString("host=127.0.0.1")); Assert.assertThat(response, Matchers.containsString("X-HOST: 127.0.0.1"));
} }
@Test @Test
@ -142,39 +147,39 @@ public class SniSslConnectionFactoryTest
// The first entry in the keystore is www.example.com, and it will // The first entry in the keystore is www.example.com, and it will
// be returned by default, so make sure that here we don't ask for it. // be returned by default, so make sure that here we don't ask for it.
String response = getResponse("jetty.eclipse.org", "jetty.eclipse.org"); String response = getResponse("jetty.eclipse.org", "jetty.eclipse.org");
Assert.assertThat(response, Matchers.containsString("host=jetty.eclipse.org")); Assert.assertThat(response, Matchers.containsString("X-HOST: jetty.eclipse.org"));
} }
@Test @Test
public void testSNIConnect() throws Exception public void testSNIConnect() throws Exception
{ {
String response = getResponse("jetty.eclipse.org", "jetty.eclipse.org"); String response = getResponse("jetty.eclipse.org", "jetty.eclipse.org");
Assert.assertThat(response, Matchers.containsString("host=jetty.eclipse.org")); Assert.assertThat(response, Matchers.containsString("X-HOST: jetty.eclipse.org"));
response = getResponse("www.example.com", "www.example.com"); response = getResponse("www.example.com", "www.example.com");
Assert.assertThat(response, Matchers.containsString("host=www.example.com")); Assert.assertThat(response, Matchers.containsString("X-HOST: www.example.com"));
response = getResponse("foo.domain.com", "*.domain.com"); response = getResponse("foo.domain.com", "*.domain.com");
Assert.assertThat(response, Matchers.containsString("host=foo.domain.com")); Assert.assertThat(response, Matchers.containsString("X-HOST: foo.domain.com"));
response = getResponse("m.san.com", "san example"); response = getResponse("m.san.com", "san example");
Assert.assertThat(response, Matchers.containsString("host=m.san.com")); Assert.assertThat(response, Matchers.containsString("X-HOST: m.san.com"));
response = getResponse("www.san.com", "san example"); response = getResponse("www.san.com", "san example");
Assert.assertThat(response, Matchers.containsString("host=www.san.com")); Assert.assertThat(response, Matchers.containsString("X-HOST: www.san.com"));
} }
@Test @Test
public void testWildSNIConnect() throws Exception public void testWildSNIConnect() throws Exception
{ {
String response = getResponse("domain.com", "www.domain.com", "*.domain.com"); String response = getResponse("domain.com", "www.domain.com", "*.domain.com");
Assert.assertThat(response, Matchers.containsString("host=www.domain.com")); Assert.assertThat(response, Matchers.containsString("X-HOST: www.domain.com"));
response = getResponse("domain.com", "domain.com", "*.domain.com"); response = getResponse("domain.com", "domain.com", "*.domain.com");
Assert.assertThat(response, Matchers.containsString("host=domain.com")); Assert.assertThat(response, Matchers.containsString("X-HOST: domain.com"));
response = getResponse("www.domain.com", "www.domain.com", "*.domain.com"); response = getResponse("www.domain.com", "www.domain.com", "*.domain.com");
Assert.assertThat(response, Matchers.containsString("host=www.domain.com")); Assert.assertThat(response, Matchers.containsString("X-HOST: www.domain.com"));
} }
@Test @Test
@ -185,11 +190,138 @@ public class SniSslConnectionFactoryTest
Assert.assertThat(response, Matchers.containsString("Host does not match SNI")); Assert.assertThat(response, Matchers.containsString("Host does not match SNI"));
} }
@Test
public void testSameConnectionRequestsForManyDomains() throws Exception
{
SslContextFactory clientContextFactory = new SslContextFactory(true);
clientContextFactory.start();
SSLSocketFactory factory = clientContextFactory.getSslContext().getSocketFactory();
try (SSLSocket sslSocket = (SSLSocket)factory.createSocket("127.0.0.1", _port))
{
SNIHostName serverName = new SNIHostName("m.san.com");
SSLParameters params = sslSocket.getSSLParameters();
params.setServerNames(Collections.singletonList(serverName));
sslSocket.setSSLParameters(params);
sslSocket.startHandshake();
// The first request binds the socket to an alias.
String request = "" +
"GET /ctx/path HTTP/1.1\r\n" +
"Host: m.san.com\r\n" +
"\r\n";
OutputStream output = sslSocket.getOutputStream();
output.write(request.getBytes(StandardCharsets.UTF_8));
output.flush();
InputStream input = sslSocket.getInputStream();
String response = response(input);
Assert.assertTrue(response.startsWith("HTTP/1.1 200 "));
// Same socket, send a request for a different domain but same alias.
request = "" +
"GET /ctx/path HTTP/1.1\r\n" +
"Host: www.san.com\r\n" +
"\r\n";
output.write(request.getBytes(StandardCharsets.UTF_8));
output.flush();
response = response(input);
Assert.assertTrue(response.startsWith("HTTP/1.1 200 "));
// Same socket, send a request for a different domain but different alias.
request = "" +
"GET /ctx/path HTTP/1.1\r\n" +
"Host: www.example.com\r\n" +
"\r\n";
output.write(request.getBytes(StandardCharsets.UTF_8));
output.flush();
response = response(input);
Assert.assertTrue(response.startsWith("HTTP/1.1 400 "));
Assert.assertThat(response, Matchers.containsString("Host does not match SNI"));
}
finally
{
clientContextFactory.stop();
}
}
@Test
public void testSameConnectionRequestsForManyWildDomains() throws Exception
{
SslContextFactory clientContextFactory = new SslContextFactory(true);
clientContextFactory.start();
SSLSocketFactory factory = clientContextFactory.getSslContext().getSocketFactory();
try (SSLSocket sslSocket = (SSLSocket)factory.createSocket("127.0.0.1", _port))
{
SNIHostName serverName = new SNIHostName("www.domain.com");
SSLParameters params = sslSocket.getSSLParameters();
params.setServerNames(Collections.singletonList(serverName));
sslSocket.setSSLParameters(params);
sslSocket.startHandshake();
String request = "" +
"GET /ctx/path HTTP/1.1\r\n" +
"Host: www.domain.com\r\n" +
"\r\n";
OutputStream output = sslSocket.getOutputStream();
output.write(request.getBytes(StandardCharsets.UTF_8));
output.flush();
InputStream input = sslSocket.getInputStream();
String response = response(input);
Assert.assertTrue(response.startsWith("HTTP/1.1 200 "));
// Now, on the same socket, send a request for a different valid domain.
request = "" +
"GET /ctx/path HTTP/1.1\r\n" +
"Host: assets.domain.com\r\n" +
"\r\n";
output.write(request.getBytes(StandardCharsets.UTF_8));
output.flush();
response = response(input);
Assert.assertTrue(response.startsWith("HTTP/1.1 200 "));
// Now make a request for an invalid domain for this connection.
request = "" +
"GET /ctx/path HTTP/1.1\r\n" +
"Host: www.example.com\r\n" +
"\r\n";
output.write(request.getBytes(StandardCharsets.UTF_8));
output.flush();
response = response(input);
Assert.assertTrue(response.startsWith("HTTP/1.1 400 "));
Assert.assertThat(response, Matchers.containsString("Host does not match SNI"));
}
finally
{
clientContextFactory.stop();
}
}
private String response(InputStream input) throws IOException
{
Utf8StringBuilder buffer = new Utf8StringBuilder();
int crlfs = 0;
while (true)
{
int read = input.read();
Assert.assertTrue(read >= 0);
buffer.append((byte)read);
crlfs = (read == '\r' || read == '\n') ? crlfs + 1 : 0;
if (crlfs == 4)
break;
}
return buffer.toString();
}
private String getResponse(String host, String cn) throws Exception private String getResponse(String host, String cn) throws Exception
{ {
String response = getResponse(host, host, cn); String response = getResponse(host, host, cn);
Assert.assertThat(response, Matchers.startsWith("HTTP/1.1 200 OK")); Assert.assertThat(response, Matchers.startsWith("HTTP/1.1 200 "));
Assert.assertThat(response, Matchers.containsString("url=/ctx/path")); Assert.assertThat(response, Matchers.containsString("X-URL: /ctx/path"));
return response; return response;
} }
@ -198,8 +330,8 @@ public class SniSslConnectionFactoryTest
SslContextFactory clientContextFactory = new SslContextFactory(true); SslContextFactory clientContextFactory = new SslContextFactory(true);
clientContextFactory.start(); clientContextFactory.start();
SSLSocketFactory factory = clientContextFactory.getSslContext().getSocketFactory(); SSLSocketFactory factory = clientContextFactory.getSslContext().getSocketFactory();
SSLSocket sslSocket = (SSLSocket)factory.createSocket("127.0.0.1", _port); try (SSLSocket sslSocket = (SSLSocket)factory.createSocket("127.0.0.1", _port))
{
if (cn != null) if (cn != null)
{ {
SNIHostName serverName = new SNIHostName(sniHost); SNIHostName serverName = new SNIHostName(sniHost);
@ -215,16 +347,17 @@ public class SniSslConnectionFactoryTest
if (cn != null) if (cn != null)
{ {
X509Certificate cert = ((X509Certificate)sslSocket.getSession().getPeerCertificates()[0]); X509Certificate cert = ((X509Certificate)sslSocket.getSession().getPeerCertificates()[0]);
Assert.assertThat(cert.getSubjectX500Principal().getName("CANONICAL"), Matchers.startsWith("cn=" + cn)); Assert.assertThat(cert.getSubjectX500Principal().getName("CANONICAL"), Matchers.startsWith("cn=" + cn));
} }
sslSocket.getOutputStream().write(("GET /ctx/path HTTP/1.0\r\nHost: " + reqHost + ":" + _port + "\r\n\r\n").getBytes(StandardCharsets.ISO_8859_1)); String response = "GET /ctx/path HTTP/1.0\r\nHost: " + reqHost + ":" + _port + "\r\n\r\n";
String response = IO.toString(sslSocket.getInputStream()); sslSocket.getOutputStream().write(response.getBytes(StandardCharsets.ISO_8859_1));
return IO.toString(sslSocket.getInputStream());
sslSocket.close(); }
finally
{
clientContextFactory.stop(); clientContextFactory.stop();
return response; }
} }
@Test @Test
@ -260,7 +393,7 @@ public class SniSslConnectionFactoryTest
}); });
String response = getResponse("127.0.0.1", null); String response = getResponse("127.0.0.1", null);
Assert.assertThat(response, Matchers.containsString("host=127.0.0.1")); Assert.assertThat(response, Matchers.containsString("X-HOST: 127.0.0.1"));
Assert.assertEquals("customize connector class org.eclipse.jetty.io.ssl.SslConnection,false", history.poll()); Assert.assertEquals("customize connector class org.eclipse.jetty.io.ssl.SslConnection,false", history.poll());
Assert.assertEquals("customize ssl class org.eclipse.jetty.io.ssl.SslConnection,false", history.poll()); Assert.assertEquals("customize ssl class org.eclipse.jetty.io.ssl.SslConnection,false", history.poll());

View File

@ -471,21 +471,6 @@ public class IO
return total; return total;
} }
/* ------------------------------------------------------------ */
/**
* @param name A host name like www.foo.com
* @param domain A domain name like foo.com
* @return True if the host name is in the domain name
*/
public static boolean isInDomain(String name, String domain)
{
if (!name.endsWith(domain))
return false;
if (name.length()==domain.length())
return true;
return name.charAt(name.length()-domain.length()-1)=='.';
}
/* ------------------------------------------------------------ */ /* ------------------------------------------------------------ */
/** /**
* @return An outputstream to nowhere * @return An outputstream to nowhere

View File

@ -41,9 +41,8 @@ import org.eclipse.jetty.util.log.Logger;
*/ */
public class SniX509ExtendedKeyManager extends X509ExtendedKeyManager public class SniX509ExtendedKeyManager extends X509ExtendedKeyManager
{ {
public static final String SNI_NAME = "org.eclipse.jetty.util.ssl.sniname"; public static final String SNI_X509 = "org.eclipse.jetty.util.ssl.snix509";
public static final String SNI_WILD = "org.eclipse.jetty.util.ssl.sniwild"; private static final String NO_MATCHERS = "no_matchers";
public static final String NO_MATCHERS="No Matchers";
private static final Logger LOG = Log.getLogger(SniX509ExtendedKeyManager.class); private static final Logger LOG = Log.getLogger(SniX509ExtendedKeyManager.class);
private final X509ExtendedKeyManager _delegate; private final X509ExtendedKeyManager _delegate;
@ -72,10 +71,9 @@ public class SniX509ExtendedKeyManager extends X509ExtendedKeyManager
if (aliases==null || aliases.length==0) if (aliases==null || aliases.length==0)
return null; return null;
// Look for an SNI alias // Look for the SNI information.
String alias=null;
String host=null; String host=null;
String wild=null; X509 x509=null;
if (matchers!=null) if (matchers!=null)
{ {
for (SNIMatcher m : matchers) for (SNIMatcher m : matchers)
@ -83,28 +81,25 @@ public class SniX509ExtendedKeyManager extends X509ExtendedKeyManager
if (m instanceof SslContextFactory.AliasSNIMatcher) if (m instanceof SslContextFactory.AliasSNIMatcher)
{ {
SslContextFactory.AliasSNIMatcher matcher = (SslContextFactory.AliasSNIMatcher)m; SslContextFactory.AliasSNIMatcher matcher = (SslContextFactory.AliasSNIMatcher)m;
alias=matcher.getAlias(); host=matcher.getHost();
host=matcher.getServerName(); x509=matcher.getX509();
wild=matcher.getWildDomain();
break; break;
} }
} }
} }
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("matched {}/{} from {}",alias,host,Arrays.asList(aliases)); LOG.debug("Matched {} with {} from {}",host,x509,Arrays.asList(aliases));
// Check if the SNI selected alias is allowable // Check if the SNI selected alias is allowable
if (alias!=null) if (x509!=null)
{ {
for (String a:aliases) for (String a:aliases)
{ {
if (a.equals(alias)) if (a.equals(x509.getAlias()))
{ {
session.putValue(SNI_NAME,host); session.putValue(SNI_X509,x509);
if (wild!=null) return a;
session.putValue(SNI_WILD,wild);
return alias;
} }
} }
return null; return null;
@ -121,7 +116,7 @@ public class SniX509ExtendedKeyManager extends X509ExtendedKeyManager
if (alias==NO_MATCHERS) if (alias==NO_MATCHERS)
alias=_delegate.chooseServerAlias(keyType,issuers,socket); alias=_delegate.chooseServerAlias(keyType,issuers,socket);
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("chose {}/{} on {}",alias,keyType,socket); LOG.debug("Chose alias {}/{} on {}",alias,keyType,socket);
return alias; return alias;
} }
@ -132,7 +127,7 @@ public class SniX509ExtendedKeyManager extends X509ExtendedKeyManager
if (alias==NO_MATCHERS) if (alias==NO_MATCHERS)
alias=_delegate.chooseEngineServerAlias(keyType,issuers,engine); alias=_delegate.chooseEngineServerAlias(keyType,issuers,engine);
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("chose {}/{} on {}",alias,keyType,engine); LOG.debug("Chose alias {}/{} on {}",alias,keyType,engine);
return alias; return alias;
} }

View File

@ -38,7 +38,6 @@ import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator; import java.util.Comparator;
import java.util.Enumeration;
import java.util.HashMap; import java.util.HashMap;
import java.util.Iterator; import java.util.Iterator;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
@ -48,8 +47,6 @@ import java.util.Set;
import java.util.regex.Matcher; import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import javax.naming.ldap.LdapName;
import javax.naming.ldap.Rdn;
import javax.net.ssl.CertPathTrustManagerParameters; import javax.net.ssl.CertPathTrustManagerParameters;
import javax.net.ssl.KeyManager; import javax.net.ssl.KeyManager;
import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.KeyManagerFactory;
@ -70,8 +67,8 @@ import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory; import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.X509ExtendedKeyManager; import javax.net.ssl.X509ExtendedKeyManager;
import javax.net.ssl.X509TrustManager; import javax.net.ssl.X509TrustManager;
import javax.security.auth.x500.X500Principal;
import org.eclipse.jetty.util.StringUtil;
import org.eclipse.jetty.util.component.AbstractLifeCycle; import org.eclipse.jetty.util.component.AbstractLifeCycle;
import org.eclipse.jetty.util.log.Log; import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger; import org.eclipse.jetty.util.log.Logger;
@ -110,18 +107,6 @@ public class SslContextFactory extends AbstractLifeCycle
static final Logger LOG = Log.getLogger(SslContextFactory.class); static final Logger LOG = Log.getLogger(SslContextFactory.class);
/*
* @see {@link X509Certificate#getKeyUsage()}
*/
private static final int KEY_USAGE__KEY_CERT_SIGN=5;
/*
*
* @see {@link X509Certificate#getSubjectAlternativeNames()}
*/
private static final int SUBJECT_ALTERNATIVE_NAMES__DNS_NAME=2;
public static final String DEFAULT_KEYMANAGERFACTORY_ALGORITHM = public static final String DEFAULT_KEYMANAGERFACTORY_ALGORITHM =
(Security.getProperty("ssl.KeyManagerFactory.algorithm") == null ? (Security.getProperty("ssl.KeyManagerFactory.algorithm") == null ?
KeyManagerFactory.getDefaultAlgorithm() : Security.getProperty("ssl.KeyManagerFactory.algorithm")); KeyManagerFactory.getDefaultAlgorithm() : Security.getProperty("ssl.KeyManagerFactory.algorithm"));
@ -149,7 +134,7 @@ public class SslContextFactory extends AbstractLifeCycle
private final Set<String> _excludeCipherSuites = new LinkedHashSet<>(); private final Set<String> _excludeCipherSuites = new LinkedHashSet<>();
/** Included cipher suites. */ /** Included cipher suites. */
private final List<String> _includeCipherSuites = new ArrayList<String>(); private final List<String> _includeCipherSuites = new ArrayList<>();
private boolean _useCipherSuitesOrder=true; private boolean _useCipherSuitesOrder=true;
/** Cipher comparator for ordering ciphers */ /** Cipher comparator for ordering ciphers */
@ -167,8 +152,9 @@ public class SslContextFactory extends AbstractLifeCycle
/** SSL certificate alias */ /** SSL certificate alias */
private String _certAlias; private String _certAlias;
private final Map<String,String> _certAliases = new HashMap<>(); private final Map<String,X509> _aliasX509 = new HashMap<>();
private final Map<String,String> _certWilds = new HashMap<>(); private final Map<String,X509> _certHosts = new HashMap<>();
private final Map<String,X509> _certWilds = new HashMap<>();
/** Truststore path */ /** Truststore path */
private Resource _trustStoreResource; private Resource _trustStoreResource;
@ -304,6 +290,16 @@ public class SslContextFactory extends AbstractLifeCycle
_cipherComparator = cipherComparator; _cipherComparator = cipherComparator;
} }
public Set<String> getAliases()
{
return Collections.unmodifiableSet(_aliasX509.keySet());
}
public X509 getX509(String alias)
{
return _aliasX509.get(alias);
}
/** /**
* Create the SSLContext object and start the lifecycle * Create the SSLContext object and start the lifecycle
* @see org.eclipse.jetty.util.component.AbstractLifeCycle#doStart() * @see org.eclipse.jetty.util.component.AbstractLifeCycle#doStart()
@ -343,35 +339,8 @@ public class SslContextFactory extends AbstractLifeCycle
Collection<? extends CRL> crls = loadCRL(_crlPath); Collection<? extends CRL> crls = loadCRL(_crlPath);
if (_validateCerts && keyStore != null)
{
if (_certAlias==null)
{
for (Enumeration<String> e=keyStore.aliases(); _certAlias==null && e.hasMoreElements(); )
{
String alias=e.nextElement();
Certificate c =keyStore.getCertificate(alias);
if (c!=null && "X.509".equals(c.getType()))
_certAlias=alias;
}
}
Certificate cert = _certAlias == null?null:keyStore.getCertificate(_certAlias);
if (cert==null || !"X.509".equals(cert.getType()))
{
throw new Exception("No X.509 certificate in the keystore" + (_certAlias==null ? "":" for alias " + _certAlias));
}
CertificateValidator validator = new CertificateValidator(trustStore, crls);
validator.setMaxCertPathLength(_maxCertPathLength);
validator.setEnableCRLDP(_enableCRLDP);
validator.setEnableOCSP(_enableOCSP);
validator.setOcspResponderURL(_ocspResponderURL);
validator.validate(keyStore, cert);
}
// Look for X.509 certificates to create alias map // Look for X.509 certificates to create alias map
_certAliases.clear(); _certHosts.clear();
if (keyStore!=null) if (keyStore!=null)
{ {
for (String alias : Collections.list(keyStore.aliases())) for (String alias : Collections.list(keyStore.aliases()))
@ -379,64 +348,37 @@ public class SslContextFactory extends AbstractLifeCycle
Certificate certificate = keyStore.getCertificate(alias); Certificate certificate = keyStore.getCertificate(alias);
if (certificate!=null && "X.509".equals(certificate.getType())) if (certificate!=null && "X.509".equals(certificate.getType()))
{ {
X509Certificate x509 = (X509Certificate)certificate; X509Certificate x509C = (X509Certificate)certificate;
// Exclude certificates with special uses // Exclude certificates with special uses
if (x509.getKeyUsage()!=null) if (X509.isCertSign(x509C))
{ {
boolean[] b=x509.getKeyUsage(); if (LOG.isDebugEnabled())
if (b[KEY_USAGE__KEY_CERT_SIGN]) LOG.debug("Skipping "+x509C);
continue; continue;
} }
X509 x509 = new X509(alias,x509C);
_aliasX509.put(alias,x509);
// Look for alternative name extensions if (_validateCerts)
boolean named=false;
Collection<List<?>> altNames = x509.getSubjectAlternativeNames();
if (altNames!=null)
{ {
for (List<?> list : altNames) CertificateValidator validator = new CertificateValidator(trustStore, crls);
{ validator.setMaxCertPathLength(_maxCertPathLength);
if (((Number)list.get(0)).intValue() == SUBJECT_ALTERNATIVE_NAMES__DNS_NAME) validator.setEnableCRLDP(_enableCRLDP);
{ validator.setEnableOCSP(_enableOCSP);
String cn = list.get(1).toString(); validator.setOcspResponderURL(_ocspResponderURL);
if (LOG.isDebugEnabled()) validator.validate(keyStore, x509C); // TODO what about truststore?
LOG.debug("Certificate SAN alias={} cn={} in {}",alias,cn,this);
if (cn!=null)
{
named=true;
_certAliases.put(cn,alias);
}
}
}
} }
// If no names found, look up the cn from the subject LOG.info("x509={} for {}",x509,this);
if (!named)
{
LdapName name=new LdapName(x509.getSubjectX500Principal().getName(X500Principal.RFC2253));
for (Rdn rdn : name.getRdns())
{
if (rdn.getType().equalsIgnoreCase("cn"))
{
String cn = rdn.getValue().toString();
if (LOG.isDebugEnabled())
LOG.debug("Certificate cn alias={} cn={} in {}",alias,cn,this);
if (cn!=null && cn.contains(".") && !cn.contains(" "))
_certAliases.put(cn,alias);
}
}
}
}
}
}
// find wild aliases for (String h:x509.getHosts())
_certWilds.clear(); _certHosts.put(h,x509);
for (String name : _certAliases.keySet()) for (String w:x509.getWilds())
if (name.startsWith("*.")) _certWilds.put(w,x509);
_certWilds.put(name.substring(2),_certAliases.get(name)); }
}
LOG.info("x509={} wild={} alias={} for {}",_certAliases,_certWilds,_certAlias,this); }
// Instantiate key and trust managers // Instantiate key and trust managers
KeyManager[] keyManagers = getKeyManagers(keyStore); KeyManager[] keyManagers = getKeyManagers(keyStore);
@ -462,7 +404,6 @@ public class SslContextFactory extends AbstractLifeCycle
LOG.debug("Selected Protocols {} of {}",Arrays.asList(_selectedProtocols),Arrays.asList(sslEngine.getSupportedProtocols())); LOG.debug("Selected Protocols {} of {}",Arrays.asList(_selectedProtocols),Arrays.asList(sslEngine.getSupportedProtocols()));
LOG.debug("Selected Ciphers {} of {}",Arrays.asList(_selectedCipherSuites),Arrays.asList(sslEngine.getSupportedCipherSuites())); LOG.debug("Selected Ciphers {} of {}",Arrays.asList(_selectedCipherSuites),Arrays.asList(sslEngine.getSupportedCipherSuites()));
} }
} }
@Override @Override
@ -470,8 +411,9 @@ public class SslContextFactory extends AbstractLifeCycle
{ {
_factory = null; _factory = null;
super.doStop(); super.doStop();
_certAliases.clear(); _certHosts.clear();
_certWilds.clear(); _certWilds.clear();
_aliasX509.clear();
} }
/** /**
@ -1140,7 +1082,7 @@ public class SslContextFactory extends AbstractLifeCycle
} }
} }
if (!_certAliases.isEmpty() || !_certWilds.isEmpty()) if (!_certHosts.isEmpty() || !_certWilds.isEmpty())
{ {
for (int idx = 0; idx < managers.length; idx++) for (int idx = 0; idx < managers.length; idx++)
{ {
@ -1615,7 +1557,7 @@ public class SslContextFactory extends AbstractLifeCycle
SSLParameters sslParams = sslEngine.getSSLParameters(); SSLParameters sslParams = sslEngine.getSSLParameters();
sslParams.setEndpointIdentificationAlgorithm(_endpointIdentificationAlgorithm); sslParams.setEndpointIdentificationAlgorithm(_endpointIdentificationAlgorithm);
sslParams.setUseCipherSuitesOrder(_useCipherSuitesOrder); sslParams.setUseCipherSuitesOrder(_useCipherSuitesOrder);
if (!_certAliases.isEmpty() || !_certWilds.isEmpty()) if (!_certHosts.isEmpty() || !_certWilds.isEmpty())
{ {
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("Enable SNI matching {}",sslEngine); LOG.debug("Enable SNI matching {}",sslEngine);
@ -1727,8 +1669,6 @@ public class SslContextFactory extends AbstractLifeCycle
_trustStoreResource); _trustStoreResource);
} }
protected class Factory protected class Factory
{ {
final KeyStore _keyStore; final KeyStore _keyStore;
@ -1752,11 +1692,10 @@ public class SslContextFactory extends AbstractLifeCycle
class AliasSNIMatcher extends SNIMatcher class AliasSNIMatcher extends SNIMatcher
{ {
private String _alias; private String _host;
private String _wild; private X509 _x509;
private SNIHostName _name;
protected AliasSNIMatcher() AliasSNIMatcher()
{ {
super(StandardConstants.SNI_HOST_NAME); super(StandardConstants.SNI_HOST_NAME);
} }
@ -1764,66 +1703,57 @@ public class SslContextFactory extends AbstractLifeCycle
@Override @Override
public boolean matches(SNIServerName serverName) public boolean matches(SNIServerName serverName)
{ {
LOG.debug("matches={} for {}",serverName,this); if (LOG.isDebugEnabled())
LOG.debug("SNI matching for {}",serverName);
if (serverName instanceof SNIHostName) if (serverName instanceof SNIHostName)
{ {
_name=(SNIHostName)serverName; String host = _host = ((SNIHostName)serverName).getAsciiName();
host=StringUtil.asciiToLowerCase(host);
// If we don't have a SNI name, or didn't see any certificate aliases,
// just say true as it will either somehow work or fail elsewhere
if (_certAliases.size()==0)
return true;
// Try an exact match // Try an exact match
_alias = _certAliases.get(_name.getAsciiName()); _x509 = _certHosts.get(host);
if (_alias!=null)
{
if (LOG.isDebugEnabled())
LOG.debug("matched {}->{}",_name.getAsciiName(),_alias);
return true;
}
// Try wild card matches // Else try an exact wild match
String domain = _name.getAsciiName(); if (_x509==null)
_alias = _certWilds.get(domain);
if (_alias==null)
{ {
int dot=domain.indexOf('.'); _x509 = _certWilds.get(host);
// Else try an 1 deep wild match
if (_x509==null)
{
int dot=host.indexOf('.');
if (dot>=0) if (dot>=0)
{ {
domain=domain.substring(dot+1); String domain=host.substring(dot+1);
_alias = _certWilds.get(domain); _x509 = _certWilds.get(domain);
} }
} }
if (_alias!=null) }
if (LOG.isDebugEnabled())
LOG.debug("SNI matched {}->{}",host,_x509);
}
else
{ {
_wild=domain;
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("wild match {}->{}",_name.getAsciiName(),_alias); LOG.debug("SNI no match for {}", serverName);
return true;
} }
}
if (LOG.isDebugEnabled())
LOG.debug("No match for {}",_name.getAsciiName());
// Return true and allow the KeyManager to accept or reject when choosing a certificate. // Return true and allow the KeyManager to accept or reject when choosing a certificate.
// If we don't have a SNI host, or didn't see any certificate aliases,
// just say true as it will either somehow work or fail elsewhere.
return true; return true;
} }
public String getAlias() public String getHost()
{ {
return _alias; return _host;
} }
public String getWildDomain() public X509 getX509()
{ {
return _wild; return _x509;
}
public String getServerName()
{
return _name==null?null:_name.getAsciiName();
} }
} }
} }

View File

@ -0,0 +1,163 @@
//
// ========================================================================
// Copyright (c) 1995-2015 Mort Bay Consulting Pty. Ltd.
// ------------------------------------------------------------------------
// All rights reserved. This program and the accompanying materials
// are made available under the terms of the Eclipse Public License v1.0
// and Apache License v2.0 which accompanies this distribution.
//
// The Eclipse Public License is available at
// http://www.eclipse.org/legal/epl-v10.html
//
// The Apache License v2.0 is available at
// http://www.opensource.org/licenses/apache2.0.php
//
// You may elect to redistribute this code under either of these licenses.
// ========================================================================
//
package org.eclipse.jetty.util.ssl;
import java.security.cert.CertificateParsingException;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import javax.naming.InvalidNameException;
import javax.naming.ldap.LdapName;
import javax.naming.ldap.Rdn;
import javax.security.auth.x500.X500Principal;
import org.eclipse.jetty.util.StringUtil;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
public class X509
{
private static final Logger LOG = Log.getLogger(X509.class);
/*
* @see {@link X509Certificate#getKeyUsage()}
*/
private static final int KEY_USAGE__KEY_CERT_SIGN=5;
/*
*
* @see {@link X509Certificate#getSubjectAlternativeNames()}
*/
private static final int SUBJECT_ALTERNATIVE_NAMES__DNS_NAME=2;
public static boolean isCertSign(X509Certificate x509)
{
boolean[] key_usage=x509.getKeyUsage();
return key_usage!=null && key_usage[KEY_USAGE__KEY_CERT_SIGN];
}
private final X509Certificate _x509;
private final String _alias;
private final List<String> _hosts=new ArrayList<>();
private final List<String> _wilds=new ArrayList<>();
public X509(String alias,X509Certificate x509) throws CertificateParsingException, InvalidNameException
{
_alias=alias;
_x509 = x509;
// Look for alternative name extensions
boolean named=false;
Collection<List<?>> altNames = x509.getSubjectAlternativeNames();
if (altNames!=null)
{
for (List<?> list : altNames)
{
if (((Number)list.get(0)).intValue() == SUBJECT_ALTERNATIVE_NAMES__DNS_NAME)
{
String cn = list.get(1).toString();
if (LOG.isDebugEnabled())
LOG.debug("Certificate SAN alias={} CN={} in {}",alias,cn,this);
if (cn!=null)
{
named=true;
addName(cn);
}
}
}
}
// If no names found, look up the CN from the subject
if (!named)
{
LdapName name=new LdapName(x509.getSubjectX500Principal().getName(X500Principal.RFC2253));
for (Rdn rdn : name.getRdns())
{
if (rdn.getType().equalsIgnoreCase("CN"))
{
String cn = rdn.getValue().toString();
if (LOG.isDebugEnabled())
LOG.debug("Certificate CN alias={} CN={} in {}",alias,cn,this);
if (cn!=null && cn.contains(".") && !cn.contains(" "))
addName(cn);
}
}
}
}
protected void addName(String cn)
{
cn=StringUtil.asciiToLowerCase(cn);
if (cn.startsWith("*."))
_wilds.add(cn.substring(2));
else
_hosts.add(cn);
}
public String getAlias()
{
return _alias;
}
public X509Certificate getCertificate()
{
return _x509;
}
public Set<String> getHosts()
{
return new HashSet<>(_hosts);
}
public Set<String> getWilds()
{
return new HashSet<>(_wilds);
}
public boolean matches(String host)
{
host=StringUtil.asciiToLowerCase(host);
if (_hosts.contains(host) || _wilds.contains(host))
return true;
int dot = host.indexOf('.');
if (dot>=0)
{
String domain=host.substring(dot+1);
if (_wilds.contains(domain))
return true;
}
return false;
}
@Override
public String toString()
{
return String.format("%s@%x(%s,h=%s,w=%s)",
getClass().getSimpleName(),
hashCode(),
_alias,
_hosts,
_wilds);
}
}

View File

@ -19,7 +19,9 @@
package org.eclipse.jetty.util.ssl; package org.eclipse.jetty.util.ssl;
import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThat; import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
@ -221,4 +223,45 @@ public class SslContextFactoryTest
assertNotNull(cf.getExcludeCipherSuites()); assertNotNull(cf.getExcludeCipherSuites());
assertNotNull(cf.getIncludeCipherSuites()); assertNotNull(cf.getIncludeCipherSuites());
} }
@Test
public void testSNICertificates() throws Exception
{
Resource keystoreResource = Resource.newSystemResource("snikeystore");
cf.setKeyStoreResource(keystoreResource);
cf.setKeyStorePassword("storepwd");
cf.setKeyManagerPassword("keypwd");
cf.start();
assertThat(cf.getAliases(),containsInAnyOrder("jetty","other","san","wild"));
assertThat(cf.getX509("jetty").getHosts(),containsInAnyOrder("jetty.eclipse.org"));
assertTrue(cf.getX509("jetty").getWilds().isEmpty());
assertTrue(cf.getX509("jetty").matches("JETTY.Eclipse.Org"));
assertFalse(cf.getX509("jetty").matches("m.jetty.eclipse.org"));
assertFalse(cf.getX509("jetty").matches("eclipse.org"));
assertThat(cf.getX509("other").getHosts(),containsInAnyOrder("www.example.com"));
assertTrue(cf.getX509("other").getWilds().isEmpty());
assertTrue(cf.getX509("other").matches("www.example.com"));
assertFalse(cf.getX509("other").matches("eclipse.org"));
assertThat(cf.getX509("san").getHosts(),containsInAnyOrder("www.san.com","m.san.com"));
assertTrue(cf.getX509("san").getWilds().isEmpty());
assertTrue(cf.getX509("san").matches("www.san.com"));
assertTrue(cf.getX509("san").matches("m.san.com"));
assertFalse(cf.getX509("san").matches("other.san.com"));
assertFalse(cf.getX509("san").matches("san.com"));
assertFalse(cf.getX509("san").matches("eclipse.org"));
assertTrue(cf.getX509("wild").getHosts().isEmpty());
assertThat(cf.getX509("wild").getWilds(),containsInAnyOrder("domain.com"));
assertTrue(cf.getX509("wild").matches("domain.com"));
assertTrue(cf.getX509("wild").matches("www.domain.com"));
assertTrue(cf.getX509("wild").matches("other.domain.com"));
assertFalse(cf.getX509("wild").matches("foo.bar.domain.com"));
assertFalse(cf.getX509("wild").matches("other.com"));
}
} }

Binary file not shown.

View File

@ -413,7 +413,7 @@ public abstract class AbstractWebSocketConnection extends AbstractConnection imp
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("{} onClose()",policy.getBehavior()); LOG.debug("{} onClose()",policy.getBehavior());
super.onClose(); super.onClose();
// ioState.onDisconnected(); ioState.onDisconnected();
flusher.close(); flusher.close();
} }

View File

@ -18,10 +18,12 @@
package org.eclipse.jetty.websocket.server; package org.eclipse.jetty.websocket.server;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Set;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@ -35,7 +37,9 @@ import org.eclipse.jetty.websocket.api.WebSocketAdapter;
import org.eclipse.jetty.websocket.common.CloseInfo; import org.eclipse.jetty.websocket.common.CloseInfo;
import org.eclipse.jetty.websocket.common.OpCode; import org.eclipse.jetty.websocket.common.OpCode;
import org.eclipse.jetty.websocket.common.WebSocketFrame; import org.eclipse.jetty.websocket.common.WebSocketFrame;
import org.eclipse.jetty.websocket.common.WebSocketSession;
import org.eclipse.jetty.websocket.common.events.AbstractEventDriver; import org.eclipse.jetty.websocket.common.events.AbstractEventDriver;
import org.eclipse.jetty.websocket.common.frames.TextFrame;
import org.eclipse.jetty.websocket.common.test.BlockheadClient; import org.eclipse.jetty.websocket.common.test.BlockheadClient;
import org.eclipse.jetty.websocket.server.helper.RFCSocket; import org.eclipse.jetty.websocket.server.helper.RFCSocket;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest; import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest;
@ -44,7 +48,6 @@ import org.eclipse.jetty.websocket.servlet.WebSocketCreator;
import org.eclipse.jetty.websocket.servlet.WebSocketServlet; import org.eclipse.jetty.websocket.servlet.WebSocketServlet;
import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory; import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory;
import org.junit.AfterClass; import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass; import org.junit.BeforeClass;
import org.junit.Test; import org.junit.Test;
@ -79,10 +82,16 @@ public class WebSocketCloseTest
@SuppressWarnings("serial") @SuppressWarnings("serial")
public static class CloseServlet extends WebSocketServlet implements WebSocketCreator public static class CloseServlet extends WebSocketServlet implements WebSocketCreator
{ {
private WebSocketServerFactory serverFactory;
@Override @Override
public void configure(WebSocketServletFactory factory) public void configure(WebSocketServletFactory factory)
{ {
factory.setCreator(this); factory.setCreator(this);
if (factory instanceof WebSocketServerFactory)
{
this.serverFactory = (WebSocketServerFactory)factory;
}
} }
@Override @Override
@ -99,10 +108,58 @@ public class WebSocketCloseTest
closeSocket = new FastFailSocket(); closeSocket = new FastFailSocket();
return closeSocket; return closeSocket;
} }
if (req.hasSubProtocol("container"))
{
closeSocket = new ContainerSocket(serverFactory);
return closeSocket;
}
return new RFCSocket(); return new RFCSocket();
} }
} }
/**
* On Message, return container information
*/
public static class ContainerSocket extends AbstractCloseSocket
{
private static final Logger LOG = Log.getLogger(WebSocketCloseTest.ContainerSocket.class);
private final WebSocketServerFactory container;
private Session session;
public ContainerSocket(WebSocketServerFactory container)
{
this.container = container;
}
@Override
public void onWebSocketText(String message)
{
LOG.debug("onWebSocketText({})",message);
if (message.equalsIgnoreCase("openSessions"))
{
Set<WebSocketSession> sessions = container.getOpenSessions();
StringBuilder ret = new StringBuilder();
ret.append("openSessions.size=").append(sessions.size()).append('\n');
int idx = 0;
for (WebSocketSession sess : sessions)
{
ret.append('[').append(idx++).append("] ").append(sess.toString()).append('\n');
}
session.getRemote().sendStringByFuture(ret.toString());
}
session.close(StatusCode.NORMAL,"ContainerSocket");
}
@Override
public void onWebSocketConnect(Session sess)
{
LOG.debug("onWebSocketConnect({})",sess);
this.session = sess;
}
}
/** /**
* On Connect, close socket * On Connect, close socket
*/ */
@ -155,7 +212,9 @@ public class WebSocketCloseTest
/** /**
* Test fast close (bug #403817) * Test fast close (bug #403817)
* @throws Exception on test failure *
* @throws Exception
* on test failure
*/ */
@Test @Test
public void testFastClose() throws Exception public void testFastClose() throws Exception
@ -171,22 +230,24 @@ public class WebSocketCloseTest
// Verify that client got close frame // Verify that client got close frame
EventQueue<WebSocketFrame> frames = client.readFrames(1,1,TimeUnit.SECONDS); EventQueue<WebSocketFrame> frames = client.readFrames(1,1,TimeUnit.SECONDS);
WebSocketFrame frame = frames.poll(); WebSocketFrame frame = frames.poll();
Assert.assertThat("frames[0].opcode",frame.getOpCode(),is(OpCode.CLOSE)); assertThat("frames[0].opcode",frame.getOpCode(),is(OpCode.CLOSE));
CloseInfo close = new CloseInfo(frame); CloseInfo close = new CloseInfo(frame);
Assert.assertThat("Close Status Code",close.getStatusCode(),is(StatusCode.NORMAL)); assertThat("Close Status Code",close.getStatusCode(),is(StatusCode.NORMAL));
// Notify server of close handshake // Notify server of close handshake
client.write(close.asFrame()); // respond with close client.write(close.asFrame()); // respond with close
// ensure server socket got close event // ensure server socket got close event
Assert.assertThat("Fast Close Latch",closeSocket.closeLatch.await(1,TimeUnit.SECONDS),is(true)); assertThat("Fast Close Latch",closeSocket.closeLatch.await(1,TimeUnit.SECONDS),is(true));
Assert.assertThat("Fast Close.statusCode",closeSocket.closeStatusCode,is(StatusCode.NORMAL)); assertThat("Fast Close.statusCode",closeSocket.closeStatusCode,is(StatusCode.NORMAL));
} }
} }
/** /**
* Test fast fail (bug #410537) * Test fast fail (bug #410537)
* @throws Exception on test failure *
* @throws Exception
* on test failure
*/ */
@Test @Test
public void testFastFail() throws Exception public void testFastFail() throws Exception
@ -203,16 +264,129 @@ public class WebSocketCloseTest
EventQueue<WebSocketFrame> frames = client.readFrames(1,1,TimeUnit.SECONDS); EventQueue<WebSocketFrame> frames = client.readFrames(1,1,TimeUnit.SECONDS);
WebSocketFrame frame = frames.poll(); WebSocketFrame frame = frames.poll();
Assert.assertThat("frames[0].opcode",frame.getOpCode(),is(OpCode.CLOSE)); assertThat("frames[0].opcode",frame.getOpCode(),is(OpCode.CLOSE));
CloseInfo close = new CloseInfo(frame); CloseInfo close = new CloseInfo(frame);
Assert.assertThat("Close Status Code",close.getStatusCode(),is(StatusCode.SERVER_ERROR)); assertThat("Close Status Code",close.getStatusCode(),is(StatusCode.SERVER_ERROR));
client.write(close.asFrame()); // respond with close client.write(close.asFrame()); // respond with close
// ensure server socket got close event // ensure server socket got close event
Assert.assertThat("Fast Fail Latch",closeSocket.closeLatch.await(1,TimeUnit.SECONDS),is(true)); assertThat("Fast Fail Latch",closeSocket.closeLatch.await(1,TimeUnit.SECONDS),is(true));
Assert.assertThat("Fast Fail.statusCode",closeSocket.closeStatusCode,is(StatusCode.SERVER_ERROR)); assertThat("Fast Fail.statusCode",closeSocket.closeStatusCode,is(StatusCode.SERVER_ERROR));
Assert.assertThat("Fast Fail.errors",closeSocket.errors.size(),is(1)); assertThat("Fast Fail.errors",closeSocket.errors.size(),is(1));
}
}
}
/**
* Test session open session cleanup (bug #474936)
*
* @throws Exception
* on test failure
*/
@Test
public void testOpenSessionCleanup() throws Exception
{
fastFail();
fastClose();
dropConnection();
try (BlockheadClient client = new BlockheadClient(server.getServerUri()))
{
client.setProtocols("container");
client.setTimeout(1,TimeUnit.SECONDS);
client.connect();
client.sendStandardRequest();
client.expectUpgradeResponse();
TextFrame text = new TextFrame();
text.setPayload("openSessions");
client.write(text);
EventQueue<WebSocketFrame> frames = client.readFrames(2,1,TimeUnit.SECONDS);
WebSocketFrame frame = frames.poll();
assertThat("frames[0].opcode",frame.getOpCode(),is(OpCode.TEXT));
String resp = frame.getPayloadAsUTF8();
assertThat("Should only have 1 open session",resp,containsString("openSessions.size=1\n"));
frame = frames.poll();
assertThat("frames[1].opcode",frame.getOpCode(),is(OpCode.CLOSE));
CloseInfo close = new CloseInfo(frame);
assertThat("Close Status Code",close.getStatusCode(),is(StatusCode.NORMAL));
client.write(close.asFrame()); // respond with close
// ensure server socket got close event
assertThat("Open Sessions Latch",closeSocket.closeLatch.await(1,TimeUnit.SECONDS),is(true));
assertThat("Open Sessions.statusCode",closeSocket.closeStatusCode,is(StatusCode.NORMAL));
assertThat("Open Sessions.errors",closeSocket.errors.size(),is(0));
}
}
private void fastClose() throws Exception
{
try (BlockheadClient client = new BlockheadClient(server.getServerUri()))
{
client.setProtocols("fastclose");
client.setTimeout(1,TimeUnit.SECONDS);
try (StacklessLogging scope = new StacklessLogging(WebSocketSession.class))
{
client.connect();
client.sendStandardRequest();
client.expectUpgradeResponse();
client.readFrames(1,1,TimeUnit.SECONDS);
CloseInfo close = new CloseInfo(StatusCode.NORMAL,"Normal");
assertThat("Close Status Code",close.getStatusCode(),is(StatusCode.NORMAL));
// Notify server of close handshake
client.write(close.asFrame()); // respond with close
// ensure server socket got close event
assertThat("Fast Close Latch",closeSocket.closeLatch.await(1,TimeUnit.SECONDS),is(true));
assertThat("Fast Close.statusCode",closeSocket.closeStatusCode,is(StatusCode.NORMAL));
}
}
}
private void fastFail() throws Exception
{
try (BlockheadClient client = new BlockheadClient(server.getServerUri()))
{
client.setProtocols("fastfail");
client.setTimeout(1,TimeUnit.SECONDS);
try (StacklessLogging scope = new StacklessLogging(WebSocketSession.class))
{
client.connect();
client.sendStandardRequest();
client.expectUpgradeResponse();
client.readFrames(1,1,TimeUnit.SECONDS);
CloseInfo close = new CloseInfo(StatusCode.NORMAL,"Normal");
client.write(close.asFrame()); // respond with close
// ensure server socket got close event
assertThat("Fast Fail Latch",closeSocket.closeLatch.await(1,TimeUnit.SECONDS),is(true));
assertThat("Fast Fail.statusCode",closeSocket.closeStatusCode,is(StatusCode.SERVER_ERROR));
assertThat("Fast Fail.errors",closeSocket.errors.size(),is(1));
}
}
}
private void dropConnection() throws Exception
{
try (BlockheadClient client = new BlockheadClient(server.getServerUri()))
{
client.setProtocols("container");
client.setTimeout(1,TimeUnit.SECONDS);
try (StacklessLogging scope = new StacklessLogging(WebSocketSession.class))
{
client.connect();
client.sendStandardRequest();
client.expectUpgradeResponse();
client.disconnect();
} }
} }
} }