diff --git a/core/src/main/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenService.java b/core/src/main/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenService.java index 7cf52adce1..328ad926cb 100644 --- a/core/src/main/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenService.java +++ b/core/src/main/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenService.java @@ -152,7 +152,9 @@ public final class JdbcOneTimeTokenService implements OneTimeTokenService, Dispo return null; } OneTimeToken token = tokens.get(0); - deleteOneTimeToken(token); + if (deleteOneTimeToken(token) == 0) { + return null; + } if (isExpired(token)) { return null; } @@ -170,11 +172,11 @@ public final class JdbcOneTimeTokenService implements OneTimeTokenService, Dispo return this.jdbcOperations.query(SELECT_ONE_TIME_TOKEN_SQL, pss, this.oneTimeTokenRowMapper); } - private void deleteOneTimeToken(OneTimeToken oneTimeToken) { + private int deleteOneTimeToken(OneTimeToken oneTimeToken) { List parameters = List .of(new SqlParameterValue(Types.VARCHAR, oneTimeToken.getTokenValue())); PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray()); - this.jdbcOperations.update(DELETE_ONE_TIME_TOKEN_SQL, pss); + return this.jdbcOperations.update(DELETE_ONE_TIME_TOKEN_SQL, pss); } private ThreadPoolTaskScheduler createTaskScheduler(String cleanupCron) { diff --git a/core/src/test/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenServiceTests.java b/core/src/test/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenServiceTests.java index a1a9a2e225..7d2e07920b 100644 --- a/core/src/test/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenServiceTests.java +++ b/core/src/test/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenServiceTests.java @@ -21,19 +21,24 @@ import java.time.Duration; import java.time.Instant; import java.time.ZoneOffset; import java.time.temporal.ChronoUnit; +import java.util.List; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentMatchers; import org.springframework.jdbc.core.JdbcOperations; import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.core.PreparedStatementSetter; +import org.springframework.jdbc.core.RowMapper; import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase; import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; @@ -145,6 +150,27 @@ class JdbcOneTimeTokenServiceTests { assertThat(consumedOneTimeToken).isNull(); } + @Test + void consumeWhenTokenIsDeletedConcurrentlyThenReturnNull() throws Exception { + // Simulates a concurrent consume: SELECT finds the token but DELETE affects + // 0 rows because another caller already consumed it. + JdbcOperations jdbcOperations = mock(JdbcOperations.class); + Instant notExpired = Instant.now().plus(5, ChronoUnit.MINUTES); + OneTimeToken token = new DefaultOneTimeToken(TOKEN_VALUE, USERNAME, notExpired); + given(jdbcOperations.query(any(String.class), any(PreparedStatementSetter.class), + ArgumentMatchers.>any())) + .willReturn(List.of(token)); + given(jdbcOperations.update(any(String.class), any(PreparedStatementSetter.class))).willReturn(0); + JdbcOneTimeTokenService service = new JdbcOneTimeTokenService(jdbcOperations); + try { + OneTimeToken consumed = service.consume(new OneTimeTokenAuthenticationToken(TOKEN_VALUE)); + assertThat(consumed).isNull(); + } + finally { + service.destroy(); + } + } + @Test void consumeWhenTokenIsExpiredThenReturnNull() { GenerateOneTimeTokenRequest request = new GenerateOneTimeTokenRequest(USERNAME);