Polish spring-security-test main code

Manually polish `spring-security-test` following the formatting
and checkstyle fixes.

Issue gh-8945
This commit is contained in:
Phillip Webb 2020-07-31 23:19:43 -07:00 committed by Rob Winch
parent 2ca6256b89
commit ef951bae90
12 changed files with 88 additions and 159 deletions

View File

@ -56,12 +56,14 @@ import org.springframework.util.Assert;
* @author Rob Winch
* @author Tadaya Tsuyukubo
* @since 4.0
*
*/
public final class TestSecurityContextHolder {
private static final ThreadLocal<SecurityContext> contextHolder = new ThreadLocal<>();
private TestSecurityContextHolder() {
}
/**
* Clears the {@link SecurityContext} from {@link TestSecurityContextHolder} and
* {@link SecurityContextHolder}.
@ -77,12 +79,10 @@ public final class TestSecurityContextHolder {
*/
public static SecurityContext getContext() {
SecurityContext ctx = contextHolder.get();
if (ctx == null) {
ctx = getDefaultContext();
contextHolder.set(ctx);
}
return ctx;
}
@ -120,7 +120,4 @@ public final class TestSecurityContextHolder {
return SecurityContextHolder.getContext();
}
private TestSecurityContextHolder() {
}
}

View File

@ -52,9 +52,11 @@ public class ReactorContextTestExecutionListener extends DelegatingTestExecution
}
private static TestExecutionListener createDelegate() {
return ClassUtils.isPresent(HOOKS_CLASS_NAME, ReactorContextTestExecutionListener.class.getClassLoader())
? new DelegateTestExecutionListener() : new AbstractTestExecutionListener() {
};
if (!ClassUtils.isPresent(HOOKS_CLASS_NAME, ReactorContextTestExecutionListener.class.getClassLoader())) {
return new AbstractTestExecutionListener() {
};
}
return new DelegateTestExecutionListener();
}
/**

View File

@ -33,6 +33,7 @@ public enum TestExecutionEvent {
* event.
*/
TEST_METHOD,
/**
* Associated to
* {@link org.springframework.test.context.TestExecutionListener#beforeTestExecution(TestContext)}

View File

@ -27,6 +27,7 @@ import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.userdetails.User;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
/**
@ -41,21 +42,14 @@ final class WithMockUserSecurityContextFactory implements WithSecurityContextFac
@Override
public SecurityContext createSecurityContext(WithMockUser withUser) {
String username = StringUtils.hasLength(withUser.username()) ? withUser.username() : withUser.value();
if (username == null) {
throw new IllegalArgumentException(
withUser + " cannot have null username on both username and value properties");
}
Assert.notNull(username, () -> withUser + " cannot have null username on both username and value properties");
List<GrantedAuthority> grantedAuthorities = new ArrayList<>();
for (String authority : withUser.authorities()) {
grantedAuthorities.add(new SimpleGrantedAuthority(authority));
}
if (grantedAuthorities.isEmpty()) {
for (String role : withUser.roles()) {
if (role.startsWith("ROLE_")) {
throw new IllegalArgumentException("roles cannot start with ROLE_ Got " + role);
}
Assert.isTrue(!role.startsWith("ROLE_"), () -> "roles cannot start with ROLE_ Got " + role);
grantedAuthorities.add(new SimpleGrantedAuthority("ROLE_" + role));
}
}
@ -63,7 +57,6 @@ final class WithMockUserSecurityContextFactory implements WithSecurityContextFac
throw new IllegalStateException("You cannot define roles attribute " + Arrays.asList(withUser.roles())
+ " with authorities attribute " + Arrays.asList(withUser.authorities()));
}
User principal = new User(username, withUser.password(), true, true, true, true, grantedAuthorities);
Authentication authentication = new UsernamePasswordAuthenticationToken(principal, principal.getPassword(),
principal.getAuthorities());

View File

@ -68,7 +68,6 @@ public class WithSecurityContextTestExecutionListener extends AbstractTestExecut
if (testSecurityContext == null) {
return;
}
Supplier<SecurityContext> supplier = testSecurityContext.getSecurityContextSupplier();
if (testSecurityContext.getTestExecutionEvent() == TestExecutionEvent.TEST_METHOD) {
TestSecurityContextHolder.setContext(supplier.get());

View File

@ -84,7 +84,7 @@ final class WithUserDetailsSecurityContextFactory implements WithSecurityContext
: this.beans.getBean(ReactiveUserDetailsService.class);
return new ReactiveUserDetailsServiceAdapter(reactiveUserDetailsService);
}
catch (NoSuchBeanDefinitionException | BeanNotOfRequiredTypeException notReactive) {
catch (NoSuchBeanDefinitionException | BeanNotOfRequiredTypeException ex) {
return null;
}
}

View File

@ -108,10 +108,12 @@ public final class SecurityMockServerConfigurers {
*/
public static MockServerConfigurer springSecurity() {
return new MockServerConfigurer() {
@Override
public void beforeServerCreated(WebHttpHandlerBuilder builder) {
builder.filters((filters) -> filters.add(0, new MutatorFilter()));
}
};
}
@ -992,26 +994,22 @@ public final class SecurityMockServerConfigurers {
}
private Collection<GrantedAuthority> getAuthorities() {
if (this.authorities == null) {
Set<GrantedAuthority> authorities = new LinkedHashSet<>();
authorities.add(new OidcUserAuthority(getOidcIdToken(), getOidcUserInfo()));
for (String authority : this.accessToken.getScopes()) {
authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority));
}
return authorities;
}
else {
if (this.authorities != null) {
return this.authorities;
}
Set<GrantedAuthority> authorities = new LinkedHashSet<>();
authorities.add(new OidcUserAuthority(getOidcIdToken(), getOidcUserInfo()));
for (String authority : this.accessToken.getScopes()) {
authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority));
}
return authorities;
}
private OidcIdToken getOidcIdToken() {
if (this.idToken == null) {
return new OidcIdToken("id-token", null, null, Collections.singletonMap(IdTokenClaimNames.SUB, "user"));
}
else {
if (this.idToken != null) {
return this.idToken;
}
return new OidcIdToken("id-token", null, null, Collections.singletonMap(IdTokenClaimNames.SUB, "user"));
}
private OidcUserInfo getOidcUserInfo() {
@ -1071,7 +1069,6 @@ public final class SecurityMockServerConfigurers {
*/
public OAuth2ClientMutator clientRegistration(
Consumer<ClientRegistration.Builder> clientRegistrationConfigurer) {
ClientRegistration.Builder builder = clientRegistrationBuilder();
clientRegistrationConfigurer.accept(builder);
this.clientRegistration = builder.build();
@ -1108,7 +1105,6 @@ public final class SecurityMockServerConfigurers {
@Override
public void afterConfigureAdded(WebTestClient.MockServerSpec<?> serverSpec) {
}
@Override
@ -1134,10 +1130,8 @@ public final class SecurityMockServerConfigurers {
}
private OAuth2AuthorizedClient getClient() {
if (this.clientRegistration == null) {
throw new IllegalArgumentException(
"Please specify a ClientRegistration via one " + "of the clientRegistration methods");
}
Assert.notNull(this.clientRegistration,
"Please specify a ClientRegistration via one of the clientRegistration methods");
return new OAuth2AuthorizedClient(this.clientRegistration, this.principalName, this.accessToken);
}
@ -1173,9 +1167,7 @@ public final class SecurityMockServerConfigurers {
OAuth2AuthorizedClient client = exchange.getAttribute(TOKEN_ATTR_NAME);
return Mono.just(client);
}
else {
return this.delegate.authorize(authorizeRequest);
}
return this.delegate.authorize(authorizeRequest);
}
static void enable(ServerWebExchange exchange) {

View File

@ -36,10 +36,12 @@ import static org.springframework.test.web.servlet.request.MockMvcRequestBuilder
*
* @author Rob Winch
* @since 4.0
*
*/
public final class SecurityMockMvcRequestBuilders {
private SecurityMockMvcRequestBuilders() {
}
/**
* Creates a request (including any necessary {@link CsrfToken}) that will submit a
* form based login to POST "/login".
@ -91,18 +93,18 @@ public final class SecurityMockMvcRequestBuilders {
private Mergeable parent;
private LogoutRequestBuilder() {
}
@Override
public MockHttpServletRequest buildRequest(ServletContext servletContext) {
MockHttpServletRequestBuilder logoutRequest = post(this.logoutUrl).accept(MediaType.TEXT_HTML,
MediaType.ALL);
if (this.parent != null) {
logoutRequest = (MockHttpServletRequestBuilder) logoutRequest.merge(this.parent);
}
MockHttpServletRequest request = logoutRequest.buildRequest(servletContext);
logoutRequest.postProcessRequest(request);
return this.postProcessor.postProcessRequest(request);
}
@ -141,12 +143,7 @@ public final class SecurityMockMvcRequestBuilders {
this.parent = (Mergeable) parent;
return this;
}
else {
throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]");
}
}
private LogoutRequestBuilder() {
throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]");
}
}
@ -175,18 +172,18 @@ public final class SecurityMockMvcRequestBuilders {
private RequestPostProcessor postProcessor = csrf();
private FormLoginRequestBuilder() {
}
@Override
public MockHttpServletRequest buildRequest(ServletContext servletContext) {
MockHttpServletRequestBuilder loginRequest = post(this.loginProcessingUrl).accept(this.acceptMediaType)
.param(this.usernameParam, this.username).param(this.passwordParam, this.password);
if (this.parent != null) {
loginRequest = (MockHttpServletRequestBuilder) loginRequest.merge(this.parent);
}
MockHttpServletRequest request = loginRequest.buildRequest(servletContext);
loginRequest.postProcessRequest(request);
return this.postProcessor.postProcessRequest(request);
}
@ -305,17 +302,9 @@ public final class SecurityMockMvcRequestBuilders {
this.parent = (Mergeable) parent;
return this;
}
else {
throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]");
}
throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]");
}
private FormLoginRequestBuilder() {
}
}
private SecurityMockMvcRequestBuilders() {
}
}

View File

@ -116,6 +116,9 @@ import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandl
*/
public final class SecurityMockMvcRequestPostProcessors {
private SecurityMockMvcRequestPostProcessors() {
}
/**
* Creates a DigestRequestPostProcessor that enables easily adding digest based
* authentication to a request.
@ -634,7 +637,6 @@ public final class SecurityMockMvcRequestPostProcessors {
String toDigest = expiryTime + ":" + "key";
String signatureValue = md5Hex(toDigest);
String nonceValue = expiryTime + ":" + signatureValue;
return new String(Base64.getEncoder().encode(nonceValue.getBytes()));
}
@ -649,7 +651,6 @@ public final class SecurityMockMvcRequestPostProcessors {
@Override
public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
request.addHeader("Authorization", createAuthorizationHeader(request));
return request;
}
@ -676,28 +677,19 @@ public final class SecurityMockMvcRequestPostProcessors {
String a1Md5 = encodePasswordInA1Format(username, realm, password);
String a2 = httpMethod + ":" + uri;
String a2Md5 = md5Hex(a2);
String digest;
if (qop == null) {
// as per RFC 2069 compliant clients (also reaffirmed by RFC 2617)
digest = a1Md5 + ":" + nonce + ":" + a2Md5;
return md5Hex(a1Md5 + ":" + nonce + ":" + a2Md5);
}
else if ("auth".equals(qop)) {
if ("auth".equals(qop)) {
// As per RFC 2617 compliant clients
digest = a1Md5 + ":" + nonce + ":" + nc + ":" + cnonce + ":" + qop + ":" + a2Md5;
return md5Hex(a1Md5 + ":" + nonce + ":" + nc + ":" + cnonce + ":" + qop + ":" + a2Md5);
}
else {
throw new IllegalArgumentException("This method does not support a qop: '" + qop + "'");
}
return md5Hex(digest);
throw new IllegalArgumentException("This method does not support a qop: '" + qop + "'");
}
static String encodePasswordInA1Format(String username, String realm, String password) {
String a1 = username + ":" + realm + ":" + password;
return md5Hex(a1);
return md5Hex(username + ":" + realm + ":" + password);
}
private static String md5Hex(String a2) {
@ -736,15 +728,11 @@ public final class SecurityMockMvcRequestPostProcessors {
securityContextRepository = new TestSecurityContextRepository(securityContextRepository);
WebTestUtils.setSecurityContextRepository(request, securityContextRepository);
}
HttpServletResponse response = new MockHttpServletResponse();
HttpRequestResponseHolder requestResponseHolder = new HttpRequestResponseHolder(request, response);
securityContextRepository.loadContext(requestResponseHolder);
request = requestResponseHolder.getRequest();
response = requestResponseHolder.getResponse();
securityContextRepository.saveContext(securityContext, request, response);
}
@ -812,12 +800,10 @@ public final class SecurityMockMvcRequestPostProcessors {
if (existingContext != null) {
return request;
}
SecurityContext context = TestSecurityContextHolder.getContext();
if (!this.EMPTY.equals(context)) {
save(context, request);
}
return request;
}
@ -889,7 +875,6 @@ public final class SecurityMockMvcRequestPostProcessors {
UserDetailsRequestPostProcessor(UserDetails user) {
Authentication token = new UsernamePasswordAuthenticationToken(user, user.getPassword(),
user.getAuthorities());
this.delegate = new AuthenticationRequestPostProcessor(token);
}
@ -948,13 +933,9 @@ public final class SecurityMockMvcRequestPostProcessors {
public UserRequestPostProcessor roles(String... roles) {
List<GrantedAuthority> authorities = new ArrayList<>(roles.length);
for (String role : roles) {
if (role.startsWith(ROLE_PREFIX)) {
throw new IllegalArgumentException("Role should not start with " + ROLE_PREFIX
+ " since this method automatically prefixes with this value. Got " + role);
}
else {
authorities.add(new SimpleGrantedAuthority(ROLE_PREFIX + role));
}
Assert.isTrue(!role.startsWith(ROLE_PREFIX), () -> "Role should not start with " + ROLE_PREFIX
+ " since this method automatically prefixes with this value. Got " + role);
authorities.add(new SimpleGrantedAuthority(ROLE_PREFIX + role));
}
this.authorities = authorities;
return this;
@ -1027,8 +1008,7 @@ public final class SecurityMockMvcRequestPostProcessors {
private String headerValue;
private HttpBasicRequestPostProcessor(String username, String password) {
byte[] toEncode;
toEncode = (username + ":" + password).getBytes(StandardCharsets.UTF_8);
byte[] toEncode = (username + ":" + password).getBytes(StandardCharsets.UTF_8);
this.headerValue = "Basic " + new String(Base64.getEncoder().encode(toEncode));
}
@ -1356,7 +1336,6 @@ public final class SecurityMockMvcRequestPostProcessors {
OAuth2User oauth2User = this.oauth2User.get();
OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(oauth2User, oauth2User.getAuthorities(),
this.clientRegistration.getRegistrationId());
request = new AuthenticationRequestPostProcessor(token).postProcessRequest(request);
return new OAuth2ClientRequestPostProcessor().clientRegistration(this.clientRegistration)
.principalName(oauth2User.getName()).accessToken(this.accessToken).postProcessRequest(request);
@ -1504,26 +1483,22 @@ public final class SecurityMockMvcRequestPostProcessors {
}
private Collection<GrantedAuthority> getAuthorities() {
if (this.authorities == null) {
Set<GrantedAuthority> authorities = new LinkedHashSet<>();
authorities.add(new OidcUserAuthority(getOidcIdToken(), getOidcUserInfo()));
for (String authority : this.accessToken.getScopes()) {
authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority));
}
return authorities;
}
else {
if (this.authorities != null) {
return this.authorities;
}
Set<GrantedAuthority> authorities = new LinkedHashSet<>();
authorities.add(new OidcUserAuthority(getOidcIdToken(), getOidcUserInfo()));
for (String authority : this.accessToken.getScopes()) {
authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority));
}
return authorities;
}
private OidcIdToken getOidcIdToken() {
if (this.idToken == null) {
return new OidcIdToken("id-token", null, null, Collections.singletonMap(IdTokenClaimNames.SUB, "user"));
}
else {
if (this.idToken != null) {
return this.idToken;
}
return new OidcIdToken("id-token", null, null, Collections.singletonMap(IdTokenClaimNames.SUB, "user"));
}
private OidcUserInfo getOidcUserInfo() {
@ -1577,7 +1552,6 @@ public final class SecurityMockMvcRequestPostProcessors {
*/
public OAuth2ClientRequestPostProcessor clientRegistration(
Consumer<ClientRegistration.Builder> clientRegistrationConfigurer) {
ClientRegistration.Builder builder = clientRegistrationBuilder();
clientRegistrationConfigurer.accept(builder);
this.clientRegistration = builder.build();
@ -1613,7 +1587,6 @@ public final class SecurityMockMvcRequestPostProcessors {
}
OAuth2AuthorizedClient client = new OAuth2AuthorizedClient(this.clientRegistration, this.principalName,
this.accessToken);
OAuth2AuthorizedClientManager authorizationClientManager = OAuth2ClientServletTestUtils
.getOAuth2AuthorizedClientManager(request);
if (!(authorizationClientManager instanceof TestOAuth2AuthorizedClientManager)) {
@ -1654,9 +1627,7 @@ public final class SecurityMockMvcRequestPostProcessors {
if (isEnabled(request)) {
return (OAuth2AuthorizedClient) request.getAttribute(TOKEN_ATTR_NAME);
}
else {
return this.delegate.authorize(authorizeRequest);
}
return this.delegate.authorize(authorizeRequest);
}
static void enable(HttpServletRequest request) {
@ -1762,7 +1733,4 @@ public final class SecurityMockMvcRequestPostProcessors {
}
private SecurityMockMvcRequestPostProcessors() {
}
}

View File

@ -43,6 +43,9 @@ import org.springframework.test.web.servlet.ResultMatcher;
*/
public final class SecurityMockMvcResultMatchers {
private SecurityMockMvcResultMatchers() {
}
/**
* {@link ResultMatcher} that verifies that a specified user is authenticated.
* @return the {@link AuthenticatedMatcher} to use
@ -90,29 +93,26 @@ public final class SecurityMockMvcResultMatchers {
private Consumer<Authentication> assertAuthentication;
AuthenticatedMatcher() {
}
@Override
public void match(MvcResult result) {
SecurityContext context = load(result);
Authentication auth = context.getAuthentication();
AssertionErrors.assertTrue("Authentication should not be null", auth != null);
if (this.assertAuthentication != null) {
this.assertAuthentication.accept(auth);
}
if (this.expectedContext != null) {
AssertionErrors.assertEquals(this.expectedContext + " does not equal " + context, this.expectedContext,
context);
}
if (this.expectedAuthentication != null) {
AssertionErrors.assertEquals(
this.expectedAuthentication + " does not equal " + context.getAuthentication(),
this.expectedAuthentication, context.getAuthentication());
}
if (this.expectedAuthenticationPrincipal != null) {
AssertionErrors.assertTrue("Authentication cannot be null", context.getAuthentication() != null);
AssertionErrors.assertEquals(
@ -120,14 +120,12 @@ public final class SecurityMockMvcResultMatchers {
+ context.getAuthentication().getPrincipal(),
this.expectedAuthenticationPrincipal, context.getAuthentication().getPrincipal());
}
if (this.expectedAuthenticationName != null) {
AssertionErrors.assertTrue("Authentication cannot be null", auth != null);
String name = auth.getName();
AssertionErrors.assertEquals(this.expectedAuthenticationName + " does not equal " + name,
this.expectedAuthenticationName, name);
}
if (this.expectedGrantedAuthorities != null) {
AssertionErrors.assertTrue("Authentication cannot be null", auth != null);
Collection<? extends GrantedAuthority> authorities = auth.getAuthorities();
@ -222,9 +220,6 @@ public final class SecurityMockMvcResultMatchers {
return withAuthorities(authorities);
}
AuthenticatedMatcher() {
}
}
/**
@ -238,6 +233,9 @@ public final class SecurityMockMvcResultMatchers {
private AuthenticationTrustResolver trustResolver = new AuthenticationTrustResolverImpl();
private UnAuthenticatedMatcher() {
}
@Override
public void match(MvcResult result) {
SecurityContext context = load(result);
@ -247,12 +245,6 @@ public final class SecurityMockMvcResultMatchers {
authentication == null || this.trustResolver.isAnonymous(authentication));
}
private UnAuthenticatedMatcher() {
}
}
private SecurityMockMvcResultMatchers() {
}
}

View File

@ -29,6 +29,7 @@ import org.springframework.security.config.BeanIds;
import org.springframework.test.web.servlet.request.RequestPostProcessor;
import org.springframework.test.web.servlet.setup.ConfigurableMockMvcBuilder;
import org.springframework.test.web.servlet.setup.MockMvcConfigurerAdapter;
import org.springframework.util.Assert;
import org.springframework.web.context.WebApplicationContext;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.testSecurityContext;
@ -72,15 +73,11 @@ final class SecurityMockMvcConfigurer extends MockMvcConfigurerAdapter {
if (getSpringSecurityFilterChain() == null && context.containsBean(securityBeanId)) {
setSpringSecurityFitlerChain(context.getBean(securityBeanId, Filter.class));
}
if (getSpringSecurityFilterChain() == null) {
throw new IllegalStateException("springSecurityFilterChain cannot be null. Ensure a Bean with the name "
+ securityBeanId + " implementing Filter is present or inject the Filter to be used.");
}
Assert.state(getSpringSecurityFilterChain() != null,
() -> "springSecurityFilterChain cannot be null. Ensure a Bean with the name " + securityBeanId
+ " implementing Filter is present or inject the Filter to be used.");
// This is used by other test support to obtain the FilterChainProxy
context.getServletContext().setAttribute(BeanIds.SPRING_SECURITY_FILTER_CHAIN, getSpringSecurityFilterChain());
return testSecurityContext();
}
@ -118,11 +115,9 @@ final class SecurityMockMvcConfigurer extends MockMvcConfigurerAdapter {
Filter getDelegate() {
Filter result = this.delegate;
if (result == null) {
throw new IllegalStateException(
"delegate cannot be null. Ensure a Bean with the name " + BeanIds.SPRING_SECURITY_FILTER_CHAIN
+ " implementing Filter is present or inject the Filter to be used.");
}
Assert.state(result != null,
() -> "delegate cannot be null. Ensure a Bean with the name " + BeanIds.SPRING_SECURITY_FILTER_CHAIN
+ " implementing Filter is present or inject the Filter to be used.");
return result;
}

View File

@ -47,6 +47,9 @@ public abstract class WebTestUtils {
private static final CsrfTokenRepository DEFAULT_TOKEN_REPO = new HttpSessionCsrfTokenRepository();
private WebTestUtils() {
}
/**
* Gets the {@link SecurityContextRepository} for the specified
* {@link HttpServletRequest}. If one is not found, a default
@ -134,18 +137,16 @@ public abstract class WebTestUtils {
}
WebApplicationContext webApplicationContext = WebApplicationContextUtils
.getWebApplicationContext(servletContext);
if (webApplicationContext != null) {
try {
return webApplicationContext.getBean(AbstractSecurityWebApplicationInitializer.DEFAULT_FILTER_NAME,
Filter.class);
}
catch (NoSuchBeanDefinitionException notFound) {
}
if (webApplicationContext == null) {
return null;
}
try {
String beanName = AbstractSecurityWebApplicationInitializer.DEFAULT_FILTER_NAME;
return webApplicationContext.getBean(beanName, Filter.class);
}
catch (NoSuchBeanDefinitionException ex) {
return null;
}
return null;
}
private WebTestUtils() {
}
}