Add @Nullable Annotations to saml2-service-provider

Issue gh-17823

Signed-off-by: Josh Cummings <3627351+jzheaux@users.noreply.github.com>
This commit is contained in:
Josh Cummings 2026-02-02 15:18:56 -07:00
parent f3656b4991
commit e771ec04b7
No known key found for this signature in database
GPG Key ID: 869B37A20E876129
97 changed files with 971 additions and 398 deletions

View File

@ -18,6 +18,8 @@ package org.springframework.security.saml2.core;
import java.io.Serializable;
import org.jspecify.annotations.Nullable;
import org.springframework.util.Assert;
/**
@ -37,14 +39,14 @@ public class Saml2Error implements Serializable {
private final String errorCode;
private final String description;
private final @Nullable String description;
/**
* Constructs a {@code Saml2Error} using the provided parameters.
* @param errorCode the error code
* @param description the error description
*/
public Saml2Error(String errorCode, String description) {
public Saml2Error(String errorCode, @Nullable String description) {
Assert.hasText(errorCode, "errorCode cannot be empty");
this.errorCode = errorCode;
this.description = description;
@ -56,7 +58,7 @@ public class Saml2Error implements Serializable {
* @return the resulting {@link Saml2Error}
* @since 7.0
*/
public static Saml2Error invalidResponse(String description) {
public static Saml2Error invalidResponse(@Nullable String description) {
return new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, description);
}
@ -66,7 +68,7 @@ public class Saml2Error implements Serializable {
* @return the resulting {@link Saml2Error}
* @since 7.0
*/
public static Saml2Error internalValidationError(String description) {
public static Saml2Error internalValidationError(@Nullable String description) {
return new Saml2Error(Saml2ErrorCodes.INTERNAL_VALIDATION_ERROR, description);
}
@ -76,7 +78,7 @@ public class Saml2Error implements Serializable {
* @return the resulting {@link Saml2Error}
* @since 7.0
*/
public static Saml2Error malformedResponseData(String description) {
public static Saml2Error malformedResponseData(@Nullable String description) {
return new Saml2Error(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, description);
}
@ -86,7 +88,7 @@ public class Saml2Error implements Serializable {
* @return the resulting {@link Saml2Error}
* @since 7.0
*/
public static Saml2Error decryptionError(String description) {
public static Saml2Error decryptionError(@Nullable String description) {
return new Saml2Error(Saml2ErrorCodes.DECRYPTION_ERROR, description);
}
@ -96,7 +98,7 @@ public class Saml2Error implements Serializable {
* @return the resulting {@link Saml2Error}
* @since 7.0
*/
public static Saml2Error relyingPartyRegistrationNotFound(String description) {
public static Saml2Error relyingPartyRegistrationNotFound(@Nullable String description) {
return new Saml2Error(Saml2ErrorCodes.RELYING_PARTY_REGISTRATION_NOT_FOUND, description);
}
@ -106,7 +108,7 @@ public class Saml2Error implements Serializable {
* @return the resulting {@link Saml2Error}
* @since 7.0
*/
public static Saml2Error subjectNotFound(String description) {
public static Saml2Error subjectNotFound(@Nullable String description) {
return new Saml2Error(Saml2ErrorCodes.SUBJECT_NOT_FOUND, description);
}
@ -122,7 +124,7 @@ public class Saml2Error implements Serializable {
* Returns the error description.
* @return the error description
*/
public final String getDescription() {
public final @Nullable String getDescription() {
return this.description;
}

View File

@ -24,6 +24,8 @@ import java.util.LinkedHashSet;
import java.util.Objects;
import java.util.Set;
import org.jspecify.annotations.Nullable;
import org.springframework.util.Assert;
/**
@ -40,7 +42,7 @@ public final class Saml2X509Credential implements Serializable {
private static final long serialVersionUID = -1015853414272603517L;
private final PrivateKey privateKey;
private final @Nullable PrivateKey privateKey;
private final X509Certificate certificate;
@ -77,7 +79,8 @@ public final class Saml2X509Credential implements Serializable {
* @param certificate the credential's public certificate
* @param types the credential's intended usages
*/
public Saml2X509Credential(PrivateKey privateKey, X509Certificate certificate, Set<Saml2X509CredentialType> types) {
public Saml2X509Credential(@Nullable PrivateKey privateKey, X509Certificate certificate,
Set<Saml2X509CredentialType> types) {
Assert.notNull(certificate, "certificate cannot be null");
Assert.notNull(types, "credentialTypes cannot be null");
this.privateKey = privateKey;
@ -123,7 +126,7 @@ public final class Saml2X509Credential implements Serializable {
return new Saml2X509Credential(privateKey, certificate, Saml2X509Credential.Saml2X509CredentialType.SIGNING);
}
private Saml2X509Credential(PrivateKey privateKey, boolean keyRequired, X509Certificate certificate,
private Saml2X509Credential(@Nullable PrivateKey privateKey, boolean keyRequired, X509Certificate certificate,
Saml2X509CredentialType... types) {
Assert.notNull(certificate, "certificate cannot be null");
Assert.notEmpty(types, "credentials types cannot be empty");
@ -140,7 +143,7 @@ public final class Saml2X509Credential implements Serializable {
* @return the private key, may be null
* @see #Saml2X509Credential(PrivateKey, X509Certificate, Saml2X509CredentialType...)
*/
public PrivateKey getPrivateKey() {
public @Nullable PrivateKey getPrivateKey() {
return this.privateKey;
}

View File

@ -0,0 +1,23 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Core SAML2 types and utilities.
*/
@NullMarked
package org.springframework.security.saml2.core;
import org.jspecify.annotations.NullMarked;

View File

@ -25,6 +25,7 @@ import java.util.Objects;
import javax.xml.namespace.QName;
import org.jspecify.annotations.Nullable;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.saml.saml2.core.Issuer;
import org.opensaml.saml.saml2.core.RequestAbstractType;
@ -35,6 +36,7 @@ import org.w3c.dom.Element;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.core.Saml2ParameterNames;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.util.Assert;
import org.springframework.web.util.UriComponentsBuilder;
interface OpenSamlOperations {
@ -89,14 +91,17 @@ interface OpenSamlOperations {
private final String algorithm;
private final byte[] signature;
private final byte @Nullable [] signature;
private final byte[] content;
RedirectParameters(Map<String, String> parameters, String parametersQuery, RequestAbstractType request) {
Assert.notNull(request.getID(), "SAML request's ID cannot be null");
Assert.notNull(request.getIssuer(), "SAML request's Issuer cannot be null");
this.id = request.getID();
this.issuer = request.getIssuer();
this.algorithm = parameters.get(Saml2ParameterNames.SIG_ALG);
this.algorithm = Objects.requireNonNull(parameters.get(Saml2ParameterNames.SIG_ALG),
"sigAlg parameter cannot be null");
if (parameters.get(Saml2ParameterNames.SIGNATURE) != null) {
this.signature = Saml2Utils.samlDecode(parameters.get(Saml2ParameterNames.SIGNATURE));
}
@ -113,9 +118,12 @@ interface OpenSamlOperations {
}
RedirectParameters(Map<String, String> parameters, String parametersQuery, StatusResponseType response) {
Assert.notNull(response.getID(), "SAML response's ID cannot be null");
Assert.notNull(response.getIssuer(), "SAML response's Issuer cannot be null");
this.id = response.getID();
this.issuer = response.getIssuer();
this.algorithm = parameters.get(Saml2ParameterNames.SIG_ALG);
this.algorithm = Objects.requireNonNull(parameters.get(Saml2ParameterNames.SIG_ALG),
"sigAlg parameter cannot be null");
if (parameters.get(Saml2ParameterNames.SIGNATURE) != null) {
this.signature = Saml2Utils.samlDecode(parameters.get(Saml2ParameterNames.SIGNATURE));
}
@ -131,7 +139,8 @@ interface OpenSamlOperations {
this.content = getContent(Saml2ParameterNames.SAML_RESPONSE, relayState, queryParams);
}
static byte[] getContent(String samlObject, String relayState, final Map<String, String> queryParams) {
static byte[] getContent(String samlObject, @Nullable String relayState,
final Map<String, String> queryParams) {
if (Objects.nonNull(relayState)) {
return String
.format("%s=%s&%s=%s&%s=%s", samlObject, queryParams.get(samlObject),
@ -163,7 +172,7 @@ interface OpenSamlOperations {
return this.algorithm;
}
byte[] getSignature() {
byte @Nullable [] getSignature() {
return this.signature;
}

View File

@ -0,0 +1,23 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Internal utilities for SAML2 support (not for public use).
*/
@NullMarked
package org.springframework.security.saml2.internal;
import org.jspecify.annotations.NullMarked;

View File

@ -22,6 +22,7 @@ import java.util.Map;
import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import org.jspecify.annotations.NullUnmarked;
import org.springframework.security.jackson.SecurityJacksonModules;
import org.springframework.security.saml2.provider.service.authentication.DefaultSaml2AuthenticatedPrincipal;
@ -38,6 +39,7 @@ import org.springframework.security.saml2.provider.service.authentication.Defaul
*/
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS)
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE)
@NullUnmarked
class DefaultSaml2AuthenticatedPrincipalMixin {
@JsonProperty("registrationId")

View File

@ -22,6 +22,7 @@ import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import org.jspecify.annotations.NullUnmarked;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.jackson.SecurityJacksonModules;
@ -41,6 +42,7 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2R
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS)
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE,
isGetterVisibility = JsonAutoDetect.Visibility.NONE)
@NullUnmarked
class Saml2AssertionAuthenticationMixin {
@JsonCreator

View File

@ -21,6 +21,7 @@ import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import org.jspecify.annotations.NullUnmarked;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
@ -38,6 +39,7 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2A
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.NONE, getterVisibility = JsonAutoDetect.Visibility.NONE,
isGetterVisibility = JsonAutoDetect.Visibility.NONE)
@JsonIgnoreProperties({ "cause", "stackTrace", "suppressedExceptions" })
@NullUnmarked
abstract class Saml2AuthenticationExceptionMixin {
@JsonProperty("error")

View File

@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import org.jspecify.annotations.NullUnmarked;
import org.springframework.security.core.AuthenticatedPrincipal;
import org.springframework.security.core.GrantedAuthority;
@ -41,6 +42,7 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2A
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE,
isGetterVisibility = JsonAutoDetect.Visibility.NONE)
@JsonIgnoreProperties({ "authenticated" })
@NullUnmarked
class Saml2AuthenticationMixin {
@JsonCreator

View File

@ -20,6 +20,7 @@ import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import org.jspecify.annotations.NullUnmarked;
import org.springframework.security.saml2.core.Saml2Error;
@ -35,6 +36,7 @@ import org.springframework.security.saml2.core.Saml2Error;
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS)
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE,
isGetterVisibility = JsonAutoDetect.Visibility.NONE)
@NullUnmarked
class Saml2ErrorMixin {
@JsonCreator

View File

@ -24,6 +24,7 @@ import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import org.jspecify.annotations.NullUnmarked;
import org.springframework.security.jackson.SecurityJacksonModules;
import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequest;
@ -40,6 +41,7 @@ import org.springframework.security.saml2.provider.service.registration.Saml2Mes
*/
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS)
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE)
@NullUnmarked
class Saml2LogoutRequestMixin {
@JsonIgnore

View File

@ -20,6 +20,7 @@ import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import org.jspecify.annotations.NullUnmarked;
import org.springframework.security.jackson.SecurityJacksonModules;
import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
@ -36,6 +37,7 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2P
*/
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS)
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE)
@NullUnmarked
class Saml2PostAuthenticationRequestMixin {
@JsonCreator

View File

@ -20,6 +20,7 @@ import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import org.jspecify.annotations.NullUnmarked;
import org.springframework.security.jackson.SecurityJacksonModules;
import org.springframework.security.saml2.provider.service.authentication.Saml2RedirectAuthenticationRequest;
@ -36,6 +37,7 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2R
*/
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS)
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE)
@NullUnmarked
class Saml2RedirectAuthenticationRequestMixin {
@JsonCreator

View File

@ -24,6 +24,7 @@ import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import org.jspecify.annotations.NullUnmarked;
import org.springframework.security.jackson.SecurityJacksonModules;
import org.springframework.security.saml2.provider.service.authentication.Saml2ResponseAssertion;
@ -41,6 +42,7 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2R
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE,
isGetterVisibility = JsonAutoDetect.Visibility.NONE)
@JsonIgnoreProperties({ "authenticated" })
@NullUnmarked
class SimpleSaml2ResponseAssertionAccessorMixin {
@JsonCreator

View File

@ -17,4 +17,7 @@
/**
* Jackson 3+ serialization support for SAML2.
*/
@NullMarked
package org.springframework.security.saml2.jackson;
import org.jspecify.annotations.NullMarked;

View File

@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import org.jspecify.annotations.NullUnmarked;
import org.springframework.security.jackson2.SecurityJackson2Modules;
import org.springframework.security.saml2.provider.service.authentication.DefaultSaml2AuthenticatedPrincipal;
@ -49,6 +50,7 @@ import org.springframework.security.saml2.provider.service.authentication.Defaul
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS)
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE)
@JsonIgnoreProperties(ignoreUnknown = true)
@NullUnmarked
class DefaultSaml2AuthenticatedPrincipalMixin {
@JsonProperty("registrationId")

View File

@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import org.jspecify.annotations.NullUnmarked;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.jackson2.SecurityJackson2Modules;
@ -52,6 +53,7 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2R
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE,
isGetterVisibility = JsonAutoDetect.Visibility.NONE)
@JsonIgnoreProperties(value = { "authenticated" }, ignoreUnknown = true)
@NullUnmarked
class Saml2AssertionAuthenticationMixin {
@JsonCreator

View File

@ -21,6 +21,7 @@ import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import org.jspecify.annotations.NullUnmarked;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
@ -42,6 +43,7 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2A
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.NONE, getterVisibility = JsonAutoDetect.Visibility.NONE,
isGetterVisibility = JsonAutoDetect.Visibility.NONE)
@JsonIgnoreProperties(ignoreUnknown = true, value = { "cause", "stackTrace", "suppressedExceptions" })
@NullUnmarked
abstract class Saml2AuthenticationExceptionMixin {
@JsonProperty("error")

View File

@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import org.jspecify.annotations.NullUnmarked;
import org.springframework.security.core.AuthenticatedPrincipal;
import org.springframework.security.core.GrantedAuthority;
@ -51,6 +52,7 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2A
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE,
isGetterVisibility = JsonAutoDetect.Visibility.NONE)
@JsonIgnoreProperties(value = { "authenticated" }, ignoreUnknown = true)
@NullUnmarked
class Saml2AuthenticationMixin {
@JsonCreator

View File

@ -21,6 +21,7 @@ import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import org.jspecify.annotations.NullUnmarked;
import org.springframework.security.saml2.core.Saml2Error;
@ -40,6 +41,7 @@ import org.springframework.security.saml2.core.Saml2Error;
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE,
isGetterVisibility = JsonAutoDetect.Visibility.NONE)
@JsonIgnoreProperties(ignoreUnknown = true)
@NullUnmarked
class Saml2ErrorMixin {
@JsonCreator

View File

@ -25,6 +25,7 @@ import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import org.jspecify.annotations.NullUnmarked;
import org.springframework.security.jackson2.SecurityJackson2Modules;
import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequest;
@ -51,6 +52,7 @@ import org.springframework.security.saml2.provider.service.registration.Saml2Mes
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS)
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE)
@JsonIgnoreProperties(ignoreUnknown = true)
@NullUnmarked
class Saml2LogoutRequestMixin {
@JsonIgnore

View File

@ -21,6 +21,7 @@ import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import org.jspecify.annotations.NullUnmarked;
import org.springframework.security.jackson2.SecurityJackson2Modules;
import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
@ -47,6 +48,7 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2P
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS)
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE)
@JsonIgnoreProperties(ignoreUnknown = true)
@NullUnmarked
class Saml2PostAuthenticationRequestMixin {
@JsonCreator

View File

@ -21,6 +21,7 @@ import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import org.jspecify.annotations.NullUnmarked;
import org.springframework.security.jackson2.SecurityJackson2Modules;
import org.springframework.security.saml2.provider.service.authentication.Saml2RedirectAuthenticationRequest;
@ -47,6 +48,7 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2R
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS)
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE)
@JsonIgnoreProperties(ignoreUnknown = true)
@NullUnmarked
class Saml2RedirectAuthenticationRequestMixin {
@JsonCreator

View File

@ -24,6 +24,7 @@ import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import org.jspecify.annotations.NullUnmarked;
import org.springframework.security.jackson2.SecurityJackson2Modules;
import org.springframework.security.saml2.provider.service.authentication.Saml2ResponseAssertion;
@ -50,6 +51,7 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2R
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE,
isGetterVisibility = JsonAutoDetect.Visibility.NONE)
@JsonIgnoreProperties(value = { "authenticated" }, ignoreUnknown = true)
@NullUnmarked
class SimpleSaml2ResponseAssertionAccessorMixin {
@JsonCreator

View File

@ -17,4 +17,7 @@
/**
* Jackson 2 serialization support for SAML2.
*/
@NullMarked
package org.springframework.security.saml2.jackson2;
import org.jspecify.annotations.NullMarked;

View File

@ -0,0 +1,23 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Core SAML2 support for Spring Security.
*/
@NullMarked
package org.springframework.security.saml2;
import org.jspecify.annotations.NullMarked;

View File

@ -19,6 +19,8 @@ package org.springframework.security.saml2.provider.service.authentication;
import java.io.Serializable;
import java.nio.charset.Charset;
import org.jspecify.annotations.Nullable;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.util.Assert;
@ -42,13 +44,13 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable
private final String samlRequest;
private final String relayState;
private final @Nullable String relayState;
private final String authenticationRequestUri;
private final String relyingPartyRegistrationId;
private final @Nullable String relyingPartyRegistrationId;
private final String id;
private final @Nullable String id;
/**
* Mandatory constructor for the {@link AbstractSaml2AuthenticationRequest}
@ -62,8 +64,8 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable
* @param id This is the unique id used in the {@link #samlRequest}, cannot be empty
* or null
*/
AbstractSaml2AuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri,
String relyingPartyRegistrationId, String id) {
AbstractSaml2AuthenticationRequest(String samlRequest, @Nullable String relayState, String authenticationRequestUri,
@Nullable String relyingPartyRegistrationId, @Nullable String id) {
Assert.hasText(samlRequest, "samlRequest cannot be null or empty");
Assert.hasText(authenticationRequestUri, "authenticationRequestUri cannot be null or empty");
this.authenticationRequestUri = authenticationRequestUri;
@ -88,7 +90,7 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable
* Returns the RelayState value, if present in the parameters
* @return the RelayState value, or null if not available
*/
public String getRelayState() {
public @Nullable String getRelayState() {
return this.relayState;
}
@ -106,7 +108,7 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable
* @return the {@link RelyingPartyRegistration} id
* @since 5.8
*/
public String getRelyingPartyRegistrationId() {
public @Nullable String getRelyingPartyRegistrationId() {
return this.relyingPartyRegistrationId;
}
@ -115,7 +117,7 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable
* @return the Authentication Request identifier
* @since 5.8
*/
public String getId() {
public @Nullable String getId() {
return this.id;
}
@ -132,15 +134,15 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable
*/
public static class Builder<T extends Builder<T>> {
String authenticationRequestUri;
@Nullable String authenticationRequestUri;
String samlRequest;
@Nullable String samlRequest;
String relayState;
@Nullable String relayState;
String relyingPartyRegistrationId;
@Nullable String relyingPartyRegistrationId;
String id;
@Nullable String id;
/**
* @deprecated Use {@link #Builder(RelyingPartyRegistration)} instead
@ -173,7 +175,7 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable
* parameter will be removed from the map.
* @return this object
*/
public T relayState(String relayState) {
public T relayState(@Nullable String relayState) {
this.relayState = relayState;
return _this();
}

View File

@ -26,6 +26,7 @@ import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Consumer;
@ -33,6 +34,7 @@ import javax.xml.namespace.QName;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.schema.XSAny;
import org.opensaml.core.xml.schema.XSBoolean;
@ -58,15 +60,17 @@ import org.opensaml.saml.saml2.core.Attribute;
import org.opensaml.saml.saml2.core.AttributeStatement;
import org.opensaml.saml.saml2.core.AuthnStatement;
import org.opensaml.saml.saml2.core.Condition;
import org.opensaml.saml.saml2.core.NameID;
import org.opensaml.saml.saml2.core.OneTimeUse;
import org.opensaml.saml.saml2.core.Response;
import org.opensaml.saml.saml2.core.Status;
import org.opensaml.saml.saml2.core.StatusCode;
import org.opensaml.saml.saml2.core.Subject;
import org.opensaml.saml.saml2.core.SubjectConfirmation;
import org.opensaml.saml.saml2.core.SubjectConfirmationData;
import org.springframework.core.convert.converter.Converter;
import org.springframework.core.log.LogMessage;
import org.springframework.lang.NonNull;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication;
@ -189,7 +193,7 @@ class BaseOpenSamlAuthenticationProvider implements AuthenticationProvider {
};
}
private static String issuer(Response response) {
private static @Nullable String issuer(Response response) {
if (response.getIssuer() == null) {
return null;
}
@ -197,18 +201,20 @@ class BaseOpenSamlAuthenticationProvider implements AuthenticationProvider {
}
static List<String> getStatusCodes(Response response) {
if (response.getStatus() == null) {
Status status = response.getStatus();
if (status == null) {
return List.of(StatusCode.SUCCESS);
}
if (response.getStatus().getStatusCode() == null) {
StatusCode statusCode = status.getStatusCode();
if (statusCode == null) {
return List.of(StatusCode.SUCCESS);
}
StatusCode parentStatusCode = response.getStatus().getStatusCode();
String parentStatusCodeValue = parentStatusCode.getValue();
String parentStatusCodeValue = statusCode.getValue();
Assert.notNull(parentStatusCodeValue, "Response#Status#StatusCode has not value");
if (!includeChildStatusCodes.contains(parentStatusCodeValue)) {
return List.of(parentStatusCodeValue);
}
StatusCode childStatusCode = parentStatusCode.getStatusCode();
StatusCode childStatusCode = statusCode.getStatusCode();
if (childStatusCode == null) {
return List.of(parentStatusCodeValue);
}
@ -228,8 +234,8 @@ class BaseOpenSamlAuthenticationProvider implements AuthenticationProvider {
return StatusCode.SUCCESS.equals(statusCode);
}
static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2AuthenticationRequest storedRequest,
String inResponseTo) {
static Saml2ResponseValidatorResult validateInResponseTo(@Nullable AbstractSaml2AuthenticationRequest storedRequest,
@Nullable String inResponseTo) {
if (!StringUtils.hasText(inResponseTo)) {
return Saml2ResponseValidatorResult.success();
}
@ -265,7 +271,13 @@ class BaseOpenSamlAuthenticationProvider implements AuthenticationProvider {
Response response = responseToken.response;
Saml2AuthenticationToken token = responseToken.token;
Assertion assertion = CollectionUtils.firstElement(response.getAssertions());
String username = assertion.getSubject().getNameID().getValue();
Assert.notNull(assertion, "response must have at least one assertion");
Subject subject = assertion.getSubject();
Assert.notNull(subject, "response assertion must have a subject");
NameID nameId = subject.getNameID();
Assert.notNull(nameId, "response assertion subject must have a nameId");
String username = nameId.getValue();
Assert.notNull(username, "required elements must have a value");
Map<String, List<Object>> attributes = getAssertionAttributes(assertion);
List<String> sessionIndexes = getSessionIndexes(assertion);
DefaultSaml2AuthenticatedPrincipal principal = new DefaultSaml2AuthenticatedPrincipal(username, attributes,
@ -301,7 +313,7 @@ class BaseOpenSamlAuthenticationProvider implements AuthenticationProvider {
@Override
public boolean supports(Class<?> authentication) {
return authentication != null && Saml2AuthenticationToken.class.isAssignableFrom(authentication);
return Saml2AuthenticationToken.class.isAssignableFrom(authentication);
}
private Response parseResponse(String response) throws Saml2Exception, Saml2AuthenticationException {
@ -435,7 +447,7 @@ class BaseOpenSamlAuthenticationProvider implements AuthenticationProvider {
};
}
static boolean hasName(Assertion assertion) {
static boolean hasName(@Nullable Assertion assertion) {
if (assertion == null) {
return false;
}
@ -459,7 +471,9 @@ class BaseOpenSamlAuthenticationProvider implements AuthenticationProvider {
attributeValues.add(attributeValue);
}
}
attributeMap.addAll(attribute.getName(), attributeValues);
String name = attribute.getName();
Assert.notNull(name, "all attributes must have a name");
attributeMap.addAll(name, attributeValues);
}
}
return new LinkedHashMap<>(attributeMap); // gh-11785
@ -468,12 +482,15 @@ class BaseOpenSamlAuthenticationProvider implements AuthenticationProvider {
static List<String> getSessionIndexes(Assertion assertion) {
List<String> sessionIndexes = new ArrayList<>();
for (AuthnStatement statement : assertion.getAuthnStatements()) {
sessionIndexes.add(statement.getSessionIndex());
String sessionIndex = statement.getSessionIndex();
if (sessionIndex != null) {
sessionIndexes.add(sessionIndex);
}
}
return sessionIndexes;
}
private static Object getXmlObjectValue(XMLObject xmlObject) {
private static @Nullable Object getXmlObjectValue(XMLObject xmlObject) {
if (xmlObject instanceof XSAny) {
return ((XSAny) xmlObject).getTextContent();
}
@ -504,6 +521,7 @@ class BaseOpenSamlAuthenticationProvider implements AuthenticationProvider {
Assertion assertion = assertionToken.assertion;
SAML20AssertionValidator validator = validatorConverter.convert(assertionToken);
ValidationContext context = contextConverter.convert(assertionToken);
Response response = (Response) Objects.requireNonNull(assertion.getParent());
try {
ValidationResult result = validator.validate(assertion, context);
if (result == ValidationResult.VALID) {
@ -512,11 +530,11 @@ class BaseOpenSamlAuthenticationProvider implements AuthenticationProvider {
}
catch (Exception ex) {
String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s", assertion.getID(),
((Response) assertion.getParent()).getID(), ex.getMessage());
response.getID(), ex.getMessage());
return Saml2ResponseValidatorResult.failure(new Saml2Error(errorCode, message));
}
String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s", assertion.getID(),
((Response) assertion.getParent()).getID(), context.getValidationFailureMessages());
response.getID(), context.getValidationFailureMessages());
return Saml2ResponseValidatorResult.failure(new Saml2Error(errorCode, message));
};
}
@ -557,7 +575,7 @@ class BaseOpenSamlAuthenticationProvider implements AuthenticationProvider {
return false;
}
private static String getAuthnRequestId(AbstractSaml2AuthenticationRequest serialized) {
private static @Nullable String getAuthnRequestId(@Nullable AbstractSaml2AuthenticationRequest serialized) {
return (serialized != null) ? serialized.getId() : null;
}
@ -573,13 +591,11 @@ class BaseOpenSamlAuthenticationProvider implements AuthenticationProvider {
conditions.add(new AudienceRestrictionConditionValidator());
conditions.add(new DelegationRestrictionConditionValidator());
conditions.add(new ConditionValidator() {
@NonNull
@Override
public QName getServicedCondition() {
return OneTimeUse.DEFAULT_ELEMENT_NAME;
}
@NonNull
@Override
public ValidationResult validate(Condition condition, Assertion assertion, ValidationContext context) {
// applications should validate their own OneTimeUse conditions
@ -588,16 +604,13 @@ class BaseOpenSamlAuthenticationProvider implements AuthenticationProvider {
});
conditions.add(new ProxyRestrictionConditionValidator());
subjects.add(new BearerSubjectConfirmationValidator() {
@NonNull
protected ValidationResult validateAddress(@NonNull SubjectConfirmation confirmation,
@NonNull Assertion assertion, @NonNull ValidationContext context, boolean required)
throws AssertionValidationException {
protected ValidationResult validateAddress(SubjectConfirmation confirmation, Assertion assertion,
ValidationContext context, boolean required) throws AssertionValidationException {
return ValidationResult.VALID;
}
@NonNull
protected ValidationResult validateAddress(@NonNull SubjectConfirmationData confirmationData,
@NonNull Assertion assertion, @NonNull ValidationContext context, boolean required)
protected ValidationResult validateAddress(SubjectConfirmationData confirmationData,
Assertion assertion, ValidationContext context, boolean required)
throws AssertionValidationException {
// applications should validate their own addresses - gh-7514
return ValidationResult.VALID;
@ -607,7 +620,6 @@ class BaseOpenSamlAuthenticationProvider implements AuthenticationProvider {
static final SAML20AssertionValidator attributeValidator = new SAML20AssertionValidator(conditions, subjects,
statements, null, null, null) {
@NonNull
@Override
protected ValidationResult validateSignature(Assertion token, ValidationContext context) {
return ValidationResult.VALID;

View File

@ -23,6 +23,8 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.jspecify.annotations.NullUnmarked;
import org.springframework.util.Assert;
/**
@ -33,6 +35,7 @@ import org.springframework.util.Assert;
* @deprecated Please use {@link Saml2ResponseAssertionAccessor}
*/
@Deprecated
@NullUnmarked
public class DefaultSaml2AuthenticatedPrincipal implements Saml2AuthenticatedPrincipal, Serializable {
@Serial

View File

@ -25,6 +25,7 @@ import java.util.Objects;
import javax.xml.namespace.QName;
import org.jspecify.annotations.Nullable;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.saml.saml2.core.Issuer;
import org.opensaml.saml.saml2.core.RequestAbstractType;
@ -35,6 +36,7 @@ import org.w3c.dom.Element;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.core.Saml2ParameterNames;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.util.Assert;
import org.springframework.web.util.UriComponentsBuilder;
interface OpenSamlOperations {
@ -89,14 +91,17 @@ interface OpenSamlOperations {
private final String algorithm;
private final byte[] signature;
private final byte @Nullable [] signature;
private final byte[] content;
RedirectParameters(Map<String, String> parameters, String parametersQuery, RequestAbstractType request) {
Assert.notNull(request.getID(), "SAML request's ID cannot be null");
Assert.notNull(request.getIssuer(), "SAML request's Issuer cannot be null");
this.id = request.getID();
this.issuer = request.getIssuer();
this.algorithm = parameters.get(Saml2ParameterNames.SIG_ALG);
this.algorithm = Objects.requireNonNull(parameters.get(Saml2ParameterNames.SIG_ALG),
"sigAlg parameter cannot be null");
if (parameters.get(Saml2ParameterNames.SIGNATURE) != null) {
this.signature = Saml2Utils.samlDecode(parameters.get(Saml2ParameterNames.SIGNATURE));
}
@ -113,9 +118,12 @@ interface OpenSamlOperations {
}
RedirectParameters(Map<String, String> parameters, String parametersQuery, StatusResponseType response) {
Assert.notNull(response.getID(), "SAML response's ID cannot be null");
Assert.notNull(response.getIssuer(), "SAML response's Issuer cannot be null");
this.id = response.getID();
this.issuer = response.getIssuer();
this.algorithm = parameters.get(Saml2ParameterNames.SIG_ALG);
this.algorithm = Objects.requireNonNull(parameters.get(Saml2ParameterNames.SIG_ALG),
"sigAlg parameter cannot be null");
if (parameters.get(Saml2ParameterNames.SIGNATURE) != null) {
this.signature = Saml2Utils.samlDecode(parameters.get(Saml2ParameterNames.SIGNATURE));
}
@ -131,7 +139,8 @@ interface OpenSamlOperations {
this.content = getContent(Saml2ParameterNames.SAML_RESPONSE, relayState, queryParams);
}
static byte[] getContent(String samlObject, String relayState, final Map<String, String> queryParams) {
static byte[] getContent(String samlObject, @Nullable String relayState,
final Map<String, String> queryParams) {
if (Objects.nonNull(relayState)) {
return String
.format("%s=%s&%s=%s&%s=%s", samlObject, queryParams.get(samlObject),
@ -163,7 +172,7 @@ interface OpenSamlOperations {
return this.algorithm;
}
byte[] getSignature() {
byte @Nullable [] getSignature() {
return this.signature;
}

View File

@ -20,7 +20,8 @@ import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.NullUnmarked;
import org.springframework.security.core.AuthenticatedPrincipal;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.util.CollectionUtils;
@ -35,6 +36,7 @@ import org.springframework.util.CollectionUtils;
* {@link Saml2ResponseAssertionAccessor} instead
*/
@Deprecated
@NullUnmarked
public interface Saml2AuthenticatedPrincipal extends AuthenticatedPrincipal {
/**
@ -44,7 +46,6 @@ public interface Saml2AuthenticatedPrincipal extends AuthenticatedPrincipal {
* @return the first attribute value or {@code null} otherwise
* @since 5.4
*/
@Nullable
default <A> A getFirstAttribute(String name) {
List<A> values = getAttribute(name);
return CollectionUtils.firstElement(values);
@ -57,7 +58,6 @@ public interface Saml2AuthenticatedPrincipal extends AuthenticatedPrincipal {
* @return the attribute or {@code null} otherwise
* @since 5.4
*/
@Nullable
default <A> List<A> getAttribute(String name) {
return (List<A>) getAttributes().get(name);
}

View File

@ -18,6 +18,8 @@ package org.springframework.security.saml2.provider.service.authentication;
import java.io.Serial;
import org.jspecify.annotations.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.saml2.core.Saml2Error;
@ -52,7 +54,7 @@ public class Saml2AuthenticationException extends AuthenticationException {
* @param error the {@link Saml2Error SAML 2.0 Error}
*/
public Saml2AuthenticationException(Saml2Error error) {
this(error, error.getDescription());
this(error, defaultMessage(error.getDescription(), error.getErrorCode()));
}
/**
@ -61,7 +63,8 @@ public class Saml2AuthenticationException extends AuthenticationException {
* @param cause the root cause
*/
public Saml2AuthenticationException(Saml2Error error, Throwable cause) {
this(error, (cause != null) ? cause.getMessage() : error.getDescription(), cause);
this(error, defaultMessage((cause != null) ? cause.getMessage() : error.getDescription(), error.getErrorCode()),
cause);
}
/**
@ -69,8 +72,9 @@ public class Saml2AuthenticationException extends AuthenticationException {
* @param error the {@link Saml2Error SAML 2.0 Error}
* @param message the detail message
*/
public Saml2AuthenticationException(Saml2Error error, String message) {
this(error, message, null);
public Saml2AuthenticationException(Saml2Error error, @Nullable String message) {
super(defaultMessage(message, error.getErrorCode()));
this.error = error;
}
/**
@ -79,12 +83,16 @@ public class Saml2AuthenticationException extends AuthenticationException {
* @param message the detail message
* @param cause the root cause
*/
public Saml2AuthenticationException(Saml2Error error, String message, Throwable cause) {
super(message, cause);
public Saml2AuthenticationException(Saml2Error error, @Nullable String message, Throwable cause) {
super(defaultMessage(message, error.getErrorCode()), cause);
Assert.notNull(error, "error cannot be null");
this.error = error;
}
private static String defaultMessage(@Nullable String message, String errorCode) {
return (message != null) ? message : errorCode;
}
/**
* Get the associated {@link Saml2Error}
* @return the associated {@link Saml2Error}

View File

@ -18,6 +18,8 @@ package org.springframework.security.saml2.provider.service.authentication;
import java.util.Collections;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.web.authentication.Saml2WebSsoAuthenticationFilter;
@ -37,7 +39,7 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
private final String saml2Response;
private final AbstractSaml2AuthenticationRequest authenticationRequest;
private final @Nullable AbstractSaml2AuthenticationRequest authenticationRequest;
/**
* Creates a {@link Saml2AuthenticationToken} with the provided parameters.
@ -53,7 +55,7 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
* @since 5.6
*/
public Saml2AuthenticationToken(RelyingPartyRegistration relyingPartyRegistration, String saml2Response,
AbstractSaml2AuthenticationRequest authenticationRequest) {
@Nullable AbstractSaml2AuthenticationRequest authenticationRequest) {
super(Collections.emptyList());
Assert.notNull(relyingPartyRegistration, "relyingPartyRegistration cannot be null");
Assert.notNull(saml2Response, "saml2Response cannot be null");
@ -92,7 +94,7 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
* @return null
*/
@Override
public Object getPrincipal() {
public @Nullable Object getPrincipal() {
return null;
}
@ -136,7 +138,7 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
* @return the authentication request sent to the assertion party
* @since 5.6
*/
public AbstractSaml2AuthenticationRequest getAuthenticationRequest() {
public @Nullable AbstractSaml2AuthenticationRequest getAuthenticationRequest() {
return this.authenticationRequest;
}

View File

@ -18,8 +18,11 @@ package org.springframework.security.saml2.provider.service.authentication;
import java.io.Serial;
import org.jspecify.annotations.Nullable;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.util.Assert;
/**
* Data holder for information required to send an {@code AuthNRequest} over a POST
@ -35,8 +38,8 @@ public class Saml2PostAuthenticationRequest extends AbstractSaml2AuthenticationR
@Serial
private static final long serialVersionUID = -6412064305715642123L;
Saml2PostAuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri,
String relyingPartyRegistrationId, String id) {
Saml2PostAuthenticationRequest(String samlRequest, @Nullable String relayState, String authenticationRequestUri,
@Nullable String relyingPartyRegistrationId, @Nullable String id) {
super(samlRequest, relayState, authenticationRequestUri, relyingPartyRegistrationId, id);
}
@ -73,6 +76,8 @@ public class Saml2PostAuthenticationRequest extends AbstractSaml2AuthenticationR
* @return an immutable {@link Saml2PostAuthenticationRequest} object.
*/
public Saml2PostAuthenticationRequest build() {
Assert.notNull(this.samlRequest, "samlRequest cannot be null");
Assert.notNull(this.authenticationRequestUri, "authenticationRequestUri cannot be null");
return new Saml2PostAuthenticationRequest(this.samlRequest, this.relayState, this.authenticationRequestUri,
this.relyingPartyRegistrationId, this.id);
}

View File

@ -18,8 +18,11 @@ package org.springframework.security.saml2.provider.service.authentication;
import java.io.Serial;
import org.jspecify.annotations.Nullable;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.util.Assert;
/**
* Data holder for information required to send an {@code AuthNRequest} over a REDIRECT
@ -35,12 +38,13 @@ public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authe
@Serial
private static final long serialVersionUID = 6476874109764554798L;
private final String sigAlg;
private final @Nullable String sigAlg;
private final String signature;
private final @Nullable String signature;
private Saml2RedirectAuthenticationRequest(String samlRequest, String sigAlg, String signature, String relayState,
String authenticationRequestUri, String relyingPartyRegistrationId, String id) {
private Saml2RedirectAuthenticationRequest(String samlRequest, @Nullable String sigAlg, @Nullable String signature,
@Nullable String relayState, String authenticationRequestUri, String relyingPartyRegistrationId,
@Nullable String id) {
super(samlRequest, relayState, authenticationRequestUri, relyingPartyRegistrationId, id);
this.sigAlg = sigAlg;
this.signature = signature;
@ -50,7 +54,7 @@ public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authe
* Returns the SigAlg value for {@link Saml2MessageBinding#REDIRECT} requests
* @return the SigAlg value
*/
public String getSigAlg() {
public @Nullable String getSigAlg() {
return this.sigAlg;
}
@ -58,7 +62,7 @@ public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authe
* Returns the Signature value for {@link Saml2MessageBinding#REDIRECT} requests
* @return the Signature value
*/
public String getSignature() {
public @Nullable String getSignature() {
return this.signature;
}
@ -87,9 +91,9 @@ public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authe
*/
public static final class Builder extends AbstractSaml2AuthenticationRequest.Builder<Builder> {
private String sigAlg;
private @Nullable String sigAlg;
private String signature;
private @Nullable String signature;
private Builder(RelyingPartyRegistration registration) {
super(registration);
@ -100,7 +104,7 @@ public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authe
* @param sigAlg the SigAlg parameter value.
* @return this object
*/
public Builder sigAlg(String sigAlg) {
public Builder sigAlg(@Nullable String sigAlg) {
this.sigAlg = sigAlg;
return _this();
}
@ -110,7 +114,7 @@ public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authe
* @param signature the Signature parameter value.
* @return this object
*/
public Builder signature(String signature) {
public Builder signature(@Nullable String signature) {
this.signature = signature;
return _this();
}
@ -120,6 +124,9 @@ public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authe
* @return an immutable {@link Saml2RedirectAuthenticationRequest} object.
*/
public Saml2RedirectAuthenticationRequest build() {
Assert.notNull(this.samlRequest, "samlRequest cannot be null");
Assert.notNull(this.authenticationRequestUri, "authenticationRequestUri cannot be null");
Assert.notNull(this.relyingPartyRegistrationId, "relyingPartyRegistrationId cannot be null");
return new Saml2RedirectAuthenticationRequest(this.samlRequest, this.sigAlg, this.signature,
this.relayState, this.authenticationRequestUri, this.relyingPartyRegistrationId, this.id);
}

View File

@ -20,6 +20,8 @@ import java.io.Serial;
import java.util.List;
import java.util.Map;
import org.jspecify.annotations.Nullable;
import org.springframework.util.Assert;
/**
@ -81,7 +83,7 @@ public class Saml2ResponseAssertion implements Saml2ResponseAssertionAccessor {
private final String responseValue;
private String nameId;
private @Nullable String nameId;
private List<String> sessionIndexes = List.of();
@ -107,6 +109,7 @@ public class Saml2ResponseAssertion implements Saml2ResponseAssertionAccessor {
}
public Saml2ResponseAssertion build() {
Assert.notNull(this.nameId, "nameId cannot be null");
return new Saml2ResponseAssertion(this.responseValue, this.nameId, this.sessionIndexes, this.attributes);
}

View File

@ -19,6 +19,7 @@ package org.springframework.security.saml2.provider.service.authentication.logou
import java.util.Collection;
import java.util.function.Consumer;
import org.jspecify.annotations.Nullable;
import org.opensaml.saml.saml2.core.LogoutRequest;
import org.opensaml.saml.saml2.core.NameID;
@ -33,6 +34,7 @@ import org.springframework.security.saml2.provider.service.authentication.logout
import org.springframework.security.saml2.provider.service.registration.AssertingPartyMetadata;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.util.Assert;
class BaseOpenSamlLogoutRequestValidator implements Saml2LogoutRequestValidator {
@ -73,15 +75,17 @@ class BaseOpenSamlLogoutRequestValidator implements Saml2LogoutRequestValidator
errors.addAll(verify.verify(logoutRequest));
}
else {
RedirectParameters params = new RedirectParameters(request.getParameters(),
request.getParametersQuery(), logoutRequest);
String parametersQuery = request.getParametersQuery();
Assert.notNull(parametersQuery, "parametersQuery cannot be null for redirect binding");
RedirectParameters params = new RedirectParameters(request.getParameters(), parametersQuery,
logoutRequest);
errors.addAll(verify.verify(params));
}
};
}
private Consumer<Collection<Saml2Error>> validateRequest(LogoutRequest request,
RelyingPartyRegistration registration, Authentication authentication) {
RelyingPartyRegistration registration, @Nullable Authentication authentication) {
return (errors) -> {
validateIssuer(request, registration).accept(errors);
validateDestination(request, registration).accept(errors);
@ -97,7 +101,7 @@ class BaseOpenSamlLogoutRequestValidator implements Saml2LogoutRequestValidator
return;
}
String issuer = request.getIssuer().getValue();
if (!issuer.equals(registration.getAssertingPartyMetadata().getEntityId())) {
if (!registration.getAssertingPartyMetadata().getEntityId().equals(issuer)) {
errors
.add(new Saml2Error(Saml2ErrorCodes.INVALID_ISSUER, "Failed to match issuer to configured issuer"));
}
@ -121,7 +125,7 @@ class BaseOpenSamlLogoutRequestValidator implements Saml2LogoutRequestValidator
}
private Consumer<Collection<Saml2Error>> validateSubject(LogoutRequest request,
RelyingPartyRegistration registration, Authentication authentication) {
RelyingPartyRegistration registration, @Nullable Authentication authentication) {
return (errors) -> {
if (authentication == null) {
return;
@ -137,7 +141,7 @@ class BaseOpenSamlLogoutRequestValidator implements Saml2LogoutRequestValidator
};
}
private NameID getNameId(LogoutRequest request, RelyingPartyRegistration registration) {
private @Nullable NameID getNameId(LogoutRequest request, RelyingPartyRegistration registration) {
this.saml.withDecryptionKeys(registration.getDecryptionX509Credentials()).decrypt(request);
return request.getNameID();
}
@ -145,7 +149,7 @@ class BaseOpenSamlLogoutRequestValidator implements Saml2LogoutRequestValidator
private void validateNameId(NameID nameId, Authentication authentication, Collection<Saml2Error> errors) {
String name = (authentication.getCredentials() instanceof Saml2ResponseAssertionAccessor assertion)
? assertion.getNameId() : authentication.getName();
if (!nameId.getValue().equals(name)) {
if (!name.equals(nameId.getValue())) {
errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_REQUEST,
"Failed to match subject in LogoutRequest with currently logged in user"));
}

View File

@ -19,6 +19,7 @@ package org.springframework.security.saml2.provider.service.authentication.logou
import java.util.Collection;
import java.util.function.Consumer;
import org.jspecify.annotations.Nullable;
import org.opensaml.saml.saml2.core.LogoutResponse;
import org.opensaml.saml.saml2.core.StatusCode;
@ -31,6 +32,7 @@ import org.springframework.security.saml2.provider.service.authentication.logout
import org.springframework.security.saml2.provider.service.registration.AssertingPartyMetadata;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.util.Assert;
class BaseOpenSamlLogoutResponseValidator implements Saml2LogoutResponseValidator {
@ -71,8 +73,10 @@ class BaseOpenSamlLogoutResponseValidator implements Saml2LogoutResponseValidato
errors.addAll(verify.verify(logoutResponse));
}
else {
RedirectParameters params = new RedirectParameters(response.getParameters(),
response.getParametersQuery(), logoutResponse);
String parametersQuery = response.getParametersQuery();
Assert.notNull(parametersQuery, "parametersQuery cannot be null for redirect binding");
RedirectParameters params = new RedirectParameters(response.getParameters(), parametersQuery,
logoutResponse);
errors.addAll(verify.verify(params));
}
};
@ -95,7 +99,7 @@ class BaseOpenSamlLogoutResponseValidator implements Saml2LogoutResponseValidato
return;
}
String issuer = response.getIssuer().getValue();
if (!issuer.equals(registration.getAssertingPartyMetadata().getEntityId())) {
if (!registration.getAssertingPartyMetadata().getEntityId().equals(issuer)) {
errors
.add(new Saml2Error(Saml2ErrorCodes.INVALID_ISSUER, "Failed to match issuer to configured issuer"));
}
@ -136,7 +140,7 @@ class BaseOpenSamlLogoutResponseValidator implements Saml2LogoutResponseValidato
};
}
private Consumer<Collection<Saml2Error>> validateLogoutRequest(LogoutResponse response, String id) {
private Consumer<Collection<Saml2Error>> validateLogoutRequest(LogoutResponse response, @Nullable String id) {
return (errors) -> {
if (response.getInResponseTo() == null) {
return;

View File

@ -25,6 +25,7 @@ import java.util.Objects;
import javax.xml.namespace.QName;
import org.jspecify.annotations.Nullable;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.saml.saml2.core.Issuer;
import org.opensaml.saml.saml2.core.RequestAbstractType;
@ -35,6 +36,7 @@ import org.w3c.dom.Element;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.core.Saml2ParameterNames;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.util.Assert;
import org.springframework.web.util.UriComponentsBuilder;
interface OpenSamlOperations {
@ -89,14 +91,17 @@ interface OpenSamlOperations {
private final String algorithm;
private final byte[] signature;
private final byte @Nullable [] signature;
private final byte[] content;
RedirectParameters(Map<String, String> parameters, String parametersQuery, RequestAbstractType request) {
Assert.notNull(request.getID(), "SAML request's ID cannot be null");
Assert.notNull(request.getIssuer(), "SAML request's Issuer cannot be null");
this.id = request.getID();
this.issuer = request.getIssuer();
this.algorithm = parameters.get(Saml2ParameterNames.SIG_ALG);
this.algorithm = Objects.requireNonNull(parameters.get(Saml2ParameterNames.SIG_ALG),
"sigAlg parameter cannot be null");
if (parameters.get(Saml2ParameterNames.SIGNATURE) != null) {
this.signature = Saml2Utils.samlDecode(parameters.get(Saml2ParameterNames.SIGNATURE));
}
@ -113,9 +118,12 @@ interface OpenSamlOperations {
}
RedirectParameters(Map<String, String> parameters, String parametersQuery, StatusResponseType response) {
Assert.notNull(response.getID(), "SAML response's ID cannot be null");
Assert.notNull(response.getIssuer(), "SAML response's Issuer cannot be null");
this.id = response.getID();
this.issuer = response.getIssuer();
this.algorithm = parameters.get(Saml2ParameterNames.SIG_ALG);
this.algorithm = Objects.requireNonNull(parameters.get(Saml2ParameterNames.SIG_ALG),
"sigAlg parameter cannot be null");
if (parameters.get(Saml2ParameterNames.SIGNATURE) != null) {
this.signature = Saml2Utils.samlDecode(parameters.get(Saml2ParameterNames.SIGNATURE));
}
@ -131,7 +139,8 @@ interface OpenSamlOperations {
this.content = getContent(Saml2ParameterNames.SAML_RESPONSE, relayState, queryParams);
}
static byte[] getContent(String samlObject, String relayState, final Map<String, String> queryParams) {
static byte[] getContent(String samlObject, @Nullable String relayState,
final Map<String, String> queryParams) {
if (Objects.nonNull(relayState)) {
return String
.format("%s=%s&%s=%s&%s=%s", samlObject, queryParams.get(samlObject),
@ -163,7 +172,7 @@ interface OpenSamlOperations {
return this.algorithm;
}
byte[] getSignature() {
byte @Nullable [] getSignature() {
return this.signature;
}

View File

@ -25,10 +25,13 @@ import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;
import org.jspecify.annotations.Nullable;
import org.springframework.security.saml2.core.Saml2ParameterNames;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutRequestResolver;
import org.springframework.util.Assert;
import org.springframework.web.util.UriComponentsBuilder;
import org.springframework.web.util.UriUtils;
@ -43,7 +46,7 @@ public final class Saml2LogoutRequest implements Serializable {
@Serial
private static final long serialVersionUID = -3588981995674761337L;
private static final Function<Map<String, String>, String> DEFAULT_ENCODER = (params) -> {
private static final Function<Map<String, String>, @Nullable String> DEFAULT_ENCODER = (params) -> {
if (params.isEmpty()) {
return null;
}
@ -60,19 +63,20 @@ public final class Saml2LogoutRequest implements Serializable {
private final Map<String, String> parameters;
private final String id;
private final @Nullable String id;
private final String relyingPartyRegistrationId;
private transient Function<Map<String, String>, String> encoder;
private transient Function<Map<String, String>, @Nullable String> encoder;
private Saml2LogoutRequest(String location, Saml2MessageBinding binding, Map<String, String> parameters, String id,
String relyingPartyRegistrationId) {
this(location, binding, parameters, id, relyingPartyRegistrationId, DEFAULT_ENCODER);
}
private Saml2LogoutRequest(String location, Saml2MessageBinding binding, Map<String, String> parameters, String id,
String relyingPartyRegistrationId, Function<Map<String, String>, String> encoder) {
private Saml2LogoutRequest(String location, Saml2MessageBinding binding, Map<String, String> parameters,
@Nullable String id, String relyingPartyRegistrationId,
Function<Map<String, String>, @Nullable String> encoder) {
this.location = location;
this.binding = binding;
this.parameters = Collections.unmodifiableMap(new LinkedHashMap<>(parameters));
@ -85,7 +89,7 @@ public final class Saml2LogoutRequest implements Serializable {
* The unique identifier for this Logout Request
* @return the Logout Request identifier
*/
public String getId() {
public @Nullable String getId() {
return this.id;
}
@ -112,14 +116,16 @@ public final class Saml2LogoutRequest implements Serializable {
* @return the signed and serialized &lt;saml2:LogoutRequest&gt; payload
*/
public String getSamlRequest() {
return this.parameters.get(Saml2ParameterNames.SAML_REQUEST);
String samlRequest = this.parameters.get(Saml2ParameterNames.SAML_REQUEST);
Assert.notNull(samlRequest, "samlRequest cannot be null");
return samlRequest;
}
/**
* The relay state associated with this Logout Request
* @return the relay state
*/
public String getRelayState() {
public @Nullable String getRelayState() {
return this.parameters.get(Saml2ParameterNames.RELAY_STATE);
}
@ -132,7 +138,7 @@ public final class Saml2LogoutRequest implements Serializable {
* @param name the parameter's name
* @return the parameter's value
*/
public String getParameter(String name) {
public @Nullable String getParameter(String name) {
return this.parameters.get(name);
}
@ -152,7 +158,7 @@ public final class Saml2LogoutRequest implements Serializable {
* @return an encoded string of all parameters
* @since 5.8
*/
public String getParametersQuery() {
public @Nullable String getParametersQuery() {
return this.encoder.apply(this.parameters);
}
@ -182,15 +188,15 @@ public final class Saml2LogoutRequest implements Serializable {
private final RelyingPartyRegistration registration;
private String location;
private @Nullable String location;
private Saml2MessageBinding binding;
private Map<String, String> parameters = new LinkedHashMap<>();
private Function<Map<String, String>, String> encoder = DEFAULT_ENCODER;
private Function<Map<String, String>, @Nullable String> encoder = DEFAULT_ENCODER;
private String id;
private @Nullable String id;
private Builder(RelyingPartyRegistration registration) {
this.registration = registration;
@ -284,7 +290,7 @@ public final class Saml2LogoutRequest implements Serializable {
* @return the {@link Builder} for further configurations
* @since 5.8
*/
public Builder parametersQuery(Function<Map<String, String>, String> encoder) {
public Builder parametersQuery(Function<Map<String, String>, @Nullable String> encoder) {
this.encoder = encoder;
return this;
}
@ -294,6 +300,8 @@ public final class Saml2LogoutRequest implements Serializable {
* @return a constructed {@link Saml2LogoutRequest}
*/
public Saml2LogoutRequest build() {
Assert.notNull(this.location, "singleLocationServiceLocation cannot be null");
Assert.notNull(this.parameters.get(Saml2ParameterNames.SAML_REQUEST), "samlRequest cannot be null");
return new Saml2LogoutRequest(this.location, this.binding, this.parameters, this.id,
this.registration.getRegistrationId(), this.encoder);
}

View File

@ -16,6 +16,8 @@
package org.springframework.security.saml2.provider.service.authentication.logout;
import org.jspecify.annotations.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
@ -31,7 +33,7 @@ public class Saml2LogoutRequestValidatorParameters {
private final RelyingPartyRegistration registration;
private final Authentication authentication;
private final @Nullable Authentication authentication;
/**
* Construct a {@link Saml2LogoutRequestValidatorParameters}
@ -40,7 +42,7 @@ public class Saml2LogoutRequestValidatorParameters {
* @param authentication the current user
*/
public Saml2LogoutRequestValidatorParameters(Saml2LogoutRequest request, RelyingPartyRegistration registration,
Authentication authentication) {
@Nullable Authentication authentication) {
this.request = request;
this.registration = registration;
this.authentication = authentication;
@ -66,7 +68,7 @@ public class Saml2LogoutRequestValidatorParameters {
* The current {@link Authentication}
* @return the authenticated user
*/
public Authentication getAuthentication() {
public @Nullable Authentication getAuthentication() {
return this.authentication;
}

View File

@ -23,10 +23,13 @@ import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;
import org.jspecify.annotations.Nullable;
import org.springframework.security.saml2.core.Saml2ParameterNames;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutResponseResolver;
import org.springframework.util.Assert;
import org.springframework.web.util.UriComponentsBuilder;
import org.springframework.web.util.UriUtils;
@ -38,7 +41,7 @@ import org.springframework.web.util.UriUtils;
*/
public final class Saml2LogoutResponse {
private static final Function<Map<String, String>, String> DEFAULT_ENCODER = (params) -> {
private static final Function<Map<String, String>, @Nullable String> DEFAULT_ENCODER = (params) -> {
if (params.isEmpty()) {
return null;
}
@ -55,10 +58,10 @@ public final class Saml2LogoutResponse {
private final Map<String, String> parameters;
private final Function<Map<String, String>, String> encoder;
private final Function<Map<String, String>, @Nullable String> encoder;
private Saml2LogoutResponse(String location, Saml2MessageBinding binding, Map<String, String> parameters,
Function<Map<String, String>, String> encoder) {
Function<Map<String, String>, @Nullable String> encoder) {
this.location = location;
this.binding = binding;
this.parameters = Collections.unmodifiableMap(new LinkedHashMap<>(parameters));
@ -88,14 +91,16 @@ public final class Saml2LogoutResponse {
* @return the signed and serialized &lt;saml2:LogoutResponse&gt; payload
*/
public String getSamlResponse() {
return this.parameters.get(Saml2ParameterNames.SAML_RESPONSE);
String samlResponse = this.parameters.get(Saml2ParameterNames.SAML_RESPONSE);
Assert.notNull(samlResponse, "samlResponse cannot be null");
return samlResponse;
}
/**
* The relay state associated with this Logout Request
* @return the relay state
*/
public String getRelayState() {
public @Nullable String getRelayState() {
return this.parameters.get(Saml2ParameterNames.RELAY_STATE);
}
@ -108,7 +113,7 @@ public final class Saml2LogoutResponse {
* @param name the parameter's name
* @return the parameter's value
*/
public String getParameter(String name) {
public @Nullable String getParameter(String name) {
return this.parameters.get(name);
}
@ -128,7 +133,7 @@ public final class Saml2LogoutResponse {
* @return an encoded string of all parameters
* @since 5.8
*/
public String getParametersQuery() {
public @Nullable String getParametersQuery() {
return this.encoder.apply(this.parameters);
}
@ -147,13 +152,13 @@ public final class Saml2LogoutResponse {
public static final class Builder {
private String location;
private @Nullable String location;
private Saml2MessageBinding binding;
private Map<String, String> parameters = new LinkedHashMap<>();
private Function<Map<String, String>, String> encoder = DEFAULT_ENCODER;
private Function<Map<String, String>, @Nullable String> encoder = DEFAULT_ENCODER;
private Builder(RelyingPartyRegistration registration) {
this.location = registration.getAssertingPartyMetadata().getSingleLogoutServiceResponseLocation();
@ -236,7 +241,7 @@ public final class Saml2LogoutResponse {
* @return the {@link Saml2LogoutRequest.Builder} for further configurations
* @since 5.8
*/
public Builder parametersQuery(Function<Map<String, String>, String> encoder) {
public Builder parametersQuery(Function<Map<String, String>, @Nullable String> encoder) {
this.encoder = encoder;
return this;
}
@ -246,6 +251,8 @@ public final class Saml2LogoutResponse {
* @return a constructed {@link Saml2LogoutResponse}
*/
public Saml2LogoutResponse build() {
Assert.notNull(this.location, "singleLogoutResponseLocation cannot be null");
Assert.notNull(this.parameters.get(Saml2ParameterNames.SAML_RESPONSE), "samlResponse cannot be null");
return new Saml2LogoutResponse(this.location, this.binding, this.parameters, this.encoder);
}

View File

@ -0,0 +1,23 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Internal utilities for SAML2 support (not for public use).
*/
@NullMarked
package org.springframework.security.saml2.provider.service.authentication.logout;
import org.jspecify.annotations.NullMarked;

View File

@ -0,0 +1,23 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Internal utilities for SAML2 support (not for public use).
*/
@NullMarked
package org.springframework.security.saml2.provider.service.authentication;
import org.jspecify.annotations.NullMarked;

View File

@ -25,6 +25,7 @@ import java.util.Objects;
import javax.xml.namespace.QName;
import org.jspecify.annotations.Nullable;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.saml.saml2.core.Issuer;
import org.opensaml.saml.saml2.core.RequestAbstractType;
@ -35,6 +36,7 @@ import org.w3c.dom.Element;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.core.Saml2ParameterNames;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.util.Assert;
import org.springframework.web.util.UriComponentsBuilder;
interface OpenSamlOperations {
@ -89,14 +91,17 @@ interface OpenSamlOperations {
private final String algorithm;
private final byte[] signature;
private final byte @Nullable [] signature;
private final byte[] content;
RedirectParameters(Map<String, String> parameters, String parametersQuery, RequestAbstractType request) {
Assert.notNull(request.getID(), "SAML request's ID cannot be null");
Assert.notNull(request.getIssuer(), "SAML request's Issuer cannot be null");
this.id = request.getID();
this.issuer = request.getIssuer();
this.algorithm = parameters.get(Saml2ParameterNames.SIG_ALG);
this.algorithm = Objects.requireNonNull(parameters.get(Saml2ParameterNames.SIG_ALG),
"sigAlg parameter cannot be null");
if (parameters.get(Saml2ParameterNames.SIGNATURE) != null) {
this.signature = Saml2Utils.samlDecode(parameters.get(Saml2ParameterNames.SIGNATURE));
}
@ -113,9 +118,12 @@ interface OpenSamlOperations {
}
RedirectParameters(Map<String, String> parameters, String parametersQuery, StatusResponseType response) {
Assert.notNull(response.getID(), "SAML response's ID cannot be null");
Assert.notNull(response.getIssuer(), "SAML response's Issuer cannot be null");
this.id = response.getID();
this.issuer = response.getIssuer();
this.algorithm = parameters.get(Saml2ParameterNames.SIG_ALG);
this.algorithm = Objects.requireNonNull(parameters.get(Saml2ParameterNames.SIG_ALG),
"sigAlg parameter cannot be null");
if (parameters.get(Saml2ParameterNames.SIGNATURE) != null) {
this.signature = Saml2Utils.samlDecode(parameters.get(Saml2ParameterNames.SIGNATURE));
}
@ -131,7 +139,8 @@ interface OpenSamlOperations {
this.content = getContent(Saml2ParameterNames.SAML_RESPONSE, relayState, queryParams);
}
static byte[] getContent(String samlObject, String relayState, final Map<String, String> queryParams) {
static byte[] getContent(String samlObject, @Nullable String relayState,
final Map<String, String> queryParams) {
if (Objects.nonNull(relayState)) {
return String
.format("%s=%s&%s=%s&%s=%s", samlObject, queryParams.get(samlObject),
@ -163,7 +172,7 @@ interface OpenSamlOperations {
return this.algorithm;
}
byte[] getSignature() {
byte @Nullable [] getSignature() {
return this.signature;
}

View File

@ -16,6 +16,8 @@
package org.springframework.security.saml2.provider.service.metadata;
import org.jspecify.annotations.NullUnmarked;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.web.util.matcher.RequestMatcher;
@ -30,6 +32,7 @@ import org.springframework.security.web.util.matcher.RequestMatcher;
* {@link org.springframework.security.saml2.provider.service.web.metadata.RequestMatcherMetadataResponseResolver}
*/
@Deprecated
@NullUnmarked
public final class RequestMatcherMetadataResponseResolver extends
org.springframework.security.saml2.provider.service.web.metadata.RequestMatcherMetadataResponseResolver {

View File

@ -17,6 +17,7 @@
package org.springframework.security.saml2.provider.service.metadata;
import jakarta.servlet.http.HttpServletRequest;
import org.jspecify.annotations.Nullable;
/**
* Resolves Relying Party SAML 2.0 Metadata given details from the
@ -33,6 +34,6 @@ public interface Saml2MetadataResponseResolver {
* @param request the HTTP request
* @return a {@link Saml2MetadataResponse} instance
*/
Saml2MetadataResponse resolve(HttpServletRequest request);
@Nullable Saml2MetadataResponse resolve(HttpServletRequest request);
}

View File

@ -0,0 +1,23 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Internal utilities for SAML2 support (not for public use).
*/
@NullMarked
package org.springframework.security.saml2.provider.service.metadata;
import org.jspecify.annotations.NullMarked;

View File

@ -21,6 +21,8 @@ import java.util.Collection;
import java.util.List;
import java.util.function.Consumer;
import org.jspecify.annotations.Nullable;
import org.springframework.security.saml2.core.Saml2X509Credential;
/**
@ -119,7 +121,7 @@ public interface AssertingPartyMetadata extends Serializable {
* @return the SingleLogoutService Location
* @since 5.6
*/
String getSingleLogoutServiceLocation();
@Nullable String getSingleLogoutServiceLocation();
/**
* Get the <a href=
@ -132,7 +134,7 @@ public interface AssertingPartyMetadata extends Serializable {
* @return the SingleLogoutService Response Location
* @since 5.6
*/
String getSingleLogoutServiceResponseLocation();
@Nullable String getSingleLogoutServiceResponseLocation();
/**
* Get the <a href=

View File

@ -16,7 +16,7 @@
package org.springframework.security.saml2.provider.service.registration;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
/**
* A repository for retrieving SAML 2.0 Asserting Party Metadata
@ -34,8 +34,7 @@ public interface AssertingPartyMetadataRepository extends Iterable<AssertingPart
* @param entityId the EntityID to lookup
* @return the found {@link AssertingPartyMetadata}, or {@code null} otherwise
*/
@Nullable
default AssertingPartyMetadata findByEntityId(String entityId) {
default @Nullable AssertingPartyMetadata findByEntityId(String entityId) {
for (AssertingPartyMetadata metadata : this) {
if (metadata.getEntityId().equals(entityId)) {
return metadata;

View File

@ -20,6 +20,7 @@ import java.util.Iterator;
import java.util.Set;
import java.util.function.Supplier;
import org.jspecify.annotations.Nullable;
import org.opensaml.core.criterion.EntityIdCriterion;
import org.opensaml.saml.criterion.EntityRoleCriterion;
import org.opensaml.saml.metadata.IterableMetadataSource;
@ -31,8 +32,6 @@ import org.opensaml.saml.metadata.resolver.index.impl.RoleMetadataIndex;
import org.opensaml.saml.saml2.metadata.EntityDescriptor;
import org.opensaml.saml.saml2.metadata.IDPSSODescriptor;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.core.OpenSamlInitializationService;
import org.springframework.util.Assert;
@ -95,7 +94,6 @@ class BaseOpenSamlAssertingPartyMetadataRepository implements AssertingPartyMeta
}
@Override
@NonNull
public Iterator<AssertingPartyMetadata> iterator() {
Iterator<EntityDescriptor> descriptors = this.descriptors.get();
return new Iterator<>() {
@ -111,9 +109,8 @@ class BaseOpenSamlAssertingPartyMetadataRepository implements AssertingPartyMeta
};
}
@Nullable
@Override
public AssertingPartyMetadata findByEntityId(String entityId) {
public @Nullable AssertingPartyMetadata findByEntityId(String entityId) {
EntityDescriptor descriptor = resolveSingle(new EntityIdCriterion(entityId));
if (descriptor == null) {
return null;

View File

@ -21,6 +21,8 @@ import java.util.Spliterator;
import java.util.concurrent.Callable;
import java.util.function.Consumer;
import org.jspecify.annotations.Nullable;
import org.springframework.cache.Cache;
import org.springframework.cache.concurrent.ConcurrentMapCache;
import org.springframework.util.Assert;
@ -55,12 +57,12 @@ public final class CachingRelyingPartyRegistrationRepository implements Iterable
* {@inheritDoc}
*/
@Override
public RelyingPartyRegistration findByRegistrationId(String registrationId) {
public @Nullable RelyingPartyRegistration findByRegistrationId(String registrationId) {
return registrations().findByRegistrationId(registrationId);
}
@Override
public RelyingPartyRegistration findUniqueByAssertingPartyEntityId(String entityId) {
public @Nullable RelyingPartyRegistration findUniqueByAssertingPartyEntityId(String entityId) {
return registrations().findUniqueByAssertingPartyEntityId(entityId);
}
@ -75,7 +77,10 @@ public final class CachingRelyingPartyRegistrationRepository implements Iterable
}
private IterableRelyingPartyRegistrationRepository registrations() {
return this.cache.get("registrations", this.registrationLoader);
IterableRelyingPartyRegistrationRepository registrations = this.cache.get("registrations",
this.registrationLoader);
Assert.notNull(registrations, "cache loader failed to return a repostory instance");
return registrations;
}
/**

View File

@ -24,6 +24,8 @@ import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.jspecify.annotations.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
@ -75,12 +77,12 @@ public class InMemoryRelyingPartyRegistrationRepository implements IterableRelyi
}
@Override
public RelyingPartyRegistration findByRegistrationId(String id) {
public @Nullable RelyingPartyRegistration findByRegistrationId(String id) {
return this.byRegistrationId.get(id);
}
@Override
public RelyingPartyRegistration findUniqueByAssertingPartyEntityId(String entityId) {
public @Nullable RelyingPartyRegistration findUniqueByAssertingPartyEntityId(String entityId) {
Collection<RelyingPartyRegistration> registrations = this.byAssertingPartyEntityId.get(entityId);
if (registrations == null) {
return null;

View File

@ -27,6 +27,8 @@ import java.util.Iterator;
import java.util.List;
import java.util.function.Function;
import org.jspecify.annotations.Nullable;
import org.springframework.core.serializer.DefaultDeserializer;
import org.springframework.core.serializer.DefaultSerializer;
import org.springframework.core.serializer.Deserializer;
@ -105,7 +107,7 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
}
@Override
public AssertingPartyMetadata findByEntityId(String entityId) {
public @Nullable AssertingPartyMetadata findByEntityId(String entityId) {
Assert.hasText(entityId, "entityId cannot be empty");
SqlParameterValue[] parameters = new SqlParameterValue[] { new SqlParameterValue(Types.VARCHAR, entityId) };
PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters);
@ -158,6 +160,7 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
String entityId = rs.getString(COLUMN_NAMES[0]);
String singleSignOnUrl = rs.getString(COLUMN_NAMES[1]);
Saml2MessageBinding singleSignOnBinding = Saml2MessageBinding.from(rs.getString(COLUMN_NAMES[2]));
Assert.notNull(singleSignOnBinding, "retrieved an unsupported binding " + rs.getString(COLUMN_NAMES[2]));
boolean singleSignOnSignRequest = rs.getBoolean(COLUMN_NAMES[3]);
List<String> algorithms = List.of(rs.getString(COLUMN_NAMES[4]).split(","));
byte[] verificationCredentialsBytes = rs.getBytes(COLUMN_NAMES[5]);
@ -171,6 +174,7 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
String singleLogoutUrl = rs.getString(COLUMN_NAMES[7]);
String singleLogoutResponseUrl = rs.getString(COLUMN_NAMES[8]);
Saml2MessageBinding singleLogoutBinding = Saml2MessageBinding.from(rs.getString(COLUMN_NAMES[9]));
Assert.notNull(singleLogoutBinding, "retrieved an unsupported binding " + rs.getString(COLUMN_NAMES[9]));
builder.entityId(entityId)
.wantAuthnRequestsSigned(singleSignOnSignRequest)

View File

@ -22,8 +22,10 @@ import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.function.Consumer;
import org.jspecify.annotations.Nullable;
import org.opensaml.saml.common.xml.SAMLConstants;
import org.opensaml.saml.ext.saml2alg.SigningMethod;
import org.opensaml.saml.saml2.metadata.EntityDescriptor;
@ -37,6 +39,7 @@ import org.opensaml.xmlsec.keyinfo.KeyInfoSupport;
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.util.Assert;
/**
* A {@link RelyingPartyRegistration.AssertingPartyDetails} that contains
@ -86,19 +89,19 @@ public final class OpenSamlAssertingPartyDetails extends RelyingPartyRegistratio
List<Saml2X509Credential> verification = new ArrayList<>();
List<Saml2X509Credential> encryption = new ArrayList<>();
for (KeyDescriptor keyDescriptor : idpssoDescriptor.getKeyDescriptors()) {
if (keyDescriptor.getUse().equals(UsageType.SIGNING)) {
if (UsageType.SIGNING.equals(keyDescriptor.getUse())) {
List<X509Certificate> certificates = certificates(keyDescriptor);
for (X509Certificate certificate : certificates) {
verification.add(Saml2X509Credential.verification(certificate));
}
}
if (keyDescriptor.getUse().equals(UsageType.ENCRYPTION)) {
if (UsageType.ENCRYPTION.equals(keyDescriptor.getUse())) {
List<X509Certificate> certificates = certificates(keyDescriptor);
for (X509Certificate certificate : certificates) {
encryption.add(Saml2X509Credential.encryption(certificate));
}
}
if (keyDescriptor.getUse().equals(UsageType.UNSPECIFIED)) {
if (UsageType.UNSPECIFIED.equals(keyDescriptor.getUse())) {
List<X509Certificate> certificates = certificates(keyDescriptor);
for (X509Certificate certificate : certificates) {
verification.add(Saml2X509Credential.verification(certificate));
@ -110,14 +113,17 @@ public final class OpenSamlAssertingPartyDetails extends RelyingPartyRegistratio
throw new Saml2Exception(
"Metadata response is missing verification certificates, necessary for verifying SAML assertions");
}
String entityId = entity.getEntityID();
Assert.notNull(entityId, "EntityDescriptor#EntityID cannot be null");
OpenSamlAssertingPartyDetails.Builder builder = new OpenSamlAssertingPartyDetails.Builder(entity)
.entityId(entity.getEntityID())
.entityId(entityId)
.wantAuthnRequestsSigned(Boolean.TRUE.equals(idpssoDescriptor.getWantAuthnRequestsSigned()))
.verificationX509Credentials((c) -> c.addAll(verification))
.encryptionX509Credentials((c) -> c.addAll(encryption));
List<SigningMethod> signingMethods = signingMethods(idpssoDescriptor);
for (SigningMethod method : signingMethods) {
Assert.notNull(method.getAlgorithm(), "EntityDescriptor declares a SigningMethod with no value");
builder.signingAlgorithms((algorithms) -> algorithms.add(method.getAlgorithm()));
}
if (idpssoDescriptor.getSingleSignOnServices().isEmpty()) {
@ -126,32 +132,36 @@ public final class OpenSamlAssertingPartyDetails extends RelyingPartyRegistratio
}
for (SingleSignOnService singleSignOnService : idpssoDescriptor.getSingleSignOnServices()) {
Saml2MessageBinding binding;
if (singleSignOnService.getBinding().equals(Saml2MessageBinding.POST.getUrn())) {
if (Saml2MessageBinding.POST.getUrn().equals(singleSignOnService.getBinding())) {
binding = Saml2MessageBinding.POST;
}
else if (singleSignOnService.getBinding().equals(Saml2MessageBinding.REDIRECT.getUrn())) {
else if (Saml2MessageBinding.REDIRECT.getUrn().equals(singleSignOnService.getBinding())) {
binding = Saml2MessageBinding.REDIRECT;
}
else {
continue;
}
builder.singleSignOnServiceLocation(singleSignOnService.getLocation()).singleSignOnServiceBinding(binding);
String location = singleSignOnService.getLocation();
Assert.notNull(location, "EntityDescriptor has a SingleSignOnService declaration, but no Location");
builder.singleSignOnServiceLocation(location).singleSignOnServiceBinding(binding);
break;
}
for (SingleLogoutService singleLogoutService : idpssoDescriptor.getSingleLogoutServices()) {
Saml2MessageBinding binding;
if (singleLogoutService.getBinding().equals(Saml2MessageBinding.POST.getUrn())) {
if (Saml2MessageBinding.POST.getUrn().equals(singleLogoutService.getBinding())) {
binding = Saml2MessageBinding.POST;
}
else if (singleLogoutService.getBinding().equals(Saml2MessageBinding.REDIRECT.getUrn())) {
else if (Saml2MessageBinding.REDIRECT.getUrn().equals(singleLogoutService.getBinding())) {
binding = Saml2MessageBinding.REDIRECT;
}
else {
continue;
}
String responseLocation = (singleLogoutService.getResponseLocation() == null)
? singleLogoutService.getLocation() : singleLogoutService.getResponseLocation();
builder.singleLogoutServiceLocation(singleLogoutService.getLocation())
String location = singleLogoutService.getLocation();
Assert.notNull(location, "EntityDescriptor has a SingleLogoutService declaration, but no Location");
String responseLocation = (singleLogoutService.getResponseLocation() == null) ? location
: singleLogoutService.getResponseLocation();
builder.singleLogoutServiceLocation(location)
.singleLogoutServiceResponseLocation(responseLocation)
.singleLogoutServiceBinding(binding);
break;
@ -174,12 +184,12 @@ public final class OpenSamlAssertingPartyDetails extends RelyingPartyRegistratio
if (!result.isEmpty()) {
return result;
}
EntityDescriptor descriptor = (EntityDescriptor) idpssoDescriptor.getParent();
EntityDescriptor descriptor = (EntityDescriptor) Objects.requireNonNull(idpssoDescriptor.getParent());
extensions = descriptor.getExtensions();
return signingMethods(extensions);
}
private static <T> List<T> signingMethods(Extensions extensions) {
private static <T> List<T> signingMethods(@Nullable Extensions extensions) {
if (extensions != null) {
return (List<T>) extensions.getUnknownXMLObjects(SigningMethod.DEFAULT_ELEMENT_NAME);
}
@ -273,7 +283,7 @@ public final class OpenSamlAssertingPartyDetails extends RelyingPartyRegistratio
* {@inheritDoc}
*/
@Override
public Builder singleLogoutServiceLocation(String singleLogoutServiceLocation) {
public Builder singleLogoutServiceLocation(@Nullable String singleLogoutServiceLocation) {
return (Builder) super.singleLogoutServiceLocation(singleLogoutServiceLocation);
}
@ -281,7 +291,7 @@ public final class OpenSamlAssertingPartyDetails extends RelyingPartyRegistratio
* {@inheritDoc}
*/
@Override
public Builder singleLogoutServiceResponseLocation(String singleLogoutServiceResponseLocation) {
public Builder singleLogoutServiceResponseLocation(@Nullable String singleLogoutServiceResponseLocation) {
return (Builder) super.singleLogoutServiceResponseLocation(singleLogoutServiceResponseLocation);
}

View File

@ -20,6 +20,7 @@ import java.io.InputStream;
import java.util.Collection;
import java.util.Collections;
import net.shibboleth.shared.xml.ParserPool;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.core.xml.io.Unmarshaller;
@ -31,6 +32,7 @@ import org.w3c.dom.Element;
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.core.OpenSamlInitializationService;
import org.springframework.util.Assert;
final class OpenSamlMetadataUtils {
@ -71,7 +73,9 @@ final class OpenSamlMetadataUtils {
@Override
public XMLObject deserialize(InputStream serialized) {
try {
Document document = XMLObjectProviderRegistrySupport.getParserPool().parse(serialized);
ParserPool parserPool = XMLObjectProviderRegistrySupport.getParserPool();
Assert.notNull(parserPool, "A ParserPool must be configured");
Document document = parserPool.parse(serialized);
Element element = document.getDocumentElement();
UnmarshallerFactory factory = XMLObjectProviderRegistrySupport.getUnmarshallerFactory();
Unmarshaller unmarshaller = factory.getUnmarshaller(element);

View File

@ -25,6 +25,7 @@ import java.util.Objects;
import javax.xml.namespace.QName;
import org.jspecify.annotations.Nullable;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.saml.saml2.core.Issuer;
import org.opensaml.saml.saml2.core.RequestAbstractType;
@ -35,6 +36,7 @@ import org.w3c.dom.Element;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.core.Saml2ParameterNames;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.util.Assert;
import org.springframework.web.util.UriComponentsBuilder;
interface OpenSamlOperations {
@ -89,14 +91,17 @@ interface OpenSamlOperations {
private final String algorithm;
private final byte[] signature;
private final byte @Nullable [] signature;
private final byte[] content;
RedirectParameters(Map<String, String> parameters, String parametersQuery, RequestAbstractType request) {
Assert.notNull(request.getID(), "SAML request's ID cannot be null");
Assert.notNull(request.getIssuer(), "SAML request's Issuer cannot be null");
this.id = request.getID();
this.issuer = request.getIssuer();
this.algorithm = parameters.get(Saml2ParameterNames.SIG_ALG);
this.algorithm = Objects.requireNonNull(parameters.get(Saml2ParameterNames.SIG_ALG),
"sigAlg parameter cannot be null");
if (parameters.get(Saml2ParameterNames.SIGNATURE) != null) {
this.signature = Saml2Utils.samlDecode(parameters.get(Saml2ParameterNames.SIGNATURE));
}
@ -113,9 +118,12 @@ interface OpenSamlOperations {
}
RedirectParameters(Map<String, String> parameters, String parametersQuery, StatusResponseType response) {
Assert.notNull(response.getID(), "SAML response's ID cannot be null");
Assert.notNull(response.getIssuer(), "SAML response's Issuer cannot be null");
this.id = response.getID();
this.issuer = response.getIssuer();
this.algorithm = parameters.get(Saml2ParameterNames.SIG_ALG);
this.algorithm = Objects.requireNonNull(parameters.get(Saml2ParameterNames.SIG_ALG),
"sigAlg parameter cannot be null");
if (parameters.get(Saml2ParameterNames.SIGNATURE) != null) {
this.signature = Saml2Utils.samlDecode(parameters.get(Saml2ParameterNames.SIGNATURE));
}
@ -131,7 +139,8 @@ interface OpenSamlOperations {
this.content = getContent(Saml2ParameterNames.SAML_RESPONSE, relayState, queryParams);
}
static byte[] getContent(String samlObject, String relayState, final Map<String, String> queryParams) {
static byte[] getContent(String samlObject, @Nullable String relayState,
final Map<String, String> queryParams) {
if (Objects.nonNull(relayState)) {
return String
.format("%s=%s&%s=%s&%s=%s", samlObject, queryParams.get(samlObject),
@ -163,7 +172,7 @@ interface OpenSamlOperations {
return this.algorithm;
}
byte[] getSignature() {
byte @Nullable [] getSignature() {
return this.signature;
}

View File

@ -20,6 +20,8 @@ import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import org.jspecify.annotations.Nullable;
import org.springframework.http.HttpInputMessage;
import org.springframework.http.HttpOutputMessage;
import org.springframework.http.MediaType;
@ -63,12 +65,12 @@ public class OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverter
}
@Override
public boolean canRead(Class<?> clazz, MediaType mediaType) {
public boolean canRead(Class<?> clazz, @Nullable MediaType mediaType) {
return RelyingPartyRegistration.Builder.class.isAssignableFrom(clazz);
}
@Override
public boolean canWrite(Class<?> clazz, MediaType mediaType) {
public boolean canWrite(Class<?> clazz, @Nullable MediaType mediaType) {
return false;
}
@ -84,8 +86,8 @@ public class OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverter
}
@Override
public void write(RelyingPartyRegistration.Builder builder, MediaType contentType, HttpOutputMessage outputMessage)
throws HttpMessageNotWritableException {
public void write(RelyingPartyRegistration.Builder builder, @Nullable MediaType contentType,
HttpOutputMessage outputMessage) throws HttpMessageNotWritableException {
throw new HttpMessageNotWritableException("This converter cannot write a RelyingPartyRegistration.Builder");
}

View File

@ -25,6 +25,8 @@ import java.util.LinkedList;
import java.util.List;
import java.util.function.Consumer;
import org.jspecify.annotations.Nullable;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
@ -79,13 +81,13 @@ public class RelyingPartyRegistration implements Serializable {
private final Saml2MessageBinding assertionConsumerServiceBinding;
private final String singleLogoutServiceLocation;
private final @Nullable String singleLogoutServiceLocation;
private final String singleLogoutServiceResponseLocation;
private final @Nullable String singleLogoutServiceResponseLocation;
private final Collection<Saml2MessageBinding> singleLogoutServiceBindings;
private final String nameIdFormat;
private final @Nullable String nameIdFormat;
private final boolean authnRequestsSigned;
@ -96,9 +98,10 @@ public class RelyingPartyRegistration implements Serializable {
private final Collection<Saml2X509Credential> signingX509Credentials;
protected RelyingPartyRegistration(String registrationId, String entityId, String assertionConsumerServiceLocation,
Saml2MessageBinding assertionConsumerServiceBinding, String singleLogoutServiceLocation,
String singleLogoutServiceResponseLocation, Collection<Saml2MessageBinding> singleLogoutServiceBindings,
AssertingPartyDetails assertingPartyDetails, String nameIdFormat, boolean authnRequestsSigned,
Saml2MessageBinding assertionConsumerServiceBinding, @Nullable String singleLogoutServiceLocation,
@Nullable String singleLogoutServiceResponseLocation,
Collection<Saml2MessageBinding> singleLogoutServiceBindings, AssertingPartyDetails assertingPartyDetails,
@Nullable String nameIdFormat, boolean authnRequestsSigned,
Collection<Saml2X509Credential> decryptionX509Credentials,
Collection<Saml2X509Credential> signingX509Credentials) {
Assert.hasText(registrationId, "registrationId cannot be empty");
@ -134,9 +137,10 @@ public class RelyingPartyRegistration implements Serializable {
}
private RelyingPartyRegistration(String registrationId, String entityId, String assertionConsumerServiceLocation,
Saml2MessageBinding assertionConsumerServiceBinding, String singleLogoutServiceLocation,
String singleLogoutServiceResponseLocation, Collection<Saml2MessageBinding> singleLogoutServiceBindings,
AssertingPartyMetadata assertingPartyMetadata, String nameIdFormat, boolean authnRequestsSigned,
Saml2MessageBinding assertionConsumerServiceBinding, @Nullable String singleLogoutServiceLocation,
@Nullable String singleLogoutServiceResponseLocation,
Collection<Saml2MessageBinding> singleLogoutServiceBindings, AssertingPartyMetadata assertingPartyMetadata,
@Nullable String nameIdFormat, boolean authnRequestsSigned,
Collection<Saml2X509Credential> decryptionX509Credentials,
Collection<Saml2X509Credential> signingX509Credentials) {
Assert.hasText(registrationId, "registrationId cannot be empty");
@ -285,7 +289,7 @@ public class RelyingPartyRegistration implements Serializable {
* @return the SingleLogoutService Location
* @since 5.6
*/
public String getSingleLogoutServiceLocation() {
public @Nullable String getSingleLogoutServiceLocation() {
return this.singleLogoutServiceLocation;
}
@ -300,7 +304,7 @@ public class RelyingPartyRegistration implements Serializable {
* @return the SingleLogoutService Response Location
* @since 5.6
*/
public String getSingleLogoutServiceResponseLocation() {
public @Nullable String getSingleLogoutServiceResponseLocation() {
return this.singleLogoutServiceResponseLocation;
}
@ -309,7 +313,7 @@ public class RelyingPartyRegistration implements Serializable {
* @return the NameID format
* @since 5.7
*/
public String getNameIdFormat() {
public @Nullable String getNameIdFormat() {
return this.nameIdFormat;
}
@ -409,17 +413,17 @@ public class RelyingPartyRegistration implements Serializable {
private final Saml2MessageBinding singleSignOnServiceBinding;
private final String singleLogoutServiceLocation;
private final @Nullable String singleLogoutServiceLocation;
private final String singleLogoutServiceResponseLocation;
private final @Nullable String singleLogoutServiceResponseLocation;
private final Saml2MessageBinding singleLogoutServiceBinding;
AssertingPartyDetails(String entityId, boolean wantAuthnRequestsSigned, List<String> signingAlgorithms,
Collection<Saml2X509Credential> verificationX509Credentials,
Collection<Saml2X509Credential> encryptionX509Credentials, String singleSignOnServiceLocation,
Saml2MessageBinding singleSignOnServiceBinding, String singleLogoutServiceLocation,
String singleLogoutServiceResponseLocation, Saml2MessageBinding singleLogoutServiceBinding) {
Saml2MessageBinding singleSignOnServiceBinding, @Nullable String singleLogoutServiceLocation,
@Nullable String singleLogoutServiceResponseLocation, Saml2MessageBinding singleLogoutServiceBinding) {
Assert.hasText(entityId, "entityId cannot be null or empty");
Assert.notEmpty(signingAlgorithms, "signingAlgorithms cannot be empty");
Assert.notNull(verificationX509Credentials, "verificationX509Credentials cannot be null");
@ -550,7 +554,7 @@ public class RelyingPartyRegistration implements Serializable {
* @return the SingleLogoutService Location
* @since 5.6
*/
public String getSingleLogoutServiceLocation() {
public @Nullable String getSingleLogoutServiceLocation() {
return this.singleLogoutServiceLocation;
}
@ -565,7 +569,7 @@ public class RelyingPartyRegistration implements Serializable {
* @return the SingleLogoutService Response Location
* @since 5.6
*/
public String getSingleLogoutServiceResponseLocation() {
public @Nullable String getSingleLogoutServiceResponseLocation() {
return this.singleLogoutServiceResponseLocation;
}
@ -599,7 +603,7 @@ public class RelyingPartyRegistration implements Serializable {
public static class Builder implements AssertingPartyMetadata.Builder<Builder> {
private String entityId;
private @Nullable String entityId;
private boolean wantAuthnRequestsSigned = true;
@ -609,13 +613,13 @@ public class RelyingPartyRegistration implements Serializable {
private Collection<Saml2X509Credential> encryptionX509Credentials = new LinkedHashSet<>();
private String singleSignOnServiceLocation;
private @Nullable String singleSignOnServiceLocation;
private Saml2MessageBinding singleSignOnServiceBinding = Saml2MessageBinding.REDIRECT;
private String singleLogoutServiceLocation;
private @Nullable String singleLogoutServiceLocation;
private String singleLogoutServiceResponseLocation;
private @Nullable String singleLogoutServiceResponseLocation;
private Saml2MessageBinding singleLogoutServiceBinding = Saml2MessageBinding.REDIRECT;
@ -727,7 +731,7 @@ public class RelyingPartyRegistration implements Serializable {
* @return the {@link AssertingPartyDetails.Builder} for further configuration
* @since 5.6
*/
public Builder singleLogoutServiceLocation(String singleLogoutServiceLocation) {
public Builder singleLogoutServiceLocation(@Nullable String singleLogoutServiceLocation) {
this.singleLogoutServiceLocation = singleLogoutServiceLocation;
return this;
}
@ -746,7 +750,7 @@ public class RelyingPartyRegistration implements Serializable {
* @return the {@link AssertingPartyDetails.Builder} for further configuration
* @since 5.6
*/
public Builder singleLogoutServiceResponseLocation(String singleLogoutServiceResponseLocation) {
public Builder singleLogoutServiceResponseLocation(@Nullable String singleLogoutServiceResponseLocation) {
this.singleLogoutServiceResponseLocation = singleLogoutServiceResponseLocation;
return this;
}
@ -777,7 +781,8 @@ public class RelyingPartyRegistration implements Serializable {
List<String> signingAlgorithms = this.signingAlgorithms.isEmpty()
? Collections.singletonList("http://www.w3.org/2001/04/xmldsig-more#rsa-sha256")
: Collections.unmodifiableList(this.signingAlgorithms);
Assert.notNull(this.entityId, "entityId cannot be null");
Assert.notNull(this.singleSignOnServiceLocation, "singleSignOnServiceLocation cannot be null");
return new AssertingPartyDetails(this.entityId, this.wantAuthnRequestsSigned, signingAlgorithms,
this.verificationX509Credentials, this.encryptionX509Credentials,
this.singleSignOnServiceLocation, this.singleSignOnServiceBinding,
@ -803,13 +808,13 @@ public class RelyingPartyRegistration implements Serializable {
private Saml2MessageBinding assertionConsumerServiceBinding = Saml2MessageBinding.POST;
private String singleLogoutServiceLocation;
private @Nullable String singleLogoutServiceLocation;
private String singleLogoutServiceResponseLocation;
private @Nullable String singleLogoutServiceResponseLocation;
private Collection<Saml2MessageBinding> singleLogoutServiceBindings = new LinkedHashSet<>();
private String nameIdFormat = null;
private @Nullable String nameIdFormat = null;
private boolean authnRequestsSigned = false;
@ -965,7 +970,7 @@ public class RelyingPartyRegistration implements Serializable {
* @return the {@link Builder} for further configuration
* @since 5.6
*/
public Builder singleLogoutServiceLocation(String singleLogoutServiceLocation) {
public Builder singleLogoutServiceLocation(@Nullable String singleLogoutServiceLocation) {
this.singleLogoutServiceLocation = singleLogoutServiceLocation;
return this;
}
@ -983,7 +988,7 @@ public class RelyingPartyRegistration implements Serializable {
* @return the {@link Builder} for further configuration
* @since 5.6
*/
public Builder singleLogoutServiceResponseLocation(String singleLogoutServiceResponseLocation) {
public Builder singleLogoutServiceResponseLocation(@Nullable String singleLogoutServiceResponseLocation) {
this.singleLogoutServiceResponseLocation = singleLogoutServiceResponseLocation;
return this;
}
@ -994,7 +999,7 @@ public class RelyingPartyRegistration implements Serializable {
* @return the {@link Builder} for further configuration
* @since 5.7
*/
public Builder nameIdFormat(String nameIdFormat) {
public Builder nameIdFormat(@Nullable String nameIdFormat) {
this.nameIdFormat = nameIdFormat;
return this;
}

View File

@ -16,6 +16,8 @@
package org.springframework.security.saml2.provider.service.registration;
import org.jspecify.annotations.Nullable;
/**
* A repository for {@link RelyingPartyRegistration}s
*
@ -31,7 +33,7 @@ public interface RelyingPartyRegistrationRepository {
* @param registrationId the registration identifier
* @return the {@link RelyingPartyRegistration} if found, otherwise {@code null}
*/
RelyingPartyRegistration findByRegistrationId(String registrationId);
@Nullable RelyingPartyRegistration findByRegistrationId(String registrationId);
/**
* Returns the unique relying party registration associated with the asserting party's
@ -41,7 +43,7 @@ public interface RelyingPartyRegistrationRepository {
* party; {@code null} of there is no unique match asserting party
* @since 6.1
*/
default RelyingPartyRegistration findUniqueByAssertingPartyEntityId(String entityId) {
default @Nullable RelyingPartyRegistration findUniqueByAssertingPartyEntityId(String entityId) {
return findByRegistrationId(entityId);
}

View File

@ -16,6 +16,8 @@
package org.springframework.security.saml2.provider.service.registration;
import org.jspecify.annotations.Nullable;
/**
* The type of bindings that messages are exchanged using Supported bindings are
* {@code urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST} and
@ -50,7 +52,7 @@ public enum Saml2MessageBinding {
* @return the resolved {@code Saml2MessageBinding}, or {@code null} if not found
* @since 5.5
*/
public static Saml2MessageBinding from(String name) {
public static @Nullable Saml2MessageBinding from(String name) {
for (Saml2MessageBinding value : values()) {
if (value.getUrn().equals(name)) {
return value;

View File

@ -0,0 +1,23 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Internal utilities for SAML2 support (not for public use).
*/
@NullMarked
package org.springframework.security.saml2.provider.service.registration;
import org.jspecify.annotations.NullMarked;

View File

@ -16,7 +16,12 @@
package org.springframework.security.saml2.provider.service.web;
import java.util.Objects;
import jakarta.servlet.http.HttpServletRequest;
import org.jspecify.annotations.Nullable;
import org.opensaml.core.xml.schema.XSString;
import org.opensaml.saml.saml2.core.Issuer;
import org.opensaml.saml.saml2.core.Response;
import org.springframework.http.HttpMethod;
@ -92,7 +97,7 @@ final class BaseOpenSamlAuthenticationTokenConverter implements AuthenticationCo
* non-existent {@code registrationId}
*/
@Override
public Saml2AuthenticationToken convert(HttpServletRequest request) {
public @Nullable Saml2AuthenticationToken convert(HttpServletRequest request) {
String serialized = request.getParameter(Saml2ParameterNames.SAML_RESPONSE);
if (serialized == null) {
return null;
@ -111,18 +116,21 @@ final class BaseOpenSamlAuthenticationTokenConverter implements AuthenticationCo
return token;
}
private Saml2AuthenticationToken tokenByAuthenticationRequest(HttpServletRequest request) {
private @Nullable Saml2AuthenticationToken tokenByAuthenticationRequest(HttpServletRequest request) {
AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequests
.loadAuthenticationRequest(request);
if (authenticationRequest == null) {
return null;
}
String registrationId = authenticationRequest.getRelyingPartyRegistrationId();
if (registrationId == null) {
return null;
}
RelyingPartyRegistration registration = this.registrations.findByRegistrationId(registrationId);
return tokenByRegistration(request, registration, authenticationRequest);
}
private Saml2AuthenticationToken tokenByRegistrationId(HttpServletRequest request,
private @Nullable Saml2AuthenticationToken tokenByRegistrationId(HttpServletRequest request,
RequestMatcher.MatchResult result) {
String registrationId = result.getVariables().get("registrationId");
if (registrationId == null) {
@ -132,27 +140,46 @@ final class BaseOpenSamlAuthenticationTokenConverter implements AuthenticationCo
return tokenByRegistration(request, registration, null);
}
private Saml2AuthenticationToken tokenByEntityId(HttpServletRequest request) {
Response response = this.saml.deserialize(decode(request));
String issuer = response.getIssuer().getValue();
RelyingPartyRegistration registration = this.registrations.findUniqueByAssertingPartyEntityId(issuer);
private @Nullable Saml2AuthenticationToken tokenByEntityId(HttpServletRequest request) {
String decoded = decode(request);
if (decoded == null) {
return null;
}
Response response = this.saml.deserialize(decoded);
Issuer issuer = response.getIssuer();
Assert.notNull(issuer, "Response#Issuer cannot be null");
RelyingPartyRegistration registration = this.registrations.findUniqueByAssertingPartyEntityId(getValue(issuer));
return tokenByRegistration(request, registration, null);
}
private Saml2AuthenticationToken tokenByRegistration(HttpServletRequest request,
RelyingPartyRegistration registration, AbstractSaml2AuthenticationRequest authenticationRequest) {
private @Nullable Saml2AuthenticationToken tokenByRegistration(HttpServletRequest request,
@Nullable RelyingPartyRegistration registration,
@Nullable AbstractSaml2AuthenticationRequest authenticationRequest) {
if (registration == null) {
return null;
}
String decoded = decode(request);
if (decoded == null) {
return null;
}
UriResolver resolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration);
String entityId = resolver.resolve(registration.getEntityId());
entityId = Objects.requireNonNull(entityId);
String assertionConsumerServiceLocation = resolver.resolve(registration.getAssertionConsumerServiceLocation());
assertionConsumerServiceLocation = Objects.requireNonNull(assertionConsumerServiceLocation);
registration = registration.mutate()
.entityId(resolver.resolve(registration.getEntityId()))
.assertionConsumerServiceLocation(resolver.resolve(registration.getAssertionConsumerServiceLocation()))
.entityId(entityId)
.assertionConsumerServiceLocation(assertionConsumerServiceLocation)
.build();
return new Saml2AuthenticationToken(registration, decoded, authenticationRequest);
}
private String getValue(XSString object) {
String value = object.getValue();
Assert.notNull(value, "required elements must have a value");
return value;
}
/**
* Use the given {@link Saml2AuthenticationRequestRepository} to load authentication
* request.
@ -178,7 +205,7 @@ final class BaseOpenSamlAuthenticationTokenConverter implements AuthenticationCo
this.shouldConvertGetRequests = shouldConvertGetRequests;
}
private String decode(HttpServletRequest request) {
private @Nullable String decode(HttpServletRequest request) {
String encoded = request.getParameter(Saml2ParameterNames.SAML_RESPONSE);
boolean isGet = HttpMethod.GET.matches(request.getMethod());
if (!this.shouldConvertGetRequests && isGet) {

View File

@ -18,6 +18,7 @@ package org.springframework.security.saml2.provider.service.web;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.jspecify.annotations.Nullable;
import org.springframework.cache.Cache;
import org.springframework.cache.concurrent.ConcurrentMapCache;
@ -43,14 +44,14 @@ public final class CacheSaml2AuthenticationRequestRepository
private Cache cache = new ConcurrentMapCache("authentication-requests");
@Override
public AbstractSaml2AuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) {
public @Nullable AbstractSaml2AuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) {
String relayState = request.getParameter(Saml2ParameterNames.RELAY_STATE);
Assert.notNull(relayState, "relayState must not be null");
return this.cache.get(relayState, AbstractSaml2AuthenticationRequest.class);
}
@Override
public void saveAuthenticationRequest(AbstractSaml2AuthenticationRequest authenticationRequest,
public void saveAuthenticationRequest(@Nullable AbstractSaml2AuthenticationRequest authenticationRequest,
HttpServletRequest request, HttpServletResponse response) {
String relayState = request.getParameter(Saml2ParameterNames.RELAY_STATE);
Assert.notNull(relayState, "relayState must not be null");
@ -58,7 +59,7 @@ public final class CacheSaml2AuthenticationRequestRepository
}
@Override
public AbstractSaml2AuthenticationRequest removeAuthenticationRequest(HttpServletRequest request,
public @Nullable AbstractSaml2AuthenticationRequest removeAuthenticationRequest(HttpServletRequest request,
HttpServletResponse response) {
String relayState = request.getParameter(Saml2ParameterNames.RELAY_STATE);
Assert.notNull(relayState, "relayState must not be null");

View File

@ -17,10 +17,12 @@
package org.springframework.security.saml2.provider.service.web;
import java.util.Map;
import java.util.Objects;
import jakarta.servlet.http.HttpServletRequest;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.server.PathContainer;
@ -40,7 +42,7 @@ import org.springframework.util.Assert;
* @since 5.4
*/
public final class DefaultRelyingPartyRegistrationResolver
implements Converter<HttpServletRequest, RelyingPartyRegistration>, RelyingPartyRegistrationResolver {
implements Converter<HttpServletRequest, @Nullable RelyingPartyRegistration>, RelyingPartyRegistrationResolver {
private Log logger = LogFactory.getLog(getClass());
@ -76,7 +78,7 @@ public final class DefaultRelyingPartyRegistrationResolver
* {@inheritDoc}
*/
@Override
public RelyingPartyRegistration convert(HttpServletRequest request) {
public @Nullable RelyingPartyRegistration convert(HttpServletRequest request) {
return resolve(request, null);
}
@ -84,7 +86,8 @@ public final class DefaultRelyingPartyRegistrationResolver
* {@inheritDoc}
*/
@Override
public RelyingPartyRegistration resolve(HttpServletRequest request, String relyingPartyRegistrationId) {
public @Nullable RelyingPartyRegistration resolve(HttpServletRequest request,
@Nullable String relyingPartyRegistrationId) {
if (relyingPartyRegistrationId == null) {
if (this.logger.isTraceEnabled()) {
this.logger.trace("Attempting to resolve from " + this.registrationRequestMatcher
@ -106,9 +109,14 @@ public final class DefaultRelyingPartyRegistrationResolver
return null;
}
UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration);
String entityId = uriResolver.resolve(registration.getEntityId());
entityId = Objects.requireNonNull(entityId);
String assertionConsumerServiceLocation = uriResolver
.resolve(registration.getAssertionConsumerServiceLocation());
assertionConsumerServiceLocation = Objects.requireNonNull(assertionConsumerServiceLocation);
return registration.mutate()
.entityId(uriResolver.resolve(registration.getEntityId()))
.assertionConsumerServiceLocation(uriResolver.resolve(registration.getAssertionConsumerServiceLocation()))
.entityId(entityId)
.assertionConsumerServiceLocation(assertionConsumerServiceLocation)
.singleLogoutServiceLocation(uriResolver.resolve(registration.getSingleLogoutServiceLocation()))
.singleLogoutServiceResponseLocation(
uriResolver.resolve(registration.getSingleLogoutServiceResponseLocation()))

View File

@ -19,6 +19,7 @@ package org.springframework.security.saml2.provider.service.web;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpSession;
import org.jspecify.annotations.Nullable;
import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
@ -40,7 +41,7 @@ public class HttpSessionSaml2AuthenticationRequestRepository
private String saml2AuthnRequestAttributeName = DEFAULT_SAML2_AUTHN_REQUEST_ATTR_NAME;
@Override
public AbstractSaml2AuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) {
public @Nullable AbstractSaml2AuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) {
HttpSession httpSession = request.getSession(false);
if (httpSession == null) {
return null;
@ -49,7 +50,7 @@ public class HttpSessionSaml2AuthenticationRequestRepository
}
@Override
public void saveAuthenticationRequest(AbstractSaml2AuthenticationRequest authenticationRequest,
public void saveAuthenticationRequest(@Nullable AbstractSaml2AuthenticationRequest authenticationRequest,
HttpServletRequest request, HttpServletResponse response) {
if (authenticationRequest == null) {
removeAuthenticationRequest(request, response);
@ -60,7 +61,7 @@ public class HttpSessionSaml2AuthenticationRequestRepository
}
@Override
public AbstractSaml2AuthenticationRequest removeAuthenticationRequest(HttpServletRequest request,
public @Nullable AbstractSaml2AuthenticationRequest removeAuthenticationRequest(HttpServletRequest request,
HttpServletResponse response) {
AbstractSaml2AuthenticationRequest authenticationRequest = loadAuthenticationRequest(request);
if (authenticationRequest == null) {

View File

@ -25,6 +25,7 @@ import java.util.Objects;
import javax.xml.namespace.QName;
import org.jspecify.annotations.Nullable;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.saml.saml2.core.Issuer;
import org.opensaml.saml.saml2.core.RequestAbstractType;
@ -35,6 +36,7 @@ import org.w3c.dom.Element;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.core.Saml2ParameterNames;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.util.Assert;
import org.springframework.web.util.UriComponentsBuilder;
interface OpenSamlOperations {
@ -89,14 +91,17 @@ interface OpenSamlOperations {
private final String algorithm;
private final byte[] signature;
private final byte @Nullable [] signature;
private final byte[] content;
RedirectParameters(Map<String, String> parameters, String parametersQuery, RequestAbstractType request) {
Assert.notNull(request.getID(), "SAML request's ID cannot be null");
Assert.notNull(request.getIssuer(), "SAML request's Issuer cannot be null");
this.id = request.getID();
this.issuer = request.getIssuer();
this.algorithm = parameters.get(Saml2ParameterNames.SIG_ALG);
this.algorithm = Objects.requireNonNull(parameters.get(Saml2ParameterNames.SIG_ALG),
"sigAlg parameter cannot be null");
if (parameters.get(Saml2ParameterNames.SIGNATURE) != null) {
this.signature = Saml2Utils.samlDecode(parameters.get(Saml2ParameterNames.SIGNATURE));
}
@ -113,9 +118,12 @@ interface OpenSamlOperations {
}
RedirectParameters(Map<String, String> parameters, String parametersQuery, StatusResponseType response) {
Assert.notNull(response.getID(), "SAML response's ID cannot be null");
Assert.notNull(response.getIssuer(), "SAML response's Issuer cannot be null");
this.id = response.getID();
this.issuer = response.getIssuer();
this.algorithm = parameters.get(Saml2ParameterNames.SIG_ALG);
this.algorithm = Objects.requireNonNull(parameters.get(Saml2ParameterNames.SIG_ALG),
"sigAlg parameter cannot be null");
if (parameters.get(Saml2ParameterNames.SIGNATURE) != null) {
this.signature = Saml2Utils.samlDecode(parameters.get(Saml2ParameterNames.SIGNATURE));
}
@ -131,7 +139,8 @@ interface OpenSamlOperations {
this.content = getContent(Saml2ParameterNames.SAML_RESPONSE, relayState, queryParams);
}
static byte[] getContent(String samlObject, String relayState, final Map<String, String> queryParams) {
static byte[] getContent(String samlObject, @Nullable String relayState,
final Map<String, String> queryParams) {
if (Objects.nonNull(relayState)) {
return String
.format("%s=%s&%s=%s&%s=%s", samlObject, queryParams.get(samlObject),
@ -163,7 +172,7 @@ interface OpenSamlOperations {
return this.algorithm;
}
byte[] getSignature() {
byte @Nullable [] getSignature() {
return this.signature;
}

View File

@ -20,6 +20,7 @@ import java.util.HashMap;
import java.util.Map;
import jakarta.servlet.http.HttpServletRequest;
import org.jspecify.annotations.Nullable;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.web.util.UrlUtils;
@ -122,7 +123,7 @@ public final class RelyingPartyRegistrationPlaceholderResolvers {
this.uriVariables = uriVariables;
}
public String resolve(String uri) {
public @Nullable String resolve(@Nullable String uri) {
if (uri == null) {
return null;
}

View File

@ -17,6 +17,7 @@
package org.springframework.security.saml2.provider.service.web;
import jakarta.servlet.http.HttpServletRequest;
import org.jspecify.annotations.Nullable;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
@ -32,9 +33,10 @@ public interface RelyingPartyRegistrationResolver {
* Resolve a {@link RelyingPartyRegistration} from the HTTP request, using the
* {@code relyingPartyRegistrationId}, if it is provided
* @param request the HTTP request
* @param relyingPartyRegistrationId the {@link RelyingPartyRegistration} identifier
* @param relyingPartyRegistrationId the {@link RelyingPartyRegistration} identifier;
* when {@code null}, may attempt to resolve from the request
* @return the resolved {@link RelyingPartyRegistration}
*/
RelyingPartyRegistration resolve(HttpServletRequest request, String relyingPartyRegistrationId);
@Nullable RelyingPartyRegistration resolve(HttpServletRequest request, @Nullable String relyingPartyRegistrationId);
}

View File

@ -18,6 +18,7 @@ package org.springframework.security.saml2.provider.service.web;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.jspecify.annotations.Nullable;
import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
@ -36,16 +37,18 @@ public interface Saml2AuthenticationRequestRepository<T extends AbstractSaml2Aut
* @return the {@link AbstractSaml2AuthenticationRequest} or {@code null} if it is not
* present
*/
T loadAuthenticationRequest(HttpServletRequest request);
@Nullable T loadAuthenticationRequest(HttpServletRequest request);
/**
* Saves the current authentication request using the {@link HttpServletRequest} and
* {@link HttpServletResponse}
* @param authenticationRequest the {@link AbstractSaml2AuthenticationRequest}
* @param authenticationRequest the {@link AbstractSaml2AuthenticationRequest}, if
* {@code null}, then remove
* @param request the current request
* @param response the current response
*/
void saveAuthenticationRequest(T authenticationRequest, HttpServletRequest request, HttpServletResponse response);
void saveAuthenticationRequest(@Nullable T authenticationRequest, HttpServletRequest request,
HttpServletResponse response);
/**
* Removes the authentication request using the {@link HttpServletRequest} and
@ -55,6 +58,6 @@ public interface Saml2AuthenticationRequestRepository<T extends AbstractSaml2Aut
* @return the removed {@link AbstractSaml2AuthenticationRequest} or {@code null} if
* it is not present
*/
T removeAuthenticationRequest(HttpServletRequest request, HttpServletResponse response);
@Nullable T removeAuthenticationRequest(HttpServletRequest request, HttpServletResponse response);
}

View File

@ -17,6 +17,7 @@
package org.springframework.security.saml2.provider.service.web;
import jakarta.servlet.http.HttpServletRequest;
import org.jspecify.annotations.Nullable;
import org.springframework.http.HttpMethod;
import org.springframework.security.saml2.core.Saml2Error;
@ -57,7 +58,7 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
}
@Override
public Saml2AuthenticationToken convert(HttpServletRequest request) {
public @Nullable Saml2AuthenticationToken convert(HttpServletRequest request) {
AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository
.loadAuthenticationRequest(request);
String relyingPartyRegistrationId = (authenticationRequest != null)
@ -97,7 +98,7 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
this.shouldConvertGetRequests = shouldConvertGetRequests;
}
private String decode(HttpServletRequest request) {
private @Nullable String decode(HttpServletRequest request) {
String encoded = request.getParameter(Saml2ParameterNames.SAML_RESPONSE);
if (encoded == null) {
return null;

View File

@ -24,6 +24,7 @@ import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.jspecify.annotations.Nullable;
import org.springframework.http.HttpHeaders;
import org.springframework.security.saml2.Saml2Exception;
@ -160,7 +161,7 @@ public final class Saml2MetadataFilter extends OncePerRequestFilter {
}
@Override
public Saml2MetadataResponse resolve(HttpServletRequest request) {
public @Nullable Saml2MetadataResponse resolve(HttpServletRequest request) {
RequestMatcher.MatchResult matcher = this.requestMatcher.matcher(request);
if (!matcher.isMatch()) {
return null;

View File

@ -23,6 +23,7 @@ import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.jspecify.annotations.Nullable;
import org.springframework.http.MediaType;
import org.springframework.security.saml2.core.Saml2ParameterNames;
@ -121,7 +122,7 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
response.sendRedirect(redirectUrl);
}
private void addParameter(String name, String value, UriComponentsBuilder builder) {
private void addParameter(String name, @Nullable String value, UriComponentsBuilder builder) {
Assert.hasText(name, "name cannot be empty or null");
if (StringUtils.hasText(value)) {
builder.queryParam(UriUtils.encode(name, StandardCharsets.ISO_8859_1),

View File

@ -22,19 +22,24 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.function.Consumer;
import jakarta.servlet.http.HttpServletRequest;
import org.jspecify.annotations.Nullable;
import org.opensaml.core.config.ConfigurationService;
import org.opensaml.core.xml.XMLObjectBuilder;
import org.opensaml.core.xml.XMLObjectBuilderFactory;
import org.opensaml.core.xml.config.XMLObjectProviderRegistry;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.saml.common.AbstractSAMLObjectBuilder;
import org.opensaml.saml.common.SAMLObject;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.opensaml.saml.saml2.core.Issuer;
import org.opensaml.saml.saml2.core.NameID;
import org.opensaml.saml.saml2.core.NameIDPolicy;
import org.opensaml.saml.saml2.core.impl.AuthnRequestBuilder;
import org.opensaml.saml.saml2.core.impl.AuthnRequestMarshaller;
import org.opensaml.saml.saml2.core.impl.IssuerBuilder;
import org.opensaml.saml.saml2.core.impl.NameIDBuilder;
import org.opensaml.saml.saml2.core.impl.NameIDPolicyBuilder;
@ -75,8 +80,6 @@ class BaseOpenSamlAuthenticationRequestResolver implements Saml2AuthenticationRe
private final AuthnRequestBuilder authnRequestBuilder;
private final AuthnRequestMarshaller marshaller;
private final IssuerBuilder issuerBuilder;
private final NameIDBuilder nameIdBuilder;
@ -90,7 +93,8 @@ class BaseOpenSamlAuthenticationRequestResolver implements Saml2AuthenticationRe
private Clock clock = Clock.systemUTC();
private Converter<HttpServletRequest, String> relayStateResolver = (request) -> UUID.randomUUID().toString();
private Converter<HttpServletRequest, @Nullable String> relayStateResolver = (request) -> UUID.randomUUID()
.toString();
private Consumer<AuthnRequestParameters> parametersConsumer = (parameters) -> {
};
@ -107,26 +111,25 @@ class BaseOpenSamlAuthenticationRequestResolver implements Saml2AuthenticationRe
Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null");
this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
XMLObjectProviderRegistry registry = ConfigurationService.get(XMLObjectProviderRegistry.class);
this.marshaller = (AuthnRequestMarshaller) registry.getMarshallerFactory()
.getMarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME);
Assert.notNull(this.marshaller, "authnRequestMarshaller must be configured in OpenSAML");
this.authnRequestBuilder = (AuthnRequestBuilder) XMLObjectProviderRegistrySupport.getBuilderFactory()
.getBuilder(AuthnRequest.DEFAULT_ELEMENT_NAME);
Assert.notNull(this.authnRequestBuilder, "authnRequestBuilder must be configured in OpenSAML");
this.issuerBuilder = (IssuerBuilder) registry.getBuilderFactory().getBuilder(Issuer.DEFAULT_ELEMENT_NAME);
Assert.notNull(this.issuerBuilder, "issuerBuilder must be configured in OpenSAML");
this.nameIdBuilder = (NameIDBuilder) registry.getBuilderFactory().getBuilder(NameID.DEFAULT_ELEMENT_NAME);
Assert.notNull(this.nameIdBuilder, "nameIdBuilder must be configured in OpenSAML");
this.nameIdPolicyBuilder = (NameIDPolicyBuilder) registry.getBuilderFactory()
.getBuilder(NameIDPolicy.DEFAULT_ELEMENT_NAME);
Assert.notNull(this.nameIdPolicyBuilder, "nameIdPolicyBuilder must be configured in OpenSAML");
Assert.notNull(registry, "XMLObjectProviderRegistry must be configured");
XMLObjectBuilderFactory builderFactory = XMLObjectProviderRegistrySupport.getBuilderFactory();
Assert.notNull(builderFactory, "XMLObjectBuilderFactory must be configured");
this.authnRequestBuilder = builder(builderFactory.ensureBuilder(AuthnRequest.DEFAULT_ELEMENT_NAME));
this.issuerBuilder = builder(builderFactory.ensureBuilder(Issuer.DEFAULT_ELEMENT_NAME));
this.nameIdBuilder = builder(builderFactory.ensureBuilder(NameID.DEFAULT_ELEMENT_NAME));
this.nameIdPolicyBuilder = builder(builderFactory.ensureBuilder(NameIDPolicy.DEFAULT_ELEMENT_NAME));
}
private static <T extends SAMLObject, B extends AbstractSAMLObjectBuilder<T>> B builder(
XMLObjectBuilder<T> builder) {
return (B) builder;
}
void setClock(Clock clock) {
this.clock = clock;
}
void setRelayStateResolver(Converter<HttpServletRequest, String> relayStateResolver) {
void setRelayStateResolver(Converter<HttpServletRequest, @Nullable String> relayStateResolver) {
this.relayStateResolver = relayStateResolver;
}
@ -139,7 +142,7 @@ class BaseOpenSamlAuthenticationRequestResolver implements Saml2AuthenticationRe
}
@Override
public <T extends AbstractSaml2AuthenticationRequest> T resolve(HttpServletRequest request) {
public <T extends AbstractSaml2AuthenticationRequest> @Nullable T resolve(HttpServletRequest request) {
RequestMatcher.MatchResult result = this.requestMatcher.matcher(request);
if (!result.isMatch()) {
return null;
@ -186,7 +189,7 @@ class BaseOpenSamlAuthenticationRequestResolver implements Saml2AuthenticationRe
return (T) Saml2PostAuthenticationRequest.withRelyingPartyRegistration(registration)
.samlRequest(encoded)
.relayState(relayState)
.id(authnRequest.getID())
.id(Objects.requireNonNull(authnRequest.getID()))
.build();
}
else {
@ -196,7 +199,7 @@ class BaseOpenSamlAuthenticationRequestResolver implements Saml2AuthenticationRe
.withRelyingPartyRegistration(registration)
.samlRequest(deflatedAndEncoded)
.relayState(relayState)
.id(authnRequest.getID());
.id(Objects.requireNonNull(authnRequest.getID()));
if (registration.getAssertingPartyMetadata().getWantAuthnRequestsSigned()
|| registration.isAuthnRequestsSigned()) {
Map<String, String> signingParameters = new HashMap<>();

View File

@ -25,6 +25,7 @@ import java.util.Objects;
import javax.xml.namespace.QName;
import org.jspecify.annotations.Nullable;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.saml.saml2.core.Issuer;
import org.opensaml.saml.saml2.core.RequestAbstractType;
@ -35,6 +36,7 @@ import org.w3c.dom.Element;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.core.Saml2ParameterNames;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.util.Assert;
import org.springframework.web.util.UriComponentsBuilder;
interface OpenSamlOperations {
@ -89,14 +91,17 @@ interface OpenSamlOperations {
private final String algorithm;
private final byte[] signature;
private final byte @Nullable [] signature;
private final byte[] content;
RedirectParameters(Map<String, String> parameters, String parametersQuery, RequestAbstractType request) {
Assert.notNull(request.getID(), "SAML request's ID cannot be null");
Assert.notNull(request.getIssuer(), "SAML request's Issuer cannot be null");
this.id = request.getID();
this.issuer = request.getIssuer();
this.algorithm = parameters.get(Saml2ParameterNames.SIG_ALG);
this.algorithm = Objects.requireNonNull(parameters.get(Saml2ParameterNames.SIG_ALG),
"sigAlg parameter cannot be null");
if (parameters.get(Saml2ParameterNames.SIGNATURE) != null) {
this.signature = Saml2Utils.samlDecode(parameters.get(Saml2ParameterNames.SIGNATURE));
}
@ -113,9 +118,12 @@ interface OpenSamlOperations {
}
RedirectParameters(Map<String, String> parameters, String parametersQuery, StatusResponseType response) {
Assert.notNull(response.getID(), "SAML response's ID cannot be null");
Assert.notNull(response.getIssuer(), "SAML response's Issuer cannot be null");
this.id = response.getID();
this.issuer = response.getIssuer();
this.algorithm = parameters.get(Saml2ParameterNames.SIG_ALG);
this.algorithm = Objects.requireNonNull(parameters.get(Saml2ParameterNames.SIG_ALG),
"sigAlg parameter cannot be null");
if (parameters.get(Saml2ParameterNames.SIGNATURE) != null) {
this.signature = Saml2Utils.samlDecode(parameters.get(Saml2ParameterNames.SIGNATURE));
}
@ -131,7 +139,8 @@ interface OpenSamlOperations {
this.content = getContent(Saml2ParameterNames.SAML_RESPONSE, relayState, queryParams);
}
static byte[] getContent(String samlObject, String relayState, final Map<String, String> queryParams) {
static byte[] getContent(String samlObject, @Nullable String relayState,
final Map<String, String> queryParams) {
if (Objects.nonNull(relayState)) {
return String
.format("%s=%s&%s=%s&%s=%s", samlObject, queryParams.get(samlObject),
@ -163,7 +172,7 @@ interface OpenSamlOperations {
return this.algorithm;
}
byte[] getSignature() {
byte @Nullable [] getSignature() {
return this.signature;
}

View File

@ -17,6 +17,7 @@
package org.springframework.security.saml2.provider.service.web.authentication;
import jakarta.servlet.http.HttpServletRequest;
import org.jspecify.annotations.Nullable;
import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
@ -31,6 +32,6 @@ public interface Saml2AuthenticationRequestResolver {
String DEFAULT_AUTHENTICATION_REQUEST_URI = "/saml2/authenticate/{registrationId}";
<T extends AbstractSaml2AuthenticationRequest> T resolve(HttpServletRequest request);
<T extends AbstractSaml2AuthenticationRequest> @Nullable T resolve(HttpServletRequest request);
}

View File

@ -18,6 +18,7 @@ package org.springframework.security.saml2.provider.service.web.authentication;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication;
@ -122,7 +123,7 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce
}
@Override
public Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response)
public @Nullable Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response)
throws AuthenticationException {
Authentication authentication = this.authenticationConverter.convert(request);
if (authentication == null) {

View File

@ -20,21 +20,26 @@ import java.time.Clock;
import java.time.Instant;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.function.Consumer;
import jakarta.servlet.http.HttpServletRequest;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.opensaml.core.config.ConfigurationService;
import org.opensaml.core.xml.XMLObjectBuilder;
import org.opensaml.core.xml.XMLObjectBuilderFactory;
import org.opensaml.core.xml.config.XMLObjectProviderRegistry;
import org.opensaml.saml.common.AbstractSAMLObjectBuilder;
import org.opensaml.saml.common.SAMLObject;
import org.opensaml.saml.saml2.core.Issuer;
import org.opensaml.saml.saml2.core.LogoutRequest;
import org.opensaml.saml.saml2.core.NameID;
import org.opensaml.saml.saml2.core.SessionIndex;
import org.opensaml.saml.saml2.core.impl.IssuerBuilder;
import org.opensaml.saml.saml2.core.impl.LogoutRequestBuilder;
import org.opensaml.saml.saml2.core.impl.LogoutRequestMarshaller;
import org.opensaml.saml.saml2.core.impl.NameIDBuilder;
import org.opensaml.saml.saml2.core.impl.SessionIndexBuilder;
@ -69,8 +74,6 @@ final class BaseOpenSamlLogoutRequestResolver implements Saml2LogoutRequestResol
private Clock clock = Clock.systemUTC();
private final LogoutRequestMarshaller marshaller;
private final IssuerBuilder issuerBuilder;
private final NameIDBuilder nameIdBuilder;
@ -94,19 +97,17 @@ final class BaseOpenSamlLogoutRequestResolver implements Saml2LogoutRequestResol
this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
this.saml = saml;
XMLObjectProviderRegistry registry = ConfigurationService.get(XMLObjectProviderRegistry.class);
this.marshaller = (LogoutRequestMarshaller) registry.getMarshallerFactory()
.getMarshaller(LogoutRequest.DEFAULT_ELEMENT_NAME);
Assert.notNull(this.marshaller, "logoutRequestMarshaller must be configured in OpenSAML");
this.logoutRequestBuilder = (LogoutRequestBuilder) registry.getBuilderFactory()
.getBuilder(LogoutRequest.DEFAULT_ELEMENT_NAME);
Assert.notNull(this.logoutRequestBuilder, "logoutRequestBuilder must be configured in OpenSAML");
this.issuerBuilder = (IssuerBuilder) registry.getBuilderFactory().getBuilder(Issuer.DEFAULT_ELEMENT_NAME);
Assert.notNull(this.issuerBuilder, "issuerBuilder must be configured in OpenSAML");
this.nameIdBuilder = (NameIDBuilder) registry.getBuilderFactory().getBuilder(NameID.DEFAULT_ELEMENT_NAME);
Assert.notNull(this.nameIdBuilder, "nameIdBuilder must be configured in OpenSAML");
this.sessionIndexBuilder = (SessionIndexBuilder) registry.getBuilderFactory()
.getBuilder(SessionIndex.DEFAULT_ELEMENT_NAME);
Assert.notNull(this.sessionIndexBuilder, "sessionIndexBuilder must be configured in OpenSAML");
Assert.notNull(registry, "XMLObjectProviderRegistry must be configured");
XMLObjectBuilderFactory builderFactory = registry.getBuilderFactory();
this.logoutRequestBuilder = builder(builderFactory.ensureBuilder(LogoutRequest.DEFAULT_ELEMENT_NAME));
this.issuerBuilder = builder(builderFactory.ensureBuilder(Issuer.DEFAULT_ELEMENT_NAME));
this.nameIdBuilder = builder(builderFactory.ensureBuilder(NameID.DEFAULT_ELEMENT_NAME));
this.sessionIndexBuilder = builder(builderFactory.ensureBuilder(SessionIndex.DEFAULT_ELEMENT_NAME));
}
private static <T extends SAMLObject, B extends AbstractSAMLObjectBuilder<T>> B builder(
XMLObjectBuilder<T> builder) {
return (B) builder;
}
void setClock(Clock clock) {
@ -132,7 +133,7 @@ final class BaseOpenSamlLogoutRequestResolver implements Saml2LogoutRequestResol
* @return a signed and serialized SAML 2.0 Logout Request
*/
@Override
public Saml2LogoutRequest resolve(HttpServletRequest request, Authentication authentication) {
public @Nullable Saml2LogoutRequest resolve(HttpServletRequest request, Authentication authentication) {
String registrationId = getRegistrationId(authentication);
RelyingPartyRegistration registration = this.relyingPartyRegistrationResolver.resolve(request, registrationId);
if (registration == null) {
@ -178,7 +179,7 @@ final class BaseOpenSamlLogoutRequestResolver implements Saml2LogoutRequestResol
}
String relayState = this.relayStateResolver.convert(request);
Saml2LogoutRequest.Builder result = Saml2LogoutRequest.withRelyingPartyRegistration(registration)
.id(logoutRequest.getID());
.id(Objects.requireNonNull(logoutRequest.getID()));
if (registration.getAssertingPartyMetadata().getSingleLogoutServiceBinding() == Saml2MessageBinding.POST) {
String xml = serialize(this.saml.withSigningKeys(registration.getSigningX509Credentials())
.algorithms(registration.getAssertingPartyMetadata().getSigningAlgorithms())
@ -200,13 +201,10 @@ final class BaseOpenSamlLogoutRequestResolver implements Saml2LogoutRequestResol
}
}
private String getRegistrationId(Authentication authentication) {
private @Nullable String getRegistrationId(Authentication authentication) {
if (this.logger.isTraceEnabled()) {
this.logger.trace("Attempting to resolve registrationId from " + authentication);
}
if (authentication == null) {
return null;
}
if (authentication instanceof Saml2AssertionAuthentication response) {
return response.getRelyingPartyRegistrationId();
}

View File

@ -16,7 +16,12 @@
package org.springframework.security.saml2.provider.service.web.authentication.logout;
import java.util.Objects;
import jakarta.servlet.http.HttpServletRequest;
import org.jspecify.annotations.Nullable;
import org.opensaml.core.xml.schema.XSString;
import org.opensaml.saml.saml2.core.Issuer;
import org.opensaml.saml.saml2.core.LogoutRequest;
import org.springframework.http.HttpMethod;
@ -95,7 +100,8 @@ final class BaseOpenSamlLogoutRequestValidatorParametersResolver
* non-existent {@code registrationId}
*/
@Override
public Saml2LogoutRequestValidatorParameters resolve(HttpServletRequest request, Authentication authentication) {
public @Nullable Saml2LogoutRequestValidatorParameters resolve(HttpServletRequest request,
@Nullable Authentication authentication) {
if (request.getParameter(Saml2ParameterNames.SAML_REQUEST) == null) {
return null;
}
@ -126,7 +132,8 @@ final class BaseOpenSamlLogoutRequestValidatorParametersResolver
this.requestMatcher = requestMatcher;
}
private String getRegistrationId(RequestMatcher.MatchResult result, Authentication authentication) {
private @Nullable String getRegistrationId(RequestMatcher.MatchResult result,
@Nullable Authentication authentication) {
String registrationId = result.getVariables().get("registrationId");
if (registrationId != null) {
return registrationId;
@ -143,8 +150,8 @@ final class BaseOpenSamlLogoutRequestValidatorParametersResolver
return null;
}
private Saml2LogoutRequestValidatorParameters logoutRequestById(HttpServletRequest request,
Authentication authentication, String registrationId) {
private @Nullable Saml2LogoutRequestValidatorParameters logoutRequestById(HttpServletRequest request,
@Nullable Authentication authentication, String registrationId) {
RelyingPartyRegistration registration = this.registrations.findByRegistrationId(registrationId);
if (registration == null) {
throw new Saml2AuthenticationException(
@ -153,29 +160,32 @@ final class BaseOpenSamlLogoutRequestValidatorParametersResolver
return logoutRequestByRegistration(request, registration, authentication);
}
private Saml2LogoutRequestValidatorParameters logoutRequestByEntityId(HttpServletRequest request,
Authentication authentication) {
private @Nullable Saml2LogoutRequestValidatorParameters logoutRequestByEntityId(HttpServletRequest request,
@Nullable Authentication authentication) {
String serialized = request.getParameter(Saml2ParameterNames.SAML_REQUEST);
LogoutRequest logoutRequest = this.saml.deserialize(
Saml2Utils.withEncoded(serialized).inflate(HttpMethod.GET.matches(request.getMethod())).decode());
String issuer = logoutRequest.getIssuer().getValue();
RelyingPartyRegistration registration = this.registrations.findUniqueByAssertingPartyEntityId(issuer);
Issuer issuer = logoutRequest.getIssuer();
Assert.notNull(issuer, "LogoutRequest#Issuer cannot be null");
RelyingPartyRegistration registration = this.registrations.findUniqueByAssertingPartyEntityId(getValue(issuer));
return logoutRequestByRegistration(request, registration, authentication);
}
private Saml2LogoutRequestValidatorParameters logoutRequestByRegistration(HttpServletRequest request,
RelyingPartyRegistration registration, Authentication authentication) {
private @Nullable Saml2LogoutRequestValidatorParameters logoutRequestByRegistration(HttpServletRequest request,
@Nullable RelyingPartyRegistration registration, @Nullable Authentication authentication) {
if (registration == null) {
return null;
}
Saml2MessageBinding saml2MessageBinding = Saml2MessageBindingUtils.resolveBinding(request);
registration = fromRequest(request, registration);
String serialized = request.getParameter(Saml2ParameterNames.SAML_REQUEST);
String location = registration.getSingleLogoutServiceLocation();
Assert.notNull(location, "logoutServiceLocation must be configured");
Saml2LogoutRequest logoutRequest = Saml2LogoutRequest.withRelyingPartyRegistration(registration)
.samlRequest(serialized)
.relayState(request.getParameter(Saml2ParameterNames.RELAY_STATE))
.binding(saml2MessageBinding)
.location(registration.getSingleLogoutServiceLocation())
.location(location)
.parameters((params) -> params.put(Saml2ParameterNames.SIG_ALG,
request.getParameter(Saml2ParameterNames.SIG_ALG)))
.parameters((params) -> params.put(Saml2ParameterNames.SIGNATURE,
@ -188,7 +198,7 @@ final class BaseOpenSamlLogoutRequestValidatorParametersResolver
private RelyingPartyRegistration fromRequest(HttpServletRequest request, RelyingPartyRegistration registration) {
RelyingPartyRegistrationPlaceholderResolvers.UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers
.uriResolver(request, registration);
String entityId = uriResolver.resolve(registration.getEntityId());
String entityId = Objects.requireNonNull(uriResolver.resolve(registration.getEntityId()));
String logoutLocation = uriResolver.resolve(registration.getSingleLogoutServiceLocation());
String logoutResponseLocation = uriResolver.resolve(registration.getSingleLogoutServiceResponseLocation());
return registration.mutate()
@ -198,4 +208,10 @@ final class BaseOpenSamlLogoutRequestValidatorParametersResolver
.build();
}
private String getValue(XSString element) {
String value = element.getValue();
Assert.notNull(value, "required elements must have a value");
return value;
}
}

View File

@ -26,18 +26,21 @@ import java.util.function.Consumer;
import jakarta.servlet.http.HttpServletRequest;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.opensaml.core.config.ConfigurationService;
import org.opensaml.core.xml.XMLObjectBuilder;
import org.opensaml.core.xml.XMLObjectBuilderFactory;
import org.opensaml.core.xml.config.XMLObjectProviderRegistry;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.core.xml.schema.XSString;
import org.opensaml.saml.common.AbstractSAMLObjectBuilder;
import org.opensaml.saml.common.SAMLObject;
import org.opensaml.saml.saml2.core.Issuer;
import org.opensaml.saml.saml2.core.LogoutRequest;
import org.opensaml.saml.saml2.core.LogoutResponse;
import org.opensaml.saml.saml2.core.Status;
import org.opensaml.saml.saml2.core.StatusCode;
import org.opensaml.saml.saml2.core.impl.IssuerBuilder;
import org.opensaml.saml.saml2.core.impl.LogoutRequestUnmarshaller;
import org.opensaml.saml.saml2.core.impl.LogoutResponseBuilder;
import org.opensaml.saml.saml2.core.impl.LogoutResponseMarshaller;
import org.opensaml.saml.saml2.core.impl.StatusBuilder;
import org.opensaml.saml.saml2.core.impl.StatusCodeBuilder;
@ -70,12 +73,6 @@ final class BaseOpenSamlLogoutResponseResolver implements Saml2LogoutResponseRes
private final Log logger = LogFactory.getLog(getClass());
private XMLObjectProviderRegistry registry;
private final LogoutRequestUnmarshaller unmarshaller;
private final LogoutResponseMarshaller marshaller;
private final LogoutResponseBuilder logoutResponseBuilder;
private final IssuerBuilder issuerBuilder;
@ -86,7 +83,7 @@ final class BaseOpenSamlLogoutResponseResolver implements Saml2LogoutResponseRes
private final OpenSamlOperations saml;
private final RelyingPartyRegistrationRepository registrations;
private final @Nullable RelyingPartyRegistrationRepository registrations;
private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver;
@ -98,27 +95,23 @@ final class BaseOpenSamlLogoutResponseResolver implements Saml2LogoutResponseRes
/**
* Construct a {@link BaseOpenSamlLogoutResponseResolver}
*/
BaseOpenSamlLogoutResponseResolver(RelyingPartyRegistrationRepository registrations,
BaseOpenSamlLogoutResponseResolver(@Nullable RelyingPartyRegistrationRepository registrations,
RelyingPartyRegistrationResolver relyingPartyRegistrationResolver, OpenSamlOperations saml) {
this.saml = saml;
this.registrations = registrations;
this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
this.registry = ConfigurationService.get(XMLObjectProviderRegistry.class);
this.unmarshaller = (LogoutRequestUnmarshaller) XMLObjectProviderRegistrySupport.getUnmarshallerFactory()
.getUnmarshaller(LogoutRequest.DEFAULT_ELEMENT_NAME);
this.marshaller = (LogoutResponseMarshaller) this.registry.getMarshallerFactory()
.getMarshaller(LogoutResponse.DEFAULT_ELEMENT_NAME);
Assert.notNull(this.marshaller, "logoutResponseMarshaller must be configured in OpenSAML");
this.logoutResponseBuilder = (LogoutResponseBuilder) this.registry.getBuilderFactory()
.getBuilder(LogoutResponse.DEFAULT_ELEMENT_NAME);
Assert.notNull(this.logoutResponseBuilder, "logoutResponseBuilder must be configured in OpenSAML");
this.issuerBuilder = (IssuerBuilder) this.registry.getBuilderFactory().getBuilder(Issuer.DEFAULT_ELEMENT_NAME);
Assert.notNull(this.issuerBuilder, "issuerBuilder must be configured in OpenSAML");
this.statusBuilder = (StatusBuilder) this.registry.getBuilderFactory().getBuilder(Status.DEFAULT_ELEMENT_NAME);
Assert.notNull(this.statusBuilder, "statusBuilder must be configured in OpenSAML");
this.statusCodeBuilder = (StatusCodeBuilder) this.registry.getBuilderFactory()
.getBuilder(StatusCode.DEFAULT_ELEMENT_NAME);
Assert.notNull(this.statusCodeBuilder, "statusCodeBuilder must be configured in OpenSAML");
XMLObjectProviderRegistry registry = ConfigurationService.get(XMLObjectProviderRegistry.class);
Assert.notNull(registry, "XMLObjectProviderRegistry cannot be null");
XMLObjectBuilderFactory builderFactory = registry.getBuilderFactory();
this.logoutResponseBuilder = builder(builderFactory.ensureBuilder(LogoutResponse.DEFAULT_ELEMENT_NAME));
this.issuerBuilder = builder(builderFactory.ensureBuilder(Issuer.DEFAULT_ELEMENT_NAME));
this.statusBuilder = builder(builderFactory.ensureBuilder(Status.DEFAULT_ELEMENT_NAME));
this.statusCodeBuilder = builder(builderFactory.ensureBuilder(StatusCode.DEFAULT_ELEMENT_NAME));
}
private static <T extends SAMLObject, B extends AbstractSAMLObjectBuilder<T>> B builder(
XMLObjectBuilder<T> builder) {
return (B) builder;
}
/**
@ -133,23 +126,25 @@ final class BaseOpenSamlLogoutResponseResolver implements Saml2LogoutResponseRes
* @return a signed and serialized SAML 2.0 Logout Response
*/
@Override
public Saml2LogoutResponse resolve(HttpServletRequest request, Authentication authentication) {
public @Nullable Saml2LogoutResponse resolve(HttpServletRequest request, @Nullable Authentication authentication) {
return resolve(request, authentication, StatusCode.SUCCESS);
}
@Override
public Saml2LogoutResponse resolve(HttpServletRequest request, Authentication authentication,
public @Nullable Saml2LogoutResponse resolve(HttpServletRequest request, @Nullable Authentication authentication,
Saml2AuthenticationException authenticationException) {
return resolve(request, authentication, getSamlStatus(authenticationException));
}
private Saml2LogoutResponse resolve(HttpServletRequest request, Authentication authentication, String statusCode) {
private @Nullable Saml2LogoutResponse resolve(HttpServletRequest request, @Nullable Authentication authentication,
String statusCode) {
LogoutRequest logoutRequest = this.saml.deserialize(extractSamlRequest(request));
String registrationId = getRegistrationId(authentication);
RelyingPartyRegistration registration = this.relyingPartyRegistrationResolver.resolve(request, registrationId);
if (registration == null && this.registrations != null) {
String issuer = logoutRequest.getIssuer().getValue();
registration = this.registrations.findUniqueByAssertingPartyEntityId(issuer);
Issuer issuer = logoutRequest.getIssuer();
Assert.notNull(issuer, "LogoutRequest#Issuer cannot be null");
registration = this.registrations.findUniqueByAssertingPartyEntityId(getValue(issuer));
}
if (registration == null) {
return null;
@ -162,12 +157,12 @@ final class BaseOpenSamlLogoutResponseResolver implements Saml2LogoutResponseRes
LogoutResponse logoutResponse = this.logoutResponseBuilder.buildObject();
logoutResponse
.setDestination(registration.getAssertingPartyMetadata().getSingleLogoutServiceResponseLocation());
Issuer issuer = this.issuerBuilder.buildObject();
Issuer issuer = this.issuerBuilder.buildObject(Issuer.DEFAULT_ELEMENT_NAME);
issuer.setValue(entityId);
logoutResponse.setIssuer(issuer);
StatusCode code = this.statusCodeBuilder.buildObject();
StatusCode code = this.statusCodeBuilder.buildObject(StatusCode.DEFAULT_ELEMENT_NAME);
code.setValue(statusCode);
Status status = this.statusBuilder.buildObject();
Status status = this.statusBuilder.buildObject(Status.DEFAULT_ELEMENT_NAME);
status.setStatusCode(code);
logoutResponse.setStatus(status);
logoutResponse.setInResponseTo(logoutRequest.getID());
@ -206,6 +201,12 @@ final class BaseOpenSamlLogoutResponseResolver implements Saml2LogoutResponseRes
}
}
String getValue(XSString object) {
String value = object.getValue();
Assert.notNull(value, "required elements must have a value");
return value;
}
void setClock(Clock clock) {
this.clock = clock;
}
@ -214,7 +215,7 @@ final class BaseOpenSamlLogoutResponseResolver implements Saml2LogoutResponseRes
this.parametersConsumer = parametersConsumer;
}
private String getRegistrationId(Authentication authentication) {
private @Nullable String getRegistrationId(@Nullable Authentication authentication) {
if (this.logger.isTraceEnabled()) {
this.logger.trace("Attempting to resolve registrationId from " + authentication);
}
@ -255,12 +256,12 @@ final class BaseOpenSamlLogoutResponseResolver implements Saml2LogoutResponseRes
private final RelyingPartyRegistration registration;
private final Authentication authentication;
private final @Nullable Authentication authentication;
private final LogoutRequest logoutRequest;
LogoutResponseParameters(HttpServletRequest request, RelyingPartyRegistration registration,
Authentication authentication, LogoutRequest logoutRequest) {
@Nullable Authentication authentication, LogoutRequest logoutRequest) {
this.request = request;
this.registration = registration;
this.authentication = authentication;
@ -275,7 +276,7 @@ final class BaseOpenSamlLogoutResponseResolver implements Saml2LogoutResponseRes
return this.registration;
}
Authentication getAuthentication() {
@Nullable Authentication getAuthentication() {
return this.authentication;
}

View File

@ -21,6 +21,7 @@ import java.security.MessageDigest;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpSession;
import org.jspecify.annotations.Nullable;
import org.springframework.security.crypto.codec.Utf8;
import org.springframework.security.saml2.core.Saml2ParameterNames;
@ -45,7 +46,7 @@ public final class HttpSessionLogoutRequestRepository implements Saml2LogoutRequ
* {@inheritDoc}
*/
@Override
public Saml2LogoutRequest loadLogoutRequest(HttpServletRequest request) {
public @Nullable Saml2LogoutRequest loadLogoutRequest(HttpServletRequest request) {
Assert.notNull(request, "request cannot be null");
HttpSession session = request.getSession(false);
if (session == null) {
@ -62,7 +63,7 @@ public final class HttpSessionLogoutRequestRepository implements Saml2LogoutRequ
* {@inheritDoc}
*/
@Override
public void saveLogoutRequest(Saml2LogoutRequest logoutRequest, HttpServletRequest request,
public void saveLogoutRequest(@Nullable Saml2LogoutRequest logoutRequest, HttpServletRequest request,
HttpServletResponse response) {
Assert.notNull(request, "request cannot be null");
Assert.notNull(response, "response cannot be null");
@ -79,7 +80,7 @@ public final class HttpSessionLogoutRequestRepository implements Saml2LogoutRequ
* {@inheritDoc}
*/
@Override
public Saml2LogoutRequest removeLogoutRequest(HttpServletRequest request, HttpServletResponse response) {
public @Nullable Saml2LogoutRequest removeLogoutRequest(HttpServletRequest request, HttpServletResponse response) {
Assert.notNull(request, "request cannot be null");
Assert.notNull(response, "response cannot be null");
Saml2LogoutRequest logoutRequest = loadLogoutRequest(request);
@ -90,16 +91,19 @@ public final class HttpSessionLogoutRequestRepository implements Saml2LogoutRequ
return logoutRequest;
}
private String getStateParameter(HttpServletRequest request) {
private @Nullable String getStateParameter(HttpServletRequest request) {
return request.getParameter(Saml2ParameterNames.RELAY_STATE);
}
private boolean stateParameterEquals(HttpServletRequest request, Saml2LogoutRequest logoutRequest) {
private boolean stateParameterEquals(HttpServletRequest request, @Nullable Saml2LogoutRequest logoutRequest) {
String stateParameter = getStateParameter(request);
if (stateParameter == null || logoutRequest == null) {
return false;
}
String relayState = logoutRequest.getRelayState();
if (relayState == null) {
return false;
}
return MessageDigest.isEqual(Utf8.encode(stateParameter), Utf8.encode(relayState));
}

View File

@ -25,6 +25,7 @@ import java.util.Objects;
import javax.xml.namespace.QName;
import org.jspecify.annotations.Nullable;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.saml.saml2.core.Issuer;
import org.opensaml.saml.saml2.core.RequestAbstractType;
@ -35,6 +36,7 @@ import org.w3c.dom.Element;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.core.Saml2ParameterNames;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.util.Assert;
import org.springframework.web.util.UriComponentsBuilder;
interface OpenSamlOperations {
@ -89,14 +91,17 @@ interface OpenSamlOperations {
private final String algorithm;
private final byte[] signature;
private final byte @Nullable [] signature;
private final byte[] content;
RedirectParameters(Map<String, String> parameters, String parametersQuery, RequestAbstractType request) {
Assert.notNull(request.getID(), "SAML request's ID cannot be null");
Assert.notNull(request.getIssuer(), "SAML request's Issuer cannot be null");
this.id = request.getID();
this.issuer = request.getIssuer();
this.algorithm = parameters.get(Saml2ParameterNames.SIG_ALG);
this.algorithm = Objects.requireNonNull(parameters.get(Saml2ParameterNames.SIG_ALG),
"sigAlg parameter cannot be null");
if (parameters.get(Saml2ParameterNames.SIGNATURE) != null) {
this.signature = Saml2Utils.samlDecode(parameters.get(Saml2ParameterNames.SIGNATURE));
}
@ -113,9 +118,12 @@ interface OpenSamlOperations {
}
RedirectParameters(Map<String, String> parameters, String parametersQuery, StatusResponseType response) {
Assert.notNull(response.getID(), "SAML response's ID cannot be null");
Assert.notNull(response.getIssuer(), "SAML response's Issuer cannot be null");
this.id = response.getID();
this.issuer = response.getIssuer();
this.algorithm = parameters.get(Saml2ParameterNames.SIG_ALG);
this.algorithm = Objects.requireNonNull(parameters.get(Saml2ParameterNames.SIG_ALG),
"sigAlg parameter cannot be null");
if (parameters.get(Saml2ParameterNames.SIGNATURE) != null) {
this.signature = Saml2Utils.samlDecode(parameters.get(Saml2ParameterNames.SIGNATURE));
}
@ -131,7 +139,8 @@ interface OpenSamlOperations {
this.content = getContent(Saml2ParameterNames.SAML_RESPONSE, relayState, queryParams);
}
static byte[] getContent(String samlObject, String relayState, final Map<String, String> queryParams) {
static byte[] getContent(String samlObject, @Nullable String relayState,
final Map<String, String> queryParams) {
if (Objects.nonNull(relayState)) {
return String
.format("%s=%s&%s=%s&%s=%s", samlObject, queryParams.get(samlObject),
@ -163,7 +172,7 @@ interface OpenSamlOperations {
return this.algorithm;
}
byte[] getSignature() {
byte @Nullable [] getSignature() {
return this.signature;
}

View File

@ -17,6 +17,7 @@
package org.springframework.security.saml2.provider.service.web.authentication.logout;
import java.io.IOException;
import java.util.Objects;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
@ -24,6 +25,7 @@ import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.core.log.LogMessage;
import org.springframework.http.MediaType;
@ -210,8 +212,9 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
private void doRedirect(HttpServletRequest request, HttpServletResponse response,
Saml2LogoutResponse logoutResponse) throws IOException {
String location = logoutResponse.getResponseLocation();
UriComponentsBuilder uriBuilder = UriComponentsBuilder.fromUriString(location)
.query(logoutResponse.getParametersQuery());
String query = logoutResponse.getParametersQuery();
Assert.notNull(query, "logout response must have a parameters query when using redirect binding");
UriComponentsBuilder uriBuilder = UriComponentsBuilder.fromUriString(location).query(query);
this.redirectStrategy.sendRedirect(request, response, uriBuilder.build(true).toUriString());
}
@ -224,7 +227,7 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
response.getWriter().write(html);
}
private String createSamlPostRequestFormData(String location, String saml, String relayState) {
private String createSamlPostRequestFormData(String location, String saml, @Nullable String relayState) {
StringBuilder html = new StringBuilder();
html.append("<!DOCTYPE html>\n");
html.append("<html>\n").append(" <head>\n");
@ -279,8 +282,8 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
}
@Override
public Saml2LogoutRequestValidatorParameters resolve(HttpServletRequest request,
Authentication authentication) {
public @Nullable Saml2LogoutRequestValidatorParameters resolve(HttpServletRequest request,
@Nullable Authentication authentication) {
String serialized = request.getParameter(Saml2ParameterNames.SAML_REQUEST);
if (serialized == null) {
return null;
@ -298,6 +301,7 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
}
UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration);
String entityId = uriResolver.resolve(registration.getEntityId());
entityId = Objects.requireNonNull(entityId);
String logoutLocation = uriResolver.resolve(registration.getSingleLogoutServiceLocation());
String logoutResponseLocation = uriResolver.resolve(registration.getSingleLogoutServiceResponseLocation());
registration = registration.mutate()
@ -310,7 +314,6 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
.samlRequest(serialized)
.relayState(request.getParameter(Saml2ParameterNames.RELAY_STATE))
.binding(saml2MessageBinding)
.location(registration.getSingleLogoutServiceLocation())
.parameters((params) -> params.put(Saml2ParameterNames.SIG_ALG,
request.getParameter(Saml2ParameterNames.SIG_ALG)))
.parameters((params) -> params.put(Saml2ParameterNames.SIGNATURE,
@ -325,7 +328,8 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
this.logoutRequestMatcher = logoutRequestMatcher;
}
private String getRegistrationId(RequestMatcher.MatchResult result, Authentication authentication) {
private @Nullable String getRegistrationId(RequestMatcher.MatchResult result,
@Nullable Authentication authentication) {
String registrationId = result.getVariables().get("registrationId");
if (registrationId != null) {
return registrationId;

View File

@ -18,6 +18,7 @@ package org.springframework.security.saml2.provider.service.web.authentication.l
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.jspecify.annotations.Nullable;
import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequest;
@ -44,16 +45,18 @@ public interface Saml2LogoutRequestRepository {
* @param request the {@code HttpServletRequest}
* @return the {@link Saml2LogoutRequest} or {@code null} if not available
*/
Saml2LogoutRequest loadLogoutRequest(HttpServletRequest request);
@Nullable Saml2LogoutRequest loadLogoutRequest(HttpServletRequest request);
/**
* Persists the {@link Saml2LogoutRequest} associating it to the provided
* {@code HttpServletRequest} and/or {@code HttpServletResponse}.
* @param logoutRequest the {@link Saml2LogoutRequest}
* @param logoutRequest the {@link Saml2LogoutRequest}, if {@code null}, then remove
* logout request
* @param request the {@code HttpServletRequest}
* @param response the {@code HttpServletResponse}
*/
void saveLogoutRequest(Saml2LogoutRequest logoutRequest, HttpServletRequest request, HttpServletResponse response);
void saveLogoutRequest(@Nullable Saml2LogoutRequest logoutRequest, HttpServletRequest request,
HttpServletResponse response);
/**
* Removes and returns the {@link Saml2LogoutRequest} associated to the provided
@ -63,6 +66,6 @@ public interface Saml2LogoutRequestRepository {
* @param response the {@code HttpServletResponse}
* @return the {@link Saml2LogoutRequest} or {@code null} if not available
*/
Saml2LogoutRequest removeLogoutRequest(HttpServletRequest request, HttpServletResponse response);
@Nullable Saml2LogoutRequest removeLogoutRequest(HttpServletRequest request, HttpServletResponse response);
}

View File

@ -17,6 +17,7 @@
package org.springframework.security.saml2.provider.service.web.authentication.logout;
import jakarta.servlet.http.HttpServletRequest;
import org.jspecify.annotations.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequest;
@ -44,6 +45,6 @@ public interface Saml2LogoutRequestResolver {
* @param authentication the current user
* @return a signed and serialized SAML 2.0 Logout Request
*/
Saml2LogoutRequest resolve(HttpServletRequest request, Authentication authentication);
@Nullable Saml2LogoutRequest resolve(HttpServletRequest request, Authentication authentication);
}

View File

@ -17,6 +17,7 @@
package org.springframework.security.saml2.provider.service.web.authentication.logout;
import jakarta.servlet.http.HttpServletRequest;
import org.jspecify.annotations.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequestValidatorParameters;
@ -40,6 +41,6 @@ public interface Saml2LogoutRequestValidatorParametersResolver {
* @param authentication the current user, if any; may be null
* @return a SAML 2.0 Logout Request, if any; may be null
*/
Saml2LogoutRequestValidatorParameters resolve(HttpServletRequest request, Authentication authentication);
@Nullable Saml2LogoutRequestValidatorParameters resolve(HttpServletRequest request, @Nullable Authentication authentication);
}

View File

@ -17,6 +17,7 @@
package org.springframework.security.saml2.provider.service.web.authentication.logout;
import java.io.IOException;
import java.util.Objects;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
@ -135,16 +136,17 @@ public final class Saml2LogoutResponseFilter extends OncePerRequestFilter {
response.sendError(HttpServletResponse.SC_BAD_REQUEST, error.toString());
return;
}
if (registration.getSingleLogoutServiceResponseLocation() == null) {
String responseLocation = registration.getSingleLogoutServiceResponseLocation();
if (responseLocation == null) {
this.logger.trace(
"Did not process logout response since RelyingPartyRegistration has not been configured with a logout response endpoint");
response.sendError(HttpServletResponse.SC_UNAUTHORIZED);
return;
}
UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration);
String entityId = uriResolver.resolve(registration.getEntityId());
String entityId = Objects.requireNonNull(uriResolver.resolve(registration.getEntityId()));
String logoutLocation = uriResolver.resolve(registration.getSingleLogoutServiceLocation());
String logoutResponseLocation = uriResolver.resolve(registration.getSingleLogoutServiceResponseLocation());
String logoutResponseLocation = Objects.requireNonNull(uriResolver.resolve(responseLocation));
registration = registration.mutate()
.entityId(entityId)
.singleLogoutServiceLocation(logoutLocation)
@ -162,7 +164,7 @@ public final class Saml2LogoutResponseFilter extends OncePerRequestFilter {
.samlResponse(serialized)
.relayState(request.getParameter(Saml2ParameterNames.RELAY_STATE))
.binding(saml2MessageBinding)
.location(registration.getSingleLogoutServiceResponseLocation())
.location(logoutResponseLocation)
.parameters((params) -> params.put(Saml2ParameterNames.SIG_ALG,
request.getParameter(Saml2ParameterNames.SIG_ALG)))
.parameters((params) -> params.put(Saml2ParameterNames.SIGNATURE,

View File

@ -17,6 +17,7 @@
package org.springframework.security.saml2.provider.service.web.authentication.logout;
import jakarta.servlet.http.HttpServletRequest;
import org.jspecify.annotations.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
@ -43,7 +44,7 @@ public interface Saml2LogoutResponseResolver {
* @param authentication the current user
* @return a signed and serialized SAML 2.0 Logout Response
*/
Saml2LogoutResponse resolve(HttpServletRequest request, Authentication authentication);
@Nullable Saml2LogoutResponse resolve(HttpServletRequest request, @Nullable Authentication authentication);
/**
* Prepare to create, sign, and serialize a SAML 2.0 Error Logout Response.
@ -55,7 +56,7 @@ public interface Saml2LogoutResponseResolver {
* cannot generate a SAML 2.0 Error Logout Response
* @since 7.0
*/
default Saml2LogoutResponse resolve(HttpServletRequest request, Authentication authentication,
default @Nullable Saml2LogoutResponse resolve(HttpServletRequest request, @Nullable Authentication authentication,
Saml2AuthenticationException authenticationException) {
return null;
}

View File

@ -50,11 +50,11 @@ final class Saml2MessageBindingUtils {
}
static boolean isHttpRedirectBinding(HttpServletRequest request) {
return request != null && "GET".equalsIgnoreCase(request.getMethod()) && isSamlRequestResponse(request);
return "GET".equalsIgnoreCase(request.getMethod()) && isSamlRequestResponse(request);
}
static boolean isHttpPostBinding(HttpServletRequest request) {
return request != null && "POST".equalsIgnoreCase(request.getMethod()) && isSamlRequestResponse(request);
return "POST".equalsIgnoreCase(request.getMethod()) && isSamlRequestResponse(request);
}
}

View File

@ -22,6 +22,7 @@ import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.http.MediaType;
import org.springframework.security.core.Authentication;
@ -69,8 +70,13 @@ public final class Saml2RelyingPartyInitiatedLogoutSuccessHandler implements Log
* @throws IOException when failing to write to the response
*/
@Override
public void onLogoutSuccess(HttpServletRequest request, HttpServletResponse response, Authentication authentication)
throws IOException {
public void onLogoutSuccess(HttpServletRequest request, HttpServletResponse response,
@Nullable Authentication authentication) throws IOException {
if (authentication == null) {
this.logger.trace("Returning 401 since no logout request generated");
response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
return;
}
Saml2LogoutRequest logoutRequest = this.logoutRequestResolver.resolve(request, authentication);
if (logoutRequest == null) {
this.logger.trace("Returning 401 since no logout request generated");
@ -99,8 +105,9 @@ public final class Saml2RelyingPartyInitiatedLogoutSuccessHandler implements Log
private void doRedirect(HttpServletRequest request, HttpServletResponse response, Saml2LogoutRequest logoutRequest)
throws IOException {
String location = logoutRequest.getLocation();
UriComponentsBuilder uriBuilder = UriComponentsBuilder.fromUriString(location)
.query(logoutRequest.getParametersQuery());
String query = logoutRequest.getParametersQuery();
Assert.notNull(query, "logout request must have a parameters query when using redirect binding");
UriComponentsBuilder uriBuilder = UriComponentsBuilder.fromUriString(location).query(query);
this.redirectStrategy.sendRedirect(request, response, uriBuilder.build(true).toUriString());
}
@ -113,7 +120,7 @@ public final class Saml2RelyingPartyInitiatedLogoutSuccessHandler implements Log
response.getWriter().write(html);
}
private String createSamlPostRequestFormData(String location, String saml, String relayState) {
private String createSamlPostRequestFormData(String location, String saml, @Nullable String relayState) {
StringBuilder html = new StringBuilder();
html.append("<!DOCTYPE html>\n");
html.append("<html>\n").append(" <head>\n");

View File

@ -0,0 +1,23 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Internal utilities for SAML2 support (not for public use).
*/
@NullMarked
package org.springframework.security.saml2.provider.service.web.authentication.logout;
import org.jspecify.annotations.NullMarked;

View File

@ -0,0 +1,23 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Internal utilities for SAML2 support (not for public use).
*/
@NullMarked
package org.springframework.security.saml2.provider.service.web.authentication;
import org.jspecify.annotations.NullMarked;

View File

@ -22,9 +22,11 @@ import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import jakarta.servlet.http.HttpServletRequest;
import org.jspecify.annotations.Nullable;
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.provider.service.metadata.Saml2MetadataResolver;
@ -95,7 +97,7 @@ public class RequestMatcherMetadataResponseResolver implements Saml2MetadataResp
* {@code registrationId}
*/
@Override
public Saml2MetadataResponse resolve(HttpServletRequest request) {
public @Nullable Saml2MetadataResponse resolve(HttpServletRequest request) {
RequestMatcher.MatchResult result = this.matcher.matcher(request);
if (!result.isMatch()) {
return null;
@ -115,7 +117,8 @@ public class RequestMatcherMetadataResponseResolver implements Saml2MetadataResp
return null;
}
private Saml2MetadataResponse responseByRegistrationId(HttpServletRequest request, String registrationId) {
private @Nullable Saml2MetadataResponse responseByRegistrationId(HttpServletRequest request,
@Nullable String registrationId) {
if (registrationId == null) {
return null;
}
@ -132,9 +135,10 @@ public class RequestMatcherMetadataResponseResolver implements Saml2MetadataResp
for (RelyingPartyRegistration registration : registrations) {
RelyingPartyRegistrationPlaceholderResolvers.UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers
.uriResolver(request, registration);
String entityId = uriResolver.resolve(registration.getEntityId());
String entityId = Objects.requireNonNull(uriResolver.resolve(registration.getEntityId()));
results.computeIfAbsent(entityId, (e) -> {
String ssoLocation = uriResolver.resolve(registration.getAssertionConsumerServiceLocation());
ssoLocation = Objects.requireNonNull(ssoLocation);
String sloLocation = uriResolver.resolve(registration.getSingleLogoutServiceLocation());
String sloResponseLocation = uriResolver.resolve(registration.getSingleLogoutServiceResponseLocation());
return registration.mutate()

View File

@ -0,0 +1,23 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Web metadata endpoint support for SAML2 relying party.
*/
@NullMarked
package org.springframework.security.saml2.provider.service.web.metadata;
import org.jspecify.annotations.NullMarked;

View File

@ -0,0 +1,23 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Internal utilities for SAML2 support (not for public use).
*/
@NullMarked
package org.springframework.security.saml2.provider.service.web;
import org.jspecify.annotations.NullMarked;

View File

@ -17,6 +17,7 @@
package org.springframework.security.saml2.provider.service.web.authentication.logout;
import jakarta.servlet.http.HttpServletRequest;
import org.jspecify.annotations.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.saml2.core.OpenSamlInitializationService;
@ -77,7 +78,8 @@ public final class OpenSaml5LogoutRequestValidatorParametersResolver
* non-existent {@code registrationId}
*/
@Override
public Saml2LogoutRequestValidatorParameters resolve(HttpServletRequest request, Authentication authentication) {
public @Nullable Saml2LogoutRequestValidatorParameters resolve(HttpServletRequest request,
@Nullable Authentication authentication) {
return this.delegate.resolve(request, authentication);
}

View File

@ -21,6 +21,7 @@ import java.time.Instant;
import java.util.function.Consumer;
import jakarta.servlet.http.HttpServletRequest;
import org.jspecify.annotations.Nullable;
import org.opensaml.saml.saml2.core.LogoutRequest;
import org.springframework.security.core.Authentication;
@ -63,7 +64,7 @@ public final class OpenSaml5LogoutResponseResolver implements Saml2LogoutRespons
* {@inheritDoc}
*/
@Override
public Saml2LogoutResponse resolve(HttpServletRequest request, Authentication authentication) {
public @Nullable Saml2LogoutResponse resolve(HttpServletRequest request, Authentication authentication) {
return this.delegate.resolve(request, authentication);
}
@ -71,7 +72,7 @@ public final class OpenSaml5LogoutResponseResolver implements Saml2LogoutRespons
* {@inheritDoc}
*/
@Override
public Saml2LogoutResponse resolve(HttpServletRequest request, Authentication authentication,
public @Nullable Saml2LogoutResponse resolve(HttpServletRequest request, Authentication authentication,
Saml2AuthenticationException exception) {
return this.delegate.resolve(request, authentication, exception);
}