Polish DefaultSaml2AuthenticationRequestContextResolver

- Added more tests
- Standardized terminology

Issue gh-8360
This commit is contained in:
Josh Cummings 2020-04-17 15:20:35 -06:00
parent 8c0bdd50e2
commit ab772893c7
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
2 changed files with 73 additions and 31 deletions

View File

@ -16,8 +16,14 @@
package org.springframework.security.saml2.provider.service.web;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import javax.servlet.http.HttpServletRequest;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.util.Assert;
@ -25,11 +31,6 @@ import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;
import javax.servlet.http.HttpServletRequest;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import static org.springframework.security.web.util.UrlUtils.buildFullRequestUrl;
import static org.springframework.web.util.UriComponentsBuilder.fromHttpUrl;
@ -81,10 +82,6 @@ public final class DefaultSaml2AuthenticationRequestContextResolver implements S
}
private static String resolveUrlTemplate(String template, String baseUrl, RelyingPartyRegistration relyingParty) {
if (!StringUtils.hasText(template)) {
return baseUrl;
}
String entityId = relyingParty.getProviderDetails().getEntityId();
String registrationId = relyingParty.getRegistrationId();
Map<String, String> uriVariables = new HashMap<>();

View File

@ -23,44 +23,89 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2A
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import static org.springframework.security.saml2.provider.service.servlet.filter.TestSaml2SigningCredentials.signingCredential;
import static org.assertj.core.api.Assertions.*;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
/**
* Tests for {@link DefaultSaml2AuthenticationRequestContextResolver}
*
* @author Shazin Sadakath
* @author Josh Cummings
*/
public class DefaultSaml2AuthenticationRequestContextResolverTests {
private static final String IDP_SSO_URL = "https://sso-url.example.com/IDP/SSO";
private static final String TEMPLATE = "template";
private static final String ASSERTING_PARTY_SSO_URL = "https://idp.example.com/sso";
private static final String RELYING_PARTY_SSO_URL = "https://sp.example.com/sso";
private static final String ASSERTING_PARTY_ENTITY_ID = "asserting-party-entity-id";
private static final String RELYING_PARTY_ENTITY_ID = "relying-party-entity-id";
private static final String REGISTRATION_ID = "registration-id";
private static final String IDP_ENTITY_ID = "idp-entity-id";
private MockHttpServletRequest request;
private RelyingPartyRegistration.Builder rpBuilder;
private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver = new DefaultSaml2AuthenticationRequestContextResolver();
private RelyingPartyRegistration.Builder relyingPartyBuilder;
private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver
= new DefaultSaml2AuthenticationRequestContextResolver();
@Before
public void setup() {
request = new MockHttpServletRequest();
rpBuilder = RelyingPartyRegistration
this.request = new MockHttpServletRequest();
this.relyingPartyBuilder = RelyingPartyRegistration
.withRegistrationId(REGISTRATION_ID)
.providerDetails(c -> c.entityId(IDP_ENTITY_ID))
.providerDetails(c -> c.webSsoUrl(IDP_SSO_URL))
.assertionConsumerServiceUrlTemplate(TEMPLATE)
.localEntityIdTemplate(RELYING_PARTY_ENTITY_ID)
.providerDetails(c -> c.entityId(ASSERTING_PARTY_ENTITY_ID))
.providerDetails(c -> c.webSsoUrl(ASSERTING_PARTY_SSO_URL))
.assertionConsumerServiceUrlTemplate(RELYING_PARTY_SSO_URL)
.credentials(c -> c.add(signingCredential()));
}
@Test
public void resoleWhenRequestAndRelyingPartyNotNullThenCreateSaml2AuthenticationRequestContext() {
Saml2AuthenticationRequestContext authenticationRequestContext = authenticationRequestContextResolver.resolve(request, rpBuilder.build());
public void resolveWhenRequestAndRelyingPartyNotNullThenCreateSaml2AuthenticationRequestContext() {
this.request.addParameter("RelayState", "relay-state");
RelyingPartyRegistration relyingParty = this.relyingPartyBuilder.build();
Saml2AuthenticationRequestContext context =
this.authenticationRequestContextResolver.resolve(this.request, relyingParty);
assertThat(authenticationRequestContext).isNotNull();
assertThat(authenticationRequestContext.getAssertionConsumerServiceUrl()).isEqualTo(TEMPLATE);
assertThat(authenticationRequestContext.getRelyingPartyRegistration().getRegistrationId()).isEqualTo(REGISTRATION_ID);
assertThat(authenticationRequestContext.getRelyingPartyRegistration().getProviderDetails().getEntityId()).isEqualTo(IDP_ENTITY_ID);
assertThat(authenticationRequestContext.getRelyingPartyRegistration().getProviderDetails().getWebSsoUrl()).isEqualTo(IDP_SSO_URL);
assertThat(authenticationRequestContext.getRelyingPartyRegistration().getCredentials()).isNotEmpty();
assertThat(context).isNotNull();
assertThat(context.getAssertionConsumerServiceUrl()).isEqualTo(RELYING_PARTY_SSO_URL);
assertThat(context.getRelayState()).isEqualTo("relay-state");
assertThat(context.getDestination()).isEqualTo(ASSERTING_PARTY_SSO_URL);
assertThat(context.getIssuer()).isEqualTo(RELYING_PARTY_ENTITY_ID);
assertThat(context.getRelyingPartyRegistration()).isSameAs(relyingParty);
}
@Test(expected = IllegalArgumentException.class)
public void resolveWhenRequestAndRelyingPartyNullThenException() {
authenticationRequestContextResolver.resolve(null, null);
@Test
public void resolveWhenAssertionConsumerServiceUrlTemplateContainsRegistrationIdThenResolves() {
RelyingPartyRegistration relyingParty = this.relyingPartyBuilder
.assertionConsumerServiceUrlTemplate("/saml2/authenticate/{registrationId}")
.build();
Saml2AuthenticationRequestContext context =
this.authenticationRequestContextResolver.resolve(this.request, relyingParty);
assertThat(context.getAssertionConsumerServiceUrl()).isEqualTo("/saml2/authenticate/registration-id");
}
@Test
public void resolveWhenAssertionConsumerServiceUrlTemplateContainsBaseUrlThenResolves() {
RelyingPartyRegistration relyingParty = this.relyingPartyBuilder
.assertionConsumerServiceUrlTemplate("{baseUrl}/saml2/authenticate/{registrationId}")
.build();
Saml2AuthenticationRequestContext context =
this.authenticationRequestContextResolver.resolve(this.request, relyingParty);
assertThat(context.getAssertionConsumerServiceUrl())
.isEqualTo("http://localhost/saml2/authenticate/registration-id");
}
@Test
public void resolveWhenRequestNullThenException() {
assertThatCode(() ->
this.authenticationRequestContextResolver.resolve(this.request, null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void resolveWhenRelyingPartyNullThenException() {
assertThatCode(() ->
this.authenticationRequestContextResolver.resolve(null, this.relyingPartyBuilder.build()))
.isInstanceOf(IllegalArgumentException.class);
}
}