Add AuthnRequestConsumerResolver

Closes gh-8141
This commit is contained in:
Josh Cummings 2020-07-16 14:51:47 -06:00
parent 2e5c87dc75
commit 2c960d2ad1
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
3 changed files with 124 additions and 3 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,12 +16,16 @@
package org.springframework.security.config.annotation.web.configurers.saml2;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.net.URLDecoder;
import java.time.Duration;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collection;
import java.util.Collections;
import java.util.zip.Inflater;
import java.util.zip.InflaterOutputStream;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
@ -54,9 +58,12 @@ import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider;
import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory;
import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
@ -69,7 +76,11 @@ import org.springframework.security.web.context.HttpSessionSecurityContextReposi
import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
@ -157,6 +168,20 @@ public class Saml2LoginConfigurerTests {
verify(resolver).resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class));
}
@Test
public void authenticationRequestWhenAuthnRequestConsumerResolverThenUses() throws Exception {
this.spring.register(CustomAuthnRequestConsumerResolver.class).autowire();
MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id"))
.andReturn();
UriComponents components = UriComponentsBuilder
.fromHttpUrl(result.getResponse().getRedirectedUrl()).build();
String samlRequest = components.getQueryParams().getFirst("SAMLRequest");
String decoded = URLDecoder.decode(samlRequest, "UTF-8");
String inflated = samlInflate(samlDecode(decoded));
assertThat(inflated).contains("ForceAuthn=\"true\"");
}
private void validateSaml2WebSsoAuthenticationFilterConfiguration() {
// get the OpenSamlAuthenticationProvider
Saml2WebSsoAuthenticationFilter filter = getSaml2SsoFilter(this.springSecurityFilterChain);
@ -275,6 +300,29 @@ public class Saml2LoginConfigurerTests {
}
}
@EnableWebSecurity
@Import(Saml2LoginConfigBeans.class)
static class CustomAuthnRequestConsumerResolver extends WebSecurityConfigurerAdapter {
@Override
protected void configure(HttpSecurity http) throws Exception {
http
.authorizeRequests(authz -> authz
.anyRequest().authenticated()
)
.saml2Login(saml2 -> {});
}
@Bean
Saml2AuthenticationRequestFactory authenticationRequestFactory() {
OpenSamlAuthenticationRequestFactory authenticationRequestFactory =
new OpenSamlAuthenticationRequestFactory();
authenticationRequestFactory.setAuthnRequestConsumerResolver(
context -> authnRequest -> authnRequest.setForceAuthn(true));
return authenticationRequestFactory;
}
}
private static AuthenticationManager getAuthenticationManagerMock(String role) {
return new AuthenticationManager() {
@ -315,4 +363,23 @@ public class Saml2LoginConfigurerTests {
}
}
private static org.apache.commons.codec.binary.Base64 BASE64 =
new org.apache.commons.codec.binary.Base64(0, new byte[]{'\n'});
private static byte[] samlDecode(String s) {
return BASE64.decode(s);
}
private static String samlInflate(byte[] b) {
try {
ByteArrayOutputStream out = new ByteArrayOutputStream();
InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true));
iout.write(b);
iout.finish();
return new String(out.toByteArray(), UTF_8);
}
catch (IOException e) {
throw new Saml2Exception("Unable to inflate string", e);
}
}
}

View File

@ -21,6 +21,8 @@ import java.time.Instant;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.function.Consumer;
import java.util.function.Function;
import org.joda.time.DateTime;
import org.opensaml.saml.common.xml.SAMLConstants;
@ -43,6 +45,9 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
private final OpenSamlImplementation saml = OpenSamlImplementation.getInstance();
private String protocolBinding = SAMLConstants.SAML2_POST_BINDING_URI;
private Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver
= context -> authnRequest -> {};
@Override
@Deprecated
public String createAuthenticationRequest(Saml2AuthenticationRequest request) {
@ -95,8 +100,10 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
}
private AuthnRequest createAuthnRequest(Saml2AuthenticationRequestContext context) {
return createAuthnRequest(context.getIssuer(),
AuthnRequest authnRequest = createAuthnRequest(context.getIssuer(),
context.getDestination(), context.getAssertionConsumerServiceUrl());
this.authnRequestConsumerResolver.apply(context).accept(authnRequest);
return authnRequest;
}
private AuthnRequest createAuthnRequest(String issuer, String destination, String assertionConsumerServiceUrl) {
@ -114,6 +121,18 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
return auth;
}
/**
* Set the {@link AuthnRequest} post-processor resolver
*
* @param authnRequestConsumerResolver
* @since 5.4
*/
public void setAuthnRequestConsumerResolver(
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver) {
Assert.notNull(authnRequestConsumerResolver, "authnRequestConsumerResolver cannot be null");
this.authnRequestConsumerResolver = authnRequestConsumerResolver;
}
/**
* '
* Use this {@link Clock} with {@link Instant#now()} for generating

View File

@ -16,6 +16,9 @@
package org.springframework.security.saml2.provider.service.authentication;
import java.util.function.Consumer;
import java.util.function.Function;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
@ -29,9 +32,13 @@ import org.springframework.security.saml2.provider.service.registration.Saml2Mes
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.hamcrest.CoreMatchers.containsString;
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential;
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode;
import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration;
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT;
@ -160,6 +167,34 @@ public class OpenSamlAuthenticationRequestFactoryTests {
factory.setProtocolBinding("my-invalid-binding");
}
@Test
public void createPostAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver =
mock(Function.class);
when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {});
this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver);
this.factory.createPostAuthenticationRequest(this.context);
verify(authnRequestConsumerResolver).apply(this.context);
}
@Test
public void createRedirectAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver =
mock(Function.class);
when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {});
this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver);
this.factory.createRedirectAuthenticationRequest(this.context);
verify(authnRequestConsumerResolver).apply(this.context);
}
@Test
public void setAuthnRequestConsumerResolverWhenNullThenException() {
assertThatCode(() -> this.factory.setAuthnRequestConsumerResolver(null))
.isInstanceOf(IllegalArgumentException.class);
}
private AuthnRequest getAuthNRequest(Saml2MessageBinding binding) {
AbstractSaml2AuthenticationRequest result = (binding == REDIRECT) ?
factory.createRedirectAuthenticationRequest(context) :