implementation of HttpServletRequest.upgrade

Signed-off-by: Ludovic Orban <lorban@bitronix.be>
This commit is contained in:
Ludovic Orban 2021-01-28 16:03:14 +01:00
parent 4b568264f1
commit 421ed6bf8d
6 changed files with 492 additions and 8 deletions

View File

@ -465,6 +465,12 @@ public class HttpGenerator
}
}
public void servletUpgrade()
{
_noContentResponse = false;
_state = State.COMMITTED;
}
private void prepareChunk(ByteBuffer chunk, int remaining)
{
// if we need CRLF add this to header

View File

@ -1685,8 +1685,8 @@ public class HttpParser
{
_contentChunk = buffer.asReadOnlyBuffer();
// limit content by expected size
if (remaining > content)
// limit content by expected size if _contentLength is >= 0 (i.e.: not infinite)
if (_contentLength > -1 && remaining > content)
{
// We can cast remaining to an int as we know that it is smaller than
// or equal to length which is already an int.
@ -1888,6 +1888,13 @@ public class HttpParser
_headerComplete = false;
}
public void servletUpgrade()
{
setState(State.CONTENT);
_endOfContent = EndOfContent.UNKNOWN_CONTENT;
_contentLength = -1;
}
protected void setState(State state)
{
if (debugEnabled)

View File

@ -929,7 +929,7 @@ public abstract class HttpChannel implements Runnable, HttpOutput.Interceptor
commit(response);
_combinedListener.onResponseBegin(_request);
_request.onResponseCommit();
// wrap callback to process 100 responses
final int status = response.getStatus();
final Callback committed = (status < HttpStatus.OK_200 && status >= HttpStatus.CONTINUE_100)

View File

@ -66,6 +66,7 @@ public class HttpChannelOverHttp extends HttpChannel implements HttpParser.Reque
// events like timeout: we get notified and either schedule onError or release the
// blocking semaphore.
private HttpInput.Content _content;
private boolean _servletUpgrade;
public HttpChannelOverHttp(HttpConnection httpConnection, Connector connector, HttpConfiguration config, EndPoint endPoint, HttpTransport transport)
{
@ -262,10 +263,19 @@ public class HttpChannelOverHttp extends HttpChannel implements HttpParser.Reque
{
if (LOG.isDebugEnabled())
LOG.debug("received early EOF, content = {}", _content);
EofException failure = new EofException("Early EOF");
if (_content != null)
_content.failed(failure);
_content = new HttpInput.ErrorContent(failure);
if (_servletUpgrade)
{
if (_content != null)
_content.succeeded();
_content = EOF;
}
else
{
EofException failure = new EofException("Early EOF");
if (_content != null)
_content.failed(failure);
_content = new HttpInput.ErrorContent(failure);
}
}
@Override
@ -555,6 +565,16 @@ public class HttpChannelOverHttp extends HttpChannel implements HttpParser.Reque
if (_content != null && !_content.isSpecial())
throw new AssertionError("unconsumed content: " + _content);
_content = null;
_servletUpgrade = false;
}
public void servletUpgrade()
{
if (_content != null && (!_content.isSpecial() || !_content.isEof()))
throw new IllegalStateException("Cannot perform servlet upgrade with unconsumed content");
_content = null;
_servletUpgrade = true;
_httpConnection.getParser().servletUpgrade();
}
@Override

View File

@ -47,6 +47,7 @@ import javax.servlet.RequestDispatcher;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletOutputStream;
import javax.servlet.ServletRequest;
import javax.servlet.ServletRequestAttributeEvent;
import javax.servlet.ServletRequestAttributeListener;
@ -61,6 +62,7 @@ import javax.servlet.http.HttpSession;
import javax.servlet.http.HttpUpgradeHandler;
import javax.servlet.http.Part;
import javax.servlet.http.PushBuilder;
import javax.servlet.http.WebConnection;
import org.eclipse.jetty.http.BadMessageException;
import org.eclipse.jetty.http.ComplianceViolation;
@ -78,6 +80,7 @@ import org.eclipse.jetty.http.HttpURI;
import org.eclipse.jetty.http.HttpVersion;
import org.eclipse.jetty.http.MetaData;
import org.eclipse.jetty.http.MimeTypes;
import org.eclipse.jetty.io.Connection;
import org.eclipse.jetty.io.RuntimeIOException;
import org.eclipse.jetty.server.handler.ContextHandler;
import org.eclipse.jetty.server.handler.ContextHandler.Context;
@ -2140,6 +2143,11 @@ public class Request implements HttpServletRequest
{
if (_asyncNotSupportedSource != null)
throw new IllegalStateException("!asyncSupported: " + _asyncNotSupportedSource);
return forceStartAsync();
}
private AsyncContextState forceStartAsync()
{
HttpChannelState state = getHttpChannelState();
if (_async == null)
_async = new AsyncContextState(state);
@ -2372,7 +2380,95 @@ public class Request implements HttpServletRequest
@Override
public <T extends HttpUpgradeHandler> T upgrade(Class<T> handlerClass) throws IOException, ServletException
{
throw new ServletException("HttpServletRequest.upgrade() not supported in Jetty");
Response response = _channel.getResponse();
if (response.getStatus() != HttpStatus.SWITCHING_PROTOCOLS_101)
throw new IllegalStateException("Response status should be 101");
if (response.getHeader("Upgrade") == null)
throw new IllegalStateException("Missing Upgrade header");
if (!"Upgrade".equalsIgnoreCase(response.getHeader("Connection")))
throw new IllegalStateException("Invalid Connection header");
if (response.isCommitted())
throw new IllegalStateException("Cannot upgrade committed response");
if (_metaData == null || _metaData.getHttpVersion() != HttpVersion.HTTP_1_1)
throw new IllegalStateException("Only requests over HTTP/1.1 can be upgraded");
ServletOutputStream outputStream = response.getOutputStream();
ServletInputStream inputStream = getInputStream();
HttpChannelOverHttp httpChannel11 = (HttpChannelOverHttp)_channel;
HttpConnection httpConnection = (HttpConnection)_channel.getConnection();
T upgradeHandler;
try
{
upgradeHandler = handlerClass.getDeclaredConstructor().newInstance();
}
catch (Exception e)
{
throw new ServletException("Unable to instantiate handler class", e);
}
httpChannel11.servletUpgrade(); // tell the HTTP 1.1 channel that it is now handling an upgraded servlet
AsyncContext asyncContext = forceStartAsync(); // force the servlet in async mode
outputStream.flush(); // commit the 101 response
httpConnection.getGenerator().servletUpgrade(); // tell the generator it can send data as-is
httpConnection.addEventListener(new Connection.Listener()
{
@Override
public void onClosed(Connection connection)
{
try
{
asyncContext.complete();
}
catch (Exception e)
{
LOG.warn("error during upgrade AsyncContext complete", e);
}
try
{
upgradeHandler.destroy();
}
catch (Exception e)
{
LOG.warn("error during upgrade HttpUpgradeHandler destroy", e);
}
}
@Override
public void onOpened(Connection connection)
{
}
});
upgradeHandler.init(new WebConnection()
{
@Override
public void close() throws Exception
{
try
{
inputStream.close();
}
finally
{
outputStream.close();
}
}
@Override
public ServletInputStream getInputStream()
{
return inputStream;
}
@Override
public ServletOutputStream getOutputStream()
{
return outputStream;
}
});
return upgradeHandler;
}
/**

View File

@ -0,0 +1,355 @@
//
// ========================================================================
// Copyright (c) 1995-2021 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under the
// terms of the Eclipse Public License v. 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
// which is available at https://www.apache.org/licenses/LICENSE-2.0.
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//
package org.eclipse.jetty.servlet;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import javax.servlet.ReadListener;
import javax.servlet.ServletException;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpUpgradeHandler;
import javax.servlet.http.WebConnection;
import org.eclipse.jetty.server.Handler;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.server.handler.DefaultHandler;
import org.eclipse.jetty.server.handler.HandlerList;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static org.eclipse.jetty.util.StringUtil.CRLF;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.instanceOf;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class ServletUpgradeTest
{
private static final Logger LOG = LoggerFactory.getLogger(ServletUpgradeTest.class);
private Server server;
private int port;
@BeforeEach
public void setUp() throws Exception
{
server = new Server();
ServerConnector connector = new ServerConnector(server);
server.addConnector(connector);
ServletContextHandler contextHandler = new ServletContextHandler(ServletContextHandler.NO_SESSIONS);
contextHandler.setContextPath("/");
contextHandler.addServlet(new ServletHolder(new TestServlet()), "/TestServlet");
HandlerList handlers = new HandlerList();
handlers.setHandlers(new Handler[]{contextHandler, new DefaultHandler()});
server.setHandler(handlers);
server.start();
port = connector.getLocalPort();
}
@AfterEach
public void tearDown() throws Exception
{
server.stop();
}
@Test
public void upgradeTest() throws Exception
{
boolean passed1 = false;
boolean passed2 = false;
boolean passed3 = false;
String expectedResponse1 = "TCKHttpUpgradeHandler.init";
String expectedResponse2 = "onDataAvailable|Hello";
String expectedResponse3 = "onDataAvailable|World";
InputStream input = null;
OutputStream output = null;
Socket s = null;
try
{
s = new Socket("localhost", port);
output = s.getOutputStream();
StringBuilder reqStr = new StringBuilder()
.append("POST /TestServlet HTTP/1.1").append(CRLF)
.append("User-Agent: Java/1.6.0_33").append(CRLF)
.append("Host: localhost:").append(port).append(CRLF)
.append("Accept: text/html, image/gif, image/jpeg, *; q=.2, */*; q=.2").append(CRLF)
.append("Upgrade: YES").append(CRLF)
.append("Connection: Upgrade").append(CRLF)
.append("Content-type: application/x-www-form-urlencoded").append(CRLF)
.append(CRLF);
LOG.info("REQUEST=========" + reqStr.toString());
output.write(reqStr.toString().getBytes());
LOG.info("Writing first chunk");
writeChunk(output, "Hello");
LOG.info("Writing second chunk");
writeChunk(output, "World");
LOG.info("Consuming the response from the server");
// Consume the response from the server
input = s.getInputStream();
int len;
byte[] b = new byte[1024];
boolean receivedFirstMessage = false;
boolean receivedSecondMessage = false;
boolean receivedThirdMessage = false;
StringBuilder sb = new StringBuilder();
while ((len = input.read(b)) != -1)
{
String line = new String(b, 0, len);
sb.append(line);
LOG.info("==============Read from server:" + CRLF + sb + CRLF);
if (passed1 = compareString(expectedResponse1, sb.toString()))
{
LOG.info("==============Received first expected response!" + CRLF);
receivedFirstMessage = true;
}
if (passed2 = compareString(expectedResponse2, sb.toString()))
{
LOG.info("==============Received second expected response!" + CRLF);
receivedSecondMessage = true;
}
if (passed3 = compareString(expectedResponse3, sb.toString()))
{
LOG.info("==============Received third expected response!" + CRLF);
receivedThirdMessage = true;
}
LOG.info("receivedFirstMessage : " + receivedFirstMessage);
LOG.info("receivedSecondMessage : " + receivedSecondMessage);
LOG.info("receivedThirdMessage : " + receivedThirdMessage);
if (receivedFirstMessage && receivedSecondMessage && receivedThirdMessage)
{
break;
}
}
}
finally
{
try
{
if (input != null)
{
LOG.info("Closing input...");
input.close();
LOG.info("Input closed.");
}
}
catch (Exception ex)
{
LOG.error("Failed to close input:" + ex.getMessage(), ex);
}
try
{
if (output != null)
{
LOG.info("Closing output...");
output.close();
LOG.info("Output closed .");
}
}
catch (Exception ex)
{
LOG.error("Failed to close output:" + ex.getMessage(), ex);
}
try
{
if (s != null)
{
LOG.info("Closing socket..." + CRLF);
s.close();
LOG.info("Socked closed.");
}
}
catch (Exception ex)
{
LOG.error("Failed to close socket:" + ex.getMessage(), ex);
}
}
assertTrue(passed1 && passed2 && passed3);
}
private static class TestServlet extends HttpServlet
{
public void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException
{
if (request.getHeader("Upgrade") != null)
{
response.setStatus(101);
response.setHeader("Upgrade", "YES");
response.setHeader("Connection", "Upgrade");
TestHttpUpgradeHandler handler = request.upgrade(TestHttpUpgradeHandler.class);
assertThat(handler, instanceOf(TestHttpUpgradeHandler.class));
}
else
{
response.getWriter().println("No upgrade");
response.getWriter().println("End of Test");
}
}
}
public static class TestHttpUpgradeHandler implements HttpUpgradeHandler
{
public TestHttpUpgradeHandler()
{
}
@Override
public void destroy()
{
LOG.debug("===============destroy");
}
@Override
public void init(WebConnection wc)
{
try
{
ServletInputStream input = wc.getInputStream();
ServletOutputStream output = wc.getOutputStream();
TestReadListener readListener = new TestReadListener("/", input, output);
input.setReadListener(readListener);
output.println("===============TCKHttpUpgradeHandler.init");
output.flush();
}
catch (Exception ex)
{
throw new RuntimeException(ex);
}
}
}
private static class TestReadListener implements ReadListener
{
private final ServletInputStream input;
private final ServletOutputStream output;
private final String delimiter;
TestReadListener(String del, ServletInputStream in, ServletOutputStream out)
{
input = in;
output = out;
delimiter = del;
}
public void onAllDataRead()
{
try
{
output.println("=onAllDataRead");
output.close();
}
catch (Exception ex)
{
throw new IllegalStateException(ex);
}
}
public void onDataAvailable()
{
try
{
output.println("=onDataAvailable");
StringBuilder sb = new StringBuilder();
int len;
byte[] b = new byte[1024];
while (input.isReady() && (len = input.read(b)) != -1)
{
String data = new String(b, 0, len);
sb.append(data);
}
output.println(delimiter + sb.toString());
output.flush();
}
catch (Exception ex)
{
throw new IllegalStateException(ex);
}
}
public void onError(final Throwable t)
{
LOG.error("TestReadListener error", t);
}
}
private static boolean compareString(String expected, String actual)
{
String[] listExpected = expected.split("[|]");
boolean found = true;
for (int i = 0, n = listExpected.length, startIdx = 0, bodyLength = actual.length(); i < n; i++)
{
String search = listExpected[i];
if (startIdx >= bodyLength)
{
startIdx = bodyLength;
}
int searchIdx = actual.toLowerCase().indexOf(search.toLowerCase(), startIdx);
LOG.debug("[ServletTestUtil] Scanning response for " + "search string: '" + search + "' starting at index " + "location: " + startIdx);
if (searchIdx < 0)
{
found = false;
String s = "[ServletTestUtil] Unable to find the following " +
"search string in the server's " +
"response: '" + search + "' at index: " +
startIdx +
"\n[ServletTestUtil] Server's response:\n" +
"-------------------------------------------\n" +
actual +
"\n-------------------------------------------\n";
LOG.debug(s);
break;
}
LOG.debug("[ServletTestUtil] Found search string: '" + search + "' at index '" + searchIdx + "' in the server's " + "response");
// the new searchIdx is the old index plus the lenght of the
// search string.
startIdx = searchIdx + search.length();
}
return found;
}
private static void writeChunk(OutputStream out, String data) throws IOException
{
if (data != null)
{
out.write(data.getBytes());
}
out.flush();
}
}