SOLR-12121: Refresh JWK from IdP on invalid sig. Some logging improvements. Minor test fixes

This commit is contained in:
Jan Høydahl 2019-04-12 09:49:11 +02:00
parent ef9566e13e
commit 3b3879d880
2 changed files with 130 additions and 19 deletions

View File

@ -66,6 +66,7 @@ import org.jose4j.jwk.JsonWebKeySet;
import org.jose4j.jwt.JwtClaims; import org.jose4j.jwt.JwtClaims;
import org.jose4j.jwt.MalformedClaimException; import org.jose4j.jwt.MalformedClaimException;
import org.jose4j.jwt.consumer.InvalidJwtException; import org.jose4j.jwt.consumer.InvalidJwtException;
import org.jose4j.jwt.consumer.InvalidJwtSignatureException;
import org.jose4j.jwt.consumer.JwtConsumer; import org.jose4j.jwt.consumer.JwtConsumer;
import org.jose4j.jwt.consumer.JwtConsumerBuilder; import org.jose4j.jwt.consumer.JwtConsumerBuilder;
import org.jose4j.keys.resolvers.HttpsJwksVerificationKeyResolver; import org.jose4j.keys.resolvers.HttpsJwksVerificationKeyResolver;
@ -114,7 +115,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
private boolean requireSubject; private boolean requireSubject;
private boolean requireExpirationTime; private boolean requireExpirationTime;
private List<String> algWhitelist; private List<String> algWhitelist;
private VerificationKeyResolver verificationKeyResolver; VerificationKeyResolver verificationKeyResolver;
private String principalClaim; private String principalClaim;
private HashMap<String, Pattern> claimsMatchCompiled; private HashMap<String, Pattern> claimsMatchCompiled;
private boolean blockUnknown; private boolean blockUnknown;
@ -128,6 +129,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
private String authorizationEndpoint; private String authorizationEndpoint;
private String adminUiScope; private String adminUiScope;
private List<String> redirectUris; private List<String> redirectUris;
private HttpsJwks httpsJkws;
/** /**
@ -135,6 +137,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
*/ */
public JWTAuthPlugin() {} public JWTAuthPlugin() {}
@SuppressWarnings("unchecked")
@Override @Override
public void init(Map<String, Object> pluginConfig) { public void init(Map<String, Object> pluginConfig) {
List<String> unknownKeys = pluginConfig.keySet().stream().filter(k -> !PROPS.contains(k)).collect(Collectors.toList()); List<String> unknownKeys = pluginConfig.keySet().stream().filter(k -> !PROPS.contains(k)).collect(Collectors.toList());
@ -221,6 +224,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
lastInitTime = Instant.now(); lastInitTime = Instant.now();
} }
@SuppressWarnings("unchecked")
private void initJwk(Map<String, Object> pluginConfig) { private void initJwk(Map<String, Object> pluginConfig) {
this.pluginConfig = pluginConfig; this.pluginConfig = pluginConfig;
String confJwkUrl = (String) pluginConfig.get(PARAM_JWK_URL); String confJwkUrl = (String) pluginConfig.get(PARAM_JWK_URL);
@ -247,6 +251,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
try { try {
JsonWebKeySet jwks = parseJwkSet(confJwk); JsonWebKeySet jwks = parseJwkSet(confJwk);
verificationKeyResolver = new JwksVerificationKeyResolver(jwks.getJsonWebKeys()); verificationKeyResolver = new JwksVerificationKeyResolver(jwks.getJsonWebKeys());
httpsJkws = null;
} catch (JoseException e) { } catch (JoseException e) {
throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, "Invalid JWTAuthPlugin configuration, " + PARAM_JWK + " parse error", e); throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, "Invalid JWTAuthPlugin configuration, " + PARAM_JWK + " parse error", e);
} }
@ -255,7 +260,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
log.debug("JWK configured"); log.debug("JWK configured");
} }
private void setupJwkUrl(String url) { void setupJwkUrl(String url) {
// The HttpsJwks retrieves and caches keys from a the given HTTPS JWKS endpoint. // The HttpsJwks retrieves and caches keys from a the given HTTPS JWKS endpoint.
try { try {
URL jwkUrl = new URL(url); URL jwkUrl = new URL(url);
@ -265,11 +270,13 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
} catch (MalformedURLException e) { } catch (MalformedURLException e) {
throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, PARAM_JWK_URL + " must be a valid URL"); throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, PARAM_JWK_URL + " must be a valid URL");
} }
HttpsJwks httpsJkws = new HttpsJwks(url); httpsJkws = new HttpsJwks(url);
httpsJkws.setDefaultCacheDuration(jwkCacheDuration); httpsJkws.setDefaultCacheDuration(jwkCacheDuration);
httpsJkws.setRefreshReprieveThreshold(5000);
verificationKeyResolver = new HttpsJwksVerificationKeyResolver(httpsJkws); verificationKeyResolver = new HttpsJwksVerificationKeyResolver(httpsJkws);
} }
@SuppressWarnings("unchecked")
JsonWebKeySet parseJwkSet(Map<String, Object> jwkObj) throws JoseException { JsonWebKeySet parseJwkSet(Map<String, Object> jwkObj) throws JoseException {
JsonWebKeySet webKeySet = new JsonWebKeySet(); JsonWebKeySet webKeySet = new JsonWebKeySet();
if (jwkObj.containsKey("keys")) { if (jwkObj.containsKey("keys")) {
@ -297,7 +304,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
if (header == null && !blockUnknown) { if (header == null && !blockUnknown) {
log.info("JWTAuth not configured, but allowing anonymous access since {}==false", PARAM_BLOCK_UNKNOWN); log.info("JWTAuth not configured, but allowing anonymous access since {}==false", PARAM_BLOCK_UNKNOWN);
filterChain.doFilter(request, response); filterChain.doFilter(request, response);
numPassThrough.inc();; numPassThrough.inc();
return true; return true;
} }
// Retry config // Retry config
@ -313,15 +320,24 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
} }
JWTAuthenticationResponse authResponse = authenticate(header); JWTAuthenticationResponse authResponse = authenticate(header);
switch(authResponse.getAuthCode()) { if (AuthCode.SIGNATURE_INVALID.equals(authResponse.getAuthCode()) && httpsJkws != null) {
log.warn("Signature validation failed. Refreshing JWKs from IdP before trying again: {}",
authResponse.getJwtException() == null ? "" : authResponse.getJwtException().getMessage());
httpsJkws.refresh();
authResponse = authenticate(header);
}
String exceptionMessage = authResponse.getJwtException() != null ? authResponse.getJwtException().getMessage() : "";
switch (authResponse.getAuthCode()) {
case AUTHENTICATED: case AUTHENTICATED:
final Principal principal = authResponse.getPrincipal();
HttpServletRequestWrapper wrapper = new HttpServletRequestWrapper(request) { HttpServletRequestWrapper wrapper = new HttpServletRequestWrapper(request) {
@Override @Override
public Principal getUserPrincipal() { public Principal getUserPrincipal() {
return authResponse.getPrincipal(); return principal;
} }
}; };
if (!(authResponse.getPrincipal() instanceof JWTPrincipal)) { if (!(principal instanceof JWTPrincipal)) {
numErrors.mark(); numErrors.mark();
throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, "JWTAuth plugin says AUTHENTICATED but no token extracted"); throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, "JWTAuth plugin says AUTHENTICATED but no token extracted");
} }
@ -340,6 +356,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
case AUTZ_HEADER_PROBLEM: case AUTZ_HEADER_PROBLEM:
case JWT_PARSE_ERROR: case JWT_PARSE_ERROR:
log.warn("Authentication failed. {}, {}", authResponse.getAuthCode(), authResponse.getAuthCode().getMsg());
authenticationFailure(response, authResponse.getAuthCode().getMsg(), HttpServletResponse.SC_BAD_REQUEST, BearerWwwAuthErrorCode.invalid_request); authenticationFailure(response, authResponse.getAuthCode().getMsg(), HttpServletResponse.SC_BAD_REQUEST, BearerWwwAuthErrorCode.invalid_request);
numErrors.mark(); numErrors.mark();
return false; return false;
@ -348,9 +365,13 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
case JWT_EXPIRED: case JWT_EXPIRED:
case JWT_VALIDATION_EXCEPTION: case JWT_VALIDATION_EXCEPTION:
case PRINCIPAL_MISSING: case PRINCIPAL_MISSING:
if (authResponse.getJwtException() != null) { log.warn("Authentication failed. {}, {}", authResponse.getAuthCode(), exceptionMessage);
log.warn("Exception: {}", authResponse.getJwtException().getMessage()); authenticationFailure(response, authResponse.getAuthCode().getMsg(), HttpServletResponse.SC_UNAUTHORIZED, BearerWwwAuthErrorCode.invalid_token);
} numWrongCredentials.inc();
return false;
case SIGNATURE_INVALID:
log.warn("Signature validation failed: {}", exceptionMessage);
authenticationFailure(response, authResponse.getAuthCode().getMsg(), HttpServletResponse.SC_UNAUTHORIZED, BearerWwwAuthErrorCode.invalid_token); authenticationFailure(response, authResponse.getAuthCode().getMsg(), HttpServletResponse.SC_UNAUTHORIZED, BearerWwwAuthErrorCode.invalid_token);
numWrongCredentials.inc(); numWrongCredentials.inc();
return false; return false;
@ -427,6 +448,8 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
} else { } else {
return new JWTAuthenticationResponse(AuthCode.AUTHENTICATED, new JWTPrincipal(principal, jwtCompact, jwtClaims.getClaimsMap())); return new JWTAuthenticationResponse(AuthCode.AUTHENTICATED, new JWTPrincipal(principal, jwtCompact, jwtClaims.getClaimsMap()));
} }
} catch (InvalidJwtSignatureException ise) {
return new JWTAuthenticationResponse(AuthCode.SIGNATURE_INVALID, ise);
} catch (InvalidJwtException e) { } catch (InvalidJwtException e) {
// Whether or not the JWT has expired being one common reason for invalidity // Whether or not the JWT has expired being one common reason for invalidity
if (e.hasExpired()) { if (e.hasExpired()) {
@ -516,7 +539,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
return latestConf; return latestConf;
} }
private enum BearerWwwAuthErrorCode { invalid_request, invalid_token, insufficient_scope}; private enum BearerWwwAuthErrorCode { invalid_request, invalid_token, insufficient_scope}
private void authenticationFailure(HttpServletResponse response, String message, int httpCode, BearerWwwAuthErrorCode responseError) throws IOException { private void authenticationFailure(HttpServletResponse response, String message, int httpCode, BearerWwwAuthErrorCode responseError) throws IOException {
List<String> wwwAuthParams = new ArrayList<>(); List<String> wwwAuthParams = new ArrayList<>();
@ -561,7 +584,8 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
JWT_EXPIRED("JWT token expired"), // JWT token has expired JWT_EXPIRED("JWT token expired"), // JWT token has expired
CLAIM_MISMATCH("Required JWT claim missing"), // Some required claims are missing or wrong CLAIM_MISMATCH("Required JWT claim missing"), // Some required claims are missing or wrong
JWT_VALIDATION_EXCEPTION("JWT validation failed"), // The JWT parser failed validation. More details in exception JWT_VALIDATION_EXCEPTION("JWT validation failed"), // The JWT parser failed validation. More details in exception
SCOPE_MISSING("Required scope missing in JWT"); // None of the required scopes were present in JWT SCOPE_MISSING("Required scope missing in JWT"), // None of the required scopes were present in JWT
SIGNATURE_INVALID("Signature invalid"); // Validation of JWT signature failed
public String getMsg() { public String getMsg() {
return msg; return msg;
@ -647,6 +671,7 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
return parse(new ByteArrayInputStream(json.getBytes(charset))); return parse(new ByteArrayInputStream(json.getBytes(charset)));
} }
@SuppressWarnings("unchecked")
public static WellKnownDiscoveryConfig parse(InputStream configStream) { public static WellKnownDiscoveryConfig parse(InputStream configStream) {
securityConf = (Map<String, Object>) Utils.fromJSON(configStream); securityConf = (Map<String, Object>) Utils.fromJSON(configStream);
return new WellKnownDiscoveryConfig(securityConf); return new WellKnownDiscoveryConfig(securityConf);
@ -673,10 +698,12 @@ public class JWTAuthPlugin extends AuthenticationPlugin implements SpecProvider,
return (String) securityConf.get("token_endpoint"); return (String) securityConf.get("token_endpoint");
} }
@SuppressWarnings("unchecked")
public List<String> getScopesSupported() { public List<String> getScopesSupported() {
return (List<String>) securityConf.get("scopes_supported"); return (List<String>) securityConf.get("scopes_supported");
} }
@SuppressWarnings("unchecked")
public List<String> getResponseTypesSupported() { public List<String> getResponseTypesSupported() {
return (List<String>) securityConf.get("response_types_supported"); return (List<String>) securityConf.get("response_types_supported");
} }

View File

@ -24,6 +24,7 @@ import java.nio.file.Path;
import java.security.Principal; import java.security.Principal;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -34,12 +35,16 @@ import org.apache.solr.SolrTestCaseJ4;
import org.apache.solr.common.SolrException; import org.apache.solr.common.SolrException;
import org.apache.solr.common.util.Base64; import org.apache.solr.common.util.Base64;
import org.apache.solr.common.util.Utils; import org.apache.solr.common.util.Utils;
import org.jose4j.jwk.HttpsJwks;
import org.jose4j.jwk.JsonWebKey;
import org.jose4j.jwk.RsaJsonWebKey; import org.jose4j.jwk.RsaJsonWebKey;
import org.jose4j.jwk.RsaJwkGenerator; import org.jose4j.jwk.RsaJwkGenerator;
import org.jose4j.jws.AlgorithmIdentifiers; import org.jose4j.jws.AlgorithmIdentifiers;
import org.jose4j.jws.JsonWebSignature; import org.jose4j.jws.JsonWebSignature;
import org.jose4j.jwt.JwtClaims; import org.jose4j.jwt.JwtClaims;
import org.jose4j.keys.BigEndianBigInteger; import org.jose4j.keys.BigEndianBigInteger;
import org.jose4j.keys.resolvers.HttpsJwksVerificationKeyResolver;
import org.jose4j.lang.JoseException;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.BeforeClass; import org.junit.BeforeClass;
@ -49,6 +54,7 @@ import static org.apache.solr.security.JWTAuthPlugin.JWTAuthenticationResponse.A
import static org.apache.solr.security.JWTAuthPlugin.JWTAuthenticationResponse.AuthCode.NO_AUTZ_HEADER; import static org.apache.solr.security.JWTAuthPlugin.JWTAuthenticationResponse.AuthCode.NO_AUTZ_HEADER;
import static org.apache.solr.security.JWTAuthPlugin.JWTAuthenticationResponse.AuthCode.SCOPE_MISSING; import static org.apache.solr.security.JWTAuthPlugin.JWTAuthenticationResponse.AuthCode.SCOPE_MISSING;
@SuppressWarnings("unchecked")
public class JWTAuthPluginTest extends SolrTestCaseJ4 { public class JWTAuthPluginTest extends SolrTestCaseJ4 {
private static String testHeader; private static String testHeader;
private static String slimHeader; private static String slimHeader;
@ -178,6 +184,39 @@ public class JWTAuthPluginTest extends SolrTestCaseJ4 {
plugin.init(authConf); plugin.init(authConf);
} }
/**
* Simulate a rotate of JWK key in IdP.
* Validating of JWK signature will fail since we still use old cached JWK set.
* Using a mock {@link HttpsJwks} we validate that plugin calls refresh() and then passes validation
*/
@Test
public void invalidSigRefreshJwk() throws JoseException {
RsaJsonWebKey rsaJsonWebKey2 = RsaJwkGenerator.generateJwk(2048);
rsaJsonWebKey2.setKeyId("k2");
HashMap<String, Object> testJwkWrong = new HashMap<>();
testJwkWrong.put("kty", rsaJsonWebKey2.getKeyType());
testJwkWrong.put("e", BigEndianBigInteger.toBase64Url(rsaJsonWebKey2.getRsaPublicKey().getPublicExponent()));
testJwkWrong.put("use", rsaJsonWebKey2.getUse());
testJwkWrong.put("kid", rsaJsonWebKey2.getKeyId());
testJwkWrong.put("alg", rsaJsonWebKey2.getAlgorithm());
testJwkWrong.put("n", BigEndianBigInteger.toBase64Url(rsaJsonWebKey2.getRsaPublicKey().getModulus()));
JsonWebKey wrongJwk = JsonWebKey.Factory.newJwk(testJwkWrong);
// Configure our mock plugin with URL as jwk source
JsonWebKey correctJwk = JsonWebKey.Factory.newJwk(testJwk);
plugin = new MockJwksUrlPlugin(wrongJwk, correctJwk);
HashMap<String, Object> pluginConfigJwkUrl = new HashMap<>();
pluginConfigJwkUrl.put("class", "org.apache.solr.security.JWTAuthPlugin");
pluginConfigJwkUrl.put("jwkUrl", "dummy");
plugin.init(pluginConfigJwkUrl);
// Validate that plugin will call refresh() on invalid signature, then the call succeeds
assertFalse(((MockJwksUrlPlugin)plugin).isRefreshCalled());
JWTAuthPlugin.JWTAuthenticationResponse resp = plugin.authenticate(testHeader);
assertTrue(resp.isAuthenticated());
assertTrue(((MockJwksUrlPlugin)plugin).isRefreshCalled());
}
@Test @Test
public void parseJwkSet() throws Exception { public void parseJwkSet() throws Exception {
plugin.parseJwkSet(testJwk); plugin.parseJwkSet(testJwk);
@ -337,13 +376,13 @@ public class JWTAuthPluginTest extends SolrTestCaseJ4 {
@Test @Test
public void minimalConfigPassThrough() { public void minimalConfigPassThrough() {
testConfig.put("blockUnknown", false); testConfig.put("blockUnknown", false);
plugin.init(testConfig); plugin.init(minimalConfig);
JWTAuthPlugin.JWTAuthenticationResponse resp = plugin.authenticate(null); JWTAuthPlugin.JWTAuthenticationResponse resp = plugin.authenticate(null);
assertEquals(JWTAuthPlugin.JWTAuthenticationResponse.AuthCode.PASS_THROUGH, resp.getAuthCode()); assertEquals(JWTAuthPlugin.JWTAuthenticationResponse.AuthCode.PASS_THROUGH, resp.getAuthCode());
} }
@Test @Test
public void wellKnownConfig() throws IOException { public void wellKnownConfig() {
String wellKnownUrl = TEST_PATH().resolve("security").resolve("jwt_well-known-config.json").toAbsolutePath().toUri().toString(); String wellKnownUrl = TEST_PATH().resolve("security").resolve("jwt_well-known-config.json").toAbsolutePath().toUri().toString();
testConfig.put("wellKnownUrl", wellKnownUrl); testConfig.put("wellKnownUrl", wellKnownUrl);
testConfig.remove("jwk"); testConfig.remove("jwk");
@ -353,13 +392,13 @@ public class JWTAuthPluginTest extends SolrTestCaseJ4 {
} }
@Test(expected = SolrException.class) @Test(expected = SolrException.class)
public void onlyOneJwkConfig() throws IOException { public void onlyOneJwkConfig() {
testConfig.put("jwkUrl", "http://127.0.0.1:45678/.well-known/config"); testConfig.put("jwkUrl", "http://127.0.0.1:45678/myJwk");
plugin.init(testConfig); plugin.init(testConfig);
} }
@Test(expected = SolrException.class) @Test(expected = SolrException.class)
public void wellKnownConfigNotHttps() throws IOException { public void wellKnownConfigNotHttps() {
testConfig.put("wellKnownUrl", "http://127.0.0.1:45678/.well-known/config"); testConfig.put("wellKnownUrl", "http://127.0.0.1:45678/.well-known/config");
plugin.init(testConfig); plugin.init(testConfig);
} }
@ -402,4 +441,49 @@ public class JWTAuthPluginTest extends SolrTestCaseJ4 {
assertEquals("http://acmepaymentscorp/oauth/auz/authorize", parsed.get("authorizationEndpoint")); assertEquals("http://acmepaymentscorp/oauth/auz/authorize", parsed.get("authorizationEndpoint"));
assertEquals("solr-cluster", parsed.get("client_id")); assertEquals("solr-cluster", parsed.get("client_id"));
} }
/**
* Mock plugin that simulates a {@link HttpsJwks} with cached JWK that returns
* a different JWK after a call to refresh()
*/
private class MockJwksUrlPlugin extends JWTAuthPlugin {
private final JsonWebKey wrongJwk;
private final JsonWebKey correctJwk;
boolean isRefreshCalled() {
return refreshCalled;
}
private boolean refreshCalled;
MockJwksUrlPlugin(JsonWebKey wrongJwk, JsonWebKey correctJwk) {
this.wrongJwk = wrongJwk;
this.correctJwk = correctJwk;
}
@Override
void setupJwkUrl(String url) {
MockHttpsJwks httpsJkws = new MockHttpsJwks(url);
verificationKeyResolver = new HttpsJwksVerificationKeyResolver(httpsJkws);
}
private class MockHttpsJwks extends HttpsJwks {
MockHttpsJwks(String url) {
super(url);
}
@Override
public List<JsonWebKey> getJsonWebKeys() {
return refreshCalled ? Collections.singletonList(correctJwk) : Collections.singletonList(wrongJwk);
}
@Override
public void refresh() {
if (refreshCalled) {
fail("Refresh called twice");
}
refreshCalled = true;
}
}
}
} }