Complete SAML 2.0 SP Metadata Endpoint

Closes gh-8693
This commit is contained in:
Josh Cummings 2020-08-04 22:35:27 -06:00
parent 8a355240bc
commit b999faa5a0
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
15 changed files with 373 additions and 360 deletions

View File

@ -73,9 +73,6 @@ final class FilterComparator implements Comparator<Filter>, Serializable {
filterToOrder.put(
"org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter",
order.next());
filterToOrder.put(
"org.springframework.security.saml2.provider.service.web.Saml2MetadataFilter",
order.next());
filterToOrder.put(
"org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter",
order.next());

View File

@ -38,11 +38,8 @@ import org.springframework.security.saml2.provider.service.servlet.filter.Saml2W
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter;
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
import org.springframework.security.saml2.provider.service.web.OpenSamlMetadataResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
import org.springframework.security.saml2.provider.service.web.Saml2MetadataFilter;
import org.springframework.security.saml2.provider.service.web.Saml2MetadataResolver;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint;
import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter;
@ -113,15 +110,10 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>> extend
private RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
private AuthenticationConverter authenticationConverter;
private Saml2MetadataResolver saml2MetadataResolver;
private AuthenticationManager authenticationManager;
private Saml2WebSsoAuthenticationFilter saml2WebSsoAuthenticationFilter;
private Saml2MetadataFilter saml2MetadataFilter;
/**
* Use this {@link AuthenticationConverter} when converting incoming requests to an {@link Authentication}.
* By default the {@link Saml2AuthenticationTokenConverter} is used.
@ -162,16 +154,6 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>> extend
return this;
}
/**
* Sets the {@code Saml2MetadataResolver}
* @param saml2MetadataResolver the implementation of the metadata resolver
* @return the {@link Saml2LoginConfigurer} for further configuration
*/
public Saml2LoginConfigurer saml2MetadataResolver(Saml2MetadataResolver saml2MetadataResolver) {
this.saml2MetadataResolver = saml2MetadataResolver;
return this;
}
/**
* {@inheritDoc}
*/
@ -229,14 +211,6 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>> extend
setAuthenticationFilter(saml2WebSsoAuthenticationFilter);
super.loginProcessingUrl(this.loginProcessingUrl);
if (this.saml2MetadataResolver == null) {
this.saml2MetadataResolver = new OpenSamlMetadataResolver();
}
saml2MetadataFilter = new Saml2MetadataFilter(
this.relyingPartyRegistrationRepository, this.saml2MetadataResolver
);
if (hasText(this.loginPage)) {
// Set custom login page
super.loginPage(this.loginPage);
@ -276,7 +250,6 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>> extend
@Override
public void configure(B http) throws Exception {
http.addFilter(this.authenticationRequestEndpoint.build(http));
http.addFilter(saml2MetadataFilter);
super.configure(http);
if (this.authenticationManager == null) {
registerDefaultAuthenticationProvider(http);

View File

@ -30,7 +30,7 @@ import org.springframework.security.saml2.credentials.Saml2X509Credential
import org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.VERIFICATION
import org.springframework.security.saml2.provider.service.registration.InMemoryRelyingPartyRegistrationRepository
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration
import org.springframework.security.saml2.provider.service.web.Saml2WebSsoAuthenticationFilter
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter
import org.springframework.test.web.servlet.MockMvc
import org.springframework.test.web.servlet.get
import java.security.cert.Certificate

View File

@ -0,0 +1,151 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.saml2.provider.service.metadata;
import java.security.cert.CertificateEncodingException;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collection;
import java.util.List;
import javax.xml.namespace.QName;
import net.shibboleth.utilities.java.support.xml.SerializeSupport;
import org.opensaml.core.xml.XMLObjectBuilder;
import org.opensaml.saml.common.xml.SAMLConstants;
import org.opensaml.saml.saml2.metadata.AssertionConsumerService;
import org.opensaml.saml.saml2.metadata.EntityDescriptor;
import org.opensaml.saml.saml2.metadata.KeyDescriptor;
import org.opensaml.saml.saml2.metadata.SPSSODescriptor;
import org.opensaml.saml.saml2.metadata.impl.EntityDescriptorMarshaller;
import org.opensaml.security.credential.UsageType;
import org.opensaml.xmlsec.signature.KeyInfo;
import org.opensaml.xmlsec.signature.X509Certificate;
import org.opensaml.xmlsec.signature.X509Data;
import org.w3c.dom.Element;
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.core.OpenSamlInitializationService;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.util.Assert;
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory;
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getMarshallerFactory;
/**
* Resolves the SAML 2.0 Relying Party Metadata for a given {@link RelyingPartyRegistration}
* using the OpenSAML API.
*
* @author Jakub Kubrynski
* @author Josh Cummings
* @since 5.4
*/
public final class OpenSamlMetadataResolver implements Saml2MetadataResolver {
static {
OpenSamlInitializationService.initialize();
}
private final EntityDescriptorMarshaller entityDescriptorMarshaller;
public OpenSamlMetadataResolver() {
this.entityDescriptorMarshaller = (EntityDescriptorMarshaller)
getMarshallerFactory().getMarshaller(EntityDescriptor.DEFAULT_ELEMENT_NAME);
Assert.notNull(this.entityDescriptorMarshaller, "entityDescriptorMarshaller cannot be null");
}
/**
* {@inheritDoc}
*/
@Override
public String resolve(RelyingPartyRegistration relyingPartyRegistration) {
EntityDescriptor entityDescriptor = build(EntityDescriptor.ELEMENT_QNAME);
entityDescriptor.setEntityID(relyingPartyRegistration.getEntityId());
SPSSODescriptor spSsoDescriptor = buildSpSsoDescriptor(relyingPartyRegistration);
entityDescriptor.getRoleDescriptors(SPSSODescriptor.DEFAULT_ELEMENT_NAME).add(spSsoDescriptor);
return serialize(entityDescriptor);
}
private SPSSODescriptor buildSpSsoDescriptor(RelyingPartyRegistration registration) {
SPSSODescriptor spSsoDescriptor = build(SPSSODescriptor.DEFAULT_ELEMENT_NAME);
spSsoDescriptor.addSupportedProtocol(SAMLConstants.SAML20P_NS);
spSsoDescriptor.setWantAssertionsSigned(true);
spSsoDescriptor.getKeyDescriptors().addAll(buildKeys(
registration.getSigningX509Credentials(), UsageType.SIGNING));
spSsoDescriptor.getKeyDescriptors().addAll(buildKeys(
registration.getDecryptionX509Credentials(), UsageType.ENCRYPTION));
spSsoDescriptor.getAssertionConsumerServices().add(buildAssertionConsumerService(registration));
return spSsoDescriptor;
}
private List<KeyDescriptor> buildKeys(Collection<Saml2X509Credential> credentials, UsageType usageType) {
List<KeyDescriptor> list = new ArrayList<>();
for (Saml2X509Credential credential : credentials) {
KeyDescriptor keyDescriptor = buildKeyDescriptor(usageType, credential.getCertificate());
list.add(keyDescriptor);
}
return list;
}
private KeyDescriptor buildKeyDescriptor(UsageType usageType, java.security.cert.X509Certificate certificate) {
KeyDescriptor keyDescriptor = build(KeyDescriptor.DEFAULT_ELEMENT_NAME);
KeyInfo keyInfo = build(KeyInfo.DEFAULT_ELEMENT_NAME);
X509Certificate x509Certificate = build(X509Certificate.DEFAULT_ELEMENT_NAME);
X509Data x509Data = build(X509Data.DEFAULT_ELEMENT_NAME);
try {
x509Certificate.setValue(new String(Base64.getEncoder().encode(certificate.getEncoded())));
} catch (CertificateEncodingException e) {
throw new Saml2Exception("Cannot encode certificate " + certificate.toString());
}
x509Data.getX509Certificates().add(x509Certificate);
keyInfo.getX509Datas().add(x509Data);
keyDescriptor.setUse(usageType);
keyDescriptor.setKeyInfo(keyInfo);
return keyDescriptor;
}
private AssertionConsumerService buildAssertionConsumerService(RelyingPartyRegistration registration) {
AssertionConsumerService assertionConsumerService = build(AssertionConsumerService.DEFAULT_ELEMENT_NAME);
assertionConsumerService.setLocation(registration.getAssertionConsumerServiceLocation());
assertionConsumerService.setBinding(registration.getAssertionConsumerServiceBinding().getUrn());
assertionConsumerService.setIndex(1);
return assertionConsumerService;
}
@SuppressWarnings("unchecked")
private <T> T build(QName elementName) {
XMLObjectBuilder<?> builder = getBuilderFactory().getBuilder(elementName);
if (builder == null) {
throw new Saml2Exception("Unable to resolve Builder for " + elementName);
}
return (T) builder.buildObject(elementName);
}
private String serialize(EntityDescriptor entityDescriptor) {
try {
Element element = this.entityDescriptorMarshaller.marshall(entityDescriptor);
return SerializeSupport.prettyPrintXML(element);
} catch (Exception e) {
throw new Saml2Exception(e);
}
}
}

View File

@ -14,16 +14,23 @@
* limitations under the License.
*/
package org.springframework.security.saml2.provider.service.web;
package org.springframework.security.saml2.provider.service.metadata;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import javax.servlet.http.HttpServletRequest;
/**
* Resolves the SAML 2.0 Relying Party Metadata for a given {@link RelyingPartyRegistration}
*
* @author Jakub Kubrynski
* @author Josh Cummings
* @since 5.4
*/
public interface Saml2MetadataResolver {
String resolveMetadata(HttpServletRequest request, RelyingPartyRegistration registration);
/**
* Resolve the given relying party's metadata
*
* @param relyingPartyRegistration the relying party
* @return the relying party's metadata
*/
String resolve(RelyingPartyRegistration relyingPartyRegistration);
}

View File

@ -29,7 +29,6 @@ import java.util.function.Function;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
import org.springframework.security.saml2.provider.service.web.Saml2WebSsoAuthenticationFilter;
import org.springframework.util.Assert;
/**
@ -361,7 +360,6 @@ public class RelyingPartyRegistration {
.encryptionX509Credentials(c -> c.addAll(registration.getAssertingPartyDetails().getEncryptionX509Credentials()))
.singleSignOnServiceLocation(registration.getAssertingPartyDetails().getSingleSignOnServiceLocation())
.singleSignOnServiceBinding(registration.getAssertingPartyDetails().getSingleSignOnServiceBinding())
.nameIdFormat(registration.getAssertingPartyDetails().getNameIdFormat())
);
}
@ -377,7 +375,6 @@ public class RelyingPartyRegistration {
private final Collection<Saml2X509Credential> verificationX509Credentials;
private final Collection<Saml2X509Credential> encryptionX509Credentials;
private final String singleSignOnServiceLocation;
private final String nameIdFormat;
private final Saml2MessageBinding singleSignOnServiceBinding;
private AssertingPartyDetails(
@ -386,7 +383,6 @@ public class RelyingPartyRegistration {
Collection<Saml2X509Credential> verificationX509Credentials,
Collection<Saml2X509Credential> encryptionX509Credentials,
String singleSignOnServiceLocation,
String nameIdFormat,
Saml2MessageBinding singleSignOnServiceBinding) {
Assert.hasText(entityId, "entityId cannot be null or empty");
@ -409,7 +405,6 @@ public class RelyingPartyRegistration {
this.verificationX509Credentials = verificationX509Credentials;
this.encryptionX509Credentials = encryptionX509Credentials;
this.singleSignOnServiceLocation = singleSignOnServiceLocation;
this.nameIdFormat = nameIdFormat;
this.singleSignOnServiceBinding = singleSignOnServiceBinding;
}
@ -477,15 +472,6 @@ public class RelyingPartyRegistration {
return this.singleSignOnServiceLocation;
}
/**
* Get the NameIDFormat setting, indicating which user property should be used as a NameID Format attribute
*
* @return the NameIdFormat value
*/
public String getNameIdFormat() {
return nameIdFormat;
}
/**
* Get the
* <a href="https://wiki.shibboleth.net/confluence/display/CONCEPT/MetadataForIdP#MetadataForIdP-SingleSign-OnServices">SingleSignOnService</a>
@ -507,7 +493,6 @@ public class RelyingPartyRegistration {
private Collection<Saml2X509Credential> verificationX509Credentials = new HashSet<>();
private Collection<Saml2X509Credential> encryptionX509Credentials = new HashSet<>();
private String singleSignOnServiceLocation;
private String nameIdFormat = "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified";
private Saml2MessageBinding singleSignOnServiceBinding = Saml2MessageBinding.REDIRECT;
/**
@ -577,18 +562,6 @@ public class RelyingPartyRegistration {
return this;
}
/**
* Set the preference for name identifier returned by IdP.
* See <a href="https://wiki.shibboleth.net/confluence/display/SHIB/NameIdentifierFormat">for possible values</a>
*
* @param nameIdFormat the name identifier
* @return the {@link ProviderDetails.Builder} for further configuration
*/
public Builder nameIdFormat(String nameIdFormat) {
this.nameIdFormat = nameIdFormat;
return this;
}
/**
* Set the
* <a href="https://wiki.shibboleth.net/confluence/display/CONCEPT/MetadataForIdP#MetadataForIdP-SingleSign-OnServices">SingleSignOnService</a>
@ -617,7 +590,6 @@ public class RelyingPartyRegistration {
this.verificationX509Credentials,
this.encryptionX509Credentials,
this.singleSignOnServiceLocation,
this.nameIdFormat,
this.singleSignOnServiceBinding
);
}

View File

@ -1,161 +0,0 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.saml2.provider.service.web;
import net.shibboleth.utilities.java.support.xml.SerializeSupport;
import org.opensaml.core.xml.XMLObjectBuilder;
import org.opensaml.core.xml.XMLObjectBuilderFactory;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.core.xml.io.Marshaller;
import org.opensaml.saml.common.xml.SAMLConstants;
import org.opensaml.saml.saml2.metadata.AssertionConsumerService;
import org.opensaml.saml.saml2.metadata.EntityDescriptor;
import org.opensaml.saml.saml2.metadata.KeyDescriptor;
import org.opensaml.saml.saml2.metadata.NameIDFormat;
import org.opensaml.saml.saml2.metadata.SPSSODescriptor;
import org.opensaml.security.credential.UsageType;
import org.opensaml.xmlsec.signature.KeyInfo;
import org.opensaml.xmlsec.signature.X509Certificate;
import org.opensaml.xmlsec.signature.X509Data;
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.credentials.Saml2X509Credential;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2ServletUtils;
import org.w3c.dom.Element;
import javax.servlet.http.HttpServletRequest;
import javax.xml.namespace.QName;
import java.security.cert.CertificateEncodingException;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
/**
* @author Jakub Kubrynski
* @since 5.4
*/
public class OpenSamlMetadataResolver implements Saml2MetadataResolver {
@Override
public String resolveMetadata(HttpServletRequest request, RelyingPartyRegistration registration) {
XMLObjectBuilderFactory builderFactory = XMLObjectProviderRegistrySupport.getBuilderFactory();
EntityDescriptor entityDescriptor = buildObject(builderFactory, EntityDescriptor.ELEMENT_QNAME);
entityDescriptor.setEntityID(
resolveTemplate(registration.getEntityId(), registration, request));
SPSSODescriptor spSsoDescriptor = buildSpSsoDescriptor(registration, builderFactory, request);
entityDescriptor.getRoleDescriptors(SPSSODescriptor.DEFAULT_ELEMENT_NAME).add(spSsoDescriptor);
return serializeToXmlString(entityDescriptor);
}
private String serializeToXmlString(EntityDescriptor entityDescriptor) {
Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(entityDescriptor);
if (marshaller == null) {
throw new Saml2Exception("Unable to resolve Marshaller");
}
Element element;
try {
element = marshaller.marshall(entityDescriptor);
} catch (Exception e) {
throw new Saml2Exception(e);
}
return SerializeSupport.prettyPrintXML(element);
}
private SPSSODescriptor buildSpSsoDescriptor(RelyingPartyRegistration registration,
XMLObjectBuilderFactory builderFactory, HttpServletRequest request) {
SPSSODescriptor spSsoDescriptor = buildObject(builderFactory, SPSSODescriptor.DEFAULT_ELEMENT_NAME);
spSsoDescriptor.setAuthnRequestsSigned(registration.getAssertingPartyDetails().getWantAuthnRequestsSigned());
spSsoDescriptor.setWantAssertionsSigned(true);
spSsoDescriptor.addSupportedProtocol(SAMLConstants.SAML20P_NS);
NameIDFormat nameIdFormat = buildObject(builderFactory, NameIDFormat.DEFAULT_ELEMENT_NAME);
nameIdFormat.setFormat(registration.getAssertingPartyDetails().getNameIdFormat());
spSsoDescriptor.getNameIDFormats().add(nameIdFormat);
spSsoDescriptor.getAssertionConsumerServices().add(
buildAssertionConsumerService(registration, builderFactory, request));
spSsoDescriptor.getKeyDescriptors().addAll(buildKeys(builderFactory,
registration.getSigningCredentials(), UsageType.SIGNING));
spSsoDescriptor.getKeyDescriptors().addAll(buildKeys(builderFactory,
registration.getEncryptionCredentials(), UsageType.ENCRYPTION));
return spSsoDescriptor;
}
private List<KeyDescriptor> buildKeys(XMLObjectBuilderFactory builderFactory,
List<Saml2X509Credential> credentials, UsageType usageType) {
List<KeyDescriptor> list = new ArrayList<>();
for (Saml2X509Credential credential : credentials) {
KeyDescriptor keyDescriptor = buildKeyDescriptor(builderFactory, usageType, credential.getCertificate());
list.add(keyDescriptor);
}
return list;
}
private KeyDescriptor buildKeyDescriptor(XMLObjectBuilderFactory builderFactory, UsageType usageType,
java.security.cert.X509Certificate certificate) {
KeyDescriptor keyDescriptor = buildObject(builderFactory, KeyDescriptor.DEFAULT_ELEMENT_NAME);
KeyInfo keyInfo = buildObject(builderFactory, KeyInfo.DEFAULT_ELEMENT_NAME);
X509Certificate x509Certificate = buildObject(builderFactory, X509Certificate.DEFAULT_ELEMENT_NAME);
X509Data x509Data = buildObject(builderFactory, X509Data.DEFAULT_ELEMENT_NAME);
try {
x509Certificate.setValue(new String(Base64.getEncoder().encode(certificate.getEncoded())));
} catch (CertificateEncodingException e) {
throw new Saml2Exception("Cannot encode certificate " + certificate.toString());
}
x509Data.getX509Certificates().add(x509Certificate);
keyInfo.getX509Datas().add(x509Data);
keyDescriptor.setUse(usageType);
keyDescriptor.setKeyInfo(keyInfo);
return keyDescriptor;
}
private AssertionConsumerService buildAssertionConsumerService(RelyingPartyRegistration registration,
XMLObjectBuilderFactory builderFactory, HttpServletRequest request) {
AssertionConsumerService assertionConsumerService = buildObject(builderFactory, AssertionConsumerService.DEFAULT_ELEMENT_NAME);
assertionConsumerService.setLocation(
resolveTemplate(registration.getAssertionConsumerServiceLocation(), registration, request));
assertionConsumerService.setBinding(registration.getAssertingPartyDetails().getSingleSignOnServiceBinding().getUrn());
assertionConsumerService.setIndex(1);
return assertionConsumerService;
}
@SuppressWarnings("unchecked")
private <T> T buildObject(XMLObjectBuilderFactory builderFactory, QName elementName) {
XMLObjectBuilder<?> builder = builderFactory.getBuilder(elementName);
if (builder == null) {
throw new Saml2Exception("Cannot build object - builder not defined for element " + elementName);
}
return (T) builder.buildObject(elementName);
}
private String resolveTemplate(String template, RelyingPartyRegistration registration, HttpServletRequest request) {
return Saml2ServletUtils.resolveUrlTemplate(template, Saml2ServletUtils.getApplicationUri(request), registration);
}
}

View File

@ -16,66 +16,85 @@
package org.springframework.security.saml2.provider.service.web;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.web.filter.OncePerRequestFilter;
import java.io.IOException;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.security.saml2.provider.service.metadata.Saml2MetadataResolver;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
import org.springframework.web.filter.OncePerRequestFilter;
/**
* This {@code Servlet} returns a generated Service Provider Metadata XML
* A {@link javax.servlet.Filter} that returns the metadata for a Relying Party
*
* @since 5.4
* @author Jakub Kubrynski
* @author Josh Cummings
* @since 5.4
*/
public class Saml2MetadataFilter extends OncePerRequestFilter {
public final class Saml2MetadataFilter extends OncePerRequestFilter {
private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
private final Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationConverter;
private final Saml2MetadataResolver saml2MetadataResolver;
private RequestMatcher redirectMatcher = new AntPathRequestMatcher("/saml2/service-provider-metadata/{registrationId}");
private RequestMatcher requestMatcher = new AntPathRequestMatcher(
"/saml2/service-provider-metadata/{registrationId}");
public Saml2MetadataFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository, Saml2MetadataResolver saml2MetadataResolver) {
this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository;
public Saml2MetadataFilter(
Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationConverter,
Saml2MetadataResolver saml2MetadataResolver) {
this.relyingPartyRegistrationConverter = relyingPartyRegistrationConverter;
this.saml2MetadataResolver = saml2MetadataResolver;
}
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws ServletException, IOException {
RequestMatcher.MatchResult matcher = this.redirectMatcher.matcher(request);
RequestMatcher.MatchResult matcher = this.requestMatcher.matcher(request);
if (!matcher.isMatch()) {
filterChain.doFilter(request, response);
chain.doFilter(request, response);
return;
}
String registrationId = matcher.getVariables().get("registrationId");
RelyingPartyRegistration registration = relyingPartyRegistrationRepository.findByRegistrationId(registrationId);
if (registration == null) {
RelyingPartyRegistration relyingPartyRegistration =
this.relyingPartyRegistrationConverter.convert(request);
if (relyingPartyRegistration == null) {
response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
return;
}
String xml = saml2MetadataResolver.resolveMetadata(request, registration);
writeMetadataToResponse(response, registrationId, xml);
String metadata = this.saml2MetadataResolver.resolve(relyingPartyRegistration);
String registrationId = relyingPartyRegistration.getRegistrationId();
writeMetadataToResponse(response, registrationId, metadata);
}
private void writeMetadataToResponse(HttpServletResponse response, String registrationId, String xml) throws IOException {
private void writeMetadataToResponse(HttpServletResponse response, String registrationId, String metadata)
throws IOException {
response.setContentType(MediaType.APPLICATION_XML_VALUE);
response.setHeader(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"saml-" + registrationId + "-metadata.xml\"");
response.setContentLength(xml.length());
response.getWriter().write(xml);
response.setHeader(HttpHeaders.CONTENT_DISPOSITION,
"attachment; filename=\"saml-" + registrationId + "-metadata.xml\"");
response.setContentLength(metadata.length());
response.getWriter().write(metadata);
}
/**
* Set the {@link RequestMatcher} that determines whether this filter should
* handle the incoming {@link HttpServletRequest}
*
* @param requestMatcher
*/
public void setRequestMatcher(RequestMatcher requestMatcher) {
Assert.notNull(requestMatcher, "requestMatcher cannot be null");
this.requestMatcher = requestMatcher;
}
}

View File

@ -0,0 +1,80 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.saml2.provider.service.metadata;
import org.junit.Test;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.security.saml2.core.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT;
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.full;
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.noCredentials;
/**
* Tests for {@link OpenSamlMetadataResolver}
*/
public class OpenSamlMetadataResolverTests {
@Test
public void resolveWhenRelyingPartyThenMetadataMatches() {
// given
RelyingPartyRegistration relyingPartyRegistration = full()
.assertionConsumerServiceBinding(REDIRECT)
.build();
OpenSamlMetadataResolver openSamlMetadataResolver = new OpenSamlMetadataResolver();
// when
String metadata = openSamlMetadataResolver.resolve(relyingPartyRegistration);
// then
assertThat(metadata)
.contains("<EntityDescriptor")
.contains("entityID=\"rp-entity-id\"")
.contains("WantAssertionsSigned=\"true\"")
.contains("<md:KeyDescriptor use=\"signing\">")
.contains("<md:KeyDescriptor use=\"encryption\">")
.contains("<ds:X509Certificate>MIICgTCCAeoCCQCuVzyqFgMSyDANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UEBh")
.contains("Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect\"")
.contains("Location=\"https://rp.example.org/acs\" index=\"1\"");
}
@Test
public void resolveWhenRelyingPartyNoCredentialsThenMetadataMatches() {
// given
RelyingPartyRegistration relyingPartyRegistration = noCredentials()
.assertingPartyDetails(party -> party
.verificationX509Credentials(c -> c.add(relyingPartyVerifyingCredential()))
)
.build();
OpenSamlMetadataResolver openSamlMetadataResolver = new OpenSamlMetadataResolver();
// when
String metadata = openSamlMetadataResolver.resolve(relyingPartyRegistration);
// then
assertThat(metadata)
.contains("<EntityDescriptor")
.contains("entityID=\"rp-entity-id\"")
.contains("WantAssertionsSigned=\"true\"")
.doesNotContain("<md:KeyDescriptor use=\"signing\">")
.doesNotContain("<md:KeyDescriptor use=\"encryption\">")
.contains("Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\"")
.contains("Location=\"https://rp.example.org/acs\" index=\"1\"");
}
}

View File

@ -18,7 +18,7 @@ package org.springframework.security.saml2.provider.service.registration;
import org.junit.Test;
import org.springframework.security.saml2.provider.service.web.Saml2WebSsoAuthenticationFilter;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential;

View File

@ -16,8 +16,9 @@
package org.springframework.security.saml2.provider.service.registration;
import org.springframework.security.saml2.core.TestSaml2X509Credentials;
import org.springframework.security.saml2.credentials.Saml2X509Credential;
import org.springframework.security.saml2.provider.service.web.Saml2WebSsoAuthenticationFilter;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
@ -57,4 +58,15 @@ public class TestRelyingPartyRegistrations {
.singleSignOnServiceLocation("https://ap.example.org/sso")
);
}
public static RelyingPartyRegistration.Builder full() {
return noCredentials()
.signingX509Credentials(c -> c.add(TestSaml2X509Credentials.relyingPartySigningCredential()))
.decryptionX509Credentials(c -> c.add(TestSaml2X509Credentials.relyingPartyDecryptingCredential()))
.assertingPartyDetails(party -> party
.verificationX509Credentials(c -> c.add(
TestSaml2X509Credentials.relyingPartyVerifyingCredential())
)
);
}
}

View File

@ -1,67 +0,0 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.saml2.provider.service.web;
import org.junit.Before;
import org.junit.Test;
import org.opensaml.core.config.InitializationException;
import org.opensaml.core.config.InitializationService;
import org.opensaml.saml.saml2.core.NameIDType;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import javax.servlet.http.HttpServletRequest;
import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT;
public class OpenSamlMetadataResolverTest {
@Before
public void setUp() throws InitializationException {
InitializationService.initialize();
}
@Test
public void shouldGenerateMetadata() {
// given
OpenSamlMetadataResolver openSamlMetadataResolver = new OpenSamlMetadataResolver();
RelyingPartyRegistration relyingPartyRegistration = TestRelyingPartyRegistrations.relyingPartyRegistration()
.assertingPartyDetails(p -> p.singleSignOnServiceBinding(REDIRECT))
.assertingPartyDetails(p -> p.wantAuthnRequestsSigned(true))
.assertingPartyDetails(p -> p.nameIdFormat(NameIDType.EMAIL))
.build();
HttpServletRequest servletRequestMock = new MockHttpServletRequest();
// when
String metadataXml = openSamlMetadataResolver.resolveMetadata(servletRequestMock, relyingPartyRegistration);
// then
assertThat(metadataXml)
.contains("<EntityDescriptor")
.contains("entityID=\"http://localhost/saml2/service-provider-metadata/simplesamlphp\"")
.contains("AuthnRequestsSigned=\"true\"")
.contains("WantAssertionsSigned=\"true\"")
.contains("<md:KeyDescriptor use=\"signing\">")
.contains("<ds:X509Certificate>MIICgTCCAeoCCQCuVzyqFgMSyDANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UEBh")
.contains("<md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress</md:NameIDFormat>")
.contains("Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect\"")
.contains("Location=\"http://localhost/login/saml2/sso/simplesamlphp\" index=\"1\"");
}
}

View File

@ -20,95 +20,125 @@ import org.junit.Before;
import org.junit.Test;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.saml2.provider.service.metadata.Saml2MetadataResolver;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import javax.servlet.FilterChain;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import static org.springframework.security.saml2.core.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.noCredentials;
/**
* Tests for {@link Saml2MetadataFilter}
*/
public class Saml2MetadataFilterTest {
RelyingPartyRegistrationRepository repository;
Saml2MetadataResolver saml2MetadataResolver;
Saml2MetadataResolver resolver;
Saml2MetadataFilter filter;
MockHttpServletRequest request;
MockHttpServletResponse response;
FilterChain filterChain;
FilterChain chain;
@Before
public void setup() {
repository = mock(RelyingPartyRegistrationRepository.class);
saml2MetadataResolver = mock(Saml2MetadataResolver.class);
filter = new Saml2MetadataFilter(repository, saml2MetadataResolver);
request = new MockHttpServletRequest();
response = new MockHttpServletResponse();
filterChain = mock(FilterChain.class);
this.repository = mock(RelyingPartyRegistrationRepository.class);
this.resolver = mock(Saml2MetadataResolver.class);
this.filter = new Saml2MetadataFilter(
new DefaultRelyingPartyRegistrationResolver(this.repository), this.resolver);
this.request = new MockHttpServletRequest();
this.response = new MockHttpServletResponse();
this.chain = mock(FilterChain.class);
}
@Test
public void shouldReturnValueWhenMatcherSucceed() throws Exception {
public void doFilterWhenMatcherSucceedsThenResolverInvoked() throws Exception {
// given
request.setPathInfo("/saml2/service-provider-metadata/registration-id");
this.request.setPathInfo("/saml2/service-provider-metadata/registration-id");
// when
filter.doFilter(request, response, filterChain);
this.filter.doFilter(this.request, this.response, this.chain);
// then
verifyNoInteractions(filterChain);
verifyNoInteractions(this.chain);
verify(this.repository).findByRegistrationId("registration-id");
}
@Test
public void shouldProcessFilterChainIfMatcherFails() throws Exception {
public void doFilterWhenMatcherFailsThenProcessesFilterChain() throws Exception {
// given
request.setPathInfo("/saml2/authenticate/registration-id");
this.request.setPathInfo("/saml2/authenticate/registration-id");
// when
filter.doFilter(request, response, filterChain);
this.filter.doFilter(this.request, this.response, this.chain);
// then
verify(filterChain).doFilter(request, response);
verify(this.chain).doFilter(this.request, this.response);
}
@Test
public void shouldReturn401IfNoRegistrationIsFound() throws Exception {
public void doFilterWhenNoRelyingPartyRegistrationThenUnauthorized() throws Exception {
// given
request.setPathInfo("/saml2/service-provider-metadata/invalidRegistration");
when(repository.findByRegistrationId("invalidRegistration")).thenReturn(null);
this.request.setPathInfo("/saml2/service-provider-metadata/invalidRegistration");
when(this.repository.findByRegistrationId("invalidRegistration")).thenReturn(null);
// when
filter.doFilter(request, response, filterChain);
this.filter.doFilter(this.request, this.response, this.chain);
// then
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(401);
verifyNoInteractions(this.chain);
assertThat(this.response.getStatus()).isEqualTo(401);
}
@Test
public void shouldInvokeMetadataGenerationIfRegistrationIsFound() throws Exception {
public void doFilterWhenRelyingPartyRegistrationFoundThenInvokesMetadataResolver() throws Exception {
// given
request.setPathInfo("/saml2/service-provider-metadata/validRegistration");
RelyingPartyRegistration validRegistration = TestRelyingPartyRegistrations.relyingPartyRegistration().build();
when(repository.findByRegistrationId("validRegistration")).thenReturn(validRegistration);
this.request.setPathInfo("/saml2/service-provider-metadata/validRegistration");
RelyingPartyRegistration validRegistration = noCredentials()
.assertingPartyDetails(party -> party
.verificationX509Credentials(c -> c.add(relyingPartyVerifyingCredential())))
.build();
String generatedMetadata = "<xml>test</xml>";
when(saml2MetadataResolver.resolveMetadata(request, validRegistration)).thenReturn(generatedMetadata);
when(this.resolver.resolve(validRegistration)).thenReturn(generatedMetadata);
filter = new Saml2MetadataFilter(repository, saml2MetadataResolver);
this.filter = new Saml2MetadataFilter(request -> validRegistration, this.resolver);
// when
filter.doFilter(request, response, filterChain);
this.filter.doFilter(this.request, this.response, this.chain);
// then
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(200);
assertThat(response.getContentAsString()).isEqualTo(generatedMetadata);
verify(saml2MetadataResolver).resolveMetadata(request, validRegistration);
verifyNoInteractions(this.chain);
assertThat(this.response.getStatus()).isEqualTo(200);
assertThat(this.response.getContentAsString()).isEqualTo(generatedMetadata);
verify(this.resolver).resolve(validRegistration);
}
@Test
public void doFilterWhenCustomRequestMatcherThenUses() throws Exception {
// given
this.request.setPathInfo("/path");
this.filter.setRequestMatcher(new AntPathRequestMatcher("/path"));
// when
this.filter.doFilter(this.request, this.response, this.chain);
// then
verifyNoInteractions(this.chain);
verify(this.repository).findByRegistrationId("path");
}
@Test
public void setRequestMatcherWhenNullThenIllegalArgument() {
assertThatCode(() -> this.filter.setRequestMatcher(null))
.isInstanceOf(IllegalArgumentException.class);
}
}

View File

@ -29,7 +29,7 @@ import org.springframework.security.converter.RsaKeyConverters;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.security.saml2.provider.service.registration.InMemoryRelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.web.Saml2WebSsoAuthenticationFilter;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
import static org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType.DECRYPTION;
import static org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType.SIGNING;

View File

@ -17,7 +17,7 @@ package org.springframework.security.samples.config;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.security.saml2.provider.service.web.Saml2WebSsoAuthenticationFilter;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
import org.springframework.security.web.FilterChainProxy;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;