diff --git a/core/src/main/java/org/springframework/security/authentication/dao/AbstractUserDetailsAuthenticationProvider.java b/core/src/main/java/org/springframework/security/authentication/dao/AbstractUserDetailsAuthenticationProvider.java index e33aa2077e..a45ef17358 100644 --- a/core/src/main/java/org/springframework/security/authentication/dao/AbstractUserDetailsAuthenticationProvider.java +++ b/core/src/main/java/org/springframework/security/authentication/dao/AbstractUserDetailsAuthenticationProvider.java @@ -97,6 +97,8 @@ public abstract class AbstractUserDetailsAuthenticationProvider private UserDetailsChecker postAuthenticationChecks = new DefaultPostAuthenticationChecks(); + private boolean alwaysPerformAdditionalChecksOnUser = true; + private GrantedAuthoritiesMapper authoritiesMapper = new NullAuthoritiesMapper(); private static final String AUTHORITY = FactorGrantedAuthority.PASSWORD_AUTHORITY; @@ -154,8 +156,7 @@ public abstract class AbstractUserDetailsAuthenticationProvider Assert.notNull(user, "retrieveUser returned null - a violation of the interface contract"); } try { - this.preAuthenticationChecks.check(user); - additionalAuthenticationChecks(user, (UsernamePasswordAuthenticationToken) authentication); + performPreCheck(user, (UsernamePasswordAuthenticationToken) authentication); } catch (AuthenticationException ex) { if (!cacheWasUsed) { @@ -165,8 +166,7 @@ public abstract class AbstractUserDetailsAuthenticationProvider // we're using latest data (i.e. not from the cache) cacheWasUsed = false; user = retrieveUser(username, (UsernamePasswordAuthenticationToken) authentication); - this.preAuthenticationChecks.check(user); - additionalAuthenticationChecks(user, (UsernamePasswordAuthenticationToken) authentication); + performPreCheck(user, (UsernamePasswordAuthenticationToken) authentication); } this.postAuthenticationChecks.check(user); if (!cacheWasUsed) { @@ -179,6 +179,25 @@ public abstract class AbstractUserDetailsAuthenticationProvider return createSuccessAuthentication(principalToReturn, authentication, user); } + private void performPreCheck(UserDetails user, UsernamePasswordAuthenticationToken authentication) { + try { + this.preAuthenticationChecks.check(user); + } + catch (AuthenticationException ex) { + if (!this.alwaysPerformAdditionalChecksOnUser) { + throw ex; + } + try { + additionalAuthenticationChecks(user, authentication); + } + catch (AuthenticationException ignored) { + // preserve the original failed check + } + throw ex; + } + additionalAuthenticationChecks(user, authentication); + } + private String determineUsername(Authentication authentication) { return (authentication.getPrincipal() == null) ? "NONE_PROVIDED" : authentication.getName(); } @@ -324,6 +343,22 @@ public abstract class AbstractUserDetailsAuthenticationProvider this.postAuthenticationChecks = postAuthenticationChecks; } + /** + * Set whether to always perform the additional checks on the user, even if the + * pre-authentication checks fail. This is useful to ensure that regardless of the + * state of the user account, authentication takes the same amount of time to + * complete. + * + *

+ * For applications that rely on the additional checks running only once should set + * this value to {@code false} + * @param alwaysPerformAdditionalChecksOnUser + * @since 5.7.23 + */ + public void setAlwaysPerformAdditionalChecksOnUser(boolean alwaysPerformAdditionalChecksOnUser) { + this.alwaysPerformAdditionalChecksOnUser = alwaysPerformAdditionalChecksOnUser; + } + public void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) { this.authoritiesMapper = authoritiesMapper; } diff --git a/core/src/test/java/org/springframework/security/authentication/dao/DaoAuthenticationProviderTests.java b/core/src/test/java/org/springframework/security/authentication/dao/DaoAuthenticationProviderTests.java index b491510865..a2aa54656f 100644 --- a/core/src/test/java/org/springframework/security/authentication/dao/DaoAuthenticationProviderTests.java +++ b/core/src/test/java/org/springframework/security/authentication/dao/DaoAuthenticationProviderTests.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.List; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfSystemProperty; import org.springframework.cache.Cache; import org.springframework.dao.DataRetrievalFailureException; @@ -44,6 +45,7 @@ import org.springframework.security.core.authority.FactorGrantedAuthority; import org.springframework.security.core.userdetails.PasswordEncodedUser; import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.security.core.userdetails.UserDetailsChecker; import org.springframework.security.core.userdetails.UserDetailsPasswordService; import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.core.userdetails.UsernameNotFoundException; @@ -64,6 +66,7 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.willThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -422,12 +425,10 @@ public class DaoAuthenticationProviderTests { assertThat(daoAuthenticationProvider.getPasswordEncoder()).isSameAs(NoOpPasswordEncoder.getInstance()); } - /** - * This is an explicit test for SEC-2056. It is intentionally ignored since this test - * is not deterministic and {@link #testUserNotFoundEncodesPassword()} ensures that - * SEC-2056 is fixed. - */ - public void IGNOREtestSec2056() { + // SEC-2056 + @Test + @EnabledIfSystemProperty(named = "spring.security.timing-tests", matches = "true") + public void testSec2056() { UsernamePasswordAuthenticationToken foundUser = UsernamePasswordAuthenticationToken.unauthenticated("rod", "koala"); UsernamePasswordAuthenticationToken notFoundUser = UsernamePasswordAuthenticationToken @@ -460,6 +461,40 @@ public class DaoAuthenticationProviderTests { .isTrue(); } + // related to SEC-2056 + @Test + @EnabledIfSystemProperty(named = "spring.security.timing-tests", matches = "true") + public void testDisabledUserTiming() { + UsernamePasswordAuthenticationToken user = UsernamePasswordAuthenticationToken.unauthenticated("rod", "koala"); + PasswordEncoder encoder = new BCryptPasswordEncoder(); + MockUserDetailsServiceUserRod users = new MockUserDetailsServiceUserRod(); + users.password = encoder.encode((CharSequence) user.getCredentials()); + DaoAuthenticationProvider provider = new DaoAuthenticationProvider(users); + provider.setPasswordEncoder(encoder); + int sampleSize = 100; + List enabledTimes = new ArrayList<>(sampleSize); + for (int i = 0; i < sampleSize; i++) { + long start = System.currentTimeMillis(); + provider.authenticate(user); + enabledTimes.add(System.currentTimeMillis() - start); + } + UserDetailsChecker preChecks = mock(UserDetailsChecker.class); + willThrow(new DisabledException("User is disabled")).given(preChecks).check(any(UserDetails.class)); + provider.setPreAuthenticationChecks(preChecks); + List disabledTimes = new ArrayList<>(sampleSize); + for (int i = 0; i < sampleSize; i++) { + long start = System.currentTimeMillis(); + assertThatExceptionOfType(DisabledException.class).isThrownBy(() -> provider.authenticate(user)); + disabledTimes.add(System.currentTimeMillis() - start); + } + double enabledAvg = avg(enabledTimes); + double disabledAvg = avg(disabledTimes); + assertThat(Math.abs(disabledAvg - enabledAvg) <= 3) + .withFailMessage("Disabled user average " + disabledAvg + " should be within 3ms of enabled user average " + + enabledAvg) + .isTrue(); + } + private double avg(List counts) { return counts.stream().mapToLong(Long::longValue).average().orElse(0); }