From b9c15355520f1c0db9ddf4fa0d237b68269438b5 Mon Sep 17 00:00:00 2001 From: Joakim Erdfelt Date: Mon, 7 Dec 2015 13:15:29 -0700 Subject: [PATCH] 481567 - permessage-deflate causing data-dependent ju.zip.DataFormatException: invalid stored block lengths + Reworked PerMessageDeflateExtensionTest to test with different modes (http/ws vs https/wss), different messages sizes, and input buffer sizes (these various configurations do trigger the reported bug) + Made CompressExtension loop over the input buffer if the buffer happens to not be entirely consumed. --- .../websocket/jsr356/server/StreamTest.java | 59 +----- .../compress/CompressExtension.java | 39 ++-- .../jetty/websocket/common/util/Sha1Sum.java | 110 +++++++++++ .../PerMessageDeflateExtensionTest.java | 176 ++++++++++++++---- .../server/helper/CaptureSocket.java | 16 ++ 5 files changed, 287 insertions(+), 113 deletions(-) create mode 100644 jetty-websocket/websocket-common/src/test/java/org/eclipse/jetty/websocket/common/util/Sha1Sum.java diff --git a/jetty-websocket/javax-websocket-server-impl/src/test/java/org/eclipse/jetty/websocket/jsr356/server/StreamTest.java b/jetty-websocket/javax-websocket-server-impl/src/test/java/org/eclipse/jetty/websocket/jsr356/server/StreamTest.java index ca51ec54367..7a39a440964 100644 --- a/jetty-websocket/javax-websocket-server-impl/src/test/java/org/eclipse/jetty/websocket/jsr356/server/StreamTest.java +++ b/jetty-websocket/javax-websocket-server-impl/src/test/java/org/eclipse/jetty/websocket/jsr356/server/StreamTest.java @@ -28,13 +28,9 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.URI; -import java.security.DigestOutputStream; -import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; -import java.util.regex.Matcher; -import java.util.regex.Pattern; import javax.websocket.ClientEndpoint; import javax.websocket.CloseReason; @@ -63,7 +59,7 @@ import org.eclipse.jetty.toolchain.test.MavenTestingUtils; import org.eclipse.jetty.util.log.Log; import org.eclipse.jetty.util.log.Logger; import org.eclipse.jetty.websocket.common.test.LeakTrackingBufferPoolRule; -import org.eclipse.jetty.websocket.common.util.Hex; +import org.eclipse.jetty.websocket.common.util.Sha1Sum; import org.eclipse.jetty.websocket.jsr356.server.deploy.WebSocketServerContainerInitializer; import org.junit.AfterClass; import org.junit.Assert; @@ -176,33 +172,12 @@ public class StreamTest Assert.assertThat("Path should exist: " + file,file.exists(),is(true)); Assert.assertThat("Path should not be a directory:" + file,file.isDirectory(),is(false)); - String expectedSha1 = loadExpectedSha1Sum(sha1File); - String actualSha1 = calculateSha1Sum(file); + String expectedSha1 = Sha1Sum.loadSha1(sha1File); + String actualSha1 = Sha1Sum.calculate(file); Assert.assertThat("SHA1Sum of content: " + file,expectedSha1,equalToIgnoringCase(actualSha1)); } - private String calculateSha1Sum(File file) throws IOException, NoSuchAlgorithmException - { - MessageDigest digest = MessageDigest.getInstance("SHA1"); - try (FileInputStream fis = new FileInputStream(file); - NoOpOutputStream noop = new NoOpOutputStream(); - DigestOutputStream digester = new DigestOutputStream(noop,digest)) - { - IO.copy(fis,digester); - return Hex.asHex(digest.digest()); - } - } - - private String loadExpectedSha1Sum(File sha1File) throws IOException - { - String contents = IO.readToString(sha1File); - Pattern pat = Pattern.compile("^[0-9A-Fa-f]*"); - Matcher mat = pat.matcher(contents); - Assert.assertTrue("Should have found HEX code in SHA1 file: " + sha1File,mat.find()); - return mat.group(); - } - @ClientEndpoint public static class ClientSocket { @@ -317,32 +292,4 @@ public class StreamTest t.printStackTrace(System.err); } } - - private static class NoOpOutputStream extends OutputStream - { - @Override - public void write(byte[] b) throws IOException - { - } - - @Override - public void write(byte[] b, int off, int len) throws IOException - { - } - - @Override - public void flush() throws IOException - { - } - - @Override - public void close() throws IOException - { - } - - @Override - public void write(int b) throws IOException - { - } - } } diff --git a/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/CompressExtension.java b/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/CompressExtension.java index 5c27104c594..ec73e3c384e 100644 --- a/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/CompressExtension.java +++ b/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/CompressExtension.java @@ -154,30 +154,33 @@ public abstract class CompressExtension extends AbstractExtension return; } byte[] output = new byte[DECOMPRESS_BUF_SIZE]; - - if (inflater.needsInput() && !supplyInput(inflater,buf)) + + while(buf.hasRemaining() && inflater.needsInput()) { - LOG.debug("Needed input, but no buffer could supply input"); - return; - } - - int read = 0; - while ((read = inflater.inflate(output)) >= 0) - { - if (read == 0) + if (!supplyInput(inflater,buf)) { - LOG.debug("Decompress: read 0 {}",toDetail(inflater)); - break; + LOG.debug("Needed input, but no buffer could supply input"); + return; } - else + + int read = 0; + while ((read = inflater.inflate(output)) >= 0) { - // do something with output - if (LOG.isDebugEnabled()) + if (read == 0) { - LOG.debug("Decompressed {} bytes: {}",read,toDetail(inflater)); + LOG.debug("Decompress: read 0 {}",toDetail(inflater)); + break; + } + else + { + // do something with output + if (LOG.isDebugEnabled()) + { + LOG.debug("Decompressed {} bytes: {}",read,toDetail(inflater)); + } + + accumulator.copyChunk(output,0,read); } - - accumulator.copyChunk(output,0,read); } } diff --git a/jetty-websocket/websocket-common/src/test/java/org/eclipse/jetty/websocket/common/util/Sha1Sum.java b/jetty-websocket/websocket-common/src/test/java/org/eclipse/jetty/websocket/common/util/Sha1Sum.java new file mode 100644 index 00000000000..f3771efb8a3 --- /dev/null +++ b/jetty-websocket/websocket-common/src/test/java/org/eclipse/jetty/websocket/common/util/Sha1Sum.java @@ -0,0 +1,110 @@ +// +// ======================================================================== +// Copyright (c) 1995-2015 Mort Bay Consulting Pty. Ltd. +// ------------------------------------------------------------------------ +// All rights reserved. This program and the accompanying materials +// are made available under the terms of the Eclipse Public License v1.0 +// and Apache License v2.0 which accompanies this distribution. +// +// The Eclipse Public License is available at +// http://www.eclipse.org/legal/epl-v10.html +// +// The Apache License v2.0 is available at +// http://www.opensource.org/licenses/apache2.0.php +// +// You may elect to redistribute this code under either of these licenses. +// ======================================================================== +// + +package org.eclipse.jetty.websocket.common.util; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.security.DigestOutputStream; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import org.eclipse.jetty.toolchain.test.IO; +import org.junit.Assert; + +/** + * Calculate the sha1sum for various content + */ +public class Sha1Sum +{ + private static class NoOpOutputStream extends OutputStream + { + @Override + public void write(byte[] b) throws IOException + { + } + + @Override + public void write(byte[] b, int off, int len) throws IOException + { + } + + @Override + public void flush() throws IOException + { + } + + @Override + public void close() throws IOException + { + } + + @Override + public void write(int b) throws IOException + { + } + } + + public static String calculate(File file) throws NoSuchAlgorithmException, IOException + { + return calculate(file.toPath()); + } + + public static String calculate(Path path) throws NoSuchAlgorithmException, IOException + { + MessageDigest digest = MessageDigest.getInstance("SHA1"); + try (InputStream in = Files.newInputStream(path,StandardOpenOption.READ); + NoOpOutputStream noop = new NoOpOutputStream(); + DigestOutputStream digester = new DigestOutputStream(noop,digest)) + { + IO.copy(in,digester); + return Hex.asHex(digest.digest()); + } + } + + public static String calculate(byte[] buf) throws NoSuchAlgorithmException + { + MessageDigest digest = MessageDigest.getInstance("SHA1"); + digest.update(buf); + return Hex.asHex(digest.digest()); + } + + public static String calculate(byte[] buf, int offset, int len) throws NoSuchAlgorithmException + { + MessageDigest digest = MessageDigest.getInstance("SHA1"); + digest.update(buf,offset,len); + return Hex.asHex(digest.digest()); + } + + public static String loadSha1(File sha1File) throws IOException + { + String contents = IO.readToString(sha1File); + Pattern pat = Pattern.compile("^[0-9A-Fa-f]*"); + Matcher mat = pat.matcher(contents); + Assert.assertTrue("Should have found HEX code in SHA1 file: " + sha1File,mat.find()); + return mat.group(); + } + +} diff --git a/jetty-websocket/websocket-server/src/test/java/org/eclipse/jetty/websocket/server/PerMessageDeflateExtensionTest.java b/jetty-websocket/websocket-server/src/test/java/org/eclipse/jetty/websocket/server/PerMessageDeflateExtensionTest.java index 5242339e6aa..1673d6086f1 100644 --- a/jetty-websocket/websocket-server/src/test/java/org/eclipse/jetty/websocket/server/PerMessageDeflateExtensionTest.java +++ b/jetty-websocket/websocket-server/src/test/java/org/eclipse/jetty/websocket/server/PerMessageDeflateExtensionTest.java @@ -20,34 +20,90 @@ package org.eclipse.jetty.websocket.server; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; -import org.eclipse.jetty.toolchain.test.EventQueue; -import org.eclipse.jetty.websocket.common.WebSocketFrame; -import org.eclipse.jetty.websocket.common.frames.TextFrame; -import org.eclipse.jetty.websocket.common.test.BlockheadClient; -import org.eclipse.jetty.websocket.common.test.HttpResponse; +import org.eclipse.jetty.websocket.api.Session; +import org.eclipse.jetty.websocket.api.WebSocketPolicy; +import org.eclipse.jetty.websocket.api.extensions.ExtensionConfig; +import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; +import org.eclipse.jetty.websocket.client.WebSocketClient; +import org.eclipse.jetty.websocket.common.test.LeakTrackingBufferPoolRule; +import org.eclipse.jetty.websocket.common.util.Sha1Sum; +import org.eclipse.jetty.websocket.server.helper.CaptureSocket; import org.eclipse.jetty.websocket.server.helper.EchoServlet; -import org.junit.AfterClass; +import org.junit.After; import org.junit.Assert; import org.junit.Assume; -import org.junit.BeforeClass; +import org.junit.Rule; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; +@RunWith(Parameterized.class) public class PerMessageDeflateExtensionTest { - private static SimpleServletServer server; - - @BeforeClass - public static void startServer() throws Exception + private static enum TestCaseMessageSize { - server = new SimpleServletServer(new EchoServlet()); - server.start(); + TINY(10), + SMALL(1024), + MEDIUM(10*1024), + LARGE(100*1024), + HUGE(1024*1024); + + private int size; + + private TestCaseMessageSize(int size) + { + this.size = size; + } + } + + @Parameters(name = "{0} ({3}) (Input Buffer Size: {4} bytes)") + public static List modes() + { + List modes = new ArrayList<>(); + + for(TestCaseMessageSize size: TestCaseMessageSize.values()) + { + modes.add(new Object[] { "Normal HTTP/WS", false, "ws", size, -1 }); + modes.add(new Object[] { "Encrypted HTTPS/WSS", true, "wss", size, -1 }); + int altInputBufSize = 15*1024; + modes.add(new Object[] { "Normal HTTP/WS", false, "ws", size, altInputBufSize }); + modes.add(new Object[] { "Encrypted HTTPS/WSS", true, "wss", size, altInputBufSize }); + } + + return modes; } - @AfterClass - public static void stopServer() + @Rule + public LeakTrackingBufferPoolRule bufferPool = new LeakTrackingBufferPoolRule("Test"); + + private SimpleServletServer server; + private String scheme; + private int msgSize; + private int inputBufferSize; + + public PerMessageDeflateExtensionTest(String mode, boolean sslMode, String scheme, TestCaseMessageSize msgSize, int bufferSize) throws Exception + { + server = new SimpleServletServer(new EchoServlet()); + server.enableSsl(sslMode); + server.start(); + + this.scheme = scheme; + this.msgSize = msgSize.size; + this.inputBufferSize = bufferSize; + } + + @After + public void stopServer() { server.stop(); } @@ -62,42 +118,84 @@ public class PerMessageDeflateExtensionTest Assume.assumeTrue("Server has permessage-deflate registered", server.getWebSocketServletFactory().getExtensionFactory().isAvailable("permessage-deflate")); - BlockheadClient client = new BlockheadClient(server.getServerUri()); - client.clearExtensions(); - client.addExtensions("permessage-deflate"); - client.setProtocols("echo"); + Assert.assertThat("server scheme",server.getServerUri().getScheme(),is(scheme)); + int binBufferSize = (int) (msgSize * 1.5); + + WebSocketPolicy serverPolicy = server.getWebSocketServletFactory().getPolicy(); + + // Ensure binBufferSize is sane (not smaller then other buffers) + binBufferSize = Math.max(binBufferSize,serverPolicy.getMaxBinaryMessageSize()); + binBufferSize = Math.max(binBufferSize,serverPolicy.getMaxBinaryMessageBufferSize()); + binBufferSize = Math.max(binBufferSize,this.inputBufferSize); + + serverPolicy.setMaxBinaryMessageSize(binBufferSize); + serverPolicy.setMaxBinaryMessageBufferSize(binBufferSize); + + WebSocketClient client = new WebSocketClient(server.getSslContextFactory(),null,bufferPool); + WebSocketPolicy clientPolicy = client.getPolicy(); + clientPolicy.setMaxBinaryMessageSize(binBufferSize); + clientPolicy.setMaxBinaryMessageBufferSize(binBufferSize); + if (inputBufferSize > 0) + { + clientPolicy.setInputBufferSize(inputBufferSize); + } + try { + client.start(); // Make sure the read times out if there are problems with the implementation - client.setTimeout(1,TimeUnit.SECONDS); - client.connect(); - client.sendStandardRequest(); - HttpResponse resp = client.expectUpgradeResponse(); + client.setMaxIdleTimeout(TimeUnit.SECONDS.toMillis(1)); - Assert.assertThat("Response",resp.getExtensionsHeader(),containsString("permessage-deflate")); + CaptureSocket clientSocket = new CaptureSocket(); + ClientUpgradeRequest request = new ClientUpgradeRequest(); + request.addExtensions("permessage-deflate"); + request.setSubProtocols("echo"); - String msg = "Hello"; + Future fut = client.connect(clientSocket,server.getServerUri(),request); + // Wait for connect + Session session = fut.get(3,TimeUnit.SECONDS); + + assertThat("Response.extensions",getNegotiatedExtensionList(session),containsString("permessage-deflate")); + + // Create message + byte msg[] = new byte[msgSize]; + Random rand = new Random(); + rand.setSeed(8080); + rand.nextBytes(msg); + + // Calculate sha1 + String sha1 = Sha1Sum.calculate(msg); + // Client sends first message - client.write(new TextFrame().setPayload(msg)); + session.getRemote().sendBytes(ByteBuffer.wrap(msg)); - EventQueue frames = client.readFrames(1,1000,TimeUnit.MILLISECONDS); - WebSocketFrame frame = frames.poll(); - Assert.assertThat("TEXT.payload",frame.getPayloadAsUTF8(),is(msg.toString())); - - // Client sends second message - client.clearCaptured(); - msg = "There"; - client.write(new TextFrame().setPayload(msg)); - - frames = client.readFrames(1,1,TimeUnit.SECONDS); - frame = frames.poll(); - Assert.assertThat("TEXT.payload",frame.getPayloadAsUTF8(),is(msg.toString())); + clientSocket.messages.awaitEventCount(1,1,TimeUnit.SECONDS); + String echoMsg = clientSocket.messages.poll(); + Assert.assertThat("Echo'd Message",echoMsg,is("binary[sha1="+sha1+"]")); } finally { - client.close(); + client.stop(); } } + + private String getNegotiatedExtensionList(Session session) + { + StringBuilder actual = new StringBuilder(); + actual.append('['); + + boolean delim = false; + for (ExtensionConfig ext : session.getUpgradeResponse().getExtensions()) + { + if (delim) + actual.append(", "); + actual.append(ext.getName()); + delim = true; + } + actual.append(']'); + + return actual.toString(); + } } diff --git a/jetty-websocket/websocket-server/src/test/java/org/eclipse/jetty/websocket/server/helper/CaptureSocket.java b/jetty-websocket/websocket-server/src/test/java/org/eclipse/jetty/websocket/server/helper/CaptureSocket.java index 5218c94e09b..8e02950e776 100644 --- a/jetty-websocket/websocket-server/src/test/java/org/eclipse/jetty/websocket/server/helper/CaptureSocket.java +++ b/jetty-websocket/websocket-server/src/test/java/org/eclipse/jetty/websocket/server/helper/CaptureSocket.java @@ -18,12 +18,14 @@ package org.eclipse.jetty.websocket.server.helper; +import java.security.NoSuchAlgorithmException; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import org.eclipse.jetty.toolchain.test.EventQueue; import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.WebSocketAdapter; +import org.eclipse.jetty.websocket.common.util.Sha1Sum; public class CaptureSocket extends WebSocketAdapter { @@ -58,4 +60,18 @@ public class CaptureSocket extends WebSocketAdapter // System.out.printf("Received Message \"%s\" [size %d]%n", message, message.length()); messages.add(message); } + + @Override + public void onWebSocketBinary(byte[] payload, int offset, int len) + { + try + { + messages.add("binary[sha1="+Sha1Sum.calculate(payload,offset,len)+"]"); + } + catch (NoSuchAlgorithmException e) + { + messages.add("ERROR: Unable to caclulate Binary SHA1: " + e.getMessage()); + e.printStackTrace(); + } + } }