Refactored channel entry points to use a common base clase since the functionality is almost exactlythe same (apart from the function called on the PortMapper).

This commit is contained in:
Luke Taylor 2008-01-15 17:56:21 +00:00
parent afded24b62
commit 60b7e2d4f2
5 changed files with 115 additions and 199 deletions

View File

@ -0,0 +1,92 @@
package org.springframework.security.securechannel;
import org.springframework.security.util.PortMapper;
import org.springframework.security.util.PortResolver;
import org.springframework.security.util.PortMapperImpl;
import org.springframework.security.util.PortResolverImpl;
import org.springframework.util.Assert;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
/**
* @author Luke Taylor
* @version $Id$
*/
public abstract class AbstractRetryEntryPoint implements ChannelEntryPoint {
//~ Static fields/initializers =====================================================================================
private static final Log logger = LogFactory.getLog(RetryWithHttpEntryPoint.class);
//~ Instance fields ================================================================================================
private PortMapper portMapper = new PortMapperImpl();
private PortResolver portResolver = new PortResolverImpl();
/** The scheme ("http://" or "https://") */
private String scheme;
/** The standard port for the scheme (80 for http, 443 for https) */
private int standardPort;
//~ Constructors ===================================================================================================
public AbstractRetryEntryPoint(String scheme, int standardPort) {
this.scheme = scheme;
this.standardPort = standardPort;
}
//~ Methods ========================================================================================================
public void commence(ServletRequest req, ServletResponse res) throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) req;
String pathInfo = request.getPathInfo();
String queryString = request.getQueryString();
String contextPath = request.getContextPath();
String destination = request.getServletPath() + ((pathInfo == null) ? "" : pathInfo)
+ ((queryString == null) ? "" : ("?" + queryString));
String redirectUrl = contextPath;
Integer currentPort = new Integer(portResolver.getServerPort(request));
Integer redirectPort = getMappedPort(currentPort);
if (redirectPort != null) {
boolean includePort = redirectPort.intValue() != standardPort;
redirectUrl = scheme + request.getServerName() + ((includePort) ? (":" + redirectPort) : "") + contextPath
+ destination;
}
if (logger.isDebugEnabled()) {
logger.debug("Redirecting to: " + redirectUrl);
}
((HttpServletResponse) res).sendRedirect(((HttpServletResponse) res).encodeRedirectURL(redirectUrl));
}
protected abstract Integer getMappedPort(Integer mapFromPort);
protected PortMapper getPortMapper() {
return portMapper;
}
protected PortResolver getPortResolver() {
return portResolver;
}
public void setPortMapper(PortMapper portMapper) {
Assert.notNull(portMapper, "portMapper cannot be null");
this.portMapper = portMapper;
}
public void setPortResolver(PortResolver portResolver) {
Assert.notNull(portResolver, "portResolver cannot be null");
this.portResolver = portResolver;
}
}

View File

@ -15,98 +15,23 @@
package org.springframework.security.securechannel;
import org.springframework.security.util.PortMapper;
import org.springframework.security.util.PortMapperImpl;
import org.springframework.security.util.PortResolver;
import org.springframework.security.util.PortResolverImpl;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.util.Assert;
import java.io.IOException;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
/**
* Commences an insecure channel by retrying the original request using HTTP.<P>This entry point should suffice in
* most circumstances. However, it is not intended to properly handle HTTP POSTs or other usage where a standard
* redirect would cause an issue.</p>
* Commences an insecure channel by retrying the original request using HTTP.
* <p>
* This entry point should suffice in most circumstances. However, it is not intended to properly handle HTTP POSTs or
* other usage where a standard redirect would cause an issue.
*
* @author Ben Alex
* @version $Id$
*/
public class RetryWithHttpEntryPoint implements InitializingBean, ChannelEntryPoint {
//~ Static fields/initializers =====================================================================================
public class RetryWithHttpEntryPoint extends AbstractRetryEntryPoint {
private static final Log logger = LogFactory.getLog(RetryWithHttpEntryPoint.class);
//~ Instance fields ================================================================================================
private PortMapper portMapper = new PortMapperImpl();
private PortResolver portResolver = new PortResolverImpl();
//~ Methods ========================================================================================================
public void afterPropertiesSet() throws Exception {
Assert.notNull(portMapper, "portMapper is required");
Assert.notNull(portResolver, "portResolver is required");
public RetryWithHttpEntryPoint() {
super("http://", 80);
}
public void commence(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
HttpServletRequest req = (HttpServletRequest) request;
String pathInfo = req.getPathInfo();
String queryString = req.getQueryString();
String contextPath = req.getContextPath();
String destination = req.getServletPath() + ((pathInfo == null) ? "" : pathInfo)
+ ((queryString == null) ? "" : ("?" + queryString));
String redirectUrl = contextPath;
Integer httpsPort = new Integer(portResolver.getServerPort(req));
Integer httpPort = portMapper.lookupHttpPort(httpsPort);
if (httpPort != null) {
boolean includePort = true;
if (httpPort.intValue() == 80) {
includePort = false;
}
redirectUrl = "http://" + req.getServerName() + ((includePort) ? (":" + httpPort) : "") + contextPath
+ destination;
}
if (logger.isDebugEnabled()) {
logger.debug("Redirecting to: " + redirectUrl);
}
((HttpServletResponse) response).sendRedirect(((HttpServletResponse) response).encodeRedirectURL(redirectUrl));
}
public PortMapper getPortMapper() {
return portMapper;
}
public PortResolver getPortResolver() {
return portResolver;
}
public void setPortMapper(PortMapper portMapper) {
this.portMapper = portMapper;
}
public void setPortResolver(PortResolver portResolver) {
this.portResolver = portResolver;
protected Integer getMappedPort(Integer mapFromPort) {
return getPortMapper().lookupHttpPort(mapFromPort);
}
}

View File

@ -15,98 +15,22 @@
package org.springframework.security.securechannel;
import org.springframework.security.util.PortMapper;
import org.springframework.security.util.PortMapperImpl;
import org.springframework.security.util.PortResolver;
import org.springframework.security.util.PortResolverImpl;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.util.Assert;
import java.io.IOException;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
/**
* Commences a secure channel by retrying the original request using HTTPS.<P>This entry point should suffice in
* most circumstances. However, it is not intended to properly handle HTTP POSTs or other usage where a standard
* redirect would cause an issue.</p>
* Commences a secure channel by retrying the original request using HTTPS.
* <p>
* This entry point should suffice in most circumstances. However, it is not intended to properly handle HTTP POSTs
* or other usage where a standard redirect would cause an issue.</p>
*
* @author Ben Alex
* @version $Id$
*/
public class RetryWithHttpsEntryPoint implements InitializingBean, ChannelEntryPoint {
//~ Static fields/initializers =====================================================================================
public class RetryWithHttpsEntryPoint extends AbstractRetryEntryPoint {
private static final Log logger = LogFactory.getLog(RetryWithHttpsEntryPoint.class);
//~ Instance fields ================================================================================================
private PortMapper portMapper = new PortMapperImpl();
private PortResolver portResolver = new PortResolverImpl();
//~ Methods ========================================================================================================
public void afterPropertiesSet() throws Exception {
Assert.notNull(portMapper, "portMapper is required");
Assert.notNull(portResolver, "portResolver is required");
public RetryWithHttpsEntryPoint() {
super("https://", 443);
}
public void commence(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
HttpServletRequest req = (HttpServletRequest) request;
String pathInfo = req.getPathInfo();
String queryString = req.getQueryString();
String contextPath = req.getContextPath();
String destination = req.getServletPath() + ((pathInfo == null) ? "" : pathInfo)
+ ((queryString == null) ? "" : ("?" + queryString));
String redirectUrl = contextPath;
Integer httpPort = new Integer(portResolver.getServerPort(req));
Integer httpsPort = portMapper.lookupHttpsPort(httpPort);
if (httpsPort != null) {
boolean includePort = true;
if (httpsPort.intValue() == 443) {
includePort = false;
}
redirectUrl = "https://" + req.getServerName() + ((includePort) ? (":" + httpsPort) : "") + contextPath
+ destination;
}
if (logger.isDebugEnabled()) {
logger.debug("Redirecting to: " + redirectUrl);
}
((HttpServletResponse) response).sendRedirect(((HttpServletResponse) response).encodeRedirectURL(redirectUrl));
}
public PortMapper getPortMapper() {
return portMapper;
}
public PortResolver getPortResolver() {
return portResolver;
}
public void setPortMapper(PortMapper portMapper) {
this.portMapper = portMapper;
}
public void setPortResolver(PortResolver portResolver) {
this.portResolver = portResolver;
protected Integer getMappedPort(Integer mapFromPort) {
return getPortMapper().lookupHttpsPort(mapFromPort);
}
}

View File

@ -37,35 +37,23 @@ import java.util.Map;
public class RetryWithHttpEntryPointTests extends TestCase {
//~ Methods ========================================================================================================
public static void main(String[] args) {
junit.textui.TestRunner.run(RetryWithHttpEntryPointTests.class);
}
public final void setUp() throws Exception {
super.setUp();
}
public void testDetectsMissingPortMapper() throws Exception {
RetryWithHttpEntryPoint ep = new RetryWithHttpEntryPoint();
ep.setPortMapper(null);
try {
ep.afterPropertiesSet();
ep.setPortMapper(null);
fail("Should have thrown IllegalArgumentException");
} catch (IllegalArgumentException expected) {
assertEquals("portMapper is required", expected.getMessage());
}
}
public void testDetectsMissingPortResolver() throws Exception {
RetryWithHttpEntryPoint ep = new RetryWithHttpEntryPoint();
ep.setPortResolver(null);
try {
ep.afterPropertiesSet();
ep.setPortResolver(null);
fail("Should have thrown IllegalArgumentException");
} catch (IllegalArgumentException expected) {
assertEquals("portResolver is required", expected.getMessage());
}
}
@ -92,7 +80,6 @@ public class RetryWithHttpEntryPointTests extends TestCase {
RetryWithHttpEntryPoint ep = new RetryWithHttpEntryPoint();
ep.setPortMapper(new PortMapperImpl());
ep.setPortResolver(new MockPortResolver(80, 443));
ep.afterPropertiesSet();
ep.commence(request, response);
assertEquals("http://www.example.com/bigWebApp/hello/pathInfo.html?open=true", response.getRedirectedUrl());
@ -113,7 +100,6 @@ public class RetryWithHttpEntryPointTests extends TestCase {
RetryWithHttpEntryPoint ep = new RetryWithHttpEntryPoint();
ep.setPortMapper(new PortMapperImpl());
ep.setPortResolver(new MockPortResolver(80, 443));
ep.afterPropertiesSet();
ep.commence(request, response);
assertEquals("http://www.example.com/bigWebApp/hello", response.getRedirectedUrl());
@ -135,7 +121,6 @@ public class RetryWithHttpEntryPointTests extends TestCase {
RetryWithHttpEntryPoint ep = new RetryWithHttpEntryPoint();
ep.setPortMapper(new PortMapperImpl());
ep.setPortResolver(new MockPortResolver(8768, 1234));
ep.afterPropertiesSet();
ep.commence(request, response);
assertEquals("/bigWebApp", response.getRedirectedUrl());
@ -161,7 +146,6 @@ public class RetryWithHttpEntryPointTests extends TestCase {
RetryWithHttpEntryPoint ep = new RetryWithHttpEntryPoint();
ep.setPortResolver(new MockPortResolver(8888, 9999));
ep.setPortMapper(portMapper);
ep.afterPropertiesSet();
ep.commence(request, response);
assertEquals("http://www.example.com:8888/bigWebApp/hello/pathInfo.html?open=true", response.getRedirectedUrl());

View File

@ -47,25 +47,21 @@ public class RetryWithHttpsEntryPointTests extends TestCase {
public void testDetectsMissingPortMapper() throws Exception {
RetryWithHttpsEntryPoint ep = new RetryWithHttpsEntryPoint();
ep.setPortMapper(null);
try {
ep.afterPropertiesSet();
ep.setPortMapper(null);
fail("Should have thrown IllegalArgumentException");
} catch (IllegalArgumentException expected) {
assertEquals("portMapper is required", expected.getMessage());
}
}
public void testDetectsMissingPortResolver() throws Exception {
RetryWithHttpsEntryPoint ep = new RetryWithHttpsEntryPoint();
ep.setPortResolver(null);
try {
ep.afterPropertiesSet();
ep.setPortResolver(null);
fail("Should have thrown IllegalArgumentException");
} catch (IllegalArgumentException expected) {
assertEquals("portResolver is required", expected.getMessage());
}
}
@ -92,7 +88,6 @@ public class RetryWithHttpsEntryPointTests extends TestCase {
RetryWithHttpsEntryPoint ep = new RetryWithHttpsEntryPoint();
ep.setPortMapper(new PortMapperImpl());
ep.setPortResolver(new MockPortResolver(80, 443));
ep.afterPropertiesSet();
ep.commence(request, response);
assertEquals("https://www.example.com/bigWebApp/hello/pathInfo.html?open=true", response.getRedirectedUrl());
@ -113,14 +108,12 @@ public class RetryWithHttpsEntryPointTests extends TestCase {
RetryWithHttpsEntryPoint ep = new RetryWithHttpsEntryPoint();
ep.setPortMapper(new PortMapperImpl());
ep.setPortResolver(new MockPortResolver(80, 443));
ep.afterPropertiesSet();
ep.commence(request, response);
assertEquals("https://www.example.com/bigWebApp/hello", response.getRedirectedUrl());
}
public void testOperationWhenTargetPortIsUnknown()
throws Exception {
public void testOperationWhenTargetPortIsUnknown() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest();
request.setQueryString("open=true");
request.setScheme("http");
@ -135,7 +128,6 @@ public class RetryWithHttpsEntryPointTests extends TestCase {
RetryWithHttpsEntryPoint ep = new RetryWithHttpsEntryPoint();
ep.setPortMapper(new PortMapperImpl());
ep.setPortResolver(new MockPortResolver(8768, 1234));
ep.afterPropertiesSet();
ep.commence(request, response);
assertEquals("/bigWebApp", response.getRedirectedUrl());
@ -161,7 +153,6 @@ public class RetryWithHttpsEntryPointTests extends TestCase {
RetryWithHttpsEntryPoint ep = new RetryWithHttpsEntryPoint();
ep.setPortResolver(new MockPortResolver(8888, 9999));
ep.setPortMapper(portMapper);
ep.afterPropertiesSet();
ep.commence(request, response);
assertEquals("https://www.example.com:9999/bigWebApp/hello/pathInfo.html?open=true", response.getRedirectedUrl());