Add Support for Always Running Additional Authentication Checks

Signed-off-by: Josh Cummings <3627351+jzheaux@users.noreply.github.com>
This commit is contained in:
Josh Cummings 2026-04-15 21:00:55 -06:00
parent 68b820ed09
commit a317a3d866
2 changed files with 81 additions and 10 deletions

View File

@ -92,6 +92,8 @@ public abstract class AbstractUserDetailsAuthenticationProvider
private UserDetailsChecker postAuthenticationChecks = new DefaultPostAuthenticationChecks();
private boolean alwaysPerformAdditionalChecksOnUser = true;
private GrantedAuthoritiesMapper authoritiesMapper = new NullAuthoritiesMapper();
/**
@ -146,8 +148,7 @@ public abstract class AbstractUserDetailsAuthenticationProvider
Assert.notNull(user, "retrieveUser returned null - a violation of the interface contract");
}
try {
this.preAuthenticationChecks.check(user);
additionalAuthenticationChecks(user, (UsernamePasswordAuthenticationToken) authentication);
performPreCheck(user, (UsernamePasswordAuthenticationToken) authentication);
}
catch (AuthenticationException ex) {
if (!cacheWasUsed) {
@ -157,8 +158,7 @@ public abstract class AbstractUserDetailsAuthenticationProvider
// we're using latest data (i.e. not from the cache)
cacheWasUsed = false;
user = retrieveUser(username, (UsernamePasswordAuthenticationToken) authentication);
this.preAuthenticationChecks.check(user);
additionalAuthenticationChecks(user, (UsernamePasswordAuthenticationToken) authentication);
performPreCheck(user, (UsernamePasswordAuthenticationToken) authentication);
}
this.postAuthenticationChecks.check(user);
if (!cacheWasUsed) {
@ -171,6 +171,25 @@ public abstract class AbstractUserDetailsAuthenticationProvider
return createSuccessAuthentication(principalToReturn, authentication, user);
}
private void performPreCheck(UserDetails user, UsernamePasswordAuthenticationToken authentication) {
try {
this.preAuthenticationChecks.check(user);
}
catch (AuthenticationException ex) {
if (!this.alwaysPerformAdditionalChecksOnUser) {
throw ex;
}
try {
additionalAuthenticationChecks(user, authentication);
}
catch (AuthenticationException ignored) {
// preserve the original failed check
}
throw ex;
}
additionalAuthenticationChecks(user, authentication);
}
private String determineUsername(Authentication authentication) {
return (authentication.getPrincipal() == null) ? "NONE_PROVIDED" : authentication.getName();
}
@ -313,6 +332,22 @@ public abstract class AbstractUserDetailsAuthenticationProvider
this.postAuthenticationChecks = postAuthenticationChecks;
}
/**
* Set whether to always perform the additional checks on the user, even if the
* pre-authentication checks fail. This is useful to ensure that regardless of the
* state of the user account, authentication takes the same amount of time to
* complete.
*
* <p>
* For applications that rely on the additional checks running only once should set
* this value to {@code false}
* @param alwaysPerformAdditionalChecksOnUser
* @since 5.7.23
*/
public void setAlwaysPerformAdditionalChecksOnUser(boolean alwaysPerformAdditionalChecksOnUser) {
this.alwaysPerformAdditionalChecksOnUser = alwaysPerformAdditionalChecksOnUser;
}
public void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) {
this.authoritiesMapper = authoritiesMapper;
}

View File

@ -21,6 +21,7 @@ import java.util.ArrayList;
import java.util.List;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfSystemProperty;
import org.springframework.cache.Cache;
import org.springframework.dao.DataRetrievalFailureException;
@ -42,6 +43,7 @@ import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.userdetails.PasswordEncodedUser;
import org.springframework.security.core.userdetails.User;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.core.userdetails.UserDetailsChecker;
import org.springframework.security.core.userdetails.UserDetailsPasswordService;
import org.springframework.security.core.userdetails.UserDetailsService;
import org.springframework.security.core.userdetails.UsernameNotFoundException;
@ -62,6 +64,7 @@ import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.willThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@ -452,12 +455,10 @@ public class DaoAuthenticationProviderTests {
assertThat(daoAuthenticationProvider.getPasswordEncoder()).isSameAs(NoOpPasswordEncoder.getInstance());
}
/**
* This is an explicit test for SEC-2056. It is intentionally ignored since this test
* is not deterministic and {@link #testUserNotFoundEncodesPassword()} ensures that
* SEC-2056 is fixed.
*/
public void IGNOREtestSec2056() {
// SEC-2056
@Test
@EnabledIfSystemProperty(named = "spring.security.timing-tests", matches = "true")
public void testSec2056() {
UsernamePasswordAuthenticationToken foundUser = UsernamePasswordAuthenticationToken.unauthenticated("rod",
"koala");
UsernamePasswordAuthenticationToken notFoundUser = UsernamePasswordAuthenticationToken
@ -491,6 +492,41 @@ public class DaoAuthenticationProviderTests {
.isTrue();
}
// related to SEC-2056
@Test
@EnabledIfSystemProperty(named = "spring.security.timing-tests", matches = "true")
public void testDisabledUserTiming() {
UsernamePasswordAuthenticationToken user = UsernamePasswordAuthenticationToken.unauthenticated("rod", "koala");
PasswordEncoder encoder = new BCryptPasswordEncoder();
DaoAuthenticationProvider provider = new DaoAuthenticationProvider();
provider.setPasswordEncoder(encoder);
MockUserDetailsServiceUserRod users = new MockUserDetailsServiceUserRod();
users.password = encoder.encode((CharSequence) user.getCredentials());
provider.setUserDetailsService(users);
int sampleSize = 100;
List<Long> enabledTimes = new ArrayList<>(sampleSize);
for (int i = 0; i < sampleSize; i++) {
long start = System.currentTimeMillis();
provider.authenticate(user);
enabledTimes.add(System.currentTimeMillis() - start);
}
UserDetailsChecker preChecks = mock(UserDetailsChecker.class);
willThrow(new DisabledException("User is disabled")).given(preChecks).check(any(UserDetails.class));
provider.setPreAuthenticationChecks(preChecks);
List<Long> disabledTimes = new ArrayList<>(sampleSize);
for (int i = 0; i < sampleSize; i++) {
long start = System.currentTimeMillis();
assertThatExceptionOfType(DisabledException.class).isThrownBy(() -> provider.authenticate(user));
disabledTimes.add(System.currentTimeMillis() - start);
}
double enabledAvg = avg(enabledTimes);
double disabledAvg = avg(disabledTimes);
assertThat(Math.abs(disabledAvg - enabledAvg) <= 3)
.withFailMessage("Disabled user average " + disabledAvg + " should be within 3ms of enabled user average "
+ enabledAvg)
.isTrue();
}
private double avg(List<Long> counts) {
return counts.stream().mapToLong(Long::longValue).average().orElse(0);
}