Overhauling CloseException to terminate the connection when encountered

This commit is contained in:
Joakim Erdfelt 2012-06-29 15:56:08 -07:00
parent 16d366e427
commit 39d8cd1b27
10 changed files with 185 additions and 113 deletions

View File

@ -3,29 +3,29 @@ package org.eclipse.jetty.websocket.api;
@SuppressWarnings("serial")
public class CloseException extends WebSocketException
{
private short closeCode;
private short statusCode;
public CloseException(short closeCode, String message)
{
super(message);
this.closeCode = closeCode;
this.statusCode = closeCode;
}
public CloseException(short closeCode, String message, Throwable cause)
{
super(message,cause);
this.closeCode = closeCode;
this.statusCode = closeCode;
}
public CloseException(short closeCode, Throwable cause)
{
super(cause);
this.closeCode = closeCode;
this.statusCode = closeCode;
}
public short getCloseCode()
public short getStatusCode()
{
return closeCode;
return statusCode;
}
}

View File

@ -0,0 +1,23 @@
package org.eclipse.jetty.websocket.api;
/**
* Per spec, a protocol error should result in a Close frame of status code 1002 (PROTOCOL_ERROR)
*/
@SuppressWarnings("serial")
public class ProtocolException extends CloseException
{
public ProtocolException(String message)
{
super(StatusCode.PROTOCOL,message);
}
public ProtocolException(String message, Throwable t)
{
super(StatusCode.PROTOCOL,message,t);
}
public ProtocolException(Throwable t)
{
super(StatusCode.PROTOCOL,t);
}
}

View File

@ -4,6 +4,7 @@ import java.io.IOException;
import java.nio.ByteBuffer;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.StringUtil;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.websocket.annotations.WebSocket;
@ -164,6 +165,12 @@ public class WebSocketEventDriver implements Parser.Listener
LOG.debug("{}.onWebSocketException({})",websocket.getClass().getSimpleName(),e);
}
if (e instanceof CloseException)
{
CloseException close = (CloseException)e;
terminateConnection(close.getStatusCode(),close.getMessage());
}
if (events.onException != null)
{
events.onException.call(websocket,connection,e);
@ -181,27 +188,41 @@ public class WebSocketEventDriver implements Parser.Listener
this.connection = conn;
}
private void unhandled(Throwable t)
private void terminateConnection(int statusCode, String rawreason)
{
LOG.warn("Unhandled Error (closing connection)",t);
// Unhandled Error, close the connection.
try
{
switch (policy.getBehavior())
String reason = rawreason;
if (StringUtil.isNotBlank(reason))
{
case SERVER:
connection.close(StatusCode.SERVER_ERROR,t.getClass().getSimpleName());
break;
case CLIENT:
connection.close(StatusCode.POLICY_VIOLATION,t.getClass().getSimpleName());
break;
// Trim big exception messages here.
if (reason.length() > CloseFrame.MAX_REASON)
{
reason = reason.substring(0,CloseFrame.MAX_REASON);
}
}
LOG.debug("terminateConnection({},{})",statusCode,reason);
connection.close(statusCode,reason);
}
catch (IOException e)
{
LOG.debug(e);
}
}
private void unhandled(Throwable t)
{
LOG.warn("Unhandled Error (closing connection)",t);
// Unhandled Error, close the connection.
switch (policy.getBehavior())
{
case SERVER:
terminateConnection(StatusCode.SERVER_ERROR,t.getClass().getSimpleName());
break;
case CLIENT:
terminateConnection(StatusCode.POLICY_VIOLATION,t.getClass().getSimpleName());
break;
}
}
}

View File

@ -4,12 +4,17 @@ import java.nio.ByteBuffer;
import org.eclipse.jetty.util.StringUtil;
import org.eclipse.jetty.websocket.api.OpCode;
import org.eclipse.jetty.websocket.api.ProtocolException;
import org.eclipse.jetty.websocket.api.StatusCode;
import org.eclipse.jetty.websocket.api.WebSocketBehavior;
/**
* Representation of a <a href="https://tools.ietf.org/html/rfc6455#section-5.5.1">Close Frame (0x08)</a>.
*/
public class CloseFrame extends ControlFrame
{
public static final int MAX_REASON = ControlFrame.MAX_PAYLOAD - 2;
/**
* Construct CloseFrame with no status code or reason
*/
@ -38,19 +43,29 @@ public class CloseFrame extends ControlFrame
public void assertValidPayload(int statusCode, String reason)
{
if ((statusCode <= 999) || (statusCode > 65535))
if ((statusCode < StatusCode.NORMAL) || (statusCode >= 5000))
{
throw new IllegalArgumentException("Status Codes must be in the range 1000 - 65535");
throw new ProtocolException("Status Codes must be in the range 1000 - 5000");
}
if ((reason != null) && (reason.length() > 123))
{
throw new IllegalArgumentException("Reason must not exceed 123 characters.");
throw new ProtocolException("Reason must not exceed 123 characters.");
}
// TODO add check for invalid utf-8
}
public void assertValidPerPolicy(WebSocketBehavior behavior)
{
int code = getStatusCode();
if ((code < StatusCode.NORMAL) || (code == StatusCode.UNDEFINED) || (code == StatusCode.NO_CLOSE) || (code == StatusCode.NO_CODE)
|| ((code > 1011) && (code <= 2999)) || (code >= 5000))
{
throw new ProtocolException("Invalid close code: " + code);
}
}
private void constructPayload(int statusCode, String reason)
{
assertValidPayload(statusCode,reason);
@ -103,7 +118,6 @@ public class CloseFrame extends ControlFrame
}
public boolean hasReason()
{
return getPayloadLength() > 2;
@ -116,7 +130,6 @@ public class CloseFrame extends ControlFrame
assertValidPayload(getStatusCode(),getReason());
}
@Override
public void setPayload(ByteBuffer payload)
{

View File

@ -3,10 +3,13 @@ package org.eclipse.jetty.websocket.frames;
import java.nio.ByteBuffer;
import org.eclipse.jetty.websocket.api.OpCode;
import org.eclipse.jetty.websocket.api.ProtocolException;
import org.eclipse.jetty.websocket.api.WebSocketException;
public abstract class ControlFrame extends BaseFrame
{
public static final int MAX_PAYLOAD = 125;
public ControlFrame()
{
super();
@ -25,6 +28,37 @@ public abstract class ControlFrame extends BaseFrame
return false; // no control frames can be continuation
}
@Override
public void setFin(boolean fin)
{
if (!fin)
{
throw new IllegalArgumentException("Cannot set FIN to off on a " + getOpCode().name());
}
}
@Override
public void setPayload(byte[] buf)
{
if ( buf.length > 125 )
{
throw new WebSocketException("Control Payloads can not exceed 125 bytes in length.");
}
super.setPayload(buf);
}
@Override
public void setPayload(ByteBuffer payload)
{
if (payload.position() > MAX_PAYLOAD)
{
throw new ProtocolException("Control Payloads can not exceed 125 bytes in length.");
}
super.setPayload(payload);
}
@Override
public void setRsv1(boolean rsv1)
{
@ -51,35 +85,4 @@ public abstract class ControlFrame extends BaseFrame
throw new IllegalArgumentException("Cannot set RSV3 to true on a " + getOpCode().name());
}
}
@Override
public void setPayload(ByteBuffer payload)
{
if ( payload.position() > 125 )
{
throw new WebSocketException("Control Payloads can not exceed 125 bytes in length.");
}
super.setPayload(payload);
}
@Override
public void setPayload(byte[] buf)
{
if ( buf.length > 125 )
{
throw new WebSocketException("Control Payloads can not exceed 125 bytes in length.");
}
super.setPayload(buf);
}
@Override
public void setFin(boolean fin)
{
if (!fin)
{
throw new IllegalArgumentException("Cannot set FIN to off on a " + getOpCode().name());
}
}
}

View File

@ -1,9 +1,6 @@
package org.eclipse.jetty.websocket.frames;
import java.nio.ByteBuffer;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.StringUtil;
import org.eclipse.jetty.websocket.api.OpCode;
/**

View File

@ -3,7 +3,7 @@ package org.eclipse.jetty.websocket.generator;
import java.nio.ByteBuffer;
import org.eclipse.jetty.util.StringUtil;
import org.eclipse.jetty.websocket.api.WebSocketException;
import org.eclipse.jetty.websocket.api.ProtocolException;
import org.eclipse.jetty.websocket.api.WebSocketPolicy;
import org.eclipse.jetty.websocket.frames.CloseFrame;
@ -33,7 +33,7 @@ public class CloseFrameGenerator extends FrameGenerator<CloseFrame>
}
else if (close.hasPayload())
{
throw new WebSocketException("Close frames require setting a status code if using payload.");
throw new ProtocolException("Close frames require setting a status code if using payload.");
}
}
}

View File

@ -2,7 +2,8 @@ package org.eclipse.jetty.websocket.parser;
import java.nio.ByteBuffer;
import org.eclipse.jetty.websocket.api.WebSocketException;
import javax.xml.ws.ProtocolException;
import org.eclipse.jetty.websocket.api.WebSocketPolicy;
import org.eclipse.jetty.websocket.frames.CloseFrame;
@ -49,7 +50,7 @@ public class ClosePayloadParser extends FrameParser<CloseFrame>
*/
if ((payloadLength == 1) || (payloadLength > 125))
{
throw new WebSocketException("Close: invalid payload length: " + payloadLength);
throw new ProtocolException("Close: invalid payload length: " + payloadLength);
}
if (payload == null)
@ -65,6 +66,7 @@ public class ClosePayloadParser extends FrameParser<CloseFrame>
if (payload.position() >= payloadLength)
{
frame.setPayload(payload);
frame.assertValidPerPolicy(getPolicy().getBehavior());
return true;
}
}

View File

@ -43,7 +43,7 @@ public class TextPayloadParserTest
capture.assertHasNoFrames();
PolicyViolationException err = (PolicyViolationException)capture.getErrors().get(0);
Assert.assertThat("Error.closeCode",err.getCloseCode(),is(StatusCode.POLICY_VIOLATION));
Assert.assertThat("Error.closeCode",err.getStatusCode(),is(StatusCode.POLICY_VIOLATION));
}
@Test

View File

@ -1,30 +1,31 @@
package org.eclipse.jetty.websocket.server.ab;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.*;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.TimeUnit;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.websocket.api.OpCode;
import org.eclipse.jetty.websocket.api.WebSocketAdapter;
import org.eclipse.jetty.websocket.frames.BaseFrame;
import org.eclipse.jetty.websocket.frames.CloseFrame;
import org.eclipse.jetty.websocket.frames.TextFrame;
import org.eclipse.jetty.websocket.generator.FrameGenerator;
import org.eclipse.jetty.websocket.server.SimpleServletServer;
import org.eclipse.jetty.websocket.server.WebSocketServerFactory;
import org.eclipse.jetty.websocket.server.WebSocketServlet;
import org.eclipse.jetty.websocket.server.WebSocketServletRFCTest.RFCServlet;
import org.eclipse.jetty.websocket.server.WebSocketServletRFCTest.RFCSocket;
import org.eclipse.jetty.websocket.server.blockhead.BlockheadClient;
import org.eclipse.jetty.websocket.server.examples.MyEchoServlet;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@ -33,37 +34,6 @@ import org.junit.runners.Parameterized.Parameters;
@RunWith(value = Parameterized.class)
public class TestABCase7_9
{
private int invalidStatusCode;
@Parameters
public static Collection<Integer[]> data()
{
List<Integer[]> data = new ArrayList<>();
// @formatter:off
data.add(new Integer[] { new Integer(0) });
data.add(new Integer[] { new Integer(999) });
data.add(new Integer[] { new Integer(1004) });
data.add(new Integer[] { new Integer(1005) });
data.add(new Integer[] { new Integer(1006) });
data.add(new Integer[] { new Integer(1012) });
data.add(new Integer[] { new Integer(1013) });
data.add(new Integer[] { new Integer(1014) });
data.add(new Integer[] { new Integer(1015) });
data.add(new Integer[] { new Integer(1016) });
data.add(new Integer[] { new Integer(1100) });
data.add(new Integer[] { new Integer(2000) });
data.add(new Integer[] { new Integer(2999) });
// @formatter:on
return data;
}
public TestABCase7_9(Integer invalidStatusCode )
{
this.invalidStatusCode = invalidStatusCode;
}
@SuppressWarnings("serial")
public static class RFCServlet extends WebSocketServlet
{
@ -104,10 +74,33 @@ public class TestABCase7_9
private static SimpleServletServer server;
@Parameters
public static Collection<Integer[]> data()
{
List<Integer[]> data = new ArrayList<>();
// @formatter:off
data.add(new Integer[] { new Integer(0) });
data.add(new Integer[] { new Integer(999) });
data.add(new Integer[] { new Integer(1004) });
data.add(new Integer[] { new Integer(1005) });
data.add(new Integer[] { new Integer(1006) });
data.add(new Integer[] { new Integer(1012) });
data.add(new Integer[] { new Integer(1013) });
data.add(new Integer[] { new Integer(1014) });
data.add(new Integer[] { new Integer(1015) });
data.add(new Integer[] { new Integer(1016) });
data.add(new Integer[] { new Integer(1100) });
data.add(new Integer[] { new Integer(2000) });
data.add(new Integer[] { new Integer(2999) });
// @formatter:on
return data;
}
@BeforeClass
public static void startServer() throws Exception
{
server = new SimpleServletServer(new RFCServlet());
server = new SimpleServletServer(new MyEchoServlet());
server.start();
}
@ -117,11 +110,29 @@ public class TestABCase7_9
server.stop();
}
private int invalidStatusCode;
public TestABCase7_9(Integer invalidStatusCode)
{
this.invalidStatusCode = invalidStatusCode;
}
private void remask(ByteBuffer buf, int position, byte[] mask)
{
int end = buf.position();
int off;
for (int i = position; i < end; i++)
{
off = i - position;
// Mask each byte by its absolute position in the bytebuffer
buf.put(i,(byte)(buf.get(i) ^ mask[off % 4]));
}
}
/**
* Test the requirement of issuing
*/
@Test
@Ignore ("tossing a buffer overflow exception for some reason")
public void testCase7_9_XInvalidCloseStatusCodes() throws Exception
{
BlockheadClient client = new BlockheadClient(server.getServerUri());
@ -131,21 +142,25 @@ public class TestABCase7_9
client.sendStandardRequest();
client.expectUpgradeResponse();
// Generate text frame
client.write(new CloseFrame(invalidStatusCode)
{
@Override
public void assertValidPayload(int statusCode, String reason)
{
ByteBuffer buf = ByteBuffer.allocate(FrameGenerator.OVERHEAD + 2);
BufferUtil.clearToFill(buf);
}
});
// Create Close Frame manually, as we are testing the server's behavior of a bad client.
buf.put((byte)(0x80 | OpCode.CLOSE.getCode()));
buf.put((byte)(0x80 | 2));
byte mask[] = new byte[]
{ 0x44, 0x44, 0x44, 0x44 };
buf.put(mask);
int position = buf.position();
buf.putChar((char)this.invalidStatusCode);
remask(buf,position,mask);
BufferUtil.flipToFlush(buf,0);
client.writeRaw(buf);
// Read frame (hopefully text frame)
Queue<BaseFrame> frames = client.readFrames(1,TimeUnit.MILLISECONDS,500);
CloseFrame closeFrame = (CloseFrame)frames.remove();
Assert.assertThat("CloseFrame.status code", closeFrame.getStatusCode(),is(1002));
Assert.assertThat("CloseFrame.status code",closeFrame.getStatusCode(),is(1002));
}
finally
{
@ -153,6 +168,4 @@ public class TestABCase7_9
}
}
}