diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2ServletUtils.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2ServletUtils.java deleted file mode 100644 index 91515b6a3a..0000000000 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2ServletUtils.java +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright 2002-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.saml2.provider.service.servlet.filter; - -import java.util.HashMap; -import java.util.Map; -import javax.servlet.http.HttpServletRequest; - -import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; -import org.springframework.util.StringUtils; -import org.springframework.web.util.UriComponents; -import org.springframework.web.util.UriComponentsBuilder; - -import static org.springframework.security.web.util.UrlUtils.buildFullRequestUrl; -import static org.springframework.web.util.UriComponentsBuilder.fromHttpUrl; - -/** - * @since 5.3 - */ -final class Saml2ServletUtils { - - private static final char PATH_DELIMITER = '/'; - - static String resolveUrlTemplate(String template, String baseUrl, RelyingPartyRegistration relyingParty) { - if (!StringUtils.hasText(template)) { - return baseUrl; - } - - String entityId = relyingParty.getAssertingPartyDetails().getEntityId(); - String registrationId = relyingParty.getRegistrationId(); - Map uriVariables = new HashMap<>(); - UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(baseUrl) - .replaceQuery(null) - .fragment(null) - .build(); - String scheme = uriComponents.getScheme(); - uriVariables.put("baseScheme", scheme == null ? "" : scheme); - String host = uriComponents.getHost(); - uriVariables.put("baseHost", host == null ? "" : host); - // following logic is based on HierarchicalUriComponents#toUriString() - int port = uriComponents.getPort(); - uriVariables.put("basePort", port == -1 ? "" : ":" + port); - String path = uriComponents.getPath(); - if (StringUtils.hasLength(path)) { - if (path.charAt(0) != PATH_DELIMITER) { - path = PATH_DELIMITER + path; - } - } - uriVariables.put("basePath", path == null ? "" : path); - uriVariables.put("baseUrl", uriComponents.toUriString()); - uriVariables.put("entityId", StringUtils.hasText(entityId) ? entityId : ""); - uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : ""); - - return UriComponentsBuilder.fromUriString(template) - .buildAndExpand(uriVariables) - .toUriString(); - } - - static String getApplicationUri(HttpServletRequest request) { - UriComponents uriComponents = fromHttpUrl(buildFullRequestUrl(request)) - .replacePath(request.getContextPath()) - .replaceQuery(null) - .fragment(null) - .build(); - return uriComponents.toUriString(); - } -} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2Utils.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2Utils.java deleted file mode 100644 index f472a2376b..0000000000 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2Utils.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright 2002-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.saml2.provider.service.servlet.filter; - -import org.apache.commons.codec.binary.Base64; -import org.springframework.security.saml2.Saml2Exception; - -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.util.zip.Deflater; -import java.util.zip.DeflaterOutputStream; -import java.util.zip.Inflater; -import java.util.zip.InflaterOutputStream; - -import static java.nio.charset.StandardCharsets.UTF_8; -import static java.util.zip.Deflater.DEFLATED; - -/** - * @since 5.3 - */ -final class Saml2Utils { - - - private static Base64 BASE64 = new Base64(0, new byte[]{'\n'}); - - static String samlEncode(byte[] b) { - return BASE64.encodeAsString(b); - } - - static byte[] samlDecode(String s) { - return BASE64.decode(s); - } - - static byte[] samlDeflate(String s) { - try { - ByteArrayOutputStream b = new ByteArrayOutputStream(); - DeflaterOutputStream deflater = new DeflaterOutputStream(b, new Deflater(DEFLATED, true)); - deflater.write(s.getBytes(UTF_8)); - deflater.finish(); - return b.toByteArray(); - } - catch (IOException e) { - throw new Saml2Exception("Unable to deflate string", e); - } - } - - static String samlInflate(byte[] b) { - try { - ByteArrayOutputStream out = new ByteArrayOutputStream(); - InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true)); - iout.write(b); - iout.finish(); - return new String(out.toByteArray(), UTF_8); - } - catch (IOException e) { - throw new Saml2Exception("Unable to inflate string", e); - } - } -} diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2UtilsTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2UtilsTests.java deleted file mode 100644 index be3405f924..0000000000 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2UtilsTests.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright 2002-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.saml2.provider.service.servlet.filter; - -import org.apache.commons.codec.binary.Base64; -import org.junit.Test; -import org.springframework.core.io.ClassPathResource; -import org.springframework.util.StreamUtils; -import org.springframework.web.util.UriUtils; - -import java.io.IOException; -import java.nio.charset.StandardCharsets; - -import static java.nio.charset.StandardCharsets.UTF_8; -import static org.assertj.core.api.Assertions.assertThat; - -public class Saml2UtilsTests { - - private static Base64 UNCHUNKED_ENCODER = new Base64(0, new byte[]{'\n'}); - private static final Base64 CHUNKED_ENCODER = new Base64(76, new byte[] { '\n' }); - - @Test - public void decodeWhenUsingApacheCommonsBase64ThenXmlIsValid() throws Exception { - String responseUrlDecoded = getSsoCircleEncodedXml(); - String xml = new String(UNCHUNKED_ENCODER.decode(responseUrlDecoded.getBytes(UTF_8)), UTF_8); - validateSsoCircleXml(xml); - } - - @Test - public void decodeWhenUsingApacheCommonsBase64ChunkedThenXmlIsValid() throws Exception { - String responseUrlDecoded = getSsoCircleEncodedXml(); - String xml = new String(CHUNKED_ENCODER.decode(responseUrlDecoded.getBytes(UTF_8)), UTF_8); - validateSsoCircleXml(xml); - } - - @Test - public void decodeWhenUsingSamlUtilsBase64ThenXmlIsValid() throws Exception { - String responseUrlDecoded = getSsoCircleEncodedXml(); - String xml = new String(Saml2Utils.samlDecode(responseUrlDecoded), UTF_8); - validateSsoCircleXml(xml); - } - - private void validateSsoCircleXml(String xml) { - assertThat(xml) - .contains("InResponseTo=\"ARQ9a73ead-7dcf-45a8-89eb-26f3c9900c36\"") - .contains(" ID=\"s246d157446618e90e43fb79bdd4d9e9e19cf2c7c4\"") - .contains("https://idp.ssocircle.com"); - } - - private String getSsoCircleEncodedXml() throws IOException { - ClassPathResource resource = new ClassPathResource("saml2-response-sso-circle.encoded"); - String response = StreamUtils.copyToString(resource.getInputStream(), StandardCharsets.UTF_8); - return UriUtils.decode(response, UTF_8); - } - -} diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java index 3c9fa627c0..74b987ea63 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java @@ -16,6 +16,8 @@ package org.springframework.security.saml2.provider.service.web; +import java.io.IOException; +import java.nio.charset.StandardCharsets; import javax.servlet.http.HttpServletRequest; import org.junit.Test; @@ -24,10 +26,13 @@ import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; import org.springframework.core.convert.converter.Converter; +import org.springframework.core.io.ClassPathResource; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.security.saml2.core.Saml2Utils; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.util.StreamUtils; +import org.springframework.web.util.UriUtils; import static java.nio.charset.StandardCharsets.UTF_8; import static org.assertj.core.api.Assertions.assertThat; @@ -99,4 +104,29 @@ public class Saml2AuthenticationTokenConverterTests { assertThatCode(() -> new Saml2AuthenticationTokenConverter(null)) .isInstanceOf(IllegalArgumentException.class); } + + @Test + public void convertWhenUsingSamlUtilsBase64ThenXmlIsValid() throws Exception { + Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter + (this.relyingPartyRegistrationResolver); + when(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class))) + .thenReturn(this.relyingPartyRegistration); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setParameter("SAMLResponse", getSsoCircleEncodedXml()); + Saml2AuthenticationToken token = converter.convert(request); + validateSsoCircleXml(token.getSaml2Response()); + } + + private void validateSsoCircleXml(String xml) { + assertThat(xml) + .contains("InResponseTo=\"ARQ9a73ead-7dcf-45a8-89eb-26f3c9900c36\"") + .contains(" ID=\"s246d157446618e90e43fb79bdd4d9e9e19cf2c7c4\"") + .contains("https://idp.ssocircle.com"); + } + + private String getSsoCircleEncodedXml() throws IOException { + ClassPathResource resource = new ClassPathResource("saml2-response-sso-circle.encoded"); + String response = StreamUtils.copyToString(resource.getInputStream(), StandardCharsets.UTF_8); + return UriUtils.decode(response, UTF_8); + } }