Polish spring-security-saml2 main code

Manually polish `spring-security-saml2` following the formatting
and checkstyle fixes.

Issue gh-8945
This commit is contained in:
Phillip Webb 2020-07-31 22:59:05 -07:00 committed by Rob Winch
parent e8094b8cf2
commit 1f03608b73
20 changed files with 179 additions and 230 deletions

View File

@ -118,27 +118,15 @@ public final class OpenSamlInitializationService {
private static boolean initialize(Consumer<XMLObjectProviderRegistry> registryConsumer) { private static boolean initialize(Consumer<XMLObjectProviderRegistry> registryConsumer) {
if (initialized.compareAndSet(false, true)) { if (initialized.compareAndSet(false, true)) {
log.trace("Initializing OpenSAML"); log.trace("Initializing OpenSAML");
try { try {
InitializationService.initialize(); InitializationService.initialize();
} }
catch (Exception ex) { catch (Exception ex) {
throw new Saml2Exception(ex); throw new Saml2Exception(ex);
} }
BasicParserPool parserPool = new BasicParserPool(); BasicParserPool parserPool = new BasicParserPool();
parserPool.setMaxPoolSize(50); parserPool.setMaxPoolSize(50);
parserPool.setBuilderFeatures(getParserBuilderFeatures());
Map<String, Boolean> parserBuilderFeatures = new HashMap<>();
parserBuilderFeatures.put("http://apache.org/xml/features/disallow-doctype-decl", Boolean.TRUE);
parserBuilderFeatures.put(XMLConstants.FEATURE_SECURE_PROCESSING, Boolean.TRUE);
parserBuilderFeatures.put("http://xml.org/sax/features/external-general-entities", Boolean.FALSE);
parserBuilderFeatures.put("http://apache.org/xml/features/validation/schema/normalized-value",
Boolean.FALSE);
parserBuilderFeatures.put("http://xml.org/sax/features/external-parameter-entities", Boolean.FALSE);
parserBuilderFeatures.put("http://apache.org/xml/features/dom/defer-node-expansion", Boolean.FALSE);
parserPool.setBuilderFeatures(parserBuilderFeatures);
try { try {
parserPool.initialize(); parserPool.initialize();
} }
@ -146,16 +134,23 @@ public final class OpenSamlInitializationService {
throw new Saml2Exception(ex); throw new Saml2Exception(ex);
} }
XMLObjectProviderRegistrySupport.setParserPool(parserPool); XMLObjectProviderRegistrySupport.setParserPool(parserPool);
registryConsumer.accept(ConfigurationService.get(XMLObjectProviderRegistry.class)); registryConsumer.accept(ConfigurationService.get(XMLObjectProviderRegistry.class));
log.debug("Initialized OpenSAML"); log.debug("Initialized OpenSAML");
return true; return true;
} }
else {
log.debug("Refused to re-initialize OpenSAML"); log.debug("Refused to re-initialize OpenSAML");
return false; return false;
} }
private static Map<String, Boolean> getParserBuilderFeatures() {
Map<String, Boolean> parserBuilderFeatures = new HashMap<>();
parserBuilderFeatures.put("http://apache.org/xml/features/disallow-doctype-decl", Boolean.TRUE);
parserBuilderFeatures.put(XMLConstants.FEATURE_SECURE_PROCESSING, Boolean.TRUE);
parserBuilderFeatures.put("http://xml.org/sax/features/external-general-entities", Boolean.FALSE);
parserBuilderFeatures.put("http://apache.org/xml/features/validation/schema/normalized-value", Boolean.FALSE);
parserBuilderFeatures.put("http://xml.org/sax/features/external-parameter-entities", Boolean.FALSE);
parserBuilderFeatures.put("http://apache.org/xml/features/dom/defer-node-expansion", Boolean.FALSE);
return parserBuilderFeatures;
} }
} }

View File

@ -37,12 +37,6 @@ import org.springframework.util.Assert;
*/ */
public final class Saml2X509Credential { public final class Saml2X509Credential {
public enum Saml2X509CredentialType {
VERIFICATION, ENCRYPTION, SIGNING, DECRYPTION,
}
private final PrivateKey privateKey; private final PrivateKey privateKey;
private final X509Certificate certificate; private final X509Certificate certificate;
@ -225,4 +219,16 @@ public final class Saml2X509Credential {
} }
} }
public enum Saml2X509CredentialType {
VERIFICATION,
ENCRYPTION,
SIGNING,
DECRYPTION,
}
} }

View File

@ -39,18 +39,6 @@ import org.springframework.util.Assert;
@Deprecated @Deprecated
public class Saml2X509Credential { public class Saml2X509Credential {
/**
* @deprecated Use
* {@link org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType}
* instead
*/
@Deprecated
public enum Saml2X509CredentialType {
VERIFICATION, ENCRYPTION, SIGNING, DECRYPTION,
}
private final PrivateKey privateKey; private final PrivateKey privateKey;
private final X509Certificate certificate; private final X509Certificate certificate;
@ -199,4 +187,22 @@ public class Saml2X509Credential {
} }
} }
/**
* @deprecated Use
* {@link org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType}
* instead
*/
@Deprecated
public enum Saml2X509CredentialType {
VERIFICATION,
ENCRYPTION,
SIGNING,
DECRYPTION,
}
} }

View File

@ -37,7 +37,6 @@ public class DefaultSaml2AuthenticatedPrincipal implements Saml2AuthenticatedPri
public DefaultSaml2AuthenticatedPrincipal(String name, Map<String, List<Object>> attributes) { public DefaultSaml2AuthenticatedPrincipal(String name, Map<String, List<Object>> attributes) {
Assert.notNull(name, "name cannot be null"); Assert.notNull(name, "name cannot be null");
Assert.notNull(attributes, "attributes cannot be null"); Assert.notNull(attributes, "attributes cannot be null");
this.name = name; this.name = name;
this.attributes = attributes; this.attributes = attributes;
} }

View File

@ -100,6 +100,7 @@ import org.w3c.dom.Document;
import org.w3c.dom.Element; import org.w3c.dom.Element;
import org.springframework.core.convert.converter.Converter; import org.springframework.core.convert.converter.Converter;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
@ -182,24 +183,16 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
private Duration responseTimeValidationSkew = Duration.ofMinutes(5); private Duration responseTimeValidationSkew = Duration.ofMinutes(5);
private Function<Saml2AuthenticationToken, Converter<Response, AbstractAuthenticationToken>> authenticationConverter = ( private Function<Saml2AuthenticationToken, Converter<Response, AbstractAuthenticationToken>> authenticationConverter = this::getAuthenticationConverter;
token) -> (response) -> {
Assertion assertion = CollectionUtils.firstElement(response.getAssertions());
String username = assertion.getSubject().getNameID().getValue();
Map<String, List<Object>> attributes = getAssertionAttributes(assertion);
return new Saml2Authentication(new DefaultSaml2AuthenticatedPrincipal(username, attributes),
token.getSaml2Response(),
this.authoritiesMapper.mapAuthorities(getAssertionAuthorities(assertion)));
};
private Converter<Saml2AuthenticationToken, SignatureTrustEngine> signatureTrustEngineConverter = new SignatureTrustEngineConverter(); private Converter<Saml2AuthenticationToken, SignatureTrustEngine> signatureTrustEngineConverter = new SignatureTrustEngineConverter();
private Converter<Tuple, SAML20AssertionValidator> assertionValidatorConverter = new SAML20AssertionValidatorConverter(); private Converter<TokenAndResponse, SAML20AssertionValidator> assertionValidatorConverter = new SAML20AssertionValidatorConverter();
private Collection<ConditionValidator> conditionValidators = Collections private Collection<ConditionValidator> conditionValidators = Collections
.singleton(new AudienceRestrictionConditionValidator()); .singleton(new AudienceRestrictionConditionValidator());
private Converter<Tuple, ValidationContext> validationContextConverter = new ValidationContextConverter(); private Converter<TokenAndResponse, ValidationContext> validationContextConverter = new ValidationContextConverter();
private Converter<Saml2AuthenticationToken, Decrypter> decrypterConverter = new DecrypterConverter(); private Converter<Saml2AuthenticationToken, Decrypter> decrypterConverter = new DecrypterConverter();
@ -220,7 +213,6 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
* @since 5.4 * @since 5.4
*/ */
public void setConditionValidators(Collection<ConditionValidator> conditionValidators) { public void setConditionValidators(Collection<ConditionValidator> conditionValidators) {
Assert.notEmpty(conditionValidators, "conditionValidators cannot be empty"); Assert.notEmpty(conditionValidators, "conditionValidators cannot be empty");
this.conditionValidators = conditionValidators; this.conditionValidators = conditionValidators;
} }
@ -231,8 +223,8 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
* @param validationContextConverter the strategy to use * @param validationContextConverter the strategy to use
* @since 5.4 * @since 5.4
*/ */
public void setValidationContextConverter(Converter<Tuple, ValidationContext> validationContextConverter) { public void setValidationContextConverter(
Converter<TokenAndResponse, ValidationContext> validationContextConverter) {
Assert.notNull(validationContextConverter, "validationContextConverter cannot be empty"); Assert.notNull(validationContextConverter, "validationContextConverter cannot be empty");
this.validationContextConverter = validationContextConverter; this.validationContextConverter = validationContextConverter;
} }
@ -289,13 +281,10 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
throw ex; throw ex;
} }
catch (Exception ex) { catch (Exception ex) {
throw authException(Saml2ErrorCodes.INTERNAL_VALIDATION_ERROR, ex.getMessage(), ex); throw createAuthenticationException(Saml2ErrorCodes.INTERNAL_VALIDATION_ERROR, ex.getMessage(), ex);
} }
} }
/**
* {@inheritDoc}
*/
@Override @Override
public boolean supports(Class<?> authentication) { public boolean supports(Class<?> authentication) {
return authentication != null && Saml2AuthenticationToken.class.isAssignableFrom(authentication); return authentication != null && Saml2AuthenticationToken.class.isAssignableFrom(authentication);
@ -313,39 +302,32 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
return (Response) this.responseUnmarshaller.unmarshall(element); return (Response) this.responseUnmarshaller.unmarshall(element);
} }
catch (Exception ex) { catch (Exception ex) {
throw authException(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, ex.getMessage(), ex); throw createAuthenticationException(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, ex.getMessage(), ex);
} }
} }
private void process(Saml2AuthenticationToken token, Response response) { private void process(Saml2AuthenticationToken token, Response response) {
String issuer = response.getIssuer().getValue(); String issuer = response.getIssuer().getValue();
if (logger.isDebugEnabled()) { logger.debug(LogMessage.format("Processing SAML response from %s", issuer));
logger.debug("Processing SAML response from " + issuer);
}
boolean responseSigned = response.isSigned(); boolean responseSigned = response.isSigned();
Map<String, Saml2AuthenticationException> validationExceptions = validateResponse(token, response); Map<String, Saml2AuthenticationException> validationExceptions = validateResponse(token, response);
Decrypter decrypter = this.decrypterConverter.convert(token); Decrypter decrypter = this.decrypterConverter.convert(token);
List<Assertion> assertions = decryptAssertions(decrypter, response); List<Assertion> assertions = decryptAssertions(decrypter, response);
if (!isSigned(responseSigned, assertions)) { if (!isSigned(responseSigned, assertions)) {
throw authException(Saml2ErrorCodes.INVALID_SIGNATURE, String description = "Either the response or one of the assertions is unsigned. "
"Either the response or one of the assertions is unsigned. " + "Please either sign the response or all of the assertions.";
+ "Please either sign the response or all of the assertions."); throw createAuthenticationException(Saml2ErrorCodes.INVALID_SIGNATURE, description, null);
} }
validationExceptions.putAll(validateAssertions(token, response)); validationExceptions.putAll(validateAssertions(token, response));
Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions()); Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions());
NameID nameId = decryptPrincipal(decrypter, firstAssertion); NameID nameId = decryptPrincipal(decrypter, firstAssertion);
if (nameId == null || nameId.getValue() == null) { if (nameId == null || nameId.getValue() == null) {
validationExceptions.put(Saml2ErrorCodes.SUBJECT_NOT_FOUND, authException(Saml2ErrorCodes.SUBJECT_NOT_FOUND, String description = "Assertion [" + firstAssertion.getID() + "] is missing a subject";
"Assertion [" + firstAssertion.getID() + "] is missing a subject")); validationExceptions.put(Saml2ErrorCodes.SUBJECT_NOT_FOUND,
createAuthenticationException(Saml2ErrorCodes.SUBJECT_NOT_FOUND, description, null));
} }
if (validationExceptions.isEmpty()) { if (validationExceptions.isEmpty()) {
if (logger.isDebugEnabled()) { logger.debug(LogMessage.of(() -> "Successfully processed SAML Response [" + response.getID() + "]"));
logger.debug("Successfully processed SAML Response [" + response.getID() + "]");
}
} }
else { else {
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
@ -357,7 +339,6 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
+ response.getID() + "]"); + response.getID() + "]");
} }
} }
if (!validationExceptions.isEmpty()) { if (!validationExceptions.isEmpty()) {
throw validationExceptions.values().iterator().next(); throw validationExceptions.values().iterator().next();
} }
@ -365,21 +346,17 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
private Map<String, Saml2AuthenticationException> validateResponse(Saml2AuthenticationToken token, private Map<String, Saml2AuthenticationException> validateResponse(Saml2AuthenticationToken token,
Response response) { Response response) {
Map<String, Saml2AuthenticationException> exceptions = new HashMap<>();
Map<String, Saml2AuthenticationException> validationExceptions = new HashMap<>();
String issuer = response.getIssuer().getValue(); String issuer = response.getIssuer().getValue();
if (response.isSigned()) { if (response.isSigned()) {
SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator(); SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator();
try { try {
profileValidator.validate(response.getSignature()); profileValidator.validate(response.getSignature());
} }
catch (Exception ex) { catch (Exception ex) {
validationExceptions.put(Saml2ErrorCodes.INVALID_SIGNATURE, String message = "Invalid signature for SAML Response [" + response.getID() + "]: ";
authException(Saml2ErrorCodes.INVALID_SIGNATURE, addValidationException(exceptions, Saml2ErrorCodes.INVALID_SIGNATURE, message, ex);
"Invalid signature for SAML Response [" + response.getID() + "]: ", ex));
} }
try { try {
CriteriaSet criteriaSet = new CriteriaSet(); CriteriaSet criteriaSet = new CriteriaSet();
criteriaSet.add(new EvaluableEntityIDCredentialCriterion(new EntityIdCriterion(issuer))); criteriaSet.add(new EvaluableEntityIDCredentialCriterion(new EntityIdCriterion(issuer)));
@ -387,34 +364,27 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS))); new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS)));
criteriaSet.add(new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING))); criteriaSet.add(new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING)));
if (!this.signatureTrustEngineConverter.convert(token).validate(response.getSignature(), criteriaSet)) { if (!this.signatureTrustEngineConverter.convert(token).validate(response.getSignature(), criteriaSet)) {
validationExceptions.put(Saml2ErrorCodes.INVALID_SIGNATURE, String message = "Invalid signature for SAML Response [" + response.getID() + "]";
authException(Saml2ErrorCodes.INVALID_SIGNATURE, addValidationException(exceptions, Saml2ErrorCodes.INVALID_SIGNATURE, message, null);
"Invalid signature for SAML Response [" + response.getID() + "]"));
} }
} }
catch (Exception ex) { catch (Exception ex) {
validationExceptions.put(Saml2ErrorCodes.INVALID_SIGNATURE, String message = "Invalid signature for SAML Response [" + response.getID() + "]: ";
authException(Saml2ErrorCodes.INVALID_SIGNATURE, addValidationException(exceptions, Saml2ErrorCodes.INVALID_SIGNATURE, message, ex);
"Invalid signature for SAML Response [" + response.getID() + "]: ", ex));
} }
} }
String destination = response.getDestination(); String destination = response.getDestination();
String location = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation(); String location = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
if (StringUtils.hasText(destination) && !destination.equals(location)) { if (StringUtils.hasText(destination) && !destination.equals(location)) {
String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID() + "]"; String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID() + "]";
validationExceptions.put(Saml2ErrorCodes.INVALID_DESTINATION, addValidationException(exceptions, Saml2ErrorCodes.INVALID_DESTINATION, message, null);
authException(Saml2ErrorCodes.INVALID_DESTINATION, message));
} }
String assertingPartyEntityId = token.getRelyingPartyRegistration().getAssertingPartyDetails().getEntityId(); String assertingPartyEntityId = token.getRelyingPartyRegistration().getAssertingPartyDetails().getEntityId();
if (!StringUtils.hasText(issuer) || !issuer.equals(assertingPartyEntityId)) { if (!StringUtils.hasText(issuer) || !issuer.equals(assertingPartyEntityId)) {
String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID()); String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID());
validationExceptions.put(Saml2ErrorCodes.INVALID_ISSUER, addValidationException(exceptions, Saml2ErrorCodes.INVALID_ISSUER, message, null);
authException(Saml2ErrorCodes.INVALID_ISSUER, message));
} }
return exceptions;
return validationExceptions;
} }
private List<Assertion> decryptAssertions(Decrypter decrypter, Response response) { private List<Assertion> decryptAssertions(Decrypter decrypter, Response response) {
@ -425,7 +395,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
assertions.add(assertion); assertions.add(assertion);
} }
catch (DecryptionException ex) { catch (DecryptionException ex) {
throw authException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex); throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
} }
} }
response.getAssertions().addAll(assertions); response.getAssertions().addAll(assertions);
@ -436,52 +406,47 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
Response response) { Response response) {
List<Assertion> assertions = response.getAssertions(); List<Assertion> assertions = response.getAssertions();
if (assertions.isEmpty()) { if (assertions.isEmpty()) {
throw authException(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, "No assertions found in response."); throw createAuthenticationException(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA,
"No assertions found in response.", null);
} }
Map<String, Saml2AuthenticationException> exceptions = new LinkedHashMap<>();
Map<String, Saml2AuthenticationException> validationExceptions = new LinkedHashMap<>(); logger.debug(LogMessage.format("Validating %s assertions", assertions.size()));
if (logger.isDebugEnabled()) { TokenAndResponse tuple = new TokenAndResponse(token, response);
logger.debug("Validating " + assertions.size() + " assertions");
}
Tuple tuple = new Tuple(token, response);
SAML20AssertionValidator validator = this.assertionValidatorConverter.convert(tuple); SAML20AssertionValidator validator = this.assertionValidatorConverter.convert(tuple);
ValidationContext context = this.validationContextConverter.convert(tuple); ValidationContext context = this.validationContextConverter.convert(tuple);
for (Assertion assertion : assertions) { for (Assertion assertion : assertions) {
if (logger.isTraceEnabled()) { logger.trace(LogMessage.format("Validating assertion %s", assertion.getID()));
logger.trace("Validating assertion " + assertion.getID());
}
try { try {
if (validator.validate(assertion, context) != ValidationResult.VALID) { if (validator.validate(assertion, context) != ValidationResult.VALID) {
String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s", String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s",
assertion.getID(), ((Response) assertion.getParent()).getID(), assertion.getID(), ((Response) assertion.getParent()).getID(),
context.getValidationFailureMessage()); context.getValidationFailureMessage());
validationExceptions.put(Saml2ErrorCodes.INVALID_ASSERTION, addValidationException(exceptions, Saml2ErrorCodes.INVALID_ASSERTION, message, null);
authException(Saml2ErrorCodes.INVALID_ASSERTION, message));
} }
} }
catch (Exception ex) { catch (Exception ex) {
String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s", assertion.getID(), String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s", assertion.getID(),
((Response) assertion.getParent()).getID(), ex.getMessage()); ((Response) assertion.getParent()).getID(), ex.getMessage());
validationExceptions.put(Saml2ErrorCodes.INVALID_ASSERTION, addValidationException(exceptions, Saml2ErrorCodes.INVALID_ASSERTION, message, ex);
authException(Saml2ErrorCodes.INVALID_ASSERTION, message, ex));
} }
} }
return exceptions;
}
return validationExceptions; private void addValidationException(Map<String, Saml2AuthenticationException> exceptions, String code,
String message, Exception cause) {
exceptions.put(code, createAuthenticationException(code, message, cause));
} }
private boolean isSigned(boolean responseSigned, List<Assertion> assertions) { private boolean isSigned(boolean responseSigned, List<Assertion> assertions) {
if (responseSigned) { if (responseSigned) {
return true; return true;
} }
for (Assertion assertion : assertions) { for (Assertion assertion : assertions) {
if (!assertion.isSigned()) { if (!assertion.isSigned()) {
return false; return false;
} }
} }
return true; return true;
} }
@ -498,7 +463,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
return nameId; return nameId;
} }
catch (DecryptionException ex) { catch (DecryptionException ex) {
throw authException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex); throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
} }
} }
@ -506,7 +471,6 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
Map<String, List<Object>> attributeMap = new LinkedHashMap<>(); Map<String, List<Object>> attributeMap = new LinkedHashMap<>();
for (AttributeStatement attributeStatement : assertion.getAttributeStatements()) { for (AttributeStatement attributeStatement : assertion.getAttributeStatements()) {
for (Attribute attribute : attributeStatement.getAttributes()) { for (Attribute attribute : attributeStatement.getAttributes()) {
List<Object> attributeValues = new ArrayList<>(); List<Object> attributeValues = new ArrayList<>();
for (XMLObject xmlObject : attribute.getAttributeValues()) { for (XMLObject xmlObject : attribute.getAttributeValues()) {
Object attributeValue = getXmlObjectValue(xmlObject); Object attributeValue = getXmlObjectValue(xmlObject);
@ -515,7 +479,6 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
} }
} }
attributeMap.put(attribute.getName(), attributeValues); attributeMap.put(attribute.getName(), attributeValues);
} }
} }
return attributeMap; return attributeMap;
@ -559,20 +522,22 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
return xsAny.getTextContent(); return xsAny.getTextContent();
} }
private static Saml2Error validationError(String code, String description) { private Converter<Response, AbstractAuthenticationToken> getAuthenticationConverter(
return new Saml2Error(code, description); Saml2AuthenticationToken token) {
return (response) -> convertAuthenticationToken(token, response);
} }
private static Saml2AuthenticationException authException(String code, String description) private AbstractAuthenticationToken convertAuthenticationToken(Saml2AuthenticationToken token, Response response) {
throws Saml2AuthenticationException { Assertion assertion = CollectionUtils.firstElement(response.getAssertions());
String username = assertion.getSubject().getNameID().getValue();
return new Saml2AuthenticationException(validationError(code, description)); Map<String, List<Object>> attributes = getAssertionAttributes(assertion);
return new Saml2Authentication(new DefaultSaml2AuthenticatedPrincipal(username, attributes),
token.getSaml2Response(), this.authoritiesMapper.mapAuthorities(getAssertionAuthorities(assertion)));
} }
private static Saml2AuthenticationException authException(String code, String description, Exception cause) private static Saml2AuthenticationException createAuthenticationException(String code, String message,
throws Saml2AuthenticationException { Exception cause) {
return new Saml2AuthenticationException(new Saml2Error(code, message), cause);
return new Saml2AuthenticationException(validationError(code, description), cause);
} }
private static class SignatureTrustEngineConverter private static class SignatureTrustEngineConverter
@ -596,10 +561,10 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
} }
private class ValidationContextConverter implements Converter<Tuple, ValidationContext> { private class ValidationContextConverter implements Converter<TokenAndResponse, ValidationContext> {
@Override @Override
public ValidationContext convert(Tuple tuple) { public ValidationContext convert(TokenAndResponse tuple) {
String audience = tuple.authentication.getRelyingPartyRegistration().getEntityId(); String audience = tuple.authentication.getRelyingPartyRegistration().getEntityId();
String recipient = tuple.authentication.getRelyingPartyRegistration().getAssertionConsumerServiceLocation(); String recipient = tuple.authentication.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
Map<String, Object> params = new HashMap<>(); Map<String, Object> params = new HashMap<>();
@ -607,17 +572,14 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
OpenSamlAuthenticationProvider.this.responseTimeValidationSkew.toMillis()); OpenSamlAuthenticationProvider.this.responseTimeValidationSkew.toMillis());
params.put(SAML2AssertionValidationParameters.COND_VALID_AUDIENCES, Collections.singleton(audience)); params.put(SAML2AssertionValidationParameters.COND_VALID_AUDIENCES, Collections.singleton(audience));
params.put(SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS, Collections.singleton(recipient)); params.put(SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS, Collections.singleton(recipient));
params.put(SAML2AssertionValidationParameters.SIGNATURE_REQUIRED, false); // this // this verification is performed earlier
// verification params.put(SAML2AssertionValidationParameters.SIGNATURE_REQUIRED, false);
// is
// performed
// earlier
return new ValidationContext(params); return new ValidationContext(params);
} }
} }
private class SAML20AssertionValidatorConverter implements Converter<Tuple, SAML20AssertionValidator> { private class SAML20AssertionValidatorConverter implements Converter<TokenAndResponse, SAML20AssertionValidator> {
private final Collection<SubjectConfirmationValidator> subjects = new ArrayList<>(); private final Collection<SubjectConfirmationValidator> subjects = new ArrayList<>();
@ -638,7 +600,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
} }
@Override @Override
public SAML20AssertionValidator convert(Tuple tuple) { public SAML20AssertionValidator convert(TokenAndResponse tuple) {
Collection<ConditionValidator> conditions = OpenSamlAuthenticationProvider.this.conditionValidators; Collection<ConditionValidator> conditions = OpenSamlAuthenticationProvider.this.conditionValidators;
return new SAML20AssertionValidator(conditions, this.subjects, this.statements, return new SAML20AssertionValidator(conditions, this.subjects, this.statements,
OpenSamlAuthenticationProvider.this.signatureTrustEngineConverter.convert(tuple.authentication), OpenSamlAuthenticationProvider.this.signatureTrustEngineConverter.convert(tuple.authentication),
@ -674,13 +636,13 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
* *
* @since 5.4 * @since 5.4
*/ */
public static final class Tuple { public static final class TokenAndResponse {
private final Saml2AuthenticationToken authentication; private final Saml2AuthenticationToken authentication;
private final Response response; private final Response response;
private Tuple(Saml2AuthenticationToken authentication, Response response) { private TokenAndResponse(Saml2AuthenticationToken authentication, Response response) {
this.authentication = authentication; this.authentication = authentication;
this.response = response; this.response = response;
} }

View File

@ -117,22 +117,15 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
throw new IllegalArgumentException("No signing credential provided"); throw new IllegalArgumentException("No signing credential provided");
} }
/**
* {@inheritDoc}
*/
@Override @Override
public Saml2PostAuthenticationRequest createPostAuthenticationRequest(Saml2AuthenticationRequestContext context) { public Saml2PostAuthenticationRequest createPostAuthenticationRequest(Saml2AuthenticationRequestContext context) {
AuthnRequest authnRequest = createAuthnRequest(context); AuthnRequest authnRequest = createAuthnRequest(context);
String xml = context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned() String xml = context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned()
? serialize(sign(authnRequest, context.getRelyingPartyRegistration())) : serialize(authnRequest); ? serialize(sign(authnRequest, context.getRelyingPartyRegistration())) : serialize(authnRequest);
return Saml2PostAuthenticationRequest.withAuthenticationRequestContext(context) return Saml2PostAuthenticationRequest.withAuthenticationRequestContext(context)
.samlRequest(Saml2Utils.samlEncode(xml.getBytes(StandardCharsets.UTF_8))).build(); .samlRequest(Saml2Utils.samlEncode(xml.getBytes(StandardCharsets.UTF_8))).build();
} }
/**
* {@inheritDoc}
*/
@Override @Override
public Saml2RedirectAuthenticationRequest createRedirectAuthenticationRequest( public Saml2RedirectAuthenticationRequest createRedirectAuthenticationRequest(
Saml2AuthenticationRequestContext context) { Saml2AuthenticationRequestContext context) {
@ -141,7 +134,6 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
Builder result = Saml2RedirectAuthenticationRequest.withAuthenticationRequestContext(context); Builder result = Saml2RedirectAuthenticationRequest.withAuthenticationRequestContext(context);
String deflatedAndEncoded = Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml)); String deflatedAndEncoded = Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml));
result.samlRequest(deflatedAndEncoded).relayState(context.getRelayState()); result.samlRequest(deflatedAndEncoded).relayState(context.getRelayState());
if (context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned()) { if (context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned()) {
Collection<Saml2X509Credential> signingCredentials = context.getRelyingPartyRegistration() Collection<Saml2X509Credential> signingCredentials = context.getRelyingPartyRegistration()
.getSigningX509Credentials(); .getSigningX509Credentials();
@ -154,7 +146,6 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
} }
throw new Saml2Exception("No signing credential provided"); throw new Saml2Exception("No signing credential provided");
} }
return result.build(); return result.build();
} }
@ -266,12 +257,10 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
.append(UriUtils.encode(relayState, StandardCharsets.ISO_8859_1)).append("&"); .append(UriUtils.encode(relayState, StandardCharsets.ISO_8859_1)).append("&");
} }
queryString.append("SigAlg").append("=").append(UriUtils.encode(algorithmUri, StandardCharsets.ISO_8859_1)); queryString.append("SigAlg").append("=").append(UriUtils.encode(algorithmUri, StandardCharsets.ISO_8859_1));
try { try {
byte[] rawSignature = XMLSigningUtil.signWithURI(credential, algorithmUri, byte[] rawSignature = XMLSigningUtil.signWithURI(credential, algorithmUri,
queryString.toString().getBytes(StandardCharsets.UTF_8)); queryString.toString().getBytes(StandardCharsets.UTF_8));
String b64Signature = Saml2Utils.samlEncode(rawSignature); String b64Signature = Saml2Utils.samlEncode(rawSignature);
Map<String, String> result = new LinkedHashMap<>(); Map<String, String> result = new LinkedHashMap<>();
result.put("SAMLRequest", samlRequest); result.put("SAMLRequest", samlRequest);
if (StringUtils.hasText(relayState)) { if (StringUtils.hasText(relayState)) {

View File

@ -56,7 +56,7 @@ public class Saml2AuthenticationException extends AuthenticationException {
* @param cause the root cause * @param cause the root cause
*/ */
public Saml2AuthenticationException(Saml2Error error, Throwable cause) { public Saml2AuthenticationException(Saml2Error error, Throwable cause) {
this(error, cause.getMessage(), cause); this(error, (cause != null) ? cause.getMessage() : error.getDescription(), cause);
} }
/** /**

View File

@ -52,7 +52,6 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
* @since 5.4 * @since 5.4
*/ */
public Saml2AuthenticationToken(RelyingPartyRegistration relyingPartyRegistration, String saml2Response) { public Saml2AuthenticationToken(RelyingPartyRegistration relyingPartyRegistration, String saml2Response) {
super(Collections.emptyList()); super(Collections.emptyList());
Assert.notNull(relyingPartyRegistration, "relyingPartyRegistration cannot be null"); Assert.notNull(relyingPartyRegistration, "relyingPartyRegistration cannot be null");
Assert.notNull(saml2Response, "saml2Response cannot be null"); Assert.notNull(saml2Response, "saml2Response cannot be null");

View File

@ -60,7 +60,6 @@ public class Saml2PostAuthenticationRequest extends AbstractSaml2AuthenticationR
public static final class Builder extends AbstractSaml2AuthenticationRequest.Builder<Builder> { public static final class Builder extends AbstractSaml2AuthenticationRequest.Builder<Builder> {
private Builder() { private Builder() {
super();
} }
/** /**

View File

@ -87,7 +87,6 @@ public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authe
private String signature; private String signature;
private Builder() { private Builder() {
super();
} }
/** /**

View File

@ -67,17 +67,12 @@ public final class OpenSamlMetadataResolver implements Saml2MetadataResolver {
Assert.notNull(this.entityDescriptorMarshaller, "entityDescriptorMarshaller cannot be null"); Assert.notNull(this.entityDescriptorMarshaller, "entityDescriptorMarshaller cannot be null");
} }
/**
* {@inheritDoc}
*/
@Override @Override
public String resolve(RelyingPartyRegistration relyingPartyRegistration) { public String resolve(RelyingPartyRegistration relyingPartyRegistration) {
EntityDescriptor entityDescriptor = build(EntityDescriptor.ELEMENT_QNAME); EntityDescriptor entityDescriptor = build(EntityDescriptor.ELEMENT_QNAME);
entityDescriptor.setEntityID(relyingPartyRegistration.getEntityId()); entityDescriptor.setEntityID(relyingPartyRegistration.getEntityId());
SPSSODescriptor spSsoDescriptor = buildSpSsoDescriptor(relyingPartyRegistration); SPSSODescriptor spSsoDescriptor = buildSpSsoDescriptor(relyingPartyRegistration);
entityDescriptor.getRoleDescriptors(SPSSODescriptor.DEFAULT_ELEMENT_NAME).add(spSsoDescriptor); entityDescriptor.getRoleDescriptors(SPSSODescriptor.DEFAULT_ELEMENT_NAME).add(spSsoDescriptor);
return serialize(entityDescriptor); return serialize(entityDescriptor);
} }
@ -107,17 +102,14 @@ public final class OpenSamlMetadataResolver implements Saml2MetadataResolver {
KeyInfo keyInfo = build(KeyInfo.DEFAULT_ELEMENT_NAME); KeyInfo keyInfo = build(KeyInfo.DEFAULT_ELEMENT_NAME);
X509Certificate x509Certificate = build(X509Certificate.DEFAULT_ELEMENT_NAME); X509Certificate x509Certificate = build(X509Certificate.DEFAULT_ELEMENT_NAME);
X509Data x509Data = build(X509Data.DEFAULT_ELEMENT_NAME); X509Data x509Data = build(X509Data.DEFAULT_ELEMENT_NAME);
try { try {
x509Certificate.setValue(new String(Base64.getEncoder().encode(certificate.getEncoded()))); x509Certificate.setValue(new String(Base64.getEncoder().encode(certificate.getEncoded())));
} }
catch (CertificateEncodingException ex) { catch (CertificateEncodingException ex) {
throw new Saml2Exception("Cannot encode certificate " + certificate.toString()); throw new Saml2Exception("Cannot encode certificate " + certificate.toString());
} }
x509Data.getX509Certificates().add(x509Certificate); x509Data.getX509Certificates().add(x509Certificate);
keyInfo.getX509Datas().add(x509Data); keyInfo.getX509Datas().add(x509Data);
keyDescriptor.setUse(usageType); keyDescriptor.setUse(usageType);
keyDescriptor.setKeyInfo(keyInfo); keyDescriptor.setKeyInfo(keyInfo);
return keyDescriptor; return keyDescriptor;

View File

@ -96,37 +96,24 @@ public class OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverter
this.parserPool = registry.getParserPool(); this.parserPool = registry.getParserPool();
} }
/**
* {@inheritDoc}
*/
@Override @Override
public boolean canRead(Class<?> clazz, MediaType mediaType) { public boolean canRead(Class<?> clazz, MediaType mediaType) {
return RelyingPartyRegistration.Builder.class.isAssignableFrom(clazz); return RelyingPartyRegistration.Builder.class.isAssignableFrom(clazz);
} }
/**
* {@inheritDoc}
*/
@Override @Override
public boolean canWrite(Class<?> clazz, MediaType mediaType) { public boolean canWrite(Class<?> clazz, MediaType mediaType) {
return false; return false;
} }
/**
* {@inheritDoc}
*/
@Override @Override
public List<MediaType> getSupportedMediaTypes() { public List<MediaType> getSupportedMediaTypes() {
return Arrays.asList(MediaType.APPLICATION_XML, MediaType.TEXT_XML); return Arrays.asList(MediaType.APPLICATION_XML, MediaType.TEXT_XML);
} }
/**
* {@inheritDoc}
*/
@Override @Override
public RelyingPartyRegistration.Builder read(Class<? extends RelyingPartyRegistration.Builder> clazz, public RelyingPartyRegistration.Builder read(Class<? extends RelyingPartyRegistration.Builder> clazz,
HttpInputMessage inputMessage) throws IOException, HttpMessageNotReadableException { HttpInputMessage inputMessage) throws IOException, HttpMessageNotReadableException {
EntityDescriptor descriptor = entityDescriptor(inputMessage.getBody()); EntityDescriptor descriptor = entityDescriptor(inputMessage.getBody());
IDPSSODescriptor idpssoDescriptor = descriptor.getIDPSSODescriptor(SAMLConstants.SAML20P_NS); IDPSSODescriptor idpssoDescriptor = descriptor.getIDPSSODescriptor(SAMLConstants.SAML20P_NS);
if (idpssoDescriptor == null) { if (idpssoDescriptor == null) {
@ -184,6 +171,32 @@ public class OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverter
"Metadata response is missing a SingleSignOnService, necessary for sending AuthnRequests"); "Metadata response is missing a SingleSignOnService, necessary for sending AuthnRequests");
} }
private List<Saml2X509Credential> getVerification(IDPSSODescriptor idpssoDescriptor) {
List<Saml2X509Credential> verification = new ArrayList<>();
for (KeyDescriptor keyDescriptor : idpssoDescriptor.getKeyDescriptors()) {
if (keyDescriptor.getUse().equals(UsageType.SIGNING)) {
List<X509Certificate> certificates = certificates(keyDescriptor);
for (X509Certificate certificate : certificates) {
verification.add(Saml2X509Credential.verification(certificate));
}
}
}
return verification;
}
private List<Saml2X509Credential> getEncryption(IDPSSODescriptor idpssoDescriptor) {
List<Saml2X509Credential> encryption = new ArrayList<>();
for (KeyDescriptor keyDescriptor : idpssoDescriptor.getKeyDescriptors()) {
if (keyDescriptor.getUse().equals(UsageType.ENCRYPTION)) {
List<X509Certificate> certificates = certificates(keyDescriptor);
for (X509Certificate certificate : certificates) {
encryption.add(Saml2X509Credential.encryption(certificate));
}
}
}
return encryption;
}
private List<X509Certificate> certificates(KeyDescriptor keyDescriptor) { private List<X509Certificate> certificates(KeyDescriptor keyDescriptor) {
try { try {
return KeyInfoSupport.getCertificates(keyDescriptor.getKeyInfo()); return KeyInfoSupport.getCertificates(keyDescriptor.getKeyInfo());

View File

@ -28,8 +28,6 @@ import java.util.function.Consumer;
import java.util.function.Function; import java.util.function.Function;
import org.springframework.security.saml2.core.Saml2X509Credential; import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.AssertingPartyDetails;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.ProviderDetails;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter; import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@ -94,7 +92,6 @@ public final class RelyingPartyRegistration {
Collection<org.springframework.security.saml2.credentials.Saml2X509Credential> credentials, Collection<org.springframework.security.saml2.credentials.Saml2X509Credential> credentials,
Collection<Saml2X509Credential> decryptionX509Credentials, Collection<Saml2X509Credential> decryptionX509Credentials,
Collection<Saml2X509Credential> signingX509Credentials) { Collection<Saml2X509Credential> signingX509Credentials) {
Assert.hasText(registrationId, "registrationId cannot be empty"); Assert.hasText(registrationId, "registrationId cannot be empty");
Assert.hasText(entityId, "entityId cannot be empty"); Assert.hasText(entityId, "entityId cannot be empty");
Assert.hasText(assertionConsumerServiceLocation, "assertionConsumerServiceLocation cannot be empty"); Assert.hasText(assertionConsumerServiceLocation, "assertionConsumerServiceLocation cannot be empty");
@ -332,7 +329,6 @@ public final class RelyingPartyRegistration {
private List<org.springframework.security.saml2.credentials.Saml2X509Credential> filterCredentials( private List<org.springframework.security.saml2.credentials.Saml2X509Credential> filterCredentials(
Function<org.springframework.security.saml2.credentials.Saml2X509Credential, Boolean> filter) { Function<org.springframework.security.saml2.credentials.Saml2X509Credential, Boolean> filter) {
List<org.springframework.security.saml2.credentials.Saml2X509Credential> result = new LinkedList<>(); List<org.springframework.security.saml2.credentials.Saml2X509Credential> result = new LinkedList<>();
for (org.springframework.security.saml2.credentials.Saml2X509Credential c : this.credentials) { for (org.springframework.security.saml2.credentials.Saml2X509Credential c : this.credentials) {
if (filter.apply(c)) { if (filter.apply(c)) {
@ -447,7 +443,6 @@ public final class RelyingPartyRegistration {
Collection<Saml2X509Credential> verificationX509Credentials, Collection<Saml2X509Credential> verificationX509Credentials,
Collection<Saml2X509Credential> encryptionX509Credentials, String singleSignOnServiceLocation, Collection<Saml2X509Credential> encryptionX509Credentials, String singleSignOnServiceLocation,
Saml2MessageBinding singleSignOnServiceBinding) { Saml2MessageBinding singleSignOnServiceBinding) {
Assert.hasText(entityId, "entityId cannot be null or empty"); Assert.hasText(entityId, "entityId cannot be null or empty");
Assert.notNull(verificationX509Credentials, "verificationX509Credentials cannot be null"); Assert.notNull(verificationX509Credentials, "verificationX509Credentials cannot be null");
for (Saml2X509Credential credential : verificationX509Credentials) { for (Saml2X509Credential credential : verificationX509Credentials) {
@ -1038,7 +1033,6 @@ public final class RelyingPartyRegistration {
for (Saml2X509Credential credential : this.providerDetails.assertingPartyDetailsBuilder.encryptionX509Credentials) { for (Saml2X509Credential credential : this.providerDetails.assertingPartyDetailsBuilder.encryptionX509Credentials) {
this.credentials.add(toDeprecated(credential)); this.credentials.add(toDeprecated(credential));
} }
return new RelyingPartyRegistration(this.registrationId, this.entityId, return new RelyingPartyRegistration(this.registrationId, this.entityId,
this.assertionConsumerServiceLocation, this.assertionConsumerServiceBinding, this.assertionConsumerServiceLocation, this.assertionConsumerServiceBinding,
this.providerDetails.build(), this.credentials, this.decryptionX509Credentials, this.providerDetails.build(), this.credentials, this.decryptionX509Credentials,

View File

@ -41,7 +41,6 @@ final class Saml2ServletUtils {
if (!StringUtils.hasText(template)) { if (!StringUtils.hasText(template)) {
return baseUrl; return baseUrl;
} }
String entityId = relyingParty.getAssertingPartyDetails().getEntityId(); String entityId = relyingParty.getAssertingPartyDetails().getEntityId();
String registrationId = relyingParty.getRegistrationId(); String registrationId = relyingParty.getRegistrationId();
Map<String, String> uriVariables = new HashMap<>(); Map<String, String> uriVariables = new HashMap<>();
@ -64,7 +63,6 @@ final class Saml2ServletUtils {
uriVariables.put("baseUrl", uriComponents.toUriString()); uriVariables.put("baseUrl", uriComponents.toUriString());
uriVariables.put("entityId", StringUtils.hasText(entityId) ? entityId : ""); uriVariables.put("entityId", StringUtils.hasText(entityId) ? entityId : "");
uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : ""); uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : "");
return UriComponentsBuilder.fromUriString(template).buildAndExpand(uriVariables).toUriString(); return UriComponentsBuilder.fromUriString(template).buildAndExpand(uriVariables).toUriString();
} }

View File

@ -131,13 +131,9 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
this.redirectMatcher = redirectMatcher; this.redirectMatcher = redirectMatcher;
} }
/**
* {@inheritDoc}
*/
@Override @Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException { throws ServletException, IOException {
MatchResult matcher = this.redirectMatcher.matcher(request); MatchResult matcher = this.redirectMatcher.matcher(request);
if (!matcher.isMatch()) { if (!matcher.isMatch()) {
filterChain.doFilter(request, response); filterChain.doFilter(request, response);
@ -192,26 +188,42 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
String authenticationRequestUri = authenticationRequest.getAuthenticationRequestUri(); String authenticationRequestUri = authenticationRequest.getAuthenticationRequestUri();
String relayState = authenticationRequest.getRelayState(); String relayState = authenticationRequest.getRelayState();
String samlRequest = authenticationRequest.getSamlRequest(); String samlRequest = authenticationRequest.getSamlRequest();
StringBuilder postHtml = new StringBuilder().append("<!DOCTYPE html>\n").append("<html>\n") StringBuilder html = new StringBuilder();
.append(" <head>\n").append(" <meta charset=\"utf-8\" />\n").append(" </head>\n") html.append("<!DOCTYPE html>\n");
.append(" <body onload=\"document.forms[0].submit()\">\n").append(" <noscript>\n") html.append("<html>\n").append(" <head>\n");
.append(" <p>\n") html.append(" <meta charset=\"utf-8\" />\n");
.append(" <strong>Note:</strong> Since your browser does not support JavaScript,\n") html.append(" </head>\n");
.append(" you must press the Continue button once to proceed.\n") html.append(" <body onload=\"document.forms[0].submit()\">\n");
.append(" </p>\n").append(" </noscript>\n").append(" \n") html.append(" <noscript>\n");
.append(" <form action=\"").append(authenticationRequestUri).append("\" method=\"post\">\n") html.append(" <p>\n");
.append(" <div>\n") html.append(" <strong>Note:</strong> Since your browser does not support JavaScript,\n");
.append(" <input type=\"hidden\" name=\"SAMLRequest\" value=\"") html.append(" you must press the Continue button once to proceed.\n");
.append(HtmlUtils.htmlEscape(samlRequest)).append("\"/>\n"); html.append(" </p>\n");
html.append(" </noscript>\n");
html.append(" \n");
html.append(" <form action=\"");
html.append(authenticationRequestUri);
html.append("\" method=\"post\">\n");
html.append(" <div>\n");
html.append(" <input type=\"hidden\" name=\"SAMLRequest\" value=\"");
html.append(HtmlUtils.htmlEscape(samlRequest));
html.append("\"/>\n");
if (StringUtils.hasText(relayState)) { if (StringUtils.hasText(relayState)) {
postHtml.append(" <input type=\"hidden\" name=\"RelayState\" value=\"") html.append(" <input type=\"hidden\" name=\"RelayState\" value=\"");
.append(HtmlUtils.htmlEscape(relayState)).append("\"/>\n"); html.append(HtmlUtils.htmlEscape(relayState));
html.append("\"/>\n");
} }
postHtml.append(" </div>\n").append(" <noscript>\n").append(" <div>\n") html.append(" </div>\n");
.append(" <input type=\"submit\" value=\"Continue\"/>\n") html.append(" <noscript>\n");
.append(" </div>\n").append(" </noscript>\n").append(" </form>\n") html.append(" <div>\n");
.append(" \n").append(" </body>\n").append("</html>"); html.append(" <input type=\"submit\" value=\"Continue\"/>\n");
return postHtml.toString(); html.append(" </div>\n");
html.append(" </noscript>\n");
html.append(" </form>\n");
html.append(" \n");
html.append(" </body>\n");
html.append("</html>");
return html.toString();
} }
} }

View File

@ -52,7 +52,6 @@ public final class DefaultRelyingPartyRegistrationResolver
public DefaultRelyingPartyRegistrationResolver( public DefaultRelyingPartyRegistrationResolver(
RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) { RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {
Assert.notNull(relyingPartyRegistrationRepository, "relyingPartyRegistrationRepository cannot be null"); Assert.notNull(relyingPartyRegistrationRepository, "relyingPartyRegistrationRepository cannot be null");
this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository; this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository;
} }
@ -68,7 +67,6 @@ public final class DefaultRelyingPartyRegistrationResolver
if (relyingPartyRegistration == null) { if (relyingPartyRegistration == null) {
return null; return null;
} }
String applicationUri = getApplicationUri(request); String applicationUri = getApplicationUri(request);
Function<String, String> templateResolver = templateResolver(applicationUri, relyingPartyRegistration); Function<String, String> templateResolver = templateResolver(applicationUri, relyingPartyRegistration);
String relyingPartyEntityId = templateResolver.apply(relyingPartyRegistration.getEntityId()); String relyingPartyEntityId = templateResolver.apply(relyingPartyRegistration.getEntityId());
@ -104,7 +102,6 @@ public final class DefaultRelyingPartyRegistrationResolver
uriVariables.put("baseUrl", uriComponents.toUriString()); uriVariables.put("baseUrl", uriComponents.toUriString());
uriVariables.put("entityId", StringUtils.hasText(entityId) ? entityId : ""); uriVariables.put("entityId", StringUtils.hasText(entityId) ? entityId : "");
uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : ""); uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : "");
return UriComponentsBuilder.fromUriString(template).buildAndExpand(uriVariables).toUriString(); return UriComponentsBuilder.fromUriString(template).buildAndExpand(uriVariables).toUriString();
} }

View File

@ -47,9 +47,6 @@ public final class DefaultSaml2AuthenticationRequestContextResolver
this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver; this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
} }
/**
* {@inheritDoc}
*/
@Override @Override
public Saml2AuthenticationRequestContext resolve(HttpServletRequest request) { public Saml2AuthenticationRequestContext resolve(HttpServletRequest request) {
Assert.notNull(request, "request cannot be null"); Assert.notNull(request, "request cannot be null");

View File

@ -60,9 +60,6 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver; this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
} }
/**
* {@inheritDoc}
*/
@Override @Override
public Saml2AuthenticationToken convert(HttpServletRequest request) { public Saml2AuthenticationToken convert(HttpServletRequest request) {
RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.convert(request); RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.convert(request);
@ -82,10 +79,8 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
if (HttpMethod.GET.matches(request.getMethod())) { if (HttpMethod.GET.matches(request.getMethod())) {
return samlInflate(b); return samlInflate(b);
} }
else {
return new String(b, StandardCharsets.UTF_8); return new String(b, StandardCharsets.UTF_8);
} }
}
private byte[] samlDecode(String s) { private byte[] samlDecode(String s) {
return BASE64.decode(s); return BASE64.decode(s);
@ -94,9 +89,9 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
private String samlInflate(byte[] b) { private String samlInflate(byte[] b) {
try { try {
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true)); InflaterOutputStream inflaterOutputStream = new InflaterOutputStream(out, new Inflater(true));
iout.write(b); inflaterOutputStream.write(b);
iout.finish(); inflaterOutputStream.finish();
return new String(out.toByteArray(), StandardCharsets.UTF_8); return new String(out.toByteArray(), StandardCharsets.UTF_8);
} }
catch (IOException ex) { catch (IOException ex) {

View File

@ -60,19 +60,16 @@ public final class Saml2MetadataFilter extends OncePerRequestFilter {
@Override @Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain) protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws ServletException, IOException { throws ServletException, IOException {
RequestMatcher.MatchResult matcher = this.requestMatcher.matcher(request); RequestMatcher.MatchResult matcher = this.requestMatcher.matcher(request);
if (!matcher.isMatch()) { if (!matcher.isMatch()) {
chain.doFilter(request, response); chain.doFilter(request, response);
return; return;
} }
RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationConverter.convert(request); RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationConverter.convert(request);
if (relyingPartyRegistration == null) { if (relyingPartyRegistration == null) {
response.setStatus(HttpServletResponse.SC_UNAUTHORIZED); response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
return; return;
} }
String metadata = this.saml2MetadataResolver.resolve(relyingPartyRegistration); String metadata = this.saml2MetadataResolver.resolve(relyingPartyRegistration);
String registrationId = relyingPartyRegistration.getRegistrationId(); String registrationId = relyingPartyRegistration.getRegistrationId();
writeMetadataToResponse(response, registrationId, metadata); writeMetadataToResponse(response, registrationId, metadata);
@ -80,7 +77,6 @@ public final class Saml2MetadataFilter extends OncePerRequestFilter {
private void writeMetadataToResponse(HttpServletResponse response, String registrationId, String metadata) private void writeMetadataToResponse(HttpServletResponse response, String registrationId, String metadata)
throws IOException { throws IOException {
response.setContentType(MediaType.APPLICATION_XML_VALUE); response.setContentType(MediaType.APPLICATION_XML_VALUE);
response.setHeader(HttpHeaders.CONTENT_DISPOSITION, response.setHeader(HttpHeaders.CONTENT_DISPOSITION,
"attachment; filename=\"saml-" + registrationId + "-metadata.xml\""); "attachment; filename=\"saml-" + registrationId + "-metadata.xml\"");

View File

@ -45,11 +45,12 @@ public final class Saml2Utils {
public static byte[] samlDeflate(String s) { public static byte[] samlDeflate(String s) {
try { try {
ByteArrayOutputStream b = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
DeflaterOutputStream deflater = new DeflaterOutputStream(b, new Deflater(Deflater.DEFLATED, true)); DeflaterOutputStream deflaterOutputStream = new DeflaterOutputStream(out,
deflater.write(s.getBytes(StandardCharsets.UTF_8)); new Deflater(Deflater.DEFLATED, true));
deflater.finish(); deflaterOutputStream.write(s.getBytes(StandardCharsets.UTF_8));
return b.toByteArray(); deflaterOutputStream.finish();
return out.toByteArray();
} }
catch (IOException ex) { catch (IOException ex) {
throw new Saml2Exception("Unable to deflate string", ex); throw new Saml2Exception("Unable to deflate string", ex);
@ -59,9 +60,9 @@ public final class Saml2Utils {
public static String samlInflate(byte[] b) { public static String samlInflate(byte[] b) {
try { try {
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true)); InflaterOutputStream inflaterOutputStream = new InflaterOutputStream(out, new Inflater(true));
iout.write(b); inflaterOutputStream.write(b);
iout.finish(); inflaterOutputStream.finish();
return new String(out.toByteArray(), StandardCharsets.UTF_8); return new String(out.toByteArray(), StandardCharsets.UTF_8);
} }
catch (IOException ex) { catch (IOException ex) {