diff --git a/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfTokenUtils.java b/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfTokenUtils.java index 46a67cc4d3..b6395f9b14 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfTokenUtils.java +++ b/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfTokenUtils.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ package org.springframework.security.messaging.web.csrf; import java.util.Base64; import org.springframework.security.crypto.codec.Utf8; +import org.springframework.util.Assert; /** * Copied from @@ -43,26 +44,26 @@ final class XorCsrfTokenUtils { byte[] tokenBytes = Utf8.encode(token); int tokenSize = tokenBytes.length; - if (actualBytes.length < tokenSize) { + if (actualBytes.length != tokenSize * 2) { return null; } // extract token and random bytes - int randomBytesSize = actualBytes.length - tokenSize; byte[] xoredCsrf = new byte[tokenSize]; - byte[] randomBytes = new byte[randomBytesSize]; + byte[] randomBytes = new byte[tokenSize]; - System.arraycopy(actualBytes, 0, randomBytes, 0, randomBytesSize); - System.arraycopy(actualBytes, randomBytesSize, xoredCsrf, 0, tokenSize); + System.arraycopy(actualBytes, 0, randomBytes, 0, tokenSize); + System.arraycopy(actualBytes, tokenSize, xoredCsrf, 0, tokenSize); byte[] csrfBytes = xorCsrf(randomBytes, xoredCsrf); return Utf8.decode(csrfBytes); } private static byte[] xorCsrf(byte[] randomBytes, byte[] csrfBytes) { - int len = Math.min(randomBytes.length, csrfBytes.length); + Assert.isTrue(randomBytes.length == csrfBytes.length, "arrays must be equal length"); + int len = csrfBytes.length; byte[] xoredCsrf = new byte[len]; - System.arraycopy(csrfBytes, 0, xoredCsrf, 0, csrfBytes.length); + System.arraycopy(csrfBytes, 0, xoredCsrf, 0, len); for (int i = 0; i < len; i++) { xoredCsrf[i] ^= randomBytes[i]; } diff --git a/messaging/src/test/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptorTests.java b/messaging/src/test/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptorTests.java index 884c3d2fc2..5d7705c67b 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptorTests.java +++ b/messaging/src/test/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package org.springframework.security.messaging.web.csrf; +import java.util.Base64; import java.util.HashMap; import org.junit.jupiter.api.BeforeEach; @@ -141,6 +142,73 @@ public class XorCsrfChannelInterceptorTests { this.interceptor.preSend(message(), this.channel); } + // gh-13310, gh-15184 + @Test + public void preSendWhenCsrfBytesIsShorterThanRandomBytesThenThrowsInvalidCsrfTokenException() { + /* + * Token format: 3 random pad bytes + 2 padded bytes. + */ + byte[] actualBytes = { 1, 1, 1, 96, 99 }; + String actualToken = Base64.getEncoder().encodeToString(actualBytes); + this.messageHeaders.setNativeHeader(this.token.getHeaderName(), actualToken); + this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token); + // @formatter:off + assertThatExceptionOfType(InvalidCsrfTokenException.class) + .isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class))); + // @formatter:on + } + + // gh-13310, gh-15184 + @Test + public void preSendWhenCsrfBytesIsLongerThanRandomBytesThenThrowsInvalidCsrfTokenException() { + /* + * Token format: 3 random pad bytes + 4 padded bytes. + */ + byte[] actualBytes = { 1, 1, 1, 96, 99, 98, 97 }; + String actualToken = Base64.getEncoder().encodeToString(actualBytes); + this.messageHeaders.setNativeHeader(this.token.getHeaderName(), actualToken); + this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token); + // @formatter:off + assertThatExceptionOfType(InvalidCsrfTokenException.class) + .isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class))); + // @formatter:on + } + + // gh-13310, gh-15184 + @Test + public void preSendWhenTokenBytesIsShorterThanActualBytesThenThrowsInvalidCsrfTokenException() { + this.messageHeaders.setNativeHeader(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE); + CsrfToken csrfToken = new DefaultCsrfToken("header", "param", "a"); + this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), csrfToken); + // @formatter:off + assertThatExceptionOfType(InvalidCsrfTokenException.class) + .isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class))); + // @formatter:on + } + + // gh-13310, gh-15184 + @Test + public void preSendWhenTokenBytesIsLongerThanActualBytesThenThrowsInvalidCsrfTokenException() { + this.messageHeaders.setNativeHeader(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE); + CsrfToken csrfToken = new DefaultCsrfToken("header", "param", "abcde"); + this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), csrfToken); + // @formatter:off + assertThatExceptionOfType(InvalidCsrfTokenException.class) + .isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class))); + // @formatter:on + } + + // gh-13310, gh-15184 + @Test + public void preSendWhenActualBytesIsEmptyThenThrowsInvalidCsrfTokenException() { + this.messageHeaders.setNativeHeader(this.token.getHeaderName(), ""); + this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token); + // @formatter:off + assertThatExceptionOfType(InvalidCsrfTokenException.class) + .isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class))); + // @formatter:on + } + private Message message() { return MessageBuilder.withPayload("message").copyHeaders(this.messageHeaders.toMap()).build(); } diff --git a/web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandler.java b/web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandler.java index 8d966331ae..9416d233eb 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandler.java +++ b/web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -84,20 +84,19 @@ public final class XorCsrfTokenRequestAttributeHandler extends CsrfTokenRequestA byte[] tokenBytes = Utf8.encode(token); int tokenSize = tokenBytes.length; - if (actualBytes.length < tokenSize) { + if (actualBytes.length != tokenSize * 2) { return null; } // extract token and random bytes - int randomBytesSize = actualBytes.length - tokenSize; byte[] xoredCsrf = new byte[tokenSize]; - byte[] randomBytes = new byte[randomBytesSize]; + byte[] randomBytes = new byte[tokenSize]; - System.arraycopy(actualBytes, 0, randomBytes, 0, randomBytesSize); - System.arraycopy(actualBytes, randomBytesSize, xoredCsrf, 0, tokenSize); + System.arraycopy(actualBytes, 0, randomBytes, 0, tokenSize); + System.arraycopy(actualBytes, tokenSize, xoredCsrf, 0, tokenSize); byte[] csrfBytes = xorCsrf(randomBytes, xoredCsrf); - return (csrfBytes != null) ? Utf8.decode(csrfBytes) : null; + return Utf8.decode(csrfBytes); } private static String createXoredCsrfToken(SecureRandom secureRandom, String token) { @@ -114,12 +113,10 @@ public final class XorCsrfTokenRequestAttributeHandler extends CsrfTokenRequestA } private static byte[] xorCsrf(byte[] randomBytes, byte[] csrfBytes) { - if (csrfBytes.length < randomBytes.length) { - return null; - } - int len = Math.min(randomBytes.length, csrfBytes.length); + Assert.isTrue(randomBytes.length == csrfBytes.length, "arrays must be equal length"); + int len = csrfBytes.length; byte[] xoredCsrf = new byte[len]; - System.arraycopy(csrfBytes, 0, xoredCsrf, 0, csrfBytes.length); + System.arraycopy(csrfBytes, 0, xoredCsrf, 0, len); for (int i = 0; i < len; i++) { xoredCsrf[i] ^= randomBytes[i]; } diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/XorServerCsrfTokenRequestAttributeHandler.java b/web/src/main/java/org/springframework/security/web/server/csrf/XorServerCsrfTokenRequestAttributeHandler.java index de59167e8c..32e3642351 100644 --- a/web/src/main/java/org/springframework/security/web/server/csrf/XorServerCsrfTokenRequestAttributeHandler.java +++ b/web/src/main/java/org/springframework/security/web/server/csrf/XorServerCsrfTokenRequestAttributeHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -77,17 +77,16 @@ public final class XorServerCsrfTokenRequestAttributeHandler extends ServerCsrfT byte[] tokenBytes = Utf8.encode(token); int tokenSize = tokenBytes.length; - if (actualBytes.length < tokenSize) { + if (actualBytes.length != tokenSize * 2) { return null; } // extract token and random bytes - int randomBytesSize = actualBytes.length - tokenSize; byte[] xoredCsrf = new byte[tokenSize]; - byte[] randomBytes = new byte[randomBytesSize]; + byte[] randomBytes = new byte[tokenSize]; - System.arraycopy(actualBytes, 0, randomBytes, 0, randomBytesSize); - System.arraycopy(actualBytes, randomBytesSize, xoredCsrf, 0, tokenSize); + System.arraycopy(actualBytes, 0, randomBytes, 0, tokenSize); + System.arraycopy(actualBytes, tokenSize, xoredCsrf, 0, tokenSize); byte[] csrfBytes = xorCsrf(randomBytes, xoredCsrf); return (csrfBytes != null) ? Utf8.decode(csrfBytes) : null; @@ -107,12 +106,10 @@ public final class XorServerCsrfTokenRequestAttributeHandler extends ServerCsrfT } private static byte[] xorCsrf(byte[] randomBytes, byte[] csrfBytes) { - if (csrfBytes.length < randomBytes.length) { - return null; - } - int len = Math.min(randomBytes.length, csrfBytes.length); + Assert.isTrue(randomBytes.length == csrfBytes.length, "arrays must be equal length"); + int len = csrfBytes.length; byte[] xoredCsrf = new byte[len]; - System.arraycopy(csrfBytes, 0, xoredCsrf, 0, csrfBytes.length); + System.arraycopy(csrfBytes, 0, xoredCsrf, 0, len); for (int i = 0; i < len; i++) { xoredCsrf[i] ^= randomBytes[i]; } diff --git a/web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandlerTests.java b/web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandlerTests.java index 6f50862411..fc5696fb24 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandlerTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandlerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,6 +44,9 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; */ public class XorCsrfTokenRequestAttributeHandlerTests { + /* + * Token format: 3 random pad bytes + 3 padded bytes. + */ private static final byte[] XOR_CSRF_TOKEN_BYTES = new byte[] { 1, 1, 1, 96, 99, 98 }; private static final String XOR_CSRF_TOKEN_VALUE = Base64.getEncoder().encodeToString(XOR_CSRF_TOKEN_BYTES); @@ -208,14 +211,58 @@ public class XorCsrfTokenRequestAttributeHandlerTests { assertThat(tokenValue).isEqualTo(this.token.getToken()); } + // gh-13310, gh-15184 @Test - public void resolveCsrfTokenIsInvalidThenReturnsNull() { + public void resolveCsrfTokenValueWhenCsrfBytesIsShorterThanRandomBytesThenReturnsNull() { + /* + * Token format: 3 random pad bytes + 2 padded bytes. + */ + byte[] actualBytes = { 1, 1, 1, 96, 99 }; + String actualToken = Base64.getEncoder().encodeToString(actualBytes); + this.request.setParameter(this.token.getParameterName(), actualToken); + String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token); + assertThat(tokenValue).isNull(); + } + + // gh-13310, gh-15184 + @Test + public void resolveCsrfTokenValueWhenCsrfBytesIsLongerThanRandomBytesThenReturnsNull() { + /* + * Token format: 3 random pad bytes + 4 padded bytes. + */ + byte[] actualBytes = { 1, 1, 1, 96, 99, 98, 97 }; + String actualToken = Base64.getEncoder().encodeToString(actualBytes); + this.request.setParameter(this.token.getParameterName(), actualToken); + String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token); + assertThat(tokenValue).isNull(); + } + + // gh-13310, gh-15184 + @Test + public void resolveCsrfTokenValueWhenTokenBytesIsShorterThanActualBytesThenReturnsNull() { this.request.setParameter(this.token.getParameterName(), XOR_CSRF_TOKEN_VALUE); CsrfToken csrfToken = new DefaultCsrfToken("headerName", "paramName", "a"); String tokenValue = this.handler.resolveCsrfTokenValue(this.request, csrfToken); assertThat(tokenValue).isNull(); } + // gh-13310, gh-15184 + @Test + public void resolveCsrfTokenValueWhenTokenBytesIsLongerThanActualBytesThenReturnsNull() { + this.request.setParameter(this.token.getParameterName(), XOR_CSRF_TOKEN_VALUE); + CsrfToken csrfToken = new DefaultCsrfToken("headerName", "paramName", "abcde"); + String tokenValue = this.handler.resolveCsrfTokenValue(this.request, csrfToken); + assertThat(tokenValue).isNull(); + } + + // gh-13310, gh-15184 + @Test + public void resolveCsrfTokenValueWhenActualBytesIsEmptyThenReturnsNull() { + this.request.setParameter(this.token.getParameterName(), ""); + String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token); + assertThat(tokenValue).isNull(); + } + private static Answer fillByteArray() { return (invocation) -> { byte[] bytes = invocation.getArgument(0); diff --git a/web/src/test/java/org/springframework/security/web/server/csrf/XorServerCsrfTokenRequestAttributeHandlerTests.java b/web/src/test/java/org/springframework/security/web/server/csrf/XorServerCsrfTokenRequestAttributeHandlerTests.java index 08b54b01f9..bd859ce965 100644 --- a/web/src/test/java/org/springframework/security/web/server/csrf/XorServerCsrfTokenRequestAttributeHandlerTests.java +++ b/web/src/test/java/org/springframework/security/web/server/csrf/XorServerCsrfTokenRequestAttributeHandlerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,6 +46,9 @@ import static org.mockito.Mockito.verify; */ public class XorServerCsrfTokenRequestAttributeHandlerTests { + /* + * Token format: 3 random pad bytes + 3 padded bytes. + */ private static final byte[] XOR_CSRF_TOKEN_BYTES = new byte[] { 1, 1, 1, 96, 99, 98 }; private static final String XOR_CSRF_TOKEN_VALUE = Base64.getEncoder().encodeToString(XOR_CSRF_TOKEN_BYTES); @@ -188,16 +191,76 @@ public class XorServerCsrfTokenRequestAttributeHandlerTests { StepVerifier.create(csrfToken).expectNext(this.token.getToken()).verifyComplete(); } + // gh-13310, gh-15184 @Test - public void resolveCsrfTokenIsInvalidThenReturnsNull() { + public void resolveCsrfTokenValueWhenCsrfBytesIsShorterThanRandomBytesThenReturnsNull() { + /* + * Token format: 3 random pad bytes + 2 padded bytes. + */ + byte[] actualBytes = { 1, 1, 1, 96, 99 }; + String actualToken = Base64.getEncoder().encodeToString(actualBytes); this.exchange = MockServerWebExchange .builder(MockServerHttpRequest.post("/") .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE) - .body(this.token.getParameterName() + "=" + XOR_CSRF_TOKEN_VALUE)) + .header(this.token.getHeaderName(), actualToken)) .build(); - CsrfToken token = new DefaultCsrfToken("headerName", "paramName", "a"); - Mono csrfToken = this.handler.resolveCsrfTokenValue(this.exchange, token); - assertThat(csrfToken.block()).isNull(); + String tokenValue = this.handler.resolveCsrfTokenValue(this.exchange, this.token).block(); + assertThat(tokenValue).isNull(); + } + + // gh-13310, gh-15184 + @Test + public void resolveCsrfTokenValueWhenCsrfBytesIsLongerThanRandomBytesThenReturnsNull() { + /* + * Token format: 3 random pad bytes + 4 padded bytes. + */ + byte[] actualBytes = { 1, 1, 1, 96, 99, 98, 97 }; + String actualToken = Base64.getEncoder().encodeToString(actualBytes); + this.exchange = MockServerWebExchange + .builder(MockServerHttpRequest.post("/") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE) + .header(this.token.getHeaderName(), actualToken)) + .build(); + String tokenValue = this.handler.resolveCsrfTokenValue(this.exchange, this.token).block(); + assertThat(tokenValue).isNull(); + } + + // gh-13310, gh-15184 + @Test + public void resolveCsrfTokenValueWhenTokenBytesIsShorterThanActualBytesThenReturnsNull() { + this.exchange = MockServerWebExchange + .builder(MockServerHttpRequest.post("/") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE) + .header(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE)) + .build(); + CsrfToken csrfToken = new DefaultCsrfToken("headerName", "paramName", "a"); + String tokenValue = this.handler.resolveCsrfTokenValue(this.exchange, csrfToken).block(); + assertThat(tokenValue).isNull(); + } + + // gh-13310, gh-15184 + @Test + public void resolveCsrfTokenValueWhenTokenBytesIsLongerThanActualBytesThenReturnsNull() { + this.exchange = MockServerWebExchange + .builder(MockServerHttpRequest.post("/") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE) + .header(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE)) + .build(); + CsrfToken csrfToken = new DefaultCsrfToken("headerName", "paramName", "abcde"); + String tokenValue = this.handler.resolveCsrfTokenValue(this.exchange, csrfToken).block(); + assertThat(tokenValue).isNull(); + } + + // gh-13310, gh-15184 + @Test + public void resolveCsrfTokenValueWhenActualBytesIsEmptyThenReturnsNull() { + this.exchange = MockServerWebExchange + .builder(MockServerHttpRequest.post("/") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE) + .header(this.token.getHeaderName(), "")) + .build(); + String tokenValue = this.handler.resolveCsrfTokenValue(this.exchange, this.token).block(); + assertThat(tokenValue).isNull(); } private static Answer fillByteArray() {