mirror of https://github.com/apache/lucene.git
SOLR-12121: Refresh JWK from IdP on invalid sig. Some logging improvements. Minor test fixes
This commit is contained in:
parent
ef9566e13e
commit
3b3879d880
|
@ -66,6 +66,7 @@ import org.jose4j.jwk.JsonWebKeySet;
|
|||
import org.jose4j.jwt.JwtClaims;
|
||||
import org.jose4j.jwt.MalformedClaimException;
|
||||
import org.jose4j.jwt.consumer.InvalidJwtException;
|
||||
import org.jose4j.jwt.consumer.InvalidJwtSignatureException;
|
||||
import org.jose4j.jwt.consumer.JwtConsumer;
|
||||
import org.jose4j.jwt.consumer.JwtConsumerBuilder;
|
||||
import org.jose4j.keys.resolvers.HttpsJwksVerificationKeyResolver;
|
||||
|
@ -114,7 +115,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
|
|||
private boolean requireSubject;
|
||||
private boolean requireExpirationTime;
|
||||
private List<String> algWhitelist;
|
||||
private VerificationKeyResolver verificationKeyResolver;
|
||||
VerificationKeyResolver verificationKeyResolver;
|
||||
private String principalClaim;
|
||||
private HashMap<String, Pattern> claimsMatchCompiled;
|
||||
private boolean blockUnknown;
|
||||
|
@ -128,6 +129,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
|
|||
private String authorizationEndpoint;
|
||||
private String adminUiScope;
|
||||
private List<String> redirectUris;
|
||||
private HttpsJwks httpsJkws;
|
||||
|
||||
|
||||
/**
|
||||
|
@ -135,6 +137,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
|
|||
*/
|
||||
public JWTAuthPlugin() {}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@Override
|
||||
public void init(Map<String, Object> pluginConfig) {
|
||||
List<String> unknownKeys = pluginConfig.keySet().stream().filter(k -> !PROPS.contains(k)).collect(Collectors.toList());
|
||||
|
@ -221,6 +224,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
|
|||
lastInitTime = Instant.now();
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private void initJwk(Map<String, Object> pluginConfig) {
|
||||
this.pluginConfig = pluginConfig;
|
||||
String confJwkUrl = (String) pluginConfig.get(PARAM_JWK_URL);
|
||||
|
@ -247,6 +251,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
|
|||
try {
|
||||
JsonWebKeySet jwks = parseJwkSet(confJwk);
|
||||
verificationKeyResolver = new JwksVerificationKeyResolver(jwks.getJsonWebKeys());
|
||||
httpsJkws = null;
|
||||
} catch (JoseException e) {
|
||||
throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, "Invalid JWTAuthPlugin configuration, " + PARAM_JWK + " parse error", e);
|
||||
}
|
||||
|
@ -255,7 +260,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
|
|||
log.debug("JWK configured");
|
||||
}
|
||||
|
||||
private void setupJwkUrl(String url) {
|
||||
void setupJwkUrl(String url) {
|
||||
// The HttpsJwks retrieves and caches keys from a the given HTTPS JWKS endpoint.
|
||||
try {
|
||||
URL jwkUrl = new URL(url);
|
||||
|
@ -265,11 +270,13 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
|
|||
} catch (MalformedURLException e) {
|
||||
throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, PARAM_JWK_URL + " must be a valid URL");
|
||||
}
|
||||
HttpsJwks httpsJkws = new HttpsJwks(url);
|
||||
httpsJkws = new HttpsJwks(url);
|
||||
httpsJkws.setDefaultCacheDuration(jwkCacheDuration);
|
||||
httpsJkws.setRefreshReprieveThreshold(5000);
|
||||
verificationKeyResolver = new HttpsJwksVerificationKeyResolver(httpsJkws);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
JsonWebKeySet parseJwkSet(Map<String, Object> jwkObj) throws JoseException {
|
||||
JsonWebKeySet webKeySet = new JsonWebKeySet();
|
||||
if (jwkObj.containsKey("keys")) {
|
||||
|
@ -297,7 +304,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
|
|||
if (header == null && !blockUnknown) {
|
||||
log.info("JWTAuth not configured, but allowing anonymous access since {}==false", PARAM_BLOCK_UNKNOWN);
|
||||
filterChain.doFilter(request, response);
|
||||
numPassThrough.inc();;
|
||||
numPassThrough.inc();
|
||||
return true;
|
||||
}
|
||||
// Retry config
|
||||
|
@ -313,15 +320,24 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
|
|||
}
|
||||
|
||||
JWTAuthenticationResponse authResponse = authenticate(header);
|
||||
switch(authResponse.getAuthCode()) {
|
||||
if (AuthCode.SIGNATURE_INVALID.equals(authResponse.getAuthCode()) && httpsJkws != null) {
|
||||
log.warn("Signature validation failed. Refreshing JWKs from IdP before trying again: {}",
|
||||
authResponse.getJwtException() == null ? "" : authResponse.getJwtException().getMessage());
|
||||
httpsJkws.refresh();
|
||||
authResponse = authenticate(header);
|
||||
}
|
||||
String exceptionMessage = authResponse.getJwtException() != null ? authResponse.getJwtException().getMessage() : "";
|
||||
|
||||
switch (authResponse.getAuthCode()) {
|
||||
case AUTHENTICATED:
|
||||
final Principal principal = authResponse.getPrincipal();
|
||||
HttpServletRequestWrapper wrapper = new HttpServletRequestWrapper(request) {
|
||||
@Override
|
||||
public Principal getUserPrincipal() {
|
||||
return authResponse.getPrincipal();
|
||||
return principal;
|
||||
}
|
||||
};
|
||||
if (!(authResponse.getPrincipal() instanceof JWTPrincipal)) {
|
||||
if (!(principal instanceof JWTPrincipal)) {
|
||||
numErrors.mark();
|
||||
throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, "JWTAuth plugin says AUTHENTICATED but no token extracted");
|
||||
}
|
||||
|
@ -340,6 +356,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
|
|||
|
||||
case AUTZ_HEADER_PROBLEM:
|
||||
case JWT_PARSE_ERROR:
|
||||
log.warn("Authentication failed. {}, {}", authResponse.getAuthCode(), authResponse.getAuthCode().getMsg());
|
||||
authenticationFailure(response, authResponse.getAuthCode().getMsg(), HttpServletResponse.SC_BAD_REQUEST, BearerWwwAuthErrorCode.invalid_request);
|
||||
numErrors.mark();
|
||||
return false;
|
||||
|
@ -348,9 +365,13 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
|
|||
case JWT_EXPIRED:
|
||||
case JWT_VALIDATION_EXCEPTION:
|
||||
case PRINCIPAL_MISSING:
|
||||
if (authResponse.getJwtException() != null) {
|
||||
log.warn("Exception: {}", authResponse.getJwtException().getMessage());
|
||||
}
|
||||
log.warn("Authentication failed. {}, {}", authResponse.getAuthCode(), exceptionMessage);
|
||||
authenticationFailure(response, authResponse.getAuthCode().getMsg(), HttpServletResponse.SC_UNAUTHORIZED, BearerWwwAuthErrorCode.invalid_token);
|
||||
numWrongCredentials.inc();
|
||||
return false;
|
||||
|
||||
case SIGNATURE_INVALID:
|
||||
log.warn("Signature validation failed: {}", exceptionMessage);
|
||||
authenticationFailure(response, authResponse.getAuthCode().getMsg(), HttpServletResponse.SC_UNAUTHORIZED, BearerWwwAuthErrorCode.invalid_token);
|
||||
numWrongCredentials.inc();
|
||||
return false;
|
||||
|
@ -427,6 +448,8 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
|
|||
} else {
|
||||
return new JWTAuthenticationResponse(AuthCode.AUTHENTICATED, new JWTPrincipal(principal, jwtCompact, jwtClaims.getClaimsMap()));
|
||||
}
|
||||
} catch (InvalidJwtSignatureException ise) {
|
||||
return new JWTAuthenticationResponse(AuthCode.SIGNATURE_INVALID, ise);
|
||||
} catch (InvalidJwtException e) {
|
||||
// Whether or not the JWT has expired being one common reason for invalidity
|
||||
if (e.hasExpired()) {
|
||||
|
@ -516,7 +539,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
|
|||
return latestConf;
|
||||
}
|
||||
|
||||
private enum BearerWwwAuthErrorCode { invalid_request, invalid_token, insufficient_scope};
|
||||
private enum BearerWwwAuthErrorCode { invalid_request, invalid_token, insufficient_scope}
|
||||
|
||||
private void authenticationFailure(HttpServletResponse response, String message, int httpCode, BearerWwwAuthErrorCode responseError) throws IOException {
|
||||
List<String> wwwAuthParams = new ArrayList<>();
|
||||
|
@ -561,7 +584,8 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
|
|||
JWT_EXPIRED("JWT token expired"), // JWT token has expired
|
||||
CLAIM_MISMATCH("Required JWT claim missing"), // Some required claims are missing or wrong
|
||||
JWT_VALIDATION_EXCEPTION("JWT validation failed"), // The JWT parser failed validation. More details in exception
|
||||
SCOPE_MISSING("Required scope missing in JWT"); // None of the required scopes were present in JWT
|
||||
SCOPE_MISSING("Required scope missing in JWT"), // None of the required scopes were present in JWT
|
||||
SIGNATURE_INVALID("Signature invalid"); // Validation of JWT signature failed
|
||||
|
||||
public String getMsg() {
|
||||
return msg;
|
||||
|
@ -647,6 +671,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
|
|||
return parse(new ByteArrayInputStream(json.getBytes(charset)));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static WellKnownDiscoveryConfig parse(InputStream configStream) {
|
||||
securityConf = (Map<String, Object>) Utils.fromJSON(configStream);
|
||||
return new WellKnownDiscoveryConfig(securityConf);
|
||||
|
@ -673,10 +698,12 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
|
|||
return (String) securityConf.get("token_endpoint");
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public List<String> getScopesSupported() {
|
||||
return (List<String>) securityConf.get("scopes_supported");
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public List<String> getResponseTypesSupported() {
|
||||
return (List<String>) securityConf.get("response_types_supported");
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ import java.nio.file.Path;
|
|||
import java.security.Principal;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
@ -34,12 +35,16 @@ import org.apache.solr.SolrTestCaseJ4;
|
|||
import org.apache.solr.common.SolrException;
|
||||
import org.apache.solr.common.util.Base64;
|
||||
import org.apache.solr.common.util.Utils;
|
||||
import org.jose4j.jwk.HttpsJwks;
|
||||
import org.jose4j.jwk.JsonWebKey;
|
||||
import org.jose4j.jwk.RsaJsonWebKey;
|
||||
import org.jose4j.jwk.RsaJwkGenerator;
|
||||
import org.jose4j.jws.AlgorithmIdentifiers;
|
||||
import org.jose4j.jws.JsonWebSignature;
|
||||
import org.jose4j.jwt.JwtClaims;
|
||||
import org.jose4j.keys.BigEndianBigInteger;
|
||||
import org.jose4j.keys.resolvers.HttpsJwksVerificationKeyResolver;
|
||||
import org.jose4j.lang.JoseException;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.BeforeClass;
|
||||
|
@ -49,6 +54,7 @@ import static org.apache.solr.security.JWTAuthPlugin.JWTAuthenticationResponse.A
|
|||
import static org.apache.solr.security.JWTAuthPlugin.JWTAuthenticationResponse.AuthCode.NO_AUTZ_HEADER;
|
||||
import static org.apache.solr.security.JWTAuthPlugin.JWTAuthenticationResponse.AuthCode.SCOPE_MISSING;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public class JWTAuthPluginTest extends SolrTestCaseJ4 {
|
||||
private static String testHeader;
|
||||
private static String slimHeader;
|
||||
|
@ -178,6 +184,39 @@ public class JWTAuthPluginTest extends SolrTestCaseJ4 {
|
|||
plugin.init(authConf);
|
||||
}
|
||||
|
||||
/**
|
||||
* Simulate a rotate of JWK key in IdP.
|
||||
* Validating of JWK signature will fail since we still use old cached JWK set.
|
||||
* Using a mock {@link HttpsJwks} we validate that plugin calls refresh() and then passes validation
|
||||
*/
|
||||
@Test
|
||||
public void invalidSigRefreshJwk() throws JoseException {
|
||||
RsaJsonWebKey rsaJsonWebKey2 = RsaJwkGenerator.generateJwk(2048);
|
||||
rsaJsonWebKey2.setKeyId("k2");
|
||||
HashMap<String, Object> testJwkWrong = new HashMap<>();
|
||||
testJwkWrong.put("kty", rsaJsonWebKey2.getKeyType());
|
||||
testJwkWrong.put("e", BigEndianBigInteger.toBase64Url(rsaJsonWebKey2.getRsaPublicKey().getPublicExponent()));
|
||||
testJwkWrong.put("use", rsaJsonWebKey2.getUse());
|
||||
testJwkWrong.put("kid", rsaJsonWebKey2.getKeyId());
|
||||
testJwkWrong.put("alg", rsaJsonWebKey2.getAlgorithm());
|
||||
testJwkWrong.put("n", BigEndianBigInteger.toBase64Url(rsaJsonWebKey2.getRsaPublicKey().getModulus()));
|
||||
JsonWebKey wrongJwk = JsonWebKey.Factory.newJwk(testJwkWrong);
|
||||
|
||||
// Configure our mock plugin with URL as jwk source
|
||||
JsonWebKey correctJwk = JsonWebKey.Factory.newJwk(testJwk);
|
||||
plugin = new MockJwksUrlPlugin(wrongJwk, correctJwk);
|
||||
HashMap<String, Object> pluginConfigJwkUrl = new HashMap<>();
|
||||
pluginConfigJwkUrl.put("class", "org.apache.solr.security.JWTAuthPlugin");
|
||||
pluginConfigJwkUrl.put("jwkUrl", "dummy");
|
||||
plugin.init(pluginConfigJwkUrl);
|
||||
|
||||
// Validate that plugin will call refresh() on invalid signature, then the call succeeds
|
||||
assertFalse(((MockJwksUrlPlugin)plugin).isRefreshCalled());
|
||||
JWTAuthPlugin.JWTAuthenticationResponse resp = plugin.authenticate(testHeader);
|
||||
assertTrue(resp.isAuthenticated());
|
||||
assertTrue(((MockJwksUrlPlugin)plugin).isRefreshCalled());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void parseJwkSet() throws Exception {
|
||||
plugin.parseJwkSet(testJwk);
|
||||
|
@ -337,13 +376,13 @@ public class JWTAuthPluginTest extends SolrTestCaseJ4 {
|
|||
@Test
|
||||
public void minimalConfigPassThrough() {
|
||||
testConfig.put("blockUnknown", false);
|
||||
plugin.init(testConfig);
|
||||
plugin.init(minimalConfig);
|
||||
JWTAuthPlugin.JWTAuthenticationResponse resp = plugin.authenticate(null);
|
||||
assertEquals(JWTAuthPlugin.JWTAuthenticationResponse.AuthCode.PASS_THROUGH, resp.getAuthCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void wellKnownConfig() throws IOException {
|
||||
public void wellKnownConfig() {
|
||||
String wellKnownUrl = TEST_PATH().resolve("security").resolve("jwt_well-known-config.json").toAbsolutePath().toUri().toString();
|
||||
testConfig.put("wellKnownUrl", wellKnownUrl);
|
||||
testConfig.remove("jwk");
|
||||
|
@ -353,13 +392,13 @@ public class JWTAuthPluginTest extends SolrTestCaseJ4 {
|
|||
}
|
||||
|
||||
@Test(expected = SolrException.class)
|
||||
public void onlyOneJwkConfig() throws IOException {
|
||||
testConfig.put("jwkUrl", "http://127.0.0.1:45678/.well-known/config");
|
||||
public void onlyOneJwkConfig() {
|
||||
testConfig.put("jwkUrl", "http://127.0.0.1:45678/myJwk");
|
||||
plugin.init(testConfig);
|
||||
}
|
||||
|
||||
@Test(expected = SolrException.class)
|
||||
public void wellKnownConfigNotHttps() throws IOException {
|
||||
public void wellKnownConfigNotHttps() {
|
||||
testConfig.put("wellKnownUrl", "http://127.0.0.1:45678/.well-known/config");
|
||||
plugin.init(testConfig);
|
||||
}
|
||||
|
@ -402,4 +441,49 @@ public class JWTAuthPluginTest extends SolrTestCaseJ4 {
|
|||
assertEquals("http://acmepaymentscorp/oauth/auz/authorize", parsed.get("authorizationEndpoint"));
|
||||
assertEquals("solr-cluster", parsed.get("client_id"));
|
||||
}
|
||||
|
||||
/**
|
||||
* Mock plugin that simulates a {@link HttpsJwks} with cached JWK that returns
|
||||
* a different JWK after a call to refresh()
|
||||
*/
|
||||
private class MockJwksUrlPlugin extends JWTAuthPlugin {
|
||||
private final JsonWebKey wrongJwk;
|
||||
private final JsonWebKey correctJwk;
|
||||
|
||||
boolean isRefreshCalled() {
|
||||
return refreshCalled;
|
||||
}
|
||||
|
||||
private boolean refreshCalled;
|
||||
|
||||
MockJwksUrlPlugin(JsonWebKey wrongJwk, JsonWebKey correctJwk) {
|
||||
this.wrongJwk = wrongJwk;
|
||||
this.correctJwk = correctJwk;
|
||||
}
|
||||
|
||||
@Override
|
||||
void setupJwkUrl(String url) {
|
||||
MockHttpsJwks httpsJkws = new MockHttpsJwks(url);
|
||||
verificationKeyResolver = new HttpsJwksVerificationKeyResolver(httpsJkws);
|
||||
}
|
||||
|
||||
private class MockHttpsJwks extends HttpsJwks {
|
||||
MockHttpsJwks(String url) {
|
||||
super(url);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<JsonWebKey> getJsonWebKeys() {
|
||||
return refreshCalled ? Collections.singletonList(correctJwk) : Collections.singletonList(wrongJwk);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void refresh() {
|
||||
if (refreshCalled) {
|
||||
fail("Refresh called twice");
|
||||
}
|
||||
refreshCalled = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue