Polish DefaultSaml2AuthenticationRequestContextResolver

- Added more tests
- Standardized terminology

Issue gh-8360
This commit is contained in:
Josh Cummings 2020-04-17 15:20:35 -06:00
parent 8c0bdd50e2
commit ab772893c7
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
2 changed files with 73 additions and 31 deletions

View File

@ -16,8 +16,14 @@
package org.springframework.security.saml2.provider.service.web; package org.springframework.security.saml2.provider.service.web;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import javax.servlet.http.HttpServletRequest;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@ -25,11 +31,6 @@ import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponents; import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UriComponentsBuilder;
import javax.servlet.http.HttpServletRequest;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import static org.springframework.security.web.util.UrlUtils.buildFullRequestUrl; import static org.springframework.security.web.util.UrlUtils.buildFullRequestUrl;
import static org.springframework.web.util.UriComponentsBuilder.fromHttpUrl; import static org.springframework.web.util.UriComponentsBuilder.fromHttpUrl;
@ -81,10 +82,6 @@ public final class DefaultSaml2AuthenticationRequestContextResolver implements S
} }
private static String resolveUrlTemplate(String template, String baseUrl, RelyingPartyRegistration relyingParty) { private static String resolveUrlTemplate(String template, String baseUrl, RelyingPartyRegistration relyingParty) {
if (!StringUtils.hasText(template)) {
return baseUrl;
}
String entityId = relyingParty.getProviderDetails().getEntityId(); String entityId = relyingParty.getProviderDetails().getEntityId();
String registrationId = relyingParty.getRegistrationId(); String registrationId = relyingParty.getRegistrationId();
Map<String, String> uriVariables = new HashMap<>(); Map<String, String> uriVariables = new HashMap<>();

View File

@ -23,44 +23,89 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2A
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import static org.springframework.security.saml2.provider.service.servlet.filter.TestSaml2SigningCredentials.signingCredential; import static org.springframework.security.saml2.provider.service.servlet.filter.TestSaml2SigningCredentials.signingCredential;
import static org.assertj.core.api.Assertions.*; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
/**
* Tests for {@link DefaultSaml2AuthenticationRequestContextResolver}
*
* @author Shazin Sadakath
* @author Josh Cummings
*/
public class DefaultSaml2AuthenticationRequestContextResolverTests { public class DefaultSaml2AuthenticationRequestContextResolverTests {
private static final String IDP_SSO_URL = "https://sso-url.example.com/IDP/SSO"; private static final String ASSERTING_PARTY_SSO_URL = "https://idp.example.com/sso";
private static final String TEMPLATE = "template"; private static final String RELYING_PARTY_SSO_URL = "https://sp.example.com/sso";
private static final String ASSERTING_PARTY_ENTITY_ID = "asserting-party-entity-id";
private static final String RELYING_PARTY_ENTITY_ID = "relying-party-entity-id";
private static final String REGISTRATION_ID = "registration-id"; private static final String REGISTRATION_ID = "registration-id";
private static final String IDP_ENTITY_ID = "idp-entity-id";
private MockHttpServletRequest request; private MockHttpServletRequest request;
private RelyingPartyRegistration.Builder rpBuilder; private RelyingPartyRegistration.Builder relyingPartyBuilder;
private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver = new DefaultSaml2AuthenticationRequestContextResolver(); private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver
= new DefaultSaml2AuthenticationRequestContextResolver();
@Before @Before
public void setup() { public void setup() {
request = new MockHttpServletRequest(); this.request = new MockHttpServletRequest();
rpBuilder = RelyingPartyRegistration this.relyingPartyBuilder = RelyingPartyRegistration
.withRegistrationId(REGISTRATION_ID) .withRegistrationId(REGISTRATION_ID)
.providerDetails(c -> c.entityId(IDP_ENTITY_ID)) .localEntityIdTemplate(RELYING_PARTY_ENTITY_ID)
.providerDetails(c -> c.webSsoUrl(IDP_SSO_URL)) .providerDetails(c -> c.entityId(ASSERTING_PARTY_ENTITY_ID))
.assertionConsumerServiceUrlTemplate(TEMPLATE) .providerDetails(c -> c.webSsoUrl(ASSERTING_PARTY_SSO_URL))
.assertionConsumerServiceUrlTemplate(RELYING_PARTY_SSO_URL)
.credentials(c -> c.add(signingCredential())); .credentials(c -> c.add(signingCredential()));
} }
@Test @Test
public void resoleWhenRequestAndRelyingPartyNotNullThenCreateSaml2AuthenticationRequestContext() { public void resolveWhenRequestAndRelyingPartyNotNullThenCreateSaml2AuthenticationRequestContext() {
Saml2AuthenticationRequestContext authenticationRequestContext = authenticationRequestContextResolver.resolve(request, rpBuilder.build()); this.request.addParameter("RelayState", "relay-state");
RelyingPartyRegistration relyingParty = this.relyingPartyBuilder.build();
Saml2AuthenticationRequestContext context =
this.authenticationRequestContextResolver.resolve(this.request, relyingParty);
assertThat(authenticationRequestContext).isNotNull(); assertThat(context).isNotNull();
assertThat(authenticationRequestContext.getAssertionConsumerServiceUrl()).isEqualTo(TEMPLATE); assertThat(context.getAssertionConsumerServiceUrl()).isEqualTo(RELYING_PARTY_SSO_URL);
assertThat(authenticationRequestContext.getRelyingPartyRegistration().getRegistrationId()).isEqualTo(REGISTRATION_ID); assertThat(context.getRelayState()).isEqualTo("relay-state");
assertThat(authenticationRequestContext.getRelyingPartyRegistration().getProviderDetails().getEntityId()).isEqualTo(IDP_ENTITY_ID); assertThat(context.getDestination()).isEqualTo(ASSERTING_PARTY_SSO_URL);
assertThat(authenticationRequestContext.getRelyingPartyRegistration().getProviderDetails().getWebSsoUrl()).isEqualTo(IDP_SSO_URL); assertThat(context.getIssuer()).isEqualTo(RELYING_PARTY_ENTITY_ID);
assertThat(authenticationRequestContext.getRelyingPartyRegistration().getCredentials()).isNotEmpty(); assertThat(context.getRelyingPartyRegistration()).isSameAs(relyingParty);
} }
@Test(expected = IllegalArgumentException.class) @Test
public void resolveWhenRequestAndRelyingPartyNullThenException() { public void resolveWhenAssertionConsumerServiceUrlTemplateContainsRegistrationIdThenResolves() {
authenticationRequestContextResolver.resolve(null, null); RelyingPartyRegistration relyingParty = this.relyingPartyBuilder
.assertionConsumerServiceUrlTemplate("/saml2/authenticate/{registrationId}")
.build();
Saml2AuthenticationRequestContext context =
this.authenticationRequestContextResolver.resolve(this.request, relyingParty);
assertThat(context.getAssertionConsumerServiceUrl()).isEqualTo("/saml2/authenticate/registration-id");
}
@Test
public void resolveWhenAssertionConsumerServiceUrlTemplateContainsBaseUrlThenResolves() {
RelyingPartyRegistration relyingParty = this.relyingPartyBuilder
.assertionConsumerServiceUrlTemplate("{baseUrl}/saml2/authenticate/{registrationId}")
.build();
Saml2AuthenticationRequestContext context =
this.authenticationRequestContextResolver.resolve(this.request, relyingParty);
assertThat(context.getAssertionConsumerServiceUrl())
.isEqualTo("http://localhost/saml2/authenticate/registration-id");
}
@Test
public void resolveWhenRequestNullThenException() {
assertThatCode(() ->
this.authenticationRequestContextResolver.resolve(this.request, null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void resolveWhenRelyingPartyNullThenException() {
assertThatCode(() ->
this.authenticationRequestContextResolver.resolve(null, this.relyingPartyBuilder.build()))
.isInstanceOf(IllegalArgumentException.class);
} }
} }