427700 - Outgoing extensions that create multiple frames should flush

them in order and atomically.

Modified DeflateFrameExtension to use a Queue and IteratingCallback to
make sure that frames are iteratively compressed in chunks.
The compression of the next chunk only happen when there is a callback
from the next outgoing layer.
This commit is contained in:
Simone Bordet 2014-02-14 14:10:52 +01:00
parent 53b1ee9e47
commit 81b8990dec
4 changed files with 347 additions and 180 deletions

View File

@ -22,43 +22,27 @@ import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import org.eclipse.jetty.io.ByteBufferPool;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.websocket.api.MessageTooLargeException;
public class ByteAccumulator
{
private static class Buf
{
public Buf(byte[] buffer, int offset, int length)
{
this.buffer = buffer;
this.offset = offset;
this.length = length;
}
byte[] buffer;
int offset;
int length;
}
private final List<Chunk> chunks = new ArrayList<>();
private final int maxSize;
private int length = 0;
private List<Buf> buffers;
public ByteAccumulator(int maxOverallBufferSize)
{
this.maxSize = maxOverallBufferSize;
this.buffers = new ArrayList<>();
}
public void addBuffer(byte buf[], int offset, int length)
public void addChunk(byte buf[], int offset, int length)
{
if (this.length + length > maxSize)
{
throw new MessageTooLargeException("Frame is too large");
}
buffers.add(new Buf(buf,offset,length));
chunks.add(new Chunk(buf, offset, length));
this.length += length;
}
@ -67,17 +51,29 @@ public class ByteAccumulator
return length;
}
public ByteBuffer getByteBuffer(ByteBufferPool pool)
public void transferTo(ByteBuffer buffer)
{
ByteBuffer ret = pool.acquire(length,false);
BufferUtil.clearToFill(ret);
for (Buf buf : buffers)
if (buffer.remaining() < length)
throw new IllegalArgumentException();
int position = buffer.position();
for (Chunk chunk : chunks)
{
ret.put(buf.buffer, buf.offset, buf.length);
buffer.put(chunk.buffer, chunk.offset, chunk.length);
}
BufferUtil.flipToFlush(buffer, position);
}
BufferUtil.flipToFlush(ret,0);
return ret;
private static class Chunk
{
private final byte[] buffer;
private final int offset;
private final int length;
private Chunk(byte[] buffer, int offset, int length)
{
this.buffer = buffer;
this.offset = offset;
this.length = length;
}
}
}

View File

@ -19,38 +19,44 @@
package org.eclipse.jetty.websocket.common.extensions.compress;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Queue;
import java.util.zip.DataFormatException;
import java.util.zip.Deflater;
import java.util.zip.Inflater;
import java.util.zip.ZipException;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.ConcurrentArrayQueue;
import org.eclipse.jetty.util.IteratingCallback;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.websocket.api.BadPayloadException;
import org.eclipse.jetty.websocket.api.WriteCallback;
import org.eclipse.jetty.websocket.api.extensions.ExtensionConfig;
import org.eclipse.jetty.websocket.api.extensions.Frame;
import org.eclipse.jetty.websocket.common.OpCode;
import org.eclipse.jetty.websocket.common.extensions.AbstractExtension;
import org.eclipse.jetty.websocket.common.frames.DataFrame;
/**
* Implementation of the <a href="https://tools.ietf.org/id/draft-tyoshino-hybi-websocket-perframe-deflate-05.txt">deflate-frame</a> extension seen out in the
* wild.
* Implementation of the
* <a href="https://tools.ietf.org/id/draft-tyoshino-hybi-websocket-perframe-deflate.txt">deflate-frame</a>
* extension seen out in the wild.
*/
public class DeflateFrameExtension extends AbstractExtension
{
private static final boolean BFINAL_HACK = Boolean.parseBoolean(System.getProperty("jetty.websocket.bfinal.hack","true"));
private static final Logger LOG = Log.getLogger(DeflateFrameExtension.class);
private static final byte[] TAIL_BYTES = new byte[]{0x00, 0x00, (byte)0xFF, (byte)0xFF};
private static final int OVERHEAD = 64;
/** Tail Bytes per Spec */
private static final byte[] TAIL = new byte[] { 0x00, 0x00, (byte)0xFF, (byte)0xFF };
private int bufferSize = 64 * 1024;
private Deflater compressor;
private Inflater decompressor;
private final Queue<FrameEntry> entries = new ConcurrentArrayQueue<>();
private final IteratingCallback flusher = new Flusher();
private final Deflater compressor;
private final Inflater decompressor;
public DeflateFrameExtension()
{
compressor = new Deflater(Deflater.BEST_COMPRESSION, true);
decompressor = new Inflater(true);
}
@Override
public String getName()
@ -61,77 +67,85 @@ public class DeflateFrameExtension extends AbstractExtension
@Override
public void incomingFrame(Frame frame)
{
// Incoming frames are always non concurrent because
// they are read and parsed with a single thread, and
// therefore there is no need for synchronization.
if (OpCode.isControlFrame(frame.getOpCode()) || !frame.isRsv1())
{
// Cannot modify incoming control frames or ones with RSV1 set.
// Cannot modify incoming control frames or ones without RSV1 set.
nextIncomingFrame(frame);
return;
}
if (!frame.hasPayload())
{
// no payload? nothing to do.
// No payload ? Nothing to do.
nextIncomingFrame(frame);
return;
}
// Prime the decompressor
ByteBuffer payload = frame.getPayload();
int inlen = payload.remaining();
byte compressed[] = new byte[inlen + TAIL.length];
payload.get(compressed,0,inlen);
System.arraycopy(TAIL,0,compressed,inlen,TAIL.length);
int remaining = payload.remaining();
byte[] input = new byte[remaining + TAIL_BYTES.length];
payload.get(input, 0, remaining);
System.arraycopy(TAIL_BYTES, 0, input, remaining, TAIL_BYTES.length);
// Since we don't track text vs binary vs continuation state, just grab whatever is the greater value.
int maxSize = Math.max(getPolicy().getMaxTextMessageSize(),getPolicy().getMaxBinaryMessageBufferSize());
int maxSize = Math.max(getPolicy().getMaxTextMessageSize(), getPolicy().getMaxBinaryMessageBufferSize());
ByteAccumulator accumulator = new ByteAccumulator(maxSize);
DataFrame out = new DataFrame(frame);
out.setRsv1(false); // Unset RSV1
// Unset RSV1 since it's not compressed anymore.
out.setRsv1(false);
synchronized (decompressor)
decompressor.setInput(input, 0, input.length);
try
{
decompressor.setInput(compressed,0,compressed.length);
// Perform decompression
while (decompressor.getRemaining() > 0 && !decompressor.finished())
while (decompressor.getRemaining() > 0)
{
byte outbuf[] = new byte[Math.min(inlen * 2,bufferSize)];
try
byte[] output = new byte[Math.min(remaining * 2, 64 * 1024)];
int len = decompressor.inflate(output);
if (len == 0)
{
int len = decompressor.inflate(outbuf);
if (len == 0)
if (decompressor.needsInput())
{
if (decompressor.needsInput())
{
throw new BadPayloadException("Unable to inflate frame, not enough input on frame");
}
if (decompressor.needsDictionary())
{
throw new BadPayloadException("Unable to inflate frame, frame erroneously says it needs a dictionary");
}
throw new BadPayloadException("Unable to inflate frame, not enough input on frame");
}
if (len > 0)
if (decompressor.needsDictionary())
{
accumulator.addBuffer(outbuf,0,len);
throw new BadPayloadException("Unable to inflate frame, frame erroneously says it needs a dictionary");
}
}
catch (DataFormatException e)
else
{
LOG.warn(e);
throw new BadPayloadException(e);
accumulator.addChunk(output, 0, len);
}
}
}
catch (DataFormatException x)
{
throw new BadPayloadException(x);
}
// Forward on the frame
out.setPayload(accumulator.getByteBuffer(getBufferPool()));
nextIncomingFrame(out);
ByteBuffer buffer = getBufferPool().acquire(accumulator.getLength(), false);
try
{
BufferUtil.flipToFill(buffer);
accumulator.transferTo(buffer);
out.setPayload(buffer);
nextIncomingFrame(out);
}
finally
{
getBufferPool().release(buffer);
}
}
/**
* Indicates use of RSV1 flag for indicating deflation is in use.
* <p>
* <p/>
* Also known as the "COMP" framing header bit
*/
@Override
@ -143,112 +157,193 @@ public class DeflateFrameExtension extends AbstractExtension
@Override
public void outgoingFrame(Frame frame, WriteCallback callback)
{
if (OpCode.isControlFrame(frame.getOpCode()))
if (flusher.isFailed())
{
// skip, cannot compress control frames.
nextOutgoingFrame(frame,callback);
if (callback != null)
callback.writeFailed(new ZipException());
return;
}
if (!frame.hasPayload())
{
// pass through, nothing to do
nextOutgoingFrame(frame,callback);
return;
}
if (LOG.isDebugEnabled())
{
LOG.debug("outgoingFrame({}, {}) - {}",OpCode.name(frame.getOpCode()),callback != null?callback.getClass().getSimpleName():"<null>",
BufferUtil.toDetailString(frame.getPayload()));
}
// Prime the compressor
byte uncompressed[] = BufferUtil.toArray(frame.getPayload());
List<DataFrame> dframes = new ArrayList<>();
synchronized (compressor)
{
// Perform the compression
if (!compressor.finished())
{
compressor.setInput(uncompressed,0,uncompressed.length);
byte compressed[] = new byte[uncompressed.length + OVERHEAD];
while (!compressor.needsInput())
{
int len = compressor.deflate(compressed,0,compressed.length,Deflater.SYNC_FLUSH);
ByteBuffer outbuf = getBufferPool().acquire(len,true);
BufferUtil.clearToFill(outbuf);
if (len > 0)
{
outbuf.put(compressed,0,len - 4);
}
BufferUtil.flipToFlush(outbuf,0);
if (len > 0 && BFINAL_HACK)
{
/*
* Per the spec, it says that BFINAL 1 or 0 are allowed.
*
* However, Java always uses BFINAL 1, whereas the browsers Chromium and Safari fail to decompress when it encounters BFINAL 1.
*
* This hack will always set BFINAL 0
*/
byte b0 = outbuf.get(0);
if ((b0 & 1) != 0) // if BFINAL 1
{
outbuf.put(0,(b0 ^= 1)); // flip bit to BFINAL 0
}
}
DataFrame out = new DataFrame(frame);
out.setRsv1(true);
out.setBufferPool(getBufferPool());
out.setPayload(outbuf);
if (!compressor.needsInput())
{
// this is fragmented
out.setFin(false);
}
dframes.add(out);
}
}
}
// notify outside of synchronize
for (DataFrame df : dframes)
{
if (df.isFin())
{
nextOutgoingFrame(df,callback);
}
else
{
// non final frames have no callback
nextOutgoingFrame(df,null);
}
}
}
@Override
public void setConfig(ExtensionConfig config)
{
super.setConfig(config);
boolean nowrap = true;
compressor = new Deflater(Deflater.BEST_COMPRESSION,nowrap);
compressor.setStrategy(Deflater.DEFAULT_STRATEGY);
decompressor = new Inflater(nowrap);
FrameEntry entry = new FrameEntry(frame, callback);
LOG.debug("Queuing {}", entry);
entries.offer(entry);
flusher.iterate();
}
@Override
public String toString()
{
return this.getClass().getSimpleName() + "[]";
return getClass().getSimpleName();
}
private static class FrameEntry
{
private final Frame frame;
private final WriteCallback callback;
private FrameEntry(Frame frame, WriteCallback callback)
{
this.frame = frame;
this.callback = callback;
}
@Override
public String toString()
{
return frame.toString();
}
}
private class Flusher extends IteratingCallback implements WriteCallback
{
private FrameEntry current;
private int inputLength = 64 * 1024;
private ByteBuffer payload;
private boolean finished = true;
@Override
protected Action process() throws Exception
{
if (finished)
{
current = entries.poll();
LOG.debug("Processing {}", current);
if (current == null)
return Action.IDLE;
deflate(current);
}
else
{
compress(current.frame);
}
return Action.SCHEDULED;
}
private void deflate(FrameEntry entry)
{
Frame frame = entry.frame;
if (OpCode.isControlFrame(frame.getOpCode()))
{
// Skip, cannot compress control frames.
nextOutgoingFrame(frame, this);
return;
}
if (!frame.hasPayload())
{
// Pass through, nothing to do
nextOutgoingFrame(frame, this);
return;
}
compress(frame);
}
private void compress(Frame frame)
{
// Get a chunk of the payload to avoid to blow
// the heap if the payload is a huge mapped file.
ByteBuffer data = frame.getPayload();
int remaining = data.remaining();
byte[] input = new byte[Math.min(remaining, inputLength)];
int length = Math.min(remaining, input.length);
LOG.debug("Compressing {}: {} bytes in {} bytes chunk", frame, remaining, length);
finished = length == remaining;
data.get(input, 0, length);
compressor.setInput(input, 0, length);
// Use an additional space in case the content is not compressible.
byte[] output = new byte[length + 64];
int offset = 0;
int total = 0;
while (true)
{
int space = output.length - offset;
int compressed = compressor.deflate(output, offset, space, Deflater.SYNC_FLUSH);
total += compressed;
if (compressed < space)
{
// Everything was compressed.
break;
}
else
{
// The compressed output is bigger than the uncompressed input.
byte[] newOutput = new byte[output.length * 2];
System.arraycopy(output, 0, newOutput, 0, output.length);
offset += output.length;
output = newOutput;
}
}
payload = getBufferPool().acquire(total, true);
BufferUtil.flipToFill(payload);
// Skip the last tail bytes bytes generated by SYNC_FLUSH
payload.put(output, 0, total - TAIL_BYTES.length).flip();
LOG.debug("Compressed {}: {}->{} chunk bytes", frame, length, total);
DataFrame chunk = new DataFrame(frame);
chunk.setRsv1(true);
chunk.setPayload(payload);
chunk.setFin(finished);
nextOutgoingFrame(chunk, this);
}
@Override
protected void completed()
{
// This IteratingCallback never completes.
}
@Override
public void writeSuccess()
{
getBufferPool().release(payload);
if (finished)
notifyCallbackSuccess(current.callback);
succeeded();
}
@Override
public void writeFailed(Throwable x)
{
getBufferPool().release(payload);
notifyCallbackFailure(current.callback, x);
// If something went wrong, very likely the compression context
// will be invalid, so we need to fail this IteratingCallback.
failed(x);
// Now no more frames can be queued, fail those in the queue.
FrameEntry entry;
while ((entry = entries.poll()) != null)
notifyCallbackFailure(entry.callback, x);
}
private void notifyCallbackSuccess(WriteCallback callback)
{
try
{
if (callback != null)
callback.writeSuccess();
}
catch (Throwable x)
{
LOG.debug("Exception while notifying success of callback " + callback, x);
}
}
private void notifyCallbackFailure(WriteCallback callback, Throwable failure)
{
try
{
if (callback != null)
callback.writeFailed(failure);
}
catch (Throwable x)
{
LOG.debug("Exception while notifying failure of callback " + callback, x);
}
}
}
}

View File

@ -144,7 +144,7 @@ public class PerMessageDeflateExtension extends AbstractExtension
}
if (len > 0)
{
accumulator.addBuffer(outbuf,0,len);
accumulator.addChunk(outbuf, 0, len);
}
}
catch (DataFormatException e)
@ -154,9 +154,18 @@ public class PerMessageDeflateExtension extends AbstractExtension
}
}
// Forward on the frame
out.setPayload(accumulator.getByteBuffer(getBufferPool()));
nextIncomingFrame(out);
ByteBuffer buffer = getBufferPool().acquire(accumulator.getLength(), false);
try
{
BufferUtil.flipToFill(buffer);
accumulator.transferTo(buffer);
out.setPayload(buffer);
nextIncomingFrame(out);
}
finally
{
getBufferPool().release(buffer);
}
}
/**

View File

@ -18,29 +18,34 @@
package org.eclipse.jetty.websocket.common.extensions.compress;
import static org.hamcrest.Matchers.*;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.zip.Deflater;
import java.util.zip.Inflater;
import org.eclipse.jetty.io.MappedByteBufferPool;
import org.eclipse.jetty.io.RuntimeIOException;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.StringUtil;
import org.eclipse.jetty.util.TypeUtil;
import org.eclipse.jetty.websocket.api.WebSocketPolicy;
import org.eclipse.jetty.websocket.api.WriteCallback;
import org.eclipse.jetty.websocket.api.extensions.ExtensionConfig;
import org.eclipse.jetty.websocket.api.extensions.Frame;
import org.eclipse.jetty.websocket.api.extensions.IncomingFrames;
import org.eclipse.jetty.websocket.api.extensions.OutgoingFrames;
import org.eclipse.jetty.websocket.common.Generator;
import org.eclipse.jetty.websocket.common.OpCode;
import org.eclipse.jetty.websocket.common.Parser;
import org.eclipse.jetty.websocket.common.WebSocketFrame;
import org.eclipse.jetty.websocket.common.extensions.AbstractExtensionTest;
import org.eclipse.jetty.websocket.common.extensions.ExtensionTool.Tester;
import org.eclipse.jetty.websocket.common.frames.BinaryFrame;
import org.eclipse.jetty.websocket.common.frames.TextFrame;
import org.eclipse.jetty.websocket.common.test.ByteBufferAssert;
import org.eclipse.jetty.websocket.common.test.IncomingFramesCapture;
@ -51,6 +56,10 @@ import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.is;
public class DeflateFrameExtensionTest extends AbstractExtensionTest
{
@Rule
@ -119,9 +128,9 @@ public class DeflateFrameExtensionTest extends AbstractExtensionTest
ext.setNextOutgoingFrames(capture);
Frame frame = new TextFrame().setPayload(text);
ext.outgoingFrame(frame,null);
ext.outgoingFrame(frame, null);
capture.assertBytes(0,expectedHex);
capture.assertBytes(0, expectedHex);
}
@Test
@ -372,4 +381,62 @@ public class DeflateFrameExtensionTest extends AbstractExtensionTest
{
assertOutgoing("There","c1070ac9482d4a0500");
}
@Test
public void testCompressAndDecompressBigPayload() throws Exception
{
byte[] input = new byte[1024 * 1024];
// Make them not compressible.
new Random().nextBytes(input);
DeflateFrameExtension clientExtension = new DeflateFrameExtension();
clientExtension.setBufferPool(bufferPool);
clientExtension.setPolicy(WebSocketPolicy.newClientPolicy());
clientExtension.setConfig(ExtensionConfig.parse("deflate-frame"));
final DeflateFrameExtension serverExtension = new DeflateFrameExtension();
serverExtension.setBufferPool(bufferPool);
serverExtension.setPolicy(WebSocketPolicy.newServerPolicy());
serverExtension.setConfig(ExtensionConfig.parse("deflate-frame"));
// Chain the next element to decompress.
clientExtension.setNextOutgoingFrames(new OutgoingFrames()
{
@Override
public void outgoingFrame(Frame frame, WriteCallback callback)
{
serverExtension.incomingFrame(frame);
callback.writeSuccess();
}
});
final ByteArrayOutputStream result = new ByteArrayOutputStream(input.length);
serverExtension.setNextIncomingFrames(new IncomingFrames()
{
@Override
public void incomingFrame(Frame frame)
{
try
{
result.write(BufferUtil.toArray(frame.getPayload()));
}
catch (IOException x)
{
throw new RuntimeIOException(x);
}
}
@Override
public void incomingError(Throwable t)
{
}
});
BinaryFrame frame = new BinaryFrame();
frame.setPayload(input);
frame.setFin(true);
clientExtension.outgoingFrame(frame, null);
Assert.assertArrayEquals(input, result.toByteArray());
}
}