Remove restricted static imports

Replace static imports with class referenced methods. With the exception
of a few well known static imports, checkstyle restricts the static
imports that a class can use. For example, `asList(...)` would be
replaced with `Arrays.asList(...)`.

Issue gh-8945
This commit is contained in:
Phillip Webb 2020-07-27 21:34:26 -07:00 committed by Rob Winch
parent 9a3fa6e812
commit e9130489a6
252 changed files with 2216 additions and 2222 deletions

View File

@ -43,10 +43,10 @@ import org.springframework.security.acls.model.ObjectIdentity;
import org.springframework.security.acls.model.Sid; import org.springframework.security.acls.model.Sid;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.AdditionalMatchers.aryEq;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
/** /**
@ -105,7 +105,7 @@ public class JdbcAclServiceTests {
List<ObjectIdentity> result = new ArrayList<>(); List<ObjectIdentity> result = new ArrayList<>();
result.add(new ObjectIdentityImpl(Object.class, "5577")); result.add(new ObjectIdentityImpl(Object.class, "5577"));
Object[] args = { "1", "org.springframework.security.acls.jdbc.JdbcAclServiceTests$MockLongIdDomainObject" }; Object[] args = { "1", "org.springframework.security.acls.jdbc.JdbcAclServiceTests$MockLongIdDomainObject" };
given(this.jdbcOperations.query(anyString(), aryEq(args), any(RowMapper.class))).willReturn(result); given(this.jdbcOperations.query(anyString(), eq(args), any(RowMapper.class))).willReturn(result);
ObjectIdentity objectIdentity = new ObjectIdentityImpl(MockLongIdDomainObject.class, 1L); ObjectIdentity objectIdentity = new ObjectIdentityImpl(MockLongIdDomainObject.class, 1L);
List<ObjectIdentity> objectIdentities = this.aclService.findChildren(objectIdentity); List<ObjectIdentity> objectIdentities = this.aclService.findChildren(objectIdentity);

View File

@ -18,6 +18,7 @@ package org.springframework.security.config.annotation.authentication.ldap;
import java.io.IOException; import java.io.IOException;
import java.net.ServerSocket; import java.net.ServerSocket;
import java.util.Collections;
import java.util.List; import java.util.List;
import javax.naming.directory.SearchControls; import javax.naming.directory.SearchControls;
@ -46,7 +47,6 @@ import org.springframework.security.ldap.userdetails.LdapAuthoritiesPopulator;
import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
import static java.util.Collections.singleton;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin;
import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated; import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated;
@ -117,8 +117,9 @@ public class LdapAuthenticationProviderBuilderSecurityBuilderTests {
public void bindAuthentication() throws Exception { public void bindAuthentication() throws Exception {
this.spring.register(BindAuthenticationConfig.class).autowire(); this.spring.register(BindAuthenticationConfig.class).autowire();
this.mockMvc.perform(formLogin().user("bob").password("bobspassword")).andExpect(authenticated() this.mockMvc.perform(formLogin().user("bob").password("bobspassword"))
.withUsername("bob").withAuthorities(singleton(new SimpleGrantedAuthority("ROLE_DEVELOPERS")))); .andExpect(authenticated().withUsername("bob")
.withAuthorities(Collections.singleton(new SimpleGrantedAuthority("ROLE_DEVELOPERS"))));
} }
// SEC-2472 // SEC-2472
@ -126,8 +127,9 @@ public class LdapAuthenticationProviderBuilderSecurityBuilderTests {
public void canUseCryptoPasswordEncoder() throws Exception { public void canUseCryptoPasswordEncoder() throws Exception {
this.spring.register(PasswordEncoderConfig.class).autowire(); this.spring.register(PasswordEncoderConfig.class).autowire();
this.mockMvc.perform(formLogin().user("bcrypt").password("password")).andExpect(authenticated() this.mockMvc.perform(formLogin().user("bcrypt").password("password"))
.withUsername("bcrypt").withAuthorities(singleton(new SimpleGrantedAuthority("ROLE_DEVELOPERS")))); .andExpect(authenticated().withUsername("bcrypt")
.withAuthorities(Collections.singleton(new SimpleGrantedAuthority("ROLE_DEVELOPERS"))));
} }
private LdapAuthenticationProvider ldapProvider() { private LdapAuthenticationProvider ldapProvider() {

View File

@ -16,6 +16,8 @@
package org.springframework.security.config.annotation.authentication.ldap; package org.springframework.security.config.annotation.authentication.ldap;
import java.util.Collections;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
@ -29,7 +31,6 @@ import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
import static java.util.Collections.singleton;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin;
import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated; import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated;
@ -54,16 +55,18 @@ public class LdapAuthenticationProviderConfigurerTests {
public void authenticationManagerSupportMultipleLdapContextWithDefaultRolePrefix() throws Exception { public void authenticationManagerSupportMultipleLdapContextWithDefaultRolePrefix() throws Exception {
this.spring.register(MultiLdapAuthenticationProvidersConfig.class).autowire(); this.spring.register(MultiLdapAuthenticationProvidersConfig.class).autowire();
this.mockMvc.perform(formLogin().user("bob").password("bobspassword")).andExpect(authenticated() this.mockMvc.perform(formLogin().user("bob").password("bobspassword"))
.withUsername("bob").withAuthorities(singleton(new SimpleGrantedAuthority("ROLE_DEVELOPERS")))); .andExpect(authenticated().withUsername("bob")
.withAuthorities(Collections.singleton(new SimpleGrantedAuthority("ROLE_DEVELOPERS"))));
} }
@Test @Test
public void authenticationManagerSupportMultipleLdapContextWithCustomRolePrefix() throws Exception { public void authenticationManagerSupportMultipleLdapContextWithCustomRolePrefix() throws Exception {
this.spring.register(MultiLdapWithCustomRolePrefixAuthenticationProvidersConfig.class).autowire(); this.spring.register(MultiLdapWithCustomRolePrefixAuthenticationProvidersConfig.class).autowire();
this.mockMvc.perform(formLogin().user("bob").password("bobspassword")).andExpect(authenticated() this.mockMvc.perform(formLogin().user("bob").password("bobspassword"))
.withUsername("bob").withAuthorities(singleton(new SimpleGrantedAuthority("ROL_DEVELOPERS")))); .andExpect(authenticated().withUsername("bob")
.withAuthorities(Collections.singleton(new SimpleGrantedAuthority("ROL_DEVELOPERS"))));
} }
@Test @Test

View File

@ -21,6 +21,7 @@ import java.util.List;
import io.rsocket.RSocketFactory; import io.rsocket.RSocketFactory;
import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.frame.decoder.PayloadDecoder;
import io.rsocket.metadata.WellKnownMimeType;
import io.rsocket.transport.netty.server.CloseableChannel; import io.rsocket.transport.netty.server.CloseableChannel;
import io.rsocket.transport.netty.server.TcpServerTransport; import io.rsocket.transport.netty.server.TcpServerTransport;
import org.junit.After; import org.junit.After;
@ -51,7 +52,6 @@ import org.springframework.test.context.junit4.SpringRunner;
import org.springframework.util.MimeType; import org.springframework.util.MimeType;
import org.springframework.util.MimeTypeUtils; import org.springframework.util.MimeTypeUtils;
import static io.rsocket.metadata.WellKnownMimeType.MESSAGE_RSOCKET_AUTHENTICATION;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
@ -109,7 +109,8 @@ public class JwtITests {
@Test @Test
public void routeWhenAuthenticationBearerThenAuthorized() { public void routeWhenAuthenticationBearerThenAuthorized() {
MimeType authenticationMimeType = MimeTypeUtils.parseMimeType(MESSAGE_RSOCKET_AUTHENTICATION.getString()); MimeType authenticationMimeType = MimeTypeUtils
.parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_AUTHENTICATION.getString());
BearerTokenMetadata credentials = new BearerTokenMetadata("token"); BearerTokenMetadata credentials = new BearerTokenMetadata("token");
given(this.decoder.decode(any())).willReturn(Mono.just(jwt())); given(this.decoder.decode(any())).willReturn(Mono.just(jwt()));

View File

@ -22,6 +22,7 @@ import java.util.List;
import io.rsocket.RSocketFactory; import io.rsocket.RSocketFactory;
import io.rsocket.exceptions.ApplicationErrorException; import io.rsocket.exceptions.ApplicationErrorException;
import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.frame.decoder.PayloadDecoder;
import io.rsocket.metadata.WellKnownMimeType;
import io.rsocket.transport.netty.server.CloseableChannel; import io.rsocket.transport.netty.server.CloseableChannel;
import io.rsocket.transport.netty.server.TcpServerTransport; import io.rsocket.transport.netty.server.TcpServerTransport;
import org.junit.After; import org.junit.After;
@ -50,7 +51,6 @@ import org.springframework.test.context.junit4.SpringRunner;
import org.springframework.util.MimeType; import org.springframework.util.MimeType;
import org.springframework.util.MimeTypeUtils; import org.springframework.util.MimeTypeUtils;
import static io.rsocket.metadata.WellKnownMimeType.MESSAGE_RSOCKET_AUTHENTICATION;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatCode;
@ -101,7 +101,8 @@ public class SimpleAuthenticationITests {
@Test @Test
public void retrieveMonoWhenAuthorizedThenGranted() { public void retrieveMonoWhenAuthorizedThenGranted() {
MimeType authenticationMimeType = MimeTypeUtils.parseMimeType(MESSAGE_RSOCKET_AUTHENTICATION.getString()); MimeType authenticationMimeType = MimeTypeUtils
.parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_AUTHENTICATION.getString());
UsernamePasswordMetadata credentials = new UsernamePasswordMetadata("rob", "password"); UsernamePasswordMetadata credentials = new UsernamePasswordMetadata("rob", "password");
this.requester = RSocketRequester.builder().setupMetadata(credentials, authenticationMimeType) this.requester = RSocketRequester.builder().setupMetadata(credentials, authenticationMimeType)

View File

@ -36,11 +36,6 @@ import org.springframework.security.ldap.userdetails.PersonContextMapper;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.springframework.security.config.ldap.LdapUserServiceBeanDefinitionParser.INET_ORG_PERSON_MAPPER_CLASS;
import static org.springframework.security.config.ldap.LdapUserServiceBeanDefinitionParser.LDAP_AUTHORITIES_POPULATOR_CLASS;
import static org.springframework.security.config.ldap.LdapUserServiceBeanDefinitionParser.LDAP_SEARCH_CLASS;
import static org.springframework.security.config.ldap.LdapUserServiceBeanDefinitionParser.LDAP_USER_MAPPER_CLASS;
import static org.springframework.security.config.ldap.LdapUserServiceBeanDefinitionParser.PERSON_MAPPER_CLASS;
/** /**
* @author Luke Taylor * @author Luke Taylor
@ -61,11 +56,16 @@ public class LdapUserServiceBeanDefinitionParserTests {
@Test @Test
public void beanClassNamesAreCorrect() { public void beanClassNamesAreCorrect() {
assertThat(FilterBasedLdapUserSearch.class.getName()).isEqualTo(LDAP_SEARCH_CLASS); assertThat(FilterBasedLdapUserSearch.class.getName())
assertThat(PersonContextMapper.class.getName()).isEqualTo(PERSON_MAPPER_CLASS); .isEqualTo(LdapUserServiceBeanDefinitionParser.LDAP_SEARCH_CLASS);
assertThat(InetOrgPersonContextMapper.class.getName()).isEqualTo(INET_ORG_PERSON_MAPPER_CLASS); assertThat(PersonContextMapper.class.getName())
assertThat(LdapUserDetailsMapper.class.getName()).isEqualTo(LDAP_USER_MAPPER_CLASS); .isEqualTo(LdapUserServiceBeanDefinitionParser.PERSON_MAPPER_CLASS);
assertThat(DefaultLdapAuthoritiesPopulator.class.getName()).isEqualTo(LDAP_AUTHORITIES_POPULATOR_CLASS); assertThat(InetOrgPersonContextMapper.class.getName())
.isEqualTo(LdapUserServiceBeanDefinitionParser.INET_ORG_PERSON_MAPPER_CLASS);
assertThat(LdapUserDetailsMapper.class.getName())
.isEqualTo(LdapUserServiceBeanDefinitionParser.LDAP_USER_MAPPER_CLASS);
assertThat(DefaultLdapAuthoritiesPopulator.class.getName())
.isEqualTo(LdapUserServiceBeanDefinitionParser.LDAP_AUTHORITIES_POPULATOR_CLASS);
assertThat(new LdapUserServiceBeanDefinitionParser().getBeanClassName(mock(Element.class))) assertThat(new LdapUserServiceBeanDefinitionParser().getBeanClassName(mock(Element.class)))
.isEqualTo(LdapUserDetailsService.class.getName()); .isEqualTo(LdapUserDetailsService.class.getName());
} }

View File

@ -40,8 +40,6 @@ import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes; import org.springframework.web.context.request.ServletRequestAttributes;
import static org.springframework.security.config.annotation.web.configuration.SecurityReactorContextConfiguration.SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES;
/** /**
* {@link Configuration} that (potentially) adds a "decorating" {@code Publisher} for the * {@link Configuration} that (potentially) adds a "decorating" {@code Publisher} for the
* last operator created in every {@code Mono} or {@code Flux}. * last operator created in every {@code Mono} or {@code Flux}.
@ -88,7 +86,7 @@ class SecurityReactorContextConfiguration {
} }
<T> CoreSubscriber<T> createSubscriberIfNecessary(CoreSubscriber<T> delegate) { <T> CoreSubscriber<T> createSubscriberIfNecessary(CoreSubscriber<T> delegate) {
if (delegate.currentContext().hasKey(SECURITY_CONTEXT_ATTRIBUTES)) { if (delegate.currentContext().hasKey(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES)) {
// Already enriched. No need to create Subscriber so return original // Already enriched. No need to create Subscriber so return original
return delegate; return delegate;
} }

View File

@ -51,8 +51,6 @@ import org.springframework.security.web.access.AccessDeniedHandler;
import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withJwkSetUri;
/** /**
* *
* An {@link AbstractHttpConfigurer} for OAuth 2.0 Resource Server Support. * An {@link AbstractHttpConfigurer} for OAuth 2.0 Resource Server Support.
@ -367,7 +365,7 @@ public final class OAuth2ResourceServerConfigurer<H extends HttpSecurityBuilder<
} }
public JwtConfigurer jwkSetUri(String uri) { public JwtConfigurer jwkSetUri(String uri) {
this.decoder = withJwkSetUri(uri).build(); this.decoder = NimbusJwtDecoder.withJwkSetUri(uri).build();
return this; return this;
} }

View File

@ -47,8 +47,7 @@ import org.springframework.security.web.authentication.ui.DefaultLoginPageGenera
import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import static org.springframework.util.StringUtils.hasText;
/** /**
* An {@link AbstractHttpConfigurer} for SAML 2.0 Login, which leverages the SAML 2.0 Web * An {@link AbstractHttpConfigurer} for SAML 2.0 Login, which leverages the SAML 2.0 Web
@ -215,7 +214,7 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
setAuthenticationFilter(this.saml2WebSsoAuthenticationFilter); setAuthenticationFilter(this.saml2WebSsoAuthenticationFilter);
super.loginProcessingUrl(this.loginProcessingUrl); super.loginProcessingUrl(this.loginProcessingUrl);
if (hasText(this.loginPage)) { if (StringUtils.hasText(this.loginPage)) {
// Set custom login page // Set custom login page
super.loginPage(this.loginPage); super.loginPage(this.loginPage);
super.init(http); super.init(http);

View File

@ -68,22 +68,6 @@ import org.springframework.util.ClassUtils;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.util.xml.DomUtils; import org.springframework.util.xml.DomUtils;
import static org.springframework.security.config.http.SecurityFilters.ANONYMOUS_FILTER;
import static org.springframework.security.config.http.SecurityFilters.BASIC_AUTH_FILTER;
import static org.springframework.security.config.http.SecurityFilters.BEARER_TOKEN_AUTH_FILTER;
import static org.springframework.security.config.http.SecurityFilters.EXCEPTION_TRANSLATION_FILTER;
import static org.springframework.security.config.http.SecurityFilters.FORM_LOGIN_FILTER;
import static org.springframework.security.config.http.SecurityFilters.LOGIN_PAGE_FILTER;
import static org.springframework.security.config.http.SecurityFilters.LOGOUT_FILTER;
import static org.springframework.security.config.http.SecurityFilters.LOGOUT_PAGE_FILTER;
import static org.springframework.security.config.http.SecurityFilters.OAUTH2_AUTHORIZATION_CODE_GRANT_FILTER;
import static org.springframework.security.config.http.SecurityFilters.OAUTH2_AUTHORIZATION_REQUEST_FILTER;
import static org.springframework.security.config.http.SecurityFilters.OAUTH2_LOGIN_FILTER;
import static org.springframework.security.config.http.SecurityFilters.OPENID_FILTER;
import static org.springframework.security.config.http.SecurityFilters.PRE_AUTH_FILTER;
import static org.springframework.security.config.http.SecurityFilters.REMEMBER_ME_FILTER;
import static org.springframework.security.config.http.SecurityFilters.X509_FILTER;
/** /**
* Handles creation of authentication mechanism filters and related beans for &lt;http&gt; * Handles creation of authentication mechanism filters and related beans for &lt;http&gt;
* parsing. * parsing.
@ -993,59 +977,64 @@ final class AuthenticationConfigBuilder {
List<OrderDecorator> filters = new ArrayList<>(); List<OrderDecorator> filters = new ArrayList<>();
if (this.anonymousFilter != null) { if (this.anonymousFilter != null) {
filters.add(new OrderDecorator(this.anonymousFilter, ANONYMOUS_FILTER)); filters.add(new OrderDecorator(this.anonymousFilter, SecurityFilters.ANONYMOUS_FILTER));
} }
if (this.rememberMeFilter != null) { if (this.rememberMeFilter != null) {
filters.add(new OrderDecorator(this.rememberMeFilter, REMEMBER_ME_FILTER)); filters.add(new OrderDecorator(this.rememberMeFilter, SecurityFilters.REMEMBER_ME_FILTER));
} }
if (this.logoutFilter != null) { if (this.logoutFilter != null) {
filters.add(new OrderDecorator(this.logoutFilter, LOGOUT_FILTER)); filters.add(new OrderDecorator(this.logoutFilter, SecurityFilters.LOGOUT_FILTER));
} }
if (this.x509Filter != null) { if (this.x509Filter != null) {
filters.add(new OrderDecorator(this.x509Filter, X509_FILTER)); filters.add(new OrderDecorator(this.x509Filter, SecurityFilters.X509_FILTER));
} }
if (this.jeeFilter != null) { if (this.jeeFilter != null) {
filters.add(new OrderDecorator(this.jeeFilter, PRE_AUTH_FILTER)); filters.add(new OrderDecorator(this.jeeFilter, SecurityFilters.PRE_AUTH_FILTER));
} }
if (this.formFilterId != null) { if (this.formFilterId != null) {
filters.add(new OrderDecorator(new RuntimeBeanReference(this.formFilterId), FORM_LOGIN_FILTER)); filters.add(
new OrderDecorator(new RuntimeBeanReference(this.formFilterId), SecurityFilters.FORM_LOGIN_FILTER));
} }
if (this.oauth2LoginFilterId != null) { if (this.oauth2LoginFilterId != null) {
filters.add(new OrderDecorator(new RuntimeBeanReference(this.oauth2LoginFilterId), OAUTH2_LOGIN_FILTER)); filters.add(new OrderDecorator(new RuntimeBeanReference(this.oauth2LoginFilterId),
SecurityFilters.OAUTH2_LOGIN_FILTER));
filters.add(new OrderDecorator(this.oauth2AuthorizationRequestRedirectFilter, filters.add(new OrderDecorator(this.oauth2AuthorizationRequestRedirectFilter,
OAUTH2_AUTHORIZATION_REQUEST_FILTER)); SecurityFilters.OAUTH2_AUTHORIZATION_REQUEST_FILTER));
} }
if (this.openIDFilterId != null) { if (this.openIDFilterId != null) {
filters.add(new OrderDecorator(new RuntimeBeanReference(this.openIDFilterId), OPENID_FILTER)); filters.add(
new OrderDecorator(new RuntimeBeanReference(this.openIDFilterId), SecurityFilters.OPENID_FILTER));
} }
if (this.loginPageGenerationFilter != null) { if (this.loginPageGenerationFilter != null) {
filters.add(new OrderDecorator(this.loginPageGenerationFilter, LOGIN_PAGE_FILTER)); filters.add(new OrderDecorator(this.loginPageGenerationFilter, SecurityFilters.LOGIN_PAGE_FILTER));
filters.add(new OrderDecorator(this.logoutPageGenerationFilter, LOGOUT_PAGE_FILTER)); filters.add(new OrderDecorator(this.logoutPageGenerationFilter, SecurityFilters.LOGOUT_PAGE_FILTER));
} }
if (this.basicFilter != null) { if (this.basicFilter != null) {
filters.add(new OrderDecorator(this.basicFilter, BASIC_AUTH_FILTER)); filters.add(new OrderDecorator(this.basicFilter, SecurityFilters.BASIC_AUTH_FILTER));
} }
if (this.bearerTokenAuthenticationFilter != null) { if (this.bearerTokenAuthenticationFilter != null) {
filters.add(new OrderDecorator(this.bearerTokenAuthenticationFilter, BEARER_TOKEN_AUTH_FILTER)); filters.add(
new OrderDecorator(this.bearerTokenAuthenticationFilter, SecurityFilters.BEARER_TOKEN_AUTH_FILTER));
} }
if (this.authorizationCodeGrantFilter != null) { if (this.authorizationCodeGrantFilter != null) {
filters.add(new OrderDecorator(this.authorizationRequestRedirectFilter, filters.add(new OrderDecorator(this.authorizationRequestRedirectFilter,
OAUTH2_AUTHORIZATION_REQUEST_FILTER.getOrder() + 1)); SecurityFilters.OAUTH2_AUTHORIZATION_REQUEST_FILTER.getOrder() + 1));
filters.add(new OrderDecorator(this.authorizationCodeGrantFilter, OAUTH2_AUTHORIZATION_CODE_GRANT_FILTER)); filters.add(new OrderDecorator(this.authorizationCodeGrantFilter,
SecurityFilters.OAUTH2_AUTHORIZATION_CODE_GRANT_FILTER));
} }
filters.add(new OrderDecorator(this.etf, EXCEPTION_TRANSLATION_FILTER)); filters.add(new OrderDecorator(this.etf, SecurityFilters.EXCEPTION_TRANSLATION_FILTER));
return filters; return filters;
} }

View File

@ -40,8 +40,6 @@ import org.springframework.security.web.access.intercept.FilterInvocationSecurit
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.util.xml.DomUtils; import org.springframework.util.xml.DomUtils;
import static org.springframework.security.config.http.HttpSecurityBeanDefinitionParser.ATT_REQUEST_MATCHER_REF;
/** /**
* Allows for convenient creation of a {@link FilterInvocationSecurityMetadataSource} bean * Allows for convenient creation of a {@link FilterInvocationSecurityMetadataSource} bean
* for use with a FilterSecurityInterceptor. * for use with a FilterSecurityInterceptor.
@ -161,7 +159,7 @@ public class FilterInvocationSecurityMetadataSourceParser implements BeanDefinit
} }
String path = urlElt.getAttribute(ATT_PATTERN); String path = urlElt.getAttribute(ATT_PATTERN);
String matcherRef = urlElt.getAttribute(ATT_REQUEST_MATCHER_REF); String matcherRef = urlElt.getAttribute(HttpSecurityBeanDefinitionParser.ATT_REQUEST_MATCHER_REF);
boolean hasMatcherRef = StringUtils.hasText(matcherRef); boolean hasMatcherRef = StringUtils.hasText(matcherRef);
if (!hasMatcherRef && !StringUtils.hasText(path)) { if (!hasMatcherRef && !StringUtils.hasText(path)) {

View File

@ -74,24 +74,6 @@ import org.springframework.util.ClassUtils;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.util.xml.DomUtils; import org.springframework.util.xml.DomUtils;
import static org.springframework.security.config.http.HttpSecurityBeanDefinitionParser.ATT_FILTERS;
import static org.springframework.security.config.http.HttpSecurityBeanDefinitionParser.ATT_HTTP_METHOD;
import static org.springframework.security.config.http.HttpSecurityBeanDefinitionParser.ATT_PATH_PATTERN;
import static org.springframework.security.config.http.HttpSecurityBeanDefinitionParser.ATT_REQUEST_MATCHER_REF;
import static org.springframework.security.config.http.HttpSecurityBeanDefinitionParser.ATT_REQUIRES_CHANNEL;
import static org.springframework.security.config.http.SecurityFilters.CHANNEL_FILTER;
import static org.springframework.security.config.http.SecurityFilters.CONCURRENT_SESSION_FILTER;
import static org.springframework.security.config.http.SecurityFilters.CORS_FILTER;
import static org.springframework.security.config.http.SecurityFilters.CSRF_FILTER;
import static org.springframework.security.config.http.SecurityFilters.FILTER_SECURITY_INTERCEPTOR;
import static org.springframework.security.config.http.SecurityFilters.HEADERS_FILTER;
import static org.springframework.security.config.http.SecurityFilters.JAAS_API_SUPPORT_FILTER;
import static org.springframework.security.config.http.SecurityFilters.REQUEST_CACHE_FILTER;
import static org.springframework.security.config.http.SecurityFilters.SECURITY_CONTEXT_FILTER;
import static org.springframework.security.config.http.SecurityFilters.SERVLET_API_SUPPORT_FILTER;
import static org.springframework.security.config.http.SecurityFilters.SESSION_MANAGEMENT_FILTER;
import static org.springframework.security.config.http.SecurityFilters.WEB_ASYNC_MANAGER_FILTER;
/** /**
* Stateful class which helps HttpSecurityBDP to create the configuration for the * Stateful class which helps HttpSecurityBDP to create the configuration for the
* &lt;http&gt; element. * &lt;http&gt; element.
@ -197,7 +179,7 @@ class HttpConfigurationBuilder {
this.interceptUrls = DomUtils.getChildElementsByTagName(element, Elements.INTERCEPT_URL); this.interceptUrls = DomUtils.getChildElementsByTagName(element, Elements.INTERCEPT_URL);
for (Element urlElt : this.interceptUrls) { for (Element urlElt : this.interceptUrls) {
if (StringUtils.hasText(urlElt.getAttribute(ATT_FILTERS))) { if (StringUtils.hasText(urlElt.getAttribute(HttpSecurityBeanDefinitionParser.ATT_FILTERS))) {
pc.getReaderContext() pc.getReaderContext()
.error("The use of \"filters='none'\" is no longer supported. Please define a" .error("The use of \"filters='none'\" is no longer supported. Please define a"
+ " separate <http> element for the pattern you want to exclude and use the attribute" + " separate <http> element for the pattern you want to exclude and use the attribute"
@ -637,16 +619,16 @@ class HttpConfigurationBuilder {
ManagedMap<BeanMetadataElement, BeanDefinition> channelRequestMap = new ManagedMap<>(); ManagedMap<BeanMetadataElement, BeanDefinition> channelRequestMap = new ManagedMap<>();
for (Element urlElt : this.interceptUrls) { for (Element urlElt : this.interceptUrls) {
String path = urlElt.getAttribute(ATT_PATH_PATTERN); String path = urlElt.getAttribute(HttpSecurityBeanDefinitionParser.ATT_PATH_PATTERN);
String method = urlElt.getAttribute(ATT_HTTP_METHOD); String method = urlElt.getAttribute(HttpSecurityBeanDefinitionParser.ATT_HTTP_METHOD);
String matcherRef = urlElt.getAttribute(ATT_REQUEST_MATCHER_REF); String matcherRef = urlElt.getAttribute(HttpSecurityBeanDefinitionParser.ATT_REQUEST_MATCHER_REF);
boolean hasMatcherRef = StringUtils.hasText(matcherRef); boolean hasMatcherRef = StringUtils.hasText(matcherRef);
if (!hasMatcherRef && !StringUtils.hasText(path)) { if (!hasMatcherRef && !StringUtils.hasText(path)) {
this.pc.getReaderContext().error("pattern attribute cannot be empty or null", urlElt); this.pc.getReaderContext().error("pattern attribute cannot be empty or null", urlElt);
} }
String requiredChannel = urlElt.getAttribute(ATT_REQUIRES_CHANNEL); String requiredChannel = urlElt.getAttribute(HttpSecurityBeanDefinitionParser.ATT_REQUIRES_CHANNEL);
if (StringUtils.hasText(requiredChannel)) { if (StringUtils.hasText(requiredChannel)) {
BeanMetadataElement matcher = hasMatcherRef ? new RuntimeBeanReference(matcherRef) BeanMetadataElement matcher = hasMatcherRef ? new RuntimeBeanReference(matcherRef)
@ -805,47 +787,47 @@ class HttpConfigurationBuilder {
List<OrderDecorator> filters = new ArrayList<>(); List<OrderDecorator> filters = new ArrayList<>();
if (this.cpf != null) { if (this.cpf != null) {
filters.add(new OrderDecorator(this.cpf, CHANNEL_FILTER)); filters.add(new OrderDecorator(this.cpf, SecurityFilters.CHANNEL_FILTER));
} }
if (this.concurrentSessionFilter != null) { if (this.concurrentSessionFilter != null) {
filters.add(new OrderDecorator(this.concurrentSessionFilter, CONCURRENT_SESSION_FILTER)); filters.add(new OrderDecorator(this.concurrentSessionFilter, SecurityFilters.CONCURRENT_SESSION_FILTER));
} }
if (this.webAsyncManagerFilter != null) { if (this.webAsyncManagerFilter != null) {
filters.add(new OrderDecorator(this.webAsyncManagerFilter, WEB_ASYNC_MANAGER_FILTER)); filters.add(new OrderDecorator(this.webAsyncManagerFilter, SecurityFilters.WEB_ASYNC_MANAGER_FILTER));
} }
filters.add(new OrderDecorator(this.securityContextPersistenceFilter, SECURITY_CONTEXT_FILTER)); filters.add(new OrderDecorator(this.securityContextPersistenceFilter, SecurityFilters.SECURITY_CONTEXT_FILTER));
if (this.servApiFilter != null) { if (this.servApiFilter != null) {
filters.add(new OrderDecorator(this.servApiFilter, SERVLET_API_SUPPORT_FILTER)); filters.add(new OrderDecorator(this.servApiFilter, SecurityFilters.SERVLET_API_SUPPORT_FILTER));
} }
if (this.jaasApiFilter != null) { if (this.jaasApiFilter != null) {
filters.add(new OrderDecorator(this.jaasApiFilter, JAAS_API_SUPPORT_FILTER)); filters.add(new OrderDecorator(this.jaasApiFilter, SecurityFilters.JAAS_API_SUPPORT_FILTER));
} }
if (this.sfpf != null) { if (this.sfpf != null) {
filters.add(new OrderDecorator(this.sfpf, SESSION_MANAGEMENT_FILTER)); filters.add(new OrderDecorator(this.sfpf, SecurityFilters.SESSION_MANAGEMENT_FILTER));
} }
filters.add(new OrderDecorator(this.fsi, FILTER_SECURITY_INTERCEPTOR)); filters.add(new OrderDecorator(this.fsi, SecurityFilters.FILTER_SECURITY_INTERCEPTOR));
if (this.sessionPolicy != SessionCreationPolicy.STATELESS) { if (this.sessionPolicy != SessionCreationPolicy.STATELESS) {
filters.add(new OrderDecorator(this.requestCacheAwareFilter, REQUEST_CACHE_FILTER)); filters.add(new OrderDecorator(this.requestCacheAwareFilter, SecurityFilters.REQUEST_CACHE_FILTER));
} }
if (this.corsFilter != null) { if (this.corsFilter != null) {
filters.add(new OrderDecorator(this.corsFilter, CORS_FILTER)); filters.add(new OrderDecorator(this.corsFilter, SecurityFilters.CORS_FILTER));
} }
if (this.addHeadersFilter != null) { if (this.addHeadersFilter != null) {
filters.add(new OrderDecorator(this.addHeadersFilter, HEADERS_FILTER)); filters.add(new OrderDecorator(this.addHeadersFilter, SecurityFilters.HEADERS_FILTER));
} }
if (this.csrfFilter != null) { if (this.csrfFilter != null) {
filters.add(new OrderDecorator(this.csrfFilter, CSRF_FILTER)); filters.add(new OrderDecorator(this.csrfFilter, SecurityFilters.CSRF_FILTER));
} }
return filters; return filters;

View File

@ -31,11 +31,6 @@ import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepo
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.util.xml.DomUtils; import org.springframework.util.xml.DomUtils;
import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.createDefaultAuthorizedClientRepository;
import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.getAuthorizedClientRepository;
import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.getAuthorizedClientService;
import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.getClientRegistrationRepository;
/** /**
* @author Joe Grandja * @author Joe Grandja
* @since 5.3 * @since 5.3
@ -71,12 +66,15 @@ final class OAuth2ClientBeanDefinitionParser implements BeanDefinitionParser {
public BeanDefinition parse(Element element, ParserContext parserContext) { public BeanDefinition parse(Element element, ParserContext parserContext) {
Element authorizationCodeGrantElt = DomUtils.getChildElementByTagName(element, ELT_AUTHORIZATION_CODE_GRANT); Element authorizationCodeGrantElt = DomUtils.getChildElementByTagName(element, ELT_AUTHORIZATION_CODE_GRANT);
BeanMetadataElement clientRegistrationRepository = getClientRegistrationRepository(element); BeanMetadataElement clientRegistrationRepository = OAuth2ClientBeanDefinitionParserUtils
BeanMetadataElement authorizedClientRepository = getAuthorizedClientRepository(element); .getClientRegistrationRepository(element);
BeanMetadataElement authorizedClientRepository = OAuth2ClientBeanDefinitionParserUtils
.getAuthorizedClientRepository(element);
if (authorizedClientRepository == null) { if (authorizedClientRepository == null) {
BeanMetadataElement authorizedClientService = getAuthorizedClientService(element); BeanMetadataElement authorizedClientService = OAuth2ClientBeanDefinitionParserUtils
this.defaultAuthorizedClientRepository = createDefaultAuthorizedClientRepository( .getAuthorizedClientService(element);
clientRegistrationRepository, authorizedClientService); this.defaultAuthorizedClientRepository = OAuth2ClientBeanDefinitionParserUtils
.createDefaultAuthorizedClientRepository(clientRegistrationRepository, authorizedClientService);
authorizedClientRepository = new RuntimeBeanReference(OAuth2AuthorizedClientRepository.class); authorizedClientRepository = new RuntimeBeanReference(OAuth2AuthorizedClientRepository.class);
} }
BeanMetadataElement authorizationRequestRepository = getAuthorizationRequestRepository( BeanMetadataElement authorizationRequestRepository = getAuthorizationRequestRepository(

View File

@ -68,11 +68,6 @@ import org.springframework.util.xml.DomUtils;
import org.springframework.web.accept.ContentNegotiationStrategy; import org.springframework.web.accept.ContentNegotiationStrategy;
import org.springframework.web.accept.HeaderContentNegotiationStrategy; import org.springframework.web.accept.HeaderContentNegotiationStrategy;
import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.createDefaultAuthorizedClientRepository;
import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.getAuthorizedClientRepository;
import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.getAuthorizedClientService;
import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.getClientRegistrationRepository;
/** /**
* @author Ruby Hartono * @author Ruby Hartono
* @since 5.3 * @since 5.3
@ -150,12 +145,15 @@ final class OAuth2LoginBeanDefinitionParser implements BeanDefinitionParser {
.registerBeanComponent(new BeanComponentDefinition(oauth2LoginBeanConfig, oauth2LoginBeanConfigId)); .registerBeanComponent(new BeanComponentDefinition(oauth2LoginBeanConfig, oauth2LoginBeanConfigId));
// configure filter // configure filter
BeanMetadataElement clientRegistrationRepository = getClientRegistrationRepository(element); BeanMetadataElement clientRegistrationRepository = OAuth2ClientBeanDefinitionParserUtils
BeanMetadataElement authorizedClientRepository = getAuthorizedClientRepository(element); .getClientRegistrationRepository(element);
BeanMetadataElement authorizedClientRepository = OAuth2ClientBeanDefinitionParserUtils
.getAuthorizedClientRepository(element);
if (authorizedClientRepository == null) { if (authorizedClientRepository == null) {
BeanMetadataElement authorizedClientService = getAuthorizedClientService(element); BeanMetadataElement authorizedClientService = OAuth2ClientBeanDefinitionParserUtils
this.defaultAuthorizedClientRepository = createDefaultAuthorizedClientRepository( .getAuthorizedClientService(element);
clientRegistrationRepository, authorizedClientService); this.defaultAuthorizedClientRepository = OAuth2ClientBeanDefinitionParserUtils
.createDefaultAuthorizedClientRepository(clientRegistrationRepository, authorizedClientService);
authorizedClientRepository = new RuntimeBeanReference(OAuth2AuthorizedClientRepository.class); authorizedClientRepository = new RuntimeBeanReference(OAuth2AuthorizedClientRepository.class);
} }
BeanMetadataElement accessTokenResponseClient = getAccessTokenResponseClient(element); BeanMetadataElement accessTokenResponseClient = getAccessTokenResponseClient(element);

View File

@ -80,13 +80,6 @@ import org.springframework.util.Assert;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.util.xml.DomUtils; import org.springframework.util.xml.DomUtils;
import static org.springframework.security.config.Elements.EXPRESSION_HANDLER;
import static org.springframework.security.config.Elements.INVOCATION_ATTRIBUTE_FACTORY;
import static org.springframework.security.config.Elements.INVOCATION_HANDLING;
import static org.springframework.security.config.Elements.POST_INVOCATION_ADVICE;
import static org.springframework.security.config.Elements.PRE_INVOCATION_ADVICE;
import static org.springframework.security.config.Elements.PROTECT_POINTCUT;
/** /**
* Processes the top-level "global-method-security" element. * Processes the top-level "global-method-security" element.
* *
@ -150,12 +143,12 @@ public class GlobalMethodSecurityBeanDefinitionParser implements BeanDefinitionP
} }
if (prePostAnnotationsEnabled) { if (prePostAnnotationsEnabled) {
Element prePostElt = DomUtils.getChildElementByTagName(element, INVOCATION_HANDLING); Element prePostElt = DomUtils.getChildElementByTagName(element, Elements.INVOCATION_HANDLING);
Element expressionHandlerElt = DomUtils.getChildElementByTagName(element, EXPRESSION_HANDLER); Element expressionHandlerElt = DomUtils.getChildElementByTagName(element, Elements.EXPRESSION_HANDLER);
if (prePostElt != null && expressionHandlerElt != null) { if (prePostElt != null && expressionHandlerElt != null) {
pc.getReaderContext().error( pc.getReaderContext().error(Elements.INVOCATION_HANDLING + " and " + Elements.EXPRESSION_HANDLER
INVOCATION_HANDLING + " and " + EXPRESSION_HANDLER + " cannot be used together ", source); + " cannot be used together ", source);
} }
BeanDefinitionBuilder preInvocationVoterBldr = BeanDefinitionBuilder BeanDefinitionBuilder preInvocationVoterBldr = BeanDefinitionBuilder
@ -170,11 +163,12 @@ public class GlobalMethodSecurityBeanDefinitionParser implements BeanDefinitionP
if (prePostElt != null) { if (prePostElt != null) {
// Customized override of expression handling system // Customized override of expression handling system
String attributeFactoryRef = DomUtils.getChildElementByTagName(prePostElt, INVOCATION_ATTRIBUTE_FACTORY) String attributeFactoryRef = DomUtils
.getChildElementByTagName(prePostElt, Elements.INVOCATION_ATTRIBUTE_FACTORY)
.getAttribute("ref"); .getAttribute("ref");
String preAdviceRef = DomUtils.getChildElementByTagName(prePostElt, PRE_INVOCATION_ADVICE) String preAdviceRef = DomUtils.getChildElementByTagName(prePostElt, Elements.PRE_INVOCATION_ADVICE)
.getAttribute("ref"); .getAttribute("ref");
String postAdviceRef = DomUtils.getChildElementByTagName(prePostElt, POST_INVOCATION_ADVICE) String postAdviceRef = DomUtils.getChildElementByTagName(prePostElt, Elements.POST_INVOCATION_ADVICE)
.getAttribute("ref"); .getAttribute("ref");
mds.addConstructorArgReference(attributeFactoryRef); mds.addConstructorArgReference(attributeFactoryRef);
@ -257,7 +251,7 @@ public class GlobalMethodSecurityBeanDefinitionParser implements BeanDefinitionP
// Now create a Map<String, ConfigAttribute> for each <protect-pointcut> // Now create a Map<String, ConfigAttribute> for each <protect-pointcut>
// sub-element // sub-element
Map<String, List<ConfigAttribute>> pointcutMap = parseProtectPointcuts(pc, Map<String, List<ConfigAttribute>> pointcutMap = parseProtectPointcuts(pc,
DomUtils.getChildElementsByTagName(element, PROTECT_POINTCUT)); DomUtils.getChildElementsByTagName(element, Elements.PROTECT_POINTCUT));
if (pointcutMap.size() > 0) { if (pointcutMap.size() > 0) {
if (useAspectJ) { if (useAspectJ) {

View File

@ -54,8 +54,6 @@ import org.springframework.util.PathMatcher;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.util.xml.DomUtils; import org.springframework.util.xml.DomUtils;
import static org.springframework.security.config.Elements.EXPRESSION_HANDLER;
/** /**
* Parses Spring Security's websocket namespace support. A simple example is: * Parses Spring Security's websocket namespace support. A simple example is:
* *
@ -121,7 +119,7 @@ public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements
ManagedMap<BeanDefinition, String> matcherToExpression = new ManagedMap<>(); ManagedMap<BeanDefinition, String> matcherToExpression = new ManagedMap<>();
String id = element.getAttribute(ID_ATTR); String id = element.getAttribute(ID_ATTR);
Element expressionHandlerElt = DomUtils.getChildElementByTagName(element, EXPRESSION_HANDLER); Element expressionHandlerElt = DomUtils.getChildElementByTagName(element, Elements.EXPRESSION_HANDLER);
String expressionHandlerRef = expressionHandlerElt == null ? null : expressionHandlerElt.getAttribute("ref"); String expressionHandlerRef = expressionHandlerElt == null ? null : expressionHandlerElt.getAttribute("ref");
boolean expressionHandlerDefined = StringUtils.hasText(expressionHandlerRef); boolean expressionHandlerDefined = StringUtils.hasText(expressionHandlerRef);

View File

@ -20,6 +20,7 @@ import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException; import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PowerMockIgnore; import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner; import org.powermock.modules.junit4.PowerMockRunner;
@ -34,11 +35,8 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.fail; import static org.assertj.core.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.powermock.api.mockito.PowerMockito.doThrow; import static org.mockito.Mockito.mock;
import static org.powermock.api.mockito.PowerMockito.mock; import static org.mockito.Mockito.verifyZeroInteractions;
import static org.powermock.api.mockito.PowerMockito.spy;
import static org.powermock.api.mockito.PowerMockito.verifyStatic;
import static org.powermock.api.mockito.PowerMockito.verifyZeroInteractions;
/** /**
* @author Luke Taylor * @author Luke Taylor
@ -88,9 +86,9 @@ public class SecurityNamespaceHandlerTests {
@Test @Test
public void initDoesNotLogErrorWhenFilterChainProxyFailsToLoad() throws Exception { public void initDoesNotLogErrorWhenFilterChainProxyFailsToLoad() throws Exception {
String className = "javax.servlet.Filter"; String className = "javax.servlet.Filter";
spy(ClassUtils.class); PowerMockito.spy(ClassUtils.class);
doThrow(new NoClassDefFoundError(className)).when(ClassUtils.class, "forName", eq(FILTER_CHAIN_PROXY_CLASSNAME), PowerMockito.doThrow(new NoClassDefFoundError(className)).when(ClassUtils.class, "forName",
any(ClassLoader.class)); eq(FILTER_CHAIN_PROXY_CLASSNAME), any(ClassLoader.class));
Log logger = mock(Log.class); Log logger = mock(Log.class);
SecurityNamespaceHandler handler = new SecurityNamespaceHandler(); SecurityNamespaceHandler handler = new SecurityNamespaceHandler();
@ -98,7 +96,7 @@ public class SecurityNamespaceHandlerTests {
handler.init(); handler.init();
verifyStatic(ClassUtils.class); PowerMockito.verifyStatic(ClassUtils.class);
ClassUtils.forName(eq(FILTER_CHAIN_PROXY_CLASSNAME), any(ClassLoader.class)); ClassUtils.forName(eq(FILTER_CHAIN_PROXY_CLASSNAME), any(ClassLoader.class));
verifyZeroInteractions(logger); verifyZeroInteractions(logger);
} }
@ -108,18 +106,18 @@ public class SecurityNamespaceHandlerTests {
String className = "javax.servlet.Filter"; String className = "javax.servlet.Filter";
this.thrown.expect(BeanDefinitionParsingException.class); this.thrown.expect(BeanDefinitionParsingException.class);
this.thrown.expectMessage("NoClassDefFoundError: " + className); this.thrown.expectMessage("NoClassDefFoundError: " + className);
spy(ClassUtils.class); PowerMockito.spy(ClassUtils.class);
doThrow(new NoClassDefFoundError(className)).when(ClassUtils.class, "forName", eq(FILTER_CHAIN_PROXY_CLASSNAME), PowerMockito.doThrow(new NoClassDefFoundError(className)).when(ClassUtils.class, "forName",
any(ClassLoader.class)); eq(FILTER_CHAIN_PROXY_CLASSNAME), any(ClassLoader.class));
new InMemoryXmlApplicationContext(XML_AUTHENTICATION_MANAGER + XML_HTTP_BLOCK); new InMemoryXmlApplicationContext(XML_AUTHENTICATION_MANAGER + XML_HTTP_BLOCK);
} }
@Test @Test
public void filterNoClassDefFoundErrorNoHttpBlock() throws Exception { public void filterNoClassDefFoundErrorNoHttpBlock() throws Exception {
String className = "javax.servlet.Filter"; String className = "javax.servlet.Filter";
spy(ClassUtils.class); PowerMockito.spy(ClassUtils.class);
doThrow(new NoClassDefFoundError(className)).when(ClassUtils.class, "forName", eq(FILTER_CHAIN_PROXY_CLASSNAME), PowerMockito.doThrow(new NoClassDefFoundError(className)).when(ClassUtils.class, "forName",
any(ClassLoader.class)); eq(FILTER_CHAIN_PROXY_CLASSNAME), any(ClassLoader.class));
new InMemoryXmlApplicationContext(XML_AUTHENTICATION_MANAGER); new InMemoryXmlApplicationContext(XML_AUTHENTICATION_MANAGER);
// should load just fine since no http block // should load just fine since no http block
} }
@ -129,8 +127,8 @@ public class SecurityNamespaceHandlerTests {
String className = FILTER_CHAIN_PROXY_CLASSNAME; String className = FILTER_CHAIN_PROXY_CLASSNAME;
this.thrown.expect(BeanDefinitionParsingException.class); this.thrown.expect(BeanDefinitionParsingException.class);
this.thrown.expectMessage("ClassNotFoundException: " + className); this.thrown.expectMessage("ClassNotFoundException: " + className);
spy(ClassUtils.class); PowerMockito.spy(ClassUtils.class);
doThrow(new ClassNotFoundException(className)).when(ClassUtils.class, "forName", PowerMockito.doThrow(new ClassNotFoundException(className)).when(ClassUtils.class, "forName",
eq(FILTER_CHAIN_PROXY_CLASSNAME), any(ClassLoader.class)); eq(FILTER_CHAIN_PROXY_CLASSNAME), any(ClassLoader.class));
new InMemoryXmlApplicationContext(XML_AUTHENTICATION_MANAGER + XML_HTTP_BLOCK); new InMemoryXmlApplicationContext(XML_AUTHENTICATION_MANAGER + XML_HTTP_BLOCK);
} }
@ -138,8 +136,8 @@ public class SecurityNamespaceHandlerTests {
@Test @Test
public void filterChainProxyClassNotFoundExceptionNoHttpBlock() throws Exception { public void filterChainProxyClassNotFoundExceptionNoHttpBlock() throws Exception {
String className = FILTER_CHAIN_PROXY_CLASSNAME; String className = FILTER_CHAIN_PROXY_CLASSNAME;
spy(ClassUtils.class); PowerMockito.spy(ClassUtils.class);
doThrow(new ClassNotFoundException(className)).when(ClassUtils.class, "forName", PowerMockito.doThrow(new ClassNotFoundException(className)).when(ClassUtils.class, "forName",
eq(FILTER_CHAIN_PROXY_CLASSNAME), any(ClassLoader.class)); eq(FILTER_CHAIN_PROXY_CLASSNAME), any(ClassLoader.class));
new InMemoryXmlApplicationContext(XML_AUTHENTICATION_MANAGER); new InMemoryXmlApplicationContext(XML_AUTHENTICATION_MANAGER);
// should load just fine since no http block // should load just fine since no http block
@ -148,9 +146,9 @@ public class SecurityNamespaceHandlerTests {
@Test @Test
public void websocketNotFoundExceptionNoMessageBlock() throws Exception { public void websocketNotFoundExceptionNoMessageBlock() throws Exception {
String className = FILTER_CHAIN_PROXY_CLASSNAME; String className = FILTER_CHAIN_PROXY_CLASSNAME;
spy(ClassUtils.class); PowerMockito.spy(ClassUtils.class);
doThrow(new ClassNotFoundException(className)).when(ClassUtils.class, "forName", eq(Message.class.getName()), PowerMockito.doThrow(new ClassNotFoundException(className)).when(ClassUtils.class, "forName",
any(ClassLoader.class)); eq(Message.class.getName()), any(ClassLoader.class));
new InMemoryXmlApplicationContext(XML_AUTHENTICATION_MANAGER); new InMemoryXmlApplicationContext(XML_AUTHENTICATION_MANAGER);
// should load just fine since no websocket block // should load just fine since no websocket block
} }

View File

@ -24,6 +24,8 @@ import org.springframework.security.config.annotation.web.configuration.EnableWe
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.config.test.SpringTestRule;
import org.springframework.security.core.userdetails.PasswordEncodedUser; import org.springframework.security.core.userdetails.PasswordEncodedUser;
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders;
import org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
@ -70,8 +72,8 @@ public class NamespaceAuthenticationManagerTests {
public void authenticationManagerWhenGlobalAndEraseCredentialsIsFalseThenCredentialsNotNull() throws Exception { public void authenticationManagerWhenGlobalAndEraseCredentialsIsFalseThenCredentialsNotNull() throws Exception {
this.spring.register(GlobalEraseCredentialsFalseConfig.class).autowire(); this.spring.register(GlobalEraseCredentialsFalseConfig.class).autowire();
this.mockMvc.perform(formLogin()) this.mockMvc.perform(SecurityMockMvcRequestBuilders.formLogin()).andExpect(SecurityMockMvcResultMatchers
.andExpect(authenticated().withAuthentication(a -> assertThat(a.getCredentials()).isNotNull())); .authenticated().withAuthentication(a -> assertThat(a.getCredentials()).isNotNull()));
} }
@EnableWebSecurity @EnableWebSecurity

View File

@ -34,7 +34,7 @@ import org.springframework.security.config.annotation.ObjectPostProcessor;
import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.config.test.SpringTestRule;
import org.springframework.web.context.ServletContextAware; import org.springframework.web.context.ServletContextAware;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.isNotNull; import static org.mockito.ArgumentMatchers.isNotNull;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;

View File

@ -22,6 +22,7 @@ import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PowerMockIgnore; import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner; import org.powermock.modules.junit4.PowerMockRunner;
@ -48,8 +49,6 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.powermock.api.mockito.PowerMockito.spy;
import static org.powermock.api.mockito.PowerMockito.when;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
/** /**
@ -79,9 +78,10 @@ public class WebSecurityConfigurerAdapterPowermockTests {
@Test @Test
public void loadConfigWhenDefaultConfigurerAsSpringFactoryhenDefaultConfigurerApplied() { public void loadConfigWhenDefaultConfigurerAsSpringFactoryhenDefaultConfigurerApplied() {
spy(SpringFactoriesLoader.class); PowerMockito.spy(SpringFactoriesLoader.class);
DefaultConfigurer configurer = new DefaultConfigurer(); DefaultConfigurer configurer = new DefaultConfigurer();
when(SpringFactoriesLoader.loadFactories(AbstractHttpConfigurer.class, getClass().getClassLoader())) PowerMockito
.when(SpringFactoriesLoader.loadFactories(AbstractHttpConfigurer.class, getClass().getClassLoader()))
.thenReturn(Arrays.<AbstractHttpConfigurer>asList(configurer)); .thenReturn(Arrays.<AbstractHttpConfigurer>asList(configurer));
loadConfig(Config.class); loadConfig(Config.class);

View File

@ -55,7 +55,8 @@ import org.springframework.web.accept.HeaderContentNegotiationStrategy;
import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.filter.OncePerRequestFilter;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.ThrowableAssert.catchThrowable; import static org.assertj.core.api.Assertions.assertThatCode;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -153,11 +154,9 @@ public class WebSecurityConfigurerAdapterTests {
MyFilter myFilter = this.spring.getContext().getBean(MyFilter.class); MyFilter myFilter = this.spring.getContext().getBean(MyFilter.class);
Throwable thrown = catchThrowable(() -> myFilter.userDetailsService.loadUserByUsername("user")); assertThatCode(() -> myFilter.userDetailsService.loadUserByUsername("user")).doesNotThrowAnyException();
assertThat(thrown).isNull(); assertThatExceptionOfType(UsernameNotFoundException.class)
.isThrownBy(() -> myFilter.userDetailsService.loadUserByUsername("admin"));
thrown = catchThrowable(() -> myFilter.userDetailsService.loadUserByUsername("admin"));
assertThat(thrown).isInstanceOf(UsernameNotFoundException.class);
} }
// SEC-2274: WebSecurityConfigurer adds ApplicationContext as a shared object // SEC-2274: WebSecurityConfigurer adds ApplicationContext as a shared object

View File

@ -38,8 +38,7 @@ import org.springframework.security.core.userdetails.PasswordEncodedUser;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.filter.OncePerRequestFilter;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.ThrowableAssert.catchThrowable;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -62,11 +61,11 @@ public class HttpConfigurationTests {
@Test @Test
public void configureWhenAddFilterUnregisteredThenThrowsBeanCreationException() { public void configureWhenAddFilterUnregisteredThenThrowsBeanCreationException() {
Throwable thrown = catchThrowable(() -> this.spring.register(UnregisteredFilterConfig.class).autowire()); assertThatExceptionOfType(BeanCreationException.class)
assertThat(thrown).isInstanceOf(BeanCreationException.class); .isThrownBy(() -> this.spring.register(UnregisteredFilterConfig.class).autowire())
assertThat(thrown.getMessage()).contains("The Filter class " + UnregisteredFilter.class.getName() .withMessageContaining("The Filter class " + UnregisteredFilter.class.getName()
+ " does not have a registered order and cannot be added without a specified order." + " does not have a registered order and cannot be added without a specified order."
+ " Consider using addFilterBefore or addFilterAfter instead."); + " Consider using addFilterBefore or addFilterAfter instead.");
} }
// https://github.com/spring-projects/spring-security-javaconfig/issues/104 // https://github.com/spring-projects/spring-security-javaconfig/issues/104

View File

@ -34,10 +34,12 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResp
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
@ -52,9 +54,6 @@ import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.verifyZeroInteractions;
import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientCredentials;
import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
@ -79,7 +78,8 @@ public class OAuth2ClientConfigurationTests {
TestingAuthenticationToken authentication = new TestingAuthenticationToken(principalName, "password"); TestingAuthenticationToken authentication = new TestingAuthenticationToken(principalName, "password");
ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class); ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class);
ClientRegistration clientRegistration = clientRegistration().registrationId(clientRegistrationId).build(); ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration()
.registrationId(clientRegistrationId).build();
given(clientRegistrationRepository.findByRegistrationId(eq(clientRegistrationId))) given(clientRegistrationRepository.findByRegistrationId(eq(clientRegistrationId)))
.willReturn(clientRegistration); .willReturn(clientRegistration);
@ -99,8 +99,10 @@ public class OAuth2ClientConfigurationTests {
OAuth2AuthorizedClientArgumentResolverConfig.ACCESS_TOKEN_RESPONSE_CLIENT = accessTokenResponseClient; OAuth2AuthorizedClientArgumentResolverConfig.ACCESS_TOKEN_RESPONSE_CLIENT = accessTokenResponseClient;
this.spring.register(OAuth2AuthorizedClientArgumentResolverConfig.class).autowire(); this.spring.register(OAuth2AuthorizedClientArgumentResolverConfig.class).autowire();
this.mockMvc.perform(get("/authorized-client").with(authentication(authentication))).andExpect(status().isOk()) this.mockMvc
.andExpect(content().string("resolved")); .perform(get("/authorized-client")
.with(SecurityMockMvcRequestPostProcessors.authentication(authentication)))
.andExpect(status().isOk()).andExpect(content().string("resolved"));
verifyZeroInteractions(accessTokenResponseClient); verifyZeroInteractions(accessTokenResponseClient);
} }
@ -115,7 +117,8 @@ public class OAuth2ClientConfigurationTests {
OAuth2AuthorizedClientRepository authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); OAuth2AuthorizedClientRepository authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class);
OAuth2AccessTokenResponseClient accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); OAuth2AccessTokenResponseClient accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class);
ClientRegistration clientRegistration = clientCredentials().registrationId(clientRegistrationId).build(); ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials()
.registrationId(clientRegistrationId).build();
given(clientRegistrationRepository.findByRegistrationId(clientRegistrationId)).willReturn(clientRegistration); given(clientRegistrationRepository.findByRegistrationId(clientRegistrationId)).willReturn(clientRegistration);
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("access-token-1234") OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("access-token-1234")
@ -128,8 +131,10 @@ public class OAuth2ClientConfigurationTests {
OAuth2AuthorizedClientArgumentResolverConfig.ACCESS_TOKEN_RESPONSE_CLIENT = accessTokenResponseClient; OAuth2AuthorizedClientArgumentResolverConfig.ACCESS_TOKEN_RESPONSE_CLIENT = accessTokenResponseClient;
this.spring.register(OAuth2AuthorizedClientArgumentResolverConfig.class).autowire(); this.spring.register(OAuth2AuthorizedClientArgumentResolverConfig.class).autowire();
this.mockMvc.perform(get("/authorized-client").with(authentication(authentication))).andExpect(status().isOk()) this.mockMvc
.andExpect(content().string("resolved")); .perform(get("/authorized-client")
.with(SecurityMockMvcRequestPostProcessors.authentication(authentication)))
.andExpect(status().isOk()).andExpect(content().string("resolved"));
verify(accessTokenResponseClient, times(1)).getTokenResponse(any(OAuth2ClientCredentialsGrantRequest.class)); verify(accessTokenResponseClient, times(1)).getTokenResponse(any(OAuth2ClientCredentialsGrantRequest.class));
} }
@ -176,7 +181,8 @@ public class OAuth2ClientConfigurationTests {
OAuth2AuthorizedClientRepository authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); OAuth2AuthorizedClientRepository authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class);
OAuth2AuthorizedClientManager authorizedClientManager = mock(OAuth2AuthorizedClientManager.class); OAuth2AuthorizedClientManager authorizedClientManager = mock(OAuth2AuthorizedClientManager.class);
ClientRegistration clientRegistration = clientRegistration().registrationId(clientRegistrationId).build(); ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration()
.registrationId(clientRegistrationId).build();
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principalName, OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principalName,
TestOAuth2AccessTokens.noScopes()); TestOAuth2AccessTokens.noScopes());
@ -187,8 +193,10 @@ public class OAuth2ClientConfigurationTests {
OAuth2AuthorizedClientManagerRegisteredConfig.AUTHORIZED_CLIENT_MANAGER = authorizedClientManager; OAuth2AuthorizedClientManagerRegisteredConfig.AUTHORIZED_CLIENT_MANAGER = authorizedClientManager;
this.spring.register(OAuth2AuthorizedClientManagerRegisteredConfig.class).autowire(); this.spring.register(OAuth2AuthorizedClientManagerRegisteredConfig.class).autowire();
this.mockMvc.perform(get("/authorized-client").with(authentication(authentication))).andExpect(status().isOk()) this.mockMvc
.andExpect(content().string("resolved")); .perform(get("/authorized-client")
.with(SecurityMockMvcRequestPostProcessors.authentication(authentication)))
.andExpect(status().isOk()).andExpect(content().string("resolved"));
verify(authorizedClientManager).authorize(any()); verify(authorizedClientManager).authorize(any());
verifyNoInteractions(clientRegistrationRepository); verifyNoInteractions(clientRegistrationRepository);

View File

@ -31,14 +31,14 @@ import org.springframework.context.annotation.Configuration;
import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.config.test.SpringTestRule;
import org.springframework.security.oauth2.server.resource.authentication.BearerTokenAuthentication; import org.springframework.security.oauth2.server.resource.authentication.BearerTokenAuthentication;
import org.springframework.security.oauth2.server.resource.authentication.TestBearerTokenAuthentications;
import org.springframework.security.oauth2.server.resource.web.reactive.function.client.ServletBearerExchangeFilterFunction; import org.springframework.security.oauth2.server.resource.web.reactive.function.client.ServletBearerExchangeFilterFunction;
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.client.WebClient;
import static org.springframework.security.oauth2.server.resource.authentication.TestBearerTokenAuthentications.bearer;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
@ -60,21 +60,21 @@ public class SecurityReactorContextConfigurationResourceServerTests {
// gh-7418 // gh-7418
@Test @Test
public void requestWhenUsingFilterThenBearerTokenPropagated() throws Exception { public void requestWhenUsingFilterThenBearerTokenPropagated() throws Exception {
BearerTokenAuthentication authentication = bearer(); BearerTokenAuthentication authentication = TestBearerTokenAuthentications.bearer();
this.spring.register(BearerFilterConfig.class, WebServerConfig.class, Controller.class).autowire(); this.spring.register(BearerFilterConfig.class, WebServerConfig.class, Controller.class).autowire();
this.mockMvc.perform(get("/token").with(authentication(authentication))).andExpect(status().isOk()) this.mockMvc.perform(get("/token").with(SecurityMockMvcRequestPostProcessors.authentication(authentication)))
.andExpect(content().string("Bearer token")); .andExpect(status().isOk()).andExpect(content().string("Bearer token"));
} }
// gh-7418 // gh-7418
@Test @Test
public void requestWhenNotUsingFilterThenBearerTokenNotPropagated() throws Exception { public void requestWhenNotUsingFilterThenBearerTokenNotPropagated() throws Exception {
BearerTokenAuthentication authentication = bearer(); BearerTokenAuthentication authentication = TestBearerTokenAuthentications.bearer();
this.spring.register(BearerFilterlessConfig.class, WebServerConfig.class, Controller.class).autowire(); this.spring.register(BearerFilterlessConfig.class, WebServerConfig.class, Controller.class).autowire();
this.mockMvc.perform(get("/token").with(authentication(authentication))).andExpect(status().isOk()) this.mockMvc.perform(get("/token").with(SecurityMockMvcRequestPostProcessors.authentication(authentication)))
.andExpect(content().string("")); .andExpect(status().isOk()).andExpect(content().string(""));
} }
@EnableWebSecurity @EnableWebSecurity

View File

@ -33,11 +33,13 @@ import reactor.core.publisher.Operators;
import reactor.test.StepVerifier; import reactor.test.StepVerifier;
import reactor.util.context.Context; import reactor.util.context.Context;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.SecurityReactorContextConfiguration.SecurityReactorContextSubscriber;
import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.config.test.SpringTestRule;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
@ -51,8 +53,6 @@ import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.entry; import static org.assertj.core.api.Assertions.entry;
import static org.springframework.http.HttpMethod.GET;
import static org.springframework.security.config.annotation.web.configuration.SecurityReactorContextConfiguration.SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES;
/** /**
* Tests for {@link SecurityReactorContextConfiguration}. * Tests for {@link SecurityReactorContextConfiguration}.
@ -88,7 +88,7 @@ public class SecurityReactorContextConfigurationTests {
@Test @Test
public void createSubscriberIfNecessaryWhenSubscriberContextContainsSecurityContextAttributesThenReturnOriginalSubscriber() { public void createSubscriberIfNecessaryWhenSubscriberContextContainsSecurityContextAttributesThenReturnOriginalSubscriber() {
Context context = Context.of(SECURITY_CONTEXT_ATTRIBUTES, new HashMap<>()); Context context = Context.of(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES, new HashMap<>());
BaseSubscriber<Object> originalSubscriber = new BaseSubscriber<Object>() { BaseSubscriber<Object> originalSubscriber = new BaseSubscriber<Object>() {
@Override @Override
public Context currentContext() { public Context currentContext() {
@ -120,7 +120,8 @@ public class SecurityReactorContextConfigurationTests {
Context resultContext = subscriber.currentContext(); Context resultContext = subscriber.currentContext();
assertThat(resultContext.getOrEmpty(testKey)).hasValue(testValue); assertThat(resultContext.getOrEmpty(testKey)).hasValue(testValue);
Map<Object, Object> securityContextAttributes = resultContext.getOrDefault(SECURITY_CONTEXT_ATTRIBUTES, null); Map<Object, Object> securityContextAttributes = resultContext
.getOrDefault(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES, null);
assertThat(securityContextAttributes).hasSize(3); assertThat(securityContextAttributes).hasSize(3);
assertThat(securityContextAttributes).contains(entry(HttpServletRequest.class, this.servletRequest), assertThat(securityContextAttributes).contains(entry(HttpServletRequest.class, this.servletRequest),
entry(HttpServletResponse.class, this.servletResponse), entry(HttpServletResponse.class, this.servletResponse),
@ -133,7 +134,8 @@ public class SecurityReactorContextConfigurationTests {
.setRequestAttributes(new ServletRequestAttributes(this.servletRequest, this.servletResponse)); .setRequestAttributes(new ServletRequestAttributes(this.servletRequest, this.servletResponse));
SecurityContextHolder.getContext().setAuthentication(this.authentication); SecurityContextHolder.getContext().setAuthentication(this.authentication);
Context parentContext = Context.of(SECURITY_CONTEXT_ATTRIBUTES, new HashMap<>()); Context parentContext = Context.of(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES,
new HashMap<>());
BaseSubscriber<Object> parent = new BaseSubscriber<Object>() { BaseSubscriber<Object> parent = new BaseSubscriber<Object>() {
@Override @Override
public Context currentContext() { public Context currentContext() {
@ -206,8 +208,9 @@ public class SecurityReactorContextConfigurationTests {
ClientResponse clientResponseOk = ClientResponse.create(HttpStatus.OK).build(); ClientResponse clientResponseOk = ClientResponse.create(HttpStatus.OK).build();
ExchangeFilterFunction filter = (req, next) -> Mono.subscriberContext() ExchangeFilterFunction filter = (req, next) -> Mono.subscriberContext()
.filter(ctx -> ctx.hasKey(SECURITY_CONTEXT_ATTRIBUTES)).map(ctx -> ctx.get(SECURITY_CONTEXT_ATTRIBUTES)) .filter(ctx -> ctx.hasKey(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES))
.cast(Map.class).map(attributes -> { .map(ctx -> ctx.get(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES)).cast(Map.class)
.map(attributes -> {
if (attributes.containsKey(HttpServletRequest.class) if (attributes.containsKey(HttpServletRequest.class)
&& attributes.containsKey(HttpServletResponse.class) && attributes.containsKey(HttpServletResponse.class)
&& attributes.containsKey(Authentication.class)) { && attributes.containsKey(Authentication.class)) {
@ -218,7 +221,7 @@ public class SecurityReactorContextConfigurationTests {
} }
}); });
ClientRequest clientRequest = ClientRequest.create(GET, URI.create("https://example.com")).build(); ClientRequest clientRequest = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build();
MockExchangeFunction exchange = new MockExchangeFunction(); MockExchangeFunction exchange = new MockExchangeFunction();
Map<Object, Object> expectedContextAttributes = new HashMap<>(); Map<Object, Object> expectedContextAttributes = new HashMap<>();
@ -230,8 +233,8 @@ public class SecurityReactorContextConfigurationTests {
.flatMap(response -> filter.filter(clientRequest, exchange)); .flatMap(response -> filter.filter(clientRequest, exchange));
StepVerifier.create(clientResponseMono).expectAccessibleContext() StepVerifier.create(clientResponseMono).expectAccessibleContext()
.contains(SECURITY_CONTEXT_ATTRIBUTES, expectedContextAttributes).then().expectNext(clientResponseOk) .contains(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES, expectedContextAttributes)
.verifyComplete(); .then().expectNext(clientResponseOk).verifyComplete();
} }
@EnableWebSecurity @EnableWebSecurity

View File

@ -20,6 +20,7 @@ import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpHeaders;
import org.springframework.security.config.annotation.ObjectPostProcessor; import org.springframework.security.config.annotation.ObjectPostProcessor;
import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
@ -28,9 +29,6 @@ import org.springframework.security.config.test.SpringTestRule;
import org.springframework.security.web.header.HeaderWriterFilter; import org.springframework.security.web.header.HeaderWriterFilter;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
import static org.springframework.http.HttpHeaders.CACHE_CONTROL;
import static org.springframework.http.HttpHeaders.EXPIRES;
import static org.springframework.http.HttpHeaders.PRAGMA;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header;
@ -54,8 +52,9 @@ public class HeadersConfigurerEagerHeadersTests {
this.mvc.perform(get("/").secure(true)).andExpect(header().string("X-Content-Type-Options", "nosniff")) this.mvc.perform(get("/").secure(true)).andExpect(header().string("X-Content-Type-Options", "nosniff"))
.andExpect(header().string("X-Frame-Options", "DENY")) .andExpect(header().string("X-Frame-Options", "DENY"))
.andExpect(header().string("Strict-Transport-Security", "max-age=31536000 ; includeSubDomains")) .andExpect(header().string("Strict-Transport-Security", "max-age=31536000 ; includeSubDomains"))
.andExpect(header().string(CACHE_CONTROL, "no-cache, no-store, max-age=0, must-revalidate")) .andExpect(header().string(HttpHeaders.CACHE_CONTROL, "no-cache, no-store, max-age=0, must-revalidate"))
.andExpect(header().string(EXPIRES, "0")).andExpect(header().string(PRAGMA, "no-cache")) .andExpect(header().string(HttpHeaders.EXPIRES, "0"))
.andExpect(header().string(HttpHeaders.PRAGMA, "no-cache"))
.andExpect(header().string("X-XSS-Protection", "1; mode=block")); .andExpect(header().string("X-XSS-Protection", "1; mode=block"));
} }

View File

@ -39,7 +39,7 @@ import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.access.intercept.FilterSecurityInterceptor; import org.springframework.security.web.access.intercept.FilterSecurityInterceptor;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import static org.assertj.core.api.Java6Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
/** /**
* @author Rob Winch * @author Rob Winch

View File

@ -27,16 +27,13 @@ import org.springframework.security.config.annotation.web.configuration.WebSecur
import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.config.test.SpringTestRule;
import org.springframework.security.test.context.annotation.SecurityTestExecutionListeners; import org.springframework.security.test.context.annotation.SecurityTestExecutionListeners;
import org.springframework.security.test.context.support.WithMockUser; import org.springframework.security.test.context.support.WithMockUser;
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors;
import org.springframework.security.web.authentication.logout.HeaderWriterLogoutHandler; import org.springframework.security.web.authentication.logout.HeaderWriterLogoutHandler;
import org.springframework.security.web.header.writers.ClearSiteDataHeaderWriter; import org.springframework.security.web.header.writers.ClearSiteDataHeaderWriter;
import org.springframework.security.web.header.writers.ClearSiteDataHeaderWriter.Directive;
import org.springframework.test.context.junit4.SpringRunner; import org.springframework.test.context.junit4.SpringRunner;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
import static org.springframework.security.web.header.writers.ClearSiteDataHeaderWriter.Directive.CACHE;
import static org.springframework.security.web.header.writers.ClearSiteDataHeaderWriter.Directive.COOKIES;
import static org.springframework.security.web.header.writers.ClearSiteDataHeaderWriter.Directive.EXECUTION_CONTEXTS;
import static org.springframework.security.web.header.writers.ClearSiteDataHeaderWriter.Directive.STORAGE;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header;
@ -55,7 +52,8 @@ public class LogoutConfigurerClearSiteDataTests {
private static final String CLEAR_SITE_DATA_HEADER = "Clear-Site-Data"; private static final String CLEAR_SITE_DATA_HEADER = "Clear-Site-Data";
private static final ClearSiteDataHeaderWriter.Directive[] SOURCE = { CACHE, COOKIES, STORAGE, EXECUTION_CONTEXTS }; private static final Directive[] SOURCE = { Directive.CACHE, Directive.COOKIES, Directive.STORAGE,
Directive.EXECUTION_CONTEXTS };
private static final String HEADER_VALUE = "\"cache\", \"cookies\", \"storage\", \"executionContexts\""; private static final String HEADER_VALUE = "\"cache\", \"cookies\", \"storage\", \"executionContexts\"";
@ -70,7 +68,7 @@ public class LogoutConfigurerClearSiteDataTests {
public void logoutWhenRequestTypeGetThenHeaderNotPresentt() throws Exception { public void logoutWhenRequestTypeGetThenHeaderNotPresentt() throws Exception {
this.spring.register(HttpLogoutConfig.class).autowire(); this.spring.register(HttpLogoutConfig.class).autowire();
this.mvc.perform(get("/logout").secure(true).with(csrf())) this.mvc.perform(get("/logout").secure(true).with(SecurityMockMvcRequestPostProcessors.csrf()))
.andExpect(header().doesNotExist(CLEAR_SITE_DATA_HEADER)); .andExpect(header().doesNotExist(CLEAR_SITE_DATA_HEADER));
} }
@ -79,7 +77,8 @@ public class LogoutConfigurerClearSiteDataTests {
public void logoutWhenRequestTypePostAndNotSecureThenHeaderNotPresent() throws Exception { public void logoutWhenRequestTypePostAndNotSecureThenHeaderNotPresent() throws Exception {
this.spring.register(HttpLogoutConfig.class).autowire(); this.spring.register(HttpLogoutConfig.class).autowire();
this.mvc.perform(post("/logout").with(csrf())).andExpect(header().doesNotExist(CLEAR_SITE_DATA_HEADER)); this.mvc.perform(post("/logout").with(SecurityMockMvcRequestPostProcessors.csrf()))
.andExpect(header().doesNotExist(CLEAR_SITE_DATA_HEADER));
} }
@Test @Test
@ -87,7 +86,7 @@ public class LogoutConfigurerClearSiteDataTests {
public void logoutWhenRequestTypePostAndSecureThenHeaderIsPresent() throws Exception { public void logoutWhenRequestTypePostAndSecureThenHeaderIsPresent() throws Exception {
this.spring.register(HttpLogoutConfig.class).autowire(); this.spring.register(HttpLogoutConfig.class).autowire();
this.mvc.perform(post("/logout").secure(true).with(csrf())) this.mvc.perform(post("/logout").secure(true).with(SecurityMockMvcRequestPostProcessors.csrf()))
.andExpect(header().stringValues(CLEAR_SITE_DATA_HEADER, HEADER_VALUE)); .andExpect(header().stringValues(CLEAR_SITE_DATA_HEADER, HEADER_VALUE));
} }

View File

@ -26,6 +26,7 @@ import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.openid4java.consumer.ConsumerManager; import org.openid4java.consumer.ConsumerManager;
import org.openid4java.discovery.DiscoveryInformation; import org.openid4java.discovery.DiscoveryInformation;
import org.openid4java.discovery.yadis.YadisResolver;
import org.openid4java.message.AuthRequest; import org.openid4java.message.AuthRequest;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
@ -63,7 +64,6 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset; import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.openid4java.discovery.yadis.YadisResolver.YADIS_XRDS_LOCATION;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
@ -107,7 +107,7 @@ public class NamespaceHttpOpenIDLoginTests {
try (MockWebServer server = new MockWebServer()) { try (MockWebServer server = new MockWebServer()) {
String endpoint = server.url("/").toString(); String endpoint = server.url("/").toString();
server.enqueue(new MockResponse().addHeader(YADIS_XRDS_LOCATION, endpoint)); server.enqueue(new MockResponse().addHeader(YadisResolver.YADIS_XRDS_LOCATION, endpoint));
server.enqueue(new MockResponse() server.enqueue(new MockResponse()
.setBody(String.format("<XRDS><XRD><Service><URI>%s</URI></Service></XRD></XRDS>", endpoint))); .setBody(String.format("<XRDS><XRD><Service><URI>%s</URI></Service></XRD></XRDS>", endpoint)));

View File

@ -44,7 +44,7 @@ import org.springframework.security.web.context.HttpSessionSecurityContextReposi
import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.assertj.core.api.Assertions.assertThat;
/** /**
* @author Rob Winch * @author Rob Winch

View File

@ -69,6 +69,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.core.oidc.TestOidcIdTokens;
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.oidc.user.OidcUser;
@ -80,6 +81,7 @@ import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;
import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder; import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtDecoderFactory; import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
import org.springframework.security.oauth2.jwt.TestJwts;
import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.context.HttpRequestResponseHolder; import org.springframework.security.web.context.HttpRequestResponseHolder;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository; import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
@ -92,8 +94,6 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.springframework.security.oauth2.core.oidc.TestOidcIdTokens.idToken;
import static org.springframework.security.oauth2.jwt.TestJwts.jwt;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
@ -602,7 +602,7 @@ public class OAuth2LoginConfigurerTests {
} }
private static OAuth2UserService<OidcUserRequest, OidcUser> createOidcUserService() { private static OAuth2UserService<OidcUserRequest, OidcUser> createOidcUserService() {
OidcIdToken idToken = idToken().build(); OidcIdToken idToken = TestOidcIdTokens.idToken().build();
return request -> new DefaultOidcUser(Collections.singleton(new OidcUserAuthority(idToken)), idToken); return request -> new DefaultOidcUser(Collections.singleton(new OidcUserAuthority(idToken)), idToken);
} }
@ -993,7 +993,7 @@ public class OAuth2LoginConfigurerTests {
claims.put(IdTokenClaimNames.ISS, "http://localhost/iss"); claims.put(IdTokenClaimNames.ISS, "http://localhost/iss");
claims.put(IdTokenClaimNames.AUD, Arrays.asList("clientId", "a", "u", "d")); claims.put(IdTokenClaimNames.AUD, Arrays.asList("clientId", "a", "u", "d"));
claims.put(IdTokenClaimNames.AZP, "clientId"); claims.put(IdTokenClaimNames.AZP, "clientId");
Jwt jwt = jwt().claims(c -> c.putAll(claims)).build(); Jwt jwt = TestJwts.jwt().claims(c -> c.putAll(claims)).build();
JwtDecoder jwtDecoder = mock(JwtDecoder.class); JwtDecoder jwtDecoder = mock(JwtDecoder.class);
given(jwtDecoder.decode(any())).willReturn(jwt); given(jwtDecoder.decode(any())).willReturn(jwt);
return jwtDecoder; return jwtDecoder;

View File

@ -94,13 +94,16 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2TokenValidator; import org.springframework.security.oauth2.core.OAuth2TokenValidator;
import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.jose.TestKeys; import org.springframework.security.oauth2.jose.TestKeys;
import org.springframework.security.oauth2.jwt.BadJwtException; import org.springframework.security.oauth2.jwt.BadJwtException;
import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtClaimNames;
import org.springframework.security.oauth2.jwt.JwtDecoder; import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtException; import org.springframework.security.oauth2.jwt.JwtException;
import org.springframework.security.oauth2.jwt.JwtTimestampValidator; import org.springframework.security.oauth2.jwt.JwtTimestampValidator;
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder; import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
import org.springframework.security.oauth2.jwt.TestJwts;
import org.springframework.security.oauth2.server.resource.authentication.BearerTokenAuthentication; import org.springframework.security.oauth2.server.resource.authentication.BearerTokenAuthentication;
import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationConverter; import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationConverter;
import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken; import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken;
@ -124,6 +127,7 @@ import org.springframework.util.MultiValueMap;
import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.client.RestOperations; import org.springframework.web.client.RestOperations;
import org.springframework.web.context.support.GenericWebApplicationContext; import org.springframework.web.context.support.GenericWebApplicationContext;
@ -131,7 +135,7 @@ import org.springframework.web.context.support.GenericWebApplicationContext;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatCode;
import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.core.StringStartsWith.startsWith; import static org.hamcrest.CoreMatchers.startsWith;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
@ -140,12 +144,6 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.config.Customizer.withDefaults;
import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes;
import static org.springframework.security.oauth2.jwt.JwtClaimNames.ISS;
import static org.springframework.security.oauth2.jwt.JwtClaimNames.SUB;
import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withJwkSetUri;
import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withPublicKey;
import static org.springframework.security.oauth2.jwt.TestJwts.jwt;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
@ -154,8 +152,6 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
import static org.springframework.web.bind.annotation.RequestMethod.GET;
import static org.springframework.web.bind.annotation.RequestMethod.POST;
/** /**
* Tests for {@link OAuth2ResourceServerConfigurer} * Tests for {@link OAuth2ResourceServerConfigurer}
@ -169,9 +165,9 @@ public class OAuth2ResourceServerConfigurerTests {
private static final String JWT_SUBJECT = "mock-test-subject"; private static final String JWT_SUBJECT = "mock-test-subject";
private static final Map<String, Object> JWT_CLAIMS = Collections.singletonMap(SUB, JWT_SUBJECT); private static final Map<String, Object> JWT_CLAIMS = Collections.singletonMap(JwtClaimNames.SUB, JWT_SUBJECT);
private static final Jwt JWT = jwt().build(); private static final Jwt JWT = TestJwts.jwt().build();
private static final String JWK_SET_URI = "https://mock.org"; private static final String JWK_SET_URI = "https://mock.org";
@ -185,8 +181,8 @@ public class OAuth2ResourceServerConfigurerTests {
private static final String CLIENT_SECRET = "client-secret"; private static final String CLIENT_SECRET = "client-secret";
private static final BearerTokenAuthentication INTROSPECTION_AUTHENTICATION_TOKEN = new BearerTokenAuthentication( private static final BearerTokenAuthentication INTROSPECTION_AUTHENTICATION_TOKEN = new BearerTokenAuthentication(
new DefaultOAuth2AuthenticatedPrincipal(JWT_CLAIMS, Collections.emptyList()), noScopes(), new DefaultOAuth2AuthenticatedPrincipal(JWT_CLAIMS, Collections.emptyList()),
Collections.emptyList()); TestOAuth2AccessTokens.noScopes(), Collections.emptyList());
@Autowired(required = false) @Autowired(required = false)
MockMvc mvc; MockMvc mvc;
@ -1361,8 +1357,8 @@ public class OAuth2ResourceServerConfigurerTests {
private String jwtFromIssuer(String issuer) throws Exception { private String jwtFromIssuer(String issuer) throws Exception {
Map<String, Object> claims = new HashMap<>(); Map<String, Object> claims = new HashMap<>();
claims.put(ISS, issuer); claims.put(JwtClaimNames.ISS, issuer);
claims.put(SUB, "test-subject"); claims.put(JwtClaimNames.SUB, "test-subject");
claims.put("scope", "message:read"); claims.put("scope", "message:read");
JWSObject jws = new JWSObject(new JWSHeader.Builder(JWSAlgorithm.RS256).keyID("1").build(), JWSObject jws = new JWSObject(new JWSHeader.Builder(JWSAlgorithm.RS256).keyID("1").build(),
new Payload(new JSONObject(claims))); new Payload(new JSONObject(claims)));
@ -2066,7 +2062,7 @@ public class OAuth2ResourceServerConfigurerTests {
JwtDecoder decoder() throws Exception { JwtDecoder decoder() throws Exception {
RSAPublicKey publicKey = (RSAPublicKey) KeyFactory.getInstance("RSA") RSAPublicKey publicKey = (RSAPublicKey) KeyFactory.getInstance("RSA")
.generatePublic(new X509EncodedKeySpec(this.spec)); .generatePublic(new X509EncodedKeySpec(this.spec));
return withPublicKey(publicKey).build(); return NimbusJwtDecoder.withPublicKey(publicKey).build();
} }
} }
@ -2285,7 +2281,7 @@ public class OAuth2ResourceServerConfigurerTests {
return "post"; return "post";
} }
@RequestMapping(value = "/authenticated", method = { GET, POST }) @RequestMapping(value = "/authenticated", method = { RequestMethod.GET, RequestMethod.POST })
public String authenticated(Authentication authentication) { public String authenticated(Authentication authentication) {
return authentication.getName(); return authentication.getName();
} }
@ -2365,7 +2361,8 @@ public class OAuth2ResourceServerConfigurerTests {
@Bean @Bean
NimbusJwtDecoder jwtDecoder() { NimbusJwtDecoder jwtDecoder() {
return withJwkSetUri("https://example.org/.well-known/jwks.json").restOperations(this.rest).build(); return NimbusJwtDecoder.withJwkSetUri("https://example.org/.well-known/jwks.json").restOperations(this.rest)
.build();
} }
@Bean @Bean

View File

@ -24,6 +24,7 @@ import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.openid4java.consumer.ConsumerManager; import org.openid4java.consumer.ConsumerManager;
import org.openid4java.discovery.DiscoveryInformation; import org.openid4java.discovery.DiscoveryInformation;
import org.openid4java.discovery.yadis.YadisResolver;
import org.openid4java.message.AuthRequest; import org.openid4java.message.AuthRequest;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
@ -47,7 +48,6 @@ import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.openid4java.discovery.yadis.YadisResolver.YADIS_XRDS_LOCATION;
import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.config.Customizer.withDefaults;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl;
@ -113,7 +113,7 @@ public class OpenIDLoginConfigurerTests {
try (MockWebServer server = new MockWebServer()) { try (MockWebServer server = new MockWebServer()) {
String endpoint = server.url("/").toString(); String endpoint = server.url("/").toString();
server.enqueue(new MockResponse().addHeader(YADIS_XRDS_LOCATION, endpoint)); server.enqueue(new MockResponse().addHeader(YadisResolver.YADIS_XRDS_LOCATION, endpoint));
server.enqueue(new MockResponse() server.enqueue(new MockResponse()
.setBody(String.format("<XRDS><XRD><Service><URI>%s</URI></Service></XRD></XRDS>", endpoint))); .setBody(String.format("<XRDS><XRD><Service><URI>%s</URI></Service></XRD></XRDS>", endpoint)));
@ -151,7 +151,7 @@ public class OpenIDLoginConfigurerTests {
try (MockWebServer server = new MockWebServer()) { try (MockWebServer server = new MockWebServer()) {
String endpoint = server.url("/").toString(); String endpoint = server.url("/").toString();
server.enqueue(new MockResponse().addHeader(YADIS_XRDS_LOCATION, endpoint)); server.enqueue(new MockResponse().addHeader(YadisResolver.YADIS_XRDS_LOCATION, endpoint));
server.enqueue(new MockResponse() server.enqueue(new MockResponse()
.setBody(String.format("<XRDS><XRD><Service><URI>%s</URI></Service></XRD></XRDS>", endpoint))); .setBody(String.format("<XRDS><XRD><Service><URI>%s</URI></Service></XRD></XRDS>", endpoint)));

View File

@ -19,6 +19,7 @@ package org.springframework.security.config.annotation.web.configurers.saml2;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.IOException; import java.io.IOException;
import java.net.URLDecoder; import java.net.URLDecoder;
import java.nio.charset.StandardCharsets;
import java.time.Duration; import java.time.Duration;
import java.util.Arrays; import java.util.Arrays;
import java.util.Base64; import java.util.Base64;
@ -60,14 +61,17 @@ import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
import org.springframework.security.saml2.Saml2Exception; import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.core.TestSaml2X509Credentials;
import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider; import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider;
import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory;
import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication; import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
import org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter; import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.FilterChainProxy;
@ -81,7 +85,6 @@ import org.springframework.test.web.servlet.MvcResult;
import org.springframework.web.util.UriComponents; import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UriComponentsBuilder;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.anyString;
@ -89,10 +92,6 @@ import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.config.Customizer.withDefaults;
import static org.springframework.security.saml2.core.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext;
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.noCredentials;
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl;
@ -171,7 +170,8 @@ public class Saml2LoginConfigurerTests {
public void saml2LoginWhenCustomAuthenticationRequestContextResolverThenUses() throws Exception { public void saml2LoginWhenCustomAuthenticationRequestContextResolverThenUses() throws Exception {
this.spring.register(CustomAuthenticationRequestContextResolver.class).autowire(); this.spring.register(CustomAuthenticationRequestContextResolver.class).autowire();
Saml2AuthenticationRequestContext context = authenticationRequestContext().build(); Saml2AuthenticationRequestContext context = TestSaml2AuthenticationRequestContexts
.authenticationRequestContext().build();
Saml2AuthenticationRequestContextResolver resolver = CustomAuthenticationRequestContextResolver.resolver; Saml2AuthenticationRequestContextResolver resolver = CustomAuthenticationRequestContextResolver.resolver;
given(resolver.resolve(any(HttpServletRequest.class))).willReturn(context); given(resolver.resolve(any(HttpServletRequest.class))).willReturn(context);
this.mvc.perform(get("/saml2/authenticate/registration-id")).andExpect(status().isFound()); this.mvc.perform(get("/saml2/authenticate/registration-id")).andExpect(status().isFound());
@ -193,9 +193,9 @@ public class Saml2LoginConfigurerTests {
@Test @Test
public void authenticateWhenCustomAuthenticationConverterThenUses() throws Exception { public void authenticateWhenCustomAuthenticationConverterThenUses() throws Exception {
this.spring.register(CustomAuthenticationConverter.class).autowire(); this.spring.register(CustomAuthenticationConverter.class).autowire();
RelyingPartyRegistration relyingPartyRegistration = noCredentials() RelyingPartyRegistration relyingPartyRegistration = TestRelyingPartyRegistrations.noCredentials()
.assertingPartyDetails( .assertingPartyDetails(party -> party.verificationX509Credentials(
party -> party.verificationX509Credentials(c -> c.add(relyingPartyVerifyingCredential()))) c -> c.add(TestSaml2X509Credentials.relyingPartyVerifyingCredential())))
.build(); .build();
String response = new String(samlDecode(SIGNED_RESPONSE)); String response = new String(samlDecode(SIGNED_RESPONSE));
given(CustomAuthenticationConverter.authenticationConverter.convert(any(HttpServletRequest.class))) given(CustomAuthenticationConverter.authenticationConverter.convert(any(HttpServletRequest.class)))
@ -254,7 +254,7 @@ public class Saml2LoginConfigurerTests {
InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true)); InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true));
iout.write(b); iout.write(b);
iout.finish(); iout.finish();
return new String(out.toByteArray(), UTF_8); return new String(out.toByteArray(), StandardCharsets.UTF_8);
} }
catch (IOException e) { catch (IOException e) {
throw new Saml2Exception("Unable to inflate string", e); throw new Saml2Exception("Unable to inflate string", e);
@ -387,7 +387,8 @@ public class Saml2LoginConfigurerTests {
@Bean @Bean
RelyingPartyRegistrationRepository relyingPartyRegistrationRepository() { RelyingPartyRegistrationRepository relyingPartyRegistrationRepository() {
RelyingPartyRegistrationRepository repository = mock(RelyingPartyRegistrationRepository.class); RelyingPartyRegistrationRepository repository = mock(RelyingPartyRegistrationRepository.class);
given(repository.findByRegistrationId(anyString())).willReturn(relyingPartyRegistration().build()); given(repository.findByRegistrationId(anyString()))
.willReturn(TestRelyingPartyRegistrations.relyingPartyRegistration().build());
return repository; return repository;
} }

View File

@ -24,10 +24,7 @@ import java.security.cert.X509Certificate;
import org.springframework.security.converter.RsaKeyConverters; import org.springframework.security.converter.RsaKeyConverters;
import org.springframework.security.saml2.credentials.Saml2X509Credential; import org.springframework.security.saml2.credentials.Saml2X509Credential;
import org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType;
import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.DECRYPTION;
import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.SIGNING;
import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.VERIFICATION;
/** /**
* Preconfigured SAML credentials for SAML integration tests. * Preconfigured SAML credentials for SAML integration tests.
@ -58,7 +55,7 @@ public class TestSaml2Credentials {
+ "xbzb7ykxVr7EVFXwltPxzE9TmL9OACNNyF5eJHWMRMllarUvkcXlh4pux4ks9e6z\n" + "xbzb7ykxVr7EVFXwltPxzE9TmL9OACNNyF5eJHWMRMllarUvkcXlh4pux4ks9e6z\n"
+ "V9DQBy2zds9f1I3qxg0eX6JnGrXi/ZiCT+lJgVe3ZFXiejiLAiKB04sXW3ti0LW3\n" + "V9DQBy2zds9f1I3qxg0eX6JnGrXi/ZiCT+lJgVe3ZFXiejiLAiKB04sXW3ti0LW3\n"
+ "lx13Y1YlQ4/tlpgTgfIJxKV6nyPiLoK0nywbMd+vpAirDt2Oc+hk\n" + "-----END CERTIFICATE-----"; + "lx13Y1YlQ4/tlpgTgfIJxKV6nyPiLoK0nywbMd+vpAirDt2Oc+hk\n" + "-----END CERTIFICATE-----";
return new Saml2X509Credential(x509Certificate(certificate), VERIFICATION); return new Saml2X509Credential(x509Certificate(certificate), Saml2X509CredentialType.VERIFICATION);
} }
static X509Certificate x509Certificate(String source) { static X509Certificate x509Certificate(String source) {
@ -105,7 +102,7 @@ public class TestSaml2Credentials {
+ "RZ/nbTJ7VTeZOSyRoVn5XHhpuJ0B\n" + "-----END CERTIFICATE-----"; + "RZ/nbTJ7VTeZOSyRoVn5XHhpuJ0B\n" + "-----END CERTIFICATE-----";
PrivateKey pk = RsaKeyConverters.pkcs8().convert(new ByteArrayInputStream(key.getBytes())); PrivateKey pk = RsaKeyConverters.pkcs8().convert(new ByteArrayInputStream(key.getBytes()));
X509Certificate cert = x509Certificate(certificate); X509Certificate cert = x509Certificate(certificate);
return new Saml2X509Credential(pk, cert, SIGNING, DECRYPTION); return new Saml2X509Credential(pk, cert, Saml2X509CredentialType.SIGNING, Saml2X509CredentialType.DECRYPTION);
} }
} }

View File

@ -51,6 +51,7 @@ import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder;
import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.security.test.context.annotation.SecurityTestExecutionListeners; import org.springframework.security.test.context.annotation.SecurityTestExecutionListeners;
import org.springframework.security.test.context.support.WithMockUser; import org.springframework.security.test.context.support.WithMockUser;
import org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers;
import org.springframework.security.test.web.reactive.server.WebTestClientBuilder; import org.springframework.security.test.web.reactive.server.WebTestClientBuilder;
import org.springframework.security.web.reactive.result.method.annotation.AuthenticationPrincipalArgumentResolver; import org.springframework.security.web.reactive.result.method.annotation.AuthenticationPrincipalArgumentResolver;
import org.springframework.security.web.reactive.result.view.CsrfRequestDataValueProcessor; import org.springframework.security.web.reactive.result.view.CsrfRequestDataValueProcessor;
@ -71,7 +72,6 @@ import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.result.view.AbstractView; import org.springframework.web.reactive.result.view.AbstractView;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.csrf;
/** /**
* @author Rob Winch * @author Rob Winch
@ -202,8 +202,9 @@ public class EnableWebFluxSecurityTests {
MultiValueMap<String, String> data = new LinkedMultiValueMap<>(); MultiValueMap<String, String> data = new LinkedMultiValueMap<>();
data.add("username", "user"); data.add("username", "user");
data.add("password", "password"); data.add("password", "password");
client.mutateWith(csrf()).post().uri("/login").body(BodyInserters.fromFormData(data)).exchange().expectStatus() client.mutateWith(SecurityMockServerConfigurers.csrf()).post().uri("/login")
.is3xxRedirection().expectHeader().valueMatches("Location", "/"); .body(BodyInserters.fromFormData(data)).exchange().expectStatus().is3xxRedirection().expectHeader()
.valueMatches("Location", "/");
} }
@Test @Test

View File

@ -46,8 +46,6 @@ import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.fail; import static org.assertj.core.api.Assertions.fail;
import static org.springframework.messaging.simp.SimpMessageType.MESSAGE;
import static org.springframework.messaging.simp.SimpMessageType.SUBSCRIBE;
public class AbstractSecurityWebSocketMessageBrokerConfigurerDocTests { public class AbstractSecurityWebSocketMessageBrokerConfigurerDocTests {
@ -139,7 +137,7 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerDocTests {
.simpDestMatchers("/app/**").hasRole("USER") .simpDestMatchers("/app/**").hasRole("USER")
// <3> // <3>
.simpSubscribeDestMatchers("/user/**", "/topic/friends/*").hasRole("USER") // <4> .simpSubscribeDestMatchers("/user/**", "/topic/friends/*").hasRole("USER") // <4>
.simpTypeMatchers(MESSAGE, SUBSCRIBE).denyAll() // <5> .simpTypeMatchers(SimpMessageType.MESSAGE, SimpMessageType.SUBSCRIBE).denyAll() // <5>
.anyMessage().denyAll(); // <6> .anyMessage().denyAll(); // <6>
} }

View File

@ -26,7 +26,7 @@ import org.springframework.security.core.userdetails.ReactiveUserDetailsService;
import org.springframework.security.util.InMemoryResource; import org.springframework.security.util.InMemoryResource;
import org.springframework.test.context.junit4.SpringRunner; import org.springframework.test.context.junit4.SpringRunner;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.assertj.core.api.Assertions.assertThat;
/** /**
* @author Rob Winch * @author Rob Winch

View File

@ -25,7 +25,7 @@ import org.springframework.context.annotation.Configuration;
import org.springframework.security.core.userdetails.ReactiveUserDetailsService; import org.springframework.security.core.userdetails.ReactiveUserDetailsService;
import org.springframework.test.context.junit4.SpringRunner; import org.springframework.test.context.junit4.SpringRunner;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.assertj.core.api.Assertions.assertThat;
/** /**
* @author Rob Winch * @author Rob Winch

View File

@ -25,7 +25,7 @@ import org.springframework.context.annotation.Configuration;
import org.springframework.security.core.userdetails.ReactiveUserDetailsService; import org.springframework.security.core.userdetails.ReactiveUserDetailsService;
import org.springframework.test.context.junit4.SpringRunner; import org.springframework.test.context.junit4.SpringRunner;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.assertj.core.api.Assertions.assertThat;
/** /**
* @author Rob Winch * @author Rob Winch

View File

@ -29,7 +29,7 @@ import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.util.InMemoryResource; import org.springframework.security.util.InMemoryResource;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.assertj.core.api.Assertions.assertThatThrownBy;
/** /**
* @author Rob Winch * @author Rob Winch

View File

@ -18,13 +18,12 @@ package org.springframework.security.config.debug;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.springframework.security.config.BeanIds;
import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.config.test.SpringTestRule;
import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.debug.DebugFilter; import org.springframework.security.web.debug.DebugFilter;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.security.config.BeanIds.FILTER_CHAIN_PROXY;
import static org.springframework.security.config.BeanIds.SPRING_SECURITY_FILTER_CHAIN;
/** /**
* @author Rob Winch * @author Rob Winch
@ -42,8 +41,9 @@ public class SecurityDebugBeanFactoryPostProcessorTests {
"classpath:org/springframework/security/config/debug/SecurityDebugBeanFactoryPostProcessorTests-context.xml") "classpath:org/springframework/security/config/debug/SecurityDebugBeanFactoryPostProcessorTests-context.xml")
.autowire(); .autowire();
assertThat(this.spring.getContext().getBean(SPRING_SECURITY_FILTER_CHAIN)).isInstanceOf(DebugFilter.class); assertThat(this.spring.getContext().getBean(BeanIds.SPRING_SECURITY_FILTER_CHAIN))
assertThat(this.spring.getContext().getBean(FILTER_CHAIN_PROXY)).isInstanceOf(FilterChainProxy.class); .isInstanceOf(DebugFilter.class);
assertThat(this.spring.getContext().getBean(BeanIds.FILTER_CHAIN_PROXY)).isInstanceOf(FilterChainProxy.class);
} }
} }

View File

@ -50,6 +50,7 @@ import org.springframework.test.web.servlet.ResultMatcher;
import org.springframework.test.web.servlet.setup.MockMvcBuilders; import org.springframework.test.web.servlet.setup.MockMvcBuilders;
import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.ResponseBody; import org.springframework.web.bind.annotation.ResponseBody;
import org.springframework.web.servlet.support.RequestDataValueProcessor; import org.springframework.web.servlet.support.RequestDataValueProcessor;
@ -68,14 +69,6 @@ import static org.springframework.test.web.servlet.request.MockMvcRequestBuilder
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.request; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.request;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
import static org.springframework.web.bind.annotation.RequestMethod.DELETE;
import static org.springframework.web.bind.annotation.RequestMethod.GET;
import static org.springframework.web.bind.annotation.RequestMethod.HEAD;
import static org.springframework.web.bind.annotation.RequestMethod.OPTIONS;
import static org.springframework.web.bind.annotation.RequestMethod.PATCH;
import static org.springframework.web.bind.annotation.RequestMethod.POST;
import static org.springframework.web.bind.annotation.RequestMethod.PUT;
import static org.springframework.web.bind.annotation.RequestMethod.TRACE;
/** /**
* @author Rob Winch * @author Rob Winch
@ -441,20 +434,22 @@ public class CsrfConfigTests {
@Controller @Controller
public static class RootController { public static class RootController {
@RequestMapping(value = "/csrf-in-header", method = { HEAD, TRACE, OPTIONS }) @RequestMapping(value = "/csrf-in-header",
method = { RequestMethod.HEAD, RequestMethod.TRACE, RequestMethod.OPTIONS })
@ResponseBody @ResponseBody
String csrfInHeaderAndBody(CsrfToken token, HttpServletResponse response) { String csrfInHeaderAndBody(CsrfToken token, HttpServletResponse response) {
response.setHeader(token.getHeaderName(), token.getToken()); response.setHeader(token.getHeaderName(), token.getToken());
return csrfInBody(token); return csrfInBody(token);
} }
@RequestMapping(value = "/csrf", method = { POST, PUT, PATCH, DELETE, GET }) @RequestMapping(value = "/csrf", method = { RequestMethod.POST, RequestMethod.PUT, RequestMethod.PATCH,
RequestMethod.DELETE, RequestMethod.GET })
@ResponseBody @ResponseBody
String csrfInBody(CsrfToken token) { String csrfInBody(CsrfToken token) {
return token.getToken(); return token.getToken();
} }
@RequestMapping(value = "/ok", method = { POST, GET }) @RequestMapping(value = "/ok", method = { RequestMethod.POST, RequestMethod.GET })
@ResponseBody @ResponseBody
String ok() { String ok() {
return "ok"; return "ok";

View File

@ -24,8 +24,8 @@ import org.springframework.security.web.WebAttributes;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.core.IsNot.not; import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.core.IsNull.nullValue; import static org.hamcrest.CoreMatchers.nullValue;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;

View File

@ -102,6 +102,7 @@ import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.security.web.savedrequest.RequestCacheAwareFilter; import org.springframework.security.web.savedrequest.RequestCacheAwareFilter;
import org.springframework.security.web.servletapi.SecurityContextHolderAwareRequestFilter; import org.springframework.security.web.servletapi.SecurityContextHolderAwareRequestFilter;
import org.springframework.security.web.session.SessionManagementFilter; import org.springframework.security.web.session.SessionManagementFilter;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult; import org.springframework.test.web.servlet.MvcResult;
import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.GetMapping;
@ -120,7 +121,6 @@ import static org.mockito.Mockito.verify;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.x509; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.x509;
import static org.springframework.test.util.ReflectionTestUtils.getField;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.delete; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.delete;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
@ -618,8 +618,8 @@ public class MiscHttpConfigTests {
this.mvc.perform(get("/details").session(session)).andExpect(content().string(details.getClass().getName())); this.mvc.perform(get("/details").session(session)).andExpect(content().string(details.getClass().getName()));
assertThat(getField(getFilter(OpenIDAuthenticationFilter.class), "authenticationDetailsSource")) assertThat(ReflectionTestUtils.getField(getFilter(OpenIDAuthenticationFilter.class),
.isEqualTo(source); "authenticationDetailsSource")).isEqualTo(source);
} }
@Test @Test

View File

@ -40,6 +40,7 @@ import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
import org.springframework.security.test.context.annotation.SecurityTestExecutionListeners; import org.springframework.security.test.context.annotation.SecurityTestExecutionListeners;
import org.springframework.security.test.context.support.WithMockUser; import org.springframework.security.test.context.support.WithMockUser;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
@ -55,7 +56,6 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses.accessTokenResponse;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl;
@ -153,7 +153,7 @@ public class OAuth2ClientBeanDefinitionParserTests {
given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())) given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any()))
.willReturn(authorizationRequest); .willReturn(authorizationRequest);
OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().build(); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
MultiValueMap<String, String> params = new LinkedMultiValueMap<>(); MultiValueMap<String, String> params = new LinkedMultiValueMap<>();
@ -183,7 +183,7 @@ public class OAuth2ClientBeanDefinitionParserTests {
given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())) given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any()))
.willReturn(authorizationRequest); .willReturn(authorizationRequest);
OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().build(); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
MultiValueMap<String, String> params = new LinkedMultiValueMap<>(); MultiValueMap<String, String> params = new LinkedMultiValueMap<>();

View File

@ -52,6 +52,7 @@ import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests;
import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.core.user.OAuth2User;
@ -78,8 +79,6 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses.accessTokenResponse;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses.oidcAccessTokenResponse;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl;
@ -214,7 +213,7 @@ public class OAuth2LoginBeanDefinitionParserTests {
given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())) given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any()))
.willReturn(authorizationRequest); .willReturn(authorizationRequest);
OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().build(); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
OAuth2User oauth2User = TestOAuth2Users.create(); OAuth2User oauth2User = TestOAuth2Users.create();
@ -243,7 +242,7 @@ public class OAuth2LoginBeanDefinitionParserTests {
given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())) given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any()))
.willReturn(authorizationRequest); .willReturn(authorizationRequest);
OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().build(); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
OAuth2User oauth2User = TestOAuth2Users.create(); OAuth2User oauth2User = TestOAuth2Users.create();
@ -269,7 +268,8 @@ public class OAuth2LoginBeanDefinitionParserTests {
given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())) given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any()))
.willReturn(authorizationRequest); .willReturn(authorizationRequest);
OAuth2AccessTokenResponse accessTokenResponse = oidcAccessTokenResponse().build(); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.oidcAccessTokenResponse()
.build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
Jwt jwt = TestJwts.user(); Jwt jwt = TestJwts.user();
@ -297,7 +297,7 @@ public class OAuth2LoginBeanDefinitionParserTests {
given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())) given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any()))
.willReturn(authorizationRequest); .willReturn(authorizationRequest);
OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().build(); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
OAuth2User oauth2User = TestOAuth2Users.create(); OAuth2User oauth2User = TestOAuth2Users.create();
@ -326,7 +326,7 @@ public class OAuth2LoginBeanDefinitionParserTests {
given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())) given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any()))
.willReturn(authorizationRequest); .willReturn(authorizationRequest);
accessTokenResponse = oidcAccessTokenResponse().build(); accessTokenResponse = TestOAuth2AccessTokenResponses.oidcAccessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
Jwt jwt = TestJwts.user(); Jwt jwt = TestJwts.user();
@ -359,7 +359,7 @@ public class OAuth2LoginBeanDefinitionParserTests {
given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())) given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any()))
.willReturn(authorizationRequest); .willReturn(authorizationRequest);
OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().build(); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
OAuth2User oauth2User = TestOAuth2Users.create(); OAuth2User oauth2User = TestOAuth2Users.create();
@ -428,7 +428,7 @@ public class OAuth2LoginBeanDefinitionParserTests {
given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())) given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any()))
.willReturn(authorizationRequest); .willReturn(authorizationRequest);
OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().build(); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
OAuth2User oauth2User = TestOAuth2Users.create(); OAuth2User oauth2User = TestOAuth2Users.create();
@ -456,7 +456,7 @@ public class OAuth2LoginBeanDefinitionParserTests {
given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())) given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any()))
.willReturn(authorizationRequest); .willReturn(authorizationRequest);
OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().build(); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
OAuth2User oauth2User = TestOAuth2Users.create(); OAuth2User oauth2User = TestOAuth2Users.create();
@ -484,7 +484,7 @@ public class OAuth2LoginBeanDefinitionParserTests {
given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())) given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any()))
.willReturn(authorizationRequest); .willReturn(authorizationRequest);
OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().build(); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
OAuth2User oauth2User = TestOAuth2Users.create(); OAuth2User oauth2User = TestOAuth2Users.create();

View File

@ -76,9 +76,11 @@ import org.springframework.security.oauth2.core.OAuth2TokenValidator;
import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
import org.springframework.security.oauth2.jose.TestKeys; import org.springframework.security.oauth2.jose.TestKeys;
import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtClaimNames;
import org.springframework.security.oauth2.jwt.JwtDecoder; import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtException; import org.springframework.security.oauth2.jwt.JwtException;
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder; import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
import org.springframework.security.oauth2.jwt.TestJwts;
import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken; import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken;
import org.springframework.security.oauth2.server.resource.introspection.NimbusOpaqueTokenIntrospector; import org.springframework.security.oauth2.server.resource.introspection.NimbusOpaqueTokenIntrospector;
import org.springframework.security.oauth2.server.resource.introspection.OpaqueTokenIntrospector; import org.springframework.security.oauth2.server.resource.introspection.OpaqueTokenIntrospector;
@ -96,23 +98,15 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatCode;
import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.core.StringStartsWith.startsWith; import static org.hamcrest.CoreMatchers.startsWith;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset; import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.powermock.api.mockito.PowerMockito.when;
import static org.springframework.security.config.http.OAuth2ResourceServerBeanDefinitionParser.AUTHENTICATION_MANAGER_RESOLVER_REF;
import static org.springframework.security.config.http.OAuth2ResourceServerBeanDefinitionParser.JwtBeanDefinitionParser.DECODER_REF;
import static org.springframework.security.config.http.OAuth2ResourceServerBeanDefinitionParser.JwtBeanDefinitionParser.JWK_SET_URI;
import static org.springframework.security.config.http.OAuth2ResourceServerBeanDefinitionParser.OpaqueTokenBeanDefinitionParser.INTROSPECTION_URI;
import static org.springframework.security.config.http.OAuth2ResourceServerBeanDefinitionParser.OpaqueTokenBeanDefinitionParser.INTROSPECTOR_REF;
import static org.springframework.security.oauth2.jwt.JwtClaimNames.ISS;
import static org.springframework.security.oauth2.jwt.JwtClaimNames.SUB;
import static org.springframework.security.oauth2.jwt.TestJwts.jwt;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
@ -435,10 +429,10 @@ public class OAuth2ResourceServerBeanDefinitionParserTests {
.autowire(); .autowire();
JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class); JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class);
when(decoder.decode("token")).thenReturn(jwt().build()); given(decoder.decode("token")).willReturn(TestJwts.jwt().build());
BearerTokenResolver bearerTokenResolver = this.spring.getContext().getBean(BearerTokenResolver.class); BearerTokenResolver bearerTokenResolver = this.spring.getContext().getBean(BearerTokenResolver.class);
when(bearerTokenResolver.resolve(any(HttpServletRequest.class))).thenReturn("token"); given(bearerTokenResolver.resolve(any(HttpServletRequest.class))).willReturn("token");
this.mvc.perform(get("/")).andExpect(status().isNotFound()); this.mvc.perform(get("/")).andExpect(status().isNotFound());
@ -453,7 +447,7 @@ public class OAuth2ResourceServerBeanDefinitionParserTests {
this.spring.configLocations(xml("MockJwtDecoder"), xml("AllowBearerTokenInBody")).autowire(); this.spring.configLocations(xml("MockJwtDecoder"), xml("AllowBearerTokenInBody")).autowire();
JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class); JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class);
when(decoder.decode(anyString())).thenReturn(jwt().build()); given(decoder.decode(anyString())).willReturn(TestJwts.jwt().build());
this.mvc.perform(get("/authenticated").header("Authorization", "Bearer token")) this.mvc.perform(get("/authenticated").header("Authorization", "Bearer token"))
.andExpect(status().isNotFound()); .andExpect(status().isNotFound());
@ -468,7 +462,7 @@ public class OAuth2ResourceServerBeanDefinitionParserTests {
this.spring.configLocations(xml("MockJwtDecoder"), xml("AllowBearerTokenInQuery")).autowire(); this.spring.configLocations(xml("MockJwtDecoder"), xml("AllowBearerTokenInQuery")).autowire();
JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class); JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class);
Mockito.when(decoder.decode(anyString())).thenReturn(jwt().build()); given(decoder.decode(anyString())).willReturn(TestJwts.jwt().build());
this.mvc.perform(get("/authenticated").header("Authorization", "Bearer token")) this.mvc.perform(get("/authenticated").header("Authorization", "Bearer token"))
.andExpect(status().isNotFound()); .andExpect(status().isNotFound());
@ -517,7 +511,7 @@ public class OAuth2ResourceServerBeanDefinitionParserTests {
JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class); JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class);
when(decoder.decode(anyString())).thenReturn(jwt().build()); given(decoder.decode(anyString())).willReturn(TestJwts.jwt().build());
this.mvc.perform(get("/authenticated").header("Authorization", "Bearer token")) this.mvc.perform(get("/authenticated").header("Authorization", "Bearer token"))
.andExpect(status().isNotFound()); .andExpect(status().isNotFound());
@ -552,7 +546,7 @@ public class OAuth2ResourceServerBeanDefinitionParserTests {
this.spring.configLocations(xml("MockJwtDecoder"), xml("AccessDeniedHandler")).autowire(); this.spring.configLocations(xml("MockJwtDecoder"), xml("AccessDeniedHandler")).autowire();
JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class); JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class);
Mockito.when(decoder.decode(anyString())).thenReturn(jwt().build()); given(decoder.decode(anyString())).willReturn(TestJwts.jwt().build());
this.mvc.perform(get("/authenticated").header("Authorization", "Bearer insufficiently_scoped")) this.mvc.perform(get("/authenticated").header("Authorization", "Bearer insufficiently_scoped"))
.andExpect(status().isForbidden()) .andExpect(status().isForbidden())
@ -572,7 +566,7 @@ public class OAuth2ResourceServerBeanDefinitionParserTests {
OAuth2Error error = new OAuth2Error("custom-error", "custom-description", "custom-uri"); OAuth2Error error = new OAuth2Error("custom-error", "custom-description", "custom-uri");
when(jwtValidator.validate(any(Jwt.class))).thenReturn(OAuth2TokenValidatorResult.failure(error)); given(jwtValidator.validate(any(Jwt.class))).willReturn(OAuth2TokenValidatorResult.failure(error));
this.mvc.perform(get("/").header("Authorization", "Bearer " + token)).andExpect(status().isUnauthorized()) this.mvc.perform(get("/").header("Authorization", "Bearer " + token)).andExpect(status().isUnauthorized())
.andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, containsString("custom-description"))); .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, containsString("custom-description")));
@ -609,11 +603,11 @@ public class OAuth2ResourceServerBeanDefinitionParserTests {
Converter<Jwt, JwtAuthenticationToken> jwtAuthenticationConverter = (Converter<Jwt, JwtAuthenticationToken>) this.spring Converter<Jwt, JwtAuthenticationToken> jwtAuthenticationConverter = (Converter<Jwt, JwtAuthenticationToken>) this.spring
.getContext().getBean("jwtAuthenticationConverter"); .getContext().getBean("jwtAuthenticationConverter");
when(jwtAuthenticationConverter.convert(any(Jwt.class))) given(jwtAuthenticationConverter.convert(any(Jwt.class)))
.thenReturn(new JwtAuthenticationToken(jwt().build(), Collections.emptyList())); .willReturn(new JwtAuthenticationToken(TestJwts.jwt().build(), Collections.emptyList()));
JwtDecoder jwtDecoder = this.spring.getContext().getBean(JwtDecoder.class); JwtDecoder jwtDecoder = this.spring.getContext().getBean(JwtDecoder.class);
Mockito.when(jwtDecoder.decode(anyString())).thenReturn(jwt().build()); given(jwtDecoder.decode(anyString())).willReturn(TestJwts.jwt().build());
this.mvc.perform(get("/").header("Authorization", "Bearer token")).andExpect(status().isNotFound()); this.mvc.perform(get("/").header("Authorization", "Bearer token")).andExpect(status().isNotFound());
@ -702,8 +696,8 @@ public class OAuth2ResourceServerBeanDefinitionParserTests {
AuthenticationManagerResolver<HttpServletRequest> authenticationManagerResolver = this.spring.getContext() AuthenticationManagerResolver<HttpServletRequest> authenticationManagerResolver = this.spring.getContext()
.getBean(AuthenticationManagerResolver.class); .getBean(AuthenticationManagerResolver.class);
when(authenticationManagerResolver.resolve(any(HttpServletRequest.class))) given(authenticationManagerResolver.resolve(any(HttpServletRequest.class))).willReturn(
.thenReturn(authentication -> new JwtAuthenticationToken(jwt().build(), Collections.emptyList())); authentication -> new JwtAuthenticationToken(TestJwts.jwt().build(), Collections.emptyList()));
this.mvc.perform(get("/").header("Authorization", "Bearer token")).andExpect(status().isNotFound()); this.mvc.perform(get("/").header("Authorization", "Bearer token")).andExpect(status().isNotFound());
@ -754,7 +748,7 @@ public class OAuth2ResourceServerBeanDefinitionParserTests {
this.spring.configLocations(xml("MockJwtDecoder"), xml("BasicAndResourceServer")).autowire(); this.spring.configLocations(xml("MockJwtDecoder"), xml("BasicAndResourceServer")).autowire();
JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class); JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class);
when(decoder.decode(anyString())).thenThrow(JwtException.class); given(decoder.decode(anyString())).willThrow(JwtException.class);
this.mvc.perform(get("/authenticated").with(httpBasic("some", "user"))).andExpect(status().isUnauthorized()) this.mvc.perform(get("/authenticated").with(httpBasic("some", "user"))).andExpect(status().isUnauthorized())
.andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, startsWith("Basic"))); .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, startsWith("Basic")));
@ -775,7 +769,7 @@ public class OAuth2ResourceServerBeanDefinitionParserTests {
this.spring.configLocations(xml("MockJwtDecoder"), xml("FormAndResourceServer")).autowire(); this.spring.configLocations(xml("MockJwtDecoder"), xml("FormAndResourceServer")).autowire();
JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class); JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class);
when(decoder.decode(anyString())).thenThrow(JwtException.class); given(decoder.decode(anyString())).willThrow(JwtException.class);
MvcResult result = this.mvc.perform(get("/authenticated")).andExpect(status().isUnauthorized()).andReturn(); MvcResult result = this.mvc.perform(get("/authenticated")).andExpect(status().isUnauthorized()).andReturn();
@ -827,7 +821,8 @@ public class OAuth2ResourceServerBeanDefinitionParserTests {
OAuth2ResourceServerBeanDefinitionParser parser = new OAuth2ResourceServerBeanDefinitionParser(null, null, null, OAuth2ResourceServerBeanDefinitionParser parser = new OAuth2ResourceServerBeanDefinitionParser(null, null, null,
null, null); null, null);
Element element = mock(Element.class); Element element = mock(Element.class);
when(element.hasAttribute(AUTHENTICATION_MANAGER_RESOLVER_REF)).thenReturn(true); given(element.hasAttribute(OAuth2ResourceServerBeanDefinitionParser.AUTHENTICATION_MANAGER_RESOLVER_REF))
.willReturn(true);
Element child = mock(Element.class); Element child = mock(Element.class);
ParserContext pc = new ParserContext(mock(XmlReaderContext.class), mock(BeanDefinitionParserDelegate.class)); ParserContext pc = new ParserContext(mock(XmlReaderContext.class), mock(BeanDefinitionParserDelegate.class));
@ -844,7 +839,8 @@ public class OAuth2ResourceServerBeanDefinitionParserTests {
OAuth2ResourceServerBeanDefinitionParser parser = new OAuth2ResourceServerBeanDefinitionParser(null, null, null, OAuth2ResourceServerBeanDefinitionParser parser = new OAuth2ResourceServerBeanDefinitionParser(null, null, null,
null, null); null, null);
Element element = mock(Element.class); Element element = mock(Element.class);
when(element.hasAttribute(AUTHENTICATION_MANAGER_RESOLVER_REF)).thenReturn(false); given(element.hasAttribute(OAuth2ResourceServerBeanDefinitionParser.AUTHENTICATION_MANAGER_RESOLVER_REF))
.willReturn(false);
ParserContext pc = new ParserContext(mock(XmlReaderContext.class), mock(BeanDefinitionParserDelegate.class)); ParserContext pc = new ParserContext(mock(XmlReaderContext.class), mock(BeanDefinitionParserDelegate.class));
parser.validateConfiguration(element, null, null, pc); parser.validateConfiguration(element, null, null, pc);
verify(pc.getReaderContext()).error(anyString(), eq(element)); verify(pc.getReaderContext()).error(anyString(), eq(element));
@ -854,8 +850,8 @@ public class OAuth2ResourceServerBeanDefinitionParserTests {
public void validateConfigurationWhenBothJwtAttributesThenError() { public void validateConfigurationWhenBothJwtAttributesThenError() {
JwtBeanDefinitionParser parser = new JwtBeanDefinitionParser(); JwtBeanDefinitionParser parser = new JwtBeanDefinitionParser();
Element element = mock(Element.class); Element element = mock(Element.class);
when(element.hasAttribute(JWK_SET_URI)).thenReturn(true); given(element.hasAttribute(JwtBeanDefinitionParser.JWK_SET_URI)).willReturn(true);
when(element.hasAttribute(DECODER_REF)).thenReturn(true); given(element.hasAttribute(JwtBeanDefinitionParser.DECODER_REF)).willReturn(true);
ParserContext pc = new ParserContext(mock(XmlReaderContext.class), mock(BeanDefinitionParserDelegate.class)); ParserContext pc = new ParserContext(mock(XmlReaderContext.class), mock(BeanDefinitionParserDelegate.class));
parser.validateConfiguration(element, pc); parser.validateConfiguration(element, pc);
verify(pc.getReaderContext()).error(anyString(), eq(element)); verify(pc.getReaderContext()).error(anyString(), eq(element));
@ -865,8 +861,8 @@ public class OAuth2ResourceServerBeanDefinitionParserTests {
public void validateConfigurationWhenNoJwtAttributesThenError() { public void validateConfigurationWhenNoJwtAttributesThenError() {
JwtBeanDefinitionParser parser = new JwtBeanDefinitionParser(); JwtBeanDefinitionParser parser = new JwtBeanDefinitionParser();
Element element = mock(Element.class); Element element = mock(Element.class);
when(element.hasAttribute(JWK_SET_URI)).thenReturn(false); given(element.hasAttribute(JwtBeanDefinitionParser.JWK_SET_URI)).willReturn(false);
when(element.hasAttribute(DECODER_REF)).thenReturn(false); given(element.hasAttribute(JwtBeanDefinitionParser.DECODER_REF)).willReturn(false);
ParserContext pc = new ParserContext(mock(XmlReaderContext.class), mock(BeanDefinitionParserDelegate.class)); ParserContext pc = new ParserContext(mock(XmlReaderContext.class), mock(BeanDefinitionParserDelegate.class));
parser.validateConfiguration(element, pc); parser.validateConfiguration(element, pc);
verify(pc.getReaderContext()).error(anyString(), eq(element)); verify(pc.getReaderContext()).error(anyString(), eq(element));
@ -876,8 +872,8 @@ public class OAuth2ResourceServerBeanDefinitionParserTests {
public void validateConfigurationWhenBothOpaqueTokenModesThenError() { public void validateConfigurationWhenBothOpaqueTokenModesThenError() {
OpaqueTokenBeanDefinitionParser parser = new OpaqueTokenBeanDefinitionParser(); OpaqueTokenBeanDefinitionParser parser = new OpaqueTokenBeanDefinitionParser();
Element element = mock(Element.class); Element element = mock(Element.class);
when(element.hasAttribute(INTROSPECTION_URI)).thenReturn(true); given(element.hasAttribute(OpaqueTokenBeanDefinitionParser.INTROSPECTION_URI)).willReturn(true);
when(element.hasAttribute(INTROSPECTOR_REF)).thenReturn(true); given(element.hasAttribute(OpaqueTokenBeanDefinitionParser.INTROSPECTOR_REF)).willReturn(true);
ParserContext pc = new ParserContext(mock(XmlReaderContext.class), mock(BeanDefinitionParserDelegate.class)); ParserContext pc = new ParserContext(mock(XmlReaderContext.class), mock(BeanDefinitionParserDelegate.class));
parser.validateConfiguration(element, pc); parser.validateConfiguration(element, pc);
verify(pc.getReaderContext()).error(anyString(), eq(element)); verify(pc.getReaderContext()).error(anyString(), eq(element));
@ -887,8 +883,8 @@ public class OAuth2ResourceServerBeanDefinitionParserTests {
public void validateConfigurationWhenNoOpaqueTokenModeThenError() { public void validateConfigurationWhenNoOpaqueTokenModeThenError() {
OpaqueTokenBeanDefinitionParser parser = new OpaqueTokenBeanDefinitionParser(); OpaqueTokenBeanDefinitionParser parser = new OpaqueTokenBeanDefinitionParser();
Element element = mock(Element.class); Element element = mock(Element.class);
when(element.hasAttribute(INTROSPECTION_URI)).thenReturn(false); given(element.hasAttribute(OpaqueTokenBeanDefinitionParser.INTROSPECTION_URI)).willReturn(false);
when(element.hasAttribute(INTROSPECTOR_REF)).thenReturn(false); given(element.hasAttribute(OpaqueTokenBeanDefinitionParser.INTROSPECTOR_REF)).willReturn(false);
ParserContext pc = new ParserContext(mock(XmlReaderContext.class), mock(BeanDefinitionParserDelegate.class)); ParserContext pc = new ParserContext(mock(XmlReaderContext.class), mock(BeanDefinitionParserDelegate.class));
parser.validateConfiguration(element, pc); parser.validateConfiguration(element, pc);
verify(pc.getReaderContext()).error(anyString(), eq(element)); verify(pc.getReaderContext()).error(anyString(), eq(element));
@ -920,8 +916,8 @@ public class OAuth2ResourceServerBeanDefinitionParserTests {
private String jwtFromIssuer(String issuer) throws Exception { private String jwtFromIssuer(String issuer) throws Exception {
Map<String, Object> claims = new HashMap<>(); Map<String, Object> claims = new HashMap<>();
claims.put(ISS, issuer); claims.put(JwtClaimNames.ISS, issuer);
claims.put(SUB, "test-subject"); claims.put(JwtClaimNames.SUB, "test-subject");
claims.put("scope", "message:read"); claims.put("scope", "message:read");
JWSObject jws = new JWSObject(new JWSHeader.Builder(JWSAlgorithm.RS256).keyID("1").build(), JWSObject jws = new JWSObject(new JWSHeader.Builder(JWSAlgorithm.RS256).keyID("1").build(),
new Payload(new JSONObject(claims))); new Payload(new JSONObject(claims)));
@ -939,7 +935,7 @@ public class OAuth2ResourceServerBeanDefinitionParserTests {
HttpHeaders headers = new HttpHeaders(); HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON); headers.setContentType(MediaType.APPLICATION_JSON);
ResponseEntity<String> entity = new ResponseEntity<>(response, headers, HttpStatus.OK); ResponseEntity<String> entity = new ResponseEntity<>(response, headers, HttpStatus.OK);
Mockito.when(rest.exchange(any(RequestEntity.class), eq(String.class))).thenReturn(entity); given(rest.exchange(any(RequestEntity.class), eq(String.class))).willReturn(entity);
} }
private String json(String name) throws IOException { private String json(String name) throws IOException {

View File

@ -26,6 +26,7 @@ import okhttp3.mockwebserver.MockWebServer;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.openid4java.consumer.ConsumerManager; import org.openid4java.consumer.ConsumerManager;
import org.openid4java.discovery.yadis.YadisResolver;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.parsing.BeanDefinitionParsingException; import org.springframework.beans.factory.parsing.BeanDefinitionParsingException;
@ -48,7 +49,6 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.openid4java.discovery.yadis.YadisResolver.YADIS_XRDS_LOCATION;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl;
@ -147,7 +147,7 @@ public class OpenIDConfigTests {
try (MockWebServer server = new MockWebServer()) { try (MockWebServer server = new MockWebServer()) {
String endpoint = server.url("/").toString(); String endpoint = server.url("/").toString();
server.enqueue(new MockResponse().addHeader(YADIS_XRDS_LOCATION, endpoint)); server.enqueue(new MockResponse().addHeader(YadisResolver.YADIS_XRDS_LOCATION, endpoint));
server.enqueue(new MockResponse() server.enqueue(new MockResponse()
.setBody(String.format("<XRDS><XRD><Service><URI>%s</URI></Service></XRD></XRDS>", endpoint))); .setBody(String.format("<XRDS><XRD><Service><URI>%s</URI></Service></XRD></XRDS>", endpoint)));

View File

@ -30,6 +30,8 @@ import org.springframework.security.TestDataSource;
import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.config.test.SpringTestRule;
import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.User;
import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.core.userdetails.UserDetailsService;
import org.springframework.security.web.authentication.rememberme.AbstractRememberMeServices;
import org.springframework.security.web.authentication.rememberme.JdbcTokenRepositoryImpl;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult; import org.springframework.test.web.servlet.MvcResult;
import org.springframework.test.web.servlet.ResultActions; import org.springframework.test.web.servlet.ResultActions;
@ -43,9 +45,6 @@ import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
import static org.springframework.security.web.authentication.rememberme.AbstractRememberMeServices.DEFAULT_PARAMETER;
import static org.springframework.security.web.authentication.rememberme.AbstractRememberMeServices.SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY;
import static org.springframework.security.web.authentication.rememberme.JdbcTokenRepositoryImpl.CREATE_TABLE_SQL;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.cookie; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.cookie;
@ -73,7 +72,8 @@ public class RememberMeConfigTests {
this.spring.configLocations(this.xml("WithTokenRepository")).autowire(); this.spring.configLocations(this.xml("WithTokenRepository")).autowire();
MvcResult result = this.rememberAuthentication("user", "password") MvcResult result = this.rememberAuthentication("user", "password")
.andExpect(cookie().secure(SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, false)).andReturn(); .andExpect(cookie().secure(AbstractRememberMeServices.SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, false))
.andReturn();
Cookie cookie = rememberMeCookie(result); Cookie cookie = rememberMeCookie(result);
@ -91,10 +91,11 @@ public class RememberMeConfigTests {
TestDataSource dataSource = this.spring.getContext().getBean(TestDataSource.class); TestDataSource dataSource = this.spring.getContext().getBean(TestDataSource.class);
JdbcTemplate template = new JdbcTemplate(dataSource); JdbcTemplate template = new JdbcTemplate(dataSource);
template.execute(CREATE_TABLE_SQL); template.execute(JdbcTokenRepositoryImpl.CREATE_TABLE_SQL);
MvcResult result = this.rememberAuthentication("user", "password") MvcResult result = this.rememberAuthentication("user", "password")
.andExpect(cookie().secure(SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, false)).andReturn(); .andExpect(cookie().secure(AbstractRememberMeServices.SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, false))
.andReturn();
Cookie cookie = rememberMeCookie(result); Cookie cookie = rememberMeCookie(result);
@ -111,10 +112,11 @@ public class RememberMeConfigTests {
TestDataSource dataSource = this.spring.getContext().getBean(TestDataSource.class); TestDataSource dataSource = this.spring.getContext().getBean(TestDataSource.class);
JdbcTemplate template = new JdbcTemplate(dataSource); JdbcTemplate template = new JdbcTemplate(dataSource);
template.execute(CREATE_TABLE_SQL); template.execute(JdbcTokenRepositoryImpl.CREATE_TABLE_SQL);
MvcResult result = this.rememberAuthentication("user", "password") MvcResult result = this.rememberAuthentication("user", "password")
.andExpect(cookie().secure(SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, false)).andReturn(); .andExpect(cookie().secure(AbstractRememberMeServices.SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, false))
.andReturn();
Cookie cookie = rememberMeCookie(result); Cookie cookie = rememberMeCookie(result);
@ -130,8 +132,9 @@ public class RememberMeConfigTests {
this.spring.configLocations(this.xml("WithServicesRef")).autowire(); this.spring.configLocations(this.xml("WithServicesRef")).autowire();
MvcResult result = this.rememberAuthentication("user", "password") MvcResult result = this.rememberAuthentication("user", "password")
.andExpect(cookie().secure(SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, false)) .andExpect(cookie().secure(AbstractRememberMeServices.SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, false))
.andExpect(cookie().maxAge(SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, 5000)).andReturn(); .andExpect(cookie().maxAge(AbstractRememberMeServices.SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, 5000))
.andReturn();
Cookie cookie = rememberMeCookie(result); Cookie cookie = rememberMeCookie(result);
@ -139,7 +142,8 @@ public class RememberMeConfigTests {
// SEC-909 // SEC-909
this.mvc.perform(post("/logout").cookie(cookie).with(csrf())) this.mvc.perform(post("/logout").cookie(cookie).with(csrf()))
.andExpect(cookie().maxAge(SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, 0)).andReturn(); .andExpect(cookie().maxAge(AbstractRememberMeServices.SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, 0))
.andReturn();
} }
@Test @Test
@ -152,7 +156,7 @@ public class RememberMeConfigTests {
Cookie cookie = rememberMeCookie(result); Cookie cookie = rememberMeCookie(result);
this.mvc.perform(post("/logout").cookie(cookie).with(csrf())) this.mvc.perform(post("/logout").cookie(cookie).with(csrf()))
.andExpect(cookie().maxAge(SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, 0)); .andExpect(cookie().maxAge(AbstractRememberMeServices.SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, 0));
} }
@Test @Test
@ -162,7 +166,8 @@ public class RememberMeConfigTests {
this.spring.configLocations(this.xml("TokenValidity")).autowire(); this.spring.configLocations(this.xml("TokenValidity")).autowire();
MvcResult result = this.rememberAuthentication("user", "password") MvcResult result = this.rememberAuthentication("user", "password")
.andExpect(cookie().maxAge(SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, 10000)).andReturn(); .andExpect(cookie().maxAge(AbstractRememberMeServices.SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, 10000))
.andReturn();
Cookie cookie = rememberMeCookie(result); Cookie cookie = rememberMeCookie(result);
@ -175,7 +180,7 @@ public class RememberMeConfigTests {
this.spring.configLocations(this.xml("NegativeTokenValidity")).autowire(); this.spring.configLocations(this.xml("NegativeTokenValidity")).autowire();
this.rememberAuthentication("user", "password") this.rememberAuthentication("user", "password")
.andExpect(cookie().maxAge(SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, -1)); .andExpect(cookie().maxAge(AbstractRememberMeServices.SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, -1));
} }
@Test @Test
@ -191,7 +196,7 @@ public class RememberMeConfigTests {
this.spring.configLocations(this.xml("Sec2165")).autowire(); this.spring.configLocations(this.xml("Sec2165")).autowire();
this.rememberAuthentication("user", "password") this.rememberAuthentication("user", "password")
.andExpect(cookie().maxAge(SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, 30)); .andExpect(cookie().maxAge(AbstractRememberMeServices.SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, 30));
} }
@Test @Test
@ -200,7 +205,7 @@ public class RememberMeConfigTests {
this.spring.configLocations(this.xml("SecureCookie")).autowire(); this.spring.configLocations(this.xml("SecureCookie")).autowire();
this.rememberAuthentication("user", "password") this.rememberAuthentication("user", "password")
.andExpect(cookie().secure(SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, true)); .andExpect(cookie().secure(AbstractRememberMeServices.SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, true));
} }
/** /**
@ -212,7 +217,7 @@ public class RememberMeConfigTests {
this.spring.configLocations(this.xml("Sec1827")).autowire(); this.spring.configLocations(this.xml("Sec1827")).autowire();
this.rememberAuthentication("user", "password") this.rememberAuthentication("user", "password")
.andExpect(cookie().secure(SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, false)); .andExpect(cookie().secure(AbstractRememberMeServices.SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, false));
} }
@Test @Test
@ -304,7 +309,8 @@ public class RememberMeConfigTests {
private ResultActions rememberAuthentication(String username, String password) throws Exception { private ResultActions rememberAuthentication(String username, String password) throws Exception {
return this.mvc.perform(login(username, password).param(DEFAULT_PARAMETER, "true").with(csrf())) return this.mvc.perform(
login(username, password).param(AbstractRememberMeServices.DEFAULT_PARAMETER, "true").with(csrf()))
.andExpect(redirectedUrl("/")); .andExpect(redirectedUrl("/"));
} }

View File

@ -39,7 +39,7 @@ import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.core.StringContains.containsString; import static org.hamcrest.CoreMatchers.containsString;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.cookie; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.cookie;

View File

@ -39,7 +39,7 @@ import org.springframework.security.web.context.HttpRequestResponseHolder;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository; import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
import org.springframework.util.ReflectionUtils; import org.springframework.util.ReflectionUtils;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.assertj.core.api.Assertions.assertThat;
/** /**
* @author Rob Winch * @author Rob Winch

View File

@ -48,6 +48,7 @@ import org.springframework.security.web.authentication.logout.LogoutHandler;
import org.springframework.security.web.authentication.logout.LogoutSuccessEventPublishingLogoutHandler; import org.springframework.security.web.authentication.logout.LogoutSuccessEventPublishingLogoutHandler;
import org.springframework.security.web.authentication.session.SessionAuthenticationException; import org.springframework.security.web.authentication.session.SessionAuthenticationException;
import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy; import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
import org.springframework.security.web.session.ConcurrentSessionFilter; import org.springframework.security.web.session.ConcurrentSessionFilter;
import org.springframework.security.web.session.SessionManagementFilter; import org.springframework.security.web.session.SessionManagementFilter;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
@ -60,7 +61,6 @@ import org.springframework.web.context.WebApplicationContext;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic;
import static org.springframework.security.web.context.HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
@ -139,7 +139,8 @@ public class SessionManagementConfigTests {
assertThat(response.getStatus()).isEqualTo(HttpStatus.SC_MOVED_TEMPORARILY); assertThat(response.getStatus()).isEqualTo(HttpStatus.SC_MOVED_TEMPORARILY);
assertThat(request.getSession(false)).isNotNull(); assertThat(request.getSession(false)).isNotNull();
assertThat(request.getSession(false).getAttribute(SPRING_SECURITY_CONTEXT_KEY)).isNotNull(); assertThat(request.getSession(false)
.getAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY)).isNotNull();
} }
@Test @Test
@ -169,7 +170,8 @@ public class SessionManagementConfigTests {
.session(new MockHttpSession()).with(csrf())) .session(new MockHttpSession()).with(csrf()))
.andExpect(status().isFound()).andExpect(session()).andReturn(); .andExpect(status().isFound()).andExpect(session()).andReturn();
assertThat(result.getRequest().getSession(false).getAttribute(SPRING_SECURITY_CONTEXT_KEY)).isNull(); assertThat(result.getRequest().getSession(false)
.getAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY)).isNull();
} }
@Test @Test

View File

@ -37,7 +37,6 @@ import org.springframework.security.config.annotation.web.configuration.WebSecur
import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.FilterChainProxy;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.security.config.http.customconfigurer.CustomConfigurer.customConfigurer;
/** /**
* @author Rob Winch * @author Rob Winch
@ -126,7 +125,7 @@ public class CustomHttpSecurityConfigurerTests {
protected void configure(HttpSecurity http) throws Exception { protected void configure(HttpSecurity http) throws Exception {
// @formatter:off // @formatter:off
http http
.apply(customConfigurer()) .apply(CustomConfigurer.customConfigurer())
.loginPage("/custom"); .loginPage("/custom");
// @formatter:on // @formatter:on
} }
@ -151,7 +150,7 @@ public class CustomHttpSecurityConfigurerTests {
protected void configure(HttpSecurity http) throws Exception { protected void configure(HttpSecurity http) throws Exception {
// @formatter:off // @formatter:off
http http
.apply(customConfigurer()) .apply(CustomConfigurer.customConfigurer())
.and() .and()
.csrf().disable() .csrf().disable()
.formLogin() .formLogin()

View File

@ -59,7 +59,6 @@ import org.springframework.security.util.FieldUtils;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.fail; import static org.assertj.core.api.Assertions.fail;
import static org.springframework.security.config.ConfigTestUtils.AUTH_PROVIDER_XML;
/** /**
* @author Ben Alex * @author Ben Alex
@ -185,7 +184,8 @@ public class GlobalMethodSecurityBeanDefinitionParserTests {
+ "<global-method-security>" + " <protect-pointcut expression=" + "<global-method-security>" + " <protect-pointcut expression="
+ " 'execution(* org.springframework.security.access.annotation.BusinessService.*(..)) " + " 'execution(* org.springframework.security.access.annotation.BusinessService.*(..)) "
+ " and not execution(* org.springframework.security.access.annotation.BusinessService.someOther(String)))' " + " and not execution(* org.springframework.security.access.annotation.BusinessService.someOther(String)))' "
+ " access='ROLE_USER'/>" + "</global-method-security>" + AUTH_PROVIDER_XML); + " access='ROLE_USER'/>" + "</global-method-security>"
+ ConfigTestUtils.AUTH_PROVIDER_XML);
this.target = (BusinessService) this.appContext.getBean("target"); this.target = (BusinessService) this.appContext.getBean("target");
// String method should not be protected // String method should not be protected
this.target.someOther("somestring"); this.target.someOther("somestring");
@ -215,7 +215,7 @@ public class GlobalMethodSecurityBeanDefinitionParserTests {
+ "<b:bean id='businessService' class='org.springframework.remoting.httpinvoker.HttpInvokerProxyFactoryBean'>" + "<b:bean id='businessService' class='org.springframework.remoting.httpinvoker.HttpInvokerProxyFactoryBean'>"
+ " <b:property name='serviceUrl' value='http://localhost:8080/SomeService'/>" + " <b:property name='serviceUrl' value='http://localhost:8080/SomeService'/>"
+ " <b:property name='serviceInterface' value='org.springframework.security.access.annotation.BusinessService'/>" + " <b:property name='serviceInterface' value='org.springframework.security.access.annotation.BusinessService'/>"
+ "</b:bean>" + AUTH_PROVIDER_XML); + "</b:bean>" + ConfigTestUtils.AUTH_PROVIDER_XML);
UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("Test", "Password", UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("Test", "Password",
AuthorityUtils.createAuthorityList("ROLE_SOMEOTHERROLE")); AuthorityUtils.createAuthorityList("ROLE_SOMEOTHERROLE"));
@ -229,7 +229,7 @@ public class GlobalMethodSecurityBeanDefinitionParserTests {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Test @Test
public void expressionVoterAndAfterInvocationProviderUseSameExpressionHandlerInstance() throws Exception { public void expressionVoterAndAfterInvocationProviderUseSameExpressionHandlerInstance() throws Exception {
setContext("<global-method-security pre-post-annotations='enabled'/>" + AUTH_PROVIDER_XML); setContext("<global-method-security pre-post-annotations='enabled'/>" + ConfigTestUtils.AUTH_PROVIDER_XML);
AffirmativeBased adm = (AffirmativeBased) this.appContext.getBeansOfType(AffirmativeBased.class).values() AffirmativeBased adm = (AffirmativeBased) this.appContext.getBeansOfType(AffirmativeBased.class).values()
.toArray()[0]; .toArray()[0];
List voters = (List) FieldUtils.getFieldValue(adm, "decisionVoters"); List voters = (List) FieldUtils.getFieldValue(adm, "decisionVoters");
@ -247,7 +247,7 @@ public class GlobalMethodSecurityBeanDefinitionParserTests {
public void accessIsDeniedForHasRoleExpression() { public void accessIsDeniedForHasRoleExpression() {
setContext("<global-method-security pre-post-annotations='enabled'/>" setContext("<global-method-security pre-post-annotations='enabled'/>"
+ "<b:bean id='target' class='org.springframework.security.access.annotation.ExpressionProtectedBusinessServiceImpl'/>" + "<b:bean id='target' class='org.springframework.security.access.annotation.ExpressionProtectedBusinessServiceImpl'/>"
+ AUTH_PROVIDER_XML); + ConfigTestUtils.AUTH_PROVIDER_XML);
SecurityContextHolder.getContext().setAuthentication(this.bob); SecurityContextHolder.getContext().setAuthentication(this.bob);
this.target = (BusinessService) this.appContext.getBean("target"); this.target = (BusinessService) this.appContext.getBean("target");
this.target.someAdminMethod(); this.target.someAdminMethod();
@ -259,7 +259,7 @@ public class GlobalMethodSecurityBeanDefinitionParserTests {
+ "<b:bean id='number' class='java.lang.Integer'>" + " <b:constructor-arg value='1294'/>" + "<b:bean id='number' class='java.lang.Integer'>" + " <b:constructor-arg value='1294'/>"
+ "</b:bean>" + "</b:bean>"
+ "<b:bean id='target' class='org.springframework.security.access.annotation.ExpressionProtectedBusinessServiceImpl'/>" + "<b:bean id='target' class='org.springframework.security.access.annotation.ExpressionProtectedBusinessServiceImpl'/>"
+ AUTH_PROVIDER_XML); + ConfigTestUtils.AUTH_PROVIDER_XML);
SecurityContextHolder.getContext().setAuthentication(this.bob); SecurityContextHolder.getContext().setAuthentication(this.bob);
ExpressionProtectedBusinessServiceImpl target = (ExpressionProtectedBusinessServiceImpl) this.appContext ExpressionProtectedBusinessServiceImpl target = (ExpressionProtectedBusinessServiceImpl) this.appContext
.getBean("target"); .getBean("target");
@ -270,7 +270,7 @@ public class GlobalMethodSecurityBeanDefinitionParserTests {
public void preAndPostFilterAnnotationsWorkWithLists() { public void preAndPostFilterAnnotationsWorkWithLists() {
setContext("<global-method-security pre-post-annotations='enabled'/>" setContext("<global-method-security pre-post-annotations='enabled'/>"
+ "<b:bean id='target' class='org.springframework.security.access.annotation.ExpressionProtectedBusinessServiceImpl'/>" + "<b:bean id='target' class='org.springframework.security.access.annotation.ExpressionProtectedBusinessServiceImpl'/>"
+ AUTH_PROVIDER_XML); + ConfigTestUtils.AUTH_PROVIDER_XML);
SecurityContextHolder.getContext().setAuthentication(this.bob); SecurityContextHolder.getContext().setAuthentication(this.bob);
this.target = (BusinessService) this.appContext.getBean("target"); this.target = (BusinessService) this.appContext.getBean("target");
List<String> arg = new ArrayList<>(); List<String> arg = new ArrayList<>();
@ -289,7 +289,7 @@ public class GlobalMethodSecurityBeanDefinitionParserTests {
public void prePostFilterAnnotationWorksWithArrays() { public void prePostFilterAnnotationWorksWithArrays() {
setContext("<global-method-security pre-post-annotations='enabled'/>" setContext("<global-method-security pre-post-annotations='enabled'/>"
+ "<b:bean id='target' class='org.springframework.security.access.annotation.ExpressionProtectedBusinessServiceImpl'/>" + "<b:bean id='target' class='org.springframework.security.access.annotation.ExpressionProtectedBusinessServiceImpl'/>"
+ AUTH_PROVIDER_XML); + ConfigTestUtils.AUTH_PROVIDER_XML);
SecurityContextHolder.getContext().setAuthentication(this.bob); SecurityContextHolder.getContext().setAuthentication(this.bob);
this.target = (BusinessService) this.appContext.getBean("target"); this.target = (BusinessService) this.appContext.getBean("target");
Object[] arg = new String[] { "joe", "bob", "sam" }; Object[] arg = new String[] { "joe", "bob", "sam" };
@ -306,7 +306,7 @@ public class GlobalMethodSecurityBeanDefinitionParserTests {
+ "<b:bean id='expressionHandler' class='org.springframework.security.access.expression.method.DefaultMethodSecurityExpressionHandler'>" + "<b:bean id='expressionHandler' class='org.springframework.security.access.expression.method.DefaultMethodSecurityExpressionHandler'>"
+ " <b:property name='permissionEvaluator' ref='myPermissionEvaluator'/>" + "</b:bean>" + " <b:property name='permissionEvaluator' ref='myPermissionEvaluator'/>" + "</b:bean>"
+ "<b:bean id='myPermissionEvaluator' class='org.springframework.security.config.method.TestPermissionEvaluator'/>" + "<b:bean id='myPermissionEvaluator' class='org.springframework.security.config.method.TestPermissionEvaluator'/>"
+ AUTH_PROVIDER_XML); + ConfigTestUtils.AUTH_PROVIDER_XML);
} }
// SEC-1450 // SEC-1450
@ -317,7 +317,7 @@ public class GlobalMethodSecurityBeanDefinitionParserTests {
"<b:bean id='target' class='org.springframework.security.config.method.GlobalMethodSecurityBeanDefinitionParserTests$ConcreteFoo'/>" "<b:bean id='target' class='org.springframework.security.config.method.GlobalMethodSecurityBeanDefinitionParserTests$ConcreteFoo'/>"
+ "<global-method-security>" + "<global-method-security>"
+ " <protect-pointcut expression='execution(* org..*Foo.foo(..))' access='ROLE_USER'/>" + " <protect-pointcut expression='execution(* org..*Foo.foo(..))' access='ROLE_USER'/>"
+ "</global-method-security>" + AUTH_PROVIDER_XML); + "</global-method-security>" + ConfigTestUtils.AUTH_PROVIDER_XML);
Foo foo = (Foo) this.appContext.getBean("target"); Foo foo = (Foo) this.appContext.getBean("target");
foo.foo(new SecurityConfig("A")); foo.foo(new SecurityConfig("A"));
} }
@ -327,7 +327,7 @@ public class GlobalMethodSecurityBeanDefinitionParserTests {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void genericsMethodArgumentNamesAreResolved() { public void genericsMethodArgumentNamesAreResolved() {
setContext("<b:bean id='target' class='" + ConcreteFoo.class.getName() + "'/>" setContext("<b:bean id='target' class='" + ConcreteFoo.class.getName() + "'/>"
+ "<global-method-security pre-post-annotations='enabled'/>" + AUTH_PROVIDER_XML); + "<global-method-security pre-post-annotations='enabled'/>" + ConfigTestUtils.AUTH_PROVIDER_XML);
SecurityContextHolder.getContext().setAuthentication(this.bob); SecurityContextHolder.getContext().setAuthentication(this.bob);
Foo foo = (Foo) this.appContext.getBean("target"); Foo foo = (Foo) this.appContext.getBean("target");
foo.foo(new SecurityConfig("A")); foo.foo(new SecurityConfig("A"));
@ -341,7 +341,8 @@ public class GlobalMethodSecurityBeanDefinitionParserTests {
parent.registerSingleton("runAsMgr", RunAsManagerImpl.class, props); parent.registerSingleton("runAsMgr", RunAsManagerImpl.class, props);
parent.refresh(); parent.refresh();
setContext("<global-method-security run-as-manager-ref='runAsMgr'/>" + AUTH_PROVIDER_XML, parent); setContext("<global-method-security run-as-manager-ref='runAsMgr'/>" + ConfigTestUtils.AUTH_PROVIDER_XML,
parent);
RunAsManagerImpl ram = (RunAsManagerImpl) this.appContext.getBean("runAsMgr"); RunAsManagerImpl ram = (RunAsManagerImpl) this.appContext.getBean("runAsMgr");
MethodSecurityMetadataSourceAdvisor msi = (MethodSecurityMetadataSourceAdvisor) this.appContext MethodSecurityMetadataSourceAdvisor msi = (MethodSecurityMetadataSourceAdvisor) this.appContext
.getBeansOfType(MethodSecurityMetadataSourceAdvisor.class).values().toArray()[0]; .getBeansOfType(MethodSecurityMetadataSourceAdvisor.class).values().toArray()[0];
@ -355,7 +356,7 @@ public class GlobalMethodSecurityBeanDefinitionParserTests {
+ "<method-security-metadata-source id='mds'>" + " <protect method='" + Foo.class.getName() + "<method-security-metadata-source id='mds'>" + " <protect method='" + Foo.class.getName()
+ ".foo' access='ROLE_ADMIN'/>" + "</method-security-metadata-source>" + ".foo' access='ROLE_ADMIN'/>" + "</method-security-metadata-source>"
+ "<global-method-security pre-post-annotations='enabled' metadata-source-ref='mds'/>" + "<global-method-security pre-post-annotations='enabled' metadata-source-ref='mds'/>"
+ AUTH_PROVIDER_XML); + ConfigTestUtils.AUTH_PROVIDER_XML);
// External MDS should take precedence over PreAuthorize // External MDS should take precedence over PreAuthorize
SecurityContextHolder.getContext().setAuthentication(this.bob); SecurityContextHolder.getContext().setAuthentication(this.bob);
Foo foo = (Foo) this.appContext.getBean("target"); Foo foo = (Foo) this.appContext.getBean("target");
@ -377,7 +378,7 @@ public class GlobalMethodSecurityBeanDefinitionParserTests {
+ ".foo' access='ROLE_ADMIN'/>" + "</method-security-metadata-source>" + ".foo' access='ROLE_ADMIN'/>" + "</method-security-metadata-source>"
+ "<global-method-security pre-post-annotations='enabled' metadata-source-ref='mds' authentication-manager-ref='customAuthMgr'/>" + "<global-method-security pre-post-annotations='enabled' metadata-source-ref='mds' authentication-manager-ref='customAuthMgr'/>"
+ "<b:bean id='customAuthMgr' class='org.springframework.security.config.method.GlobalMethodSecurityBeanDefinitionParserTests$CustomAuthManager'>" + "<b:bean id='customAuthMgr' class='org.springframework.security.config.method.GlobalMethodSecurityBeanDefinitionParserTests$CustomAuthManager'>"
+ " <b:constructor-arg value='authManager'/>" + "</b:bean>" + AUTH_PROVIDER_XML); + " <b:constructor-arg value='authManager'/>" + "</b:bean>" + ConfigTestUtils.AUTH_PROVIDER_XML);
SecurityContextHolder.getContext().setAuthentication(this.bob); SecurityContextHolder.getContext().setAuthentication(this.bob);
Foo foo = (Foo) this.appContext.getBean("target"); Foo foo = (Foo) this.appContext.getBean("target");
try { try {

View File

@ -26,7 +26,7 @@ import org.springframework.security.provisioning.UserDetailsManager;
import org.springframework.security.util.InMemoryResource; import org.springframework.security.util.InMemoryResource;
import org.springframework.test.context.junit4.SpringRunner; import org.springframework.test.context.junit4.SpringRunner;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.assertj.core.api.Assertions.assertThat;
/** /**
* @author Rob Winch * @author Rob Winch

View File

@ -25,7 +25,7 @@ import org.springframework.context.annotation.Configuration;
import org.springframework.security.provisioning.UserDetailsManager; import org.springframework.security.provisioning.UserDetailsManager;
import org.springframework.test.context.junit4.SpringRunner; import org.springframework.test.context.junit4.SpringRunner;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.assertj.core.api.Assertions.assertThat;
/** /**
* @author Rob Winch * @author Rob Winch

View File

@ -25,7 +25,7 @@ import org.springframework.context.annotation.Configuration;
import org.springframework.security.provisioning.UserDetailsManager; import org.springframework.security.provisioning.UserDetailsManager;
import org.springframework.test.context.junit4.SpringRunner; import org.springframework.test.context.junit4.SpringRunner;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.assertj.core.api.Assertions.assertThat;
/** /**
* @author Rob Winch * @author Rob Winch

View File

@ -28,6 +28,7 @@ import javax.servlet.http.HttpServletResponse;
import org.springframework.beans.factory.annotation.AutowiredAnnotationBeanPostProcessor; import org.springframework.beans.factory.annotation.AutowiredAnnotationBeanPostProcessor;
import org.springframework.mock.web.MockServletConfig; import org.springframework.mock.web.MockServletConfig;
import org.springframework.mock.web.MockServletContext; import org.springframework.mock.web.MockServletContext;
import org.springframework.security.config.BeanIds;
import org.springframework.security.config.util.InMemoryXmlWebApplicationContext; import org.springframework.security.config.util.InMemoryXmlWebApplicationContext;
import org.springframework.test.context.web.GenericXmlWebContextLoader; import org.springframework.test.context.web.GenericXmlWebContextLoader;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
@ -41,7 +42,6 @@ import org.springframework.web.context.support.AnnotationConfigWebApplicationCon
import org.springframework.web.context.support.XmlWebApplicationContext; import org.springframework.web.context.support.XmlWebApplicationContext;
import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.filter.OncePerRequestFilter;
import static org.springframework.security.config.BeanIds.SPRING_SECURITY_FILTER_CHAIN;
import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity;
/** /**
@ -129,7 +129,7 @@ public class SpringTestContext implements Closeable {
this.context.setServletConfig(new MockServletConfig()); this.context.setServletConfig(new MockServletConfig());
this.context.refresh(); this.context.refresh();
if (this.context.containsBean(SPRING_SECURITY_FILTER_CHAIN)) { if (this.context.containsBean(BeanIds.SPRING_SECURITY_FILTER_CHAIN)) {
MockMvc mockMvc = MockMvcBuilders.webAppContextSetup(this.context).apply(springSecurity()) MockMvc mockMvc = MockMvcBuilders.webAppContextSetup(this.context).apply(springSecurity())
.apply(new AddFilter()).build(); .apply(new AddFilter()).build();
this.context.getBeanFactory().registerResolvableDependency(MockMvc.class, mockMvc); this.context.getBeanFactory().registerResolvableDependency(MockMvc.class, mockMvc);

View File

@ -23,10 +23,6 @@ import org.springframework.core.io.Resource;
import org.springframework.security.util.InMemoryResource; import org.springframework.security.util.InMemoryResource;
import org.springframework.web.context.support.AbstractRefreshableWebApplicationContext; import org.springframework.web.context.support.AbstractRefreshableWebApplicationContext;
import static org.springframework.security.config.util.InMemoryXmlApplicationContext.BEANS_CLOSE;
import static org.springframework.security.config.util.InMemoryXmlApplicationContext.BEANS_OPENING;
import static org.springframework.security.config.util.InMemoryXmlApplicationContext.SPRING_SECURITY_VERSION;
/** /**
* @author Joe Grandja * @author Joe Grandja
*/ */
@ -35,15 +31,16 @@ public class InMemoryXmlWebApplicationContext extends AbstractRefreshableWebAppl
private Resource inMemoryXml; private Resource inMemoryXml;
public InMemoryXmlWebApplicationContext(String xml) { public InMemoryXmlWebApplicationContext(String xml) {
this(xml, SPRING_SECURITY_VERSION, null); this(xml, InMemoryXmlApplicationContext.SPRING_SECURITY_VERSION, null);
} }
public InMemoryXmlWebApplicationContext(String xml, ApplicationContext parent) { public InMemoryXmlWebApplicationContext(String xml, ApplicationContext parent) {
this(xml, SPRING_SECURITY_VERSION, parent); this(xml, InMemoryXmlApplicationContext.SPRING_SECURITY_VERSION, parent);
} }
public InMemoryXmlWebApplicationContext(String xml, String secVersion, ApplicationContext parent) { public InMemoryXmlWebApplicationContext(String xml, String secVersion, ApplicationContext parent) {
String fullXml = BEANS_OPENING + secVersion + ".xsd'>\n" + xml + BEANS_CLOSE; String fullXml = InMemoryXmlApplicationContext.BEANS_OPENING + secVersion + ".xsd'>\n" + xml
+ InMemoryXmlApplicationContext.BEANS_CLOSE;
this.inMemoryXml = new InMemoryResource(fullXml); this.inMemoryXml = new InMemoryResource(fullXml);
setAllowBeanDefinitionOverriding(true); setAllowBeanDefinitionOverriding(true);
setParent(parent); setParent(parent);

View File

@ -37,7 +37,7 @@ import org.springframework.test.web.reactive.server.WebTestClient;
import org.springframework.web.cors.CorsConfiguration; import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.reactive.CorsConfigurationSource; import org.springframework.web.cors.reactive.CorsConfigurationSource;
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;

View File

@ -39,7 +39,7 @@ import org.springframework.security.web.server.header.XXssProtectionServerHttpHe
import org.springframework.test.web.reactive.server.FluxExchangeResult; import org.springframework.test.web.reactive.server.FluxExchangeResult;
import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.test.web.reactive.server.WebTestClient;
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.config.Customizer.withDefaults;
/** /**

View File

@ -80,6 +80,7 @@ import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtValidationException; import org.springframework.security.oauth2.jwt.JwtValidationException;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory; import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory;
import org.springframework.security.oauth2.jwt.TestJwts;
import org.springframework.security.test.web.reactive.server.WebTestClientBuilder; import org.springframework.security.test.web.reactive.server.WebTestClientBuilder;
import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.security.web.server.SecurityWebFilterChain;
import org.springframework.security.web.server.WebFilterChainProxy; import org.springframework.security.web.server.WebFilterChainProxy;
@ -108,7 +109,6 @@ import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.springframework.security.oauth2.jwt.TestJwts.jwt;
/** /**
* @author Rob Winch * @author Rob Winch
@ -680,7 +680,7 @@ public class OAuth2LoginTests {
claims.put(IdTokenClaimNames.ISS, "http://localhost/issuer"); claims.put(IdTokenClaimNames.ISS, "http://localhost/issuer");
claims.put(IdTokenClaimNames.AUD, Collections.singletonList("client")); claims.put(IdTokenClaimNames.AUD, Collections.singletonList("client"));
claims.put(IdTokenClaimNames.AZP, "client"); claims.put(IdTokenClaimNames.AZP, "client");
Jwt jwt = jwt().claims(c -> c.putAll(claims)).build(); Jwt jwt = TestJwts.jwt().claims(c -> c.putAll(claims)).build();
return Mono.just(jwt); return Mono.just(jwt);
}; };
} }

View File

@ -61,6 +61,7 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
import org.springframework.security.oauth2.jwt.TestJwts;
import org.springframework.security.oauth2.server.resource.BearerTokenAuthenticationToken; import org.springframework.security.oauth2.server.resource.BearerTokenAuthenticationToken;
import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationConverter; import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationConverter;
import org.springframework.security.oauth2.server.resource.authentication.ReactiveJwtAuthenticationConverterAdapter; import org.springframework.security.oauth2.server.resource.authentication.ReactiveJwtAuthenticationConverterAdapter;
@ -80,13 +81,12 @@ import org.springframework.web.server.ServerWebExchange;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatCode;
import static org.hamcrest.core.StringStartsWith.startsWith; import static org.hamcrest.CoreMatchers.startsWith;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.springframework.security.oauth2.jwt.TestJwts.jwt;
/** /**
* Tests for * Tests for
@ -108,7 +108,7 @@ public class OAuth2ResourceServerSpecTests {
+ " \"n\":\"0IUjrPZDz-3z0UE4ppcKU36v7hnh8FJjhu3lbJYj0qj9eZiwEJxi9HHUfSK1DhUQG7mJBbYTK1tPYCgre5EkfKh-64VhYUa-vz17zYCmuB8fFj4XHE3MLkWIG-AUn8hNbPzYYmiBTjfGnMKxLHjsbdTiF4mtn-85w366916R6midnAuiPD4HjZaZ1PAsuY60gr8bhMEDtJ8unz81hoQrozpBZJ6r8aR1PrsWb1OqPMloK9kAIutJNvWYKacp8WYAp2WWy72PxQ7Fb0eIA1br3A5dnp-Cln6JROJcZUIRJ-QvS6QONWeS2407uQmS-i-lybsqaH0ldYC7NBEBA5inPQ\"\n" + " \"n\":\"0IUjrPZDz-3z0UE4ppcKU36v7hnh8FJjhu3lbJYj0qj9eZiwEJxi9HHUfSK1DhUQG7mJBbYTK1tPYCgre5EkfKh-64VhYUa-vz17zYCmuB8fFj4XHE3MLkWIG-AUn8hNbPzYYmiBTjfGnMKxLHjsbdTiF4mtn-85w366916R6midnAuiPD4HjZaZ1PAsuY60gr8bhMEDtJ8unz81hoQrozpBZJ6r8aR1PrsWb1OqPMloK9kAIutJNvWYKacp8WYAp2WWy72PxQ7Fb0eIA1br3A5dnp-Cln6JROJcZUIRJ-QvS6QONWeS2407uQmS-i-lybsqaH0ldYC7NBEBA5inPQ\"\n"
+ " }\n" + " ]\n" + "}\n"; + " }\n" + " ]\n" + "}\n";
private Jwt jwt = jwt().build(); private Jwt jwt = TestJwts.jwt().build();
private String clientId = "client"; private String clientId = "client";

View File

@ -61,6 +61,7 @@ import org.springframework.security.web.server.csrf.CsrfWebFilter;
import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository; import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository;
import org.springframework.security.web.server.savedrequest.ServerRequestCache; import org.springframework.security.web.server.savedrequest.ServerRequestCache;
import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache; import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.test.web.reactive.server.EntityExchangeResult; import org.springframework.test.web.reactive.server.EntityExchangeResult;
import org.springframework.test.web.reactive.server.FluxExchangeResult; import org.springframework.test.web.reactive.server.FluxExchangeResult;
import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.test.web.reactive.server.WebTestClient;
@ -78,7 +79,6 @@ import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.verifyZeroInteractions;
import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.config.Customizer.withDefaults;
import static org.springframework.test.util.ReflectionTestUtils.getField;
/** /**
* @author Rob Winch * @author Rob Winch
@ -187,8 +187,8 @@ public class ServerHttpSecurityTests {
assertThat(getWebFilter(securityWebFilterChain, CsrfWebFilter.class)).isNotPresent(); assertThat(getWebFilter(securityWebFilterChain, CsrfWebFilter.class)).isNotPresent();
Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class) Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class)
.map(logoutWebFilter -> (ServerLogoutHandler) getField(logoutWebFilter, LogoutWebFilter.class, .map(logoutWebFilter -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter,
"logoutHandler")); LogoutWebFilter.class, "logoutHandler"));
assertThat(logoutHandler).get().isExactlyInstanceOf(SecurityContextServerLogoutHandler.class); assertThat(logoutHandler).get().isExactlyInstanceOf(SecurityContextServerLogoutHandler.class);
} }
@ -199,17 +199,17 @@ public class ServerHttpSecurityTests {
.and().build(); .and().build();
assertThat(getWebFilter(securityWebFilterChain, CsrfWebFilter.class)).get() assertThat(getWebFilter(securityWebFilterChain, CsrfWebFilter.class)).get()
.extracting(csrfWebFilter -> getField(csrfWebFilter, "csrfTokenRepository")) .extracting(csrfWebFilter -> ReflectionTestUtils.getField(csrfWebFilter, "csrfTokenRepository"))
.isEqualTo(this.csrfTokenRepository); .isEqualTo(this.csrfTokenRepository);
Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class) Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class)
.map(logoutWebFilter -> (ServerLogoutHandler) getField(logoutWebFilter, LogoutWebFilter.class, .map(logoutWebFilter -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter,
"logoutHandler")); LogoutWebFilter.class, "logoutHandler"));
assertThat(logoutHandler).get().isExactlyInstanceOf(DelegatingServerLogoutHandler.class) assertThat(logoutHandler).get().isExactlyInstanceOf(DelegatingServerLogoutHandler.class)
.extracting(delegatingLogoutHandler -> ((List<ServerLogoutHandler>) getField(delegatingLogoutHandler, .extracting(delegatingLogoutHandler -> ((List<ServerLogoutHandler>) ReflectionTestUtils
DelegatingServerLogoutHandler.class, "delegates")).stream().map(ServerLogoutHandler::getClass) .getField(delegatingLogoutHandler, DelegatingServerLogoutHandler.class, "delegates")).stream()
.collect(Collectors.toList())) .map(ServerLogoutHandler::getClass).collect(Collectors.toList()))
.isEqualTo(Arrays.asList(SecurityContextServerLogoutHandler.class, CsrfServerLogoutHandler.class)); .isEqualTo(Arrays.asList(SecurityContextServerLogoutHandler.class, CsrfServerLogoutHandler.class));
} }
@ -439,8 +439,8 @@ public class ServerHttpSecurityTests {
OAuth2LoginAuthenticationWebFilter authenticationWebFilter = getWebFilter(securityFilterChain, OAuth2LoginAuthenticationWebFilter authenticationWebFilter = getWebFilter(securityFilterChain,
OAuth2LoginAuthenticationWebFilter.class).get(); OAuth2LoginAuthenticationWebFilter.class).get();
Object handler = getField(authenticationWebFilter, "authenticationSuccessHandler"); Object handler = ReflectionTestUtils.getField(authenticationWebFilter, "authenticationSuccessHandler");
assertThat(getField(handler, "requestCache")).isSameAs(requestCache); assertThat(ReflectionTestUtils.getField(handler, "requestCache")).isSameAs(requestCache);
} }
@Test @Test
@ -467,7 +467,7 @@ public class ServerHttpSecurityTests {
private boolean isX509Filter(WebFilter filter) { private boolean isX509Filter(WebFilter filter) {
try { try {
Object converter = getField(filter, "authenticationConverter"); Object converter = ReflectionTestUtils.getField(filter, "authenticationConverter");
return converter.getClass().isAssignableFrom(ServerX509AuthenticationConverter.class); return converter.getClass().isAssignableFrom(ServerX509AuthenticationConverter.class);
} }
catch (IllegalArgumentException e) { catch (IllegalArgumentException e) {

View File

@ -15,6 +15,7 @@
*/ */
package org.springframework.security.access.hierarchicalroles; package org.springframework.security.access.hierarchicalroles;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
@ -23,7 +24,6 @@ import java.util.TreeMap;
import org.junit.Test; import org.junit.Test;
import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
/** /**
@ -45,9 +45,9 @@ public class RoleHierarchyUtilsTests {
// @formatter:on // @formatter:on
Map<String, List<String>> roleHierarchyMap = new TreeMap<>(); Map<String, List<String>> roleHierarchyMap = new TreeMap<>();
roleHierarchyMap.put("ROLE_A", asList("ROLE_B", "ROLE_C")); roleHierarchyMap.put("ROLE_A", Arrays.asList("ROLE_B", "ROLE_C"));
roleHierarchyMap.put("ROLE_B", asList("ROLE_D")); roleHierarchyMap.put("ROLE_B", Arrays.asList("ROLE_D"));
roleHierarchyMap.put("ROLE_C", asList("ROLE_D")); roleHierarchyMap.put("ROLE_C", Arrays.asList("ROLE_D"));
String roleHierarchy = RoleHierarchyUtils.roleHierarchyFromMap(roleHierarchyMap); String roleHierarchy = RoleHierarchyUtils.roleHierarchyFromMap(roleHierarchyMap);
@ -67,7 +67,7 @@ public class RoleHierarchyUtilsTests {
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
public void roleHierarchyFromMapWhenRoleNullThenThrowsIllegalArgumentException() { public void roleHierarchyFromMapWhenRoleNullThenThrowsIllegalArgumentException() {
Map<String, List<String>> roleHierarchyMap = new HashMap<>(); Map<String, List<String>> roleHierarchyMap = new HashMap<>();
roleHierarchyMap.put(null, asList("ROLE_B", "ROLE_C")); roleHierarchyMap.put(null, Arrays.asList("ROLE_B", "ROLE_C"));
RoleHierarchyUtils.roleHierarchyFromMap(roleHierarchyMap); RoleHierarchyUtils.roleHierarchyFromMap(roleHierarchyMap);
} }
@ -75,7 +75,7 @@ public class RoleHierarchyUtilsTests {
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
public void roleHierarchyFromMapWhenRoleEmptyThenThrowsIllegalArgumentException() { public void roleHierarchyFromMapWhenRoleEmptyThenThrowsIllegalArgumentException() {
Map<String, List<String>> roleHierarchyMap = new HashMap<>(); Map<String, List<String>> roleHierarchyMap = new HashMap<>();
roleHierarchyMap.put("", asList("ROLE_B", "ROLE_C")); roleHierarchyMap.put("", Arrays.asList("ROLE_B", "ROLE_C"));
RoleHierarchyUtils.roleHierarchyFromMap(roleHierarchyMap); RoleHierarchyUtils.roleHierarchyFromMap(roleHierarchyMap);
} }

View File

@ -22,7 +22,7 @@ import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.mockito.Mock; import org.mockito.Mock;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -55,7 +55,7 @@ public abstract class AbstractDelegatingSecurityContextScheduledExecutorServiceT
given((ScheduledFuture<Object>) this.delegate.schedule(this.wrappedRunnable, 1, TimeUnit.SECONDS)) given((ScheduledFuture<Object>) this.delegate.schedule(this.wrappedRunnable, 1, TimeUnit.SECONDS))
.willReturn(this.expectedResult); .willReturn(this.expectedResult);
ScheduledFuture<?> result = this.executor.schedule(this.runnable, 1, TimeUnit.SECONDS); ScheduledFuture<?> result = this.executor.schedule(this.runnable, 1, TimeUnit.SECONDS);
assertThat(result).isEqualTo(this.expectedResult); assertThat((Object) result).isEqualTo(this.expectedResult);
verify(this.delegate).schedule(this.wrappedRunnable, 1, TimeUnit.SECONDS); verify(this.delegate).schedule(this.wrappedRunnable, 1, TimeUnit.SECONDS);
} }
@ -63,7 +63,7 @@ public abstract class AbstractDelegatingSecurityContextScheduledExecutorServiceT
public void scheduleCallable() { public void scheduleCallable() {
given(this.delegate.schedule(this.wrappedCallable, 1, TimeUnit.SECONDS)).willReturn(this.expectedResult); given(this.delegate.schedule(this.wrappedCallable, 1, TimeUnit.SECONDS)).willReturn(this.expectedResult);
ScheduledFuture<Object> result = this.executor.schedule(this.callable, 1, TimeUnit.SECONDS); ScheduledFuture<Object> result = this.executor.schedule(this.callable, 1, TimeUnit.SECONDS);
assertThat(result).isEqualTo(this.expectedResult); assertThat((Object) result).isEqualTo(this.expectedResult);
verify(this.delegate).schedule(this.wrappedCallable, 1, TimeUnit.SECONDS); verify(this.delegate).schedule(this.wrappedCallable, 1, TimeUnit.SECONDS);
} }
@ -73,7 +73,7 @@ public abstract class AbstractDelegatingSecurityContextScheduledExecutorServiceT
given((ScheduledFuture<Object>) this.delegate.scheduleAtFixedRate(this.wrappedRunnable, 1, 2, TimeUnit.SECONDS)) given((ScheduledFuture<Object>) this.delegate.scheduleAtFixedRate(this.wrappedRunnable, 1, 2, TimeUnit.SECONDS))
.willReturn(this.expectedResult); .willReturn(this.expectedResult);
ScheduledFuture<?> result = this.executor.scheduleAtFixedRate(this.runnable, 1, 2, TimeUnit.SECONDS); ScheduledFuture<?> result = this.executor.scheduleAtFixedRate(this.runnable, 1, 2, TimeUnit.SECONDS);
assertThat(result).isEqualTo(this.expectedResult); assertThat((Object) result).isEqualTo(this.expectedResult);
verify(this.delegate).scheduleAtFixedRate(this.wrappedRunnable, 1, 2, TimeUnit.SECONDS); verify(this.delegate).scheduleAtFixedRate(this.wrappedRunnable, 1, 2, TimeUnit.SECONDS);
} }
@ -83,7 +83,7 @@ public abstract class AbstractDelegatingSecurityContextScheduledExecutorServiceT
given((ScheduledFuture<Object>) this.delegate.scheduleWithFixedDelay(this.wrappedRunnable, 1, 2, given((ScheduledFuture<Object>) this.delegate.scheduleWithFixedDelay(this.wrappedRunnable, 1, 2,
TimeUnit.SECONDS)).willReturn(this.expectedResult); TimeUnit.SECONDS)).willReturn(this.expectedResult);
ScheduledFuture<?> result = this.executor.scheduleWithFixedDelay(this.runnable, 1, 2, TimeUnit.SECONDS); ScheduledFuture<?> result = this.executor.scheduleWithFixedDelay(this.runnable, 1, 2, TimeUnit.SECONDS);
assertThat(result).isEqualTo(this.expectedResult); assertThat((Object) result).isEqualTo(this.expectedResult);
verify(this.delegate).scheduleWithFixedDelay(this.wrappedRunnable, 1, 2, TimeUnit.SECONDS); verify(this.delegate).scheduleWithFixedDelay(this.wrappedRunnable, 1, 2, TimeUnit.SECONDS);
} }

View File

@ -23,6 +23,7 @@ import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.Captor; import org.mockito.Captor;
import org.mockito.Mock; import org.mockito.Mock;
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner; import org.powermock.modules.junit4.PowerMockRunner;
@ -30,8 +31,6 @@ import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.powermock.api.mockito.PowerMockito.doReturn;
import static org.powermock.api.mockito.PowerMockito.spy;
/** /**
* Abstract base class for testing classes that extend * Abstract base class for testing classes that extend
@ -67,19 +66,21 @@ public abstract class AbstractDelegatingSecurityContextTestSupport {
protected Runnable wrappedRunnable; protected Runnable wrappedRunnable;
public final void explicitSecurityContextPowermockSetup() throws Exception { public final void explicitSecurityContextPowermockSetup() throws Exception {
spy(DelegatingSecurityContextCallable.class); PowerMockito.spy(DelegatingSecurityContextCallable.class);
doReturn(this.wrappedCallable).when(DelegatingSecurityContextCallable.class, "create", eq(this.callable), PowerMockito.doReturn(this.wrappedCallable).when(DelegatingSecurityContextCallable.class, "create",
this.securityContextCaptor.capture()); eq(this.callable), this.securityContextCaptor.capture());
spy(DelegatingSecurityContextRunnable.class); PowerMockito.spy(DelegatingSecurityContextRunnable.class);
doReturn(this.wrappedRunnable).when(DelegatingSecurityContextRunnable.class, "create", eq(this.runnable), PowerMockito.doReturn(this.wrappedRunnable).when(DelegatingSecurityContextRunnable.class, "create",
this.securityContextCaptor.capture()); eq(this.runnable), this.securityContextCaptor.capture());
} }
public final void currentSecurityContextPowermockSetup() throws Exception { public final void currentSecurityContextPowermockSetup() throws Exception {
spy(DelegatingSecurityContextCallable.class); PowerMockito.spy(DelegatingSecurityContextCallable.class);
doReturn(this.wrappedCallable).when(DelegatingSecurityContextCallable.class, "create", this.callable, null); PowerMockito.doReturn(this.wrappedCallable).when(DelegatingSecurityContextCallable.class, "create",
spy(DelegatingSecurityContextRunnable.class); this.callable, null);
doReturn(this.wrappedRunnable).when(DelegatingSecurityContextRunnable.class, "create", this.runnable, null); PowerMockito.spy(DelegatingSecurityContextRunnable.class);
PowerMockito.doReturn(this.wrappedRunnable).when(DelegatingSecurityContextRunnable.class, "create",
this.runnable, null);
} }
@Before @Before

View File

@ -21,6 +21,7 @@ import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.Mock; import org.mockito.Mock;
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner; import org.powermock.modules.junit4.PowerMockRunner;
import org.powermock.reflect.Whitebox; import org.powermock.reflect.Whitebox;
@ -33,8 +34,6 @@ import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.verifyZeroInteractions;
import static org.powermock.api.mockito.PowerMockito.doReturn;
import static org.powermock.api.mockito.PowerMockito.spy;
/** /**
* Checks that the embedded version information is up to date. * Checks that the embedded version information is up to date.
@ -83,10 +82,10 @@ public class SpringSecurityCoreVersionTests {
@Test @Test
public void noLoggingIfVersionsAreEqual() throws Exception { public void noLoggingIfVersionsAreEqual() throws Exception {
String version = "1"; String version = "1";
spy(SpringSecurityCoreVersion.class); PowerMockito.spy(SpringSecurityCoreVersion.class);
spy(SpringVersion.class); PowerMockito.spy(SpringVersion.class);
doReturn(version).when(SpringSecurityCoreVersion.class, "getVersion"); PowerMockito.doReturn(version).when(SpringSecurityCoreVersion.class, "getVersion");
doReturn(version).when(SpringVersion.class, "getVersion"); PowerMockito.doReturn(version).when(SpringVersion.class, "getVersion");
performChecks(); performChecks();
@ -95,10 +94,10 @@ public class SpringSecurityCoreVersionTests {
@Test @Test
public void noLoggingIfSpringVersionNull() throws Exception { public void noLoggingIfSpringVersionNull() throws Exception {
spy(SpringSecurityCoreVersion.class); PowerMockito.spy(SpringSecurityCoreVersion.class);
spy(SpringVersion.class); PowerMockito.spy(SpringVersion.class);
doReturn("1").when(SpringSecurityCoreVersion.class, "getVersion"); PowerMockito.doReturn("1").when(SpringSecurityCoreVersion.class, "getVersion");
doReturn(null).when(SpringVersion.class, "getVersion"); PowerMockito.doReturn(null).when(SpringVersion.class, "getVersion");
performChecks(); performChecks();
@ -107,10 +106,10 @@ public class SpringSecurityCoreVersionTests {
@Test @Test
public void warnIfSpringVersionTooSmall() throws Exception { public void warnIfSpringVersionTooSmall() throws Exception {
spy(SpringSecurityCoreVersion.class); PowerMockito.spy(SpringSecurityCoreVersion.class);
spy(SpringVersion.class); PowerMockito.spy(SpringVersion.class);
doReturn("3").when(SpringSecurityCoreVersion.class, "getVersion"); PowerMockito.doReturn("3").when(SpringSecurityCoreVersion.class, "getVersion");
doReturn("2").when(SpringVersion.class, "getVersion"); PowerMockito.doReturn("2").when(SpringVersion.class, "getVersion");
performChecks(); performChecks();
@ -119,10 +118,10 @@ public class SpringSecurityCoreVersionTests {
@Test @Test
public void noWarnIfSpringVersionLarger() throws Exception { public void noWarnIfSpringVersionLarger() throws Exception {
spy(SpringSecurityCoreVersion.class); PowerMockito.spy(SpringSecurityCoreVersion.class);
spy(SpringVersion.class); PowerMockito.spy(SpringVersion.class);
doReturn("4.0.0.RELEASE").when(SpringSecurityCoreVersion.class, "getVersion"); PowerMockito.doReturn("4.0.0.RELEASE").when(SpringSecurityCoreVersion.class, "getVersion");
doReturn("4.0.0.RELEASE").when(SpringVersion.class, "getVersion"); PowerMockito.doReturn("4.0.0.RELEASE").when(SpringVersion.class, "getVersion");
performChecks(); performChecks();
@ -133,10 +132,10 @@ public class SpringSecurityCoreVersionTests {
@Test @Test
public void noWarnIfSpringPatchVersionDoubleDigits() throws Exception { public void noWarnIfSpringPatchVersionDoubleDigits() throws Exception {
String minSpringVersion = "3.2.8.RELEASE"; String minSpringVersion = "3.2.8.RELEASE";
spy(SpringSecurityCoreVersion.class); PowerMockito.spy(SpringSecurityCoreVersion.class);
spy(SpringVersion.class); PowerMockito.spy(SpringVersion.class);
doReturn("3.2.0.RELEASE").when(SpringSecurityCoreVersion.class, "getVersion"); PowerMockito.doReturn("3.2.0.RELEASE").when(SpringSecurityCoreVersion.class, "getVersion");
doReturn("3.2.10.RELEASE").when(SpringVersion.class, "getVersion"); PowerMockito.doReturn("3.2.10.RELEASE").when(SpringVersion.class, "getVersion");
performChecks(minSpringVersion); performChecks(minSpringVersion);
@ -145,10 +144,10 @@ public class SpringSecurityCoreVersionTests {
@Test @Test
public void noLoggingIfPropertySet() throws Exception { public void noLoggingIfPropertySet() throws Exception {
spy(SpringSecurityCoreVersion.class); PowerMockito.spy(SpringSecurityCoreVersion.class);
spy(SpringVersion.class); PowerMockito.spy(SpringVersion.class);
doReturn("3").when(SpringSecurityCoreVersion.class, "getVersion"); PowerMockito.doReturn("3").when(SpringSecurityCoreVersion.class, "getVersion");
doReturn("2").when(SpringVersion.class, "getVersion"); PowerMockito.doReturn("2").when(SpringVersion.class, "getVersion");
System.setProperty(getDisableChecksProperty(), Boolean.TRUE.toString()); System.setProperty(getDisableChecksProperty(), Boolean.TRUE.toString());
performChecks(); performChecks();

View File

@ -20,6 +20,8 @@ import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import com.fasterxml.jackson.annotation.JsonClassDescription; import com.fasterxml.jackson.annotation.JsonClassDescription;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonInclude.Value;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import org.json.JSONException; import org.json.JSONException;
import org.junit.Test; import org.junit.Test;
@ -29,10 +31,6 @@ import org.springframework.security.authentication.UsernamePasswordAuthenticatio
import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.User;
import static com.fasterxml.jackson.annotation.JsonInclude.Include.ALWAYS;
import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_ABSENT;
import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL;
import static com.fasterxml.jackson.annotation.JsonInclude.Value.construct;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
/** /**
@ -181,7 +179,8 @@ public class UsernamePasswordAuthenticationTokenMixinTests extends AbstractMixin
@Test @Test
public void serializingThenDeserializingWithConfiguredObjectMapperShouldWork() throws IOException { public void serializingThenDeserializingWithConfiguredObjectMapperShouldWork() throws IOException {
this.mapper.setDefaultPropertyInclusion(construct(ALWAYS, NON_NULL)).setSerializationInclusion(NON_ABSENT); this.mapper.setDefaultPropertyInclusion(Value.construct(Include.ALWAYS, Include.NON_NULL))
.setSerializationInclusion(Include.NON_ABSENT);
UsernamePasswordAuthenticationToken original = new UsernamePasswordAuthenticationToken("Frodo", null); UsernamePasswordAuthenticationToken original = new UsernamePasswordAuthenticationToken("Frodo", null);
String serialized = this.mapper.writeValueAsString(original); String serialized = this.mapper.writeValueAsString(original);
UsernamePasswordAuthenticationToken deserialized = this.mapper.readValue(serialized, UsernamePasswordAuthenticationToken deserialized = this.mapper.readValue(serialized,

View File

@ -27,13 +27,7 @@ import javax.crypto.spec.SecretKeySpec;
import org.springframework.security.crypto.codec.Hex; import org.springframework.security.crypto.codec.Hex;
import org.springframework.security.crypto.keygen.BytesKeyGenerator; import org.springframework.security.crypto.keygen.BytesKeyGenerator;
import org.springframework.security.crypto.keygen.KeyGenerators; import org.springframework.security.crypto.keygen.KeyGenerators;
import org.springframework.security.crypto.util.EncodingUtils;
import static org.springframework.security.crypto.encrypt.CipherUtils.doFinal;
import static org.springframework.security.crypto.encrypt.CipherUtils.initCipher;
import static org.springframework.security.crypto.encrypt.CipherUtils.newCipher;
import static org.springframework.security.crypto.encrypt.CipherUtils.newSecretKey;
import static org.springframework.security.crypto.util.EncodingUtils.concatenate;
import static org.springframework.security.crypto.util.EncodingUtils.subArray;
/** /**
* Encryptor that uses AES encryption. * Encryptor that uses AES encryption.
@ -80,7 +74,7 @@ public final class AesBytesEncryptor implements BytesEncryptor {
} }
public Cipher createCipher() { public Cipher createCipher() {
return newCipher(this.toString()); return CipherUtils.newCipher(this.toString());
} }
public BytesKeyGenerator defaultIvGenerator() { public BytesKeyGenerator defaultIvGenerator() {
@ -98,8 +92,8 @@ public final class AesBytesEncryptor implements BytesEncryptor {
} }
public AesBytesEncryptor(String password, CharSequence salt, BytesKeyGenerator ivGenerator, CipherAlgorithm alg) { public AesBytesEncryptor(String password, CharSequence salt, BytesKeyGenerator ivGenerator, CipherAlgorithm alg) {
this(newSecretKey("PBKDF2WithHmacSHA1", new PBEKeySpec(password.toCharArray(), Hex.decode(salt), 1024, 256)), this(CipherUtils.newSecretKey("PBKDF2WithHmacSHA1",
ivGenerator, alg); new PBEKeySpec(password.toCharArray(), Hex.decode(salt), 1024, 256)), ivGenerator, alg);
} }
/** /**
@ -122,9 +116,9 @@ public final class AesBytesEncryptor implements BytesEncryptor {
public byte[] encrypt(byte[] bytes) { public byte[] encrypt(byte[] bytes) {
synchronized (this.encryptor) { synchronized (this.encryptor) {
byte[] iv = this.ivGenerator.generateKey(); byte[] iv = this.ivGenerator.generateKey();
initCipher(this.encryptor, Cipher.ENCRYPT_MODE, this.secretKey, this.alg.getParameterSpec(iv)); CipherUtils.initCipher(this.encryptor, Cipher.ENCRYPT_MODE, this.secretKey, this.alg.getParameterSpec(iv));
byte[] encrypted = doFinal(this.encryptor, bytes); byte[] encrypted = CipherUtils.doFinal(this.encryptor, bytes);
return this.ivGenerator != NULL_IV_GENERATOR ? concatenate(iv, encrypted) : encrypted; return this.ivGenerator != NULL_IV_GENERATOR ? EncodingUtils.concatenate(iv, encrypted) : encrypted;
} }
} }
@ -132,8 +126,8 @@ public final class AesBytesEncryptor implements BytesEncryptor {
public byte[] decrypt(byte[] encryptedBytes) { public byte[] decrypt(byte[] encryptedBytes) {
synchronized (this.decryptor) { synchronized (this.decryptor) {
byte[] iv = iv(encryptedBytes); byte[] iv = iv(encryptedBytes);
initCipher(this.decryptor, Cipher.DECRYPT_MODE, this.secretKey, this.alg.getParameterSpec(iv)); CipherUtils.initCipher(this.decryptor, Cipher.DECRYPT_MODE, this.secretKey, this.alg.getParameterSpec(iv));
return doFinal(this.decryptor, return CipherUtils.doFinal(this.decryptor,
this.ivGenerator != NULL_IV_GENERATOR ? encrypted(encryptedBytes, iv.length) : encryptedBytes); this.ivGenerator != NULL_IV_GENERATOR ? encrypted(encryptedBytes, iv.length) : encryptedBytes);
} }
} }
@ -141,12 +135,13 @@ public final class AesBytesEncryptor implements BytesEncryptor {
// internal helpers // internal helpers
private byte[] iv(byte[] encrypted) { private byte[] iv(byte[] encrypted) {
return this.ivGenerator != NULL_IV_GENERATOR ? subArray(encrypted, 0, this.ivGenerator.getKeyLength()) return this.ivGenerator != NULL_IV_GENERATOR
? EncodingUtils.subArray(encrypted, 0, this.ivGenerator.getKeyLength())
: NULL_IV_GENERATOR.generateKey(); : NULL_IV_GENERATOR.generateKey();
} }
private byte[] encrypted(byte[] encryptedBytes, int ivLength) { private byte[] encrypted(byte[] encryptedBytes, int ivLength) {
return subArray(encryptedBytes, ivLength, encryptedBytes.length); return EncodingUtils.subArray(encryptedBytes, ivLength, encryptedBytes.length);
} }
private static final BytesKeyGenerator NULL_IV_GENERATOR = new BytesKeyGenerator() { private static final BytesKeyGenerator NULL_IV_GENERATOR = new BytesKeyGenerator() {

View File

@ -24,9 +24,7 @@ import org.bouncycastle.crypto.params.ParametersWithIV;
import org.springframework.security.crypto.encrypt.AesBytesEncryptor.CipherAlgorithm; import org.springframework.security.crypto.encrypt.AesBytesEncryptor.CipherAlgorithm;
import org.springframework.security.crypto.keygen.BytesKeyGenerator; import org.springframework.security.crypto.keygen.BytesKeyGenerator;
import org.springframework.security.crypto.util.EncodingUtils;
import static org.springframework.security.crypto.util.EncodingUtils.concatenate;
import static org.springframework.security.crypto.util.EncodingUtils.subArray;
/** /**
* An Encryptor equivalent to {@link AesBytesEncryptor} using {@link CipherAlgorithm#CBC} * An Encryptor equivalent to {@link AesBytesEncryptor} using {@link CipherAlgorithm#CBC}
@ -55,13 +53,13 @@ public class BouncyCastleAesCbcBytesEncryptor extends BouncyCastleAesBytesEncryp
new CBCBlockCipher(new org.bouncycastle.crypto.engines.AESFastEngine()), new PKCS7Padding()); new CBCBlockCipher(new org.bouncycastle.crypto.engines.AESFastEngine()), new PKCS7Padding());
blockCipher.init(true, new ParametersWithIV(this.secretKey, iv)); blockCipher.init(true, new ParametersWithIV(this.secretKey, iv));
byte[] encrypted = process(blockCipher, bytes); byte[] encrypted = process(blockCipher, bytes);
return iv != null ? concatenate(iv, encrypted) : encrypted; return iv != null ? EncodingUtils.concatenate(iv, encrypted) : encrypted;
} }
@Override @Override
public byte[] decrypt(byte[] encryptedBytes) { public byte[] decrypt(byte[] encryptedBytes) {
byte[] iv = subArray(encryptedBytes, 0, this.ivGenerator.getKeyLength()); byte[] iv = EncodingUtils.subArray(encryptedBytes, 0, this.ivGenerator.getKeyLength());
encryptedBytes = subArray(encryptedBytes, this.ivGenerator.getKeyLength(), encryptedBytes.length); encryptedBytes = EncodingUtils.subArray(encryptedBytes, this.ivGenerator.getKeyLength(), encryptedBytes.length);
@SuppressWarnings("deprecation") @SuppressWarnings("deprecation")
PaddedBufferedBlockCipher blockCipher = new PaddedBufferedBlockCipher( PaddedBufferedBlockCipher blockCipher = new PaddedBufferedBlockCipher(

View File

@ -22,9 +22,7 @@ import org.bouncycastle.crypto.params.AEADParameters;
import org.springframework.security.crypto.encrypt.AesBytesEncryptor.CipherAlgorithm; import org.springframework.security.crypto.encrypt.AesBytesEncryptor.CipherAlgorithm;
import org.springframework.security.crypto.keygen.BytesKeyGenerator; import org.springframework.security.crypto.keygen.BytesKeyGenerator;
import org.springframework.security.crypto.util.EncodingUtils;
import static org.springframework.security.crypto.util.EncodingUtils.concatenate;
import static org.springframework.security.crypto.util.EncodingUtils.subArray;
/** /**
* An Encryptor equivalent to {@link AesBytesEncryptor} using {@link CipherAlgorithm#GCM} * An Encryptor equivalent to {@link AesBytesEncryptor} using {@link CipherAlgorithm#GCM}
@ -53,13 +51,13 @@ public class BouncyCastleAesGcmBytesEncryptor extends BouncyCastleAesBytesEncryp
blockCipher.init(true, new AEADParameters(this.secretKey, 128, iv, null)); blockCipher.init(true, new AEADParameters(this.secretKey, 128, iv, null));
byte[] encrypted = process(blockCipher, bytes); byte[] encrypted = process(blockCipher, bytes);
return iv != null ? concatenate(iv, encrypted) : encrypted; return iv != null ? EncodingUtils.concatenate(iv, encrypted) : encrypted;
} }
@Override @Override
public byte[] decrypt(byte[] encryptedBytes) { public byte[] decrypt(byte[] encryptedBytes) {
byte[] iv = subArray(encryptedBytes, 0, this.ivGenerator.getKeyLength()); byte[] iv = EncodingUtils.subArray(encryptedBytes, 0, this.ivGenerator.getKeyLength());
encryptedBytes = subArray(encryptedBytes, this.ivGenerator.getKeyLength(), encryptedBytes.length); encryptedBytes = EncodingUtils.subArray(encryptedBytes, this.ivGenerator.getKeyLength(), encryptedBytes.length);
@SuppressWarnings("deprecation") @SuppressWarnings("deprecation")
GCMBlockCipher blockCipher = new GCMBlockCipher(new org.bouncycastle.crypto.engines.AESFastEngine()); GCMBlockCipher blockCipher = new GCMBlockCipher(new org.bouncycastle.crypto.engines.AESFastEngine());

View File

@ -20,9 +20,7 @@ import java.security.MessageDigest;
import org.springframework.security.crypto.codec.Hex; import org.springframework.security.crypto.codec.Hex;
import org.springframework.security.crypto.keygen.BytesKeyGenerator; import org.springframework.security.crypto.keygen.BytesKeyGenerator;
import org.springframework.security.crypto.keygen.KeyGenerators; import org.springframework.security.crypto.keygen.KeyGenerators;
import org.springframework.security.crypto.util.EncodingUtils;
import static org.springframework.security.crypto.util.EncodingUtils.concatenate;
import static org.springframework.security.crypto.util.EncodingUtils.subArray;
/** /**
* Abstract base class for password encoders * Abstract base class for password encoders
@ -47,14 +45,14 @@ public abstract class AbstractPasswordEncoder implements PasswordEncoder {
@Override @Override
public boolean matches(CharSequence rawPassword, String encodedPassword) { public boolean matches(CharSequence rawPassword, String encodedPassword) {
byte[] digested = Hex.decode(encodedPassword); byte[] digested = Hex.decode(encodedPassword);
byte[] salt = subArray(digested, 0, this.saltGenerator.getKeyLength()); byte[] salt = EncodingUtils.subArray(digested, 0, this.saltGenerator.getKeyLength());
return matches(digested, encodeAndConcatenate(rawPassword, salt)); return matches(digested, encodeAndConcatenate(rawPassword, salt));
} }
protected abstract byte[] encode(CharSequence rawPassword, byte[] salt); protected abstract byte[] encode(CharSequence rawPassword, byte[] salt);
protected byte[] encodeAndConcatenate(CharSequence rawPassword, byte[] salt) { protected byte[] encodeAndConcatenate(CharSequence rawPassword, byte[] salt) {
return concatenate(salt, encode(rawPassword, salt)); return EncodingUtils.concatenate(salt, encode(rawPassword, salt));
} }
/** /**

View File

@ -27,9 +27,7 @@ import org.springframework.security.crypto.codec.Hex;
import org.springframework.security.crypto.codec.Utf8; import org.springframework.security.crypto.codec.Utf8;
import org.springframework.security.crypto.keygen.BytesKeyGenerator; import org.springframework.security.crypto.keygen.BytesKeyGenerator;
import org.springframework.security.crypto.keygen.KeyGenerators; import org.springframework.security.crypto.keygen.KeyGenerators;
import org.springframework.security.crypto.util.EncodingUtils;
import static org.springframework.security.crypto.util.EncodingUtils.concatenate;
import static org.springframework.security.crypto.util.EncodingUtils.subArray;
/** /**
* A {@code PasswordEncoder} implementation that uses PBKDF2 with a configurable number of * A {@code PasswordEncoder} implementation that uses PBKDF2 with a configurable number of
@ -147,7 +145,7 @@ public class Pbkdf2PasswordEncoder implements PasswordEncoder {
@Override @Override
public boolean matches(CharSequence rawPassword, String encodedPassword) { public boolean matches(CharSequence rawPassword, String encodedPassword) {
byte[] digested = decode(encodedPassword); byte[] digested = decode(encodedPassword);
byte[] salt = subArray(digested, 0, this.saltGenerator.getKeyLength()); byte[] salt = EncodingUtils.subArray(digested, 0, this.saltGenerator.getKeyLength());
return MessageDigest.isEqual(digested, encode(rawPassword, salt)); return MessageDigest.isEqual(digested, encode(rawPassword, salt));
} }
@ -160,10 +158,10 @@ public class Pbkdf2PasswordEncoder implements PasswordEncoder {
private byte[] encode(CharSequence rawPassword, byte[] salt) { private byte[] encode(CharSequence rawPassword, byte[] salt) {
try { try {
PBEKeySpec spec = new PBEKeySpec(rawPassword.toString().toCharArray(), concatenate(salt, this.secret), PBEKeySpec spec = new PBEKeySpec(rawPassword.toString().toCharArray(),
this.iterations, this.hashWidth); EncodingUtils.concatenate(salt, this.secret), this.iterations, this.hashWidth);
SecretKeyFactory skf = SecretKeyFactory.getInstance(this.algorithm); SecretKeyFactory skf = SecretKeyFactory.getInstance(this.algorithm);
return concatenate(salt, skf.generateSecret(spec).getEncoded()); return EncodingUtils.concatenate(salt, skf.generateSecret(spec).getEncoded());
} }
catch (GeneralSecurityException e) { catch (GeneralSecurityException e) {
throw new IllegalStateException("Could not create hash", e); throw new IllegalStateException("Could not create hash", e);

View File

@ -21,9 +21,7 @@ import org.springframework.security.crypto.codec.Hex;
import org.springframework.security.crypto.codec.Utf8; import org.springframework.security.crypto.codec.Utf8;
import org.springframework.security.crypto.keygen.BytesKeyGenerator; import org.springframework.security.crypto.keygen.BytesKeyGenerator;
import org.springframework.security.crypto.keygen.KeyGenerators; import org.springframework.security.crypto.keygen.KeyGenerators;
import org.springframework.security.crypto.util.EncodingUtils;
import static org.springframework.security.crypto.util.EncodingUtils.concatenate;
import static org.springframework.security.crypto.util.EncodingUtils.subArray;
/** /**
* This {@link PasswordEncoder} is provided for legacy purposes only and is not considered * This {@link PasswordEncoder} is provided for legacy purposes only and is not considered
@ -81,7 +79,7 @@ public final class StandardPasswordEncoder implements PasswordEncoder {
@Override @Override
public boolean matches(CharSequence rawPassword, String encodedPassword) { public boolean matches(CharSequence rawPassword, String encodedPassword) {
byte[] digested = decode(encodedPassword); byte[] digested = decode(encodedPassword);
byte[] salt = subArray(digested, 0, this.saltGenerator.getKeyLength()); byte[] salt = EncodingUtils.subArray(digested, 0, this.saltGenerator.getKeyLength());
return MessageDigest.isEqual(digested, digest(rawPassword, salt)); return MessageDigest.isEqual(digested, digest(rawPassword, salt));
} }
@ -99,8 +97,8 @@ public final class StandardPasswordEncoder implements PasswordEncoder {
} }
private byte[] digest(CharSequence rawPassword, byte[] salt) { private byte[] digest(CharSequence rawPassword, byte[] salt) {
byte[] digest = this.digester.digest(concatenate(salt, this.secret, Utf8.encode(rawPassword))); byte[] digest = this.digester.digest(EncodingUtils.concatenate(salt, this.secret, Utf8.encode(rawPassword)));
return concatenate(salt, digest); return EncodingUtils.concatenate(salt, digest);
} }
private byte[] decode(CharSequence encodedPassword) { private byte[] decode(CharSequence encodedPassword) {

View File

@ -23,14 +23,13 @@ import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.springframework.security.crypto.codec.Hex; import org.springframework.security.crypto.codec.Hex;
import org.springframework.security.crypto.encrypt.AesBytesEncryptor.CipherAlgorithm;
import org.springframework.security.crypto.keygen.BytesKeyGenerator; import org.springframework.security.crypto.keygen.BytesKeyGenerator;
import org.springframework.security.crypto.password.Pbkdf2PasswordEncoder.SecretKeyFactoryAlgorithm;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.springframework.security.crypto.encrypt.AesBytesEncryptor.CipherAlgorithm.GCM;
import static org.springframework.security.crypto.encrypt.CipherUtils.newSecretKey;
import static org.springframework.security.crypto.password.Pbkdf2PasswordEncoder.SecretKeyFactoryAlgorithm.PBKDF2WithHmacSHA1;
/** /**
* Tests for {@link AesBytesEncryptor} * Tests for {@link AesBytesEncryptor}
@ -76,7 +75,8 @@ public class AesBytesEncryptorTests {
@Test @Test
public void roundtripWhenUsingGcmThenEncryptsAndDecrypts() { public void roundtripWhenUsingGcmThenEncryptsAndDecrypts() {
CryptoAssumptions.assumeGCMJCE(); CryptoAssumptions.assumeGCMJCE();
AesBytesEncryptor encryptor = new AesBytesEncryptor(this.password, this.hexSalt, this.generator, GCM); AesBytesEncryptor encryptor = new AesBytesEncryptor(this.password, this.hexSalt, this.generator,
CipherAlgorithm.GCM);
byte[] encryption = encryptor.encrypt(this.secret.getBytes()); byte[] encryption = encryptor.encrypt(this.secret.getBytes());
assertThat(new String(Hex.encode(encryption))) assertThat(new String(Hex.encode(encryption)))
@ -90,8 +90,8 @@ public class AesBytesEncryptorTests {
public void roundtripWhenUsingSecretKeyThenEncryptsAndDecrypts() { public void roundtripWhenUsingSecretKeyThenEncryptsAndDecrypts() {
CryptoAssumptions.assumeGCMJCE(); CryptoAssumptions.assumeGCMJCE();
PBEKeySpec keySpec = new PBEKeySpec(this.password.toCharArray(), Hex.decode(this.hexSalt), 1024, 256); PBEKeySpec keySpec = new PBEKeySpec(this.password.toCharArray(), Hex.decode(this.hexSalt), 1024, 256);
SecretKey secretKey = newSecretKey(PBKDF2WithHmacSHA1.name(), keySpec); SecretKey secretKey = CipherUtils.newSecretKey(SecretKeyFactoryAlgorithm.PBKDF2WithHmacSHA1.name(), keySpec);
AesBytesEncryptor encryptor = new AesBytesEncryptor(secretKey, this.generator, GCM); AesBytesEncryptor encryptor = new AesBytesEncryptor(secretKey, this.generator, CipherAlgorithm.GCM);
byte[] encryption = encryptor.encrypt(this.secret.getBytes()); byte[] encryption = encryptor.encrypt(this.secret.getBytes());
assertThat(new String(Hex.encode(encryption))) assertThat(new String(Hex.encode(encryption)))

View File

@ -3,7 +3,6 @@
"-//Checkstyle//DTD SuppressionFilter Configuration 1.2//EN" "-//Checkstyle//DTD SuppressionFilter Configuration 1.2//EN"
"https://checkstyle.org/dtds/suppressions_1_2.dtd"> "https://checkstyle.org/dtds/suppressions_1_2.dtd">
<suppressions> <suppressions>
<suppress files=".*" checks="SpringAvoidStaticImport" />
<suppress files=".*" checks="SpringCatch" /> <suppress files=".*" checks="SpringCatch" />
<suppress files=".*" checks="SpringHeader" /> <suppress files=".*" checks="SpringHeader" />
<suppress files=".*" checks="SpringHideUtilityClassConstructor" /> <suppress files=".*" checks="SpringHideUtilityClassConstructor" />

View File

@ -15,13 +15,13 @@
*/ */
package org.springframework.security.messaging.util.matcher; package org.springframework.security.messaging.util.matcher;
import java.util.Arrays;
import java.util.List; import java.util.List;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import static java.util.Arrays.asList; import org.springframework.util.Assert;
import static org.apache.commons.logging.LogFactory.getLog;
import static org.springframework.util.Assert.notEmpty;
/** /**
* Abstract {@link MessageMatcher} containing multiple {@link MessageMatcher} * Abstract {@link MessageMatcher} containing multiple {@link MessageMatcher}
@ -30,7 +30,7 @@ import static org.springframework.util.Assert.notEmpty;
*/ */
abstract class AbstractMessageMatcherComposite<T> implements MessageMatcher<T> { abstract class AbstractMessageMatcherComposite<T> implements MessageMatcher<T> {
protected final Log LOGGER = getLog(getClass()); protected final Log LOGGER = LogFactory.getLog(getClass());
private final List<MessageMatcher<T>> messageMatchers; private final List<MessageMatcher<T>> messageMatchers;
@ -39,7 +39,7 @@ abstract class AbstractMessageMatcherComposite<T> implements MessageMatcher<T> {
* @param messageMatchers the {@link MessageMatcher} instances to try * @param messageMatchers the {@link MessageMatcher} instances to try
*/ */
AbstractMessageMatcherComposite(List<MessageMatcher<T>> messageMatchers) { AbstractMessageMatcherComposite(List<MessageMatcher<T>> messageMatchers) {
notEmpty(messageMatchers, "messageMatchers must contain a value"); Assert.notEmpty(messageMatchers, "messageMatchers must contain a value");
if (messageMatchers.contains(null)) { if (messageMatchers.contains(null)) {
throw new IllegalArgumentException("messageMatchers cannot contain null values"); throw new IllegalArgumentException("messageMatchers cannot contain null values");
} }
@ -53,7 +53,7 @@ abstract class AbstractMessageMatcherComposite<T> implements MessageMatcher<T> {
*/ */
@SafeVarargs @SafeVarargs
AbstractMessageMatcherComposite(MessageMatcher<T>... messageMatchers) { AbstractMessageMatcherComposite(MessageMatcher<T>... messageMatchers) {
this(asList(messageMatchers)); this(Arrays.asList(messageMatchers));
} }
public List<MessageMatcher<T>> getMessageMatchers() { public List<MessageMatcher<T>> getMessageMatchers() {

View File

@ -31,8 +31,7 @@ import org.springframework.security.messaging.access.intercept.MessageSecurityMe
import org.springframework.security.messaging.util.matcher.MessageMatcher; import org.springframework.security.messaging.util.matcher.MessageMatcher;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.powermock.api.mockito.PowerMockito.when; import static org.mockito.BDDMockito.given;
import static org.springframework.security.messaging.access.expression.ExpressionBasedMessageSecurityMetadataSourceFactory.createExpressionMessageMetadataSource;
@RunWith(MockitoJUnitRunner.class) @RunWith(MockitoJUnitRunner.class)
public class ExpressionBasedMessageSecurityMetadataSourceFactoryTests { public class ExpressionBasedMessageSecurityMetadataSourceFactoryTests {
@ -67,7 +66,8 @@ public class ExpressionBasedMessageSecurityMetadataSourceFactoryTests {
this.matcherToExpression.put(this.matcher1, this.expression1); this.matcherToExpression.put(this.matcher1, this.expression1);
this.matcherToExpression.put(this.matcher2, this.expression2); this.matcherToExpression.put(this.matcher2, this.expression2);
this.source = createExpressionMessageMetadataSource(this.matcherToExpression); this.source = ExpressionBasedMessageSecurityMetadataSourceFactory
.createExpressionMessageMetadataSource(this.matcherToExpression);
this.rootObject = new MessageSecurityExpressionRoot(this.authentication, this.message); this.rootObject = new MessageSecurityExpressionRoot(this.authentication, this.message);
} }
@ -81,7 +81,7 @@ public class ExpressionBasedMessageSecurityMetadataSourceFactoryTests {
@Test @Test
public void createExpressionMessageMetadataSourceMatchFirst() { public void createExpressionMessageMetadataSourceMatchFirst() {
when(this.matcher1.matches(this.message)).thenReturn(true); given(this.matcher1.matches(this.message)).willReturn(true);
Collection<ConfigAttribute> attrs = this.source.getAttributes(this.message); Collection<ConfigAttribute> attrs = this.source.getAttributes(this.message);
@ -94,7 +94,7 @@ public class ExpressionBasedMessageSecurityMetadataSourceFactoryTests {
@Test @Test
public void createExpressionMessageMetadataSourceMatchSecond() { public void createExpressionMessageMetadataSourceMatchSecond() {
when(this.matcher2.matches(this.message)).thenReturn(true); given(this.matcher2.matches(this.message)).willReturn(true);
Collection<ConfigAttribute> attrs = this.source.getAttributes(this.message); Collection<ConfigAttribute> attrs = this.source.getAttributes(this.message);

View File

@ -27,6 +27,7 @@ import org.mockito.junit.MockitoJUnitRunner;
import org.springframework.expression.EvaluationContext; import org.springframework.expression.EvaluationContext;
import org.springframework.expression.Expression; import org.springframework.expression.Expression;
import org.springframework.messaging.Message; import org.springframework.messaging.Message;
import org.springframework.security.access.AccessDecisionVoter;
import org.springframework.security.access.ConfigAttribute; import org.springframework.security.access.ConfigAttribute;
import org.springframework.security.access.SecurityConfig; import org.springframework.security.access.SecurityConfig;
import org.springframework.security.access.expression.SecurityExpressionHandler; import org.springframework.security.access.expression.SecurityExpressionHandler;
@ -39,9 +40,6 @@ import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.springframework.security.access.AccessDecisionVoter.ACCESS_ABSTAIN;
import static org.springframework.security.access.AccessDecisionVoter.ACCESS_DENIED;
import static org.springframework.security.access.AccessDecisionVoter.ACCESS_GRANTED;
@RunWith(MockitoJUnitRunner.class) @RunWith(MockitoJUnitRunner.class)
public class MessageExpressionVoterTests { public class MessageExpressionVoterTests {
@ -79,19 +77,22 @@ public class MessageExpressionVoterTests {
@Test @Test
public void voteGranted() { public void voteGranted() {
given(this.expression.getValue(any(EvaluationContext.class), eq(Boolean.class))).willReturn(true); given(this.expression.getValue(any(EvaluationContext.class), eq(Boolean.class))).willReturn(true);
assertThat(this.voter.vote(this.authentication, this.message, this.attributes)).isEqualTo(ACCESS_GRANTED); assertThat(this.voter.vote(this.authentication, this.message, this.attributes))
.isEqualTo(AccessDecisionVoter.ACCESS_GRANTED);
} }
@Test @Test
public void voteDenied() { public void voteDenied() {
given(this.expression.getValue(any(EvaluationContext.class), eq(Boolean.class))).willReturn(false); given(this.expression.getValue(any(EvaluationContext.class), eq(Boolean.class))).willReturn(false);
assertThat(this.voter.vote(this.authentication, this.message, this.attributes)).isEqualTo(ACCESS_DENIED); assertThat(this.voter.vote(this.authentication, this.message, this.attributes))
.isEqualTo(AccessDecisionVoter.ACCESS_DENIED);
} }
@Test @Test
public void voteAbstain() { public void voteAbstain() {
this.attributes = Arrays.<ConfigAttribute>asList(new SecurityConfig("ROLE_USER")); this.attributes = Arrays.<ConfigAttribute>asList(new SecurityConfig("ROLE_USER"));
assertThat(this.voter.vote(this.authentication, this.message, this.attributes)).isEqualTo(ACCESS_ABSTAIN); assertThat(this.voter.vote(this.authentication, this.message, this.attributes))
.isEqualTo(AccessDecisionVoter.ACCESS_ABSTAIN);
} }
@Test @Test
@ -126,7 +127,8 @@ public class MessageExpressionVoterTests {
.willReturn(this.evaluationContext); .willReturn(this.evaluationContext);
given(this.expression.getValue(this.evaluationContext, Boolean.class)).willReturn(true); given(this.expression.getValue(this.evaluationContext, Boolean.class)).willReturn(true);
assertThat(this.voter.vote(this.authentication, this.message, this.attributes)).isEqualTo(ACCESS_GRANTED); assertThat(this.voter.vote(this.authentication, this.message, this.attributes))
.isEqualTo(AccessDecisionVoter.ACCESS_GRANTED);
verify(this.expressionHandler).createEvaluationContext(this.authentication, this.message); verify(this.expressionHandler).createEvaluationContext(this.authentication, this.message);
} }
@ -142,7 +144,8 @@ public class MessageExpressionVoterTests {
given(configAttribute.postProcess(this.evaluationContext, this.message)).willReturn(this.evaluationContext); given(configAttribute.postProcess(this.evaluationContext, this.message)).willReturn(this.evaluationContext);
given(this.expression.getValue(any(EvaluationContext.class), eq(Boolean.class))).willReturn(true); given(this.expression.getValue(any(EvaluationContext.class), eq(Boolean.class))).willReturn(true);
assertThat(this.voter.vote(this.authentication, this.message, this.attributes)).isEqualTo(ACCESS_GRANTED); assertThat(this.voter.vote(this.authentication, this.message, this.attributes))
.isEqualTo(AccessDecisionVoter.ACCESS_GRANTED);
verify(configAttribute).postProcess(this.evaluationContext, this.message); verify(configAttribute).postProcess(this.evaluationContext, this.message);
} }

View File

@ -32,7 +32,7 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.messaging.util.matcher.MessageMatcher; import org.springframework.security.messaging.util.matcher.MessageMatcher;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.powermock.api.mockito.PowerMockito.when; import static org.mockito.BDDMockito.given;
@RunWith(MockitoJUnitRunner.class) @RunWith(MockitoJUnitRunner.class)
public class DefaultMessageSecurityMetadataSourceTests { public class DefaultMessageSecurityMetadataSourceTests {
@ -73,14 +73,14 @@ public class DefaultMessageSecurityMetadataSourceTests {
@Test @Test
public void getAttributesFirst() { public void getAttributesFirst() {
when(this.matcher1.matches(this.message)).thenReturn(true); given(this.matcher1.matches(this.message)).willReturn(true);
assertThat(this.source.getAttributes(this.message)).containsOnly(this.config1); assertThat(this.source.getAttributes(this.message)).containsOnly(this.config1);
} }
@Test @Test
public void getAttributesSecond() { public void getAttributesSecond() {
when(this.matcher1.matches(this.message)).thenReturn(true); given(this.matcher1.matches(this.message)).willReturn(true);
assertThat(this.source.getAttributes(this.message)).containsOnly(this.config2); assertThat(this.source.getAttributes(this.message)).containsOnly(this.config2);
} }

View File

@ -35,7 +35,6 @@ import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.security.core.context.SecurityContextHolder.clearContext;
@RunWith(MockitoJUnitRunner.class) @RunWith(MockitoJUnitRunner.class)
public class SecurityContextChannelInterceptorTests { public class SecurityContextChannelInterceptorTests {
@ -69,7 +68,7 @@ public class SecurityContextChannelInterceptorTests {
@After @After
public void cleanup() { public void cleanup() {
clearContext(); SecurityContextHolder.clearContext();
} }
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)

View File

@ -27,6 +27,7 @@ import java.util.Set;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Predicate; import java.util.function.Predicate;
import java.util.function.Supplier; import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.aopalliance.intercept.MethodInterceptor; import org.aopalliance.intercept.MethodInterceptor;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
@ -54,8 +55,6 @@ import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils; import org.springframework.util.ObjectUtils;
import org.springframework.util.ReflectionUtils; import org.springframework.util.ReflectionUtils;
import static java.util.stream.Collectors.joining;
/** /**
* NOTE: This class is a replica of the same class in spring-web so it can be used for * NOTE: This class is a replica of the same class in spring-web so it can be used for
* tests in spring-messaging. * tests in spring-messaging.
@ -216,13 +215,14 @@ public final class ResolvableMethod {
private String formatMethod() { private String formatMethod() {
return (method().getName() + Arrays.stream(this.method.getParameters()).map(this::formatParameter) return (method().getName() + Arrays.stream(this.method.getParameters()).map(this::formatParameter)
.collect(joining(",\n\t", "(\n\t", "\n)"))); .collect(Collectors.joining(",\n\t", "(\n\t", "\n)")));
} }
private String formatParameter(Parameter param) { private String formatParameter(Parameter param) {
Annotation[] anns = param.getAnnotations(); Annotation[] anns = param.getAnnotations();
return (anns.length > 0 return (anns.length > 0
? Arrays.stream(anns).map(this::formatAnnotation).collect(joining(",", "[", "]")) + " " + param ? Arrays.stream(anns).map(this::formatAnnotation).collect(Collectors.joining(",", "[", "]")) + " "
+ param
: param.toString()); : param.toString());
} }
@ -427,8 +427,8 @@ public final class ResolvableMethod {
} }
private String formatMethods(Set<Method> methods) { private String formatMethods(Set<Method> methods) {
return "\nMatched:\n" return "\nMatched:\n" + methods.stream().map(Method::toGenericString)
+ methods.stream().map(Method::toGenericString).collect(joining(",\n\t", "[\n\t", "\n]")); .collect(Collectors.joining(",\n\t", "[\n\t", "\n]"));
} }
public ResolvableMethod mockCall(Consumer<T> invoker) { public ResolvableMethod mockCall(Consumer<T> invoker) {
@ -504,7 +504,8 @@ public final class ResolvableMethod {
} }
private String formatFilters() { private String formatFilters() {
return this.filters.stream().map(Object::toString).collect(joining(",\n\t\t", "[\n\t\t", "\n\t]")); return this.filters.stream().map(Object::toString)
.collect(Collectors.joining(",\n\t\t", "[\n\t\t", "\n\t]"));
} }
} }

View File

@ -26,6 +26,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
@ -33,8 +34,6 @@ import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.ClientResponse; import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.client.WebClient;
import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse;
/** /**
* Abstract base class for all of the {@code WebClientReactive*TokenResponseClient}s that * Abstract base class for all of the {@code WebClientReactive*TokenResponseClient}s that
* communicate to the Authorization Server's Token Endpoint. * communicate to the Authorization Server's Token Endpoint.
@ -169,7 +168,7 @@ abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T extend
* @return the token response from the response body. * @return the token response from the response body.
*/ */
private Mono<OAuth2AccessTokenResponse> readTokenResponse(T grantRequest, ClientResponse response) { private Mono<OAuth2AccessTokenResponse> readTokenResponse(T grantRequest, ClientResponse response) {
return response.body(oauth2AccessTokenResponse()) return response.body(OAuth2BodyExtractors.oauth2AccessTokenResponse())
.map(tokenResponse -> populateTokenResponse(grantRequest, tokenResponse)); .map(tokenResponse -> populateTokenResponse(grantRequest, tokenResponse));
} }

View File

@ -24,8 +24,6 @@ import org.springframework.http.RequestEntity;
import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VALUE;
/** /**
* Utility methods used by the {@link Converter}'s that convert from an implementation of * Utility methods used by the {@link Converter}'s that convert from an implementation of
* an {@link AbstractOAuth2AuthorizationGrantRequest} to a {@link RequestEntity} * an {@link AbstractOAuth2AuthorizationGrantRequest} to a {@link RequestEntity}
@ -53,7 +51,7 @@ final class OAuth2AuthorizationGrantRequestEntityUtils {
private static HttpHeaders getDefaultTokenRequestHeaders() { private static HttpHeaders getDefaultTokenRequestHeaders() {
HttpHeaders headers = new HttpHeaders(); HttpHeaders headers = new HttpHeaders();
headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON_UTF8)); headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON_UTF8));
final MediaType contentType = MediaType.valueOf(APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); final MediaType contentType = MediaType.valueOf(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8");
headers.setContentType(contentType); headers.setContentType(contentType);
return headers; return headers;
} }

View File

@ -29,12 +29,6 @@ import org.springframework.security.oauth2.core.AuthenticationMethod;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import static org.springframework.security.oauth2.client.jackson2.JsonNodeUtils.MAP_TYPE_REFERENCE;
import static org.springframework.security.oauth2.client.jackson2.JsonNodeUtils.SET_TYPE_REFERENCE;
import static org.springframework.security.oauth2.client.jackson2.JsonNodeUtils.findObjectNode;
import static org.springframework.security.oauth2.client.jackson2.JsonNodeUtils.findStringValue;
import static org.springframework.security.oauth2.client.jackson2.JsonNodeUtils.findValue;
/** /**
* A {@code JsonDeserializer} for {@link ClientRegistration}. * A {@code JsonDeserializer} for {@link ClientRegistration}.
* *
@ -55,28 +49,31 @@ final class ClientRegistrationDeserializer extends JsonDeserializer<ClientRegist
public ClientRegistration deserialize(JsonParser parser, DeserializationContext context) throws IOException { public ClientRegistration deserialize(JsonParser parser, DeserializationContext context) throws IOException {
ObjectMapper mapper = (ObjectMapper) parser.getCodec(); ObjectMapper mapper = (ObjectMapper) parser.getCodec();
JsonNode clientRegistrationNode = mapper.readTree(parser); JsonNode clientRegistrationNode = mapper.readTree(parser);
JsonNode providerDetailsNode = findObjectNode(clientRegistrationNode, "providerDetails"); JsonNode providerDetailsNode = JsonNodeUtils.findObjectNode(clientRegistrationNode, "providerDetails");
JsonNode userInfoEndpointNode = findObjectNode(providerDetailsNode, "userInfoEndpoint"); JsonNode userInfoEndpointNode = JsonNodeUtils.findObjectNode(providerDetailsNode, "userInfoEndpoint");
return ClientRegistration.withRegistrationId(findStringValue(clientRegistrationNode, "registrationId")) return ClientRegistration
.clientId(findStringValue(clientRegistrationNode, "clientId")) .withRegistrationId(JsonNodeUtils.findStringValue(clientRegistrationNode, "registrationId"))
.clientSecret(findStringValue(clientRegistrationNode, "clientSecret")) .clientId(JsonNodeUtils.findStringValue(clientRegistrationNode, "clientId"))
.clientSecret(JsonNodeUtils.findStringValue(clientRegistrationNode, "clientSecret"))
.clientAuthenticationMethod(CLIENT_AUTHENTICATION_METHOD_CONVERTER .clientAuthenticationMethod(CLIENT_AUTHENTICATION_METHOD_CONVERTER
.convert(findObjectNode(clientRegistrationNode, "clientAuthenticationMethod"))) .convert(JsonNodeUtils.findObjectNode(clientRegistrationNode, "clientAuthenticationMethod")))
.authorizationGrantType(AUTHORIZATION_GRANT_TYPE_CONVERTER .authorizationGrantType(AUTHORIZATION_GRANT_TYPE_CONVERTER
.convert(findObjectNode(clientRegistrationNode, "authorizationGrantType"))) .convert(JsonNodeUtils.findObjectNode(clientRegistrationNode, "authorizationGrantType")))
.redirectUri(findStringValue(clientRegistrationNode, "redirectUri")) .redirectUri(JsonNodeUtils.findStringValue(clientRegistrationNode, "redirectUri"))
.scope(findValue(clientRegistrationNode, "scopes", SET_TYPE_REFERENCE, mapper)) .scope(JsonNodeUtils.findValue(clientRegistrationNode, "scopes", JsonNodeUtils.SET_TYPE_REFERENCE,
.clientName(findStringValue(clientRegistrationNode, "clientName")) mapper))
.authorizationUri(findStringValue(providerDetailsNode, "authorizationUri")) .clientName(JsonNodeUtils.findStringValue(clientRegistrationNode, "clientName"))
.tokenUri(findStringValue(providerDetailsNode, "tokenUri")) .authorizationUri(JsonNodeUtils.findStringValue(providerDetailsNode, "authorizationUri"))
.userInfoUri(findStringValue(userInfoEndpointNode, "uri")) .tokenUri(JsonNodeUtils.findStringValue(providerDetailsNode, "tokenUri"))
.userInfoUri(JsonNodeUtils.findStringValue(userInfoEndpointNode, "uri"))
.userInfoAuthenticationMethod(AUTHENTICATION_METHOD_CONVERTER .userInfoAuthenticationMethod(AUTHENTICATION_METHOD_CONVERTER
.convert(findObjectNode(userInfoEndpointNode, "authenticationMethod"))) .convert(JsonNodeUtils.findObjectNode(userInfoEndpointNode, "authenticationMethod")))
.userNameAttributeName(findStringValue(userInfoEndpointNode, "userNameAttributeName")) .userNameAttributeName(JsonNodeUtils.findStringValue(userInfoEndpointNode, "userNameAttributeName"))
.jwkSetUri(findStringValue(providerDetailsNode, "jwkSetUri")) .jwkSetUri(JsonNodeUtils.findStringValue(providerDetailsNode, "jwkSetUri"))
.issuerUri(findStringValue(providerDetailsNode, "issuerUri")).providerConfigurationMetadata( .issuerUri(JsonNodeUtils.findStringValue(providerDetailsNode, "issuerUri"))
findValue(providerDetailsNode, "configurationMetadata", MAP_TYPE_REFERENCE, mapper)) .providerConfigurationMetadata(JsonNodeUtils.findValue(providerDetailsNode, "configurationMetadata",
JsonNodeUtils.MAP_TYPE_REFERENCE, mapper))
.build(); .build();
} }

View File

@ -28,12 +28,6 @@ import com.fasterxml.jackson.databind.util.StdConverter;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import static org.springframework.security.oauth2.client.jackson2.JsonNodeUtils.MAP_TYPE_REFERENCE;
import static org.springframework.security.oauth2.client.jackson2.JsonNodeUtils.SET_TYPE_REFERENCE;
import static org.springframework.security.oauth2.client.jackson2.JsonNodeUtils.findObjectNode;
import static org.springframework.security.oauth2.client.jackson2.JsonNodeUtils.findStringValue;
import static org.springframework.security.oauth2.client.jackson2.JsonNodeUtils.findValue;
/** /**
* A {@code JsonDeserializer} for {@link OAuth2AuthorizationRequest}. * A {@code JsonDeserializer} for {@link OAuth2AuthorizationRequest}.
* *
@ -53,7 +47,7 @@ final class OAuth2AuthorizationRequestDeserializer extends JsonDeserializer<OAut
JsonNode authorizationRequestNode = mapper.readTree(parser); JsonNode authorizationRequestNode = mapper.readTree(parser);
AuthorizationGrantType authorizationGrantType = AUTHORIZATION_GRANT_TYPE_CONVERTER AuthorizationGrantType authorizationGrantType = AUTHORIZATION_GRANT_TYPE_CONVERTER
.convert(findObjectNode(authorizationRequestNode, "authorizationGrantType")); .convert(JsonNodeUtils.findObjectNode(authorizationRequestNode, "authorizationGrantType"));
OAuth2AuthorizationRequest.Builder builder; OAuth2AuthorizationRequest.Builder builder;
if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(authorizationGrantType)) { if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(authorizationGrantType)) {
@ -66,15 +60,19 @@ final class OAuth2AuthorizationRequestDeserializer extends JsonDeserializer<OAut
throw new JsonParseException(parser, "Invalid authorizationGrantType"); throw new JsonParseException(parser, "Invalid authorizationGrantType");
} }
return builder.authorizationUri(findStringValue(authorizationRequestNode, "authorizationUri")) return builder.authorizationUri(JsonNodeUtils.findStringValue(authorizationRequestNode, "authorizationUri"))
.clientId(findStringValue(authorizationRequestNode, "clientId")) .clientId(JsonNodeUtils.findStringValue(authorizationRequestNode, "clientId"))
.redirectUri(findStringValue(authorizationRequestNode, "redirectUri")) .redirectUri(JsonNodeUtils.findStringValue(authorizationRequestNode, "redirectUri"))
.scopes(findValue(authorizationRequestNode, "scopes", SET_TYPE_REFERENCE, mapper)) .scopes(JsonNodeUtils
.state(findStringValue(authorizationRequestNode, "state")) .findValue(authorizationRequestNode, "scopes", JsonNodeUtils.SET_TYPE_REFERENCE, mapper))
.additionalParameters( .state(JsonNodeUtils.findStringValue(authorizationRequestNode, "state"))
findValue(authorizationRequestNode, "additionalParameters", MAP_TYPE_REFERENCE, mapper)) .additionalParameters(JsonNodeUtils.findValue(authorizationRequestNode, "additionalParameters",
.authorizationRequestUri(findStringValue(authorizationRequestNode, "authorizationRequestUri")) JsonNodeUtils.MAP_TYPE_REFERENCE, mapper))
.attributes(findValue(authorizationRequestNode, "attributes", MAP_TYPE_REFERENCE, mapper)).build(); .authorizationRequestUri(
JsonNodeUtils.findStringValue(authorizationRequestNode, "authorizationRequestUri"))
.attributes(JsonNodeUtils.findValue(authorizationRequestNode, "attributes",
JsonNodeUtils.MAP_TYPE_REFERENCE, mapper))
.build();
} }
} }

View File

@ -23,8 +23,6 @@ import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import static org.springframework.security.oauth2.client.jackson2.JsonNodeUtils.findStringValue;
/** /**
* {@code StdConverter} implementations. * {@code StdConverter} implementations.
* *
@ -37,7 +35,7 @@ abstract class StdConverters {
@Override @Override
public OAuth2AccessToken.TokenType convert(JsonNode jsonNode) { public OAuth2AccessToken.TokenType convert(JsonNode jsonNode) {
String value = findStringValue(jsonNode, "value"); String value = JsonNodeUtils.findStringValue(jsonNode, "value");
if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase(value)) { if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase(value)) {
return OAuth2AccessToken.TokenType.BEARER; return OAuth2AccessToken.TokenType.BEARER;
} }
@ -50,7 +48,7 @@ abstract class StdConverters {
@Override @Override
public ClientAuthenticationMethod convert(JsonNode jsonNode) { public ClientAuthenticationMethod convert(JsonNode jsonNode) {
String value = findStringValue(jsonNode, "value"); String value = JsonNodeUtils.findStringValue(jsonNode, "value");
if (ClientAuthenticationMethod.BASIC.getValue().equalsIgnoreCase(value)) { if (ClientAuthenticationMethod.BASIC.getValue().equalsIgnoreCase(value)) {
return ClientAuthenticationMethod.BASIC; return ClientAuthenticationMethod.BASIC;
} }
@ -69,7 +67,7 @@ abstract class StdConverters {
@Override @Override
public AuthorizationGrantType convert(JsonNode jsonNode) { public AuthorizationGrantType convert(JsonNode jsonNode) {
String value = findStringValue(jsonNode, "value"); String value = JsonNodeUtils.findStringValue(jsonNode, "value");
if (AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equalsIgnoreCase(value)) { if (AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equalsIgnoreCase(value)) {
return AuthorizationGrantType.AUTHORIZATION_CODE; return AuthorizationGrantType.AUTHORIZATION_CODE;
} }
@ -91,7 +89,7 @@ abstract class StdConverters {
@Override @Override
public AuthenticationMethod convert(JsonNode jsonNode) { public AuthenticationMethod convert(JsonNode jsonNode) {
String value = findStringValue(jsonNode, "value"); String value = JsonNodeUtils.findStringValue(jsonNode, "value");
if (AuthenticationMethod.HEADER.getValue().equalsIgnoreCase(value)) { if (AuthenticationMethod.HEADER.getValue().equalsIgnoreCase(value)) {
return AuthenticationMethod.HEADER; return AuthenticationMethod.HEADER;
} }

View File

@ -48,9 +48,6 @@ import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withJwkSetUri;
import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withSecretKey;
/** /**
* A {@link JwtDecoderFactory factory} that provides a {@link JwtDecoder} used for * A {@link JwtDecoderFactory factory} that provides a {@link JwtDecoder} used for
* {@link OidcIdToken} signature verification. The provided {@link JwtDecoder} is * {@link OidcIdToken} signature verification. The provided {@link JwtDecoder} is
@ -162,7 +159,7 @@ public final class OidcIdTokenDecoderFactory implements JwtDecoderFactory<Client
null); null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
} }
return withJwkSetUri(jwkSetUri).jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm).build(); return NimbusJwtDecoder.withJwkSetUri(jwkSetUri).jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm).build();
} }
else if (jwsAlgorithm != null && MacAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) { else if (jwsAlgorithm != null && MacAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) {
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation // https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
@ -187,7 +184,7 @@ public final class OidcIdTokenDecoderFactory implements JwtDecoderFactory<Client
} }
SecretKeySpec secretKeySpec = new SecretKeySpec(clientSecret.getBytes(StandardCharsets.UTF_8), SecretKeySpec secretKeySpec = new SecretKeySpec(clientSecret.getBytes(StandardCharsets.UTF_8),
jcaAlgorithmMappings.get(jwsAlgorithm)); jcaAlgorithmMappings.get(jwsAlgorithm));
return withSecretKey(secretKeySpec).macAlgorithm((MacAlgorithm) jwsAlgorithm).build(); return NimbusJwtDecoder.withSecretKey(secretKeySpec).macAlgorithm((MacAlgorithm) jwsAlgorithm).build();
} }
OAuth2Error oauth2Error = new OAuth2Error(MISSING_SIGNATURE_VERIFIER_ERROR_CODE, OAuth2Error oauth2Error = new OAuth2Error(MISSING_SIGNATURE_VERIFIER_ERROR_CODE,

View File

@ -48,9 +48,6 @@ import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import static org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder.withJwkSetUri;
import static org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder.withSecretKey;
/** /**
* A {@link ReactiveJwtDecoderFactory factory} that provides a {@link ReactiveJwtDecoder} * A {@link ReactiveJwtDecoderFactory factory} that provides a {@link ReactiveJwtDecoder}
* used for {@link OidcIdToken} signature verification. The provided * used for {@link OidcIdToken} signature verification. The provided
@ -162,7 +159,8 @@ public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecod
null); null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
} }
return withJwkSetUri(jwkSetUri).jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm).build(); return NimbusReactiveJwtDecoder.withJwkSetUri(jwkSetUri).jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm)
.build();
} }
else if (jwsAlgorithm != null && MacAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) { else if (jwsAlgorithm != null && MacAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) {
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation // https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
@ -187,7 +185,8 @@ public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecod
} }
SecretKeySpec secretKeySpec = new SecretKeySpec(clientSecret.getBytes(StandardCharsets.UTF_8), SecretKeySpec secretKeySpec = new SecretKeySpec(clientSecret.getBytes(StandardCharsets.UTF_8),
jcaAlgorithmMappings.get(jwsAlgorithm)); jcaAlgorithmMappings.get(jwsAlgorithm));
return withSecretKey(secretKeySpec).macAlgorithm((MacAlgorithm) jwsAlgorithm).build(); return NimbusReactiveJwtDecoder.withSecretKey(secretKeySpec).macAlgorithm((MacAlgorithm) jwsAlgorithm)
.build();
} }
OAuth2Error oauth2Error = new OAuth2Error(MISSING_SIGNATURE_VERIFIER_ERROR_CODE, OAuth2Error oauth2Error = new OAuth2Error(MISSING_SIGNATURE_VERIFIER_ERROR_CODE,

View File

@ -33,8 +33,6 @@ import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import static java.util.Collections.EMPTY_MAP;
/** /**
* A representation of a client registration with an OAuth 2.0 or OpenID Connect 1.0 * A representation of a client registration with an OAuth 2.0 or OpenID Connect 1.0
* Provider. * Provider.
@ -378,7 +376,7 @@ public final class ClientRegistration implements Serializable {
this.jwkSetUri = clientRegistration.providerDetails.jwkSetUri; this.jwkSetUri = clientRegistration.providerDetails.jwkSetUri;
this.issuerUri = clientRegistration.providerDetails.issuerUri; this.issuerUri = clientRegistration.providerDetails.issuerUri;
Map<String, Object> configurationMetadata = clientRegistration.providerDetails.configurationMetadata; Map<String, Object> configurationMetadata = clientRegistration.providerDetails.configurationMetadata;
if (configurationMetadata != EMPTY_MAP) { if (configurationMetadata != Collections.EMPTY_MAP) {
this.configurationMetadata = new HashMap<>(configurationMetadata); this.configurationMetadata = new HashMap<>(configurationMetadata);
} }
this.clientName = clientRegistration.clientName; this.clientName = clientRegistration.clientName;

View File

@ -30,8 +30,6 @@ import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UriComponentsBuilder;
import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VALUE;
/** /**
* A {@link Converter} that converts the provided {@link OAuth2UserRequest} to a * A {@link Converter} that converts the provided {@link OAuth2UserRequest} to a
* {@link RequestEntity} representation of a request for the UserInfo Endpoint. * {@link RequestEntity} representation of a request for the UserInfo Endpoint.
@ -45,7 +43,7 @@ import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VAL
public class OAuth2UserRequestEntityConverter implements Converter<OAuth2UserRequest, RequestEntity<?>> { public class OAuth2UserRequestEntityConverter implements Converter<OAuth2UserRequest, RequestEntity<?>> {
private static final MediaType DEFAULT_CONTENT_TYPE = MediaType private static final MediaType DEFAULT_CONTENT_TYPE = MediaType
.valueOf(APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); .valueOf(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8");
/** /**
* Returns the {@link RequestEntity} used for the UserInfo Request. * Returns the {@link RequestEntity} used for the UserInfo Request.

View File

@ -19,11 +19,11 @@ import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes;
/** /**
* Tests for {@link OAuth2AuthorizedClient}. * Tests for {@link OAuth2AuthorizedClient}.
@ -40,9 +40,9 @@ public class OAuth2AuthorizedClientTests {
@Before @Before
public void setUp() { public void setUp() {
this.clientRegistration = clientRegistration().build(); this.clientRegistration = TestClientRegistrations.clientRegistration().build();
this.principalName = "principal"; this.principalName = "principal";
this.accessToken = noScopes(); this.accessToken = TestOAuth2AccessTokens.noScopes();
} }
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)

View File

@ -25,23 +25,22 @@ import org.junit.Test;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses.accessTokenResponse;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.error;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.success;
/** /**
* Tests for {@link OAuth2AuthorizationCodeAuthenticationProvider}. * Tests for {@link OAuth2AuthorizationCodeAuthenticationProvider}.
@ -61,8 +60,8 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
@Before @Before
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void setUp() { public void setUp() {
this.clientRegistration = clientRegistration().build(); this.clientRegistration = TestClientRegistrations.clientRegistration().build();
this.authorizationRequest = request().build(); this.authorizationRequest = TestOAuth2AuthorizationRequests.request().build();
this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class);
this.authenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient); this.authenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient);
} }
@ -80,7 +79,8 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
@Test @Test
public void authenticateWhenAuthorizationErrorResponseThenThrowOAuth2AuthorizationException() { public void authenticateWhenAuthorizationErrorResponseThenThrowOAuth2AuthorizationException() {
OAuth2AuthorizationResponse authorizationResponse = error().errorCode(OAuth2ErrorCodes.INVALID_REQUEST).build(); OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.error()
.errorCode(OAuth2ErrorCodes.INVALID_REQUEST).build();
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest,
authorizationResponse); authorizationResponse);
@ -92,7 +92,8 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
@Test @Test
public void authenticateWhenAuthorizationResponseStateNotEqualAuthorizationRequestStateThenThrowOAuth2AuthorizationException() { public void authenticateWhenAuthorizationResponseStateNotEqualAuthorizationRequestStateThenThrowOAuth2AuthorizationException() {
OAuth2AuthorizationResponse authorizationResponse = success().state("67890").build(); OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success().state("67890")
.build();
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest,
authorizationResponse); authorizationResponse);
@ -104,11 +105,12 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
@Test @Test
public void authenticateWhenAuthorizationSuccessResponseThenExchangedForAccessToken() { public void authenticateWhenAuthorizationSuccessResponseThenExchangedForAccessToken() {
OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().refreshToken("refresh").build(); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse()
.refreshToken("refresh").build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest,
success().build()); TestOAuth2AuthorizationResponses.success().build());
OAuth2AuthorizationCodeAuthenticationToken authenticationResult = (OAuth2AuthorizationCodeAuthenticationToken) this.authenticationProvider OAuth2AuthorizationCodeAuthenticationToken authenticationResult = (OAuth2AuthorizationCodeAuthenticationToken) this.authenticationProvider
.authenticate( .authenticate(
new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, authorizationExchange)); new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, authorizationExchange));
@ -131,12 +133,12 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
additionalParameters.put("param1", "value1"); additionalParameters.put("param1", "value1");
additionalParameters.put("param2", "value2"); additionalParameters.put("param2", "value2");
OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().additionalParameters(additionalParameters) OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse()
.build(); .additionalParameters(additionalParameters).build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest,
success().build()); TestOAuth2AuthorizationResponses.success().build());
OAuth2AuthorizationCodeAuthenticationToken authentication = (OAuth2AuthorizationCodeAuthenticationToken) this.authenticationProvider OAuth2AuthorizationCodeAuthenticationToken authentication = (OAuth2AuthorizationCodeAuthenticationToken) this.authenticationProvider
.authenticate( .authenticate(

View File

@ -21,15 +21,15 @@ import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.success;
/** /**
* Tests for {@link OAuth2AuthorizationCodeAuthenticationToken}. * Tests for {@link OAuth2AuthorizationCodeAuthenticationToken}.
@ -46,9 +46,10 @@ public class OAuth2AuthorizationCodeAuthenticationTokenTests {
@Before @Before
public void setUp() { public void setUp() {
this.clientRegistration = clientRegistration().build(); this.clientRegistration = TestClientRegistrations.clientRegistration().build();
this.authorizationExchange = new OAuth2AuthorizationExchange(request().build(), success().code("code").build()); this.authorizationExchange = new OAuth2AuthorizationExchange(TestOAuth2AuthorizationRequests.request().build(),
this.accessToken = noScopes(); TestOAuth2AuthorizationResponses.success().code("code").build());
this.accessToken = TestOAuth2AccessTokens.noScopes();
} }
@Test @Test

View File

@ -36,6 +36,7 @@ import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMap
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
@ -45,6 +46,8 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenRespon
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses;
import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.core.user.OAuth2User;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
@ -53,10 +56,6 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyCollection; import static org.mockito.ArgumentMatchers.anyCollection;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.error;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.success;
/** /**
* Tests for {@link OAuth2LoginAuthenticationProvider}. * Tests for {@link OAuth2LoginAuthenticationProvider}.
@ -85,9 +84,9 @@ public class OAuth2LoginAuthenticationProviderTests {
@Before @Before
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void setUp() { public void setUp() {
this.clientRegistration = clientRegistration().build(); this.clientRegistration = TestClientRegistrations.clientRegistration().build();
this.authorizationRequest = request().scope("scope1", "scope2").build(); this.authorizationRequest = TestOAuth2AuthorizationRequests.request().scope("scope1", "scope2").build();
this.authorizationResponse = success().build(); this.authorizationResponse = TestOAuth2AuthorizationResponses.success().build();
this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest,
this.authorizationResponse); this.authorizationResponse);
this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class);
@ -121,7 +120,8 @@ public class OAuth2LoginAuthenticationProviderTests {
@Test @Test
public void authenticateWhenAuthorizationRequestContainsOpenidScopeThenReturnNull() { public void authenticateWhenAuthorizationRequestContainsOpenidScopeThenReturnNull() {
OAuth2AuthorizationRequest authorizationRequest = request().scope("openid").build(); OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request().scope("openid")
.build();
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest,
this.authorizationResponse); this.authorizationResponse);
@ -136,7 +136,8 @@ public class OAuth2LoginAuthenticationProviderTests {
this.exception.expect(OAuth2AuthenticationException.class); this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString(OAuth2ErrorCodes.INVALID_REQUEST)); this.exception.expectMessage(containsString(OAuth2ErrorCodes.INVALID_REQUEST));
OAuth2AuthorizationResponse authorizationResponse = error().errorCode(OAuth2ErrorCodes.INVALID_REQUEST).build(); OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.error()
.errorCode(OAuth2ErrorCodes.INVALID_REQUEST).build();
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest,
authorizationResponse); authorizationResponse);
@ -149,7 +150,8 @@ public class OAuth2LoginAuthenticationProviderTests {
this.exception.expect(OAuth2AuthenticationException.class); this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_state_parameter")); this.exception.expectMessage(containsString("invalid_state_parameter"));
OAuth2AuthorizationResponse authorizationResponse = success().state("67890").build(); OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success().state("67890")
.build();
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest,
authorizationResponse); authorizationResponse);

View File

@ -23,16 +23,16 @@ import org.junit.Test;
import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses;
import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.core.user.OAuth2User;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.success;
/** /**
* Tests for {@link OAuth2LoginAuthenticationToken}. * Tests for {@link OAuth2LoginAuthenticationToken}.
@ -55,9 +55,10 @@ public class OAuth2LoginAuthenticationTokenTests {
public void setUp() { public void setUp() {
this.principal = mock(OAuth2User.class); this.principal = mock(OAuth2User.class);
this.authorities = Collections.emptyList(); this.authorities = Collections.emptyList();
this.clientRegistration = clientRegistration().build(); this.clientRegistration = TestClientRegistrations.clientRegistration().build();
this.authorizationExchange = new OAuth2AuthorizationExchange(request().build(), success().code("code").build()); this.authorizationExchange = new OAuth2AuthorizationExchange(TestOAuth2AuthorizationRequests.request().build(),
this.accessToken = noScopes(); TestOAuth2AuthorizationResponses.success().code("code").build());
this.accessToken = TestOAuth2AccessTokens.noScopes();
} }
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)

View File

@ -27,6 +27,7 @@ import org.junit.rules.ExpectedException;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
@ -34,12 +35,11 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenRespon
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.containsString;
import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.success;
/** /**
* Tests for {@link NimbusAuthorizationCodeTokenResponseClient}. * Tests for {@link NimbusAuthorizationCodeTokenResponseClient}.
@ -63,10 +63,10 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
@Before @Before
public void setUp() { public void setUp() {
this.clientRegistrationBuilder = clientRegistration() this.clientRegistrationBuilder = TestClientRegistrations.clientRegistration()
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC); .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC);
this.authorizationRequest = request().build(); this.authorizationRequest = TestOAuth2AuthorizationRequests.request().build();
this.authorizationResponse = success().build(); this.authorizationResponse = TestOAuth2AuthorizationResponses.success().build();
this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest,
this.authorizationResponse); this.authorizationResponse);
} }
@ -112,7 +112,8 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
this.exception.expect(IllegalArgumentException.class); this.exception.expect(IllegalArgumentException.class);
String redirectUri = "http:\\example.com"; String redirectUri = "http:\\example.com";
OAuth2AuthorizationRequest authorizationRequest = request().redirectUri(redirectUri).build(); OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request()
.redirectUri(redirectUri).build();
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest,
this.authorizationResponse); this.authorizationResponse);
@ -260,8 +261,8 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
String tokenUri = server.url("/oauth2/token").toString(); String tokenUri = server.url("/oauth2/token").toString();
this.clientRegistrationBuilder.tokenUri(tokenUri); this.clientRegistrationBuilder.tokenUri(tokenUri);
OAuth2AuthorizationRequest authorizationRequest = request().scope("openid", "profile", "email", "address") OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request()
.build(); .scope("openid", "profile", "email", "address").build();
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest,
this.authorizationResponse); this.authorizationResponse);
@ -287,8 +288,8 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
String tokenUri = server.url("/oauth2/token").toString(); String tokenUri = server.url("/oauth2/token").toString();
this.clientRegistrationBuilder.tokenUri(tokenUri); this.clientRegistrationBuilder.tokenUri(tokenUri);
OAuth2AuthorizationRequest authorizationRequest = request().scope("openid", "profile", "email", "address") OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request()
.build(); .scope("openid", "profile", "email", "address").build();
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest,
this.authorizationResponse); this.authorizationResponse);

View File

@ -37,7 +37,6 @@ import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VALUE;
/** /**
* Tests for {@link OAuth2AuthorizationCodeGrantRequestEntityConverter}. * Tests for {@link OAuth2AuthorizationCodeGrantRequestEntityConverter}.
@ -84,7 +83,7 @@ public class OAuth2AuthorizationCodeGrantRequestEntityConverterTests {
HttpHeaders headers = requestEntity.getHeaders(); HttpHeaders headers = requestEntity.getHeaders();
assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON_UTF8); assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON_UTF8);
assertThat(headers.getContentType()) assertThat(headers.getContentType())
.isEqualTo(MediaType.valueOf(APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8")); .isEqualTo(MediaType.valueOf(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"));
assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).startsWith("Basic ");
MultiValueMap<String, String> formParameters = (MultiValueMap<String, String>) requestEntity.getBody(); MultiValueMap<String, String> formParameters = (MultiValueMap<String, String>) requestEntity.getBody();
@ -127,7 +126,7 @@ public class OAuth2AuthorizationCodeGrantRequestEntityConverterTests {
HttpHeaders headers = requestEntity.getHeaders(); HttpHeaders headers = requestEntity.getHeaders();
assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON_UTF8); assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON_UTF8);
assertThat(headers.getContentType()) assertThat(headers.getContentType())
.isEqualTo(MediaType.valueOf(APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8")); .isEqualTo(MediaType.valueOf(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"));
assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).isNull(); assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).isNull();
MultiValueMap<String, String> formParameters = (MultiValueMap<String, String>) requestEntity.getBody(); MultiValueMap<String, String> formParameters = (MultiValueMap<String, String>) requestEntity.getBody();

Some files were not shown because too many files have changed in this diff Show More