427700 - Outgoing extensions that create multiple frames should flush

them in order and atomically.

Refactored PerMessageDeflateExtension and DeflateFrameExtension
introducing superclass CompressExtension that factors in common
functionalities.
This commit is contained in:
Simone Bordet 2014-02-14 21:30:07 +01:00
parent 81b8990dec
commit ad15b27a01
10 changed files with 570 additions and 550 deletions

View File

@ -0,0 +1,171 @@
//
// ========================================================================
// Copyright (c) 1995-2014 Mort Bay Consulting Pty. Ltd.
// ------------------------------------------------------------------------
// All rights reserved. This program and the accompanying materials
// are made available under the terms of the Eclipse Public License v1.0
// and Apache License v2.0 which accompanies this distribution.
//
// The Eclipse Public License is available at
// http://www.eclipse.org/legal/epl-v10.html
//
// The Apache License v2.0 is available at
// http://www.opensource.org/licenses/apache2.0.php
//
// You may elect to redistribute this code under either of these licenses.
// ========================================================================
//
package org.eclipse.jetty.websocket.jsr356.server;
import java.net.URI;
import java.util.Arrays;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javax.websocket.ClientEndpointConfig;
import javax.websocket.ContainerProvider;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
import javax.websocket.Extension;
import javax.websocket.MessageHandler;
import javax.websocket.SendHandler;
import javax.websocket.SendResult;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
import javax.websocket.server.ServerEndpointConfig;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.websocket.api.extensions.OutgoingFrames;
import org.eclipse.jetty.websocket.client.io.WebSocketClientConnection;
import org.eclipse.jetty.websocket.common.extensions.ExtensionStack;
import org.eclipse.jetty.websocket.common.extensions.compress.DeflateFrameExtension;
import org.eclipse.jetty.websocket.jsr356.JsrExtension;
import org.eclipse.jetty.websocket.jsr356.JsrSession;
import org.eclipse.jetty.websocket.jsr356.server.deploy.WebSocketServerContainerInitializer;
import org.eclipse.jetty.websocket.jsr356.server.samples.echo.BasicEchoEndpoint;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
public class ExtensionStackProcessingTest
{
private Server server;
private ServerConnector connector;
private WebSocketContainer client;
@Before
public void prepare() throws Exception
{
server = new Server();
connector = new ServerConnector(server);
server.addConnector(connector);
ServletContextHandler context = new ServletContextHandler(server, "/", true, false);
ServerContainer container = WebSocketServerContainerInitializer.configureContext(context);
ServerEndpointConfig config = ServerEndpointConfig.Builder.create(BasicEchoEndpoint.class, "/").build();
container.addEndpoint(config);
server.start();
client = ContainerProvider.getWebSocketContainer();
server.addBean(client, true);
}
@After
public void dispose() throws Exception
{
server.stop();
}
@Test
public void testDeflateFrameExtension() throws Exception
{
ClientEndpointConfig config = ClientEndpointConfig.Builder.create()
.extensions(Arrays.<Extension>asList(new JsrExtension("deflate-frame")))
.build();
final String content = "deflate_me";
final CountDownLatch messageLatch = new CountDownLatch(1);
URI uri = URI.create("ws://localhost:" + connector.getLocalPort());
Session session = client.connectToServer(new EndpointAdapter()
{
@Override
public void onMessage(String message)
{
Assert.assertEquals(content, message);
messageLatch.countDown();
}
}, config, uri);
// Make sure everything is wired properly.
OutgoingFrames firstOut = ((JsrSession)session).getOutgoingHandler();
Assert.assertTrue(firstOut instanceof ExtensionStack);
ExtensionStack extensionStack = (ExtensionStack)firstOut;
Assert.assertTrue(extensionStack.isRunning());
OutgoingFrames secondOut = extensionStack.getNextOutgoing();
Assert.assertTrue(secondOut instanceof DeflateFrameExtension);
DeflateFrameExtension deflateExtension = (DeflateFrameExtension)secondOut;
Assert.assertTrue(deflateExtension.isRunning());
OutgoingFrames thirdOut = deflateExtension.getNextOutgoing();
Assert.assertTrue(thirdOut instanceof WebSocketClientConnection);
final CountDownLatch latch = new CountDownLatch(1);
session.getAsyncRemote().sendText(content, new SendHandler()
{
@Override
public void onResult(SendResult result)
{
latch.countDown();
}
});
Assert.assertTrue(latch.await(5, TimeUnit.SECONDS));
Assert.assertTrue(messageLatch.await(5, TimeUnit.SECONDS));
}
@Test
public void testPerMessageDeflateExtension() throws Exception
{
ClientEndpointConfig config = ClientEndpointConfig.Builder.create()
.extensions(Arrays.<Extension>asList(new JsrExtension("permessage-deflate")))
.build();
final String content = "deflate_me";
final CountDownLatch messageLatch = new CountDownLatch(1);
URI uri = URI.create("ws://localhost:" + connector.getLocalPort());
Session session = client.connectToServer(new EndpointAdapter()
{
@Override
public void onMessage(String message)
{
Assert.assertEquals(content, message);
messageLatch.countDown();
}
}, config, uri);
final CountDownLatch latch = new CountDownLatch(1);
session.getAsyncRemote().sendText(content, new SendHandler()
{
@Override
public void onResult(SendResult result)
{
latch.countDown();
}
});
Assert.assertTrue(latch.await(5, TimeUnit.SECONDS));
Assert.assertTrue(messageLatch.await(5, TimeUnit.SECONDS));
}
private static abstract class EndpointAdapter extends Endpoint implements MessageHandler.Whole<String>
{
@Override
public void onOpen(Session session, EndpointConfig config)
{
session.addMessageHandler(this);
}
}
}

View File

@ -68,6 +68,11 @@ public interface Frame
return (opcode == TEXT.getOpCode()) | (opcode == BINARY.getOpCode());
}
public boolean isContinuation()
{
return opcode == CONTINUATION.getOpCode();
}
@Override
public String toString()
{

View File

@ -23,11 +23,10 @@ import java.net.CookieStore;
import java.net.SocketAddress;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.Executor;
import java.util.concurrent.Future;
@ -75,7 +74,6 @@ public class WebSocketClient extends ContainerLifeCycle implements SessionListen
private boolean daemon = false;
private EventDriverFactory eventDriverFactory;
private SessionFactory sessionFactory;
private Set<WebSocketSession> openSessions = new CopyOnWriteArraySet<>();
private ByteBufferPool bufferPool;
private Executor executor;
private Scheduler scheduler;
@ -374,7 +372,7 @@ public class WebSocketClient extends ContainerLifeCycle implements SessionListen
public Set<WebSocketSession> getOpenSessions()
{
return Collections.unmodifiableSet(this.openSessions);
return new HashSet<>(getBeans(WebSocketSession.class));
}
public WebSocketPolicy getPolicy()
@ -473,14 +471,13 @@ public class WebSocketClient extends ContainerLifeCycle implements SessionListen
public void onSessionClosed(WebSocketSession session)
{
LOG.info("Session Closed: {}",session);
this.openSessions.remove(session);
removeBean(session);
}
@Override
public void onSessionOpened(WebSocketSession session)
{
LOG.info("Session Opened: {}",session);
this.openSessions.add(session);
}
public void setAsyncWriteTimeout(long ms)

View File

@ -259,6 +259,9 @@ public class UpgradeConnection extends AbstractConnection
session.setOutgoingHandler(extensionStack);
extensionStack.setNextOutgoing(connection);
session.addBean(extensionStack);
connectPromise.getClient().addBean(session);
// Now swap out the connection
endp.setConnection(connection);
connection.onOpen();

View File

@ -81,20 +81,20 @@ public class ExtensionStack extends ContainerLifeCycle implements IncomingFrames
// Wire up Extensions
if ((extensions != null) && (extensions.size() > 0))
{
ListIterator<Extension> eiter = extensions.listIterator();
ListIterator<Extension> exts = extensions.listIterator();
// Connect outgoings
while (eiter.hasNext())
while (exts.hasNext())
{
Extension ext = eiter.next();
Extension ext = exts.next();
ext.setNextOutgoingFrames(nextOutgoing);
nextOutgoing = ext;
}
// Connect incomings
while (eiter.hasPrevious())
while (exts.hasPrevious())
{
Extension ext = eiter.previous();
Extension ext = exts.previous();
ext.setNextIncomingFrames(nextIncoming);
nextIncoming = ext;
}
@ -252,6 +252,8 @@ public class ExtensionStack extends ContainerLifeCycle implements IncomingFrames
// Add Extension
extensions.add(ext);
addBean(ext);
LOG.debug("Adding Extension: {}",config);
// Record RSV Claims
@ -268,8 +270,6 @@ public class ExtensionStack extends ContainerLifeCycle implements IncomingFrames
rsvClaims[2] = ext.getName();
}
}
addBean(extensions);
}
@Override

View File

@ -0,0 +1,337 @@
//
// ========================================================================
// Copyright (c) 1995-2014 Mort Bay Consulting Pty. Ltd.
// ------------------------------------------------------------------------
// All rights reserved. This program and the accompanying materials
// are made available under the terms of the Eclipse Public License v1.0
// and Apache License v2.0 which accompanies this distribution.
//
// The Eclipse Public License is available at
// http://www.eclipse.org/legal/epl-v10.html
//
// The Apache License v2.0 is available at
// http://www.opensource.org/licenses/apache2.0.php
//
// You may elect to redistribute this code under either of these licenses.
// ========================================================================
//
package org.eclipse.jetty.websocket.common.extensions.compress;
import java.nio.ByteBuffer;
import java.util.Queue;
import java.util.zip.DataFormatException;
import java.util.zip.Deflater;
import java.util.zip.Inflater;
import java.util.zip.ZipException;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.ConcurrentArrayQueue;
import org.eclipse.jetty.util.IteratingCallback;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.websocket.api.BadPayloadException;
import org.eclipse.jetty.websocket.api.WriteCallback;
import org.eclipse.jetty.websocket.api.extensions.Frame;
import org.eclipse.jetty.websocket.common.OpCode;
import org.eclipse.jetty.websocket.common.extensions.AbstractExtension;
import org.eclipse.jetty.websocket.common.frames.DataFrame;
public abstract class CompressExtension extends AbstractExtension
{
protected static final byte[] TAIL_BYTES = new byte[]{0x00, 0x00, (byte)0xFF, (byte)0xFF};
private static final Logger LOG = Log.getLogger(CompressExtension.class);
private final Queue<FrameEntry> entries = new ConcurrentArrayQueue<>();
private final IteratingCallback flusher = new Flusher();
private final Deflater compressor;
private final Inflater decompressor;
protected CompressExtension()
{
compressor = new Deflater(Deflater.BEST_COMPRESSION, true);
decompressor = new Inflater(true);
}
public Deflater getDeflater()
{
return compressor;
}
public Inflater getInflater()
{
return decompressor;
}
/**
* Indicates use of RSV1 flag for indicating deflation is in use.
*/
@Override
public boolean isRsv1User()
{
return true;
}
protected void forwardIncoming(Frame frame, ByteAccumulator accumulator)
{
DataFrame newFrame = new DataFrame(frame);
// Unset RSV1 since it's not compressed anymore.
newFrame.setRsv1(false);
ByteBuffer buffer = getBufferPool().acquire(accumulator.getLength(), false);
try
{
BufferUtil.flipToFill(buffer);
accumulator.transferTo(buffer);
newFrame.setPayload(buffer);
nextIncomingFrame(newFrame);
}
finally
{
getBufferPool().release(buffer);
}
}
protected ByteAccumulator decompress(byte[] input)
{
// Since we don't track text vs binary vs continuation state, just grab whatever is the greater value.
int maxSize = Math.max(getPolicy().getMaxTextMessageSize(), getPolicy().getMaxBinaryMessageBufferSize());
ByteAccumulator accumulator = new ByteAccumulator(maxSize);
decompressor.setInput(input, 0, input.length);
LOG.debug("Decompressing {} bytes", input.length);
try
{
// It is allowed to send DEFLATE blocks with BFINAL=1.
// For such blocks, getRemaining() will be > 0 but finished()
// will be true, so we need to check for both.
// When BFINAL=0, finished() will always be false and we only
// check the remaining bytes.
while (decompressor.getRemaining() > 0 && !decompressor.finished())
{
byte[] output = new byte[Math.min(input.length * 2, 64 * 1024)];
int decompressed = decompressor.inflate(output);
if (decompressed == 0)
{
if (decompressor.needsInput())
{
throw new BadPayloadException("Unable to inflate frame, not enough input on frame");
}
if (decompressor.needsDictionary())
{
throw new BadPayloadException("Unable to inflate frame, frame erroneously says it needs a dictionary");
}
}
else
{
accumulator.addChunk(output, 0, decompressed);
}
}
LOG.debug("Decompressed {}->{} bytes", input.length, accumulator.getLength());
return accumulator;
}
catch (DataFormatException x)
{
throw new BadPayloadException(x);
}
}
@Override
public void outgoingFrame(Frame frame, WriteCallback callback)
{
// We use a queue and an IteratingCallback to handle concurrency.
// We must compress and write atomically, otherwise the compression
// context on the other end gets confused.
if (flusher.isFailed())
{
notifyCallbackFailure(callback, new ZipException());
return;
}
FrameEntry entry = new FrameEntry(frame, callback);
LOG.debug("Queuing {}", entry);
entries.offer(entry);
flusher.iterate();
}
@Override
public String toString()
{
return getClass().getSimpleName();
}
private static class FrameEntry
{
private final Frame frame;
private final WriteCallback callback;
private FrameEntry(Frame frame, WriteCallback callback)
{
this.frame = frame;
this.callback = callback;
}
@Override
public String toString()
{
return frame.toString();
}
}
private class Flusher extends IteratingCallback implements WriteCallback
{
private FrameEntry current;
private int inputLength = 64 * 1024;
private ByteBuffer payload;
private boolean finished = true;
@Override
protected Action process() throws Exception
{
if (finished)
{
current = entries.poll();
LOG.debug("Processing {}", current);
if (current == null)
return Action.IDLE;
deflate(current);
}
else
{
compress(current.frame, false);
}
return Action.SCHEDULED;
}
private void deflate(FrameEntry entry)
{
Frame frame = entry.frame;
if (OpCode.isControlFrame(frame.getOpCode()))
{
// Skip, cannot compress control frames.
nextOutgoingFrame(frame, this);
return;
}
if (!frame.hasPayload())
{
// Pass through, nothing to do
nextOutgoingFrame(frame, this);
return;
}
compress(frame, true);
}
private void compress(Frame frame, boolean first)
{
// Get a chunk of the payload to avoid to blow
// the heap if the payload is a huge mapped file.
ByteBuffer data = frame.getPayload();
int remaining = data.remaining();
byte[] input = new byte[Math.min(remaining, inputLength)];
int length = Math.min(remaining, input.length);
LOG.debug("Compressing {}: {} bytes in {} bytes chunk", frame, remaining, length);
finished = length == remaining;
data.get(input, 0, length);
compressor.setInput(input, 0, length);
// Use an additional space in case the content is not compressible.
byte[] output = new byte[length + 64];
int offset = 0;
int total = 0;
while (true)
{
int space = output.length - offset;
int compressed = compressor.deflate(output, offset, space, Deflater.SYNC_FLUSH);
total += compressed;
if (compressed < space)
{
// Everything was compressed.
break;
}
else
{
// The compressed output is bigger than the uncompressed input.
byte[] newOutput = new byte[output.length * 2];
System.arraycopy(output, 0, newOutput, 0, output.length);
offset += output.length;
output = newOutput;
}
}
payload = getBufferPool().acquire(total, true);
BufferUtil.flipToFill(payload);
// Skip the last tail bytes bytes generated by SYNC_FLUSH
payload.put(output, 0, total - TAIL_BYTES.length).flip();
LOG.debug("Compressed {}: {}->{} chunk bytes", frame, length, total);
boolean continuation = frame.getType().isContinuation() || !first;
DataFrame chunk = new DataFrame(frame, continuation);
chunk.setRsv1(true);
chunk.setPayload(payload);
boolean fin = frame.isFin() && finished;
chunk.setFin(fin);
nextOutgoingFrame(chunk, this);
}
@Override
protected void completed()
{
// This IteratingCallback never completes.
}
@Override
public void writeSuccess()
{
getBufferPool().release(payload);
if (finished)
notifyCallbackSuccess(current.callback);
succeeded();
}
@Override
public void writeFailed(Throwable x)
{
getBufferPool().release(payload);
notifyCallbackFailure(current.callback, x);
// If something went wrong, very likely the compression context
// will be invalid, so we need to fail this IteratingCallback.
failed(x);
// Now no more frames can be queued, fail those in the queue.
FrameEntry entry;
while ((entry = entries.poll()) != null)
notifyCallbackFailure(entry.callback, x);
}
}
protected void notifyCallbackSuccess(WriteCallback callback)
{
try
{
if (callback != null)
callback.writeSuccess();
}
catch (Throwable x)
{
LOG.debug("Exception while notifying success of callback " + callback, x);
}
}
protected void notifyCallbackFailure(WriteCallback callback, Throwable failure)
{
try
{
if (callback != null)
callback.writeFailed(failure);
}
catch (Throwable x)
{
LOG.debug("Exception while notifying failure of callback " + callback, x);
}
}
}

View File

@ -19,45 +19,17 @@
package org.eclipse.jetty.websocket.common.extensions.compress;
import java.nio.ByteBuffer;
import java.util.Queue;
import java.util.zip.DataFormatException;
import java.util.zip.Deflater;
import java.util.zip.Inflater;
import java.util.zip.ZipException;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.ConcurrentArrayQueue;
import org.eclipse.jetty.util.IteratingCallback;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.websocket.api.BadPayloadException;
import org.eclipse.jetty.websocket.api.WriteCallback;
import org.eclipse.jetty.websocket.api.extensions.Frame;
import org.eclipse.jetty.websocket.common.OpCode;
import org.eclipse.jetty.websocket.common.extensions.AbstractExtension;
import org.eclipse.jetty.websocket.common.frames.DataFrame;
/**
* Implementation of the
* <a href="https://tools.ietf.org/id/draft-tyoshino-hybi-websocket-perframe-deflate.txt">deflate-frame</a>
* extension seen out in the wild.
*/
public class DeflateFrameExtension extends AbstractExtension
public class DeflateFrameExtension extends CompressExtension
{
private static final Logger LOG = Log.getLogger(DeflateFrameExtension.class);
private static final byte[] TAIL_BYTES = new byte[]{0x00, 0x00, (byte)0xFF, (byte)0xFF};
private final Queue<FrameEntry> entries = new ConcurrentArrayQueue<>();
private final IteratingCallback flusher = new Flusher();
private final Deflater compressor;
private final Inflater decompressor;
public DeflateFrameExtension()
{
compressor = new Deflater(Deflater.BEST_COMPRESSION, true);
decompressor = new Inflater(true);
}
@Override
public String getName()
{
@ -71,16 +43,8 @@ public class DeflateFrameExtension extends AbstractExtension
// they are read and parsed with a single thread, and
// therefore there is no need for synchronization.
if (OpCode.isControlFrame(frame.getOpCode()) || !frame.isRsv1())
if (OpCode.isControlFrame(frame.getOpCode()) || !frame.isRsv1() || !frame.hasPayload())
{
// Cannot modify incoming control frames or ones without RSV1 set.
nextIncomingFrame(frame);
return;
}
if (!frame.hasPayload())
{
// No payload ? Nothing to do.
nextIncomingFrame(frame);
return;
}
@ -91,259 +55,6 @@ public class DeflateFrameExtension extends AbstractExtension
payload.get(input, 0, remaining);
System.arraycopy(TAIL_BYTES, 0, input, remaining, TAIL_BYTES.length);
// Since we don't track text vs binary vs continuation state, just grab whatever is the greater value.
int maxSize = Math.max(getPolicy().getMaxTextMessageSize(), getPolicy().getMaxBinaryMessageBufferSize());
ByteAccumulator accumulator = new ByteAccumulator(maxSize);
DataFrame out = new DataFrame(frame);
// Unset RSV1 since it's not compressed anymore.
out.setRsv1(false);
decompressor.setInput(input, 0, input.length);
try
{
while (decompressor.getRemaining() > 0)
{
byte[] output = new byte[Math.min(remaining * 2, 64 * 1024)];
int len = decompressor.inflate(output);
if (len == 0)
{
if (decompressor.needsInput())
{
throw new BadPayloadException("Unable to inflate frame, not enough input on frame");
}
if (decompressor.needsDictionary())
{
throw new BadPayloadException("Unable to inflate frame, frame erroneously says it needs a dictionary");
}
}
else
{
accumulator.addChunk(output, 0, len);
}
}
}
catch (DataFormatException x)
{
throw new BadPayloadException(x);
}
ByteBuffer buffer = getBufferPool().acquire(accumulator.getLength(), false);
try
{
BufferUtil.flipToFill(buffer);
accumulator.transferTo(buffer);
out.setPayload(buffer);
nextIncomingFrame(out);
}
finally
{
getBufferPool().release(buffer);
}
}
/**
* Indicates use of RSV1 flag for indicating deflation is in use.
* <p/>
* Also known as the "COMP" framing header bit
*/
@Override
public boolean isRsv1User()
{
return true;
}
@Override
public void outgoingFrame(Frame frame, WriteCallback callback)
{
if (flusher.isFailed())
{
if (callback != null)
callback.writeFailed(new ZipException());
return;
}
FrameEntry entry = new FrameEntry(frame, callback);
LOG.debug("Queuing {}", entry);
entries.offer(entry);
flusher.iterate();
}
@Override
public String toString()
{
return getClass().getSimpleName();
}
private static class FrameEntry
{
private final Frame frame;
private final WriteCallback callback;
private FrameEntry(Frame frame, WriteCallback callback)
{
this.frame = frame;
this.callback = callback;
}
@Override
public String toString()
{
return frame.toString();
}
}
private class Flusher extends IteratingCallback implements WriteCallback
{
private FrameEntry current;
private int inputLength = 64 * 1024;
private ByteBuffer payload;
private boolean finished = true;
@Override
protected Action process() throws Exception
{
if (finished)
{
current = entries.poll();
LOG.debug("Processing {}", current);
if (current == null)
return Action.IDLE;
deflate(current);
}
else
{
compress(current.frame);
}
return Action.SCHEDULED;
}
private void deflate(FrameEntry entry)
{
Frame frame = entry.frame;
if (OpCode.isControlFrame(frame.getOpCode()))
{
// Skip, cannot compress control frames.
nextOutgoingFrame(frame, this);
return;
}
if (!frame.hasPayload())
{
// Pass through, nothing to do
nextOutgoingFrame(frame, this);
return;
}
compress(frame);
}
private void compress(Frame frame)
{
// Get a chunk of the payload to avoid to blow
// the heap if the payload is a huge mapped file.
ByteBuffer data = frame.getPayload();
int remaining = data.remaining();
byte[] input = new byte[Math.min(remaining, inputLength)];
int length = Math.min(remaining, input.length);
LOG.debug("Compressing {}: {} bytes in {} bytes chunk", frame, remaining, length);
finished = length == remaining;
data.get(input, 0, length);
compressor.setInput(input, 0, length);
// Use an additional space in case the content is not compressible.
byte[] output = new byte[length + 64];
int offset = 0;
int total = 0;
while (true)
{
int space = output.length - offset;
int compressed = compressor.deflate(output, offset, space, Deflater.SYNC_FLUSH);
total += compressed;
if (compressed < space)
{
// Everything was compressed.
break;
}
else
{
// The compressed output is bigger than the uncompressed input.
byte[] newOutput = new byte[output.length * 2];
System.arraycopy(output, 0, newOutput, 0, output.length);
offset += output.length;
output = newOutput;
}
}
payload = getBufferPool().acquire(total, true);
BufferUtil.flipToFill(payload);
// Skip the last tail bytes bytes generated by SYNC_FLUSH
payload.put(output, 0, total - TAIL_BYTES.length).flip();
LOG.debug("Compressed {}: {}->{} chunk bytes", frame, length, total);
DataFrame chunk = new DataFrame(frame);
chunk.setRsv1(true);
chunk.setPayload(payload);
chunk.setFin(finished);
nextOutgoingFrame(chunk, this);
}
@Override
protected void completed()
{
// This IteratingCallback never completes.
}
@Override
public void writeSuccess()
{
getBufferPool().release(payload);
if (finished)
notifyCallbackSuccess(current.callback);
succeeded();
}
@Override
public void writeFailed(Throwable x)
{
getBufferPool().release(payload);
notifyCallbackFailure(current.callback, x);
// If something went wrong, very likely the compression context
// will be invalid, so we need to fail this IteratingCallback.
failed(x);
// Now no more frames can be queued, fail those in the queue.
FrameEntry entry;
while ((entry = entries.poll()) != null)
notifyCallbackFailure(entry.callback, x);
}
private void notifyCallbackSuccess(WriteCallback callback)
{
try
{
if (callback != null)
callback.writeSuccess();
}
catch (Throwable x)
{
LOG.debug("Exception while notifying success of callback " + callback, x);
}
}
private void notifyCallbackFailure(WriteCallback callback, Throwable failure)
{
try
{
if (callback != null)
callback.writeFailed(failure);
}
catch (Throwable x)
{
LOG.debug("Exception while notifying failure of callback " + callback, x);
}
}
forwardIncoming(frame, decompress(input));
}
}

View File

@ -19,49 +19,28 @@
package org.eclipse.jetty.websocket.common.extensions.compress;
import java.nio.ByteBuffer;
import java.util.zip.DataFormatException;
import java.util.zip.Deflater;
import java.util.zip.Inflater;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.websocket.api.BadPayloadException;
import org.eclipse.jetty.websocket.api.WriteCallback;
import org.eclipse.jetty.websocket.api.extensions.ExtensionConfig;
import org.eclipse.jetty.websocket.api.extensions.Frame;
import org.eclipse.jetty.websocket.common.OpCode;
import org.eclipse.jetty.websocket.common.extensions.AbstractExtension;
import org.eclipse.jetty.websocket.common.frames.DataFrame;
/**
* Per Message Deflate Compression extension for WebSocket.
* <p>
* <p/>
* Attempts to follow <a href="https://tools.ietf.org/html/draft-ietf-hybi-permessage-compression-12">draft-ietf-hybi-permessage-compression-12</a>
*/
public class PerMessageDeflateExtension extends AbstractExtension
public class PerMessageDeflateExtension extends CompressExtension
{
private static final boolean BFINAL_HACK = Boolean.parseBoolean(System.getProperty("jetty.websocket.bfinal.hack","true"));
private static final Logger LOG = Log.getLogger(PerMessageDeflateExtension.class);
private static final int OVERHEAD = 64;
/** Tail Bytes per Spec */
private static final byte[] TAIL = new byte[]
{ 0x00, 0x00, (byte)0xFF, (byte)0xFF };
private ExtensionConfig configRequested;
private ExtensionConfig configNegotiated;
private Deflater compressor;
private Inflater decompressor;
private boolean incomingCompressed = false;
private boolean outgoingCompressed = false;
/**
* Context Takeover Control.
* <p>
* If true, the same LZ77 window is used between messages. Can be overridden with extension parameters.
*/
private boolean incomingContextTakeover = true;
private boolean outgoingContextTakeover = true;
private boolean incomingCompressed;
@Override
public String getName()
@ -70,213 +49,36 @@ public class PerMessageDeflateExtension extends AbstractExtension
}
@Override
public synchronized void incomingFrame(Frame frame)
public void incomingFrame(Frame frame)
{
switch (frame.getOpCode())
{
case OpCode.BINARY: // fall-thru
case OpCode.TEXT:
// Incoming frames are always non concurrent because
// they are read and parsed with a single thread, and
// therefore there is no need for synchronization.
// This extension requires the RSV1 bit set only in the first frame.
// Subsequent continuation frames don't have RSV1 set, but are compressed.
if (frame.getType().isData())
incomingCompressed = frame.isRsv1();
break;
case OpCode.CONTINUATION:
if (!incomingCompressed)
if (OpCode.isControlFrame(frame.getOpCode()) || !frame.hasPayload() || !incomingCompressed)
{
nextIncomingFrame(frame);
}
break;
default:
// All others are assumed to be control frames
nextIncomingFrame(frame);
return;
}
if (!incomingCompressed || !frame.hasPayload())
{
// nothing to do with this frame
nextIncomingFrame(frame);
return;
}
// Prime the decompressor
boolean appendTail = frame.isFin();
ByteBuffer payload = frame.getPayload();
int inlen = payload.remaining();
byte compressed[] = null;
int remaining = payload.remaining();
byte[] input = new byte[remaining + (appendTail ? TAIL_BYTES.length : 0)];
payload.get(input, 0, remaining);
if (appendTail)
System.arraycopy(TAIL_BYTES, 0, input, remaining, TAIL_BYTES.length);
forwardIncoming(frame, decompress(input));
if (frame.isFin())
{
compressed = new byte[inlen + TAIL.length];
payload.get(compressed,0,inlen);
System.arraycopy(TAIL,0,compressed,inlen,TAIL.length);
incomingCompressed = false;
}
else
{
compressed = new byte[inlen];
payload.get(compressed,0,inlen);
}
decompressor.setInput(compressed,0,compressed.length);
// Since we don't track text vs binary vs continuation state, just grab whatever is the greater value.
int maxSize = Math.max(getPolicy().getMaxTextMessageSize(),getPolicy().getMaxBinaryMessageBufferSize());
ByteAccumulator accumulator = new ByteAccumulator(maxSize);
DataFrame out = new DataFrame(frame);
out.setRsv1(false); // Unset RSV1
// Perform decompression
while (decompressor.getRemaining() > 0 && !decompressor.finished())
{
byte outbuf[] = new byte[inlen];
try
{
int len = decompressor.inflate(outbuf);
if (len == 0)
{
if (decompressor.needsInput())
{
throw new BadPayloadException("Unable to inflate frame, not enough input on frame");
}
if (decompressor.needsDictionary())
{
throw new BadPayloadException("Unable to inflate frame, frame erroneously says it needs a dictionary");
}
}
if (len > 0)
{
accumulator.addChunk(outbuf, 0, len);
}
}
catch (DataFormatException e)
{
LOG.warn(e);
throw new BadPayloadException(e);
}
}
ByteBuffer buffer = getBufferPool().acquire(accumulator.getLength(), false);
try
{
BufferUtil.flipToFill(buffer);
accumulator.transferTo(buffer);
out.setPayload(buffer);
nextIncomingFrame(out);
}
finally
{
getBufferPool().release(buffer);
}
}
/**
* Indicates use of RSV1 flag for indicating deflation is in use.
*/
@Override
public boolean isRsv1User()
{
return true;
}
@Override
public synchronized void outgoingFrame(Frame frame, WriteCallback callback)
{
if (OpCode.isControlFrame(frame.getOpCode()))
{
// skip, cannot compress control frames.
nextOutgoingFrame(frame,callback);
return;
}
if (!frame.hasPayload())
{
// pass through, nothing to do
nextOutgoingFrame(frame,callback);
return;
}
if (LOG.isDebugEnabled())
{
LOG.debug("outgoingFrame({}, {}) - {}",OpCode.name(frame.getOpCode()),callback != null?callback.getClass().getSimpleName():"<null>",
BufferUtil.toDetailString(frame.getPayload()));
}
// Prime the compressor
byte uncompressed[] = BufferUtil.toArray(frame.getPayload());
// Perform the compression
if (!compressor.finished())
{
compressor.setInput(uncompressed,0,uncompressed.length);
byte compressed[] = new byte[uncompressed.length + OVERHEAD];
while (!compressor.needsInput())
{
int len = compressor.deflate(compressed,0,compressed.length,Deflater.SYNC_FLUSH);
ByteBuffer outbuf = getBufferPool().acquire(len,true);
BufferUtil.clearToFill(outbuf);
if (len > 0)
{
if (len > 4)
{
// Test for the 4 tail octets (0x00 0x00 0xff 0xff)
int idx = len - 4;
boolean found = true;
for (int n = 0; n < TAIL.length; n++)
{
if (compressed[idx + n] != TAIL[n])
{
found = false;
break;
}
}
if (found)
{
len = len - 4;
}
}
outbuf.put(compressed,0,len);
}
BufferUtil.flipToFlush(outbuf,0);
if (len > 0 && BFINAL_HACK)
{
/*
* Per the spec, it says that BFINAL 1 or 0 are allowed.
*
* However, Java always uses BFINAL 1, whereas the browsers Chromium and Safari fail to decompress when it encounters BFINAL 1.
*
* This hack will always set BFINAL 0
*/
byte b0 = outbuf.get(0);
if ((b0 & 1) != 0) // if BFINAL 1
{
outbuf.put(0,(b0 ^= 1)); // flip bit to BFINAL 0
}
}
DataFrame out = new DataFrame(frame,outgoingCompressed);
out.setRsv1(true);
out.setBufferPool(getBufferPool());
out.setPayload(outbuf);
if (!compressor.needsInput())
{
// this is fragmented
out.setFin(false);
nextOutgoingFrame(out,null); // non final frames have no callback
}
else
{
// pass through the callback
nextOutgoingFrame(out,callback);
}
outgoingCompressed = !out.isFin();
}
}
}
@Override
protected void nextIncomingFrame(Frame frame)
@ -284,9 +86,8 @@ public class PerMessageDeflateExtension extends AbstractExtension
if (frame.isFin() && !incomingContextTakeover)
{
LOG.debug("Incoming Context Reset");
decompressor.reset();
getInflater().reset();
}
super.nextIncomingFrame(frame);
}
@ -296,9 +97,8 @@ public class PerMessageDeflateExtension extends AbstractExtension
if (frame.isFin() && !outgoingContextTakeover)
{
LOG.debug("Outgoing Context Reset");
compressor.reset();
getDeflater().reset();
}
super.nextOutgoingFrame(frame, callback);
}
@ -308,23 +108,20 @@ public class PerMessageDeflateExtension extends AbstractExtension
configRequested = new ExtensionConfig(config);
configNegotiated = new ExtensionConfig(config.getName());
boolean nowrap = true;
compressor = new Deflater(Deflater.BEST_COMPRESSION,nowrap);
compressor.setStrategy(Deflater.DEFAULT_STRATEGY);
decompressor = new Inflater(nowrap);
for (String key : config.getParameterKeys())
{
key = key.trim();
switch (key)
{
case "client_max_window_bits": // fallthru
case "client_max_window_bits":
case "server_max_window_bits":
{
// Not supported by Jetty
// Don't negotiate these parameters
break;
}
case "client_no_context_takeover":
{
configNegotiated.setParameter("client_no_context_takeover");
switch (getPolicy().getBehavior())
{
@ -336,7 +133,9 @@ public class PerMessageDeflateExtension extends AbstractExtension
break;
}
break;
}
case "server_no_context_takeover":
{
configNegotiated.setParameter("server_no_context_takeover");
switch (getPolicy().getBehavior())
{
@ -349,6 +148,11 @@ public class PerMessageDeflateExtension extends AbstractExtension
}
break;
}
default:
{
throw new IllegalArgumentException();
}
}
}
super.setConfig(configNegotiated);
@ -357,11 +161,9 @@ public class PerMessageDeflateExtension extends AbstractExtension
@Override
public String toString()
{
StringBuilder str = new StringBuilder();
str.append(this.getClass().getSimpleName());
str.append("[requested=").append(configRequested.getParameterizedName());
str.append(",negotiated=").append(configNegotiated.getParameterizedName());
str.append(']');
return str.toString();
return String.format("%s[requested=%s,negotiated=%s]",
getClass().getSimpleName(),
configRequested.getParameterizedName(),
configNegotiated.getParameterizedName());
}
}

View File

@ -29,10 +29,4 @@ public class XWebkitDeflateFrameExtension extends DeflateFrameExtension
{
return "x-webkit-deflate-frame";
}
@Override
public String toString()
{
return this.getClass().getSimpleName() + "[]";
}
}

View File

@ -18,8 +18,6 @@
package org.eclipse.jetty.websocket.common.extensions.compress;
import static org.hamcrest.Matchers.*;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
@ -46,6 +44,8 @@ import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import static org.hamcrest.Matchers.is;
/**
* Client side behavioral tests for permessage-deflate extension.
* <p>