diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/logout/LogoutRequestEncryptedIDUtils.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/logout/LogoutRequestEncryptedIdUtils.java similarity index 92% rename from saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/logout/LogoutRequestEncryptedIDUtils.java rename to saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/logout/LogoutRequestEncryptedIdUtils.java index a7fd9a6f5f..5ff94a701b 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/logout/LogoutRequestEncryptedIDUtils.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/logout/LogoutRequestEncryptedIdUtils.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,16 +46,16 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP * * @author Robert Stoiber */ -final class LogoutRequestEncryptedIDUtils { +final class LogoutRequestEncryptedIdUtils { private static final EncryptedKeyResolver encryptedKeyResolver = new ChainingEncryptedKeyResolver( Arrays.asList(new InlineEncryptedKeyResolver(), new EncryptedElementTypeEncryptedKeyResolver(), new SimpleRetrievalMethodEncryptedKeyResolver())); - static SAMLObject decryptEncryptedID(EncryptedID encryptedID, RelyingPartyRegistration registration) { + static SAMLObject decryptEncryptedId(EncryptedID encryptedId, RelyingPartyRegistration registration) { Decrypter decrypter = decrypter(registration); try { - return decrypter.decrypt(encryptedID); + return decrypter.decrypt(encryptedId); } catch (Exception ex) { @@ -75,7 +75,7 @@ final class LogoutRequestEncryptedIDUtils { return decrypter; } - private LogoutRequestEncryptedIDUtils() { + private LogoutRequestEncryptedIdUtils() { } } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSamlLogoutRequestValidator.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSamlLogoutRequestValidator.java index e20082a760..5345aa8875 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSamlLogoutRequestValidator.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSamlLogoutRequestValidator.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -161,25 +161,30 @@ public final class OpenSamlLogoutRequestValidator implements Saml2LogoutRequestV if (authentication == null) { return; } - NameID nameId = request.getNameID(); - EncryptedID encryptedID = request.getEncryptedID(); - if (nameId == null && encryptedID == null) { + NameID nameId = getNameId(request, registration); + if (nameId == null) { errors.add( new Saml2Error(Saml2ErrorCodes.SUBJECT_NOT_FOUND, "Failed to find subject in LogoutRequest")); return; } - if (nameId != null) { - validateNameID(nameId, authentication, errors); - } - else { - final NameID nameIDFromEncryptedID = decryptNameID(encryptedID, registration); - validateNameID(nameIDFromEncryptedID, authentication, errors); - } + validateNameId(nameId, authentication, errors); }; } - private void validateNameID(NameID nameId, Authentication authentication, Collection errors) { + private NameID getNameId(LogoutRequest request, RelyingPartyRegistration registration) { + NameID nameId = request.getNameID(); + if (nameId != null) { + return nameId; + } + EncryptedID encryptedId = request.getEncryptedID(); + if (encryptedId == null) { + return null; + } + return decryptNameId(encryptedId, registration); + } + + private void validateNameId(NameID nameId, Authentication authentication, Collection errors) { String name = nameId.getValue(); if (!name.equals(authentication.getName())) { errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_REQUEST, @@ -187,8 +192,8 @@ public final class OpenSamlLogoutRequestValidator implements Saml2LogoutRequestV } } - private NameID decryptNameID(EncryptedID encryptedID, RelyingPartyRegistration registration) { - final SAMLObject decryptedId = LogoutRequestEncryptedIDUtils.decryptEncryptedID(encryptedID, registration); + private NameID decryptNameId(EncryptedID encryptedId, RelyingPartyRegistration registration) { + final SAMLObject decryptedId = LogoutRequestEncryptedIdUtils.decryptEncryptedId(encryptedId, registration); if (decryptedId instanceof NameID) { return ((NameID) decryptedId); } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java index 49eae5891d..643df7e685 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java @@ -373,8 +373,10 @@ public final class TestOpenSamlObjects { NameID nameId = nameIdBuilder.buildObject(); nameId.setValue("user"); logoutRequest.setNameID(null); - logoutRequest.setEncryptedID(encrypted(nameId, - registration.getAssertingPartyDetails().getEncryptionX509Credentials().stream().findFirst().get())); + Saml2X509Credential credential = registration.getAssertingPartyDetails().getEncryptionX509Credentials() + .iterator().next(); + EncryptedID encrypted = encrypted(nameId, credential); + logoutRequest.setEncryptedID(encrypted); IssuerBuilder issuerBuilder = new IssuerBuilder(); Issuer issuer = issuerBuilder.buildObject(); issuer.setValue(registration.getAssertingPartyDetails().getEntityId()); diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSamlLogoutRequestValidatorTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSamlLogoutRequestValidatorTests.java index 8bd38f988e..8def02122c 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSamlLogoutRequestValidatorTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSamlLogoutRequestValidatorTests.java @@ -61,17 +61,16 @@ public class OpenSamlLogoutRequestValidatorTests { } @Test - public void handleWhenNameIdInEncryptedIdPostThenValidates() { + public void handleWhenNameIdIsEncryptedIdPostThenValidates() { - RelyingPartyRegistration registration = registrationWithEncryption() - .assertingPartyDetails((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST)).build(); + RelyingPartyRegistration registration = decrypting(encrypting(registration())).build(); LogoutRequest logoutRequest = TestOpenSamlObjects.assertingPartyLogoutRequestNameIdInEncryptedId(registration); sign(logoutRequest, registration); Saml2LogoutRequest request = post(logoutRequest, registration); Saml2LogoutRequestValidatorParameters parameters = new Saml2LogoutRequestValidatorParameters(request, registration, authentication(registration)); Saml2LogoutValidatorResult result = this.manager.validate(parameters); - assertThat(result.hasErrors()).withFailMessage(() -> result.getErrors().toString()).isFalse().isFalse(); + assertThat(result.hasErrors()).withFailMessage(() -> result.getErrors().toString()).isFalse(); } @@ -149,10 +148,14 @@ public class OpenSamlLogoutRequestValidatorTests { .assertingPartyDetails((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST)); } - private RelyingPartyRegistration.Builder registrationWithEncryption() { - return signing(verifying(TestRelyingPartyRegistrations.full())) - .assertingPartyDetails((party) -> party.encryptionX509Credentials( - (c) -> c.add(TestSaml2X509Credentials.assertingPartyEncryptingCredential()))); + private RelyingPartyRegistration.Builder decrypting(RelyingPartyRegistration.Builder builder) { + return builder + .decryptionX509Credentials((c) -> c.add(TestSaml2X509Credentials.relyingPartyDecryptingCredential())); + } + + private RelyingPartyRegistration.Builder encrypting(RelyingPartyRegistration.Builder builder) { + return builder.assertingPartyDetails((party) -> party.encryptionX509Credentials( + (c) -> c.add(TestSaml2X509Credentials.assertingPartyEncryptingCredential()))); } private RelyingPartyRegistration.Builder verifying(RelyingPartyRegistration.Builder builder) {