mirror of
https://github.com/spring-projects/spring-security.git
synced 2025-03-01 10:59:16 +00:00
Add AuthnRequestConsumerResolver
Closes gh-8141
This commit is contained in:
parent
2e5c87dc75
commit
2c960d2ad1
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2002-2019 the original author or authors.
|
||||
* Copyright 2002-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -16,12 +16,16 @@
|
||||
|
||||
package org.springframework.security.config.annotation.web.configurers.saml2;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.net.URLDecoder;
|
||||
import java.time.Duration;
|
||||
import java.util.Arrays;
|
||||
import java.util.Base64;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.zip.Inflater;
|
||||
import java.util.zip.InflaterOutputStream;
|
||||
import javax.servlet.ServletException;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
|
||||
@ -54,9 +58,12 @@ import org.springframework.security.core.AuthenticationException;
|
||||
import org.springframework.security.core.GrantedAuthority;
|
||||
import org.springframework.security.core.authority.SimpleGrantedAuthority;
|
||||
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
|
||||
import org.springframework.security.saml2.Saml2Exception;
|
||||
import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider;
|
||||
import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory;
|
||||
import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication;
|
||||
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
|
||||
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
|
||||
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
|
||||
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
||||
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
|
||||
@ -69,7 +76,11 @@ import org.springframework.security.web.context.HttpSessionSecurityContextReposi
|
||||
import org.springframework.security.web.context.SecurityContextRepository;
|
||||
import org.springframework.test.util.ReflectionTestUtils;
|
||||
import org.springframework.test.web.servlet.MockMvc;
|
||||
import org.springframework.test.web.servlet.MvcResult;
|
||||
import org.springframework.web.util.UriComponents;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
|
||||
import static java.nio.charset.StandardCharsets.UTF_8;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
@ -157,6 +168,20 @@ public class Saml2LoginConfigurerTests {
|
||||
verify(resolver).resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void authenticationRequestWhenAuthnRequestConsumerResolverThenUses() throws Exception {
|
||||
this.spring.register(CustomAuthnRequestConsumerResolver.class).autowire();
|
||||
|
||||
MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id"))
|
||||
.andReturn();
|
||||
UriComponents components = UriComponentsBuilder
|
||||
.fromHttpUrl(result.getResponse().getRedirectedUrl()).build();
|
||||
String samlRequest = components.getQueryParams().getFirst("SAMLRequest");
|
||||
String decoded = URLDecoder.decode(samlRequest, "UTF-8");
|
||||
String inflated = samlInflate(samlDecode(decoded));
|
||||
assertThat(inflated).contains("ForceAuthn=\"true\"");
|
||||
}
|
||||
|
||||
private void validateSaml2WebSsoAuthenticationFilterConfiguration() {
|
||||
// get the OpenSamlAuthenticationProvider
|
||||
Saml2WebSsoAuthenticationFilter filter = getSaml2SsoFilter(this.springSecurityFilterChain);
|
||||
@ -275,6 +300,29 @@ public class Saml2LoginConfigurerTests {
|
||||
}
|
||||
}
|
||||
|
||||
@EnableWebSecurity
|
||||
@Import(Saml2LoginConfigBeans.class)
|
||||
static class CustomAuthnRequestConsumerResolver extends WebSecurityConfigurerAdapter {
|
||||
|
||||
@Override
|
||||
protected void configure(HttpSecurity http) throws Exception {
|
||||
http
|
||||
.authorizeRequests(authz -> authz
|
||||
.anyRequest().authenticated()
|
||||
)
|
||||
.saml2Login(saml2 -> {});
|
||||
}
|
||||
|
||||
@Bean
|
||||
Saml2AuthenticationRequestFactory authenticationRequestFactory() {
|
||||
OpenSamlAuthenticationRequestFactory authenticationRequestFactory =
|
||||
new OpenSamlAuthenticationRequestFactory();
|
||||
authenticationRequestFactory.setAuthnRequestConsumerResolver(
|
||||
context -> authnRequest -> authnRequest.setForceAuthn(true));
|
||||
return authenticationRequestFactory;
|
||||
}
|
||||
}
|
||||
|
||||
private static AuthenticationManager getAuthenticationManagerMock(String role) {
|
||||
return new AuthenticationManager() {
|
||||
|
||||
@ -315,4 +363,23 @@ public class Saml2LoginConfigurerTests {
|
||||
}
|
||||
}
|
||||
|
||||
private static org.apache.commons.codec.binary.Base64 BASE64 =
|
||||
new org.apache.commons.codec.binary.Base64(0, new byte[]{'\n'});
|
||||
|
||||
private static byte[] samlDecode(String s) {
|
||||
return BASE64.decode(s);
|
||||
}
|
||||
|
||||
private static String samlInflate(byte[] b) {
|
||||
try {
|
||||
ByteArrayOutputStream out = new ByteArrayOutputStream();
|
||||
InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true));
|
||||
iout.write(b);
|
||||
iout.finish();
|
||||
return new String(out.toByteArray(), UTF_8);
|
||||
}
|
||||
catch (IOException e) {
|
||||
throw new Saml2Exception("Unable to inflate string", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -21,6 +21,8 @@ import java.time.Instant;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.function.Function;
|
||||
|
||||
import org.joda.time.DateTime;
|
||||
import org.opensaml.saml.common.xml.SAMLConstants;
|
||||
@ -43,6 +45,9 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
|
||||
private final OpenSamlImplementation saml = OpenSamlImplementation.getInstance();
|
||||
private String protocolBinding = SAMLConstants.SAML2_POST_BINDING_URI;
|
||||
|
||||
private Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver
|
||||
= context -> authnRequest -> {};
|
||||
|
||||
@Override
|
||||
@Deprecated
|
||||
public String createAuthenticationRequest(Saml2AuthenticationRequest request) {
|
||||
@ -95,8 +100,10 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
|
||||
}
|
||||
|
||||
private AuthnRequest createAuthnRequest(Saml2AuthenticationRequestContext context) {
|
||||
return createAuthnRequest(context.getIssuer(),
|
||||
AuthnRequest authnRequest = createAuthnRequest(context.getIssuer(),
|
||||
context.getDestination(), context.getAssertionConsumerServiceUrl());
|
||||
this.authnRequestConsumerResolver.apply(context).accept(authnRequest);
|
||||
return authnRequest;
|
||||
}
|
||||
|
||||
private AuthnRequest createAuthnRequest(String issuer, String destination, String assertionConsumerServiceUrl) {
|
||||
@ -114,6 +121,18 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
|
||||
return auth;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the {@link AuthnRequest} post-processor resolver
|
||||
*
|
||||
* @param authnRequestConsumerResolver
|
||||
* @since 5.4
|
||||
*/
|
||||
public void setAuthnRequestConsumerResolver(
|
||||
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver) {
|
||||
Assert.notNull(authnRequestConsumerResolver, "authnRequestConsumerResolver cannot be null");
|
||||
this.authnRequestConsumerResolver = authnRequestConsumerResolver;
|
||||
}
|
||||
|
||||
/**
|
||||
* '
|
||||
* Use this {@link Clock} with {@link Instant#now()} for generating
|
||||
|
@ -16,6 +16,9 @@
|
||||
|
||||
package org.springframework.security.saml2.provider.service.authentication;
|
||||
|
||||
import java.util.function.Consumer;
|
||||
import java.util.function.Function;
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
import org.junit.Rule;
|
||||
@ -29,9 +32,13 @@ import org.springframework.security.saml2.provider.service.registration.Saml2Mes
|
||||
|
||||
import static java.nio.charset.StandardCharsets.UTF_8;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatCode;
|
||||
import static org.hamcrest.CoreMatchers.containsString;
|
||||
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential;
|
||||
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode;
|
||||
import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration;
|
||||
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
|
||||
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT;
|
||||
@ -160,6 +167,34 @@ public class OpenSamlAuthenticationRequestFactoryTests {
|
||||
factory.setProtocolBinding("my-invalid-binding");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void createPostAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
|
||||
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver =
|
||||
mock(Function.class);
|
||||
when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {});
|
||||
this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver);
|
||||
|
||||
this.factory.createPostAuthenticationRequest(this.context);
|
||||
verify(authnRequestConsumerResolver).apply(this.context);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void createRedirectAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
|
||||
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver =
|
||||
mock(Function.class);
|
||||
when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {});
|
||||
this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver);
|
||||
|
||||
this.factory.createRedirectAuthenticationRequest(this.context);
|
||||
verify(authnRequestConsumerResolver).apply(this.context);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void setAuthnRequestConsumerResolverWhenNullThenException() {
|
||||
assertThatCode(() -> this.factory.setAuthnRequestConsumerResolver(null))
|
||||
.isInstanceOf(IllegalArgumentException.class);
|
||||
}
|
||||
|
||||
private AuthnRequest getAuthNRequest(Saml2MessageBinding binding) {
|
||||
AbstractSaml2AuthenticationRequest result = (binding == REDIRECT) ?
|
||||
factory.createRedirectAuthenticationRequest(context) :
|
||||
|
Loading…
x
Reference in New Issue
Block a user