Single-key Key Selector

Fixes: gh-7049
Fixes: gh-7056
This commit is contained in:
Josh Cummings 2019-06-28 10:55:19 -06:00
parent 3b5a4189ef
commit ce79ef2634
5 changed files with 208 additions and 52 deletions

View File

@ -29,10 +29,6 @@ import javax.crypto.SecretKey;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.RemoteKeySourceException;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jose.jwk.source.ImmutableJWKSet;
import com.nimbusds.jose.jwk.source.ImmutableSecret;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
import com.nimbusds.jose.proc.JWSKeySelector;
@ -316,17 +312,12 @@ public final class NimbusJwtDecoder implements JwtDecoder {
*/
public static final class PublicKeyJwtDecoderBuilder {
private JWSAlgorithm jwsAlgorithm;
private RSAKey key;
private RSAPublicKey key;
private PublicKeyJwtDecoderBuilder(RSAPublicKey key) {
Assert.notNull(key, "key cannot be null");
this.jwsAlgorithm = JWSAlgorithm.RS256;
this.key = rsaKey(key);
}
private static RSAKey rsaKey(RSAPublicKey publicKey) {
return new RSAKey.Builder(publicKey)
.build();
this.key = key;
}
/**
@ -352,10 +343,8 @@ public final class NimbusJwtDecoder implements JwtDecoder {
this.jwsAlgorithm + ". Please indicate one of RS256, RS384, or RS512.");
}
JWKSet jwkSet = new JWKSet(this.key);
JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(jwkSet);
JWSKeySelector<SecurityContext> jwsKeySelector =
new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource);
new SingleKeyJWSKeySelector<>(this.jwsAlgorithm, this.key);
DefaultJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
jwtProcessor.setJWSKeySelector(jwsKeySelector);
@ -414,9 +403,8 @@ public final class NimbusJwtDecoder implements JwtDecoder {
}
JWTProcessor<SecurityContext> processor() {
JWKSource<SecurityContext> jwkSource = new ImmutableSecret<>(this.secretKey);
JWSKeySelector<SecurityContext> jwsKeySelector =
new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource);
new SingleKeyJWSKeySelector<>(this.jwsAlgorithm, this.secretKey);
DefaultJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
jwtProcessor.setJWSKeySelector(jwsKeySelector);

View File

@ -30,12 +30,7 @@ import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKMatcher;
import com.nimbusds.jose.jwk.JWKSelector;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jose.jwk.source.ImmutableJWKSet;
import com.nimbusds.jose.jwk.source.ImmutableSecret;
import com.nimbusds.jose.jwk.source.JWKSecurityContextJWKSet;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.proc.BadJOSEException;
import com.nimbusds.jose.proc.JWKSecurityContext;
import com.nimbusds.jose.proc.JWSKeySelector;
@ -318,20 +313,15 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
* @since 5.2
*/
public static final class PublicKeyReactiveJwtDecoderBuilder {
private final RSAKey key;
private final RSAPublicKey key;
private JWSAlgorithm jwsAlgorithm;
private PublicKeyReactiveJwtDecoderBuilder(RSAPublicKey key) {
Assert.notNull(key, "key cannot be null");
this.key = rsaKey(key);
this.key = key;
this.jwsAlgorithm = JWSAlgorithm.RS256;
}
private static RSAKey rsaKey(RSAPublicKey publicKey) {
return new RSAKey.Builder(publicKey)
.build();
}
/**
* Use the given signing
* <a href="https://tools.ietf.org/html/rfc7515#section-4.1.1" target="_blank">algorithm</a>.
@ -363,10 +353,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
this.jwsAlgorithm + ". Please indicate one of RS256, RS384, or RS512.");
}
JWKSet jwkSet = new JWKSet(this.key);
JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(jwkSet);
JWSKeySelector<SecurityContext> jwsKeySelector =
new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource);
new SingleKeyJWSKeySelector<>(this.jwsAlgorithm, this.key);
DefaultJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
jwtProcessor.setJWSKeySelector(jwsKeySelector);
@ -418,9 +406,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
}
Converter<JWT, Mono<JWTClaimsSet>> processor() {
JWKSource<SecurityContext> jwkSource = new ImmutableSecret<>(this.secretKey);
JWSKeySelector<SecurityContext> jwsKeySelector =
new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource);
new SingleKeyJWSKeySelector<>(this.jwsAlgorithm, this.secretKey);
DefaultJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
jwtProcessor.setJWSKeySelector(jwsKeySelector);

View File

@ -0,0 +1,54 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.jwt;
import java.security.Key;
import java.util.Arrays;
import java.util.List;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
import org.springframework.util.Assert;
/**
* An internal implementation of {@link JWSKeySelector} that always returns the same key
*
* @author Josh Cummings
* @since 5.2
*/
final class SingleKeyJWSKeySelector<C extends SecurityContext> implements JWSKeySelector<C> {
private final List<Key> keySet;
private final JWSAlgorithm expectedJwsAlgorithm;
SingleKeyJWSKeySelector(JWSAlgorithm expectedJwsAlgorithm, Key key) {
Assert.notNull(expectedJwsAlgorithm, "expectedJwsAlgorithm cannot be null");
Assert.notNull(key, "key cannot be null");
this.keySet = Arrays.asList(key);
this.expectedJwsAlgorithm = expectedJwsAlgorithm;
}
@Override
public List<? extends Key> selectJWSKeys(JWSHeader header, C context) {
if (!this.expectedJwsAlgorithm.equals(header.getAlgorithm())) {
throw new IllegalArgumentException("Unsupported algorithm of " + header.getAlgorithm());
}
return this.keySet;
}
}

View File

@ -15,18 +15,93 @@
*/
package org.springframework.security.oauth2.jose;
import java.security.KeyFactory;
import java.security.NoSuchAlgorithmException;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.X509EncodedKeySpec;
import java.util.Base64;
import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;
import java.util.Base64;
/**
* @author Joe Grandja
* @since 5.2
*/
public class TestKeys {
public static final KeyFactory kf;
static {
try {
kf = KeyFactory.getInstance("RSA");
} catch (NoSuchAlgorithmException e) {
throw new IllegalStateException(e);
}
}
public static final String DEFAULT_ENCODED_SECRET_KEY = "bCzY/M48bbkwBEWjmNSIEPfwApcvXOnkCxORBEbPr+4=";
public static final SecretKey DEFAULT_SECRET_KEY =
new SecretKeySpec(Base64.getDecoder().decode(DEFAULT_ENCODED_SECRET_KEY), "AES");
public static final String DEFAULT_RSA_PUBLIC_KEY =
"MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA3FlqJr5TRskIQIgdE3Dd" +
"7D9lboWdcTUT8a+fJR7MAvQm7XXNoYkm3v7MQL1NYtDvL2l8CAnc0WdSTINU6IRv" +
"c5Kqo2Q4csNX9SHOmEfzoROjQqahEcve1jBXluoCXdYuYpx4/1tfRgG6ii4Uhxh6" +
"iI8qNMJQX+fLfqhbfYfxBQVRPywBkAbIP4x1EAsbC6FSNmkhCxiMNqEgxaIpY8C2" +
"kJdJ/ZIV+WW4noDdzpKqHcwmB8FsrumlVY/DNVvUSDIipiq9PbP4H99TXN1o746o" +
"RaNa07rq1hoCgMSSy+85SagCoxlmyE+D+of9SsMY8Ol9t0rdzpobBuhyJ/o5dfvj" +
"KwIDAQAB";
public static final RSAPublicKey DEFAULT_PUBLIC_KEY = publicKey();
private static RSAPublicKey publicKey() {
X509EncodedKeySpec spec = new X509EncodedKeySpec(Base64.getDecoder().decode(DEFAULT_RSA_PUBLIC_KEY));
try {
return (RSAPublicKey) kf.generatePublic(spec);
} catch (InvalidKeySpecException e) {
throw new IllegalArgumentException(e);
}
}
public static final String DEFAULT_RSA_PRIVATE_KEY =
"MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDcWWomvlNGyQhA" +
"iB0TcN3sP2VuhZ1xNRPxr58lHswC9Cbtdc2hiSbe/sxAvU1i0O8vaXwICdzRZ1JM" +
"g1TohG9zkqqjZDhyw1f1Ic6YR/OhE6NCpqERy97WMFeW6gJd1i5inHj/W19GAbqK" +
"LhSHGHqIjyo0wlBf58t+qFt9h/EFBVE/LAGQBsg/jHUQCxsLoVI2aSELGIw2oSDF" +
"oiljwLaQl0n9khX5ZbiegN3OkqodzCYHwWyu6aVVj8M1W9RIMiKmKr09s/gf31Nc" +
"3WjvjqhFo1rTuurWGgKAxJLL7zlJqAKjGWbIT4P6h/1Kwxjw6X23St3OmhsG6HIn" +
"+jl1++MrAgMBAAECggEBAMf820wop3pyUOwI3aLcaH7YFx5VZMzvqJdNlvpg1jbE" +
"E2Sn66b1zPLNfOIxLcBG8x8r9Ody1Bi2Vsqc0/5o3KKfdgHvnxAB3Z3dPh2WCDek" +
"lCOVClEVoLzziTuuTdGO5/CWJXdWHcVzIjPxmK34eJXioiLaTYqN3XKqKMdpD0ZG" +
"mtNTGvGf+9fQ4i94t0WqIxpMpGt7NM4RHy3+Onggev0zLiDANC23mWrTsUgect/7" +
"62TYg8g1bKwLAb9wCBT+BiOuCc2wrArRLOJgUkj/F4/gtrR9ima34SvWUyoUaKA0" +
"bi4YBX9l8oJwFGHbU9uFGEMnH0T/V0KtIB7qetReywkCgYEA9cFyfBIQrYISV/OA" +
"+Z0bo3vh2aL0QgKrSXZ924cLt7itQAHNZ2ya+e3JRlTczi5mnWfjPWZ6eJB/8MlH" +
"Gpn12o/POEkU+XjZZSPe1RWGt5g0S3lWqyx9toCS9ACXcN9tGbaqcFSVI73zVTRA" +
"8J9grR0fbGn7jaTlTX2tnlOTQ60CgYEA5YjYpEq4L8UUMFkuj+BsS3u0oEBnzuHd" +
"I9LEHmN+CMPosvabQu5wkJXLuqo2TxRnAznsA8R3pCLkdPGoWMCiWRAsCn979TdY" +
"QbqO2qvBAD2Q19GtY7lIu6C35/enQWzJUMQE3WW0OvjLzZ0l/9mA2FBRR+3F9A1d" +
"rBdnmv0c3TcCgYEAi2i+ggVZcqPbtgrLOk5WVGo9F1GqUBvlgNn30WWNTx4zIaEk" +
"HSxtyaOLTxtq2odV7Kr3LGiKxwPpn/T+Ief+oIp92YcTn+VfJVGw4Z3BezqbR8lA" +
"Uf/+HF5ZfpMrVXtZD4Igs3I33Duv4sCuqhEvLWTc44pHifVloozNxYfRfU0CgYBN" +
"HXa7a6cJ1Yp829l62QlJKtx6Ymj95oAnQu5Ez2ROiZMqXRO4nucOjGUP55Orac1a" +
"FiGm+mC/skFS0MWgW8evaHGDbWU180wheQ35hW6oKAb7myRHtr4q20ouEtQMdQIF" +
"snV39G1iyqeeAsf7dxWElydXpRi2b68i3BIgzhzebQKBgQCdUQuTsqV9y/JFpu6H" +
"c5TVvhG/ubfBspI5DhQqIGijnVBzFT//UfIYMSKJo75qqBEyP2EJSmCsunWsAFsM" +
"TszuiGTkrKcZy9G0wJqPztZZl2F2+bJgnA6nBEV7g5PA4Af+QSmaIhRwqGDAuROR" +
"47jndeyIaMTNETEmOnms+as17g==";
public static final RSAPrivateKey DEFAULT_PRIVATE_KEY = privateKey();
private static RSAPrivateKey privateKey() {
PKCS8EncodedKeySpec spec = new PKCS8EncodedKeySpec(Base64.getDecoder().decode(DEFAULT_RSA_PRIVATE_KEY));
try {
return (RSAPrivateKey) kf.generatePrivate(spec);
} catch (InvalidKeySpecException e) {
throw new IllegalArgumentException(e);
}
}
}

View File

@ -16,11 +16,29 @@
package org.springframework.security.oauth2.jwt;
import java.security.KeyFactory;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.EncodedKeySpec;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.X509EncodedKeySpec;
import java.text.ParseException;
import java.time.Instant;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.Date;
import java.util.Map;
import javax.crypto.SecretKey;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.JWSSigner;
import com.nimbusds.jose.crypto.MACSigner;
import com.nimbusds.jose.crypto.RSASSASigner;
import com.nimbusds.jose.proc.BadJOSEException;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jwt.JWTClaimsSet;
@ -32,6 +50,7 @@ import okhttp3.mockwebserver.MockWebServer;
import org.assertj.core.api.Assertions;
import org.junit.BeforeClass;
import org.junit.Test;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpStatus;
import org.springframework.http.RequestEntity;
@ -44,21 +63,6 @@ import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.web.client.RestOperations;
import javax.crypto.SecretKey;
import java.security.KeyFactory;
import java.security.NoSuchAlgorithmException;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.EncodedKeySpec;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.X509EncodedKeySpec;
import java.text.ParseException;
import java.time.Instant;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.Date;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatCode;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
@ -66,7 +70,9 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.*;
import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withJwkSetUri;
import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withPublicKey;
import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withSecretKey;
/**
* Tests for {@link NimbusJwtDecoder}
@ -266,6 +272,23 @@ public class NimbusJwtDecoderTests {
.isEqualTo("test-subject");
}
// gh-7049
@Test
public void decodeWhenUsingPublicKeyWithKidThenStillUsesKey() throws Exception {
RSAPublicKey publicKey = TestKeys.DEFAULT_PUBLIC_KEY;
RSAPrivateKey privateKey = TestKeys.DEFAULT_PRIVATE_KEY;
JWSHeader header = new JWSHeader.Builder(JWSAlgorithm.RS256).keyID("one").build();
JWTClaimsSet claimsSet = new JWTClaimsSet.Builder()
.subject("test-subject")
.expirationTime(Date.from(Instant.now().plusSeconds(60)))
.build();
SignedJWT signedJwt = signedJwt(privateKey, header, claimsSet);
NimbusJwtDecoder decoder = withPublicKey(publicKey).signatureAlgorithm(SignatureAlgorithm.RS256).build();
assertThat(decoder.decode(signedJwt.serialize()))
.extracting(Jwt::getSubject)
.isEqualTo("test-subject");
}
@Test
public void decodeWhenSignatureMismatchesAlgorithmThenThrowsException() throws Exception {
NimbusJwtDecoder decoder = withPublicKey(key()).signatureAlgorithm(SignatureAlgorithm.RS512).build();
@ -315,7 +338,23 @@ public class NimbusJwtDecoderTests {
NimbusJwtDecoder decoder = withSecretKey(secretKey).macAlgorithm(MacAlgorithm.HS512).build();
assertThatThrownBy(() -> decoder.decode(signedJWT.serialize()))
.isInstanceOf(JwtException.class)
.hasMessage("An error occurred while attempting to decode the Jwt: Signed JWT rejected: Another algorithm expected, or no matching key(s) found");
.hasMessageContaining("Unsupported algorithm of HS256");
}
// gh-7056
@Test
public void decodeWhenUsingSecertKeyWithKidThenStillUsesKey() throws Exception {
SecretKey secretKey = TestKeys.DEFAULT_SECRET_KEY;
JWSHeader header = new JWSHeader.Builder(JWSAlgorithm.HS256).keyID("one").build();
JWTClaimsSet claimsSet = new JWTClaimsSet.Builder()
.subject("test-subject")
.expirationTime(Date.from(Instant.now().plusSeconds(60)))
.build();
SignedJWT signedJwt = signedJwt(secretKey, header, claimsSet);
NimbusJwtDecoder decoder = withSecretKey(secretKey).macAlgorithm(MacAlgorithm.HS256).build();
assertThat(decoder.decode(signedJwt.serialize()))
.extracting(Jwt::getSubject)
.isEqualTo("test-subject");
}
private RSAPublicKey key() throws InvalidKeySpecException {
@ -325,8 +364,21 @@ public class NimbusJwtDecoderTests {
}
private SignedJWT signedJwt(SecretKey secretKey, MacAlgorithm jwsAlgorithm, JWTClaimsSet claimsSet) throws Exception {
SignedJWT signedJWT = new SignedJWT(new JWSHeader(JWSAlgorithm.parse(jwsAlgorithm.getName())), claimsSet);
return signedJwt(secretKey, new JWSHeader(JWSAlgorithm.parse(jwsAlgorithm.getName())), claimsSet);
}
private SignedJWT signedJwt(SecretKey secretKey, JWSHeader header, JWTClaimsSet claimsSet) throws Exception {
JWSSigner signer = new MACSigner(secretKey);
return signedJwt(signer, header, claimsSet);
}
private SignedJWT signedJwt(PrivateKey privateKey, JWSHeader header, JWTClaimsSet claimsSet) throws Exception {
JWSSigner signer = new RSASSASigner(privateKey);
return signedJwt(signer, header, claimsSet);
}
private SignedJWT signedJwt(JWSSigner signer, JWSHeader header, JWTClaimsSet claimsSet) throws Exception {
SignedJWT signedJWT = new SignedJWT(header, claimsSet);
signedJWT.sign(signer);
return signedJWT;
}