Implement internal cache in JtiClaimValidator

Closes gh-17107
This commit is contained in:
Joe Grandja 2025-05-14 05:21:00 -04:00
parent a265ac6ae7
commit 5f7155bfc7

View File

@ -18,12 +18,12 @@ package org.springframework.security.oauth2.jwt;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.time.Clock;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Base64;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import com.nimbusds.jose.JOSEException;
@ -146,7 +146,7 @@ public final class DPoPProofJwtDecoderFactory implements JwtDecoderFactory<DPoPP
private static final class JtiClaimValidator implements OAuth2TokenValidator<Jwt> {
private static final Map<String, Long> jtiCache = new ConcurrentHashMap<>();
private static final Map<String, Long> JTI_CACHE = Collections.synchronizedMap(new JtiCache());
@Override
public OAuth2TokenValidatorResult validate(Jwt jwt) {
@ -166,8 +166,8 @@ public final class DPoPProofJwtDecoderFactory implements JwtDecoderFactory<DPoPP
OAuth2Error error = createOAuth2Error("jti claim is invalid.");
return OAuth2TokenValidatorResult.failure(error);
}
Instant now = Instant.now(Clock.systemUTC());
if ((jtiCache.putIfAbsent(jtiHash, now.toEpochMilli())) != null) {
Instant expiry = Instant.now().plus(1, ChronoUnit.HOURS);
if ((JTI_CACHE.putIfAbsent(jtiHash, expiry.toEpochMilli())) != null) {
// Already used
OAuth2Error error = createOAuth2Error("jti claim is invalid.");
return OAuth2TokenValidatorResult.failure(error);
@ -185,6 +185,21 @@ public final class DPoPProofJwtDecoderFactory implements JwtDecoderFactory<DPoPP
return Base64.getUrlEncoder().withoutPadding().encodeToString(digest);
}
private static final class JtiCache extends LinkedHashMap<String, Long> {
private static final int MAX_SIZE = 1000;
@Override
protected boolean removeEldestEntry(Map.Entry<String, Long> eldest) {
if (size() > MAX_SIZE) {
return true;
}
Instant expiry = Instant.ofEpochMilli(eldest.getValue());
return Instant.now().isAfter(expiry);
}
}
}
}