Add support for AuthnRequestsSigned setting

closes gh-12604
This commit is contained in:
Liviu Gheorghe 2023-01-31 15:02:20 +02:00 committed by Josh Cummings
parent 55224b58e0
commit 21d919169a
4 changed files with 66 additions and 16 deletions

View File

@ -86,6 +86,8 @@ public class RelyingPartyRegistration {
private final String nameIdFormat; private final String nameIdFormat;
private final boolean authnRequestsSigned;
private final AssertingPartyDetails assertingPartyDetails; private final AssertingPartyDetails assertingPartyDetails;
private final Collection<Saml2X509Credential> decryptionX509Credentials; private final Collection<Saml2X509Credential> decryptionX509Credentials;
@ -95,7 +97,7 @@ public class RelyingPartyRegistration {
protected RelyingPartyRegistration(String registrationId, String entityId, String assertionConsumerServiceLocation, protected RelyingPartyRegistration(String registrationId, String entityId, String assertionConsumerServiceLocation,
Saml2MessageBinding assertionConsumerServiceBinding, String singleLogoutServiceLocation, Saml2MessageBinding assertionConsumerServiceBinding, String singleLogoutServiceLocation,
String singleLogoutServiceResponseLocation, Collection<Saml2MessageBinding> singleLogoutServiceBindings, String singleLogoutServiceResponseLocation, Collection<Saml2MessageBinding> singleLogoutServiceBindings,
AssertingPartyDetails assertingPartyDetails, String nameIdFormat, AssertingPartyDetails assertingPartyDetails, String nameIdFormat, boolean authnRequestsSigned,
Collection<Saml2X509Credential> decryptionX509Credentials, Collection<Saml2X509Credential> decryptionX509Credentials,
Collection<Saml2X509Credential> signingX509Credentials) { Collection<Saml2X509Credential> signingX509Credentials) {
Assert.hasText(registrationId, "registrationId cannot be empty"); Assert.hasText(registrationId, "registrationId cannot be empty");
@ -124,6 +126,7 @@ public class RelyingPartyRegistration {
this.singleLogoutServiceResponseLocation = singleLogoutServiceResponseLocation; this.singleLogoutServiceResponseLocation = singleLogoutServiceResponseLocation;
this.singleLogoutServiceBindings = Collections.unmodifiableList(new LinkedList<>(singleLogoutServiceBindings)); this.singleLogoutServiceBindings = Collections.unmodifiableList(new LinkedList<>(singleLogoutServiceBindings));
this.nameIdFormat = nameIdFormat; this.nameIdFormat = nameIdFormat;
this.authnRequestsSigned = authnRequestsSigned;
this.assertingPartyDetails = assertingPartyDetails; this.assertingPartyDetails = assertingPartyDetails;
this.decryptionX509Credentials = Collections.unmodifiableList(new LinkedList<>(decryptionX509Credentials)); this.decryptionX509Credentials = Collections.unmodifiableList(new LinkedList<>(decryptionX509Credentials));
this.signingX509Credentials = Collections.unmodifiableList(new LinkedList<>(signingX509Credentials)); this.signingX509Credentials = Collections.unmodifiableList(new LinkedList<>(signingX509Credentials));
@ -281,6 +284,15 @@ public class RelyingPartyRegistration {
return this.nameIdFormat; return this.nameIdFormat;
} }
/**
* Get the WantAuthnRequestsSigned setting
* @return the WantAuthnRequestsSigned setting
* @since 6.0
*/
public boolean isAuthnRequestsSigned() {
return authnRequestsSigned;
}
/** /**
* Get the {@link Collection} of decryption {@link Saml2X509Credential}s associated * Get the {@link Collection} of decryption {@link Saml2X509Credential}s associated
* with this relying party * with this relying party
@ -357,6 +369,7 @@ public class RelyingPartyRegistration {
.singleLogoutServiceResponseLocation(registration.getSingleLogoutServiceResponseLocation()) .singleLogoutServiceResponseLocation(registration.getSingleLogoutServiceResponseLocation())
.singleLogoutServiceBindings((c) -> c.addAll(registration.getSingleLogoutServiceBindings())) .singleLogoutServiceBindings((c) -> c.addAll(registration.getSingleLogoutServiceBindings()))
.nameIdFormat(registration.getNameIdFormat()) .nameIdFormat(registration.getNameIdFormat())
.authnRequestsSigned(registration.isAuthnRequestsSigned())
.assertingPartyDetails((assertingParty) -> assertingParty .assertingPartyDetails((assertingParty) -> assertingParty
.entityId(registration.getAssertingPartyDetails().getEntityId()) .entityId(registration.getAssertingPartyDetails().getEntityId())
.wantAuthnRequestsSigned(registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()) .wantAuthnRequestsSigned(registration.getAssertingPartyDetails().getWantAuthnRequestsSigned())
@ -788,6 +801,8 @@ public class RelyingPartyRegistration {
private String nameIdFormat = null; private String nameIdFormat = null;
private boolean authnRequestsSigned = false;
private AssertingPartyDetails.Builder assertingPartyDetailsBuilder; private AssertingPartyDetails.Builder assertingPartyDetailsBuilder;
protected Builder(String registrationId, AssertingPartyDetails.Builder assertingPartyDetailsBuilder) { protected Builder(String registrationId, AssertingPartyDetails.Builder assertingPartyDetailsBuilder) {
@ -974,6 +989,17 @@ public class RelyingPartyRegistration {
return this; return this;
} }
/**
* Set the AuthnRequestsSigned setting
* @param authnRequestsSigned
* @return the {@link Builder} for further configuration
* @since 6.0
*/
public Builder authnRequestsSigned(Boolean authnRequestsSigned) {
this.authnRequestsSigned = authnRequestsSigned;
return this;
}
/** /**
* Apply this {@link Consumer} to further configure the Asserting Party details * Apply this {@link Consumer} to further configure the Asserting Party details
* @param assertingPartyDetails The {@link Consumer} to apply * @param assertingPartyDetails The {@link Consumer} to apply
@ -1003,8 +1029,8 @@ public class RelyingPartyRegistration {
return new RelyingPartyRegistration(this.registrationId, this.entityId, return new RelyingPartyRegistration(this.registrationId, this.entityId,
this.assertionConsumerServiceLocation, this.assertionConsumerServiceBinding, this.assertionConsumerServiceLocation, this.assertionConsumerServiceBinding,
this.singleLogoutServiceLocation, this.singleLogoutServiceResponseLocation, this.singleLogoutServiceLocation, this.singleLogoutServiceResponseLocation,
this.singleLogoutServiceBindings, party, this.nameIdFormat, this.decryptionX509Credentials, this.singleLogoutServiceBindings, party, this.nameIdFormat, this.authnRequestsSigned,
this.signingX509Credentials); this.decryptionX509Credentials, this.signingX509Credentials);
} }
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2022 the original author or authors. * Copyright 2002-2023 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -142,7 +142,7 @@ class OpenSamlAuthenticationRequestResolver {
String relayState = this.relayStateResolver.convert(request); String relayState = this.relayStateResolver.convert(request);
Saml2MessageBinding binding = registration.getAssertingPartyDetails().getSingleSignOnServiceBinding(); Saml2MessageBinding binding = registration.getAssertingPartyDetails().getSingleSignOnServiceBinding();
if (binding == Saml2MessageBinding.POST) { if (binding == Saml2MessageBinding.POST) {
if (registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()) { if (registration.getAssertingPartyDetails().getWantAuthnRequestsSigned() || registration.isAuthnRequestsSigned()) {
OpenSamlSigningUtils.sign(authnRequest, registration); OpenSamlSigningUtils.sign(authnRequest, registration);
} }
String xml = serialize(authnRequest); String xml = serialize(authnRequest);
@ -156,7 +156,7 @@ class OpenSamlAuthenticationRequestResolver {
Saml2RedirectAuthenticationRequest.Builder builder = Saml2RedirectAuthenticationRequest Saml2RedirectAuthenticationRequest.Builder builder = Saml2RedirectAuthenticationRequest
.withRelyingPartyRegistration(registration).samlRequest(deflatedAndEncoded).relayState(relayState) .withRelyingPartyRegistration(registration).samlRequest(deflatedAndEncoded).relayState(relayState)
.id(authnRequest.getID()); .id(authnRequest.getID());
if (registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()) { if (registration.getAssertingPartyDetails().getWantAuthnRequestsSigned() || registration.isAuthnRequestsSigned()) {
Map<String, String> parameters = OpenSamlSigningUtils.sign(registration) Map<String, String> parameters = OpenSamlSigningUtils.sign(registration)
.param(Saml2ParameterNames.SAML_REQUEST, deflatedAndEncoded) .param(Saml2ParameterNames.SAML_REQUEST, deflatedAndEncoded)
.param(Saml2ParameterNames.RELAY_STATE, relayState).parameters(); .param(Saml2ParameterNames.RELAY_STATE, relayState).parameters();

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2022 the original author or authors. * Copyright 2002-2023 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -30,6 +30,7 @@ public class RelyingPartyRegistrationTests {
public void withRelyingPartyRegistrationWorks() { public void withRelyingPartyRegistrationWorks() {
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration() RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration()
.nameIdFormat("format") .nameIdFormat("format")
.authnRequestsSigned(true)
.assertingPartyDetails((a) -> a.singleSignOnServiceBinding(Saml2MessageBinding.POST)) .assertingPartyDetails((a) -> a.singleSignOnServiceBinding(Saml2MessageBinding.POST))
.assertingPartyDetails((a) -> a.wantAuthnRequestsSigned(false)) .assertingPartyDetails((a) -> a.wantAuthnRequestsSigned(false))
.assertingPartyDetails((a) -> a.signingAlgorithms((algs) -> algs.add("alg"))) .assertingPartyDetails((a) -> a.signingAlgorithms((algs) -> algs.add("alg")))
@ -82,6 +83,7 @@ public class RelyingPartyRegistrationTests {
assertThat(copy.getAssertingPartyDetails().getSigningAlgorithms()) assertThat(copy.getAssertingPartyDetails().getSigningAlgorithms())
.isEqualTo(registration.getAssertingPartyDetails().getSigningAlgorithms()); .isEqualTo(registration.getAssertingPartyDetails().getSigningAlgorithms());
assertThat(copy.getNameIdFormat()).isEqualTo(registration.getNameIdFormat()); assertThat(copy.getNameIdFormat()).isEqualTo(registration.getNameIdFormat());
assertThat(copy.isAuthnRequestsSigned()).isEqualTo(registration.isAuthnRequestsSigned());
} }
@Test @Test

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2022 the original author or authors. * Copyright 2002-2023 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -18,6 +18,9 @@ package org.springframework.security.saml2.provider.service.web.authentication;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.opensaml.xmlsec.signature.support.SignatureConstants; import org.opensaml.xmlsec.signature.support.SignatureConstants;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
@ -32,6 +35,8 @@ import org.springframework.security.saml2.provider.service.registration.TestRely
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers; import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers;
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers.UriResolver; import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers.UriResolver;
import java.util.stream.Stream;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
@ -47,11 +52,15 @@ public class OpenSamlAuthenticationRequestResolverTests {
this.relyingPartyRegistrationBuilder = TestRelyingPartyRegistrations.relyingPartyRegistration(); this.relyingPartyRegistrationBuilder = TestRelyingPartyRegistrations.relyingPartyRegistration();
} }
@Test @ParameterizedTest
public void resolveAuthenticationRequestWhenSignedRedirectThenSignsAndRedirects() { @MethodSource("provideSignRequestFlags")
public void resolveAuthenticationRequestWhenSignedRedirectThenSignsAndRedirects(boolean wantAuthRequestsSigned, boolean authnRequestsSigned) {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = new MockHttpServletRequest();
request.setPathInfo("/saml2/authenticate/registration-id"); request.setPathInfo("/saml2/authenticate/registration-id");
RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder.build(); RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder
.authnRequestsSigned(authnRequestsSigned)
.assertingPartyDetails(party -> party.wantAuthnRequestsSigned(wantAuthRequestsSigned))
.build();
OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration); OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration);
Saml2RedirectAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> { Saml2RedirectAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> {
UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration); UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration);
@ -113,8 +122,9 @@ public class OpenSamlAuthenticationRequestResolverTests {
public void resolveAuthenticationRequestWhenUnsignedPostThenOnlyPosts() { public void resolveAuthenticationRequestWhenUnsignedPostThenOnlyPosts() {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = new MockHttpServletRequest();
request.setPathInfo("/saml2/authenticate/registration-id"); request.setPathInfo("/saml2/authenticate/registration-id");
RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder.assertingPartyDetails( RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder
(party) -> party.singleSignOnServiceBinding(Saml2MessageBinding.POST).wantAuthnRequestsSigned(false)) .assertingPartyDetails((party) -> party.singleSignOnServiceBinding(Saml2MessageBinding.POST).wantAuthnRequestsSigned(false))
.authnRequestsSigned(false)
.build(); .build();
OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration); OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration);
Saml2PostAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> { Saml2PostAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> {
@ -134,12 +144,16 @@ public class OpenSamlAuthenticationRequestResolverTests {
assertThat(result.getId()).isNotEmpty(); assertThat(result.getId()).isNotEmpty();
} }
@Test @ParameterizedTest
public void resolveAuthenticationRequestWhenSignedPostThenSignsAndPosts() { @MethodSource("provideSignRequestFlags")
public void resolveAuthenticationRequestWhenSignedPostThenSignsAndPosts(boolean wantAuthRequestsSigned, boolean authnRequestsSigned) {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = new MockHttpServletRequest();
request.setPathInfo("/saml2/authenticate/registration-id"); request.setPathInfo("/saml2/authenticate/registration-id");
RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder
.assertingPartyDetails((party) -> party.singleSignOnServiceBinding(Saml2MessageBinding.POST)).build(); .authnRequestsSigned(authnRequestsSigned)
.assertingPartyDetails((party) -> party.singleSignOnServiceBinding(Saml2MessageBinding.POST)
.wantAuthnRequestsSigned(wantAuthRequestsSigned))
.build();
OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration); OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration);
Saml2PostAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> { Saml2PostAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> {
UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration); UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration);
@ -180,4 +194,12 @@ public class OpenSamlAuthenticationRequestResolverTests {
return new OpenSamlAuthenticationRequestResolver((request, id) -> registration); return new OpenSamlAuthenticationRequestResolver((request, id) -> registration);
} }
private static Stream<Arguments> provideSignRequestFlags() {
return Stream.of(
Arguments.of(true, true),
Arguments.of(true, false),
Arguments.of(false, true)
);
}
} }