363124 improved websocket close handling

This commit is contained in:
Greg Wilkins 2011-11-08 18:18:34 +11:00
parent 5f2323418b
commit 21e692aee6
6 changed files with 267 additions and 50 deletions

View File

@ -109,7 +109,23 @@ public interface WebSocket
String getProtocol(); String getProtocol();
void sendMessage(String data) throws IOException; void sendMessage(String data) throws IOException;
void sendMessage(byte[] data, int offset, int length) throws IOException; void sendMessage(byte[] data, int offset, int length) throws IOException;
/**
* @deprecated Use {@link #close()}
*/
void disconnect(); void disconnect();
/**
* Close the connection with normal close code.
*/
void close();
/** Close the connection with specific closeCode and message.
* @param closeCode The close code to send, or -1 for no close code
* @param message The message to send or null for no message
*/
void close(int closeCode,String message);
boolean isOpen(); boolean isOpen();
/** /**
@ -154,12 +170,6 @@ public interface WebSocket
*/ */
public interface FrameConnection extends Connection public interface FrameConnection extends Connection
{ {
/** Close the connection with specific closeCode and message.
* @param closeCode
* @param message
*/
void close(int closeCode,String message);
/** /**
* @return The opcode of a binary message * @return The opcode of a binary message
*/ */

View File

@ -296,6 +296,12 @@ public class WebSocketConnectionD00 extends AbstractConnection implements WebSoc
/* ------------------------------------------------------------ */ /* ------------------------------------------------------------ */
public void disconnect() public void disconnect()
{
close();
}
/* ------------------------------------------------------------ */
public void close()
{ {
try try
{ {

View File

@ -508,6 +508,12 @@ public class WebSocketConnectionD06 extends AbstractConnection implements WebSoc
/* ------------------------------------------------------------ */ /* ------------------------------------------------------------ */
public void disconnect() public void disconnect()
{
close();
}
/* ------------------------------------------------------------ */
public void close()
{ {
close(CLOSE_NORMAL,null); close(CLOSE_NORMAL,null);
} }

View File

@ -585,9 +585,15 @@ public class WebSocketConnectionD08 extends AbstractConnection implements WebSoc
{ {
return opcode==OP_PONG; return opcode==OP_PONG;
} }
/* ------------------------------------------------------------ */ /* ------------------------------------------------------------ */
public void disconnect() public void disconnect()
{
close();
}
/* ------------------------------------------------------------ */
public void close()
{ {
close(CLOSE_NORMAL,null); close(CLOSE_NORMAL,null);
} }

View File

@ -382,16 +382,18 @@ public class WebSocketConnectionD13 extends AbstractConnection implements WebSoc
{ {
if (!closed_out) if (!closed_out)
{ {
// Close code 1005 (CLOSE No Code) is never to be sent as a status over // Close code 1005/1006 are never to be sent as a status over
// a Close control frame. // a Close control frame. Code<-1 also means no node.
if ( (code<=0) || (code == WebSocketConnectionD13.CLOSE_NO_CODE) )
{ if (code<0 || (code == WebSocketConnectionD13.CLOSE_NO_CODE) || code==WebSocketConnectionD13.CLOSE_NO_CLOSE)
code=-1;
else if (code==0)
code=WebSocketConnectionD13.CLOSE_NORMAL; code=WebSocketConnectionD13.CLOSE_NORMAL;
}
byte[] bytes = ("xx"+(message==null?"":message)).getBytes(StringUtil.__ISO_8859_1); byte[] bytes = ("xx"+(message==null?"":message)).getBytes(StringUtil.__ISO_8859_1);
bytes[0]=(byte)(code/0x100); bytes[0]=(byte)(code/0x100);
bytes[1]=(byte)(code%0x100); bytes[1]=(byte)(code%0x100);
_outbound.addFrame((byte)FLAG_FIN,WebSocketConnectionD13.OP_CLOSE,bytes,0,bytes.length); _outbound.addFrame((byte)FLAG_FIN,WebSocketConnectionD13.OP_CLOSE,bytes,0,code>0?bytes.length:0);
_outbound.flush(); _outbound.flush();
} }
} }
@ -607,6 +609,12 @@ public class WebSocketConnectionD13 extends AbstractConnection implements WebSoc
{ {
close(CLOSE_NORMAL,null); close(CLOSE_NORMAL,null);
} }
/* ------------------------------------------------------------ */
public void close()
{
close(CLOSE_NORMAL,null);
}
/* ------------------------------------------------------------ */ /* ------------------------------------------------------------ */
public void setAllowFrameFragmentation(boolean allowFragmentation) public void setAllowFrameFragmentation(boolean allowFragmentation)
@ -767,7 +775,7 @@ public class WebSocketConnectionD13 extends AbstractConnection implements WebSoc
( code > 1010 && code <= 2999 ) || ( code > 1010 && code <= 2999 ) ||
code >= 5000 ) code >= 5000 )
{ {
errorClose(WebSocketConnectionD13.CLOSE_PROTOCOL,"Invalid close control status code " + code); errorClose(WebSocketConnectionD13.CLOSE_PROTOCOL,"Invalid close code " + code);
return; return;
} }

View File

@ -908,52 +908,233 @@ public class WebSocketMessageD13Test
assertEquals(WebSocketConnectionD13.CLOSE_MESSAGE_TOO_LARGE,code); assertEquals(WebSocketConnectionD13.CLOSE_MESSAGE_TOO_LARGE,code);
lookFor("Message size > 15",input); lookFor("Message size > 15",input);
} }
@Test @Test
public void testCloseCode() throws Exception public void testCloseIn() throws Exception
{ {
Socket socket = new Socket("localhost", __connector.getLocalPort()); int[][] tests =
OutputStream output = socket.getOutputStream(); {
output.write( {-1,0,-1},
("GET /chat HTTP/1.1\r\n"+ {-1,0,-1},
"Host: server.example.com\r\n"+ {1000,2,1000},
"Upgrade: websocket\r\n"+ {1000,2+4,1000},
"Connection: Upgrade\r\n"+ {1005,2+23,1002},
"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"+ {1005,2+23,1002},
"Sec-WebSocket-Origin: http://example.com\r\n"+ {1006,2+23,1002},
"Sec-WebSocket-Protocol: chat\r\n" + {1006,2+23,1002},
"Sec-WebSocket-Version: "+WebSocketConnectionD13.VERSION+"\r\n"+ {4000,2,4000},
"\r\n").getBytes("ISO-8859-1")); {4000,2+4,4000},
output.flush(); {9000,2+23,1002},
{9000,2+23,1002}
};
socket.setSoTimeout(100000); String[] mesg =
InputStream input = socket.getInputStream(); {
"",
"",
"",
"mesg",
"",
"mesg",
"",
"mesg",
"",
"mesg",
"",
"mesg"
};
String[] resp =
{
"",
"",
"",
"mesg",
"Invalid close code 1005",
"Invalid close code 1005",
"Invalid close code 1006",
"Invalid close code 1006",
"",
"mesg",
"Invalid close code 9000",
"Invalid close code 9000"
};
lookFor("HTTP/1.1 101 Switching Protocols\r\n",input); for (int t=0;t<tests.length;t++)
skipTo("Sec-WebSocket-Accept: ",input); {
lookFor("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=",input); String tst=""+t;
skipTo("\r\n\r\n",input); Socket socket = new Socket("localhost", __connector.getLocalPort());
OutputStream output = socket.getOutputStream();
output.write(
("GET /chat HTTP/1.1\r\n"+
"Host: server.example.com\r\n"+
"Upgrade: websocket\r\n"+
"Connection: Upgrade\r\n"+
"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"+
"Sec-WebSocket-Origin: http://example.com\r\n"+
"Sec-WebSocket-Protocol: chat\r\n" +
"Sec-WebSocket-Version: "+WebSocketConnectionD13.VERSION+"\r\n"+
"\r\n").getBytes("ISO-8859-1"));
output.flush();
assertTrue(__serverWebSocket.awaitConnected(1000)); socket.setSoTimeout(100000);
assertNotNull(__serverWebSocket.connection); InputStream input = socket.getInputStream();
__serverWebSocket.getConnection().setMaxBinaryMessageSize(15); lookFor("HTTP/1.1 101 Switching Protocols\r\n",input);
skipTo("Sec-WebSocket-Accept: ",input);
lookFor("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=",input);
skipTo("\r\n\r\n",input);
output.write(0x88); assertTrue(__serverWebSocket.awaitConnected(1000));
output.write(0x82); assertNotNull(__serverWebSocket.connection);
output.write(0x00);
output.write(0x00);
output.write(0x00);
output.write(0x00);
output.write(0x81);
output.write(0xFF);
output.flush();
assertEquals(0x80|WebSocketConnectionD13.OP_CLOSE,input.read()); int code=tests[t][0];
assertEquals(41,input.read()); String m=mesg[t];
int code=(0xff&input.read())*0x100+(0xff&input.read());
assertEquals(1002,code); // Invalid code 0x81FF output.write(0x88);
output.write(0x80 + (code<=0?0:(2+m.length())));
output.write(0x00);
output.write(0x00);
output.write(0x00);
output.write(0x00);
if (code>0)
{
output.write(code/0x100);
output.write(code%0x100);
output.write(m.getBytes());
}
output.flush();
__serverWebSocket.awaitDisconnected(1000);
byte[] buf = new byte[128];
int len = input.read(buf);
assertEquals(tst,2+tests[t][1],len);
assertEquals(tst,(byte)0x88,buf[0]);
if (len>=4)
{
code=(0xff&buf[2])*0x100+(0xff&buf[3]);
assertEquals(tst,tests[t][2],code);
if (len>4)
{
m = new String(buf,4,len-4,"UTF-8");
assertEquals(tst,resp[t],m);
}
}
else
assertEquals(tst,tests[t][2],-1);
len = input.read(buf);
assertEquals(tst,-1,len);
}
} }
@Test
public void testCloseOut() throws Exception
{
int[][] tests =
{
{-1,0,-1},
{-1,0,-1},
{0,2,1000},
{0,2+4,1000},
{1000,2,1000},
{1000,2+4,1000},
{1005,0,-1},
{1005,0,-1},
{1006,0,-1},
{1006,0,-1},
{9000,2,9000},
{9000,2+4,9000}
};
String[] mesg =
{
null,
"Not Sent",
null,
"mesg",
null,
"mesg",
null,
"mesg",
null,
"mesg",
null,
"mesg"
};
for (int t=0;t<tests.length;t++)
{
String tst=""+t;
Socket socket = new Socket("localhost", __connector.getLocalPort());
OutputStream output = socket.getOutputStream();
output.write(
("GET /chat HTTP/1.1\r\n"+
"Host: server.example.com\r\n"+
"Upgrade: websocket\r\n"+
"Connection: Upgrade\r\n"+
"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"+
"Sec-WebSocket-Origin: http://example.com\r\n"+
"Sec-WebSocket-Protocol: chat\r\n" +
"Sec-WebSocket-Version: "+WebSocketConnectionD13.VERSION+"\r\n"+
"\r\n").getBytes("ISO-8859-1"));
output.flush();
socket.setSoTimeout(100000);
InputStream input = socket.getInputStream();
lookFor("HTTP/1.1 101 Switching Protocols\r\n",input);
skipTo("Sec-WebSocket-Accept: ",input);
lookFor("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=",input);
skipTo("\r\n\r\n",input);
assertTrue(__serverWebSocket.awaitConnected(1000));
assertNotNull(__serverWebSocket.connection);
__serverWebSocket.getConnection().close(tests[t][0],mesg[t]);
byte[] buf = new byte[128];
int len = input.read(buf);
assertEquals(tst,2+tests[t][1],len);
assertEquals(tst,(byte)0x88,buf[0]);
if (len>=4)
{
int code=(0xff&buf[2])*0x100+(0xff&buf[3]);
assertEquals(tst,tests[t][2],code);
if (len>4)
{
String m = new String(buf,4,len-4,"UTF-8");
assertEquals(tst,mesg[t],m);
}
}
else
assertEquals(tst,tests[t][2],-1);
output.write(0x88);
output.write(0x80);
output.write(0x00);
output.write(0x00);
output.write(0x00);
output.write(0x00);
output.flush();
len = input.read(buf);
assertEquals(tst,-1,len);
}
}
@Test @Test
public void testNotUTF8() throws Exception public void testNotUTF8() throws Exception