diff --git a/core/src/main/java/org/springframework/security/authentication/dao/AbstractUserDetailsAuthenticationProvider.java b/core/src/main/java/org/springframework/security/authentication/dao/AbstractUserDetailsAuthenticationProvider.java
index d107f9aa22..8ee72cf3fc 100644
--- a/core/src/main/java/org/springframework/security/authentication/dao/AbstractUserDetailsAuthenticationProvider.java
+++ b/core/src/main/java/org/springframework/security/authentication/dao/AbstractUserDetailsAuthenticationProvider.java
@@ -92,6 +92,8 @@ public abstract class AbstractUserDetailsAuthenticationProvider
private UserDetailsChecker postAuthenticationChecks = new DefaultPostAuthenticationChecks();
+ private boolean alwaysPerformAdditionalChecksOnUser = true;
+
private GrantedAuthoritiesMapper authoritiesMapper = new NullAuthoritiesMapper();
/**
@@ -146,8 +148,7 @@ public abstract class AbstractUserDetailsAuthenticationProvider
Assert.notNull(user, "retrieveUser returned null - a violation of the interface contract");
}
try {
- this.preAuthenticationChecks.check(user);
- additionalAuthenticationChecks(user, (UsernamePasswordAuthenticationToken) authentication);
+ performPreCheck(user, (UsernamePasswordAuthenticationToken) authentication);
}
catch (AuthenticationException ex) {
if (!cacheWasUsed) {
@@ -157,8 +158,7 @@ public abstract class AbstractUserDetailsAuthenticationProvider
// we're using latest data (i.e. not from the cache)
cacheWasUsed = false;
user = retrieveUser(username, (UsernamePasswordAuthenticationToken) authentication);
- this.preAuthenticationChecks.check(user);
- additionalAuthenticationChecks(user, (UsernamePasswordAuthenticationToken) authentication);
+ performPreCheck(user, (UsernamePasswordAuthenticationToken) authentication);
}
this.postAuthenticationChecks.check(user);
if (!cacheWasUsed) {
@@ -171,6 +171,25 @@ public abstract class AbstractUserDetailsAuthenticationProvider
return createSuccessAuthentication(principalToReturn, authentication, user);
}
+ private void performPreCheck(UserDetails user, UsernamePasswordAuthenticationToken authentication) {
+ try {
+ this.preAuthenticationChecks.check(user);
+ }
+ catch (AuthenticationException ex) {
+ if (!this.alwaysPerformAdditionalChecksOnUser) {
+ throw ex;
+ }
+ try {
+ additionalAuthenticationChecks(user, authentication);
+ }
+ catch (AuthenticationException ignored) {
+ // preserve the original failed check
+ }
+ throw ex;
+ }
+ additionalAuthenticationChecks(user, authentication);
+ }
+
private String determineUsername(Authentication authentication) {
return (authentication.getPrincipal() == null) ? "NONE_PROVIDED" : authentication.getName();
}
@@ -313,6 +332,22 @@ public abstract class AbstractUserDetailsAuthenticationProvider
this.postAuthenticationChecks = postAuthenticationChecks;
}
+ /**
+ * Set whether to always perform the additional checks on the user, even if the
+ * pre-authentication checks fail. This is useful to ensure that regardless of the
+ * state of the user account, authentication takes the same amount of time to
+ * complete.
+ *
+ *
+ * For applications that rely on the additional checks running only once should set
+ * this value to {@code false}
+ * @param alwaysPerformAdditionalChecksOnUser
+ * @since 5.7.23
+ */
+ public void setAlwaysPerformAdditionalChecksOnUser(boolean alwaysPerformAdditionalChecksOnUser) {
+ this.alwaysPerformAdditionalChecksOnUser = alwaysPerformAdditionalChecksOnUser;
+ }
+
public void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) {
this.authoritiesMapper = authoritiesMapper;
}
diff --git a/core/src/main/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenService.java b/core/src/main/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenService.java
index 7cf52adce1..328ad926cb 100644
--- a/core/src/main/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenService.java
+++ b/core/src/main/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenService.java
@@ -152,7 +152,9 @@ public final class JdbcOneTimeTokenService implements OneTimeTokenService, Dispo
return null;
}
OneTimeToken token = tokens.get(0);
- deleteOneTimeToken(token);
+ if (deleteOneTimeToken(token) == 0) {
+ return null;
+ }
if (isExpired(token)) {
return null;
}
@@ -170,11 +172,11 @@ public final class JdbcOneTimeTokenService implements OneTimeTokenService, Dispo
return this.jdbcOperations.query(SELECT_ONE_TIME_TOKEN_SQL, pss, this.oneTimeTokenRowMapper);
}
- private void deleteOneTimeToken(OneTimeToken oneTimeToken) {
+ private int deleteOneTimeToken(OneTimeToken oneTimeToken) {
List parameters = List
.of(new SqlParameterValue(Types.VARCHAR, oneTimeToken.getTokenValue()));
PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray());
- this.jdbcOperations.update(DELETE_ONE_TIME_TOKEN_SQL, pss);
+ return this.jdbcOperations.update(DELETE_ONE_TIME_TOKEN_SQL, pss);
}
private ThreadPoolTaskScheduler createTaskScheduler(String cleanupCron) {
diff --git a/core/src/test/java/org/springframework/security/authentication/dao/DaoAuthenticationProviderTests.java b/core/src/test/java/org/springframework/security/authentication/dao/DaoAuthenticationProviderTests.java
index 36a5d3bcb9..24cbe3b332 100644
--- a/core/src/test/java/org/springframework/security/authentication/dao/DaoAuthenticationProviderTests.java
+++ b/core/src/test/java/org/springframework/security/authentication/dao/DaoAuthenticationProviderTests.java
@@ -21,6 +21,7 @@ import java.util.ArrayList;
import java.util.List;
import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.condition.EnabledIfSystemProperty;
import org.springframework.cache.Cache;
import org.springframework.dao.DataRetrievalFailureException;
@@ -42,6 +43,7 @@ import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.userdetails.PasswordEncodedUser;
import org.springframework.security.core.userdetails.User;
import org.springframework.security.core.userdetails.UserDetails;
+import org.springframework.security.core.userdetails.UserDetailsChecker;
import org.springframework.security.core.userdetails.UserDetailsPasswordService;
import org.springframework.security.core.userdetails.UserDetailsService;
import org.springframework.security.core.userdetails.UsernameNotFoundException;
@@ -62,6 +64,7 @@ import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.BDDMockito.given;
+import static org.mockito.BDDMockito.willThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@@ -452,12 +455,10 @@ public class DaoAuthenticationProviderTests {
assertThat(daoAuthenticationProvider.getPasswordEncoder()).isSameAs(NoOpPasswordEncoder.getInstance());
}
- /**
- * This is an explicit test for SEC-2056. It is intentionally ignored since this test
- * is not deterministic and {@link #testUserNotFoundEncodesPassword()} ensures that
- * SEC-2056 is fixed.
- */
- public void IGNOREtestSec2056() {
+ // SEC-2056
+ @Test
+ @EnabledIfSystemProperty(named = "spring.security.timing-tests", matches = "true")
+ public void testSec2056() {
UsernamePasswordAuthenticationToken foundUser = UsernamePasswordAuthenticationToken.unauthenticated("rod",
"koala");
UsernamePasswordAuthenticationToken notFoundUser = UsernamePasswordAuthenticationToken
@@ -491,6 +492,41 @@ public class DaoAuthenticationProviderTests {
.isTrue();
}
+ // related to SEC-2056
+ @Test
+ @EnabledIfSystemProperty(named = "spring.security.timing-tests", matches = "true")
+ public void testDisabledUserTiming() {
+ UsernamePasswordAuthenticationToken user = UsernamePasswordAuthenticationToken.unauthenticated("rod", "koala");
+ PasswordEncoder encoder = new BCryptPasswordEncoder();
+ DaoAuthenticationProvider provider = new DaoAuthenticationProvider();
+ provider.setPasswordEncoder(encoder);
+ MockUserDetailsServiceUserRod users = new MockUserDetailsServiceUserRod();
+ users.password = encoder.encode((CharSequence) user.getCredentials());
+ provider.setUserDetailsService(users);
+ int sampleSize = 100;
+ List enabledTimes = new ArrayList<>(sampleSize);
+ for (int i = 0; i < sampleSize; i++) {
+ long start = System.currentTimeMillis();
+ provider.authenticate(user);
+ enabledTimes.add(System.currentTimeMillis() - start);
+ }
+ UserDetailsChecker preChecks = mock(UserDetailsChecker.class);
+ willThrow(new DisabledException("User is disabled")).given(preChecks).check(any(UserDetails.class));
+ provider.setPreAuthenticationChecks(preChecks);
+ List disabledTimes = new ArrayList<>(sampleSize);
+ for (int i = 0; i < sampleSize; i++) {
+ long start = System.currentTimeMillis();
+ assertThatExceptionOfType(DisabledException.class).isThrownBy(() -> provider.authenticate(user));
+ disabledTimes.add(System.currentTimeMillis() - start);
+ }
+ double enabledAvg = avg(enabledTimes);
+ double disabledAvg = avg(disabledTimes);
+ assertThat(Math.abs(disabledAvg - enabledAvg) <= 3)
+ .withFailMessage("Disabled user average " + disabledAvg + " should be within 3ms of enabled user average "
+ + enabledAvg)
+ .isTrue();
+ }
+
private double avg(List counts) {
return counts.stream().mapToLong(Long::longValue).average().orElse(0);
}
diff --git a/core/src/test/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenServiceTests.java b/core/src/test/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenServiceTests.java
index a1a9a2e225..7d2e07920b 100644
--- a/core/src/test/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenServiceTests.java
+++ b/core/src/test/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenServiceTests.java
@@ -21,19 +21,24 @@ import java.time.Duration;
import java.time.Instant;
import java.time.ZoneOffset;
import java.time.temporal.ChronoUnit;
+import java.util.List;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentMatchers;
import org.springframework.jdbc.core.JdbcOperations;
import org.springframework.jdbc.core.JdbcTemplate;
+import org.springframework.jdbc.core.PreparedStatementSetter;
+import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
+import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
@@ -145,6 +150,27 @@ class JdbcOneTimeTokenServiceTests {
assertThat(consumedOneTimeToken).isNull();
}
+ @Test
+ void consumeWhenTokenIsDeletedConcurrentlyThenReturnNull() throws Exception {
+ // Simulates a concurrent consume: SELECT finds the token but DELETE affects
+ // 0 rows because another caller already consumed it.
+ JdbcOperations jdbcOperations = mock(JdbcOperations.class);
+ Instant notExpired = Instant.now().plus(5, ChronoUnit.MINUTES);
+ OneTimeToken token = new DefaultOneTimeToken(TOKEN_VALUE, USERNAME, notExpired);
+ given(jdbcOperations.query(any(String.class), any(PreparedStatementSetter.class),
+ ArgumentMatchers.>any()))
+ .willReturn(List.of(token));
+ given(jdbcOperations.update(any(String.class), any(PreparedStatementSetter.class))).willReturn(0);
+ JdbcOneTimeTokenService service = new JdbcOneTimeTokenService(jdbcOperations);
+ try {
+ OneTimeToken consumed = service.consume(new OneTimeTokenAuthenticationToken(TOKEN_VALUE));
+ assertThat(consumed).isNull();
+ }
+ finally {
+ service.destroy();
+ }
+ }
+
@Test
void consumeWhenTokenIsExpiredThenReturnNull() {
GenerateOneTimeTokenRequest request = new GenerateOneTimeTokenRequest(USERNAME);
diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java
index 73c2b9be44..6ac6e712b1 100644
--- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java
+++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java
@@ -230,7 +230,8 @@ public final class NimbusJwtDecoder implements JwtDecoder {
.getConfigurationForIssuerLocation(issuer, rest);
JwtDecoderProviderConfigurationUtils.validateIssuer(configuration, issuer);
return configuration.get("jwks_uri").toString();
- }, JwtDecoderProviderConfigurationUtils::getJWSAlgorithms);
+ }, JwtDecoderProviderConfigurationUtils::getJWSAlgorithms)
+ .validator(JwtValidators.createDefaultWithIssuer(issuer));
}
/**
@@ -289,6 +290,8 @@ public final class NimbusJwtDecoder implements JwtDecoder {
private Consumer> jwtProcessorCustomizer;
+ private OAuth2TokenValidator validator = JwtValidators.createDefault();
+
private JwkSetUriJwtDecoderBuilder(String jwkSetUri) {
Assert.hasText(jwkSetUri, "jwkSetUri cannot be empty");
this.jwkSetUri = (rest) -> jwkSetUri;
@@ -423,6 +426,12 @@ public final class NimbusJwtDecoder implements JwtDecoder {
return this;
}
+ JwkSetUriJwtDecoderBuilder validator(OAuth2TokenValidator validator) {
+ Assert.notNull(validator, "validator cannot be null");
+ this.validator = validator;
+ return this;
+ }
+
JWSKeySelector jwsKeySelector(JWKSource jwkSource) {
if (this.signatureAlgorithms.isEmpty()) {
return new JWSVerificationKeySelector<>(this.defaultAlgorithms.apply(jwkSource), jwkSource);
@@ -461,7 +470,9 @@ public final class NimbusJwtDecoder implements JwtDecoder {
* @return the configured {@link NimbusJwtDecoder}
*/
public NimbusJwtDecoder build() {
- return new NimbusJwtDecoder(processor());
+ NimbusJwtDecoder decoder = new NimbusJwtDecoder(processor());
+ decoder.setJwtValidator(this.validator);
+ return decoder;
}
private static final class SpringJWKSource implements JWKSetSource {
diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java
index f3a38d812b..b31213dff9 100644
--- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java
+++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java
@@ -241,7 +241,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
}
return Mono.just(configuration.get("jwks_uri").toString());
}),
- ReactiveJwtDecoderProviderConfigurationUtils::getJWSAlgorithms);
+ ReactiveJwtDecoderProviderConfigurationUtils::getJWSAlgorithms)
+ .validator(JwtValidators.createDefaultWithIssuer(issuer));
}
/**
@@ -332,6 +333,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
private BiFunction, Mono>> jwtProcessorCustomizer;
+ private OAuth2TokenValidator validator = JwtValidators.createDefault();
+
private JwkSetUriReactiveJwtDecoderBuilder(String jwkSetUri) {
Assert.hasText(jwkSetUri, "jwkSetUri cannot be empty");
this.jwkSetUri = (web) -> Mono.just(jwkSetUri);
@@ -456,6 +459,11 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
return this;
}
+ JwkSetUriReactiveJwtDecoderBuilder validator(OAuth2TokenValidator validator) {
+ this.validator = validator;
+ return this;
+ }
+
JwkSetUriReactiveJwtDecoderBuilder jwtProcessorCustomizer(
BiFunction, Mono>> jwtProcessorCustomizer) {
Assert.notNull(jwtProcessorCustomizer, "jwtProcessorCustomizer cannot be null");
@@ -468,7 +476,9 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
* @return the configured {@link NimbusReactiveJwtDecoder}
*/
public NimbusReactiveJwtDecoder build() {
- return new NimbusReactiveJwtDecoder(processor());
+ NimbusReactiveJwtDecoder decoder = new NimbusReactiveJwtDecoder(processor());
+ decoder.setJwtValidator(this.validator);
+ return decoder;
}
Mono> jwsKeySelector(ReactiveRemoteJWKSource source) {
diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java
index ef8de01096..dc49325da2 100644
--- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java
+++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java
@@ -328,11 +328,26 @@ public class NimbusJwtDecoderTests {
.willReturn(new ResponseEntity<>(Map.of("issuer", issuer, "jwks_uri", issuer + "/jwks"), HttpStatus.OK));
given(restOperations.exchange(any(RequestEntity.class), eq(String.class)))
.willReturn(new ResponseEntity<>(JWK_SET, HttpStatus.OK));
- JwtDecoder jwtDecoder = NimbusJwtDecoder.withIssuerLocation(issuer).restOperations(restOperations).build();
+ NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withIssuerLocation(issuer)
+ .restOperations(restOperations)
+ .build();
+ jwtDecoder.setJwtValidator(JwtValidators.createDefault());
Jwt jwt = jwtDecoder.decode(SIGNED_JWT);
assertThat(jwt.hasClaim(JwtClaimNames.EXP)).isNotNull();
}
+ @Test
+ public void decodeWhenIssuerLocationThenRejectsMismatchingIssuers() {
+ String issuer = "https://example.org/wrong-issuer";
+ RestOperations restOperations = mock(RestOperations.class);
+ given(restOperations.exchange(any(RequestEntity.class), any(ParameterizedTypeReference.class)))
+ .willReturn(new ResponseEntity<>(Map.of("issuer", issuer, "jwks_uri", issuer + "/jwks"), HttpStatus.OK));
+ given(restOperations.exchange(any(RequestEntity.class), eq(String.class)))
+ .willReturn(new ResponseEntity<>(JWK_SET, HttpStatus.OK));
+ JwtDecoder jwtDecoder = NimbusJwtDecoder.withIssuerLocation(issuer).restOperations(restOperations).build();
+ assertThatExceptionOfType(JwtValidationException.class).isThrownBy(() -> jwtDecoder.decode(SIGNED_JWT));
+ }
+
@Test
public void withJwkSetUriWhenNullOrEmptyThenThrowsException() {
// @formatter:off
diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java
index 5066339c88..e775e2618c 100644
--- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java
+++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java
@@ -617,11 +617,31 @@ public class NimbusReactiveJwtDecoderTests {
given(responseSpec.bodyToMono(any(ParameterizedTypeReference.class)))
.willReturn(Mono.just(Map.of("issuer", issuer, "jwks_uri", issuer + "/jwks")));
given(spec.retrieve()).willReturn(responseSpec);
+ NimbusReactiveJwtDecoder jwtDecoder = NimbusReactiveJwtDecoder.withIssuerLocation(issuer)
+ .webClient(webClient)
+ .build();
+ jwtDecoder.setJwtValidator(JwtValidators.createDefault());
+ Jwt jwt = jwtDecoder.decode(this.messageReadToken).block();
+ assertThat(jwt.hasClaim(JwtClaimNames.EXP)).isNotNull();
+ }
+
+ @Test
+ public void decodeWhenIssuerLocationThenRejectsMismatchingIssuers() {
+ String issuer = "https://example.org/wrong-issuer";
+ WebClient real = WebClient.builder().build();
+ WebClient.RequestHeadersUriSpec spec = spy(real.get());
+ WebClient webClient = spy(WebClient.class);
+ given(webClient.get()).willReturn(spec);
+ WebClient.ResponseSpec responseSpec = mock(WebClient.ResponseSpec.class);
+ given(responseSpec.bodyToMono(String.class)).willReturn(Mono.just(this.jwkSet));
+ given(responseSpec.bodyToMono(any(ParameterizedTypeReference.class)))
+ .willReturn(Mono.just(Map.of("issuer", issuer, "jwks_uri", issuer + "/jwks")));
+ given(spec.retrieve()).willReturn(responseSpec);
ReactiveJwtDecoder jwtDecoder = NimbusReactiveJwtDecoder.withIssuerLocation(issuer)
.webClient(webClient)
.build();
- Jwt jwt = jwtDecoder.decode(this.messageReadToken).block();
- assertThat(jwt.hasClaim(JwtClaimNames.EXP)).isNotNull();
+ assertThatExceptionOfType(JwtValidationException.class)
+ .isThrownBy(() -> jwtDecoder.decode(this.messageReadToken).block());
}
@Test