diff --git a/core/src/main/java/org/springframework/security/authentication/dao/AbstractUserDetailsAuthenticationProvider.java b/core/src/main/java/org/springframework/security/authentication/dao/AbstractUserDetailsAuthenticationProvider.java index d107f9aa22..8ee72cf3fc 100644 --- a/core/src/main/java/org/springframework/security/authentication/dao/AbstractUserDetailsAuthenticationProvider.java +++ b/core/src/main/java/org/springframework/security/authentication/dao/AbstractUserDetailsAuthenticationProvider.java @@ -92,6 +92,8 @@ public abstract class AbstractUserDetailsAuthenticationProvider private UserDetailsChecker postAuthenticationChecks = new DefaultPostAuthenticationChecks(); + private boolean alwaysPerformAdditionalChecksOnUser = true; + private GrantedAuthoritiesMapper authoritiesMapper = new NullAuthoritiesMapper(); /** @@ -146,8 +148,7 @@ public abstract class AbstractUserDetailsAuthenticationProvider Assert.notNull(user, "retrieveUser returned null - a violation of the interface contract"); } try { - this.preAuthenticationChecks.check(user); - additionalAuthenticationChecks(user, (UsernamePasswordAuthenticationToken) authentication); + performPreCheck(user, (UsernamePasswordAuthenticationToken) authentication); } catch (AuthenticationException ex) { if (!cacheWasUsed) { @@ -157,8 +158,7 @@ public abstract class AbstractUserDetailsAuthenticationProvider // we're using latest data (i.e. not from the cache) cacheWasUsed = false; user = retrieveUser(username, (UsernamePasswordAuthenticationToken) authentication); - this.preAuthenticationChecks.check(user); - additionalAuthenticationChecks(user, (UsernamePasswordAuthenticationToken) authentication); + performPreCheck(user, (UsernamePasswordAuthenticationToken) authentication); } this.postAuthenticationChecks.check(user); if (!cacheWasUsed) { @@ -171,6 +171,25 @@ public abstract class AbstractUserDetailsAuthenticationProvider return createSuccessAuthentication(principalToReturn, authentication, user); } + private void performPreCheck(UserDetails user, UsernamePasswordAuthenticationToken authentication) { + try { + this.preAuthenticationChecks.check(user); + } + catch (AuthenticationException ex) { + if (!this.alwaysPerformAdditionalChecksOnUser) { + throw ex; + } + try { + additionalAuthenticationChecks(user, authentication); + } + catch (AuthenticationException ignored) { + // preserve the original failed check + } + throw ex; + } + additionalAuthenticationChecks(user, authentication); + } + private String determineUsername(Authentication authentication) { return (authentication.getPrincipal() == null) ? "NONE_PROVIDED" : authentication.getName(); } @@ -313,6 +332,22 @@ public abstract class AbstractUserDetailsAuthenticationProvider this.postAuthenticationChecks = postAuthenticationChecks; } + /** + * Set whether to always perform the additional checks on the user, even if the + * pre-authentication checks fail. This is useful to ensure that regardless of the + * state of the user account, authentication takes the same amount of time to + * complete. + * + *

+ * For applications that rely on the additional checks running only once should set + * this value to {@code false} + * @param alwaysPerformAdditionalChecksOnUser + * @since 5.7.23 + */ + public void setAlwaysPerformAdditionalChecksOnUser(boolean alwaysPerformAdditionalChecksOnUser) { + this.alwaysPerformAdditionalChecksOnUser = alwaysPerformAdditionalChecksOnUser; + } + public void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) { this.authoritiesMapper = authoritiesMapper; } diff --git a/core/src/main/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenService.java b/core/src/main/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenService.java index 7cf52adce1..328ad926cb 100644 --- a/core/src/main/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenService.java +++ b/core/src/main/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenService.java @@ -152,7 +152,9 @@ public final class JdbcOneTimeTokenService implements OneTimeTokenService, Dispo return null; } OneTimeToken token = tokens.get(0); - deleteOneTimeToken(token); + if (deleteOneTimeToken(token) == 0) { + return null; + } if (isExpired(token)) { return null; } @@ -170,11 +172,11 @@ public final class JdbcOneTimeTokenService implements OneTimeTokenService, Dispo return this.jdbcOperations.query(SELECT_ONE_TIME_TOKEN_SQL, pss, this.oneTimeTokenRowMapper); } - private void deleteOneTimeToken(OneTimeToken oneTimeToken) { + private int deleteOneTimeToken(OneTimeToken oneTimeToken) { List parameters = List .of(new SqlParameterValue(Types.VARCHAR, oneTimeToken.getTokenValue())); PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray()); - this.jdbcOperations.update(DELETE_ONE_TIME_TOKEN_SQL, pss); + return this.jdbcOperations.update(DELETE_ONE_TIME_TOKEN_SQL, pss); } private ThreadPoolTaskScheduler createTaskScheduler(String cleanupCron) { diff --git a/core/src/test/java/org/springframework/security/authentication/dao/DaoAuthenticationProviderTests.java b/core/src/test/java/org/springframework/security/authentication/dao/DaoAuthenticationProviderTests.java index 36a5d3bcb9..24cbe3b332 100644 --- a/core/src/test/java/org/springframework/security/authentication/dao/DaoAuthenticationProviderTests.java +++ b/core/src/test/java/org/springframework/security/authentication/dao/DaoAuthenticationProviderTests.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.List; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfSystemProperty; import org.springframework.cache.Cache; import org.springframework.dao.DataRetrievalFailureException; @@ -42,6 +43,7 @@ import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.userdetails.PasswordEncodedUser; import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.security.core.userdetails.UserDetailsChecker; import org.springframework.security.core.userdetails.UserDetailsPasswordService; import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.core.userdetails.UsernameNotFoundException; @@ -62,6 +64,7 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.willThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -452,12 +455,10 @@ public class DaoAuthenticationProviderTests { assertThat(daoAuthenticationProvider.getPasswordEncoder()).isSameAs(NoOpPasswordEncoder.getInstance()); } - /** - * This is an explicit test for SEC-2056. It is intentionally ignored since this test - * is not deterministic and {@link #testUserNotFoundEncodesPassword()} ensures that - * SEC-2056 is fixed. - */ - public void IGNOREtestSec2056() { + // SEC-2056 + @Test + @EnabledIfSystemProperty(named = "spring.security.timing-tests", matches = "true") + public void testSec2056() { UsernamePasswordAuthenticationToken foundUser = UsernamePasswordAuthenticationToken.unauthenticated("rod", "koala"); UsernamePasswordAuthenticationToken notFoundUser = UsernamePasswordAuthenticationToken @@ -491,6 +492,41 @@ public class DaoAuthenticationProviderTests { .isTrue(); } + // related to SEC-2056 + @Test + @EnabledIfSystemProperty(named = "spring.security.timing-tests", matches = "true") + public void testDisabledUserTiming() { + UsernamePasswordAuthenticationToken user = UsernamePasswordAuthenticationToken.unauthenticated("rod", "koala"); + PasswordEncoder encoder = new BCryptPasswordEncoder(); + DaoAuthenticationProvider provider = new DaoAuthenticationProvider(); + provider.setPasswordEncoder(encoder); + MockUserDetailsServiceUserRod users = new MockUserDetailsServiceUserRod(); + users.password = encoder.encode((CharSequence) user.getCredentials()); + provider.setUserDetailsService(users); + int sampleSize = 100; + List enabledTimes = new ArrayList<>(sampleSize); + for (int i = 0; i < sampleSize; i++) { + long start = System.currentTimeMillis(); + provider.authenticate(user); + enabledTimes.add(System.currentTimeMillis() - start); + } + UserDetailsChecker preChecks = mock(UserDetailsChecker.class); + willThrow(new DisabledException("User is disabled")).given(preChecks).check(any(UserDetails.class)); + provider.setPreAuthenticationChecks(preChecks); + List disabledTimes = new ArrayList<>(sampleSize); + for (int i = 0; i < sampleSize; i++) { + long start = System.currentTimeMillis(); + assertThatExceptionOfType(DisabledException.class).isThrownBy(() -> provider.authenticate(user)); + disabledTimes.add(System.currentTimeMillis() - start); + } + double enabledAvg = avg(enabledTimes); + double disabledAvg = avg(disabledTimes); + assertThat(Math.abs(disabledAvg - enabledAvg) <= 3) + .withFailMessage("Disabled user average " + disabledAvg + " should be within 3ms of enabled user average " + + enabledAvg) + .isTrue(); + } + private double avg(List counts) { return counts.stream().mapToLong(Long::longValue).average().orElse(0); } diff --git a/core/src/test/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenServiceTests.java b/core/src/test/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenServiceTests.java index a1a9a2e225..7d2e07920b 100644 --- a/core/src/test/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenServiceTests.java +++ b/core/src/test/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenServiceTests.java @@ -21,19 +21,24 @@ import java.time.Duration; import java.time.Instant; import java.time.ZoneOffset; import java.time.temporal.ChronoUnit; +import java.util.List; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentMatchers; import org.springframework.jdbc.core.JdbcOperations; import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.core.PreparedStatementSetter; +import org.springframework.jdbc.core.RowMapper; import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase; import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; @@ -145,6 +150,27 @@ class JdbcOneTimeTokenServiceTests { assertThat(consumedOneTimeToken).isNull(); } + @Test + void consumeWhenTokenIsDeletedConcurrentlyThenReturnNull() throws Exception { + // Simulates a concurrent consume: SELECT finds the token but DELETE affects + // 0 rows because another caller already consumed it. + JdbcOperations jdbcOperations = mock(JdbcOperations.class); + Instant notExpired = Instant.now().plus(5, ChronoUnit.MINUTES); + OneTimeToken token = new DefaultOneTimeToken(TOKEN_VALUE, USERNAME, notExpired); + given(jdbcOperations.query(any(String.class), any(PreparedStatementSetter.class), + ArgumentMatchers.>any())) + .willReturn(List.of(token)); + given(jdbcOperations.update(any(String.class), any(PreparedStatementSetter.class))).willReturn(0); + JdbcOneTimeTokenService service = new JdbcOneTimeTokenService(jdbcOperations); + try { + OneTimeToken consumed = service.consume(new OneTimeTokenAuthenticationToken(TOKEN_VALUE)); + assertThat(consumed).isNull(); + } + finally { + service.destroy(); + } + } + @Test void consumeWhenTokenIsExpiredThenReturnNull() { GenerateOneTimeTokenRequest request = new GenerateOneTimeTokenRequest(USERNAME); diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java index 73c2b9be44..6ac6e712b1 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java @@ -230,7 +230,8 @@ public final class NimbusJwtDecoder implements JwtDecoder { .getConfigurationForIssuerLocation(issuer, rest); JwtDecoderProviderConfigurationUtils.validateIssuer(configuration, issuer); return configuration.get("jwks_uri").toString(); - }, JwtDecoderProviderConfigurationUtils::getJWSAlgorithms); + }, JwtDecoderProviderConfigurationUtils::getJWSAlgorithms) + .validator(JwtValidators.createDefaultWithIssuer(issuer)); } /** @@ -289,6 +290,8 @@ public final class NimbusJwtDecoder implements JwtDecoder { private Consumer> jwtProcessorCustomizer; + private OAuth2TokenValidator validator = JwtValidators.createDefault(); + private JwkSetUriJwtDecoderBuilder(String jwkSetUri) { Assert.hasText(jwkSetUri, "jwkSetUri cannot be empty"); this.jwkSetUri = (rest) -> jwkSetUri; @@ -423,6 +426,12 @@ public final class NimbusJwtDecoder implements JwtDecoder { return this; } + JwkSetUriJwtDecoderBuilder validator(OAuth2TokenValidator validator) { + Assert.notNull(validator, "validator cannot be null"); + this.validator = validator; + return this; + } + JWSKeySelector jwsKeySelector(JWKSource jwkSource) { if (this.signatureAlgorithms.isEmpty()) { return new JWSVerificationKeySelector<>(this.defaultAlgorithms.apply(jwkSource), jwkSource); @@ -461,7 +470,9 @@ public final class NimbusJwtDecoder implements JwtDecoder { * @return the configured {@link NimbusJwtDecoder} */ public NimbusJwtDecoder build() { - return new NimbusJwtDecoder(processor()); + NimbusJwtDecoder decoder = new NimbusJwtDecoder(processor()); + decoder.setJwtValidator(this.validator); + return decoder; } private static final class SpringJWKSource implements JWKSetSource { diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java index f3a38d812b..b31213dff9 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java @@ -241,7 +241,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { } return Mono.just(configuration.get("jwks_uri").toString()); }), - ReactiveJwtDecoderProviderConfigurationUtils::getJWSAlgorithms); + ReactiveJwtDecoderProviderConfigurationUtils::getJWSAlgorithms) + .validator(JwtValidators.createDefaultWithIssuer(issuer)); } /** @@ -332,6 +333,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { private BiFunction, Mono>> jwtProcessorCustomizer; + private OAuth2TokenValidator validator = JwtValidators.createDefault(); + private JwkSetUriReactiveJwtDecoderBuilder(String jwkSetUri) { Assert.hasText(jwkSetUri, "jwkSetUri cannot be empty"); this.jwkSetUri = (web) -> Mono.just(jwkSetUri); @@ -456,6 +459,11 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { return this; } + JwkSetUriReactiveJwtDecoderBuilder validator(OAuth2TokenValidator validator) { + this.validator = validator; + return this; + } + JwkSetUriReactiveJwtDecoderBuilder jwtProcessorCustomizer( BiFunction, Mono>> jwtProcessorCustomizer) { Assert.notNull(jwtProcessorCustomizer, "jwtProcessorCustomizer cannot be null"); @@ -468,7 +476,9 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { * @return the configured {@link NimbusReactiveJwtDecoder} */ public NimbusReactiveJwtDecoder build() { - return new NimbusReactiveJwtDecoder(processor()); + NimbusReactiveJwtDecoder decoder = new NimbusReactiveJwtDecoder(processor()); + decoder.setJwtValidator(this.validator); + return decoder; } Mono> jwsKeySelector(ReactiveRemoteJWKSource source) { diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java index ef8de01096..dc49325da2 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java @@ -328,11 +328,26 @@ public class NimbusJwtDecoderTests { .willReturn(new ResponseEntity<>(Map.of("issuer", issuer, "jwks_uri", issuer + "/jwks"), HttpStatus.OK)); given(restOperations.exchange(any(RequestEntity.class), eq(String.class))) .willReturn(new ResponseEntity<>(JWK_SET, HttpStatus.OK)); - JwtDecoder jwtDecoder = NimbusJwtDecoder.withIssuerLocation(issuer).restOperations(restOperations).build(); + NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withIssuerLocation(issuer) + .restOperations(restOperations) + .build(); + jwtDecoder.setJwtValidator(JwtValidators.createDefault()); Jwt jwt = jwtDecoder.decode(SIGNED_JWT); assertThat(jwt.hasClaim(JwtClaimNames.EXP)).isNotNull(); } + @Test + public void decodeWhenIssuerLocationThenRejectsMismatchingIssuers() { + String issuer = "https://example.org/wrong-issuer"; + RestOperations restOperations = mock(RestOperations.class); + given(restOperations.exchange(any(RequestEntity.class), any(ParameterizedTypeReference.class))) + .willReturn(new ResponseEntity<>(Map.of("issuer", issuer, "jwks_uri", issuer + "/jwks"), HttpStatus.OK)); + given(restOperations.exchange(any(RequestEntity.class), eq(String.class))) + .willReturn(new ResponseEntity<>(JWK_SET, HttpStatus.OK)); + JwtDecoder jwtDecoder = NimbusJwtDecoder.withIssuerLocation(issuer).restOperations(restOperations).build(); + assertThatExceptionOfType(JwtValidationException.class).isThrownBy(() -> jwtDecoder.decode(SIGNED_JWT)); + } + @Test public void withJwkSetUriWhenNullOrEmptyThenThrowsException() { // @formatter:off diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java index 5066339c88..e775e2618c 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java @@ -617,11 +617,31 @@ public class NimbusReactiveJwtDecoderTests { given(responseSpec.bodyToMono(any(ParameterizedTypeReference.class))) .willReturn(Mono.just(Map.of("issuer", issuer, "jwks_uri", issuer + "/jwks"))); given(spec.retrieve()).willReturn(responseSpec); + NimbusReactiveJwtDecoder jwtDecoder = NimbusReactiveJwtDecoder.withIssuerLocation(issuer) + .webClient(webClient) + .build(); + jwtDecoder.setJwtValidator(JwtValidators.createDefault()); + Jwt jwt = jwtDecoder.decode(this.messageReadToken).block(); + assertThat(jwt.hasClaim(JwtClaimNames.EXP)).isNotNull(); + } + + @Test + public void decodeWhenIssuerLocationThenRejectsMismatchingIssuers() { + String issuer = "https://example.org/wrong-issuer"; + WebClient real = WebClient.builder().build(); + WebClient.RequestHeadersUriSpec spec = spy(real.get()); + WebClient webClient = spy(WebClient.class); + given(webClient.get()).willReturn(spec); + WebClient.ResponseSpec responseSpec = mock(WebClient.ResponseSpec.class); + given(responseSpec.bodyToMono(String.class)).willReturn(Mono.just(this.jwkSet)); + given(responseSpec.bodyToMono(any(ParameterizedTypeReference.class))) + .willReturn(Mono.just(Map.of("issuer", issuer, "jwks_uri", issuer + "/jwks"))); + given(spec.retrieve()).willReturn(responseSpec); ReactiveJwtDecoder jwtDecoder = NimbusReactiveJwtDecoder.withIssuerLocation(issuer) .webClient(webClient) .build(); - Jwt jwt = jwtDecoder.decode(this.messageReadToken).block(); - assertThat(jwt.hasClaim(JwtClaimNames.EXP)).isNotNull(); + assertThatExceptionOfType(JwtValidationException.class) + .isThrownBy(() -> jwtDecoder.decode(this.messageReadToken).block()); } @Test