diff --git a/core/src/main/java/org/springframework/security/authentication/dao/AbstractUserDetailsAuthenticationProvider.java b/core/src/main/java/org/springframework/security/authentication/dao/AbstractUserDetailsAuthenticationProvider.java index d107f9aa22..8ee72cf3fc 100644 --- a/core/src/main/java/org/springframework/security/authentication/dao/AbstractUserDetailsAuthenticationProvider.java +++ b/core/src/main/java/org/springframework/security/authentication/dao/AbstractUserDetailsAuthenticationProvider.java @@ -92,6 +92,8 @@ public abstract class AbstractUserDetailsAuthenticationProvider private UserDetailsChecker postAuthenticationChecks = new DefaultPostAuthenticationChecks(); + private boolean alwaysPerformAdditionalChecksOnUser = true; + private GrantedAuthoritiesMapper authoritiesMapper = new NullAuthoritiesMapper(); /** @@ -146,8 +148,7 @@ public abstract class AbstractUserDetailsAuthenticationProvider Assert.notNull(user, "retrieveUser returned null - a violation of the interface contract"); } try { - this.preAuthenticationChecks.check(user); - additionalAuthenticationChecks(user, (UsernamePasswordAuthenticationToken) authentication); + performPreCheck(user, (UsernamePasswordAuthenticationToken) authentication); } catch (AuthenticationException ex) { if (!cacheWasUsed) { @@ -157,8 +158,7 @@ public abstract class AbstractUserDetailsAuthenticationProvider // we're using latest data (i.e. not from the cache) cacheWasUsed = false; user = retrieveUser(username, (UsernamePasswordAuthenticationToken) authentication); - this.preAuthenticationChecks.check(user); - additionalAuthenticationChecks(user, (UsernamePasswordAuthenticationToken) authentication); + performPreCheck(user, (UsernamePasswordAuthenticationToken) authentication); } this.postAuthenticationChecks.check(user); if (!cacheWasUsed) { @@ -171,6 +171,25 @@ public abstract class AbstractUserDetailsAuthenticationProvider return createSuccessAuthentication(principalToReturn, authentication, user); } + private void performPreCheck(UserDetails user, UsernamePasswordAuthenticationToken authentication) { + try { + this.preAuthenticationChecks.check(user); + } + catch (AuthenticationException ex) { + if (!this.alwaysPerformAdditionalChecksOnUser) { + throw ex; + } + try { + additionalAuthenticationChecks(user, authentication); + } + catch (AuthenticationException ignored) { + // preserve the original failed check + } + throw ex; + } + additionalAuthenticationChecks(user, authentication); + } + private String determineUsername(Authentication authentication) { return (authentication.getPrincipal() == null) ? "NONE_PROVIDED" : authentication.getName(); } @@ -313,6 +332,22 @@ public abstract class AbstractUserDetailsAuthenticationProvider this.postAuthenticationChecks = postAuthenticationChecks; } + /** + * Set whether to always perform the additional checks on the user, even if the + * pre-authentication checks fail. This is useful to ensure that regardless of the + * state of the user account, authentication takes the same amount of time to + * complete. + * + *

+ * For applications that rely on the additional checks running only once should set + * this value to {@code false} + * @param alwaysPerformAdditionalChecksOnUser + * @since 5.7.23 + */ + public void setAlwaysPerformAdditionalChecksOnUser(boolean alwaysPerformAdditionalChecksOnUser) { + this.alwaysPerformAdditionalChecksOnUser = alwaysPerformAdditionalChecksOnUser; + } + public void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) { this.authoritiesMapper = authoritiesMapper; } diff --git a/core/src/test/java/org/springframework/security/authentication/dao/DaoAuthenticationProviderTests.java b/core/src/test/java/org/springframework/security/authentication/dao/DaoAuthenticationProviderTests.java index 36a5d3bcb9..24cbe3b332 100644 --- a/core/src/test/java/org/springframework/security/authentication/dao/DaoAuthenticationProviderTests.java +++ b/core/src/test/java/org/springframework/security/authentication/dao/DaoAuthenticationProviderTests.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.List; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfSystemProperty; import org.springframework.cache.Cache; import org.springframework.dao.DataRetrievalFailureException; @@ -42,6 +43,7 @@ import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.userdetails.PasswordEncodedUser; import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.security.core.userdetails.UserDetailsChecker; import org.springframework.security.core.userdetails.UserDetailsPasswordService; import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.core.userdetails.UsernameNotFoundException; @@ -62,6 +64,7 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.willThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -452,12 +455,10 @@ public class DaoAuthenticationProviderTests { assertThat(daoAuthenticationProvider.getPasswordEncoder()).isSameAs(NoOpPasswordEncoder.getInstance()); } - /** - * This is an explicit test for SEC-2056. It is intentionally ignored since this test - * is not deterministic and {@link #testUserNotFoundEncodesPassword()} ensures that - * SEC-2056 is fixed. - */ - public void IGNOREtestSec2056() { + // SEC-2056 + @Test + @EnabledIfSystemProperty(named = "spring.security.timing-tests", matches = "true") + public void testSec2056() { UsernamePasswordAuthenticationToken foundUser = UsernamePasswordAuthenticationToken.unauthenticated("rod", "koala"); UsernamePasswordAuthenticationToken notFoundUser = UsernamePasswordAuthenticationToken @@ -491,6 +492,41 @@ public class DaoAuthenticationProviderTests { .isTrue(); } + // related to SEC-2056 + @Test + @EnabledIfSystemProperty(named = "spring.security.timing-tests", matches = "true") + public void testDisabledUserTiming() { + UsernamePasswordAuthenticationToken user = UsernamePasswordAuthenticationToken.unauthenticated("rod", "koala"); + PasswordEncoder encoder = new BCryptPasswordEncoder(); + DaoAuthenticationProvider provider = new DaoAuthenticationProvider(); + provider.setPasswordEncoder(encoder); + MockUserDetailsServiceUserRod users = new MockUserDetailsServiceUserRod(); + users.password = encoder.encode((CharSequence) user.getCredentials()); + provider.setUserDetailsService(users); + int sampleSize = 100; + List enabledTimes = new ArrayList<>(sampleSize); + for (int i = 0; i < sampleSize; i++) { + long start = System.currentTimeMillis(); + provider.authenticate(user); + enabledTimes.add(System.currentTimeMillis() - start); + } + UserDetailsChecker preChecks = mock(UserDetailsChecker.class); + willThrow(new DisabledException("User is disabled")).given(preChecks).check(any(UserDetails.class)); + provider.setPreAuthenticationChecks(preChecks); + List disabledTimes = new ArrayList<>(sampleSize); + for (int i = 0; i < sampleSize; i++) { + long start = System.currentTimeMillis(); + assertThatExceptionOfType(DisabledException.class).isThrownBy(() -> provider.authenticate(user)); + disabledTimes.add(System.currentTimeMillis() - start); + } + double enabledAvg = avg(enabledTimes); + double disabledAvg = avg(disabledTimes); + assertThat(Math.abs(disabledAvg - enabledAvg) <= 3) + .withFailMessage("Disabled user average " + disabledAvg + " should be within 3ms of enabled user average " + + enabledAvg) + .isTrue(); + } + private double avg(List counts) { return counts.stream().mapToLong(Long::longValue).average().orElse(0); }