diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/InMemoryRelyingPartyRegistrationRepository.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/InMemoryRelyingPartyRegistrationRepository.java index fd95f479a5..738f00952f 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/InMemoryRelyingPartyRegistrationRepository.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/InMemoryRelyingPartyRegistrationRepository.java @@ -21,15 +21,19 @@ import java.util.Collection; import java.util.Collections; import java.util.Iterator; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; import org.springframework.util.Assert; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; /** - * An in-memory implementation of {@link RelyingPartyRegistrationRepository}. - * Also implements {@link Iterable} to simplify the default login page. + * An in-memory implementation of {@link RelyingPartyRegistrationRepository}. Also + * implements {@link Iterable} to simplify the default login page. * * @author Filip Hanik + * @author Josh Cummings * @since 5.2 */ public class InMemoryRelyingPartyRegistrationRepository @@ -37,6 +41,8 @@ public class InMemoryRelyingPartyRegistrationRepository private final Map byRegistrationId; + private final Map> byAssertingPartyEntityId; + public InMemoryRelyingPartyRegistrationRepository(RelyingPartyRegistration... registrations) { this(Arrays.asList(registrations)); } @@ -44,6 +50,7 @@ public class InMemoryRelyingPartyRegistrationRepository public InMemoryRelyingPartyRegistrationRepository(Collection registrations) { Assert.notEmpty(registrations, "registrations cannot be empty"); this.byRegistrationId = createMappingToIdentityProvider(registrations); + this.byAssertingPartyEntityId = createMappingByAssertingPartyEntityId(registrations); } private static Map createMappingToIdentityProvider( @@ -59,11 +66,32 @@ public class InMemoryRelyingPartyRegistrationRepository return Collections.unmodifiableMap(result); } + private static Map> createMappingByAssertingPartyEntityId( + Collection rps) { + MultiValueMap result = new LinkedMultiValueMap<>(); + for (RelyingPartyRegistration rp : rps) { + result.add(rp.getAssertingPartyDetails().getEntityId(), rp); + } + return Collections.unmodifiableMap(result); + } + @Override public RelyingPartyRegistration findByRegistrationId(String id) { return this.byRegistrationId.get(id); } + @Override + public RelyingPartyRegistration findUniqueByAssertingPartyEntityId(String entityId) { + Collection registrations = this.byAssertingPartyEntityId.get(entityId); + if (registrations == null) { + return null; + } + if (registrations.size() > 1) { + return null; + } + return registrations.iterator().next(); + } + @Override public Iterator iterator() { return this.byRegistrationId.values().iterator(); diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationRepository.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationRepository.java index 1c681d92a3..691d82f009 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationRepository.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationRepository.java @@ -20,6 +20,7 @@ package org.springframework.security.saml2.provider.service.registration; * A repository for {@link RelyingPartyRegistration}s * * @author Filip Hanik + * @author Josh Cummings * @since 5.2 */ public interface RelyingPartyRegistrationRepository { @@ -32,4 +33,16 @@ public interface RelyingPartyRegistrationRepository { */ RelyingPartyRegistration findByRegistrationId(String registrationId); + /** + * Returns the unique relying party registration associated with the asserting party's + * {@code entityId} or {@code null} if there is no unique match. + * @param entityId the asserting party's entity id + * @return the unique {@link RelyingPartyRegistration} associated the given asserting + * party; {@code null} of there is no unique match asserting party + * @since 6.1 + */ + default RelyingPartyRegistration findUniqueByAssertingPartyEntityId(String entityId) { + return findByRegistrationId(entityId); + } + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/OpenSamlAuthenticationTokenConverter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/OpenSamlAuthenticationTokenConverter.java new file mode 100644 index 0000000000..cb29a370b0 --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/OpenSamlAuthenticationTokenConverter.java @@ -0,0 +1,194 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web; + +import java.io.ByteArrayOutputStream; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Base64; +import java.util.function.Function; +import java.util.zip.Inflater; +import java.util.zip.InflaterOutputStream; + +import jakarta.servlet.http.HttpServletRequest; + +import org.springframework.http.HttpMethod; +import org.springframework.security.saml2.core.Saml2Error; +import org.springframework.security.saml2.core.Saml2ErrorCodes; +import org.springframework.security.saml2.core.Saml2ParameterNames; +import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.util.Assert; + +/** + * An {@link AuthenticationConverter} that generates a {@link Saml2AuthenticationToken} + * appropriate for authenticated a SAML 2.0 Assertion against an + * {@link org.springframework.security.authentication.AuthenticationManager}. + * + * @author Josh Cummings + * @since 5.4 + */ +public final class Saml2AuthenticationTokenConverter implements AuthenticationConverter { + + // MimeDecoder allows extra line-breaks as well as other non-alphabet values. + // This matches the behaviour of the commons-codec decoder. + private static final Base64.Decoder BASE64 = Base64.getMimeDecoder(); + + private static final Base64Checker BASE_64_CHECKER = new Base64Checker(); + + private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver; + + private Function loader; + + /** + * Constructs a {@link Saml2AuthenticationTokenConverter} given a strategy for + * resolving {@link RelyingPartyRegistration}s + * @param relyingPartyRegistrationResolver the strategy for resolving + * {@link RelyingPartyRegistration}s + */ + public Saml2AuthenticationTokenConverter(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) { + Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null"); + this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver; + this.loader = new HttpSessionSaml2AuthenticationRequestRepository()::loadAuthenticationRequest; + } + + @Override + public Saml2AuthenticationToken convert(HttpServletRequest request) { + AbstractSaml2AuthenticationRequest authenticationRequest = loadAuthenticationRequest(request); + String relyingPartyRegistrationId = (authenticationRequest != null) + ? authenticationRequest.getRelyingPartyRegistrationId() : null; + RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.resolve(request, + relyingPartyRegistrationId); + if (relyingPartyRegistration == null) { + return null; + } + String saml2Response = request.getParameter(Saml2ParameterNames.SAML_RESPONSE); + if (saml2Response == null) { + return null; + } + byte[] b = samlDecode(saml2Response); + saml2Response = inflateIfRequired(request, b); + return new Saml2AuthenticationToken(relyingPartyRegistration, saml2Response, authenticationRequest); + } + + /** + * Use the given {@link Saml2AuthenticationRequestRepository} to load authentication + * request. + * @param authenticationRequestRepository the + * {@link Saml2AuthenticationRequestRepository} to use + * @since 5.6 + */ + public void setAuthenticationRequestRepository( + Saml2AuthenticationRequestRepository authenticationRequestRepository) { + Assert.notNull(authenticationRequestRepository, "authenticationRequestRepository cannot be null"); + this.loader = authenticationRequestRepository::loadAuthenticationRequest; + } + + private AbstractSaml2AuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) { + return this.loader.apply(request); + } + + private String inflateIfRequired(HttpServletRequest request, byte[] b) { + if (HttpMethod.GET.matches(request.getMethod())) { + return samlInflate(b); + } + return new String(b, StandardCharsets.UTF_8); + } + + private byte[] samlDecode(String base64EncodedPayload) { + try { + BASE_64_CHECKER.checkAcceptable(base64EncodedPayload); + return BASE64.decode(base64EncodedPayload); + } + catch (Exception ex) { + throw new Saml2AuthenticationException( + new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, "Failed to decode SAMLResponse"), ex); + } + } + + private String samlInflate(byte[] b) { + try { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + InflaterOutputStream inflaterOutputStream = new InflaterOutputStream(out, new Inflater(true)); + inflaterOutputStream.write(b); + inflaterOutputStream.finish(); + return out.toString(StandardCharsets.UTF_8.name()); + } + catch (Exception ex) { + throw new Saml2AuthenticationException( + new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, "Unable to inflate string"), ex); + } + } + + static class Base64Checker { + + private static final int[] values = genValueMapping(); + + Base64Checker() { + + } + + private static int[] genValueMapping() { + byte[] alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" + .getBytes(StandardCharsets.ISO_8859_1); + + int[] values = new int[256]; + Arrays.fill(values, -1); + for (int i = 0; i < alphabet.length; i++) { + values[alphabet[i] & 0xff] = i; + } + return values; + } + + boolean isAcceptable(String s) { + int goodChars = 0; + int lastGoodCharVal = -1; + + // count number of characters from Base64 alphabet + for (int i = 0; i < s.length(); i++) { + int val = values[0xff & s.charAt(i)]; + if (val != -1) { + lastGoodCharVal = val; + goodChars++; + } + } + + // in cases of an incomplete final chunk, ensure the unused bits are zero + switch (goodChars % 4) { + case 0: + return true; + case 2: + return (lastGoodCharVal & 0b1111) == 0; + case 3: + return (lastGoodCharVal & 0b11) == 0; + default: + return false; + } + } + + void checkAcceptable(String ins) { + if (!isAcceptable(ins)) { + throw new IllegalArgumentException("Unaccepted Encoding"); + } + } + + } + +} diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/InMemoryRelyingPartyRegistrationRepositoryTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/InMemoryRelyingPartyRegistrationRepositoryTests.java index 24ce66792a..22fee63c07 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/InMemoryRelyingPartyRegistrationRepositoryTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/InMemoryRelyingPartyRegistrationRepositoryTests.java @@ -42,4 +42,22 @@ public class InMemoryRelyingPartyRegistrationRepositoryTests { assertThat(registrations.findByRegistrationId(null)).isNull(); } + @Test + void findByAssertingPartyEntityIdWhenGivenEntityIdThenReturnsMatchingRegistrations() { + RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration().build(); + InMemoryRelyingPartyRegistrationRepository registrations = new InMemoryRelyingPartyRegistrationRepository( + registration); + String assertingPartyEntityId = registration.getAssertingPartyDetails().getEntityId(); + assertThat(registrations.findUniqueByAssertingPartyEntityId(assertingPartyEntityId)).isEqualTo(registration); + } + + @Test + void findByAssertingPartyEntityIdWhenGivenWrongEntityIdThenReturnsEmpty() { + RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration().build(); + InMemoryRelyingPartyRegistrationRepository registrations = new InMemoryRelyingPartyRegistrationRepository( + registration); + String assertingPartyEntityId = registration.getAssertingPartyDetails().getEntityId(); + assertThat(registrations.findUniqueByAssertingPartyEntityId(assertingPartyEntityId + "wrong")).isNull(); + } + }