Add RelayStateResolver

Co-authored-by: ghaege <ghaege@qaepps.de>

Closes gh-12538
This commit is contained in:
Josh Cummings 2023-02-16 11:43:35 -07:00
parent ab8337e371
commit c1c28375d6
3 changed files with 39 additions and 4 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2021 the original author or authors. * Copyright 2002-2023 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -23,6 +23,7 @@ import java.util.function.Consumer;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import org.opensaml.saml.saml2.core.LogoutRequest; import org.opensaml.saml.saml2.core.LogoutRequest;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequest; import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequest;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
@ -34,6 +35,7 @@ import org.springframework.util.Assert;
* OpenSAML 4 * OpenSAML 4
* *
* @author Josh Cummings * @author Josh Cummings
* @author Gerhard Haege
* @since 5.6 * @since 5.6
*/ */
public final class OpenSaml4LogoutRequestResolver implements Saml2LogoutRequestResolver { public final class OpenSaml4LogoutRequestResolver implements Saml2LogoutRequestResolver {
@ -83,6 +85,16 @@ public final class OpenSaml4LogoutRequestResolver implements Saml2LogoutRequestR
this.clock = clock; this.clock = clock;
} }
/**
* Use this {@link Converter} to compute the RelayState
* @param relayStateResolver the {@link Converter} to use
* @since 6.1
*/
public void setRelayStateResolver(Converter<HttpServletRequest, String> relayStateResolver) {
Assert.notNull(relayStateResolver, "relayStateResolver cannot be null");
this.logoutRequestResolver.setRelayStateResolver(relayStateResolver);
}
public static final class LogoutRequestParameters { public static final class LogoutRequestParameters {
private final HttpServletRequest request; private final HttpServletRequest request;

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2022 the original author or authors. * Copyright 2002-2023 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -38,6 +38,7 @@ import org.opensaml.saml.saml2.core.impl.NameIDBuilder;
import org.opensaml.saml.saml2.core.impl.SessionIndexBuilder; import org.opensaml.saml.saml2.core.impl.SessionIndexBuilder;
import org.w3c.dom.Element; import org.w3c.dom.Element;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.saml2.Saml2Exception; import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.core.OpenSamlInitializationService; import org.springframework.security.saml2.core.OpenSamlInitializationService;
@ -74,6 +75,8 @@ final class OpenSamlLogoutRequestResolver {
private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver; private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver;
private Converter<HttpServletRequest, String> relayStateResolver = (request) -> UUID.randomUUID().toString();
/** /**
* Construct a {@link OpenSamlLogoutRequestResolver} * Construct a {@link OpenSamlLogoutRequestResolver}
*/ */
@ -95,6 +98,10 @@ final class OpenSamlLogoutRequestResolver {
Assert.notNull(this.sessionIndexBuilder, "sessionIndexBuilder must be configured in OpenSAML"); Assert.notNull(this.sessionIndexBuilder, "sessionIndexBuilder must be configured in OpenSAML");
} }
void setRelayStateResolver(Converter<HttpServletRequest, String> relayStateResolver) {
this.relayStateResolver = relayStateResolver;
}
/** /**
* Prepare to create, sign, and serialize a SAML 2.0 Logout Request. * Prepare to create, sign, and serialize a SAML 2.0 Logout Request.
* *
@ -140,7 +147,7 @@ final class OpenSamlLogoutRequestResolver {
if (logoutRequest.getID() == null) { if (logoutRequest.getID() == null) {
logoutRequest.setID("LR" + UUID.randomUUID()); logoutRequest.setID("LR" + UUID.randomUUID());
} }
String relayState = UUID.randomUUID().toString(); String relayState = this.relayStateResolver.convert(request);
Saml2LogoutRequest.Builder result = Saml2LogoutRequest.withRelyingPartyRegistration(registration) Saml2LogoutRequest.Builder result = Saml2LogoutRequest.withRelyingPartyRegistration(registration)
.id(logoutRequest.getID()); .id(logoutRequest.getID());
if (registration.getAssertingPartyDetails().getSingleLogoutServiceBinding() == Saml2MessageBinding.POST) { if (registration.getAssertingPartyDetails().getSingleLogoutServiceBinding() == Saml2MessageBinding.POST) {

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2021 the original author or authors. * Copyright 2002-2023 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,9 +16,11 @@
package org.springframework.security.saml2.provider.service.web.authentication.logout; package org.springframework.security.saml2.provider.service.web.authentication.logout;
import jakarta.servlet.http.HttpServletRequest;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.core.convert.converter.Converter;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
@ -32,6 +34,7 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
/** /**
* Tests for {@link OpenSaml4LogoutRequestResolver} * Tests for {@link OpenSaml4LogoutRequestResolver}
@ -67,6 +70,19 @@ public class OpenSaml4LogoutRequestResolverTests {
.isThrownBy(() -> this.logoutRequestResolver.setParametersConsumer(null)); .isThrownBy(() -> this.logoutRequestResolver.setParametersConsumer(null));
} }
@Test
public void resolveWhenCustomRelayStateThenUses() {
given(this.registrationResolver.resolve(any(), any())).willReturn(this.registration);
Converter<HttpServletRequest, String> relayState = mock(Converter.class);
given(relayState.convert(any())).willReturn("any-state");
this.logoutRequestResolver.setRelayStateResolver(relayState);
Saml2LogoutRequest logoutRequest = this.logoutRequestResolver.resolve(givenRequest(), givenAuthentication());
assertThat(logoutRequest.getRelayState()).isEqualTo("any-state");
verify(relayState).convert(any());
}
private static Authentication givenAuthentication() { private static Authentication givenAuthentication() {
return new TestingAuthenticationToken("user", "password"); return new TestingAuthenticationToken("user", "password");
} }