Replace ExpectedException @Rules with AssertJ

Replace JUnit ExpectedException @Rules with AssertJ calls.
This commit is contained in:
Phillip Webb 2020-09-10 18:40:27 -07:00 committed by Josh Cummings
parent 910b81928f
commit 20baa7d409
24 changed files with 383 additions and 543 deletions

View File

@ -17,9 +17,7 @@
package org.springframework.security.config; package org.springframework.security.config;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.powermock.api.mockito.PowerMockito; import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PowerMockIgnore; import org.powermock.core.classloader.annotations.PowerMockIgnore;
@ -48,9 +46,6 @@ import static org.mockito.Mockito.verifyZeroInteractions;
@PowerMockIgnore({ "org.w3c.dom.*", "org.xml.sax.*", "org.apache.xerces.*", "javax.xml.parsers.*" }) @PowerMockIgnore({ "org.w3c.dom.*", "org.xml.sax.*", "org.apache.xerces.*", "javax.xml.parsers.*" })
public class SecurityNamespaceHandlerTests { public class SecurityNamespaceHandlerTests {
@Rule
public ExpectedException thrown = ExpectedException.none();
// @formatter:off // @formatter:off
private static final String XML_AUTHENTICATION_MANAGER = "<authentication-manager>" private static final String XML_AUTHENTICATION_MANAGER = "<authentication-manager>"
+ " <authentication-provider>" + " <authentication-provider>"
@ -103,12 +98,12 @@ public class SecurityNamespaceHandlerTests {
@Test @Test
public void filterNoClassDefFoundError() throws Exception { public void filterNoClassDefFoundError() throws Exception {
String className = "javax.servlet.Filter"; String className = "javax.servlet.Filter";
this.thrown.expect(BeanDefinitionParsingException.class);
this.thrown.expectMessage("NoClassDefFoundError: " + className);
PowerMockito.spy(ClassUtils.class); PowerMockito.spy(ClassUtils.class);
PowerMockito.doThrow(new NoClassDefFoundError(className)).when(ClassUtils.class, "forName", PowerMockito.doThrow(new NoClassDefFoundError(className)).when(ClassUtils.class, "forName",
eq(FILTER_CHAIN_PROXY_CLASSNAME), any(ClassLoader.class)); eq(FILTER_CHAIN_PROXY_CLASSNAME), any(ClassLoader.class));
new InMemoryXmlApplicationContext(XML_AUTHENTICATION_MANAGER + XML_HTTP_BLOCK); assertThatExceptionOfType(BeanDefinitionParsingException.class)
.isThrownBy(() -> new InMemoryXmlApplicationContext(XML_AUTHENTICATION_MANAGER + XML_HTTP_BLOCK))
.withMessageContaining("NoClassDefFoundError: " + className);
} }
@Test @Test
@ -124,12 +119,12 @@ public class SecurityNamespaceHandlerTests {
@Test @Test
public void filterChainProxyClassNotFoundException() throws Exception { public void filterChainProxyClassNotFoundException() throws Exception {
String className = FILTER_CHAIN_PROXY_CLASSNAME; String className = FILTER_CHAIN_PROXY_CLASSNAME;
this.thrown.expect(BeanDefinitionParsingException.class);
this.thrown.expectMessage("ClassNotFoundException: " + className);
PowerMockito.spy(ClassUtils.class); PowerMockito.spy(ClassUtils.class);
PowerMockito.doThrow(new ClassNotFoundException(className)).when(ClassUtils.class, "forName", PowerMockito.doThrow(new ClassNotFoundException(className)).when(ClassUtils.class, "forName",
eq(FILTER_CHAIN_PROXY_CLASSNAME), any(ClassLoader.class)); eq(FILTER_CHAIN_PROXY_CLASSNAME), any(ClassLoader.class));
new InMemoryXmlApplicationContext(XML_AUTHENTICATION_MANAGER + XML_HTTP_BLOCK); assertThatExceptionOfType(BeanDefinitionParsingException.class)
.isThrownBy(() -> new InMemoryXmlApplicationContext(XML_AUTHENTICATION_MANAGER + XML_HTTP_BLOCK))
.withMessageContaining("ClassNotFoundException: " + className);
} }
@Test @Test

View File

@ -25,7 +25,6 @@ import javax.sql.DataSource;
import org.aopalliance.intercept.MethodInterceptor; import org.aopalliance.intercept.MethodInterceptor;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.springframework.beans.BeansException; import org.springframework.beans.BeansException;
@ -80,9 +79,6 @@ public class GlobalMethodSecurityConfigurationTests {
@Rule @Rule
public final SpringTestRule spring = new SpringTestRule(); public final SpringTestRule spring = new SpringTestRule();
@Rule
public ExpectedException thrown = ExpectedException.none();
@Autowired(required = false) @Autowired(required = false)
private MethodSecurityService service; private MethodSecurityService service;
@ -98,8 +94,8 @@ public class GlobalMethodSecurityConfigurationTests {
@Test @Test
public void configureWhenGlobalMethodSecurityIsMissingMetadataSourceThenException() { public void configureWhenGlobalMethodSecurityIsMissingMetadataSourceThenException() {
this.thrown.expect(UnsatisfiedDependencyException.class); assertThatExceptionOfType(UnsatisfiedDependencyException.class)
this.spring.register(IllegalStateGlobalMethodSecurityConfig.class).autowire(); .isThrownBy(() -> this.spring.register(IllegalStateGlobalMethodSecurityConfig.class).autowire());
} }
@Test @Test

View File

@ -16,11 +16,10 @@
package org.springframework.security.crypto.codec; package org.springframework.security.crypto.codec;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
/** /**
* Test cases for {@link Hex}. * Test cases for {@link Hex}.
@ -29,9 +28,6 @@ import static org.assertj.core.api.Assertions.assertThat;
*/ */
public class HexTests { public class HexTests {
@Rule
public ExpectedException expectedException = ExpectedException.none();
@Test @Test
public void encode() { public void encode() {
assertThat(Hex.encode(new byte[] { (byte) 'A', (byte) 'B', (byte) 'C', (byte) 'D' })) assertThat(Hex.encode(new byte[] { (byte) 'A', (byte) 'B', (byte) 'C', (byte) 'D' }))
@ -55,30 +51,26 @@ public class HexTests {
@Test @Test
public void decodeNotEven() { public void decodeNotEven() {
this.expectedException.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> Hex.decode("414243444"))
this.expectedException.expectMessage("Hex-encoded string must have an even number of characters"); .withMessage("Hex-encoded string must have an even number of characters");
Hex.decode("414243444");
} }
@Test @Test
public void decodeExistNonHexCharAtFirst() { public void decodeExistNonHexCharAtFirst() {
this.expectedException.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> Hex.decode("G0"))
this.expectedException.expectMessage("Detected a Non-hex character at 1 or 2 position"); .withMessage("Detected a Non-hex character at 1 or 2 position");
Hex.decode("G0");
} }
@Test @Test
public void decodeExistNonHexCharAtSecond() { public void decodeExistNonHexCharAtSecond() {
this.expectedException.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> Hex.decode("410G"))
this.expectedException.expectMessage("Detected a Non-hex character at 3 or 4 position"); .withMessage("Detected a Non-hex character at 3 or 4 position");
Hex.decode("410G");
} }
@Test @Test
public void decodeExistNonHexCharAtBoth() { public void decodeExistNonHexCharAtBoth() {
this.expectedException.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> Hex.decode("4142GG"))
this.expectedException.expectMessage("Detected a Non-hex character at 5 or 6 position"); .withMessage("Detected a Non-hex character at 5 or 6 position");
Hex.decode("4142GG");
} }
} }

View File

@ -17,21 +17,18 @@
package org.springframework.security; package org.springframework.security;
import org.junit.After; import org.junit.After;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.springframework.beans.factory.BeanDefinitionStoreException; import org.springframework.beans.factory.BeanDefinitionStoreException;
import org.springframework.context.support.ClassPathXmlApplicationContext; import org.springframework.context.support.ClassPathXmlApplicationContext;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
/** /**
* @author Eddú Meléndez * @author Eddú Meléndez
*/ */
public class LdapServerBeanDefinitionParserTests { public class LdapServerBeanDefinitionParserTests {
@Rule
public ExpectedException thrown = ExpectedException.none();
private ClassPathXmlApplicationContext context; private ClassPathXmlApplicationContext context;
@After @After
@ -44,10 +41,9 @@ public class LdapServerBeanDefinitionParserTests {
@Test @Test
public void apacheDirectoryServerIsStartedByDefault() { public void apacheDirectoryServerIsStartedByDefault() {
this.thrown.expect(BeanDefinitionStoreException.class); assertThatExceptionOfType(BeanDefinitionStoreException.class)
this.thrown.expectMessage("Embedded LDAP server is not provided"); .isThrownBy(() -> this.context = new ClassPathXmlApplicationContext("applicationContext-security.xml"))
.withMessageContaining("Embedded LDAP server is not provided");
this.context = new ClassPathXmlApplicationContext("applicationContext-security.xml");
} }
} }

View File

@ -30,14 +30,8 @@ import javax.naming.directory.SearchControls;
import javax.naming.directory.SearchResult; import javax.naming.directory.SearchResult;
import org.apache.directory.shared.ldap.util.EmptyEnumeration; import org.apache.directory.shared.ldap.util.EmptyEnumeration;
import org.hamcrest.BaseMatcher;
import org.hamcrest.CoreMatchers;
import org.hamcrest.Description;
import org.hamcrest.Matcher;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.springframework.dao.IncorrectResultSizeDataAccessException; import org.springframework.dao.IncorrectResultSizeDataAccessException;
@ -71,9 +65,6 @@ public class ActiveDirectoryLdapAuthenticationProviderTests {
public static final String NON_EXISTING_LDAP_PROVIDER = "ldap://192.168.1.201/"; public static final String NON_EXISTING_LDAP_PROVIDER = "ldap://192.168.1.201/";
@Rule
public ExpectedException thrown = ExpectedException.none();
ActiveDirectoryLdapAuthenticationProvider provider; ActiveDirectoryLdapAuthenticationProvider provider;
UsernamePasswordAuthenticationToken joe = new UsernamePasswordAuthenticationToken("joe", "password"); UsernamePasswordAuthenticationToken joe = new UsernamePasswordAuthenticationToken("joe", "password");
@ -245,29 +236,10 @@ public class ActiveDirectoryLdapAuthenticationProviderTests {
this.provider.contextFactory = createContextFactoryThrowing( this.provider.contextFactory = createContextFactoryThrowing(
new AuthenticationException(msg + dataCode + ", xxxx]")); new AuthenticationException(msg + dataCode + ", xxxx]"));
this.provider.setConvertSubErrorCodesToExceptions(true); this.provider.setConvertSubErrorCodesToExceptions(true);
this.thrown.expect(BadCredentialsException.class); assertThatExceptionOfType(BadCredentialsException.class).isThrownBy(() -> this.provider.authenticate(this.joe))
this.thrown.expect(new BaseMatcher<BadCredentialsException>() { .withCauseInstanceOf(ActiveDirectoryAuthenticationException.class)
private Matcher<Object> causeInstance = CoreMatchers .satisfies((ex) -> assertThat(((ActiveDirectoryAuthenticationException) ex.getCause()).getDataCode())
.instanceOf(ActiveDirectoryAuthenticationException.class); .isEqualTo(dataCode));
private Matcher<String> causeDataCode = CoreMatchers.equalTo(dataCode);
@Override
public boolean matches(Object that) {
Throwable t = (Throwable) that;
ActiveDirectoryAuthenticationException cause = (ActiveDirectoryAuthenticationException) t.getCause();
return this.causeInstance.matches(cause) && this.causeDataCode.matches(cause.getDataCode());
}
@Override
public void describeTo(Description desc) {
desc.appendText("getCause() ");
this.causeInstance.describeTo(desc);
desc.appendText("getCause().getDataCode() ");
this.causeDataCode.describeTo(desc);
}
});
this.provider.authenticate(this.joe);
} }
@Test(expected = CredentialsExpiredException.class) @Test(expected = CredentialsExpiredException.class)

View File

@ -25,9 +25,7 @@ import java.util.Map;
import java.util.Set; import java.util.Set;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
@ -52,7 +50,8 @@ import org.springframework.security.oauth2.core.endpoint.TestOAuth2Authorization
import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.core.user.OAuth2User;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.CoreMatchers.containsString; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyCollection; import static org.mockito.ArgumentMatchers.anyCollection;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
@ -79,9 +78,6 @@ public class OAuth2LoginAuthenticationProviderTests {
private OAuth2LoginAuthenticationProvider authenticationProvider; private OAuth2LoginAuthenticationProvider authenticationProvider;
@Rule
public ExpectedException exception = ExpectedException.none();
@Before @Before
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void setUp() { public void setUp() {
@ -98,20 +94,19 @@ public class OAuth2LoginAuthenticationProviderTests {
@Test @Test
public void constructorWhenAccessTokenResponseClientIsNullThenThrowIllegalArgumentException() { public void constructorWhenAccessTokenResponseClientIsNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException()
new OAuth2LoginAuthenticationProvider(null, this.userService); .isThrownBy(() -> new OAuth2LoginAuthenticationProvider(null, this.userService));
} }
@Test @Test
public void constructorWhenUserServiceIsNullThenThrowIllegalArgumentException() { public void constructorWhenUserServiceIsNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException()
new OAuth2LoginAuthenticationProvider(this.accessTokenResponseClient, null); .isThrownBy(() -> new OAuth2LoginAuthenticationProvider(this.accessTokenResponseClient, null));
} }
@Test @Test
public void setAuthoritiesMapperWhenAuthoritiesMapperIsNullThenThrowIllegalArgumentException() { public void setAuthoritiesMapperWhenAuthoritiesMapperIsNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> this.authenticationProvider.setAuthoritiesMapper(null));
this.authenticationProvider.setAuthoritiesMapper(null);
} }
@Test @Test
@ -132,26 +127,26 @@ public class OAuth2LoginAuthenticationProviderTests {
@Test @Test
public void authenticateWhenAuthorizationErrorResponseThenThrowOAuth2AuthenticationException() { public void authenticateWhenAuthorizationErrorResponseThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString(OAuth2ErrorCodes.INVALID_REQUEST));
OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.error() OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.error()
.errorCode(OAuth2ErrorCodes.INVALID_REQUEST).build(); .errorCode(OAuth2ErrorCodes.INVALID_REQUEST).build();
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest,
authorizationResponse); authorizationResponse);
this.authenticationProvider assertThatExceptionOfType(OAuth2AuthenticationException.class)
.authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)); .isThrownBy(() -> this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)))
.withMessageContaining(OAuth2ErrorCodes.INVALID_REQUEST);
} }
@Test @Test
public void authenticateWhenAuthorizationResponseStateNotEqualAuthorizationRequestStateThenThrowOAuth2AuthenticationException() { public void authenticateWhenAuthorizationResponseStateNotEqualAuthorizationRequestStateThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_state_parameter"));
OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success().state("67890") OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success().state("67890")
.build(); .build();
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest,
authorizationResponse); authorizationResponse);
this.authenticationProvider assertThatExceptionOfType(OAuth2AuthenticationException.class)
.authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)); .isThrownBy(() -> this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)))
.withMessageContaining("invalid_state_parameter");
} }
@Test @Test

View File

@ -21,9 +21,7 @@ import java.time.Instant;
import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.MockWebServer;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
@ -40,7 +38,8 @@ import org.springframework.security.oauth2.core.endpoint.TestOAuth2Authorization
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.CoreMatchers.containsString; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
/** /**
* Tests for {@link NimbusAuthorizationCodeTokenResponseClient}. * Tests for {@link NimbusAuthorizationCodeTokenResponseClient}.
@ -59,9 +58,6 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
private NimbusAuthorizationCodeTokenResponseClient tokenResponseClient = new NimbusAuthorizationCodeTokenResponseClient(); private NimbusAuthorizationCodeTokenResponseClient tokenResponseClient = new NimbusAuthorizationCodeTokenResponseClient();
@Rule
public ExpectedException exception = ExpectedException.none();
@Before @Before
public void setUp() { public void setUp() {
this.clientRegistrationBuilder = TestClientRegistrations.clientRegistration() this.clientRegistrationBuilder = TestClientRegistrations.clientRegistration()
@ -109,29 +105,27 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenRedirectUriMalformedThenThrowIllegalArgumentException() { public void getTokenResponseWhenRedirectUriMalformedThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class);
String redirectUri = "http:\\example.com"; String redirectUri = "http:\\example.com";
OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request()
.redirectUri(redirectUri).build(); .redirectUri(redirectUri).build();
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest,
this.authorizationResponse); this.authorizationResponse);
this.tokenResponseClient.getTokenResponse( assertThatIllegalArgumentException()
new OAuth2AuthorizationCodeGrantRequest(this.clientRegistrationBuilder.build(), authorizationExchange)); .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest(
this.clientRegistrationBuilder.build(), authorizationExchange)));
} }
@Test @Test
public void getTokenResponseWhenTokenUriMalformedThenThrowIllegalArgumentException() { public void getTokenResponseWhenTokenUriMalformedThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class);
String tokenUri = "http:\\provider.com\\oauth2\\token"; String tokenUri = "http:\\provider.com\\oauth2\\token";
this.clientRegistrationBuilder.tokenUri(tokenUri); this.clientRegistrationBuilder.tokenUri(tokenUri);
this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest( assertThatIllegalArgumentException()
this.clientRegistrationBuilder.build(), this.authorizationExchange)); .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest(
this.clientRegistrationBuilder.build(), this.authorizationExchange)));
} }
@Test @Test
public void getTokenResponseWhenSuccessResponseInvalidThenThrowOAuth2AuthorizationException() throws Exception { public void getTokenResponseWhenSuccessResponseInvalidThenThrowOAuth2AuthorizationException() throws Exception {
this.exception.expect(OAuth2AuthorizationException.class);
this.exception.expectMessage(containsString("invalid_token_response"));
MockWebServer server = new MockWebServer(); MockWebServer server = new MockWebServer();
// @formatter:off // @formatter:off
String accessTokenSuccessResponse = "{\n" String accessTokenSuccessResponse = "{\n"
@ -149,8 +143,10 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
String tokenUri = server.url("/oauth2/token").toString(); String tokenUri = server.url("/oauth2/token").toString();
this.clientRegistrationBuilder.tokenUri(tokenUri); this.clientRegistrationBuilder.tokenUri(tokenUri);
try { try {
this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest( assertThatExceptionOfType(OAuth2AuthorizationException.class)
this.clientRegistrationBuilder.build(), this.authorizationExchange)); .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest(
this.clientRegistrationBuilder.build(), this.authorizationExchange)))
.withMessageContaining("invalid_token_response");
} }
finally { finally {
server.shutdown(); server.shutdown();
@ -159,17 +155,15 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenTokenUriInvalidThenThrowOAuth2AuthorizationException() { public void getTokenResponseWhenTokenUriInvalidThenThrowOAuth2AuthorizationException() {
this.exception.expect(OAuth2AuthorizationException.class);
String tokenUri = "https://invalid-provider.com/oauth2/token"; String tokenUri = "https://invalid-provider.com/oauth2/token";
this.clientRegistrationBuilder.tokenUri(tokenUri); this.clientRegistrationBuilder.tokenUri(tokenUri);
this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest( assertThatExceptionOfType(OAuth2AuthorizationException.class)
this.clientRegistrationBuilder.build(), this.authorizationExchange)); .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest(
this.clientRegistrationBuilder.build(), this.authorizationExchange)));
} }
@Test @Test
public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() throws Exception { public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() throws Exception {
this.exception.expect(OAuth2AuthorizationException.class);
this.exception.expectMessage(containsString("unauthorized_client"));
MockWebServer server = new MockWebServer(); MockWebServer server = new MockWebServer();
// @formatter:off // @formatter:off
String accessTokenErrorResponse = "{\n" String accessTokenErrorResponse = "{\n"
@ -182,8 +176,10 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
String tokenUri = server.url("/oauth2/token").toString(); String tokenUri = server.url("/oauth2/token").toString();
this.clientRegistrationBuilder.tokenUri(tokenUri); this.clientRegistrationBuilder.tokenUri(tokenUri);
try { try {
this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest( assertThatExceptionOfType(OAuth2AuthorizationException.class)
this.clientRegistrationBuilder.build(), this.authorizationExchange)); .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest(
this.clientRegistrationBuilder.build(), this.authorizationExchange)))
.withMessageContaining("unauthorized_client");
} }
finally { finally {
server.shutdown(); server.shutdown();
@ -193,16 +189,16 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
// gh-5594 // gh-5594
@Test @Test
public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() throws Exception { public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() throws Exception {
this.exception.expect(OAuth2AuthorizationException.class);
this.exception.expectMessage(containsString("server_error"));
MockWebServer server = new MockWebServer(); MockWebServer server = new MockWebServer();
server.enqueue(new MockResponse().setResponseCode(500)); server.enqueue(new MockResponse().setResponseCode(500));
server.start(); server.start();
String tokenUri = server.url("/oauth2/token").toString(); String tokenUri = server.url("/oauth2/token").toString();
this.clientRegistrationBuilder.tokenUri(tokenUri); this.clientRegistrationBuilder.tokenUri(tokenUri);
try { try {
this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest( assertThatExceptionOfType(OAuth2AuthorizationException.class)
this.clientRegistrationBuilder.build(), this.authorizationExchange)); .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest(
this.clientRegistrationBuilder.build(), this.authorizationExchange)))
.withMessageContaining("server_error");
} }
finally { finally {
server.shutdown(); server.shutdown();
@ -212,8 +208,6 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException()
throws Exception { throws Exception {
this.exception.expect(OAuth2AuthorizationException.class);
this.exception.expectMessage(containsString("invalid_token_response"));
MockWebServer server = new MockWebServer(); MockWebServer server = new MockWebServer();
// @formatter:off // @formatter:off
String accessTokenSuccessResponse = "{\n" String accessTokenSuccessResponse = "{\n"
@ -228,8 +222,10 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
String tokenUri = server.url("/oauth2/token").toString(); String tokenUri = server.url("/oauth2/token").toString();
this.clientRegistrationBuilder.tokenUri(tokenUri); this.clientRegistrationBuilder.tokenUri(tokenUri);
try { try {
this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest( assertThatExceptionOfType(OAuth2AuthorizationException.class)
this.clientRegistrationBuilder.build(), this.authorizationExchange)); .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest(
this.clientRegistrationBuilder.build(), this.authorizationExchange)))
.withMessageContaining("invalid_token_response");
} }
finally { finally {
server.shutdown(); server.shutdown();

View File

@ -28,9 +28,7 @@ import java.util.Map;
import java.util.Set; import java.util.Set;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
@ -64,7 +62,8 @@ import org.springframework.security.oauth2.jwt.JwtException;
import org.springframework.security.oauth2.jwt.TestJwts; import org.springframework.security.oauth2.jwt.TestJwts;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.CoreMatchers.containsString; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyCollection; import static org.mockito.ArgumentMatchers.anyCollection;
import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.anyString;
@ -100,9 +99,6 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
private String nonceHash; private String nonceHash;
@Rule
public ExpectedException exception = ExpectedException.none();
@Before @Before
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void setUp() { public void setUp() {
@ -138,26 +134,24 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
@Test @Test
public void constructorWhenAccessTokenResponseClientIsNullThenThrowIllegalArgumentException() { public void constructorWhenAccessTokenResponseClientIsNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException()
new OidcAuthorizationCodeAuthenticationProvider(null, this.userService); .isThrownBy(() -> new OidcAuthorizationCodeAuthenticationProvider(null, this.userService));
} }
@Test @Test
public void constructorWhenUserServiceIsNullThenThrowIllegalArgumentException() { public void constructorWhenUserServiceIsNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(
new OidcAuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient, null); () -> new OidcAuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient, null));
} }
@Test @Test
public void setJwtDecoderFactoryWhenNullThenThrowIllegalArgumentException() { public void setJwtDecoderFactoryWhenNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> this.authenticationProvider.setJwtDecoderFactory(null));
this.authenticationProvider.setJwtDecoderFactory(null);
} }
@Test @Test
public void setAuthoritiesMapperWhenAuthoritiesMapperIsNullThenThrowIllegalArgumentException() { public void setAuthoritiesMapperWhenAuthoritiesMapperIsNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> this.authenticationProvider.setAuthoritiesMapper(null));
this.authenticationProvider.setAuthoritiesMapper(null);
} }
@Test @Test
@ -181,8 +175,6 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
@Test @Test
public void authenticateWhenAuthorizationErrorResponseThenThrowOAuth2AuthenticationException() { public void authenticateWhenAuthorizationErrorResponseThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString(OAuth2ErrorCodes.INVALID_SCOPE));
// @formatter:off // @formatter:off
OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.error() OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.error()
.errorCode(OAuth2ErrorCodes.INVALID_SCOPE) .errorCode(OAuth2ErrorCodes.INVALID_SCOPE)
@ -190,14 +182,14 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
// @formatter:on // @formatter:on
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest,
authorizationResponse); authorizationResponse);
this.authenticationProvider assertThatExceptionOfType(OAuth2AuthenticationException.class)
.authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)); .isThrownBy(() -> this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)))
.withMessageContaining(OAuth2ErrorCodes.INVALID_SCOPE);
} }
@Test @Test
public void authenticateWhenAuthorizationResponseStateNotEqualAuthorizationRequestStateThenThrowOAuth2AuthenticationException() { public void authenticateWhenAuthorizationResponseStateNotEqualAuthorizationRequestStateThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_state_parameter"));
// @formatter:off // @formatter:off
OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success() OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success()
.state("89012") .state("89012")
@ -205,14 +197,14 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
// @formatter:on // @formatter:on
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest,
authorizationResponse); authorizationResponse);
this.authenticationProvider assertThatExceptionOfType(OAuth2AuthenticationException.class)
.authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)); .isThrownBy(() -> this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)))
.withMessageContaining("invalid_state_parameter");
} }
@Test @Test
public void authenticateWhenTokenResponseDoesNotContainIdTokenThenThrowOAuth2AuthenticationException() { public void authenticateWhenTokenResponseDoesNotContainIdTokenThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_id_token"));
// @formatter:off // @formatter:off
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
.withResponse(this.accessTokenSuccessResponse()) .withResponse(this.accessTokenSuccessResponse())
@ -220,38 +212,38 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
.build(); .build();
// @formatter:on // @formatter:on
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
this.authenticationProvider assertThatExceptionOfType(OAuth2AuthenticationException.class)
.authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); .isThrownBy(() -> this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)))
.withMessageContaining("invalid_id_token");
} }
@Test @Test
public void authenticateWhenJwkSetUriNotSetThenThrowOAuth2AuthenticationException() { public void authenticateWhenJwkSetUriNotSetThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("missing_signature_verifier"));
// @formatter:off // @formatter:off
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration()
.jwkSetUri(null) .jwkSetUri(null)
.build(); .build();
// @formatter:on // @formatter:on
this.authenticationProvider assertThatExceptionOfType(OAuth2AuthenticationException.class)
.authenticate(new OAuth2LoginAuthenticationToken(clientRegistration, this.authorizationExchange)); .isThrownBy(() -> this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(clientRegistration, this.authorizationExchange)))
.withMessageContaining("missing_signature_verifier");
} }
@Test @Test
public void authenticateWhenIdTokenValidationErrorThenThrowOAuth2AuthenticationException() { public void authenticateWhenIdTokenValidationErrorThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("[invalid_id_token] ID Token Validation Error"));
JwtDecoder jwtDecoder = mock(JwtDecoder.class); JwtDecoder jwtDecoder = mock(JwtDecoder.class);
given(jwtDecoder.decode(anyString())).willThrow(new JwtException("ID Token Validation Error")); given(jwtDecoder.decode(anyString())).willThrow(new JwtException("ID Token Validation Error"));
this.authenticationProvider.setJwtDecoderFactory((registration) -> jwtDecoder); this.authenticationProvider.setJwtDecoderFactory((registration) -> jwtDecoder);
this.authenticationProvider assertThatExceptionOfType(OAuth2AuthenticationException.class)
.authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); .isThrownBy(() -> this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)))
.withMessageContaining("[invalid_id_token] ID Token Validation Error");
} }
@Test @Test
public void authenticateWhenIdTokenInvalidNonceThenThrowOAuth2AuthenticationException() { public void authenticateWhenIdTokenInvalidNonceThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("[invalid_nonce]"));
Map<String, Object> claims = new HashMap<>(); Map<String, Object> claims = new HashMap<>();
claims.put(IdTokenClaimNames.ISS, "https://provider.com"); claims.put(IdTokenClaimNames.ISS, "https://provider.com");
claims.put(IdTokenClaimNames.SUB, "subject1"); claims.put(IdTokenClaimNames.SUB, "subject1");
@ -259,8 +251,10 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
claims.put(IdTokenClaimNames.AZP, "client1"); claims.put(IdTokenClaimNames.AZP, "client1");
claims.put(IdTokenClaimNames.NONCE, "invalid-nonce-hash"); claims.put(IdTokenClaimNames.NONCE, "invalid-nonce-hash");
this.setUpIdToken(claims); this.setUpIdToken(claims);
this.authenticationProvider assertThatExceptionOfType(OAuth2AuthenticationException.class)
.authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); .isThrownBy(() -> this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)))
.withMessageContaining("[invalid_nonce]");
} }
@Test @Test

View File

@ -29,9 +29,7 @@ import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest; import okhttp3.mockwebserver.RecordedRequest;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.springframework.core.convert.converter.Converter; import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
@ -56,8 +54,8 @@ import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority; import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.hamcrest.CoreMatchers.containsString;
import static org.mockito.ArgumentMatchers.same; import static org.mockito.ArgumentMatchers.same;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
@ -80,9 +78,6 @@ public class OidcUserServiceTests {
private MockWebServer server; private MockWebServer server;
@Rule
public ExpectedException exception = ExpectedException.none();
@Before @Before
public void setup() throws Exception { public void setup() throws Exception {
this.server = new MockWebServer(); this.server = new MockWebServer();
@ -133,8 +128,7 @@ public class OidcUserServiceTests {
@Test @Test
public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() { public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> this.userService.loadUser(null));
this.userService.loadUser(null);
} }
@Test @Test
@ -260,8 +254,6 @@ public class OidcUserServiceTests {
// gh-5447 // gh-5447
@Test @Test
public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectIsNullThenThrowOAuth2AuthenticationException() { public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectIsNullThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_user_info_response"));
// @formatter:off // @formatter:off
String userInfoResponse = "{\n" String userInfoResponse = "{\n"
+ " \"email\": \"full_name@provider.com\",\n" + " \"email\": \"full_name@provider.com\",\n"
@ -272,25 +264,26 @@ public class OidcUserServiceTests {
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri)
.userNameAttributeName(StandardClaimNames.EMAIL).build(); .userNameAttributeName(StandardClaimNames.EMAIL).build();
this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); assertThatExceptionOfType(OAuth2AuthenticationException.class)
.isThrownBy(() -> this.userService
.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)))
.withMessageContaining("invalid_user_info_response");
} }
@Test @Test
public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectNotSameAsIdTokenSubjectThenThrowOAuth2AuthenticationException() { public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectNotSameAsIdTokenSubjectThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_user_info_response"));
String userInfoResponse = "{\n" + " \"sub\": \"other-subject\"\n" + "}\n"; String userInfoResponse = "{\n" + " \"sub\": \"other-subject\"\n" + "}\n";
this.server.enqueue(jsonResponse(userInfoResponse)); this.server.enqueue(jsonResponse(userInfoResponse));
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build();
this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); assertThatExceptionOfType(OAuth2AuthenticationException.class)
.isThrownBy(() -> this.userService
.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)))
.withMessageContaining("invalid_user_info_response");
} }
@Test @Test
public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() { public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString(
"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
// @formatter:off // @formatter:off
String userInfoResponse = "{\n" String userInfoResponse = "{\n"
+ " \"sub\": \"subject1\",\n" + " \"sub\": \"subject1\",\n"
@ -304,28 +297,34 @@ public class OidcUserServiceTests {
this.server.enqueue(jsonResponse(userInfoResponse)); this.server.enqueue(jsonResponse(userInfoResponse));
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build();
this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); assertThatExceptionOfType(OAuth2AuthenticationException.class)
.isThrownBy(() -> this.userService
.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)))
.withMessageContaining(
"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource");
} }
@Test @Test
public void loadUserWhenServerErrorThenThrowOAuth2AuthenticationException() { public void loadUserWhenServerErrorThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString(
"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error"));
this.server.enqueue(new MockResponse().setResponseCode(500)); this.server.enqueue(new MockResponse().setResponseCode(500));
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build();
this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); assertThatExceptionOfType(OAuth2AuthenticationException.class)
.isThrownBy(() -> this.userService
.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)))
.withMessageContaining(
"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error");
} }
@Test @Test
public void loadUserWhenUserInfoUriInvalidThenThrowOAuth2AuthenticationException() { public void loadUserWhenUserInfoUriInvalidThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString(
"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
String userInfoUri = "https://invalid-provider.com/user"; String userInfoUri = "https://invalid-provider.com/user";
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build();
this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); assertThatExceptionOfType(OAuth2AuthenticationException.class)
.isThrownBy(() -> this.userService
.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)))
.withMessageContaining(
"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource");
} }
@Test @Test

View File

@ -26,9 +26,7 @@ import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.MockWebServer;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
@ -43,7 +41,8 @@ import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.core.user.OAuth2User;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.CoreMatchers.containsString; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
/** /**
* Tests for {@link CustomUserTypesOAuth2UserService}. * Tests for {@link CustomUserTypesOAuth2UserService}.
@ -61,9 +60,6 @@ public class CustomUserTypesOAuth2UserServiceTests {
private MockWebServer server; private MockWebServer server;
@Rule
public ExpectedException exception = ExpectedException.none();
@Before @Before
public void setUp() throws Exception { public void setUp() throws Exception {
this.server = new MockWebServer(); this.server = new MockWebServer();
@ -86,32 +82,28 @@ public class CustomUserTypesOAuth2UserServiceTests {
@Test @Test
public void constructorWhenCustomUserTypesIsNullThenThrowIllegalArgumentException() { public void constructorWhenCustomUserTypesIsNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> new CustomUserTypesOAuth2UserService(null));
new CustomUserTypesOAuth2UserService(null);
} }
@Test @Test
public void constructorWhenCustomUserTypesIsEmptyThenThrowIllegalArgumentException() { public void constructorWhenCustomUserTypesIsEmptyThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException()
new CustomUserTypesOAuth2UserService(Collections.emptyMap()); .isThrownBy(() -> new CustomUserTypesOAuth2UserService(Collections.emptyMap()));
} }
@Test @Test
public void setRequestEntityConverterWhenNullThenThrowIllegalArgumentException() { public void setRequestEntityConverterWhenNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> this.userService.setRequestEntityConverter(null));
this.userService.setRequestEntityConverter(null);
} }
@Test @Test
public void setRestOperationsWhenNullThenThrowIllegalArgumentException() { public void setRestOperationsWhenNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> this.userService.setRestOperations(null));
this.userService.setRestOperations(null);
} }
@Test @Test
public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() { public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> this.userService.loadUser(null));
this.userService.loadUser(null);
} }
@Test @Test
@ -151,9 +143,6 @@ public class CustomUserTypesOAuth2UserServiceTests {
@Test @Test
public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() { public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString(
"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
// @formatter:off // @formatter:off
String userInfoResponse = "{\n" String userInfoResponse = "{\n"
+ " \"id\": \"12345\",\n" + " \"id\": \"12345\",\n"
@ -166,28 +155,34 @@ public class CustomUserTypesOAuth2UserServiceTests {
this.server.enqueue(jsonResponse(userInfoResponse)); this.server.enqueue(jsonResponse(userInfoResponse));
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build();
this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); assertThatExceptionOfType(OAuth2AuthenticationException.class)
.isThrownBy(
() -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)))
.withMessageContaining(
"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource");
} }
@Test @Test
public void loadUserWhenServerErrorThenThrowOAuth2AuthenticationException() { public void loadUserWhenServerErrorThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString(
"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error"));
this.server.enqueue(new MockResponse().setResponseCode(500)); this.server.enqueue(new MockResponse().setResponseCode(500));
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build();
this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); assertThatExceptionOfType(OAuth2AuthenticationException.class)
.isThrownBy(
() -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)))
.withMessageContaining(
"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error");
} }
@Test @Test
public void loadUserWhenUserInfoUriInvalidThenThrowOAuth2AuthenticationException() { public void loadUserWhenUserInfoUriInvalidThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString(
"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
String userInfoUri = "https://invalid-provider.com/user"; String userInfoUri = "https://invalid-provider.com/user";
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build();
this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); assertThatExceptionOfType(OAuth2AuthenticationException.class)
.isThrownBy(
() -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)))
.withMessageContaining(
"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource");
} }
private ClientRegistration.Builder withRegistrationId(String registrationId) { private ClientRegistration.Builder withRegistrationId(String registrationId) {

View File

@ -26,9 +26,7 @@ import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest; import okhttp3.mockwebserver.RecordedRequest;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.convert.converter.Converter; import org.springframework.core.convert.converter.Converter;
@ -51,7 +49,8 @@ import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;
import org.springframework.web.client.RestOperations; import org.springframework.web.client.RestOperations;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.CoreMatchers.containsString; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.ArgumentMatchers.nullable;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
@ -73,9 +72,6 @@ public class DefaultOAuth2UserServiceTests {
private MockWebServer server; private MockWebServer server;
@Rule
public ExpectedException exception = ExpectedException.none();
@Before @Before
public void setup() throws Exception { public void setup() throws Exception {
this.server = new MockWebServer(); this.server = new MockWebServer();
@ -95,40 +91,39 @@ public class DefaultOAuth2UserServiceTests {
@Test @Test
public void setRequestEntityConverterWhenNullThenThrowIllegalArgumentException() { public void setRequestEntityConverterWhenNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> this.userService.setRequestEntityConverter(null));
this.userService.setRequestEntityConverter(null);
} }
@Test @Test
public void setRestOperationsWhenNullThenThrowIllegalArgumentException() { public void setRestOperationsWhenNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> this.userService.setRestOperations(null));
this.userService.setRestOperations(null);
} }
@Test @Test
public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() { public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> this.userService.loadUser(null));
this.userService.loadUser(null);
} }
@Test @Test
public void loadUserWhenUserInfoUriIsNullThenThrowOAuth2AuthenticationException() { public void loadUserWhenUserInfoUriIsNullThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("missing_user_info_uri"));
ClientRegistration clientRegistration = this.clientRegistrationBuilder.build(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.build();
this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); assertThatExceptionOfType(OAuth2AuthenticationException.class)
.isThrownBy(
() -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)))
.withMessageContaining("missing_user_info_uri");
} }
@Test @Test
public void loadUserWhenUserNameAttributeNameIsNullThenThrowOAuth2AuthenticationException() { public void loadUserWhenUserNameAttributeNameIsNullThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("missing_user_name_attribute"));
// @formatter:off // @formatter:off
ClientRegistration clientRegistration = this.clientRegistrationBuilder ClientRegistration clientRegistration = this.clientRegistrationBuilder
.userInfoUri("https://provider.com/user") .userInfoUri("https://provider.com/user")
.build(); .build();
// @formatter:on // @formatter:on
this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); assertThatExceptionOfType(OAuth2AuthenticationException.class)
.isThrownBy(
() -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)))
.withMessageContaining("missing_user_name_attribute");
} }
@Test @Test
@ -165,9 +160,6 @@ public class DefaultOAuth2UserServiceTests {
@Test @Test
public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() { public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString(
"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
// @formatter:off // @formatter:off
String userInfoResponse = "{\n" String userInfoResponse = "{\n"
+ " \"user-name\": \"user1\",\n" + " \"user-name\": \"user1\",\n"
@ -182,16 +174,15 @@ public class DefaultOAuth2UserServiceTests {
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri)
.userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build(); .userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build();
this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); assertThatExceptionOfType(OAuth2AuthenticationException.class)
.isThrownBy(
() -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)))
.withMessageContaining(
"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource");
} }
@Test @Test
public void loadUserWhenUserInfoErrorResponseWwwAuthenticateHeaderThenThrowOAuth2AuthenticationException() { public void loadUserWhenUserInfoErrorResponseWwwAuthenticateHeaderThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString(
"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
this.exception.expectMessage(
containsString("Error Code: insufficient_scope, Error Description: The access token expired"));
String wwwAuthenticateHeader = "Bearer realm=\"auth-realm\" error=\"insufficient_scope\" error_description=\"The access token expired\""; String wwwAuthenticateHeader = "Bearer realm=\"auth-realm\" error=\"insufficient_scope\" error_description=\"The access token expired\"";
MockResponse response = new MockResponse(); MockResponse response = new MockResponse();
response.setHeader(HttpHeaders.WWW_AUTHENTICATE, wwwAuthenticateHeader); response.setHeader(HttpHeaders.WWW_AUTHENTICATE, wwwAuthenticateHeader);
@ -200,15 +191,16 @@ public class DefaultOAuth2UserServiceTests {
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri)
.userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build(); .userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build();
this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); assertThatExceptionOfType(OAuth2AuthenticationException.class)
.isThrownBy(
() -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)))
.withMessageContaining(
"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")
.withMessageContaining("Error Code: insufficient_scope, Error Description: The access token expired");
} }
@Test @Test
public void loadUserWhenUserInfoErrorResponseThenThrowOAuth2AuthenticationException() { public void loadUserWhenUserInfoErrorResponseThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString(
"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
this.exception.expectMessage(containsString("Error Code: invalid_token"));
// @formatter:off // @formatter:off
String userInfoErrorResponse = "{\n" String userInfoErrorResponse = "{\n"
+ " \"error\": \"invalid_token\"\n" + " \"error\": \"invalid_token\"\n"
@ -218,30 +210,37 @@ public class DefaultOAuth2UserServiceTests {
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri)
.userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build(); .userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build();
this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); assertThatExceptionOfType(OAuth2AuthenticationException.class)
.isThrownBy(
() -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)))
.withMessageContaining(
"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")
.withMessageContaining("Error Code: invalid_token");
} }
@Test @Test
public void loadUserWhenServerErrorThenThrowOAuth2AuthenticationException() { public void loadUserWhenServerErrorThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString(
"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error"));
this.server.enqueue(new MockResponse().setResponseCode(500)); this.server.enqueue(new MockResponse().setResponseCode(500));
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri)
.userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build(); .userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build();
this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); assertThatExceptionOfType(OAuth2AuthenticationException.class)
.isThrownBy(
() -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)))
.withMessageContaining(
"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error");
} }
@Test @Test
public void loadUserWhenUserInfoUriInvalidThenThrowOAuth2AuthenticationException() { public void loadUserWhenUserInfoUriInvalidThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString(
"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
String userInfoUri = "https://invalid-provider.com/user"; String userInfoUri = "https://invalid-provider.com/user";
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri)
.userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build(); .userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build();
this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); assertThatExceptionOfType(OAuth2AuthenticationException.class)
.isThrownBy(
() -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)))
.withMessageContaining(
"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource");
} }
// gh-5294 // gh-5294
@ -348,17 +347,18 @@ public class DefaultOAuth2UserServiceTests {
@Test @Test
public void loadUserWhenUserInfoSuccessResponseInvalidContentTypeThenThrowOAuth2AuthenticationException() { public void loadUserWhenUserInfoSuccessResponseInvalidContentTypeThenThrowOAuth2AuthenticationException() {
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString(
"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource "
+ "from '" + userInfoUri + "': response contains invalid content type 'text/plain'."));
MockResponse response = new MockResponse(); MockResponse response = new MockResponse();
response.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.TEXT_PLAIN_VALUE); response.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.TEXT_PLAIN_VALUE);
response.setBody("invalid content type"); response.setBody("invalid content type");
this.server.enqueue(response); this.server.enqueue(response);
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri)
.userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build(); .userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build();
this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); assertThatExceptionOfType(OAuth2AuthenticationException.class)
.isThrownBy(
() -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)))
.withMessageContaining(
"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource "
+ "from '" + userInfoUri + "': response contains invalid content type 'text/plain'.");
} }
private DefaultOAuth2UserService withMockResponse(Map<String, Object> response) { private DefaultOAuth2UserService withMockResponse(Map<String, Object> response) {

View File

@ -23,17 +23,15 @@ import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.springframework.security.converter.RsaKeyConverters; import org.springframework.security.converter.RsaKeyConverters;
import org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType; import org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType;
public class Saml2X509CredentialTests { import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
@Rule public class Saml2X509CredentialTests {
public ExpectedException exception = ExpectedException.none();
private PrivateKey key; private PrivateKey key;
@ -99,98 +97,90 @@ public class Saml2X509CredentialTests {
@Test @Test
public void constructorWhenRelyingPartyWithoutCredentialsThenItFails() { public void constructorWhenRelyingPartyWithoutCredentialsThenItFails() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(
new Saml2X509Credential(null, (X509Certificate) null, Saml2X509CredentialType.SIGNING); () -> new Saml2X509Credential(null, (X509Certificate) null, Saml2X509CredentialType.SIGNING));
} }
@Test @Test
public void constructorWhenRelyingPartyWithoutPrivateKeyThenItFails() { public void constructorWhenRelyingPartyWithoutPrivateKeyThenItFails() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException()
new Saml2X509Credential(null, this.certificate, Saml2X509CredentialType.SIGNING); .isThrownBy(() -> new Saml2X509Credential(null, this.certificate, Saml2X509CredentialType.SIGNING));
} }
@Test @Test
public void constructorWhenRelyingPartyWithoutCertificateThenItFails() { public void constructorWhenRelyingPartyWithoutCertificateThenItFails() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException()
new Saml2X509Credential(this.key, null, Saml2X509CredentialType.SIGNING); .isThrownBy(() -> new Saml2X509Credential(this.key, null, Saml2X509CredentialType.SIGNING));
} }
@Test @Test
public void constructorWhenAssertingPartyWithoutCertificateThenItFails() { public void constructorWhenAssertingPartyWithoutCertificateThenItFails() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException()
new Saml2X509Credential(null, Saml2X509CredentialType.SIGNING); .isThrownBy(() -> new Saml2X509Credential(null, Saml2X509CredentialType.SIGNING));
} }
@Test @Test
public void constructorWhenRelyingPartyWithEncryptionUsageThenItFails() { public void constructorWhenRelyingPartyWithEncryptionUsageThenItFails() {
this.exception.expect(IllegalStateException.class); assertThatIllegalStateException().isThrownBy(
new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.ENCRYPTION); () -> new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.ENCRYPTION));
} }
@Test @Test
public void constructorWhenRelyingPartyWithVerificationUsageThenItFails() { public void constructorWhenRelyingPartyWithVerificationUsageThenItFails() {
this.exception.expect(IllegalStateException.class); assertThatIllegalStateException().isThrownBy(
new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.VERIFICATION); () -> new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.VERIFICATION));
} }
@Test @Test
public void constructorWhenAssertingPartyWithSigningUsageThenItFails() { public void constructorWhenAssertingPartyWithSigningUsageThenItFails() {
this.exception.expect(IllegalStateException.class); assertThatIllegalStateException()
new Saml2X509Credential(this.certificate, Saml2X509CredentialType.SIGNING); .isThrownBy(() -> new Saml2X509Credential(this.certificate, Saml2X509CredentialType.SIGNING));
} }
@Test @Test
public void constructorWhenAssertingPartyWithDecryptionUsageThenItFails() { public void constructorWhenAssertingPartyWithDecryptionUsageThenItFails() {
this.exception.expect(IllegalStateException.class); assertThatIllegalStateException()
new Saml2X509Credential(this.certificate, Saml2X509CredentialType.DECRYPTION); .isThrownBy(() -> new Saml2X509Credential(this.certificate, Saml2X509CredentialType.DECRYPTION));
} }
@Test @Test
public void factoryWhenRelyingPartyForSigningWithoutCredentialsThenItFails() { public void factoryWhenRelyingPartyForSigningWithoutCredentialsThenItFails() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> Saml2X509Credential.signing(null, null));
Saml2X509Credential.signing(null, null);
} }
@Test @Test
public void factoryWhenRelyingPartyForSigningWithoutPrivateKeyThenItFails() { public void factoryWhenRelyingPartyForSigningWithoutPrivateKeyThenItFails() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> Saml2X509Credential.signing(null, this.certificate));
Saml2X509Credential.signing(null, this.certificate);
} }
@Test @Test
public void factoryWhenRelyingPartyForSigningWithoutCertificateThenItFails() { public void factoryWhenRelyingPartyForSigningWithoutCertificateThenItFails() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> Saml2X509Credential.signing(this.key, null));
Saml2X509Credential.signing(this.key, null);
} }
@Test @Test
public void factoryWhenRelyingPartyForDecryptionWithoutCredentialsThenItFails() { public void factoryWhenRelyingPartyForDecryptionWithoutCredentialsThenItFails() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> Saml2X509Credential.decryption(null, null));
Saml2X509Credential.decryption(null, null);
} }
@Test @Test
public void factoryWhenRelyingPartyForDecryptionWithoutPrivateKeyThenItFails() { public void factoryWhenRelyingPartyForDecryptionWithoutPrivateKeyThenItFails() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> Saml2X509Credential.decryption(null, this.certificate));
Saml2X509Credential.decryption(null, this.certificate);
} }
@Test @Test
public void factoryWhenRelyingPartyForDecryptionWithoutCertificateThenItFails() { public void factoryWhenRelyingPartyForDecryptionWithoutCertificateThenItFails() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> Saml2X509Credential.decryption(this.key, null));
Saml2X509Credential.decryption(this.key, null);
} }
@Test @Test
public void factoryWhenAssertingPartyForVerificationWithoutCertificateThenItFails() { public void factoryWhenAssertingPartyForVerificationWithoutCertificateThenItFails() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> Saml2X509Credential.verification(null));
Saml2X509Credential.verification(null);
} }
@Test @Test
public void factoryWhenAssertingPartyForEncryptionWithoutCertificateThenItFails() { public void factoryWhenAssertingPartyForEncryptionWithoutCertificateThenItFails() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> Saml2X509Credential.encryption(null));
Saml2X509Credential.encryption(null);
} }
} }

View File

@ -23,17 +23,15 @@ import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.springframework.security.converter.RsaKeyConverters; import org.springframework.security.converter.RsaKeyConverters;
import org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType; import org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType;
public class Saml2X509CredentialTests { import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
@Rule public class Saml2X509CredentialTests {
public ExpectedException exception = ExpectedException.none();
private Saml2X509Credential credential; private Saml2X509Credential credential;
@ -97,50 +95,50 @@ public class Saml2X509CredentialTests {
@Test @Test
public void constructorWhenRelyingPartyWithoutCredentialsThenItFails() { public void constructorWhenRelyingPartyWithoutCredentialsThenItFails() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(
new Saml2X509Credential(null, (X509Certificate) null, Saml2X509CredentialType.SIGNING); () -> new Saml2X509Credential(null, (X509Certificate) null, Saml2X509CredentialType.SIGNING));
} }
@Test @Test
public void constructorWhenRelyingPartyWithoutPrivateKeyThenItFails() { public void constructorWhenRelyingPartyWithoutPrivateKeyThenItFails() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException()
new Saml2X509Credential(null, this.certificate, Saml2X509CredentialType.SIGNING); .isThrownBy(() -> new Saml2X509Credential(null, this.certificate, Saml2X509CredentialType.SIGNING));
} }
@Test @Test
public void constructorWhenRelyingPartyWithoutCertificateThenItFails() { public void constructorWhenRelyingPartyWithoutCertificateThenItFails() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException()
new Saml2X509Credential(this.key, null, Saml2X509CredentialType.SIGNING); .isThrownBy(() -> new Saml2X509Credential(this.key, null, Saml2X509CredentialType.SIGNING));
} }
@Test @Test
public void constructorWhenAssertingPartyWithoutCertificateThenItFails() { public void constructorWhenAssertingPartyWithoutCertificateThenItFails() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException()
new Saml2X509Credential(null, Saml2X509CredentialType.SIGNING); .isThrownBy(() -> new Saml2X509Credential(null, Saml2X509CredentialType.SIGNING));
} }
@Test @Test
public void constructorWhenRelyingPartyWithEncryptionUsageThenItFails() { public void constructorWhenRelyingPartyWithEncryptionUsageThenItFails() {
this.exception.expect(IllegalStateException.class); assertThatIllegalStateException().isThrownBy(
new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.ENCRYPTION); () -> new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.ENCRYPTION));
} }
@Test @Test
public void constructorWhenRelyingPartyWithVerificationUsageThenItFails() { public void constructorWhenRelyingPartyWithVerificationUsageThenItFails() {
this.exception.expect(IllegalStateException.class); assertThatIllegalStateException().isThrownBy(
new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.VERIFICATION); () -> new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.VERIFICATION));
} }
@Test @Test
public void constructorWhenAssertingPartyWithSigningUsageThenItFails() { public void constructorWhenAssertingPartyWithSigningUsageThenItFails() {
this.exception.expect(IllegalStateException.class); assertThatIllegalStateException()
new Saml2X509Credential(this.certificate, Saml2X509CredentialType.SIGNING); .isThrownBy(() -> new Saml2X509Credential(this.certificate, Saml2X509CredentialType.SIGNING));
} }
@Test @Test
public void constructorWhenAssertingPartyWithDecryptionUsageThenItFails() { public void constructorWhenAssertingPartyWithDecryptionUsageThenItFails() {
this.exception.expect(IllegalStateException.class); assertThatIllegalStateException()
new Saml2X509Credential(this.certificate, Saml2X509CredentialType.DECRYPTION); .isThrownBy(() -> new Saml2X509Credential(this.certificate, Saml2X509CredentialType.DECRYPTION));
} }
} }

View File

@ -26,18 +26,14 @@ import java.util.HashMap;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.function.Consumer;
import javax.xml.namespace.QName; import javax.xml.namespace.QName;
import net.shibboleth.utilities.java.support.xml.SerializeSupport; import net.shibboleth.utilities.java.support.xml.SerializeSupport;
import org.hamcrest.BaseMatcher;
import org.hamcrest.Description;
import org.hamcrest.Matcher;
import org.joda.time.DateTime; import org.joda.time.DateTime;
import org.joda.time.Duration; import org.joda.time.Duration;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensaml.core.xml.XMLObject; import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.core.xml.io.Marshaller; import org.opensaml.core.xml.io.Marshaller;
@ -93,9 +89,6 @@ public class OpenSamlAuthenticationProviderTests {
private Saml2Authentication authentication = new Saml2Authentication(this.principal, "response", private Saml2Authentication authentication = new Saml2Authentication(this.principal, "response",
Collections.emptyList()); Collections.emptyList());
@Rule
public ExpectedException exception = ExpectedException.none();
@Test @Test
public void supportsWhenSaml2AuthenticationTokenThenReturnTrue() { public void supportsWhenSaml2AuthenticationTokenThenReturnTrue() {
assertThat(this.provider.supports(Saml2AuthenticationToken.class)) assertThat(this.provider.supports(Saml2AuthenticationToken.class))
@ -113,53 +106,56 @@ public class OpenSamlAuthenticationProviderTests {
@Test @Test
public void authenticateWhenUnknownDataClassThenThrowAuthenticationException() { public void authenticateWhenUnknownDataClassThenThrowAuthenticationException() {
this.exception.expect(authenticationMatcher(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA));
Assertion assertion = (Assertion) XMLObjectProviderRegistrySupport.getBuilderFactory() Assertion assertion = (Assertion) XMLObjectProviderRegistrySupport.getBuilderFactory()
.getBuilder(Assertion.DEFAULT_ELEMENT_NAME).buildObject(Assertion.DEFAULT_ELEMENT_NAME); .getBuilder(Assertion.DEFAULT_ELEMENT_NAME).buildObject(Assertion.DEFAULT_ELEMENT_NAME);
this.provider assertThatExceptionOfType(Saml2AuthenticationException.class)
.authenticate(token(serialize(assertion), TestSaml2X509Credentials.relyingPartyVerifyingCredential())); .isThrownBy(() -> this.provider.authenticate(
token(serialize(assertion), TestSaml2X509Credentials.relyingPartyVerifyingCredential())))
.satisfies(errorOf(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA));
} }
@Test @Test
public void authenticateWhenXmlErrorThenThrowAuthenticationException() { public void authenticateWhenXmlErrorThenThrowAuthenticationException() {
this.exception.expect(authenticationMatcher(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA));
Saml2AuthenticationToken token = token("invalid xml", Saml2AuthenticationToken token = token("invalid xml",
TestSaml2X509Credentials.relyingPartyVerifyingCredential()); TestSaml2X509Credentials.relyingPartyVerifyingCredential());
this.provider.authenticate(token); assertThatExceptionOfType(Saml2AuthenticationException.class)
.isThrownBy(() -> this.provider.authenticate(token))
.satisfies(errorOf(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA));
} }
@Test @Test
public void authenticateWhenInvalidDestinationThenThrowAuthenticationException() { public void authenticateWhenInvalidDestinationThenThrowAuthenticationException() {
this.exception.expect(authenticationMatcher(Saml2ErrorCodes.INVALID_DESTINATION));
Response response = TestOpenSamlObjects.response(DESTINATION + "invalid", ASSERTING_PARTY_ENTITY_ID); Response response = TestOpenSamlObjects.response(DESTINATION + "invalid", ASSERTING_PARTY_ENTITY_ID);
response.getAssertions().add(TestOpenSamlObjects.assertion()); response.getAssertions().add(TestOpenSamlObjects.assertion());
TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(),
RELYING_PARTY_ENTITY_ID); RELYING_PARTY_ENTITY_ID);
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
this.provider.authenticate(token); assertThatExceptionOfType(Saml2AuthenticationException.class)
.isThrownBy(() -> this.provider.authenticate(token))
.satisfies(errorOf(Saml2ErrorCodes.INVALID_DESTINATION));
} }
@Test @Test
public void authenticateWhenNoAssertionsPresentThenThrowAuthenticationException() { public void authenticateWhenNoAssertionsPresentThenThrowAuthenticationException() {
this.exception.expect(
authenticationMatcher(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, "No assertions found in response."));
Saml2AuthenticationToken token = token(TestOpenSamlObjects.response(), Saml2AuthenticationToken token = token(TestOpenSamlObjects.response(),
TestSaml2X509Credentials.assertingPartySigningCredential()); TestSaml2X509Credentials.assertingPartySigningCredential());
this.provider.authenticate(token); assertThatExceptionOfType(Saml2AuthenticationException.class)
.isThrownBy(() -> this.provider.authenticate(token))
.satisfies(errorOf(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, "No assertions found in response."));
} }
@Test @Test
public void authenticateWhenInvalidSignatureOnAssertionThenThrowAuthenticationException() { public void authenticateWhenInvalidSignatureOnAssertionThenThrowAuthenticationException() {
this.exception.expect(authenticationMatcher(Saml2ErrorCodes.INVALID_SIGNATURE));
Response response = TestOpenSamlObjects.response(); Response response = TestOpenSamlObjects.response();
response.getAssertions().add(TestOpenSamlObjects.assertion()); response.getAssertions().add(TestOpenSamlObjects.assertion());
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
this.provider.authenticate(token); assertThatExceptionOfType(Saml2AuthenticationException.class)
.isThrownBy(() -> this.provider.authenticate(token))
.satisfies(errorOf(Saml2ErrorCodes.INVALID_SIGNATURE));
} }
@Test @Test
public void authenticateWhenOpenSAMLValidationErrorThenThrowAuthenticationException() throws Exception { public void authenticateWhenOpenSAMLValidationErrorThenThrowAuthenticationException() throws Exception {
this.exception.expect(authenticationMatcher(Saml2ErrorCodes.INVALID_ASSERTION));
Response response = TestOpenSamlObjects.response(); Response response = TestOpenSamlObjects.response();
Assertion assertion = TestOpenSamlObjects.assertion(); Assertion assertion = TestOpenSamlObjects.assertion();
assertion.getSubject().getSubjectConfirmations().get(0).getSubjectConfirmationData() assertion.getSubject().getSubjectConfirmations().get(0).getSubjectConfirmationData()
@ -168,12 +164,13 @@ public class OpenSamlAuthenticationProviderTests {
RELYING_PARTY_ENTITY_ID); RELYING_PARTY_ENTITY_ID);
response.getAssertions().add(assertion); response.getAssertions().add(assertion);
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
this.provider.authenticate(token); assertThatExceptionOfType(Saml2AuthenticationException.class)
.isThrownBy(() -> this.provider.authenticate(token))
.satisfies(errorOf(Saml2ErrorCodes.INVALID_ASSERTION));
} }
@Test @Test
public void authenticateWhenMissingSubjectThenThrowAuthenticationException() { public void authenticateWhenMissingSubjectThenThrowAuthenticationException() {
this.exception.expect(authenticationMatcher(Saml2ErrorCodes.SUBJECT_NOT_FOUND));
Response response = TestOpenSamlObjects.response(); Response response = TestOpenSamlObjects.response();
Assertion assertion = TestOpenSamlObjects.assertion(); Assertion assertion = TestOpenSamlObjects.assertion();
assertion.setSubject(null); assertion.setSubject(null);
@ -181,12 +178,13 @@ public class OpenSamlAuthenticationProviderTests {
RELYING_PARTY_ENTITY_ID); RELYING_PARTY_ENTITY_ID);
response.getAssertions().add(assertion); response.getAssertions().add(assertion);
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
this.provider.authenticate(token); assertThatExceptionOfType(Saml2AuthenticationException.class)
.isThrownBy(() -> this.provider.authenticate(token))
.satisfies(errorOf(Saml2ErrorCodes.SUBJECT_NOT_FOUND));
} }
@Test @Test
public void authenticateWhenUsernameMissingThenThrowAuthenticationException() throws Exception { public void authenticateWhenUsernameMissingThenThrowAuthenticationException() throws Exception {
this.exception.expect(authenticationMatcher(Saml2ErrorCodes.SUBJECT_NOT_FOUND));
Response response = TestOpenSamlObjects.response(); Response response = TestOpenSamlObjects.response();
Assertion assertion = TestOpenSamlObjects.assertion(); Assertion assertion = TestOpenSamlObjects.assertion();
assertion.getSubject().getNameID().setValue(null); assertion.getSubject().getNameID().setValue(null);
@ -194,7 +192,9 @@ public class OpenSamlAuthenticationProviderTests {
RELYING_PARTY_ENTITY_ID); RELYING_PARTY_ENTITY_ID);
response.getAssertions().add(assertion); response.getAssertions().add(assertion);
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
this.provider.authenticate(token); assertThatExceptionOfType(Saml2AuthenticationException.class)
.isThrownBy(() -> this.provider.authenticate(token))
.satisfies(errorOf(Saml2ErrorCodes.SUBJECT_NOT_FOUND));
} }
@Test @Test
@ -236,13 +236,14 @@ public class OpenSamlAuthenticationProviderTests {
@Test @Test
public void authenticateWhenEncryptedAssertionWithoutSignatureThenItFails() throws Exception { public void authenticateWhenEncryptedAssertionWithoutSignatureThenItFails() throws Exception {
this.exception.expect(authenticationMatcher(Saml2ErrorCodes.INVALID_SIGNATURE));
Response response = TestOpenSamlObjects.response(); Response response = TestOpenSamlObjects.response();
EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(), EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(),
TestSaml2X509Credentials.assertingPartyEncryptingCredential()); TestSaml2X509Credentials.assertingPartyEncryptingCredential());
response.getEncryptedAssertions().add(encryptedAssertion); response.getEncryptedAssertions().add(encryptedAssertion);
Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyDecryptingCredential()); Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyDecryptingCredential());
this.provider.authenticate(token); assertThatExceptionOfType(Saml2AuthenticationException.class)
.isThrownBy(() -> this.provider.authenticate(token))
.satisfies(errorOf(Saml2ErrorCodes.INVALID_SIGNATURE));
} }
@Test @Test
@ -290,28 +291,28 @@ public class OpenSamlAuthenticationProviderTests {
@Test @Test
public void authenticateWhenDecryptionKeysAreMissingThenThrowAuthenticationException() throws Exception { public void authenticateWhenDecryptionKeysAreMissingThenThrowAuthenticationException() throws Exception {
this.exception
.expect(authenticationMatcher(Saml2ErrorCodes.DECRYPTION_ERROR, "Failed to decrypt EncryptedData"));
Response response = TestOpenSamlObjects.response(); Response response = TestOpenSamlObjects.response();
EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(), EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(),
TestSaml2X509Credentials.assertingPartyEncryptingCredential()); TestSaml2X509Credentials.assertingPartyEncryptingCredential());
response.getEncryptedAssertions().add(encryptedAssertion); response.getEncryptedAssertions().add(encryptedAssertion);
Saml2AuthenticationToken token = token(serialize(response), Saml2AuthenticationToken token = token(serialize(response),
TestSaml2X509Credentials.relyingPartyVerifyingCredential()); TestSaml2X509Credentials.relyingPartyVerifyingCredential());
this.provider.authenticate(token); assertThatExceptionOfType(Saml2AuthenticationException.class)
.isThrownBy(() -> this.provider.authenticate(token))
.satisfies(errorOf(Saml2ErrorCodes.DECRYPTION_ERROR, "Failed to decrypt EncryptedData"));
} }
@Test @Test
public void authenticateWhenDecryptionKeysAreWrongThenThrowAuthenticationException() throws Exception { public void authenticateWhenDecryptionKeysAreWrongThenThrowAuthenticationException() throws Exception {
this.exception
.expect(authenticationMatcher(Saml2ErrorCodes.DECRYPTION_ERROR, "Failed to decrypt EncryptedData"));
Response response = TestOpenSamlObjects.response(); Response response = TestOpenSamlObjects.response();
EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(), EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(),
TestSaml2X509Credentials.assertingPartyEncryptingCredential()); TestSaml2X509Credentials.assertingPartyEncryptingCredential());
response.getEncryptedAssertions().add(encryptedAssertion); response.getEncryptedAssertions().add(encryptedAssertion);
Saml2AuthenticationToken token = token(serialize(response), Saml2AuthenticationToken token = token(serialize(response),
TestSaml2X509Credentials.assertingPartyPrivateCredential()); TestSaml2X509Credentials.assertingPartyPrivateCredential());
this.provider.authenticate(token); assertThatExceptionOfType(Saml2AuthenticationException.class)
.isThrownBy(() -> this.provider.authenticate(token))
.satisfies(errorOf(Saml2ErrorCodes.DECRYPTION_ERROR, "Failed to decrypt EncryptedData"));
} }
@Test @Test
@ -487,33 +488,15 @@ public class OpenSamlAuthenticationProviderTests {
} }
} }
private Matcher<Saml2AuthenticationException> authenticationMatcher(String code) { private Consumer<Saml2AuthenticationException> errorOf(String errorCode) {
return authenticationMatcher(code, null); return errorOf(errorCode, null);
} }
private Matcher<Saml2AuthenticationException> authenticationMatcher(String code, String description) { private Consumer<Saml2AuthenticationException> errorOf(String errorCode, String description) {
return new BaseMatcher<Saml2AuthenticationException>() { return (ex) -> {
@Override assertThat(ex.getError().getErrorCode()).isEqualTo(errorCode);
public boolean matches(Object item) { if (StringUtils.hasText(description)) {
if (!(item instanceof Saml2AuthenticationException)) { assertThat(ex.getError().getDescription()).isEqualTo(description);
return false;
}
Saml2AuthenticationException ex = (Saml2AuthenticationException) item;
if (!code.equals(ex.getError().getErrorCode())) {
return false;
}
if (StringUtils.hasText(description)) {
if (!description.equals(ex.getError().getDescription())) {
return false;
}
}
return true;
}
@Override
public void describeTo(Description desc) {
String excepting = "Saml2AuthenticationException[code=" + code + "; description=" + description + "]";
desc.appendText(excepting);
} }
}; };
} }

View File

@ -21,9 +21,7 @@ import java.nio.charset.StandardCharsets;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.saml.common.xml.SAMLConstants; import org.opensaml.saml.common.xml.SAMLConstants;
import org.opensaml.saml.saml2.core.AuthnRequest; import org.opensaml.saml.saml2.core.AuthnRequest;
@ -39,7 +37,6 @@ import org.springframework.security.saml2.provider.service.registration.Saml2Mes
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.hamcrest.CoreMatchers.containsString;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -61,9 +58,6 @@ public class OpenSamlAuthenticationRequestFactoryTests {
private AuthnRequestUnmarshaller unmarshaller; private AuthnRequestUnmarshaller unmarshaller;
@Rule
public ExpectedException exception = ExpectedException.none();
@Before @Before
public void setUp() { public void setUp() {
this.relyingPartyRegistrationBuilder = RelyingPartyRegistration.withRegistrationId("id") this.relyingPartyRegistrationBuilder = RelyingPartyRegistration.withRegistrationId("id")
@ -160,9 +154,8 @@ public class OpenSamlAuthenticationRequestFactoryTests {
@Test @Test
public void createAuthenticationRequestWhenSetUnsupportredUriThenThrowsIllegalArgumentException() { public void createAuthenticationRequestWhenSetUnsupportredUriThenThrowsIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> this.factory.setProtocolBinding("my-invalid-binding"))
this.exception.expectMessage(containsString("my-invalid-binding")); .withMessageContaining("my-invalid-binding");
this.factory.setProtocolBinding("my-invalid-binding");
} }
@Test @Test

View File

@ -20,9 +20,7 @@ import javax.servlet.http.HttpServletResponse;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
@ -43,9 +41,6 @@ public class Saml2WebSsoAuthenticationFilterTests {
private HttpServletResponse response = new MockHttpServletResponse(); private HttpServletResponse response = new MockHttpServletResponse();
@Rule
public ExpectedException exception = ExpectedException.none();
@Before @Before
public void setup() { public void setup() {
this.filter = new Saml2WebSsoAuthenticationFilter(this.repository); this.filter = new Saml2WebSsoAuthenticationFilter(this.repository);
@ -55,9 +50,9 @@ public class Saml2WebSsoAuthenticationFilterTests {
@Test @Test
public void constructingFilterWithMissingRegistrationIdVariableThenThrowsException() { public void constructingFilterWithMissingRegistrationIdVariableThenThrowsException() {
this.exception.expect(IllegalArgumentException.class); assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(
this.exception.expectMessage("filterProcessesUrl must contain a {registrationId} match variable"); () -> this.filter = new Saml2WebSsoAuthenticationFilter(this.repository, "/url/missing/variable"))
this.filter = new Saml2WebSsoAuthenticationFilter(this.repository, "/url/missing/variable"); .withMessage("filterProcessesUrl must contain a {registrationId} match variable");
} }
@Test @Test

View File

@ -22,9 +22,7 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner; import org.mockito.junit.MockitoJUnitRunner;
@ -35,6 +33,7 @@ import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.authentication.CredentialsExpiredException; import org.springframework.security.authentication.CredentialsExpiredException;
import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.AuthenticationException;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.verifyZeroInteractions;
@ -48,9 +47,6 @@ import static org.mockito.Mockito.verifyZeroInteractions;
@RunWith(MockitoJUnitRunner.class) @RunWith(MockitoJUnitRunner.class)
public class DelegatingAuthenticationFailureHandlerTests { public class DelegatingAuthenticationFailureHandlerTests {
@Rule
public ExpectedException thrown = ExpectedException.none();
@Mock @Mock
private AuthenticationFailureHandler handler1; private AuthenticationFailureHandler handler1;
@ -110,24 +106,24 @@ public class DelegatingAuthenticationFailureHandlerTests {
@Test @Test
public void handlersIsNull() { public void handlersIsNull() {
this.thrown.expect(IllegalArgumentException.class); assertThatIllegalArgumentException()
this.thrown.expectMessage("handlers cannot be null or empty"); .isThrownBy(() -> new DelegatingAuthenticationFailureHandler(null, this.defaultHandler))
new DelegatingAuthenticationFailureHandler(null, this.defaultHandler); .withMessage("handlers cannot be null or empty");
} }
@Test @Test
public void handlersIsEmpty() { public void handlersIsEmpty() {
this.thrown.expect(IllegalArgumentException.class); assertThatIllegalArgumentException()
this.thrown.expectMessage("handlers cannot be null or empty"); .isThrownBy(() -> new DelegatingAuthenticationFailureHandler(this.handlers, this.defaultHandler))
new DelegatingAuthenticationFailureHandler(this.handlers, this.defaultHandler); .withMessage("handlers cannot be null or empty");
} }
@Test @Test
public void defaultHandlerIsNull() { public void defaultHandlerIsNull() {
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("defaultHandler cannot be null");
this.handlers.put(BadCredentialsException.class, this.handler1); this.handlers.put(BadCredentialsException.class, this.handler1);
new DelegatingAuthenticationFailureHandler(this.handlers, null); assertThatIllegalArgumentException()
.isThrownBy(() -> new DelegatingAuthenticationFailureHandler(this.handlers, null))
.withMessage("defaultHandler cannot be null");
} }
} }

View File

@ -22,9 +22,7 @@ import java.util.List;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.InOrder; import org.mockito.InOrder;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
@ -45,14 +43,10 @@ import static org.mockito.Mockito.verify;
*/ */
public class CompositeLogoutHandlerTests { public class CompositeLogoutHandlerTests {
@Rule
public ExpectedException exception = ExpectedException.none();
@Test @Test
public void buildEmptyCompositeLogoutHandlerThrowsException() { public void buildEmptyCompositeLogoutHandlerThrowsException() {
this.exception.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> new CompositeLogoutHandler())
this.exception.expectMessage("LogoutHandlers are required"); .withMessage("LogoutHandlers are required");
new CompositeLogoutHandler();
} }
@Test @Test

View File

@ -16,15 +16,14 @@
package org.springframework.security.web.authentication.logout; package org.springframework.security.web.authentication.logout;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
/** /**
@ -34,23 +33,18 @@ import static org.mockito.Mockito.mock;
*/ */
public class ForwardLogoutSuccessHandlerTests { public class ForwardLogoutSuccessHandlerTests {
@Rule
public ExpectedException thrown = ExpectedException.none();
@Test @Test
public void invalidTargetUrl() { public void invalidTargetUrl() {
String targetUrl = "not.valid"; String targetUrl = "not.valid";
this.thrown.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> new ForwardLogoutSuccessHandler(targetUrl))
this.thrown.expectMessage("'" + targetUrl + "' is not a valid target URL"); .withMessage("'" + targetUrl + "' is not a valid target URL");
new ForwardLogoutSuccessHandler(targetUrl);
} }
@Test @Test
public void emptyTargetUrl() { public void emptyTargetUrl() {
String targetUrl = " "; String targetUrl = " ";
this.thrown.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> new ForwardLogoutSuccessHandler(targetUrl))
this.thrown.expectMessage("'" + targetUrl + "' is not a valid target URL"); .withMessage("'" + targetUrl + "' is not a valid target URL");
new ForwardLogoutSuccessHandler(targetUrl);
} }
@Test @Test

View File

@ -17,15 +17,14 @@
package org.springframework.security.web.authentication.logout; package org.springframework.security.web.authentication.logout;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.web.header.HeaderWriter; import org.springframework.security.web.header.HeaderWriter;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -40,9 +39,6 @@ public class HeaderWriterLogoutHandlerTests {
private MockHttpServletRequest request; private MockHttpServletRequest request;
@Rule
public ExpectedException thrown = ExpectedException.none();
@Before @Before
public void setup() { public void setup() {
this.response = new MockHttpServletResponse(); this.response = new MockHttpServletResponse();
@ -51,9 +47,8 @@ public class HeaderWriterLogoutHandlerTests {
@Test @Test
public void constructorWhenHeaderWriterIsNullThenThrowsException() { public void constructorWhenHeaderWriterIsNullThenThrowsException() {
this.thrown.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> new HeaderWriterLogoutHandler(null))
this.thrown.expectMessage("headerWriter cannot be null"); .withMessage("headerWriter cannot be null");
new HeaderWriterLogoutHandler(null);
} }
@Test @Test

View File

@ -23,9 +23,7 @@ import javax.servlet.FilterChain;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
@ -49,6 +47,7 @@ import org.springframework.security.web.authentication.SimpleUrlAuthenticationSu
import org.springframework.security.web.util.matcher.AnyRequestMatcher; import org.springframework.security.web.util.matcher.AnyRequestMatcher;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -64,9 +63,6 @@ public class SwitchUserFilterTests {
private static final List<GrantedAuthority> ROLES_12 = AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO"); private static final List<GrantedAuthority> ROLES_12 = AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO");
@Rule
public ExpectedException thrown = ExpectedException.none();
@Before @Before
public void authenticateCurrentUser() { public void authenticateCurrentUser() {
UsernamePasswordAuthenticationToken auth = new UsernamePasswordAuthenticationToken("dano", "hawaii50"); UsernamePasswordAuthenticationToken auth = new UsernamePasswordAuthenticationToken("dano", "hawaii50");
@ -437,9 +433,8 @@ public class SwitchUserFilterTests {
// gh-3697 // gh-3697
@Test @Test
public void switchAuthorityRoleCannotBeNull() { public void switchAuthorityRoleCannotBeNull() {
this.thrown.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> switchToUserWithAuthorityRole("dano", null))
this.thrown.expectMessage("switchAuthorityRole cannot be null"); .withMessage("switchAuthorityRole cannot be null");
switchToUserWithAuthorityRole("dano", null);
} }
// gh-3697 // gh-3697

View File

@ -20,9 +20,7 @@ import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
@ -35,8 +33,7 @@ import static org.mockito.Mockito.verify;
*/ */
public class FirewalledResponseTests { public class FirewalledResponseTests {
@Rule private static final String CRLF_MESSAGE = "Invalid characters (CR/LF)";
public ExpectedException expectedException = ExpectedException.none();
private HttpServletResponse response; private HttpServletResponse response;
@ -62,8 +59,8 @@ public class FirewalledResponseTests {
@Test @Test
public void sendRedirectWhenHasCrlfThenThrowsException() throws Exception { public void sendRedirectWhenHasCrlfThenThrowsException() throws Exception {
expectCrlfValidationException(); assertThatIllegalArgumentException().isThrownBy(() -> this.fwResponse.sendRedirect("/theURL\r\nsomething"))
this.fwResponse.sendRedirect("/theURL\r\nsomething"); .withMessageContaining(CRLF_MESSAGE);
} }
@Test @Test
@ -80,14 +77,16 @@ public class FirewalledResponseTests {
@Test @Test
public void addHeaderWhenHeaderValueHasCrlfThenException() { public void addHeaderWhenHeaderValueHasCrlfThenException() {
expectCrlfValidationException(); assertThatIllegalArgumentException()
this.fwResponse.addHeader("foo", "abc\r\nContent-Length:100"); .isThrownBy(() -> this.fwResponse.addHeader("foo", "abc\r\nContent-Length:100"))
.withMessageContaining(CRLF_MESSAGE);
} }
@Test @Test
public void addHeaderWhenHeaderNameHasCrlfThenException() { public void addHeaderWhenHeaderNameHasCrlfThenException() {
expectCrlfValidationException(); assertThatIllegalArgumentException()
this.fwResponse.addHeader("abc\r\nContent-Length:100", "bar"); .isThrownBy(() -> this.fwResponse.addHeader("abc\r\nContent-Length:100", "bar"))
.withMessageContaining(CRLF_MESSAGE);
} }
@Test @Test
@ -115,39 +114,39 @@ public class FirewalledResponseTests {
return "foo\r\nbar"; return "foo\r\nbar";
} }
}; };
expectCrlfValidationException(); assertThatIllegalArgumentException().isThrownBy(() -> this.fwResponse.addCookie(cookie))
this.fwResponse.addCookie(cookie); .withMessageContaining(CRLF_MESSAGE);
} }
@Test @Test
public void addCookieWhenCookieValueContainsCrlfThenException() { public void addCookieWhenCookieValueContainsCrlfThenException() {
Cookie cookie = new Cookie("foo", "foo\r\nbar"); Cookie cookie = new Cookie("foo", "foo\r\nbar");
expectCrlfValidationException(); assertThatIllegalArgumentException().isThrownBy(() -> this.fwResponse.addCookie(cookie))
this.fwResponse.addCookie(cookie); .withMessageContaining(CRLF_MESSAGE);
} }
@Test @Test
public void addCookieWhenCookiePathContainsCrlfThenException() { public void addCookieWhenCookiePathContainsCrlfThenException() {
Cookie cookie = new Cookie("foo", "bar"); Cookie cookie = new Cookie("foo", "bar");
cookie.setPath("/foo\r\nbar"); cookie.setPath("/foo\r\nbar");
expectCrlfValidationException(); assertThatIllegalArgumentException().isThrownBy(() -> this.fwResponse.addCookie(cookie))
this.fwResponse.addCookie(cookie); .withMessageContaining(CRLF_MESSAGE);
} }
@Test @Test
public void addCookieWhenCookieDomainContainsCrlfThenException() { public void addCookieWhenCookieDomainContainsCrlfThenException() {
Cookie cookie = new Cookie("foo", "bar"); Cookie cookie = new Cookie("foo", "bar");
cookie.setDomain("foo\r\nbar"); cookie.setDomain("foo\r\nbar");
expectCrlfValidationException(); assertThatIllegalArgumentException().isThrownBy(() -> this.fwResponse.addCookie(cookie))
this.fwResponse.addCookie(cookie); .withMessageContaining(CRLF_MESSAGE);
} }
@Test @Test
public void addCookieWhenCookieCommentContainsCrlfThenException() { public void addCookieWhenCookieCommentContainsCrlfThenException() {
Cookie cookie = new Cookie("foo", "bar"); Cookie cookie = new Cookie("foo", "bar");
cookie.setComment("foo\r\nbar"); cookie.setComment("foo\r\nbar");
expectCrlfValidationException(); assertThatIllegalArgumentException().isThrownBy(() -> this.fwResponse.addCookie(cookie))
this.fwResponse.addCookie(cookie); .withMessageContaining(CRLF_MESSAGE);
} }
@Test @Test
@ -160,11 +159,6 @@ public class FirewalledResponseTests {
validateLineEnding("foo\nbar", "bar"); validateLineEnding("foo\nbar", "bar");
} }
private void expectCrlfValidationException() {
this.expectedException.expect(IllegalArgumentException.class);
this.expectedException.expectMessage("Invalid characters (CR/LF)");
}
private void validateLineEnding(String name, String value) { private void validateLineEnding(String name, String value) {
assertThatIllegalArgumentException().isThrownBy(() -> this.fwResponse.validateCrlf(name, value)); assertThatIllegalArgumentException().isThrownBy(() -> this.fwResponse.validateCrlf(name, value));
} }

View File

@ -17,15 +17,14 @@
package org.springframework.security.web.header.writers; package org.springframework.security.web.header.writers;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.web.header.writers.ClearSiteDataHeaderWriter.Directive; import org.springframework.security.web.header.writers.ClearSiteDataHeaderWriter.Directive;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
/** /**
* @author Rafiullah Hamedy * @author Rafiullah Hamedy
@ -40,9 +39,6 @@ public class ClearSiteDataHeaderWriterTests {
private MockHttpServletResponse response; private MockHttpServletResponse response;
@Rule
public ExpectedException thrown = ExpectedException.none();
@Before @Before
public void setup() { public void setup() {
this.request = new MockHttpServletRequest(); this.request = new MockHttpServletRequest();
@ -52,9 +48,8 @@ public class ClearSiteDataHeaderWriterTests {
@Test @Test
public void createInstanceWhenMissingSourceThenThrowsException() { public void createInstanceWhenMissingSourceThenThrowsException() {
this.thrown.expect(Exception.class); assertThatExceptionOfType(Exception.class).isThrownBy(() -> new ClearSiteDataHeaderWriter())
this.thrown.expectMessage("directives cannot be empty or null"); .withMessage("directives cannot be empty or null");
new ClearSiteDataHeaderWriter();
} }
@Test @Test

View File

@ -20,9 +20,7 @@ import java.security.Principal;
import java.util.Collections; import java.util.Collections;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.Mock; import org.mockito.Mock;
@ -55,7 +53,8 @@ import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.web.server.WebFilterChain; import org.springframework.web.server.WebFilterChain;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.fail; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
@ -83,9 +82,6 @@ public class SwitchUserWebFilterTests {
@Mock @Mock
private ServerSecurityContextRepository serverSecurityContextRepository; private ServerSecurityContextRepository serverSecurityContextRepository;
@Rule
public ExpectedException exceptionRule = ExpectedException.none();
@Before @Before
public void setUp() { public void setUp() {
this.switchUserWebFilter = new SwitchUserWebFilter(this.userDetailsService, this.successHandler, this.switchUserWebFilter = new SwitchUserWebFilter(this.userDetailsService, this.successHandler,
@ -183,11 +179,12 @@ public class SwitchUserWebFilterTests {
.from(MockServerHttpRequest.post("/login/impersonate")); .from(MockServerHttpRequest.post("/login/impersonate"));
final WebFilterChain chain = mock(WebFilterChain.class); final WebFilterChain chain = mock(WebFilterChain.class);
final SecurityContextImpl securityContext = new SecurityContextImpl(mock(Authentication.class)); final SecurityContextImpl securityContext = new SecurityContextImpl(mock(Authentication.class));
this.exceptionRule.expect(IllegalArgumentException.class); assertThatIllegalArgumentException()
this.exceptionRule.expectMessage("The userName can not be null."); .isThrownBy(() -> this.switchUserWebFilter.filter(exchange, chain)
this.switchUserWebFilter.filter(exchange, chain) .subscriberContext(
.subscriberContext(ReactiveSecurityContextHolder.withSecurityContext(Mono.just(securityContext))) ReactiveSecurityContextHolder.withSecurityContext(Mono.just(securityContext)))
.block(); .block())
.withMessage("The userName can not be null.");
verifyNoInteractions(chain); verifyNoInteractions(chain);
} }
@ -219,10 +216,12 @@ public class SwitchUserWebFilterTests {
final SecurityContextImpl securityContext = new SecurityContextImpl(mock(Authentication.class)); final SecurityContextImpl securityContext = new SecurityContextImpl(mock(Authentication.class));
final UserDetails switchUserDetails = switchUserDetails(targetUsername, false); final UserDetails switchUserDetails = switchUserDetails(targetUsername, false);
given(this.userDetailsService.findByUsername(any(String.class))).willReturn(Mono.just(switchUserDetails)); given(this.userDetailsService.findByUsername(any(String.class))).willReturn(Mono.just(switchUserDetails));
this.exceptionRule.expect(DisabledException.class); assertThatExceptionOfType(DisabledException.class)
this.switchUserWebFilter.filter(exchange, chain) .isThrownBy(
.subscriberContext(ReactiveSecurityContextHolder.withSecurityContext(Mono.just(securityContext))) () -> this.switchUserWebFilter.filter(exchange, chain)
.block(); .subscriberContext(
ReactiveSecurityContextHolder.withSecurityContext(Mono.just(securityContext)))
.block());
verifyNoInteractions(chain); verifyNoInteractions(chain);
} }
@ -265,11 +264,12 @@ public class SwitchUserWebFilterTests {
"origCredentials"); "origCredentials");
final WebFilterChain chain = mock(WebFilterChain.class); final WebFilterChain chain = mock(WebFilterChain.class);
final SecurityContextImpl securityContext = new SecurityContextImpl(originalAuthentication); final SecurityContextImpl securityContext = new SecurityContextImpl(originalAuthentication);
this.exceptionRule.expect(AuthenticationCredentialsNotFoundException.class); assertThatExceptionOfType(AuthenticationCredentialsNotFoundException.class)
this.exceptionRule.expectMessage("Could not find original Authentication object"); .isThrownBy(() -> this.switchUserWebFilter.filter(exchange, chain)
this.switchUserWebFilter.filter(exchange, chain) .subscriberContext(
.subscriberContext(ReactiveSecurityContextHolder.withSecurityContext(Mono.just(securityContext))) ReactiveSecurityContextHolder.withSecurityContext(Mono.just(securityContext)))
.block(); .block())
.withMessage("Could not find original Authentication object");
verifyNoInteractions(chain); verifyNoInteractions(chain);
} }
@ -278,34 +278,35 @@ public class SwitchUserWebFilterTests {
final MockServerWebExchange exchange = MockServerWebExchange final MockServerWebExchange exchange = MockServerWebExchange
.from(MockServerHttpRequest.post("/logout/impersonate")); .from(MockServerHttpRequest.post("/logout/impersonate"));
final WebFilterChain chain = mock(WebFilterChain.class); final WebFilterChain chain = mock(WebFilterChain.class);
this.exceptionRule.expect(AuthenticationCredentialsNotFoundException.class); assertThatExceptionOfType(AuthenticationCredentialsNotFoundException.class)
this.exceptionRule.expectMessage("No current user associated with this request"); .isThrownBy(() -> this.switchUserWebFilter.filter(exchange, chain).block())
this.switchUserWebFilter.filter(exchange, chain).block(); .withMessage("No current user associated with this request");
verifyNoInteractions(chain); verifyNoInteractions(chain);
} }
@Test @Test
public void constructorUserDetailsServiceRequired() { public void constructorUserDetailsServiceRequired() {
this.exceptionRule.expect(IllegalArgumentException.class); assertThatIllegalArgumentException()
this.exceptionRule.expectMessage("userDetailsService must be specified"); .isThrownBy(() -> this.switchUserWebFilter = new SwitchUserWebFilter(null,
this.switchUserWebFilter = new SwitchUserWebFilter(null, mock(ServerAuthenticationSuccessHandler.class), mock(ServerAuthenticationSuccessHandler.class), mock(ServerAuthenticationFailureHandler.class)))
mock(ServerAuthenticationFailureHandler.class)); .withMessage("userDetailsService must be specified");
} }
@Test @Test
public void constructorServerAuthenticationSuccessHandlerRequired() { public void constructorServerAuthenticationSuccessHandlerRequired() {
this.exceptionRule.expect(IllegalArgumentException.class); assertThatIllegalArgumentException()
this.exceptionRule.expectMessage("successHandler must be specified"); .isThrownBy(
this.switchUserWebFilter = new SwitchUserWebFilter(mock(ReactiveUserDetailsService.class), null, () -> this.switchUserWebFilter = new SwitchUserWebFilter(mock(ReactiveUserDetailsService.class),
mock(ServerAuthenticationFailureHandler.class)); null, mock(ServerAuthenticationFailureHandler.class)))
.withMessage("successHandler must be specified");
} }
@Test @Test
public void constructorSuccessTargetUrlRequired() { public void constructorSuccessTargetUrlRequired() {
this.exceptionRule.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(
this.exceptionRule.expectMessage("successTargetUrl must be specified"); () -> this.switchUserWebFilter = new SwitchUserWebFilter(mock(ReactiveUserDetailsService.class), null,
this.switchUserWebFilter = new SwitchUserWebFilter(mock(ReactiveUserDetailsService.class), null, "failure/target/url"))
"failure/target/url"); .withMessage("successTargetUrl must be specified");
} }
@Test @Test
@ -336,10 +337,9 @@ public class SwitchUserWebFilterTests {
@Test @Test
public void setSecurityContextRepositoryWhenNullThenThrowException() { public void setSecurityContextRepositoryWhenNullThenThrowException() {
this.exceptionRule.expect(IllegalArgumentException.class); assertThatIllegalArgumentException()
this.exceptionRule.expectMessage("securityContextRepository cannot be null"); .isThrownBy(() -> this.switchUserWebFilter.setSecurityContextRepository(null))
this.switchUserWebFilter.setSecurityContextRepository(null); .withMessage("securityContextRepository cannot be null");
fail("Test should fail with exception");
} }
@Test @Test
@ -357,18 +357,14 @@ public class SwitchUserWebFilterTests {
@Test @Test
public void setExitUserUrlWhenNullThenThrowException() { public void setExitUserUrlWhenNullThenThrowException() {
this.exceptionRule.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> this.switchUserWebFilter.setExitUserUrl(null))
this.exceptionRule.expectMessage("exitUserUrl cannot be empty and must be a valid redirect URL"); .withMessage("exitUserUrl cannot be empty and must be a valid redirect URL");
this.switchUserWebFilter.setExitUserUrl(null);
fail("Test should fail with exception");
} }
@Test @Test
public void setExitUserUrlWhenInvalidUrlThenThrowException() { public void setExitUserUrlWhenInvalidUrlThenThrowException() {
this.exceptionRule.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> this.switchUserWebFilter.setExitUserUrl("wrongUrl"))
this.exceptionRule.expectMessage("exitUserUrl cannot be empty and must be a valid redirect URL"); .withMessage("exitUserUrl cannot be empty and must be a valid redirect URL");
this.switchUserWebFilter.setExitUserUrl("wrongUrl");
fail("Test should fail with exception");
} }
@Test @Test
@ -387,10 +383,8 @@ public class SwitchUserWebFilterTests {
@Test @Test
public void setExitUserMatcherWhenNullThenThrowException() { public void setExitUserMatcherWhenNullThenThrowException() {
this.exceptionRule.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> this.switchUserWebFilter.setExitUserMatcher(null))
this.exceptionRule.expectMessage("exitUserMatcher cannot be null"); .withMessage("exitUserMatcher cannot be null");
this.switchUserWebFilter.setExitUserMatcher(null);
fail("Test should fail with exception");
} }
@Test @Test
@ -410,18 +404,14 @@ public class SwitchUserWebFilterTests {
@Test @Test
public void setSwitchUserUrlWhenNullThenThrowException() { public void setSwitchUserUrlWhenNullThenThrowException() {
this.exceptionRule.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> this.switchUserWebFilter.setSwitchUserUrl(null))
this.exceptionRule.expectMessage("switchUserUrl cannot be empty and must be a valid redirect URL"); .withMessage("switchUserUrl cannot be empty and must be a valid redirect URL");
this.switchUserWebFilter.setSwitchUserUrl(null);
fail("Test should fail with exception");
} }
@Test @Test
public void setSwitchUserUrlWhenInvalidThenThrowException() { public void setSwitchUserUrlWhenInvalidThenThrowException() {
this.exceptionRule.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> this.switchUserWebFilter.setSwitchUserUrl("wrongUrl"))
this.exceptionRule.expectMessage("switchUserUrl cannot be empty and must be a valid redirect URL"); .withMessage("switchUserUrl cannot be empty and must be a valid redirect URL");
this.switchUserWebFilter.setSwitchUserUrl("wrongUrl");
fail("Test should fail with exception");
} }
@Test @Test
@ -440,10 +430,8 @@ public class SwitchUserWebFilterTests {
@Test @Test
public void setSwitchUserMatcherWhenNullThenThrowException() { public void setSwitchUserMatcherWhenNullThenThrowException() {
this.exceptionRule.expect(IllegalArgumentException.class); assertThatIllegalArgumentException().isThrownBy(() -> this.switchUserWebFilter.setSwitchUserMatcher(null))
this.exceptionRule.expectMessage("switchUserMatcher cannot be null"); .withMessage("switchUserMatcher cannot be null");
this.switchUserWebFilter.setSwitchUserMatcher(null);
fail("Test should fail with exception");
} }
@Test @Test