Merge remote-tracking branch 'origin/6.5.x' into 6.5.x

This commit is contained in:
Josh Cummings 2026-04-20 09:51:26 -06:00
commit 6e5f8f2a1d
8 changed files with 175 additions and 20 deletions

View File

@ -92,6 +92,8 @@ public abstract class AbstractUserDetailsAuthenticationProvider
private UserDetailsChecker postAuthenticationChecks = new DefaultPostAuthenticationChecks();
private boolean alwaysPerformAdditionalChecksOnUser = true;
private GrantedAuthoritiesMapper authoritiesMapper = new NullAuthoritiesMapper();
/**
@ -146,8 +148,7 @@ public abstract class AbstractUserDetailsAuthenticationProvider
Assert.notNull(user, "retrieveUser returned null - a violation of the interface contract");
}
try {
this.preAuthenticationChecks.check(user);
additionalAuthenticationChecks(user, (UsernamePasswordAuthenticationToken) authentication);
performPreCheck(user, (UsernamePasswordAuthenticationToken) authentication);
}
catch (AuthenticationException ex) {
if (!cacheWasUsed) {
@ -157,8 +158,7 @@ public abstract class AbstractUserDetailsAuthenticationProvider
// we're using latest data (i.e. not from the cache)
cacheWasUsed = false;
user = retrieveUser(username, (UsernamePasswordAuthenticationToken) authentication);
this.preAuthenticationChecks.check(user);
additionalAuthenticationChecks(user, (UsernamePasswordAuthenticationToken) authentication);
performPreCheck(user, (UsernamePasswordAuthenticationToken) authentication);
}
this.postAuthenticationChecks.check(user);
if (!cacheWasUsed) {
@ -171,6 +171,25 @@ public abstract class AbstractUserDetailsAuthenticationProvider
return createSuccessAuthentication(principalToReturn, authentication, user);
}
private void performPreCheck(UserDetails user, UsernamePasswordAuthenticationToken authentication) {
try {
this.preAuthenticationChecks.check(user);
}
catch (AuthenticationException ex) {
if (!this.alwaysPerformAdditionalChecksOnUser) {
throw ex;
}
try {
additionalAuthenticationChecks(user, authentication);
}
catch (AuthenticationException ignored) {
// preserve the original failed check
}
throw ex;
}
additionalAuthenticationChecks(user, authentication);
}
private String determineUsername(Authentication authentication) {
return (authentication.getPrincipal() == null) ? "NONE_PROVIDED" : authentication.getName();
}
@ -313,6 +332,22 @@ public abstract class AbstractUserDetailsAuthenticationProvider
this.postAuthenticationChecks = postAuthenticationChecks;
}
/**
* Set whether to always perform the additional checks on the user, even if the
* pre-authentication checks fail. This is useful to ensure that regardless of the
* state of the user account, authentication takes the same amount of time to
* complete.
*
* <p>
* For applications that rely on the additional checks running only once should set
* this value to {@code false}
* @param alwaysPerformAdditionalChecksOnUser
* @since 5.7.23
*/
public void setAlwaysPerformAdditionalChecksOnUser(boolean alwaysPerformAdditionalChecksOnUser) {
this.alwaysPerformAdditionalChecksOnUser = alwaysPerformAdditionalChecksOnUser;
}
public void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) {
this.authoritiesMapper = authoritiesMapper;
}

View File

@ -152,7 +152,9 @@ public final class JdbcOneTimeTokenService implements OneTimeTokenService, Dispo
return null;
}
OneTimeToken token = tokens.get(0);
deleteOneTimeToken(token);
if (deleteOneTimeToken(token) == 0) {
return null;
}
if (isExpired(token)) {
return null;
}
@ -170,11 +172,11 @@ public final class JdbcOneTimeTokenService implements OneTimeTokenService, Dispo
return this.jdbcOperations.query(SELECT_ONE_TIME_TOKEN_SQL, pss, this.oneTimeTokenRowMapper);
}
private void deleteOneTimeToken(OneTimeToken oneTimeToken) {
private int deleteOneTimeToken(OneTimeToken oneTimeToken) {
List<SqlParameterValue> parameters = List
.of(new SqlParameterValue(Types.VARCHAR, oneTimeToken.getTokenValue()));
PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray());
this.jdbcOperations.update(DELETE_ONE_TIME_TOKEN_SQL, pss);
return this.jdbcOperations.update(DELETE_ONE_TIME_TOKEN_SQL, pss);
}
private ThreadPoolTaskScheduler createTaskScheduler(String cleanupCron) {

View File

@ -21,6 +21,7 @@ import java.util.ArrayList;
import java.util.List;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfSystemProperty;
import org.springframework.cache.Cache;
import org.springframework.dao.DataRetrievalFailureException;
@ -42,6 +43,7 @@ import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.userdetails.PasswordEncodedUser;
import org.springframework.security.core.userdetails.User;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.core.userdetails.UserDetailsChecker;
import org.springframework.security.core.userdetails.UserDetailsPasswordService;
import org.springframework.security.core.userdetails.UserDetailsService;
import org.springframework.security.core.userdetails.UsernameNotFoundException;
@ -62,6 +64,7 @@ import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.willThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@ -452,12 +455,10 @@ public class DaoAuthenticationProviderTests {
assertThat(daoAuthenticationProvider.getPasswordEncoder()).isSameAs(NoOpPasswordEncoder.getInstance());
}
/**
* This is an explicit test for SEC-2056. It is intentionally ignored since this test
* is not deterministic and {@link #testUserNotFoundEncodesPassword()} ensures that
* SEC-2056 is fixed.
*/
public void IGNOREtestSec2056() {
// SEC-2056
@Test
@EnabledIfSystemProperty(named = "spring.security.timing-tests", matches = "true")
public void testSec2056() {
UsernamePasswordAuthenticationToken foundUser = UsernamePasswordAuthenticationToken.unauthenticated("rod",
"koala");
UsernamePasswordAuthenticationToken notFoundUser = UsernamePasswordAuthenticationToken
@ -491,6 +492,41 @@ public class DaoAuthenticationProviderTests {
.isTrue();
}
// related to SEC-2056
@Test
@EnabledIfSystemProperty(named = "spring.security.timing-tests", matches = "true")
public void testDisabledUserTiming() {
UsernamePasswordAuthenticationToken user = UsernamePasswordAuthenticationToken.unauthenticated("rod", "koala");
PasswordEncoder encoder = new BCryptPasswordEncoder();
DaoAuthenticationProvider provider = new DaoAuthenticationProvider();
provider.setPasswordEncoder(encoder);
MockUserDetailsServiceUserRod users = new MockUserDetailsServiceUserRod();
users.password = encoder.encode((CharSequence) user.getCredentials());
provider.setUserDetailsService(users);
int sampleSize = 100;
List<Long> enabledTimes = new ArrayList<>(sampleSize);
for (int i = 0; i < sampleSize; i++) {
long start = System.currentTimeMillis();
provider.authenticate(user);
enabledTimes.add(System.currentTimeMillis() - start);
}
UserDetailsChecker preChecks = mock(UserDetailsChecker.class);
willThrow(new DisabledException("User is disabled")).given(preChecks).check(any(UserDetails.class));
provider.setPreAuthenticationChecks(preChecks);
List<Long> disabledTimes = new ArrayList<>(sampleSize);
for (int i = 0; i < sampleSize; i++) {
long start = System.currentTimeMillis();
assertThatExceptionOfType(DisabledException.class).isThrownBy(() -> provider.authenticate(user));
disabledTimes.add(System.currentTimeMillis() - start);
}
double enabledAvg = avg(enabledTimes);
double disabledAvg = avg(disabledTimes);
assertThat(Math.abs(disabledAvg - enabledAvg) <= 3)
.withFailMessage("Disabled user average " + disabledAvg + " should be within 3ms of enabled user average "
+ enabledAvg)
.isTrue();
}
private double avg(List<Long> counts) {
return counts.stream().mapToLong(Long::longValue).average().orElse(0);
}

View File

@ -21,19 +21,24 @@ import java.time.Duration;
import java.time.Instant;
import java.time.ZoneOffset;
import java.time.temporal.ChronoUnit;
import java.util.List;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentMatchers;
import org.springframework.jdbc.core.JdbcOperations;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.PreparedStatementSetter;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
@ -145,6 +150,27 @@ class JdbcOneTimeTokenServiceTests {
assertThat(consumedOneTimeToken).isNull();
}
@Test
void consumeWhenTokenIsDeletedConcurrentlyThenReturnNull() throws Exception {
// Simulates a concurrent consume: SELECT finds the token but DELETE affects
// 0 rows because another caller already consumed it.
JdbcOperations jdbcOperations = mock(JdbcOperations.class);
Instant notExpired = Instant.now().plus(5, ChronoUnit.MINUTES);
OneTimeToken token = new DefaultOneTimeToken(TOKEN_VALUE, USERNAME, notExpired);
given(jdbcOperations.query(any(String.class), any(PreparedStatementSetter.class),
ArgumentMatchers.<RowMapper<OneTimeToken>>any()))
.willReturn(List.of(token));
given(jdbcOperations.update(any(String.class), any(PreparedStatementSetter.class))).willReturn(0);
JdbcOneTimeTokenService service = new JdbcOneTimeTokenService(jdbcOperations);
try {
OneTimeToken consumed = service.consume(new OneTimeTokenAuthenticationToken(TOKEN_VALUE));
assertThat(consumed).isNull();
}
finally {
service.destroy();
}
}
@Test
void consumeWhenTokenIsExpiredThenReturnNull() {
GenerateOneTimeTokenRequest request = new GenerateOneTimeTokenRequest(USERNAME);

View File

@ -230,7 +230,8 @@ public final class NimbusJwtDecoder implements JwtDecoder {
.getConfigurationForIssuerLocation(issuer, rest);
JwtDecoderProviderConfigurationUtils.validateIssuer(configuration, issuer);
return configuration.get("jwks_uri").toString();
}, JwtDecoderProviderConfigurationUtils::getJWSAlgorithms);
}, JwtDecoderProviderConfigurationUtils::getJWSAlgorithms)
.validator(JwtValidators.createDefaultWithIssuer(issuer));
}
/**
@ -289,6 +290,8 @@ public final class NimbusJwtDecoder implements JwtDecoder {
private Consumer<ConfigurableJWTProcessor<SecurityContext>> jwtProcessorCustomizer;
private OAuth2TokenValidator<Jwt> validator = JwtValidators.createDefault();
private JwkSetUriJwtDecoderBuilder(String jwkSetUri) {
Assert.hasText(jwkSetUri, "jwkSetUri cannot be empty");
this.jwkSetUri = (rest) -> jwkSetUri;
@ -423,6 +426,12 @@ public final class NimbusJwtDecoder implements JwtDecoder {
return this;
}
JwkSetUriJwtDecoderBuilder validator(OAuth2TokenValidator<Jwt> validator) {
Assert.notNull(validator, "validator cannot be null");
this.validator = validator;
return this;
}
JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSource) {
if (this.signatureAlgorithms.isEmpty()) {
return new JWSVerificationKeySelector<>(this.defaultAlgorithms.apply(jwkSource), jwkSource);
@ -461,7 +470,9 @@ public final class NimbusJwtDecoder implements JwtDecoder {
* @return the configured {@link NimbusJwtDecoder}
*/
public NimbusJwtDecoder build() {
return new NimbusJwtDecoder(processor());
NimbusJwtDecoder decoder = new NimbusJwtDecoder(processor());
decoder.setJwtValidator(this.validator);
return decoder;
}
private static final class SpringJWKSource<C extends SecurityContext> implements JWKSetSource<C> {

View File

@ -241,7 +241,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
}
return Mono.just(configuration.get("jwks_uri").toString());
}),
ReactiveJwtDecoderProviderConfigurationUtils::getJWSAlgorithms);
ReactiveJwtDecoderProviderConfigurationUtils::getJWSAlgorithms)
.validator(JwtValidators.createDefaultWithIssuer(issuer));
}
/**
@ -332,6 +333,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
private BiFunction<ReactiveRemoteJWKSource, ConfigurableJWTProcessor<JWKSecurityContext>, Mono<ConfigurableJWTProcessor<JWKSecurityContext>>> jwtProcessorCustomizer;
private OAuth2TokenValidator<Jwt> validator = JwtValidators.createDefault();
private JwkSetUriReactiveJwtDecoderBuilder(String jwkSetUri) {
Assert.hasText(jwkSetUri, "jwkSetUri cannot be empty");
this.jwkSetUri = (web) -> Mono.just(jwkSetUri);
@ -456,6 +459,11 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
return this;
}
JwkSetUriReactiveJwtDecoderBuilder validator(OAuth2TokenValidator<Jwt> validator) {
this.validator = validator;
return this;
}
JwkSetUriReactiveJwtDecoderBuilder jwtProcessorCustomizer(
BiFunction<ReactiveRemoteJWKSource, ConfigurableJWTProcessor<JWKSecurityContext>, Mono<ConfigurableJWTProcessor<JWKSecurityContext>>> jwtProcessorCustomizer) {
Assert.notNull(jwtProcessorCustomizer, "jwtProcessorCustomizer cannot be null");
@ -468,7 +476,9 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
* @return the configured {@link NimbusReactiveJwtDecoder}
*/
public NimbusReactiveJwtDecoder build() {
return new NimbusReactiveJwtDecoder(processor());
NimbusReactiveJwtDecoder decoder = new NimbusReactiveJwtDecoder(processor());
decoder.setJwtValidator(this.validator);
return decoder;
}
Mono<JWSKeySelector<JWKSecurityContext>> jwsKeySelector(ReactiveRemoteJWKSource source) {

View File

@ -328,11 +328,26 @@ public class NimbusJwtDecoderTests {
.willReturn(new ResponseEntity<>(Map.of("issuer", issuer, "jwks_uri", issuer + "/jwks"), HttpStatus.OK));
given(restOperations.exchange(any(RequestEntity.class), eq(String.class)))
.willReturn(new ResponseEntity<>(JWK_SET, HttpStatus.OK));
JwtDecoder jwtDecoder = NimbusJwtDecoder.withIssuerLocation(issuer).restOperations(restOperations).build();
NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withIssuerLocation(issuer)
.restOperations(restOperations)
.build();
jwtDecoder.setJwtValidator(JwtValidators.createDefault());
Jwt jwt = jwtDecoder.decode(SIGNED_JWT);
assertThat(jwt.hasClaim(JwtClaimNames.EXP)).isNotNull();
}
@Test
public void decodeWhenIssuerLocationThenRejectsMismatchingIssuers() {
String issuer = "https://example.org/wrong-issuer";
RestOperations restOperations = mock(RestOperations.class);
given(restOperations.exchange(any(RequestEntity.class), any(ParameterizedTypeReference.class)))
.willReturn(new ResponseEntity<>(Map.of("issuer", issuer, "jwks_uri", issuer + "/jwks"), HttpStatus.OK));
given(restOperations.exchange(any(RequestEntity.class), eq(String.class)))
.willReturn(new ResponseEntity<>(JWK_SET, HttpStatus.OK));
JwtDecoder jwtDecoder = NimbusJwtDecoder.withIssuerLocation(issuer).restOperations(restOperations).build();
assertThatExceptionOfType(JwtValidationException.class).isThrownBy(() -> jwtDecoder.decode(SIGNED_JWT));
}
@Test
public void withJwkSetUriWhenNullOrEmptyThenThrowsException() {
// @formatter:off

View File

@ -617,11 +617,31 @@ public class NimbusReactiveJwtDecoderTests {
given(responseSpec.bodyToMono(any(ParameterizedTypeReference.class)))
.willReturn(Mono.just(Map.of("issuer", issuer, "jwks_uri", issuer + "/jwks")));
given(spec.retrieve()).willReturn(responseSpec);
NimbusReactiveJwtDecoder jwtDecoder = NimbusReactiveJwtDecoder.withIssuerLocation(issuer)
.webClient(webClient)
.build();
jwtDecoder.setJwtValidator(JwtValidators.createDefault());
Jwt jwt = jwtDecoder.decode(this.messageReadToken).block();
assertThat(jwt.hasClaim(JwtClaimNames.EXP)).isNotNull();
}
@Test
public void decodeWhenIssuerLocationThenRejectsMismatchingIssuers() {
String issuer = "https://example.org/wrong-issuer";
WebClient real = WebClient.builder().build();
WebClient.RequestHeadersUriSpec spec = spy(real.get());
WebClient webClient = spy(WebClient.class);
given(webClient.get()).willReturn(spec);
WebClient.ResponseSpec responseSpec = mock(WebClient.ResponseSpec.class);
given(responseSpec.bodyToMono(String.class)).willReturn(Mono.just(this.jwkSet));
given(responseSpec.bodyToMono(any(ParameterizedTypeReference.class)))
.willReturn(Mono.just(Map.of("issuer", issuer, "jwks_uri", issuer + "/jwks")));
given(spec.retrieve()).willReturn(responseSpec);
ReactiveJwtDecoder jwtDecoder = NimbusReactiveJwtDecoder.withIssuerLocation(issuer)
.webClient(webClient)
.build();
Jwt jwt = jwtDecoder.decode(this.messageReadToken).block();
assertThat(jwt.hasClaim(JwtClaimNames.EXP)).isNotNull();
assertThatExceptionOfType(JwtValidationException.class)
.isThrownBy(() -> jwtDecoder.decode(this.messageReadToken).block());
}
@Test