From dcb8c563e8f57326d3b6e0ce6b48b87ed0d2d94c Mon Sep 17 00:00:00 2001 From: Steve Riesenberg <5248162+sjohnr@users.noreply.github.com> Date: Fri, 31 May 2024 18:12:21 -0500 Subject: [PATCH] Fix ArrayIndexOutOfBoundsException Issue gh-13310 Closes gh-15184 --- .../messaging/web/csrf/XorCsrfTokenUtils.java | 17 ++-- .../csrf/XorCsrfChannelInterceptorTests.java | 70 ++++++++++++++++- .../XorCsrfTokenRequestAttributeHandler.java | 16 ++-- ...erverCsrfTokenRequestAttributeHandler.java | 16 ++-- ...CsrfTokenRequestAttributeHandlerTests.java | 57 +++++++++++++- ...CsrfTokenRequestAttributeHandlerTests.java | 77 ++++++++++++++++++- 6 files changed, 226 insertions(+), 27 deletions(-) diff --git a/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfTokenUtils.java b/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfTokenUtils.java index 46a67cc4d3..b6395f9b14 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfTokenUtils.java +++ b/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfTokenUtils.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ package org.springframework.security.messaging.web.csrf; import java.util.Base64; import org.springframework.security.crypto.codec.Utf8; +import org.springframework.util.Assert; /** * Copied from @@ -43,26 +44,26 @@ final class XorCsrfTokenUtils { byte[] tokenBytes = Utf8.encode(token); int tokenSize = tokenBytes.length; - if (actualBytes.length < tokenSize) { + if (actualBytes.length != tokenSize * 2) { return null; } // extract token and random bytes - int randomBytesSize = actualBytes.length - tokenSize; byte[] xoredCsrf = new byte[tokenSize]; - byte[] randomBytes = new byte[randomBytesSize]; + byte[] randomBytes = new byte[tokenSize]; - System.arraycopy(actualBytes, 0, randomBytes, 0, randomBytesSize); - System.arraycopy(actualBytes, randomBytesSize, xoredCsrf, 0, tokenSize); + System.arraycopy(actualBytes, 0, randomBytes, 0, tokenSize); + System.arraycopy(actualBytes, tokenSize, xoredCsrf, 0, tokenSize); byte[] csrfBytes = xorCsrf(randomBytes, xoredCsrf); return Utf8.decode(csrfBytes); } private static byte[] xorCsrf(byte[] randomBytes, byte[] csrfBytes) { - int len = Math.min(randomBytes.length, csrfBytes.length); + Assert.isTrue(randomBytes.length == csrfBytes.length, "arrays must be equal length"); + int len = csrfBytes.length; byte[] xoredCsrf = new byte[len]; - System.arraycopy(csrfBytes, 0, xoredCsrf, 0, csrfBytes.length); + System.arraycopy(csrfBytes, 0, xoredCsrf, 0, len); for (int i = 0; i < len; i++) { xoredCsrf[i] ^= randomBytes[i]; } diff --git a/messaging/src/test/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptorTests.java b/messaging/src/test/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptorTests.java index 884c3d2fc2..5d7705c67b 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptorTests.java +++ b/messaging/src/test/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package org.springframework.security.messaging.web.csrf; +import java.util.Base64; import java.util.HashMap; import org.junit.jupiter.api.BeforeEach; @@ -141,6 +142,73 @@ public class XorCsrfChannelInterceptorTests { this.interceptor.preSend(message(), this.channel); } + // gh-13310, gh-15184 + @Test + public void preSendWhenCsrfBytesIsShorterThanRandomBytesThenThrowsInvalidCsrfTokenException() { + /* + * Token format: 3 random pad bytes + 2 padded bytes. + */ + byte[] actualBytes = { 1, 1, 1, 96, 99 }; + String actualToken = Base64.getEncoder().encodeToString(actualBytes); + this.messageHeaders.setNativeHeader(this.token.getHeaderName(), actualToken); + this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token); + // @formatter:off + assertThatExceptionOfType(InvalidCsrfTokenException.class) + .isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class))); + // @formatter:on + } + + // gh-13310, gh-15184 + @Test + public void preSendWhenCsrfBytesIsLongerThanRandomBytesThenThrowsInvalidCsrfTokenException() { + /* + * Token format: 3 random pad bytes + 4 padded bytes. + */ + byte[] actualBytes = { 1, 1, 1, 96, 99, 98, 97 }; + String actualToken = Base64.getEncoder().encodeToString(actualBytes); + this.messageHeaders.setNativeHeader(this.token.getHeaderName(), actualToken); + this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token); + // @formatter:off + assertThatExceptionOfType(InvalidCsrfTokenException.class) + .isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class))); + // @formatter:on + } + + // gh-13310, gh-15184 + @Test + public void preSendWhenTokenBytesIsShorterThanActualBytesThenThrowsInvalidCsrfTokenException() { + this.messageHeaders.setNativeHeader(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE); + CsrfToken csrfToken = new DefaultCsrfToken("header", "param", "a"); + this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), csrfToken); + // @formatter:off + assertThatExceptionOfType(InvalidCsrfTokenException.class) + .isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class))); + // @formatter:on + } + + // gh-13310, gh-15184 + @Test + public void preSendWhenTokenBytesIsLongerThanActualBytesThenThrowsInvalidCsrfTokenException() { + this.messageHeaders.setNativeHeader(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE); + CsrfToken csrfToken = new DefaultCsrfToken("header", "param", "abcde"); + this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), csrfToken); + // @formatter:off + assertThatExceptionOfType(InvalidCsrfTokenException.class) + .isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class))); + // @formatter:on + } + + // gh-13310, gh-15184 + @Test + public void preSendWhenActualBytesIsEmptyThenThrowsInvalidCsrfTokenException() { + this.messageHeaders.setNativeHeader(this.token.getHeaderName(), ""); + this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token); + // @formatter:off + assertThatExceptionOfType(InvalidCsrfTokenException.class) + .isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class))); + // @formatter:on + } + private Message message() { return MessageBuilder.withPayload("message").copyHeaders(this.messageHeaders.toMap()).build(); } diff --git a/web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandler.java b/web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandler.java index e66d075486..275e0bf37c 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandler.java +++ b/web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -84,17 +84,16 @@ public final class XorCsrfTokenRequestAttributeHandler extends CsrfTokenRequestA byte[] tokenBytes = Utf8.encode(token); int tokenSize = tokenBytes.length; - if (actualBytes.length < tokenSize) { + if (actualBytes.length != tokenSize * 2) { return null; } // extract token and random bytes - int randomBytesSize = actualBytes.length - tokenSize; byte[] xoredCsrf = new byte[tokenSize]; - byte[] randomBytes = new byte[randomBytesSize]; + byte[] randomBytes = new byte[tokenSize]; - System.arraycopy(actualBytes, 0, randomBytes, 0, randomBytesSize); - System.arraycopy(actualBytes, randomBytesSize, xoredCsrf, 0, tokenSize); + System.arraycopy(actualBytes, 0, randomBytes, 0, tokenSize); + System.arraycopy(actualBytes, tokenSize, xoredCsrf, 0, tokenSize); byte[] csrfBytes = xorCsrf(randomBytes, xoredCsrf); return Utf8.decode(csrfBytes); @@ -114,9 +113,10 @@ public final class XorCsrfTokenRequestAttributeHandler extends CsrfTokenRequestA } private static byte[] xorCsrf(byte[] randomBytes, byte[] csrfBytes) { - int len = Math.min(randomBytes.length, csrfBytes.length); + Assert.isTrue(randomBytes.length == csrfBytes.length, "arrays must be equal length"); + int len = csrfBytes.length; byte[] xoredCsrf = new byte[len]; - System.arraycopy(csrfBytes, 0, xoredCsrf, 0, csrfBytes.length); + System.arraycopy(csrfBytes, 0, xoredCsrf, 0, len); for (int i = 0; i < len; i++) { xoredCsrf[i] ^= randomBytes[i]; } diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/XorServerCsrfTokenRequestAttributeHandler.java b/web/src/main/java/org/springframework/security/web/server/csrf/XorServerCsrfTokenRequestAttributeHandler.java index d361104ecb..4dc59caca0 100644 --- a/web/src/main/java/org/springframework/security/web/server/csrf/XorServerCsrfTokenRequestAttributeHandler.java +++ b/web/src/main/java/org/springframework/security/web/server/csrf/XorServerCsrfTokenRequestAttributeHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -77,17 +77,16 @@ public final class XorServerCsrfTokenRequestAttributeHandler extends ServerCsrfT byte[] tokenBytes = Utf8.encode(token); int tokenSize = tokenBytes.length; - if (actualBytes.length < tokenSize) { + if (actualBytes.length != tokenSize * 2) { return null; } // extract token and random bytes - int randomBytesSize = actualBytes.length - tokenSize; byte[] xoredCsrf = new byte[tokenSize]; - byte[] randomBytes = new byte[randomBytesSize]; + byte[] randomBytes = new byte[tokenSize]; - System.arraycopy(actualBytes, 0, randomBytes, 0, randomBytesSize); - System.arraycopy(actualBytes, randomBytesSize, xoredCsrf, 0, tokenSize); + System.arraycopy(actualBytes, 0, randomBytes, 0, tokenSize); + System.arraycopy(actualBytes, tokenSize, xoredCsrf, 0, tokenSize); byte[] csrfBytes = xorCsrf(randomBytes, xoredCsrf); return Utf8.decode(csrfBytes); @@ -107,9 +106,10 @@ public final class XorServerCsrfTokenRequestAttributeHandler extends ServerCsrfT } private static byte[] xorCsrf(byte[] randomBytes, byte[] csrfBytes) { - int len = Math.min(randomBytes.length, csrfBytes.length); + Assert.isTrue(randomBytes.length == csrfBytes.length, "arrays must be equal length"); + int len = csrfBytes.length; byte[] xoredCsrf = new byte[len]; - System.arraycopy(csrfBytes, 0, xoredCsrf, 0, csrfBytes.length); + System.arraycopy(csrfBytes, 0, xoredCsrf, 0, len); for (int i = 0; i < len; i++) { xoredCsrf[i] ^= randomBytes[i]; } diff --git a/web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandlerTests.java b/web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandlerTests.java index 142f90254d..f8f2132bf9 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandlerTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandlerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,6 +44,9 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; */ public class XorCsrfTokenRequestAttributeHandlerTests { + /* + * Token format: 3 random pad bytes + 3 padded bytes. + */ private static final byte[] XOR_CSRF_TOKEN_BYTES = new byte[] { 1, 1, 1, 96, 99, 98 }; private static final String XOR_CSRF_TOKEN_VALUE = Base64.getEncoder().encodeToString(XOR_CSRF_TOKEN_BYTES); @@ -203,6 +206,58 @@ public class XorCsrfTokenRequestAttributeHandlerTests { assertThat(tokenValue).isEqualTo(this.token.getToken()); } + // gh-13310, gh-15184 + @Test + public void resolveCsrfTokenValueWhenCsrfBytesIsShorterThanRandomBytesThenReturnsNull() { + /* + * Token format: 3 random pad bytes + 2 padded bytes. + */ + byte[] actualBytes = { 1, 1, 1, 96, 99 }; + String actualToken = Base64.getEncoder().encodeToString(actualBytes); + this.request.setParameter(this.token.getParameterName(), actualToken); + String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token); + assertThat(tokenValue).isNull(); + } + + // gh-13310, gh-15184 + @Test + public void resolveCsrfTokenValueWhenCsrfBytesIsLongerThanRandomBytesThenReturnsNull() { + /* + * Token format: 3 random pad bytes + 4 padded bytes. + */ + byte[] actualBytes = { 1, 1, 1, 96, 99, 98, 97 }; + String actualToken = Base64.getEncoder().encodeToString(actualBytes); + this.request.setParameter(this.token.getParameterName(), actualToken); + String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token); + assertThat(tokenValue).isNull(); + } + + // gh-13310, gh-15184 + @Test + public void resolveCsrfTokenValueWhenTokenBytesIsShorterThanActualBytesThenReturnsNull() { + this.request.setParameter(this.token.getParameterName(), XOR_CSRF_TOKEN_VALUE); + CsrfToken csrfToken = new DefaultCsrfToken("headerName", "paramName", "a"); + String tokenValue = this.handler.resolveCsrfTokenValue(this.request, csrfToken); + assertThat(tokenValue).isNull(); + } + + // gh-13310, gh-15184 + @Test + public void resolveCsrfTokenValueWhenTokenBytesIsLongerThanActualBytesThenReturnsNull() { + this.request.setParameter(this.token.getParameterName(), XOR_CSRF_TOKEN_VALUE); + CsrfToken csrfToken = new DefaultCsrfToken("headerName", "paramName", "abcde"); + String tokenValue = this.handler.resolveCsrfTokenValue(this.request, csrfToken); + assertThat(tokenValue).isNull(); + } + + // gh-13310, gh-15184 + @Test + public void resolveCsrfTokenValueWhenActualBytesIsEmptyThenReturnsNull() { + this.request.setParameter(this.token.getParameterName(), ""); + String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token); + assertThat(tokenValue).isNull(); + } + private static Answer fillByteArray() { return (invocation) -> { byte[] bytes = invocation.getArgument(0); diff --git a/web/src/test/java/org/springframework/security/web/server/csrf/XorServerCsrfTokenRequestAttributeHandlerTests.java b/web/src/test/java/org/springframework/security/web/server/csrf/XorServerCsrfTokenRequestAttributeHandlerTests.java index c6b800af06..bd859ce965 100644 --- a/web/src/test/java/org/springframework/security/web/server/csrf/XorServerCsrfTokenRequestAttributeHandlerTests.java +++ b/web/src/test/java/org/springframework/security/web/server/csrf/XorServerCsrfTokenRequestAttributeHandlerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,6 +46,9 @@ import static org.mockito.Mockito.verify; */ public class XorServerCsrfTokenRequestAttributeHandlerTests { + /* + * Token format: 3 random pad bytes + 3 padded bytes. + */ private static final byte[] XOR_CSRF_TOKEN_BYTES = new byte[] { 1, 1, 1, 96, 99, 98 }; private static final String XOR_CSRF_TOKEN_VALUE = Base64.getEncoder().encodeToString(XOR_CSRF_TOKEN_BYTES); @@ -188,6 +191,78 @@ public class XorServerCsrfTokenRequestAttributeHandlerTests { StepVerifier.create(csrfToken).expectNext(this.token.getToken()).verifyComplete(); } + // gh-13310, gh-15184 + @Test + public void resolveCsrfTokenValueWhenCsrfBytesIsShorterThanRandomBytesThenReturnsNull() { + /* + * Token format: 3 random pad bytes + 2 padded bytes. + */ + byte[] actualBytes = { 1, 1, 1, 96, 99 }; + String actualToken = Base64.getEncoder().encodeToString(actualBytes); + this.exchange = MockServerWebExchange + .builder(MockServerHttpRequest.post("/") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE) + .header(this.token.getHeaderName(), actualToken)) + .build(); + String tokenValue = this.handler.resolveCsrfTokenValue(this.exchange, this.token).block(); + assertThat(tokenValue).isNull(); + } + + // gh-13310, gh-15184 + @Test + public void resolveCsrfTokenValueWhenCsrfBytesIsLongerThanRandomBytesThenReturnsNull() { + /* + * Token format: 3 random pad bytes + 4 padded bytes. + */ + byte[] actualBytes = { 1, 1, 1, 96, 99, 98, 97 }; + String actualToken = Base64.getEncoder().encodeToString(actualBytes); + this.exchange = MockServerWebExchange + .builder(MockServerHttpRequest.post("/") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE) + .header(this.token.getHeaderName(), actualToken)) + .build(); + String tokenValue = this.handler.resolveCsrfTokenValue(this.exchange, this.token).block(); + assertThat(tokenValue).isNull(); + } + + // gh-13310, gh-15184 + @Test + public void resolveCsrfTokenValueWhenTokenBytesIsShorterThanActualBytesThenReturnsNull() { + this.exchange = MockServerWebExchange + .builder(MockServerHttpRequest.post("/") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE) + .header(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE)) + .build(); + CsrfToken csrfToken = new DefaultCsrfToken("headerName", "paramName", "a"); + String tokenValue = this.handler.resolveCsrfTokenValue(this.exchange, csrfToken).block(); + assertThat(tokenValue).isNull(); + } + + // gh-13310, gh-15184 + @Test + public void resolveCsrfTokenValueWhenTokenBytesIsLongerThanActualBytesThenReturnsNull() { + this.exchange = MockServerWebExchange + .builder(MockServerHttpRequest.post("/") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE) + .header(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE)) + .build(); + CsrfToken csrfToken = new DefaultCsrfToken("headerName", "paramName", "abcde"); + String tokenValue = this.handler.resolveCsrfTokenValue(this.exchange, csrfToken).block(); + assertThat(tokenValue).isNull(); + } + + // gh-13310, gh-15184 + @Test + public void resolveCsrfTokenValueWhenActualBytesIsEmptyThenReturnsNull() { + this.exchange = MockServerWebExchange + .builder(MockServerHttpRequest.post("/") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE) + .header(this.token.getHeaderName(), "")) + .build(); + String tokenValue = this.handler.resolveCsrfTokenValue(this.exchange, this.token).block(); + assertThat(tokenValue).isNull(); + } + private static Answer fillByteArray() { return (invocation) -> { byte[] bytes = invocation.getArgument(0);