481567 - permessage-deflate causing data-dependent ju.zip.DataFormatException: invalid stored block lengths

+ Reworked PerMessageDeflateExtensionTest to test with different
  modes (http/ws vs https/wss), different messages sizes, and
  input buffer sizes (these various configurations do trigger
  the reported bug)
+ Made CompressExtension loop over the input buffer if the buffer
  happens to not be entirely consumed.
This commit is contained in:
Joakim Erdfelt 2015-12-07 13:15:29 -07:00
parent bae1138211
commit b9c1535552
5 changed files with 287 additions and 113 deletions

View File

@ -28,13 +28,9 @@ import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.URI;
import java.security.DigestOutputStream;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.websocket.ClientEndpoint;
import javax.websocket.CloseReason;
@ -63,7 +59,7 @@ import org.eclipse.jetty.toolchain.test.MavenTestingUtils;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.websocket.common.test.LeakTrackingBufferPoolRule;
import org.eclipse.jetty.websocket.common.util.Hex;
import org.eclipse.jetty.websocket.common.util.Sha1Sum;
import org.eclipse.jetty.websocket.jsr356.server.deploy.WebSocketServerContainerInitializer;
import org.junit.AfterClass;
import org.junit.Assert;
@ -176,33 +172,12 @@ public class StreamTest
Assert.assertThat("Path should exist: " + file,file.exists(),is(true));
Assert.assertThat("Path should not be a directory:" + file,file.isDirectory(),is(false));
String expectedSha1 = loadExpectedSha1Sum(sha1File);
String actualSha1 = calculateSha1Sum(file);
String expectedSha1 = Sha1Sum.loadSha1(sha1File);
String actualSha1 = Sha1Sum.calculate(file);
Assert.assertThat("SHA1Sum of content: " + file,expectedSha1,equalToIgnoringCase(actualSha1));
}
private String calculateSha1Sum(File file) throws IOException, NoSuchAlgorithmException
{
MessageDigest digest = MessageDigest.getInstance("SHA1");
try (FileInputStream fis = new FileInputStream(file);
NoOpOutputStream noop = new NoOpOutputStream();
DigestOutputStream digester = new DigestOutputStream(noop,digest))
{
IO.copy(fis,digester);
return Hex.asHex(digest.digest());
}
}
private String loadExpectedSha1Sum(File sha1File) throws IOException
{
String contents = IO.readToString(sha1File);
Pattern pat = Pattern.compile("^[0-9A-Fa-f]*");
Matcher mat = pat.matcher(contents);
Assert.assertTrue("Should have found HEX code in SHA1 file: " + sha1File,mat.find());
return mat.group();
}
@ClientEndpoint
public static class ClientSocket
{
@ -317,32 +292,4 @@ public class StreamTest
t.printStackTrace(System.err);
}
}
private static class NoOpOutputStream extends OutputStream
{
@Override
public void write(byte[] b) throws IOException
{
}
@Override
public void write(byte[] b, int off, int len) throws IOException
{
}
@Override
public void flush() throws IOException
{
}
@Override
public void close() throws IOException
{
}
@Override
public void write(int b) throws IOException
{
}
}
}

View File

@ -154,30 +154,33 @@ public abstract class CompressExtension extends AbstractExtension
return;
}
byte[] output = new byte[DECOMPRESS_BUF_SIZE];
if (inflater.needsInput() && !supplyInput(inflater,buf))
while(buf.hasRemaining() && inflater.needsInput())
{
LOG.debug("Needed input, but no buffer could supply input");
return;
}
int read = 0;
while ((read = inflater.inflate(output)) >= 0)
{
if (read == 0)
if (!supplyInput(inflater,buf))
{
LOG.debug("Decompress: read 0 {}",toDetail(inflater));
break;
LOG.debug("Needed input, but no buffer could supply input");
return;
}
else
int read = 0;
while ((read = inflater.inflate(output)) >= 0)
{
// do something with output
if (LOG.isDebugEnabled())
if (read == 0)
{
LOG.debug("Decompressed {} bytes: {}",read,toDetail(inflater));
LOG.debug("Decompress: read 0 {}",toDetail(inflater));
break;
}
else
{
// do something with output
if (LOG.isDebugEnabled())
{
LOG.debug("Decompressed {} bytes: {}",read,toDetail(inflater));
}
accumulator.copyChunk(output,0,read);
}
accumulator.copyChunk(output,0,read);
}
}

View File

@ -0,0 +1,110 @@
//
// ========================================================================
// Copyright (c) 1995-2015 Mort Bay Consulting Pty. Ltd.
// ------------------------------------------------------------------------
// All rights reserved. This program and the accompanying materials
// are made available under the terms of the Eclipse Public License v1.0
// and Apache License v2.0 which accompanies this distribution.
//
// The Eclipse Public License is available at
// http://www.eclipse.org/legal/epl-v10.html
//
// The Apache License v2.0 is available at
// http://www.opensource.org/licenses/apache2.0.php
//
// You may elect to redistribute this code under either of these licenses.
// ========================================================================
//
package org.eclipse.jetty.websocket.common.util;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.security.DigestOutputStream;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.eclipse.jetty.toolchain.test.IO;
import org.junit.Assert;
/**
* Calculate the sha1sum for various content
*/
public class Sha1Sum
{
private static class NoOpOutputStream extends OutputStream
{
@Override
public void write(byte[] b) throws IOException
{
}
@Override
public void write(byte[] b, int off, int len) throws IOException
{
}
@Override
public void flush() throws IOException
{
}
@Override
public void close() throws IOException
{
}
@Override
public void write(int b) throws IOException
{
}
}
public static String calculate(File file) throws NoSuchAlgorithmException, IOException
{
return calculate(file.toPath());
}
public static String calculate(Path path) throws NoSuchAlgorithmException, IOException
{
MessageDigest digest = MessageDigest.getInstance("SHA1");
try (InputStream in = Files.newInputStream(path,StandardOpenOption.READ);
NoOpOutputStream noop = new NoOpOutputStream();
DigestOutputStream digester = new DigestOutputStream(noop,digest))
{
IO.copy(in,digester);
return Hex.asHex(digest.digest());
}
}
public static String calculate(byte[] buf) throws NoSuchAlgorithmException
{
MessageDigest digest = MessageDigest.getInstance("SHA1");
digest.update(buf);
return Hex.asHex(digest.digest());
}
public static String calculate(byte[] buf, int offset, int len) throws NoSuchAlgorithmException
{
MessageDigest digest = MessageDigest.getInstance("SHA1");
digest.update(buf,offset,len);
return Hex.asHex(digest.digest());
}
public static String loadSha1(File sha1File) throws IOException
{
String contents = IO.readToString(sha1File);
Pattern pat = Pattern.compile("^[0-9A-Fa-f]*");
Matcher mat = pat.matcher(contents);
Assert.assertTrue("Should have found HEX code in SHA1 file: " + sha1File,mat.find());
return mat.group();
}
}

View File

@ -20,34 +20,90 @@ package org.eclipse.jetty.websocket.server;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThat;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import org.eclipse.jetty.toolchain.test.EventQueue;
import org.eclipse.jetty.websocket.common.WebSocketFrame;
import org.eclipse.jetty.websocket.common.frames.TextFrame;
import org.eclipse.jetty.websocket.common.test.BlockheadClient;
import org.eclipse.jetty.websocket.common.test.HttpResponse;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WebSocketPolicy;
import org.eclipse.jetty.websocket.api.extensions.ExtensionConfig;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.common.test.LeakTrackingBufferPoolRule;
import org.eclipse.jetty.websocket.common.util.Sha1Sum;
import org.eclipse.jetty.websocket.server.helper.CaptureSocket;
import org.eclipse.jetty.websocket.server.helper.EchoServlet;
import org.junit.AfterClass;
import org.junit.After;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;
@RunWith(Parameterized.class)
public class PerMessageDeflateExtensionTest
{
private static SimpleServletServer server;
@BeforeClass
public static void startServer() throws Exception
private static enum TestCaseMessageSize
{
server = new SimpleServletServer(new EchoServlet());
server.start();
TINY(10),
SMALL(1024),
MEDIUM(10*1024),
LARGE(100*1024),
HUGE(1024*1024);
private int size;
private TestCaseMessageSize(int size)
{
this.size = size;
}
}
@Parameters(name = "{0} ({3}) (Input Buffer Size: {4} bytes)")
public static List<Object[]> modes()
{
List<Object[]> modes = new ArrayList<>();
for(TestCaseMessageSize size: TestCaseMessageSize.values())
{
modes.add(new Object[] { "Normal HTTP/WS", false, "ws", size, -1 });
modes.add(new Object[] { "Encrypted HTTPS/WSS", true, "wss", size, -1 });
int altInputBufSize = 15*1024;
modes.add(new Object[] { "Normal HTTP/WS", false, "ws", size, altInputBufSize });
modes.add(new Object[] { "Encrypted HTTPS/WSS", true, "wss", size, altInputBufSize });
}
return modes;
}
@AfterClass
public static void stopServer()
@Rule
public LeakTrackingBufferPoolRule bufferPool = new LeakTrackingBufferPoolRule("Test");
private SimpleServletServer server;
private String scheme;
private int msgSize;
private int inputBufferSize;
public PerMessageDeflateExtensionTest(String mode, boolean sslMode, String scheme, TestCaseMessageSize msgSize, int bufferSize) throws Exception
{
server = new SimpleServletServer(new EchoServlet());
server.enableSsl(sslMode);
server.start();
this.scheme = scheme;
this.msgSize = msgSize.size;
this.inputBufferSize = bufferSize;
}
@After
public void stopServer()
{
server.stop();
}
@ -62,42 +118,84 @@ public class PerMessageDeflateExtensionTest
Assume.assumeTrue("Server has permessage-deflate registered",
server.getWebSocketServletFactory().getExtensionFactory().isAvailable("permessage-deflate"));
BlockheadClient client = new BlockheadClient(server.getServerUri());
client.clearExtensions();
client.addExtensions("permessage-deflate");
client.setProtocols("echo");
Assert.assertThat("server scheme",server.getServerUri().getScheme(),is(scheme));
int binBufferSize = (int) (msgSize * 1.5);
WebSocketPolicy serverPolicy = server.getWebSocketServletFactory().getPolicy();
// Ensure binBufferSize is sane (not smaller then other buffers)
binBufferSize = Math.max(binBufferSize,serverPolicy.getMaxBinaryMessageSize());
binBufferSize = Math.max(binBufferSize,serverPolicy.getMaxBinaryMessageBufferSize());
binBufferSize = Math.max(binBufferSize,this.inputBufferSize);
serverPolicy.setMaxBinaryMessageSize(binBufferSize);
serverPolicy.setMaxBinaryMessageBufferSize(binBufferSize);
WebSocketClient client = new WebSocketClient(server.getSslContextFactory(),null,bufferPool);
WebSocketPolicy clientPolicy = client.getPolicy();
clientPolicy.setMaxBinaryMessageSize(binBufferSize);
clientPolicy.setMaxBinaryMessageBufferSize(binBufferSize);
if (inputBufferSize > 0)
{
clientPolicy.setInputBufferSize(inputBufferSize);
}
try
{
client.start();
// Make sure the read times out if there are problems with the implementation
client.setTimeout(1,TimeUnit.SECONDS);
client.connect();
client.sendStandardRequest();
HttpResponse resp = client.expectUpgradeResponse();
client.setMaxIdleTimeout(TimeUnit.SECONDS.toMillis(1));
Assert.assertThat("Response",resp.getExtensionsHeader(),containsString("permessage-deflate"));
CaptureSocket clientSocket = new CaptureSocket();
ClientUpgradeRequest request = new ClientUpgradeRequest();
request.addExtensions("permessage-deflate");
request.setSubProtocols("echo");
String msg = "Hello";
Future<Session> fut = client.connect(clientSocket,server.getServerUri(),request);
// Wait for connect
Session session = fut.get(3,TimeUnit.SECONDS);
assertThat("Response.extensions",getNegotiatedExtensionList(session),containsString("permessage-deflate"));
// Create message
byte msg[] = new byte[msgSize];
Random rand = new Random();
rand.setSeed(8080);
rand.nextBytes(msg);
// Calculate sha1
String sha1 = Sha1Sum.calculate(msg);
// Client sends first message
client.write(new TextFrame().setPayload(msg));
session.getRemote().sendBytes(ByteBuffer.wrap(msg));
EventQueue<WebSocketFrame> frames = client.readFrames(1,1000,TimeUnit.MILLISECONDS);
WebSocketFrame frame = frames.poll();
Assert.assertThat("TEXT.payload",frame.getPayloadAsUTF8(),is(msg.toString()));
// Client sends second message
client.clearCaptured();
msg = "There";
client.write(new TextFrame().setPayload(msg));
frames = client.readFrames(1,1,TimeUnit.SECONDS);
frame = frames.poll();
Assert.assertThat("TEXT.payload",frame.getPayloadAsUTF8(),is(msg.toString()));
clientSocket.messages.awaitEventCount(1,1,TimeUnit.SECONDS);
String echoMsg = clientSocket.messages.poll();
Assert.assertThat("Echo'd Message",echoMsg,is("binary[sha1="+sha1+"]"));
}
finally
{
client.close();
client.stop();
}
}
private String getNegotiatedExtensionList(Session session)
{
StringBuilder actual = new StringBuilder();
actual.append('[');
boolean delim = false;
for (ExtensionConfig ext : session.getUpgradeResponse().getExtensions())
{
if (delim)
actual.append(", ");
actual.append(ext.getName());
delim = true;
}
actual.append(']');
return actual.toString();
}
}

View File

@ -18,12 +18,14 @@
package org.eclipse.jetty.websocket.server.helper;
import java.security.NoSuchAlgorithmException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import org.eclipse.jetty.toolchain.test.EventQueue;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WebSocketAdapter;
import org.eclipse.jetty.websocket.common.util.Sha1Sum;
public class CaptureSocket extends WebSocketAdapter
{
@ -58,4 +60,18 @@ public class CaptureSocket extends WebSocketAdapter
// System.out.printf("Received Message \"%s\" [size %d]%n", message, message.length());
messages.add(message);
}
@Override
public void onWebSocketBinary(byte[] payload, int offset, int len)
{
try
{
messages.add("binary[sha1="+Sha1Sum.calculate(payload,offset,len)+"]");
}
catch (NoSuchAlgorithmException e)
{
messages.add("ERROR: Unable to caclulate Binary SHA1: " + e.getMessage());
e.printStackTrace();
}
}
}