diff --git a/jetty-websocket/websocket-core/src/main/java/org/eclipse/jetty/websocket/core/CloseStatus.java b/jetty-websocket/websocket-core/src/main/java/org/eclipse/jetty/websocket/core/CloseStatus.java index b46407f2c47..3655cc286b3 100644 --- a/jetty-websocket/websocket-core/src/main/java/org/eclipse/jetty/websocket/core/CloseStatus.java +++ b/jetty-websocket/websocket-core/src/main/java/org/eclipse/jetty/websocket/core/CloseStatus.java @@ -175,16 +175,8 @@ public class CloseStatus // TODO consider defining a precedence for every CloseStatus, and change SessionState only if higher precedence public static boolean isOrdinary(CloseStatus closeStatus) { - switch (closeStatus.getCode()) - { - case NORMAL: - case SHUTDOWN: - case NO_CODE: - return true; - - default: - return false; - } + int code = closeStatus.getCode(); + return (code == NORMAL || code == NO_CODE || code >= 3000); } public int getCode() @@ -291,8 +283,8 @@ public class CloseStatus public Frame toFrame() { if (isTransmittableStatusCode(code)) - return new CloseFrame(this, OpCode.CLOSE, true, asPayloadBuffer(code, reason)); - return new CloseFrame(this, OpCode.CLOSE); + return new CloseFrame(OpCode.CLOSE, true, asPayloadBuffer(code, reason)); + return new CloseFrame(OpCode.CLOSE); } public static Frame toFrame(int closeStatus) @@ -356,12 +348,12 @@ public class CloseStatus class CloseFrame extends Frame implements CloseStatus.Supplier { - public CloseFrame(CloseStatus closeStatus, byte opcode) + public CloseFrame(byte opcode) { super(opcode); } - public CloseFrame(CloseStatus closeStatus, byte opCode, boolean fin, ByteBuffer payload) + public CloseFrame(byte opCode, boolean fin, ByteBuffer payload) { super(opCode, fin, payload); } diff --git a/jetty-websocket/websocket-core/src/main/java/org/eclipse/jetty/websocket/core/internal/FrameFlusher.java b/jetty-websocket/websocket-core/src/main/java/org/eclipse/jetty/websocket/core/internal/FrameFlusher.java index 0e6298138e1..4f8cb909f66 100644 --- a/jetty-websocket/websocket-core/src/main/java/org/eclipse/jetty/websocket/core/internal/FrameFlusher.java +++ b/jetty-websocket/websocket-core/src/main/java/org/eclipse/jetty/websocket/core/internal/FrameFlusher.java @@ -38,10 +38,14 @@ import org.eclipse.jetty.util.TypeUtil; import org.eclipse.jetty.util.log.Log; import org.eclipse.jetty.util.log.Logger; import org.eclipse.jetty.util.thread.Scheduler; +import org.eclipse.jetty.websocket.core.CloseStatus; import org.eclipse.jetty.websocket.core.Frame; import org.eclipse.jetty.websocket.core.OpCode; +import org.eclipse.jetty.websocket.core.WebSocketException; import org.eclipse.jetty.websocket.core.WebSocketWriteTimeoutException; +import static org.eclipse.jetty.websocket.core.internal.WebSocketCoreSession.AbnormalCloseStatus; + public class FrameFlusher extends IteratingCallback { public static final Frame FLUSH_FRAME = new Frame(OpCode.BINARY); @@ -96,6 +100,8 @@ public class FrameFlusher extends IteratingCallback byte opCode = frame.getOpCode(); Throwable dead; + List failedEntries = null; + CloseStatus closeStatus = null; synchronized (this) { @@ -104,10 +110,29 @@ public class FrameFlusher extends IteratingCallback dead = closedCause; if (dead == null) { - if (opCode == OpCode.PING || opCode == OpCode.PONG) - queue.offerFirst(entry); - else - queue.offerLast(entry); + switch (opCode) + { + case OpCode.CLOSE: + closeStatus = CloseStatus.getCloseStatus(frame); + if (!CloseStatus.isOrdinary(closeStatus)) + { + //fail all existing entries in the queue, and enqueue the error close + failedEntries = new ArrayList<>(queue); + queue.clear(); + } + queue.offerLast(entry); + this.canEnqueue = false; + break; + + case OpCode.PING: + case OpCode.PONG: + queue.offerFirst(entry); + break; + + default: + queue.offerLast(entry); + break; + } /* If the queue was empty then no timeout has been set, so we set a timeout to check the current entry when it expires. When the timeout expires we will go over entries in the queue and @@ -115,9 +140,6 @@ public class FrameFlusher extends IteratingCallback with the soonest expiry time. */ if ((idleTimeout > 0) && (queue.size()==1) && entries.isEmpty()) timeoutScheduler.schedule(this::timeoutExpired, idleTimeout, TimeUnit.MILLISECONDS); - - if (opCode == OpCode.CLOSE) - this.canEnqueue = false; } } else @@ -126,6 +148,21 @@ public class FrameFlusher extends IteratingCallback } } + if (failedEntries != null) + { + WebSocketException failure = new WebSocketException("Flusher received abnormal CloseFrame: " + CloseStatus.codeString(closeStatus.getCode())); + if (closeStatus instanceof AbnormalCloseStatus) + { + Throwable cause = ((AbnormalCloseStatus)closeStatus).getCause(); + failure.initCause(cause); + } + + for (Entry e : failedEntries) + { + notifyCallbackFailure(e.callback, failure); + } + } + if (dead == null) { if (LOG.isDebugEnabled()) diff --git a/jetty-websocket/websocket-core/src/test/java/org/eclipse/jetty/websocket/core/internal/FrameFlusherTest.java b/jetty-websocket/websocket-core/src/test/java/org/eclipse/jetty/websocket/core/internal/FrameFlusherTest.java index c751ccbddd4..dee3c6860b8 100644 --- a/jetty-websocket/websocket-core/src/test/java/org/eclipse/jetty/websocket/core/internal/FrameFlusherTest.java +++ b/jetty-websocket/websocket-core/src/test/java/org/eclipse/jetty/websocket/core/internal/FrameFlusherTest.java @@ -189,6 +189,41 @@ public class FrameFlusherTest assertThat(error.get(), instanceOf(WebSocketWriteTimeoutException.class)); } + @Test + public void testErrorClose() throws Exception + { + Generator generator = new Generator(bufferPool); + BlockingEndpoint endPoint = new BlockingEndpoint(bufferPool); + endPoint.setBlockTime(100); + int bufferSize = WebSocketConstants.DEFAULT_MAX_TEXT_MESSAGE_SIZE; + int maxGather = 8; + FrameFlusher frameFlusher = new FrameFlusher(bufferPool, scheduler, generator, endPoint, bufferSize, maxGather); + + // Enqueue message before the error close. + Frame frame1 = new Frame(OpCode.TEXT).setPayload("message before close").setFin(true); + CountDownLatch failedFrame1 = new CountDownLatch(1); + Callback callbackFrame1 = Callback.from(()->{}, t->failedFrame1.countDown()); + assertTrue(frameFlusher.enqueue(frame1, callbackFrame1, false)); + + // Enqueue the close frame which should fail the previous frame as it is still in the queue. + Frame closeFrame = new CloseStatus(CloseStatus.MESSAGE_TOO_LARGE).toFrame(); + CountDownLatch succeededCloseFrame = new CountDownLatch(1); + Callback closeFrameCallback = Callback.from(succeededCloseFrame::countDown, t->{}); + assertTrue(frameFlusher.enqueue(closeFrame, closeFrameCallback, false)); + assertTrue(failedFrame1.await(1, TimeUnit.SECONDS)); + + // Any frames enqueued after this should fail. + Frame frame2 = new Frame(OpCode.TEXT).setPayload("message after close").setFin(true); + CountDownLatch failedFrame2 = new CountDownLatch(1); + Callback callbackFrame2 = Callback.from(()->{}, t->failedFrame2.countDown()); + assertFalse(frameFlusher.enqueue(frame2, callbackFrame2, false)); + assertTrue(failedFrame2.await(1, TimeUnit.SECONDS)); + + // Iterating should succeed the close callback. + frameFlusher.iterate(); + assertTrue(succeededCloseFrame.await(1, TimeUnit.SECONDS)); + } + public static class CapturingEndPoint extends MockEndpoint { public Parser parser;