parent
8522e9abd6
commit
97d1a49daf
|
@ -21,15 +21,19 @@ import java.util.Collection;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
import java.util.LinkedHashMap;
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
import org.springframework.util.Assert;
|
import org.springframework.util.Assert;
|
||||||
|
import org.springframework.util.LinkedMultiValueMap;
|
||||||
|
import org.springframework.util.MultiValueMap;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An in-memory implementation of {@link RelyingPartyRegistrationRepository}.
|
* An in-memory implementation of {@link RelyingPartyRegistrationRepository}. Also
|
||||||
* Also implements {@link Iterable} to simplify the default login page.
|
* implements {@link Iterable} to simplify the default login page.
|
||||||
*
|
*
|
||||||
* @author Filip Hanik
|
* @author Filip Hanik
|
||||||
|
* @author Josh Cummings
|
||||||
* @since 5.2
|
* @since 5.2
|
||||||
*/
|
*/
|
||||||
public class InMemoryRelyingPartyRegistrationRepository
|
public class InMemoryRelyingPartyRegistrationRepository
|
||||||
|
@ -37,6 +41,8 @@ public class InMemoryRelyingPartyRegistrationRepository
|
||||||
|
|
||||||
private final Map<String, RelyingPartyRegistration> byRegistrationId;
|
private final Map<String, RelyingPartyRegistration> byRegistrationId;
|
||||||
|
|
||||||
|
private final Map<String, List<RelyingPartyRegistration>> byAssertingPartyEntityId;
|
||||||
|
|
||||||
public InMemoryRelyingPartyRegistrationRepository(RelyingPartyRegistration... registrations) {
|
public InMemoryRelyingPartyRegistrationRepository(RelyingPartyRegistration... registrations) {
|
||||||
this(Arrays.asList(registrations));
|
this(Arrays.asList(registrations));
|
||||||
}
|
}
|
||||||
|
@ -44,6 +50,7 @@ public class InMemoryRelyingPartyRegistrationRepository
|
||||||
public InMemoryRelyingPartyRegistrationRepository(Collection<RelyingPartyRegistration> registrations) {
|
public InMemoryRelyingPartyRegistrationRepository(Collection<RelyingPartyRegistration> registrations) {
|
||||||
Assert.notEmpty(registrations, "registrations cannot be empty");
|
Assert.notEmpty(registrations, "registrations cannot be empty");
|
||||||
this.byRegistrationId = createMappingToIdentityProvider(registrations);
|
this.byRegistrationId = createMappingToIdentityProvider(registrations);
|
||||||
|
this.byAssertingPartyEntityId = createMappingByAssertingPartyEntityId(registrations);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Map<String, RelyingPartyRegistration> createMappingToIdentityProvider(
|
private static Map<String, RelyingPartyRegistration> createMappingToIdentityProvider(
|
||||||
|
@ -59,11 +66,32 @@ public class InMemoryRelyingPartyRegistrationRepository
|
||||||
return Collections.unmodifiableMap(result);
|
return Collections.unmodifiableMap(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static Map<String, List<RelyingPartyRegistration>> createMappingByAssertingPartyEntityId(
|
||||||
|
Collection<RelyingPartyRegistration> rps) {
|
||||||
|
MultiValueMap<String, RelyingPartyRegistration> result = new LinkedMultiValueMap<>();
|
||||||
|
for (RelyingPartyRegistration rp : rps) {
|
||||||
|
result.add(rp.getAssertingPartyDetails().getEntityId(), rp);
|
||||||
|
}
|
||||||
|
return Collections.unmodifiableMap(result);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public RelyingPartyRegistration findByRegistrationId(String id) {
|
public RelyingPartyRegistration findByRegistrationId(String id) {
|
||||||
return this.byRegistrationId.get(id);
|
return this.byRegistrationId.get(id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RelyingPartyRegistration findUniqueByAssertingPartyEntityId(String entityId) {
|
||||||
|
Collection<RelyingPartyRegistration> registrations = this.byAssertingPartyEntityId.get(entityId);
|
||||||
|
if (registrations == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
if (registrations.size() > 1) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return registrations.iterator().next();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Iterator<RelyingPartyRegistration> iterator() {
|
public Iterator<RelyingPartyRegistration> iterator() {
|
||||||
return this.byRegistrationId.values().iterator();
|
return this.byRegistrationId.values().iterator();
|
||||||
|
|
|
@ -20,6 +20,7 @@ package org.springframework.security.saml2.provider.service.registration;
|
||||||
* A repository for {@link RelyingPartyRegistration}s
|
* A repository for {@link RelyingPartyRegistration}s
|
||||||
*
|
*
|
||||||
* @author Filip Hanik
|
* @author Filip Hanik
|
||||||
|
* @author Josh Cummings
|
||||||
* @since 5.2
|
* @since 5.2
|
||||||
*/
|
*/
|
||||||
public interface RelyingPartyRegistrationRepository {
|
public interface RelyingPartyRegistrationRepository {
|
||||||
|
@ -32,4 +33,16 @@ public interface RelyingPartyRegistrationRepository {
|
||||||
*/
|
*/
|
||||||
RelyingPartyRegistration findByRegistrationId(String registrationId);
|
RelyingPartyRegistration findByRegistrationId(String registrationId);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the unique relying party registration associated with the asserting party's
|
||||||
|
* {@code entityId} or {@code null} if there is no unique match.
|
||||||
|
* @param entityId the asserting party's entity id
|
||||||
|
* @return the unique {@link RelyingPartyRegistration} associated the given asserting
|
||||||
|
* party; {@code null} of there is no unique match asserting party
|
||||||
|
* @since 6.1
|
||||||
|
*/
|
||||||
|
default RelyingPartyRegistration findUniqueByAssertingPartyEntityId(String entityId) {
|
||||||
|
return findByRegistrationId(entityId);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,194 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2002-2022 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.springframework.security.saml2.provider.service.web;
|
||||||
|
|
||||||
|
import java.io.ByteArrayOutputStream;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Base64;
|
||||||
|
import java.util.function.Function;
|
||||||
|
import java.util.zip.Inflater;
|
||||||
|
import java.util.zip.InflaterOutputStream;
|
||||||
|
|
||||||
|
import jakarta.servlet.http.HttpServletRequest;
|
||||||
|
|
||||||
|
import org.springframework.http.HttpMethod;
|
||||||
|
import org.springframework.security.saml2.core.Saml2Error;
|
||||||
|
import org.springframework.security.saml2.core.Saml2ErrorCodes;
|
||||||
|
import org.springframework.security.saml2.core.Saml2ParameterNames;
|
||||||
|
import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
|
||||||
|
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
|
||||||
|
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
|
||||||
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
||||||
|
import org.springframework.security.web.authentication.AuthenticationConverter;
|
||||||
|
import org.springframework.util.Assert;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An {@link AuthenticationConverter} that generates a {@link Saml2AuthenticationToken}
|
||||||
|
* appropriate for authenticated a SAML 2.0 Assertion against an
|
||||||
|
* {@link org.springframework.security.authentication.AuthenticationManager}.
|
||||||
|
*
|
||||||
|
* @author Josh Cummings
|
||||||
|
* @since 5.4
|
||||||
|
*/
|
||||||
|
public final class Saml2AuthenticationTokenConverter implements AuthenticationConverter {
|
||||||
|
|
||||||
|
// MimeDecoder allows extra line-breaks as well as other non-alphabet values.
|
||||||
|
// This matches the behaviour of the commons-codec decoder.
|
||||||
|
private static final Base64.Decoder BASE64 = Base64.getMimeDecoder();
|
||||||
|
|
||||||
|
private static final Base64Checker BASE_64_CHECKER = new Base64Checker();
|
||||||
|
|
||||||
|
private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver;
|
||||||
|
|
||||||
|
private Function<HttpServletRequest, AbstractSaml2AuthenticationRequest> loader;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructs a {@link Saml2AuthenticationTokenConverter} given a strategy for
|
||||||
|
* resolving {@link RelyingPartyRegistration}s
|
||||||
|
* @param relyingPartyRegistrationResolver the strategy for resolving
|
||||||
|
* {@link RelyingPartyRegistration}s
|
||||||
|
*/
|
||||||
|
public Saml2AuthenticationTokenConverter(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) {
|
||||||
|
Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null");
|
||||||
|
this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
|
||||||
|
this.loader = new HttpSessionSaml2AuthenticationRequestRepository()::loadAuthenticationRequest;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Saml2AuthenticationToken convert(HttpServletRequest request) {
|
||||||
|
AbstractSaml2AuthenticationRequest authenticationRequest = loadAuthenticationRequest(request);
|
||||||
|
String relyingPartyRegistrationId = (authenticationRequest != null)
|
||||||
|
? authenticationRequest.getRelyingPartyRegistrationId() : null;
|
||||||
|
RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.resolve(request,
|
||||||
|
relyingPartyRegistrationId);
|
||||||
|
if (relyingPartyRegistration == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
String saml2Response = request.getParameter(Saml2ParameterNames.SAML_RESPONSE);
|
||||||
|
if (saml2Response == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
byte[] b = samlDecode(saml2Response);
|
||||||
|
saml2Response = inflateIfRequired(request, b);
|
||||||
|
return new Saml2AuthenticationToken(relyingPartyRegistration, saml2Response, authenticationRequest);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Use the given {@link Saml2AuthenticationRequestRepository} to load authentication
|
||||||
|
* request.
|
||||||
|
* @param authenticationRequestRepository the
|
||||||
|
* {@link Saml2AuthenticationRequestRepository} to use
|
||||||
|
* @since 5.6
|
||||||
|
*/
|
||||||
|
public void setAuthenticationRequestRepository(
|
||||||
|
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository) {
|
||||||
|
Assert.notNull(authenticationRequestRepository, "authenticationRequestRepository cannot be null");
|
||||||
|
this.loader = authenticationRequestRepository::loadAuthenticationRequest;
|
||||||
|
}
|
||||||
|
|
||||||
|
private AbstractSaml2AuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) {
|
||||||
|
return this.loader.apply(request);
|
||||||
|
}
|
||||||
|
|
||||||
|
private String inflateIfRequired(HttpServletRequest request, byte[] b) {
|
||||||
|
if (HttpMethod.GET.matches(request.getMethod())) {
|
||||||
|
return samlInflate(b);
|
||||||
|
}
|
||||||
|
return new String(b, StandardCharsets.UTF_8);
|
||||||
|
}
|
||||||
|
|
||||||
|
private byte[] samlDecode(String base64EncodedPayload) {
|
||||||
|
try {
|
||||||
|
BASE_64_CHECKER.checkAcceptable(base64EncodedPayload);
|
||||||
|
return BASE64.decode(base64EncodedPayload);
|
||||||
|
}
|
||||||
|
catch (Exception ex) {
|
||||||
|
throw new Saml2AuthenticationException(
|
||||||
|
new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, "Failed to decode SAMLResponse"), ex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private String samlInflate(byte[] b) {
|
||||||
|
try {
|
||||||
|
ByteArrayOutputStream out = new ByteArrayOutputStream();
|
||||||
|
InflaterOutputStream inflaterOutputStream = new InflaterOutputStream(out, new Inflater(true));
|
||||||
|
inflaterOutputStream.write(b);
|
||||||
|
inflaterOutputStream.finish();
|
||||||
|
return out.toString(StandardCharsets.UTF_8.name());
|
||||||
|
}
|
||||||
|
catch (Exception ex) {
|
||||||
|
throw new Saml2AuthenticationException(
|
||||||
|
new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, "Unable to inflate string"), ex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static class Base64Checker {
|
||||||
|
|
||||||
|
private static final int[] values = genValueMapping();
|
||||||
|
|
||||||
|
Base64Checker() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private static int[] genValueMapping() {
|
||||||
|
byte[] alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
|
||||||
|
.getBytes(StandardCharsets.ISO_8859_1);
|
||||||
|
|
||||||
|
int[] values = new int[256];
|
||||||
|
Arrays.fill(values, -1);
|
||||||
|
for (int i = 0; i < alphabet.length; i++) {
|
||||||
|
values[alphabet[i] & 0xff] = i;
|
||||||
|
}
|
||||||
|
return values;
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean isAcceptable(String s) {
|
||||||
|
int goodChars = 0;
|
||||||
|
int lastGoodCharVal = -1;
|
||||||
|
|
||||||
|
// count number of characters from Base64 alphabet
|
||||||
|
for (int i = 0; i < s.length(); i++) {
|
||||||
|
int val = values[0xff & s.charAt(i)];
|
||||||
|
if (val != -1) {
|
||||||
|
lastGoodCharVal = val;
|
||||||
|
goodChars++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// in cases of an incomplete final chunk, ensure the unused bits are zero
|
||||||
|
switch (goodChars % 4) {
|
||||||
|
case 0:
|
||||||
|
return true;
|
||||||
|
case 2:
|
||||||
|
return (lastGoodCharVal & 0b1111) == 0;
|
||||||
|
case 3:
|
||||||
|
return (lastGoodCharVal & 0b11) == 0;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void checkAcceptable(String ins) {
|
||||||
|
if (!isAcceptable(ins)) {
|
||||||
|
throw new IllegalArgumentException("Unaccepted Encoding");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -42,4 +42,22 @@ public class InMemoryRelyingPartyRegistrationRepositoryTests {
|
||||||
assertThat(registrations.findByRegistrationId(null)).isNull();
|
assertThat(registrations.findByRegistrationId(null)).isNull();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void findByAssertingPartyEntityIdWhenGivenEntityIdThenReturnsMatchingRegistrations() {
|
||||||
|
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration().build();
|
||||||
|
InMemoryRelyingPartyRegistrationRepository registrations = new InMemoryRelyingPartyRegistrationRepository(
|
||||||
|
registration);
|
||||||
|
String assertingPartyEntityId = registration.getAssertingPartyDetails().getEntityId();
|
||||||
|
assertThat(registrations.findUniqueByAssertingPartyEntityId(assertingPartyEntityId)).isEqualTo(registration);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void findByAssertingPartyEntityIdWhenGivenWrongEntityIdThenReturnsEmpty() {
|
||||||
|
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration().build();
|
||||||
|
InMemoryRelyingPartyRegistrationRepository registrations = new InMemoryRelyingPartyRegistrationRepository(
|
||||||
|
registration);
|
||||||
|
String assertingPartyEntityId = registration.getAssertingPartyDetails().getEntityId();
|
||||||
|
assertThat(registrations.findUniqueByAssertingPartyEntityId(assertingPartyEntityId + "wrong")).isNull();
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue