diff --git a/openid/src/main/java/org/springframework/security/openid/OpenID4JavaConsumer.java b/openid/src/main/java/org/springframework/security/openid/OpenID4JavaConsumer.java index 1a803af100..1863fea7c5 100644 --- a/openid/src/main/java/org/springframework/security/openid/OpenID4JavaConsumer.java +++ b/openid/src/main/java/org/springframework/security/openid/OpenID4JavaConsumer.java @@ -37,12 +37,14 @@ import org.openid4java.message.ParameterList; import org.openid4java.message.ax.AxMessage; import org.openid4java.message.ax.FetchRequest; import org.openid4java.message.ax.FetchResponse; +import org.springframework.util.StringUtils; /** * @author Ray Krueger * @author Luke Taylor */ +@SuppressWarnings("unchecked") public class OpenID4JavaConsumer implements OpenIDConsumer { private static final String DISCOVERY_INFO_KEY = DiscoveryInformation.class.getName(); private static final String ATTRIBUTE_LIST_KEY = "SPRING_SECURITY_OPEN_ID_ATTRIBUTES_FETCH_LIST"; @@ -93,7 +95,6 @@ public class OpenID4JavaConsumer implements OpenIDConsumer { //~ Methods ======================================================================================================== - @SuppressWarnings("unchecked") public String beginConsumption(HttpServletRequest req, String identityUrl, String returnToUrl, String realm) throws OpenIDConsumerException { List discoveries; @@ -136,9 +137,7 @@ public class OpenID4JavaConsumer implements OpenIDConsumer { return authReq.getDestinationUrl(true); } - @SuppressWarnings("unchecked") public OpenIDAuthenticationToken endConsumption(HttpServletRequest request) throws OpenIDConsumerException { - final boolean debug = logger.isDebugEnabled(); // extract the parameters from the authentication response // (which comes in as a HTTP request from the OpenID provider) ParameterList openidResp = new ParameterList(request.getParameterMap()); @@ -154,7 +153,7 @@ public class OpenID4JavaConsumer implements OpenIDConsumer { StringBuffer receivingURL = request.getRequestURL(); String queryString = request.getQueryString(); - if ((queryString != null) && (queryString.length() > 0)) { + if (StringUtils.hasLength(queryString)) { receivingURL.append("?").append(request.getQueryString()); } @@ -171,8 +170,6 @@ public class OpenID4JavaConsumer implements OpenIDConsumer { throw new OpenIDConsumerException("Error verifying openid response", e); } - List attributes = new ArrayList(); - // examine the verification result and extract the verified identifier Identifier verified = verification.getVerifiedId(); @@ -180,39 +177,50 @@ public class OpenID4JavaConsumer implements OpenIDConsumer { Identifier id = discovered.getClaimedIdentifier(); return new OpenIDAuthenticationToken(OpenIDAuthenticationStatus.FAILURE, id == null ? "Unknown" : id.getIdentifier(), - "Verification status message: [" + verification.getStatusMsg() + "]", attributes); + "Verification status message: [" + verification.getStatusMsg() + "]", + Collections.emptyList()); } - // fetch the attributesToFetch of the response - Message authSuccess = verification.getAuthResponse(); - - if (authSuccess.hasExtension(AxMessage.OPENID_NS_AX)) { - if (debug) { - logger.debug("Extracting attributes retrieved by attribute exchange"); - } - try { - MessageExtension ext = authSuccess.getExtension(AxMessage.OPENID_NS_AX); - if (ext instanceof FetchResponse) { - FetchResponse fetchResp = (FetchResponse) ext; - for (OpenIDAttribute attr : attributesToFetch) { - List values = fetchResp.getAttributeValues(attr.getName()); - if (!values.isEmpty()) { - OpenIDAttribute fetched = new OpenIDAttribute(attr.getName(), attr.getType(), values); - fetched.setRequired(attr.isRequired()); - attributes.add(fetched); - } - } - } - } catch (MessageException e) { - attributes.clear(); - throw new OpenIDConsumerException("Attribute retrieval failed", e); - } - if (debug) { - logger.debug("Retrieved attributes" + attributes); - } - } + List attributes = fetchAxAttributes(verification.getAuthResponse(), attributesToFetch); return new OpenIDAuthenticationToken(OpenIDAuthenticationStatus.SUCCESS, verified.getIdentifier(), "some message", attributes); } + + List fetchAxAttributes(Message authSuccess, List attributesToFetch) + throws OpenIDConsumerException { + + if (!authSuccess.hasExtension(AxMessage.OPENID_NS_AX)) { + return Collections.emptyList(); + } + + logger.debug("Extracting attributes retrieved by attribute exchange"); + + List attributes = Collections.emptyList(); + + try { + MessageExtension ext = authSuccess.getExtension(AxMessage.OPENID_NS_AX); + if (ext instanceof FetchResponse) { + FetchResponse fetchResp = (FetchResponse) ext; + attributes = new ArrayList(attributesToFetch.size()); + + for (OpenIDAttribute attr : attributesToFetch) { + List values = fetchResp.getAttributeValues(attr.getName()); + if (!values.isEmpty()) { + OpenIDAttribute fetched = new OpenIDAttribute(attr.getName(), attr.getType(), values); + fetched.setRequired(attr.isRequired()); + attributes.add(fetched); + } + } + } + } catch (MessageException e) { + throw new OpenIDConsumerException("Attribute retrieval failed", e); + } + + if (logger.isDebugEnabled()) { + logger.debug("Retrieved attributes" + attributes); + } + + return attributes; + } } diff --git a/openid/src/test/java/org/springframework/security/openid/OpenID4JavaConsumerTests.java b/openid/src/test/java/org/springframework/security/openid/OpenID4JavaConsumerTests.java new file mode 100644 index 0000000000..e7212804e5 --- /dev/null +++ b/openid/src/test/java/org/springframework/security/openid/OpenID4JavaConsumerTests.java @@ -0,0 +1,197 @@ +package org.springframework.security.openid; + +import static org.junit.Assert.*; +import static org.mockito.Matchers.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import org.junit.*; +import org.openid4java.association.AssociationException; +import org.openid4java.consumer.ConsumerException; +import org.openid4java.consumer.ConsumerManager; +import org.openid4java.consumer.VerificationResult; +import org.openid4java.discovery.DiscoveryException; +import org.openid4java.discovery.DiscoveryInformation; +import org.openid4java.discovery.Identifier; +import org.openid4java.message.AuthRequest; +import org.openid4java.message.Message; +import org.openid4java.message.MessageException; +import org.openid4java.message.ParameterList; +import org.openid4java.message.ax.AxMessage; +import org.openid4java.message.ax.FetchResponse; +import org.springframework.mock.web.MockHttpServletRequest; + +import java.util.*; + +/** + * @author Luke Taylor + */ +public class OpenID4JavaConsumerTests { + List attributes = Arrays.asList(new OpenIDAttribute("a","b"), new OpenIDAttribute("b","b", Arrays.asList("c"))); + + @Test + public void beginConsumptionCreatesExpectedSessionData() throws Exception { + ConsumerManager mgr = mock(ConsumerManager.class); + AuthRequest authReq = mock(AuthRequest.class); + DiscoveryInformation di = mock(DiscoveryInformation.class); + + when(mgr.authenticate(any(DiscoveryInformation.class), anyString(), anyString())).thenReturn(authReq); + when(mgr.associate(anyList())).thenReturn(di); + + OpenID4JavaConsumer consumer = new OpenID4JavaConsumer(mgr, attributes); + + MockHttpServletRequest request = new MockHttpServletRequest(); + consumer.beginConsumption(request, "", "", ""); + + assertEquals(attributes, request.getSession().getAttribute("SPRING_SECURITY_OPEN_ID_ATTRIBUTES_FETCH_LIST")); + assertSame(di, request.getSession().getAttribute(DiscoveryInformation.class.getName())); + + // Check with empty attribute fetch list + consumer = new OpenID4JavaConsumer(mgr, new NullAxFetchListFactory()); + + request = new MockHttpServletRequest(); + consumer.beginConsumption(request, "", "", ""); + } + + @Test(expected = OpenIDConsumerException.class) + public void discoveryExceptionRaisesOpenIDException() throws Exception { + ConsumerManager mgr = mock(ConsumerManager.class); + OpenID4JavaConsumer consumer = new OpenID4JavaConsumer(mgr, new NullAxFetchListFactory()); + when(mgr.discover(anyString())).thenThrow(new DiscoveryException("msg")); + consumer.beginConsumption(new MockHttpServletRequest(), "", "", ""); + } + + @Test + public void messageOrConsumerAuthenticationExceptionRaisesOpenIDException() throws Exception { + ConsumerManager mgr = mock(ConsumerManager.class); + OpenID4JavaConsumer consumer = new OpenID4JavaConsumer(mgr, new NullAxFetchListFactory()); + + when(mgr.authenticate(any(DiscoveryInformation.class), anyString(), anyString())) + .thenThrow(new MessageException("msg"), new ConsumerException("msg")); + + try { + consumer.beginConsumption(new MockHttpServletRequest(), "", "", ""); + fail(); + } catch (OpenIDConsumerException expected) { + } + + try { + consumer.beginConsumption(new MockHttpServletRequest(), "", "", ""); + fail(); + } catch (OpenIDConsumerException expected) { + } + } + + @Test + public void failedVerificationReturnsFailedAuthenticationStatus() throws Exception { + ConsumerManager mgr = mock(ConsumerManager.class); + OpenID4JavaConsumer consumer = new OpenID4JavaConsumer(mgr, new NullAxFetchListFactory()); + VerificationResult vr = mock(VerificationResult.class); + DiscoveryInformation di = mock(DiscoveryInformation.class); + + when(mgr.verify(anyString(), any(ParameterList.class), any(DiscoveryInformation.class))).thenReturn(vr); + + MockHttpServletRequest request = new MockHttpServletRequest(); + + request.getSession().setAttribute(DiscoveryInformation.class.getName(), di); + + OpenIDAuthenticationToken auth = consumer.endConsumption(request); + + assertEquals(OpenIDAuthenticationStatus.FAILURE, auth.getStatus()); + } + + @Test + public void verificationExceptionsRaiseOpenIDException() throws Exception { + ConsumerManager mgr = mock(ConsumerManager.class); + OpenID4JavaConsumer consumer = new OpenID4JavaConsumer(mgr, new NullAxFetchListFactory()); + + when(mgr.verify(anyString(), any(ParameterList.class), any(DiscoveryInformation.class))) + .thenThrow(new MessageException("")) + .thenThrow(new AssociationException("")) + .thenThrow(new DiscoveryException("")); + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setQueryString("x=5"); + + try { + consumer.endConsumption(request); + fail(); + } catch (OpenIDConsumerException expected) { + } + + try { + consumer.endConsumption(request); + fail(); + } catch (OpenIDConsumerException expected) { + } + + try { + consumer.endConsumption(request); + fail(); + } catch (OpenIDConsumerException expected) { + } + + } + + @Test + public void successfulVerificationReturnsExpectedAuthentication() throws Exception { + ConsumerManager mgr = mock(ConsumerManager.class); + OpenID4JavaConsumer consumer = new OpenID4JavaConsumer(mgr, new NullAxFetchListFactory()); + VerificationResult vr = mock(VerificationResult.class); + DiscoveryInformation di = mock(DiscoveryInformation.class); + Identifier id = new Identifier() { + public String getIdentifier() { + return "id"; + } + }; + Message msg = mock(Message.class); + + when(mgr.verify(anyString(), any(ParameterList.class), any(DiscoveryInformation.class))).thenReturn(vr); + when(vr.getVerifiedId()).thenReturn(id); + when(vr.getAuthResponse()).thenReturn(msg); + + MockHttpServletRequest request = new MockHttpServletRequest(); + + request.getSession().setAttribute(DiscoveryInformation.class.getName(), di); + request.getSession().setAttribute("SPRING_SECURITY_OPEN_ID_ATTRIBUTES_FETCH_LIST", attributes); + + OpenIDAuthenticationToken auth = consumer.endConsumption(request); + + assertEquals(OpenIDAuthenticationStatus.SUCCESS, auth.getStatus()); + } + + @Test + public void fetchAttributesReturnsExpectedValues() throws Exception { + OpenID4JavaConsumer consumer = new OpenID4JavaConsumer(new NullAxFetchListFactory()); + Message msg = mock(Message.class); + FetchResponse fr = mock(FetchResponse.class); + when(msg.hasExtension(AxMessage.OPENID_NS_AX)).thenReturn(true); + when(msg.getExtension(AxMessage.OPENID_NS_AX)).thenReturn(fr); + when(fr.getAttributeValues("a")).thenReturn(Arrays.asList("x","y")); + + List fetched = consumer.fetchAxAttributes(msg, attributes); + + assertEquals(1, fetched.size()); + assertEquals(2, fetched.get(0).getValues().size()); + } + + @Test(expected = OpenIDConsumerException.class) + public void messageExceptionFetchingAttributesRaisesOpenIDException() throws Exception { + OpenID4JavaConsumer consumer = new OpenID4JavaConsumer(new NullAxFetchListFactory()); + Message msg = mock(Message.class); + FetchResponse fr = mock(FetchResponse.class); + when(msg.hasExtension(AxMessage.OPENID_NS_AX)).thenReturn(true); + when(msg.getExtension(AxMessage.OPENID_NS_AX)).thenThrow(new MessageException("")); + when(fr.getAttributeValues("a")).thenReturn(Arrays.asList("x","y")); + + consumer.fetchAxAttributes(msg, attributes); + } + + + @Test + public void additionalConstructorsWork() throws Exception { + new OpenID4JavaConsumer(); + new OpenID4JavaConsumer(attributes); + } + +} diff --git a/openid/src/test/java/org/springframework/security/openid/OpenIDAuthenticationProviderTests.java b/openid/src/test/java/org/springframework/security/openid/OpenIDAuthenticationProviderTests.java index d692e27bec..86b5537d41 100644 --- a/openid/src/test/java/org/springframework/security/openid/OpenIDAuthenticationProviderTests.java +++ b/openid/src/test/java/org/springframework/security/openid/OpenIDAuthenticationProviderTests.java @@ -22,8 +22,10 @@ import org.springframework.security.authentication.UsernamePasswordAuthenticatio import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.authority.mapping.NullAuthoritiesMapper; import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.security.core.userdetails.UserDetailsByNameServiceWrapper; import org.springframework.security.core.userdetails.UserDetailsService; @@ -45,6 +47,7 @@ public class OpenIDAuthenticationProviderTests extends TestCase { public void testAuthenticateCancel() { OpenIDAuthenticationProvider provider = new OpenIDAuthenticationProvider(); provider.setUserDetailsService(new MockUserDetailsService()); + provider.setAuthoritiesMapper(new NullAuthoritiesMapper()); Authentication preAuth = new OpenIDAuthenticationToken(OpenIDAuthenticationStatus.CANCELLED, USERNAME, "" ,null); @@ -82,7 +85,7 @@ public class OpenIDAuthenticationProviderTests extends TestCase { */ public void testAuthenticateFailure() { OpenIDAuthenticationProvider provider = new OpenIDAuthenticationProvider(); - provider.setUserDetailsService(new MockUserDetailsService()); + provider.setAuthenticationUserDetailsService(new UserDetailsByNameServiceWrapper(new MockUserDetailsService())); Authentication preAuth = new OpenIDAuthenticationToken(OpenIDAuthenticationStatus.FAILURE, USERNAME, "", null);