diff --git a/solr/core/src/java/org/apache/solr/security/JWTAuthPlugin.java b/solr/core/src/java/org/apache/solr/security/JWTAuthPlugin.java index baf32f41bd9..121e91b47bd 100644 --- a/solr/core/src/java/org/apache/solr/security/JWTAuthPlugin.java +++ b/solr/core/src/java/org/apache/solr/security/JWTAuthPlugin.java @@ -66,6 +66,7 @@ import org.jose4j.jwk.JsonWebKeySet; import org.jose4j.jwt.JwtClaims; import org.jose4j.jwt.MalformedClaimException; import org.jose4j.jwt.consumer.InvalidJwtException; +import org.jose4j.jwt.consumer.InvalidJwtSignatureException; import org.jose4j.jwt.consumer.JwtConsumer; import org.jose4j.jwt.consumer.JwtConsumerBuilder; import org.jose4j.keys.resolvers.HttpsJwksVerificationKeyResolver; @@ -114,7 +115,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider, private boolean requireSubject; private boolean requireExpirationTime; private List algWhitelist; - private VerificationKeyResolver verificationKeyResolver; + VerificationKeyResolver verificationKeyResolver; private String principalClaim; private HashMap claimsMatchCompiled; private boolean blockUnknown; @@ -128,6 +129,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider, private String authorizationEndpoint; private String adminUiScope; private List redirectUris; + private HttpsJwks httpsJkws; /** @@ -135,6 +137,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider, */ public JWTAuthPlugin() {} + @SuppressWarnings("unchecked") @Override public void init(Map pluginConfig) { List unknownKeys = pluginConfig.keySet().stream().filter(k -> !PROPS.contains(k)).collect(Collectors.toList()); @@ -221,6 +224,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider, lastInitTime = Instant.now(); } + @SuppressWarnings("unchecked") private void initJwk(Map pluginConfig) { this.pluginConfig = pluginConfig; String confJwkUrl = (String) pluginConfig.get(PARAM_JWK_URL); @@ -247,6 +251,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider, try { JsonWebKeySet jwks = parseJwkSet(confJwk); verificationKeyResolver = new JwksVerificationKeyResolver(jwks.getJsonWebKeys()); + httpsJkws = null; } catch (JoseException e) { throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, "Invalid JWTAuthPlugin configuration, " + PARAM_JWK + " parse error", e); } @@ -255,7 +260,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider, log.debug("JWK configured"); } - private void setupJwkUrl(String url) { + void setupJwkUrl(String url) { // The HttpsJwks retrieves and caches keys from a the given HTTPS JWKS endpoint. try { URL jwkUrl = new URL(url); @@ -265,11 +270,13 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider, } catch (MalformedURLException e) { throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, PARAM_JWK_URL + " must be a valid URL"); } - HttpsJwks httpsJkws = new HttpsJwks(url); + httpsJkws = new HttpsJwks(url); httpsJkws.setDefaultCacheDuration(jwkCacheDuration); + httpsJkws.setRefreshReprieveThreshold(5000); verificationKeyResolver = new HttpsJwksVerificationKeyResolver(httpsJkws); } + @SuppressWarnings("unchecked") JsonWebKeySet parseJwkSet(Map jwkObj) throws JoseException { JsonWebKeySet webKeySet = new JsonWebKeySet(); if (jwkObj.containsKey("keys")) { @@ -297,7 +304,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider, if (header == null && !blockUnknown) { log.info("JWTAuth not configured, but allowing anonymous access since {}==false", PARAM_BLOCK_UNKNOWN); filterChain.doFilter(request, response); - numPassThrough.inc();; + numPassThrough.inc(); return true; } // Retry config @@ -313,15 +320,24 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider, } JWTAuthenticationResponse authResponse = authenticate(header); - switch(authResponse.getAuthCode()) { + if (AuthCode.SIGNATURE_INVALID.equals(authResponse.getAuthCode()) && httpsJkws != null) { + log.warn("Signature validation failed. Refreshing JWKs from IdP before trying again: {}", + authResponse.getJwtException() == null ? "" : authResponse.getJwtException().getMessage()); + httpsJkws.refresh(); + authResponse = authenticate(header); + } + String exceptionMessage = authResponse.getJwtException() != null ? authResponse.getJwtException().getMessage() : ""; + + switch (authResponse.getAuthCode()) { case AUTHENTICATED: + final Principal principal = authResponse.getPrincipal(); HttpServletRequestWrapper wrapper = new HttpServletRequestWrapper(request) { @Override public Principal getUserPrincipal() { - return authResponse.getPrincipal(); + return principal; } }; - if (!(authResponse.getPrincipal() instanceof JWTPrincipal)) { + if (!(principal instanceof JWTPrincipal)) { numErrors.mark(); throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, "JWTAuth plugin says AUTHENTICATED but no token extracted"); } @@ -340,6 +356,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider, case AUTZ_HEADER_PROBLEM: case JWT_PARSE_ERROR: + log.warn("Authentication failed. {}, {}", authResponse.getAuthCode(), authResponse.getAuthCode().getMsg()); authenticationFailure(response, authResponse.getAuthCode().getMsg(), HttpServletResponse.SC_BAD_REQUEST, BearerWwwAuthErrorCode.invalid_request); numErrors.mark(); return false; @@ -348,9 +365,13 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider, case JWT_EXPIRED: case JWT_VALIDATION_EXCEPTION: case PRINCIPAL_MISSING: - if (authResponse.getJwtException() != null) { - log.warn("Exception: {}", authResponse.getJwtException().getMessage()); - } + log.warn("Authentication failed. {}, {}", authResponse.getAuthCode(), exceptionMessage); + authenticationFailure(response, authResponse.getAuthCode().getMsg(), HttpServletResponse.SC_UNAUTHORIZED, BearerWwwAuthErrorCode.invalid_token); + numWrongCredentials.inc(); + return false; + + case SIGNATURE_INVALID: + log.warn("Signature validation failed: {}", exceptionMessage); authenticationFailure(response, authResponse.getAuthCode().getMsg(), HttpServletResponse.SC_UNAUTHORIZED, BearerWwwAuthErrorCode.invalid_token); numWrongCredentials.inc(); return false; @@ -359,7 +380,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider, authenticationFailure(response, authResponse.getAuthCode().getMsg(), HttpServletResponse.SC_UNAUTHORIZED, BearerWwwAuthErrorCode.insufficient_scope); numWrongCredentials.inc(); return false; - + case NO_AUTZ_HEADER: default: authenticationFailure(response, authResponse.getAuthCode().getMsg(), HttpServletResponse.SC_UNAUTHORIZED, null); @@ -427,6 +448,8 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider, } else { return new JWTAuthenticationResponse(AuthCode.AUTHENTICATED, new JWTPrincipal(principal, jwtCompact, jwtClaims.getClaimsMap())); } + } catch (InvalidJwtSignatureException ise) { + return new JWTAuthenticationResponse(AuthCode.SIGNATURE_INVALID, ise); } catch (InvalidJwtException e) { // Whether or not the JWT has expired being one common reason for invalidity if (e.hasExpired()) { @@ -516,7 +539,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider, return latestConf; } - private enum BearerWwwAuthErrorCode { invalid_request, invalid_token, insufficient_scope}; + private enum BearerWwwAuthErrorCode { invalid_request, invalid_token, insufficient_scope} private void authenticationFailure(HttpServletResponse response, String message, int httpCode, BearerWwwAuthErrorCode responseError) throws IOException { List wwwAuthParams = new ArrayList<>(); @@ -561,8 +584,9 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider, JWT_EXPIRED("JWT token expired"), // JWT token has expired CLAIM_MISMATCH("Required JWT claim missing"), // Some required claims are missing or wrong JWT_VALIDATION_EXCEPTION("JWT validation failed"), // The JWT parser failed validation. More details in exception - SCOPE_MISSING("Required scope missing in JWT"); // None of the required scopes were present in JWT - + SCOPE_MISSING("Required scope missing in JWT"), // None of the required scopes were present in JWT + SIGNATURE_INVALID("Signature invalid"); // Validation of JWT signature failed + public String getMsg() { return msg; } @@ -647,6 +671,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider, return parse(new ByteArrayInputStream(json.getBytes(charset))); } + @SuppressWarnings("unchecked") public static WellKnownDiscoveryConfig parse(InputStream configStream) { securityConf = (Map) Utils.fromJSON(configStream); return new WellKnownDiscoveryConfig(securityConf); @@ -673,10 +698,12 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider, return (String) securityConf.get("token_endpoint"); } + @SuppressWarnings("unchecked") public List getScopesSupported() { return (List) securityConf.get("scopes_supported"); } + @SuppressWarnings("unchecked") public List getResponseTypesSupported() { return (List) securityConf.get("response_types_supported"); } diff --git a/solr/core/src/test/org/apache/solr/security/JWTAuthPluginTest.java b/solr/core/src/test/org/apache/solr/security/JWTAuthPluginTest.java index 00309df1266..3066ef84815 100644 --- a/solr/core/src/test/org/apache/solr/security/JWTAuthPluginTest.java +++ b/solr/core/src/test/org/apache/solr/security/JWTAuthPluginTest.java @@ -24,6 +24,7 @@ import java.nio.file.Path; import java.security.Principal; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -34,12 +35,16 @@ import org.apache.solr.SolrTestCaseJ4; import org.apache.solr.common.SolrException; import org.apache.solr.common.util.Base64; import org.apache.solr.common.util.Utils; +import org.jose4j.jwk.HttpsJwks; +import org.jose4j.jwk.JsonWebKey; import org.jose4j.jwk.RsaJsonWebKey; import org.jose4j.jwk.RsaJwkGenerator; import org.jose4j.jws.AlgorithmIdentifiers; import org.jose4j.jws.JsonWebSignature; import org.jose4j.jwt.JwtClaims; import org.jose4j.keys.BigEndianBigInteger; +import org.jose4j.keys.resolvers.HttpsJwksVerificationKeyResolver; +import org.jose4j.lang.JoseException; import org.junit.After; import org.junit.Before; import org.junit.BeforeClass; @@ -49,6 +54,7 @@ import static org.apache.solr.security.JWTAuthPlugin.JWTAuthenticationResponse.A import static org.apache.solr.security.JWTAuthPlugin.JWTAuthenticationResponse.AuthCode.NO_AUTZ_HEADER; import static org.apache.solr.security.JWTAuthPlugin.JWTAuthenticationResponse.AuthCode.SCOPE_MISSING; +@SuppressWarnings("unchecked") public class JWTAuthPluginTest extends SolrTestCaseJ4 { private static String testHeader; private static String slimHeader; @@ -178,6 +184,39 @@ public class JWTAuthPluginTest extends SolrTestCaseJ4 { plugin.init(authConf); } + /** + * Simulate a rotate of JWK key in IdP. + * Validating of JWK signature will fail since we still use old cached JWK set. + * Using a mock {@link HttpsJwks} we validate that plugin calls refresh() and then passes validation + */ + @Test + public void invalidSigRefreshJwk() throws JoseException { + RsaJsonWebKey rsaJsonWebKey2 = RsaJwkGenerator.generateJwk(2048); + rsaJsonWebKey2.setKeyId("k2"); + HashMap testJwkWrong = new HashMap<>(); + testJwkWrong.put("kty", rsaJsonWebKey2.getKeyType()); + testJwkWrong.put("e", BigEndianBigInteger.toBase64Url(rsaJsonWebKey2.getRsaPublicKey().getPublicExponent())); + testJwkWrong.put("use", rsaJsonWebKey2.getUse()); + testJwkWrong.put("kid", rsaJsonWebKey2.getKeyId()); + testJwkWrong.put("alg", rsaJsonWebKey2.getAlgorithm()); + testJwkWrong.put("n", BigEndianBigInteger.toBase64Url(rsaJsonWebKey2.getRsaPublicKey().getModulus())); + JsonWebKey wrongJwk = JsonWebKey.Factory.newJwk(testJwkWrong); + + // Configure our mock plugin with URL as jwk source + JsonWebKey correctJwk = JsonWebKey.Factory.newJwk(testJwk); + plugin = new MockJwksUrlPlugin(wrongJwk, correctJwk); + HashMap pluginConfigJwkUrl = new HashMap<>(); + pluginConfigJwkUrl.put("class", "org.apache.solr.security.JWTAuthPlugin"); + pluginConfigJwkUrl.put("jwkUrl", "dummy"); + plugin.init(pluginConfigJwkUrl); + + // Validate that plugin will call refresh() on invalid signature, then the call succeeds + assertFalse(((MockJwksUrlPlugin)plugin).isRefreshCalled()); + JWTAuthPlugin.JWTAuthenticationResponse resp = plugin.authenticate(testHeader); + assertTrue(resp.isAuthenticated()); + assertTrue(((MockJwksUrlPlugin)plugin).isRefreshCalled()); + } + @Test public void parseJwkSet() throws Exception { plugin.parseJwkSet(testJwk); @@ -337,13 +376,13 @@ public class JWTAuthPluginTest extends SolrTestCaseJ4 { @Test public void minimalConfigPassThrough() { testConfig.put("blockUnknown", false); - plugin.init(testConfig); + plugin.init(minimalConfig); JWTAuthPlugin.JWTAuthenticationResponse resp = plugin.authenticate(null); assertEquals(JWTAuthPlugin.JWTAuthenticationResponse.AuthCode.PASS_THROUGH, resp.getAuthCode()); } @Test - public void wellKnownConfig() throws IOException { + public void wellKnownConfig() { String wellKnownUrl = TEST_PATH().resolve("security").resolve("jwt_well-known-config.json").toAbsolutePath().toUri().toString(); testConfig.put("wellKnownUrl", wellKnownUrl); testConfig.remove("jwk"); @@ -353,13 +392,13 @@ public class JWTAuthPluginTest extends SolrTestCaseJ4 { } @Test(expected = SolrException.class) - public void onlyOneJwkConfig() throws IOException { - testConfig.put("jwkUrl", "http://127.0.0.1:45678/.well-known/config"); + public void onlyOneJwkConfig() { + testConfig.put("jwkUrl", "http://127.0.0.1:45678/myJwk"); plugin.init(testConfig); } @Test(expected = SolrException.class) - public void wellKnownConfigNotHttps() throws IOException { + public void wellKnownConfigNotHttps() { testConfig.put("wellKnownUrl", "http://127.0.0.1:45678/.well-known/config"); plugin.init(testConfig); } @@ -402,4 +441,49 @@ public class JWTAuthPluginTest extends SolrTestCaseJ4 { assertEquals("http://acmepaymentscorp/oauth/auz/authorize", parsed.get("authorizationEndpoint")); assertEquals("solr-cluster", parsed.get("client_id")); } + + /** + * Mock plugin that simulates a {@link HttpsJwks} with cached JWK that returns + * a different JWK after a call to refresh() + */ + private class MockJwksUrlPlugin extends JWTAuthPlugin { + private final JsonWebKey wrongJwk; + private final JsonWebKey correctJwk; + + boolean isRefreshCalled() { + return refreshCalled; + } + + private boolean refreshCalled; + + MockJwksUrlPlugin(JsonWebKey wrongJwk, JsonWebKey correctJwk) { + this.wrongJwk = wrongJwk; + this.correctJwk = correctJwk; + } + + @Override + void setupJwkUrl(String url) { + MockHttpsJwks httpsJkws = new MockHttpsJwks(url); + verificationKeyResolver = new HttpsJwksVerificationKeyResolver(httpsJkws); + } + + private class MockHttpsJwks extends HttpsJwks { + MockHttpsJwks(String url) { + super(url); + } + + @Override + public List getJsonWebKeys() { + return refreshCalled ? Collections.singletonList(correctJwk) : Collections.singletonList(wrongJwk); + } + + @Override + public void refresh() { + if (refreshCalled) { + fail("Refresh called twice"); + } + refreshCalled = true; + } + } + } }