diff --git a/jetty-http/src/main/java/org/eclipse/jetty/http/HttpGenerator.java b/jetty-http/src/main/java/org/eclipse/jetty/http/HttpGenerator.java index 81398ea59b1..00e795a7853 100644 --- a/jetty-http/src/main/java/org/eclipse/jetty/http/HttpGenerator.java +++ b/jetty-http/src/main/java/org/eclipse/jetty/http/HttpGenerator.java @@ -465,6 +465,12 @@ public class HttpGenerator } } + public void servletUpgrade() + { + _noContentResponse = false; + _state = State.COMMITTED; + } + private void prepareChunk(ByteBuffer chunk, int remaining) { // if we need CRLF add this to header diff --git a/jetty-http/src/main/java/org/eclipse/jetty/http/HttpParser.java b/jetty-http/src/main/java/org/eclipse/jetty/http/HttpParser.java index 07d597d81d6..52844a5d9eb 100644 --- a/jetty-http/src/main/java/org/eclipse/jetty/http/HttpParser.java +++ b/jetty-http/src/main/java/org/eclipse/jetty/http/HttpParser.java @@ -1685,8 +1685,8 @@ public class HttpParser { _contentChunk = buffer.asReadOnlyBuffer(); - // limit content by expected size - if (remaining > content) + // limit content by expected size if _contentLength is >= 0 (i.e.: not infinite) + if (_contentLength > -1 && remaining > content) { // We can cast remaining to an int as we know that it is smaller than // or equal to length which is already an int. @@ -1888,6 +1888,13 @@ public class HttpParser _headerComplete = false; } + public void servletUpgrade() + { + setState(State.CONTENT); + _endOfContent = EndOfContent.UNKNOWN_CONTENT; + _contentLength = -1; + } + protected void setState(State state) { if (debugEnabled) diff --git a/jetty-server/src/main/java/org/eclipse/jetty/server/HttpChannel.java b/jetty-server/src/main/java/org/eclipse/jetty/server/HttpChannel.java index 36e9f81eadf..a7c1eb35d4f 100644 --- a/jetty-server/src/main/java/org/eclipse/jetty/server/HttpChannel.java +++ b/jetty-server/src/main/java/org/eclipse/jetty/server/HttpChannel.java @@ -929,7 +929,7 @@ public abstract class HttpChannel implements Runnable, HttpOutput.Interceptor commit(response); _combinedListener.onResponseBegin(_request); _request.onResponseCommit(); - + // wrap callback to process 100 responses final int status = response.getStatus(); final Callback committed = (status < HttpStatus.OK_200 && status >= HttpStatus.CONTINUE_100) diff --git a/jetty-server/src/main/java/org/eclipse/jetty/server/HttpChannelOverHttp.java b/jetty-server/src/main/java/org/eclipse/jetty/server/HttpChannelOverHttp.java index ca4b7047a87..689ffaadfbe 100644 --- a/jetty-server/src/main/java/org/eclipse/jetty/server/HttpChannelOverHttp.java +++ b/jetty-server/src/main/java/org/eclipse/jetty/server/HttpChannelOverHttp.java @@ -66,6 +66,7 @@ public class HttpChannelOverHttp extends HttpChannel implements HttpParser.Reque // events like timeout: we get notified and either schedule onError or release the // blocking semaphore. private HttpInput.Content _content; + private boolean _servletUpgrade; public HttpChannelOverHttp(HttpConnection httpConnection, Connector connector, HttpConfiguration config, EndPoint endPoint, HttpTransport transport) { @@ -262,10 +263,19 @@ public class HttpChannelOverHttp extends HttpChannel implements HttpParser.Reque { if (LOG.isDebugEnabled()) LOG.debug("received early EOF, content = {}", _content); - EofException failure = new EofException("Early EOF"); - if (_content != null) - _content.failed(failure); - _content = new HttpInput.ErrorContent(failure); + if (_servletUpgrade) + { + if (_content != null) + _content.succeeded(); + _content = EOF; + } + else + { + EofException failure = new EofException("Early EOF"); + if (_content != null) + _content.failed(failure); + _content = new HttpInput.ErrorContent(failure); + } } @Override @@ -555,6 +565,16 @@ public class HttpChannelOverHttp extends HttpChannel implements HttpParser.Reque if (_content != null && !_content.isSpecial()) throw new AssertionError("unconsumed content: " + _content); _content = null; + _servletUpgrade = false; + } + + public void servletUpgrade() + { + if (_content != null && (!_content.isSpecial() || !_content.isEof())) + throw new IllegalStateException("Cannot perform servlet upgrade with unconsumed content"); + _content = null; + _servletUpgrade = true; + _httpConnection.getParser().servletUpgrade(); } @Override diff --git a/jetty-server/src/main/java/org/eclipse/jetty/server/Request.java b/jetty-server/src/main/java/org/eclipse/jetty/server/Request.java index 0393fa47477..7423dc172d2 100644 --- a/jetty-server/src/main/java/org/eclipse/jetty/server/Request.java +++ b/jetty-server/src/main/java/org/eclipse/jetty/server/Request.java @@ -47,6 +47,7 @@ import javax.servlet.RequestDispatcher; import javax.servlet.ServletContext; import javax.servlet.ServletException; import javax.servlet.ServletInputStream; +import javax.servlet.ServletOutputStream; import javax.servlet.ServletRequest; import javax.servlet.ServletRequestAttributeEvent; import javax.servlet.ServletRequestAttributeListener; @@ -61,6 +62,7 @@ import javax.servlet.http.HttpSession; import javax.servlet.http.HttpUpgradeHandler; import javax.servlet.http.Part; import javax.servlet.http.PushBuilder; +import javax.servlet.http.WebConnection; import org.eclipse.jetty.http.BadMessageException; import org.eclipse.jetty.http.ComplianceViolation; @@ -78,6 +80,7 @@ import org.eclipse.jetty.http.HttpURI; import org.eclipse.jetty.http.HttpVersion; import org.eclipse.jetty.http.MetaData; import org.eclipse.jetty.http.MimeTypes; +import org.eclipse.jetty.io.Connection; import org.eclipse.jetty.io.RuntimeIOException; import org.eclipse.jetty.server.handler.ContextHandler; import org.eclipse.jetty.server.handler.ContextHandler.Context; @@ -2140,6 +2143,11 @@ public class Request implements HttpServletRequest { if (_asyncNotSupportedSource != null) throw new IllegalStateException("!asyncSupported: " + _asyncNotSupportedSource); + return forceStartAsync(); + } + + private AsyncContextState forceStartAsync() + { HttpChannelState state = getHttpChannelState(); if (_async == null) _async = new AsyncContextState(state); @@ -2372,7 +2380,95 @@ public class Request implements HttpServletRequest @Override public T upgrade(Class handlerClass) throws IOException, ServletException { - throw new ServletException("HttpServletRequest.upgrade() not supported in Jetty"); + Response response = _channel.getResponse(); + if (response.getStatus() != HttpStatus.SWITCHING_PROTOCOLS_101) + throw new IllegalStateException("Response status should be 101"); + if (response.getHeader("Upgrade") == null) + throw new IllegalStateException("Missing Upgrade header"); + if (!"Upgrade".equalsIgnoreCase(response.getHeader("Connection"))) + throw new IllegalStateException("Invalid Connection header"); + if (response.isCommitted()) + throw new IllegalStateException("Cannot upgrade committed response"); + if (_metaData == null || _metaData.getHttpVersion() != HttpVersion.HTTP_1_1) + throw new IllegalStateException("Only requests over HTTP/1.1 can be upgraded"); + + ServletOutputStream outputStream = response.getOutputStream(); + ServletInputStream inputStream = getInputStream(); + HttpChannelOverHttp httpChannel11 = (HttpChannelOverHttp)_channel; + HttpConnection httpConnection = (HttpConnection)_channel.getConnection(); + + T upgradeHandler; + try + { + upgradeHandler = handlerClass.getDeclaredConstructor().newInstance(); + } + catch (Exception e) + { + throw new ServletException("Unable to instantiate handler class", e); + } + + httpChannel11.servletUpgrade(); // tell the HTTP 1.1 channel that it is now handling an upgraded servlet + AsyncContext asyncContext = forceStartAsync(); // force the servlet in async mode + + outputStream.flush(); // commit the 101 response + httpConnection.getGenerator().servletUpgrade(); // tell the generator it can send data as-is + httpConnection.addEventListener(new Connection.Listener() + { + @Override + public void onClosed(Connection connection) + { + try + { + asyncContext.complete(); + } + catch (Exception e) + { + LOG.warn("error during upgrade AsyncContext complete", e); + } + try + { + upgradeHandler.destroy(); + } + catch (Exception e) + { + LOG.warn("error during upgrade HttpUpgradeHandler destroy", e); + } + } + + @Override + public void onOpened(Connection connection) + { + } + }); + + upgradeHandler.init(new WebConnection() + { + @Override + public void close() throws Exception + { + try + { + inputStream.close(); + } + finally + { + outputStream.close(); + } + } + + @Override + public ServletInputStream getInputStream() + { + return inputStream; + } + + @Override + public ServletOutputStream getOutputStream() + { + return outputStream; + } + }); + return upgradeHandler; } /** diff --git a/jetty-servlet/src/test/java/org/eclipse/jetty/servlet/ServletUpgradeTest.java b/jetty-servlet/src/test/java/org/eclipse/jetty/servlet/ServletUpgradeTest.java new file mode 100644 index 00000000000..430d543387a --- /dev/null +++ b/jetty-servlet/src/test/java/org/eclipse/jetty/servlet/ServletUpgradeTest.java @@ -0,0 +1,355 @@ +// +// ======================================================================== +// Copyright (c) 1995-2021 Mort Bay Consulting Pty Ltd and others. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 +// ======================================================================== +// + +package org.eclipse.jetty.servlet; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.Socket; +import javax.servlet.ReadListener; +import javax.servlet.ServletException; +import javax.servlet.ServletInputStream; +import javax.servlet.ServletOutputStream; +import javax.servlet.http.HttpServlet; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpUpgradeHandler; +import javax.servlet.http.WebConnection; + +import org.eclipse.jetty.server.Handler; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.server.ServerConnector; +import org.eclipse.jetty.server.handler.DefaultHandler; +import org.eclipse.jetty.server.handler.HandlerList; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.eclipse.jetty.util.StringUtil.CRLF; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ServletUpgradeTest +{ + private static final Logger LOG = LoggerFactory.getLogger(ServletUpgradeTest.class); + + private Server server; + private int port; + + @BeforeEach + public void setUp() throws Exception + { + server = new Server(); + + ServerConnector connector = new ServerConnector(server); + server.addConnector(connector); + + ServletContextHandler contextHandler = new ServletContextHandler(ServletContextHandler.NO_SESSIONS); + contextHandler.setContextPath("/"); + contextHandler.addServlet(new ServletHolder(new TestServlet()), "/TestServlet"); + + HandlerList handlers = new HandlerList(); + handlers.setHandlers(new Handler[]{contextHandler, new DefaultHandler()}); + server.setHandler(handlers); + + server.start(); + port = connector.getLocalPort(); + } + + @AfterEach + public void tearDown() throws Exception + { + server.stop(); + } + + @Test + public void upgradeTest() throws Exception + { + boolean passed1 = false; + boolean passed2 = false; + boolean passed3 = false; + String expectedResponse1 = "TCKHttpUpgradeHandler.init"; + String expectedResponse2 = "onDataAvailable|Hello"; + String expectedResponse3 = "onDataAvailable|World"; + + InputStream input = null; + OutputStream output = null; + Socket s = null; + + try + { + s = new Socket("localhost", port); + output = s.getOutputStream(); + + StringBuilder reqStr = new StringBuilder() + .append("POST /TestServlet HTTP/1.1").append(CRLF) + .append("User-Agent: Java/1.6.0_33").append(CRLF) + .append("Host: localhost:").append(port).append(CRLF) + .append("Accept: text/html, image/gif, image/jpeg, *; q=.2, */*; q=.2").append(CRLF) + .append("Upgrade: YES").append(CRLF) + .append("Connection: Upgrade").append(CRLF) + .append("Content-type: application/x-www-form-urlencoded").append(CRLF) + .append(CRLF); + + LOG.info("REQUEST=========" + reqStr.toString()); + output.write(reqStr.toString().getBytes()); + + LOG.info("Writing first chunk"); + writeChunk(output, "Hello"); + + LOG.info("Writing second chunk"); + writeChunk(output, "World"); + + LOG.info("Consuming the response from the server"); + + // Consume the response from the server + input = s.getInputStream(); + int len; + byte[] b = new byte[1024]; + boolean receivedFirstMessage = false; + boolean receivedSecondMessage = false; + boolean receivedThirdMessage = false; + StringBuilder sb = new StringBuilder(); + while ((len = input.read(b)) != -1) + { + String line = new String(b, 0, len); + sb.append(line); + LOG.info("==============Read from server:" + CRLF + sb + CRLF); + if (passed1 = compareString(expectedResponse1, sb.toString())) + { + LOG.info("==============Received first expected response!" + CRLF); + receivedFirstMessage = true; + } + if (passed2 = compareString(expectedResponse2, sb.toString())) + { + LOG.info("==============Received second expected response!" + CRLF); + receivedSecondMessage = true; + } + if (passed3 = compareString(expectedResponse3, sb.toString())) + { + LOG.info("==============Received third expected response!" + CRLF); + receivedThirdMessage = true; + } + LOG.info("receivedFirstMessage : " + receivedFirstMessage); + LOG.info("receivedSecondMessage : " + receivedSecondMessage); + LOG.info("receivedThirdMessage : " + receivedThirdMessage); + if (receivedFirstMessage && receivedSecondMessage && receivedThirdMessage) + { + break; + } + } + } + finally + { + try + { + if (input != null) + { + LOG.info("Closing input..."); + input.close(); + LOG.info("Input closed."); + } + } + catch (Exception ex) + { + LOG.error("Failed to close input:" + ex.getMessage(), ex); + } + + try + { + if (output != null) + { + LOG.info("Closing output..."); + output.close(); + LOG.info("Output closed ."); + } + } + catch (Exception ex) + { + LOG.error("Failed to close output:" + ex.getMessage(), ex); + } + + try + { + if (s != null) + { + LOG.info("Closing socket..." + CRLF); + s.close(); + LOG.info("Socked closed."); + } + } + catch (Exception ex) + { + LOG.error("Failed to close socket:" + ex.getMessage(), ex); + } + } + + assertTrue(passed1 && passed2 && passed3); + } + + private static class TestServlet extends HttpServlet + { + public void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException + { + if (request.getHeader("Upgrade") != null) + { + response.setStatus(101); + response.setHeader("Upgrade", "YES"); + response.setHeader("Connection", "Upgrade"); + TestHttpUpgradeHandler handler = request.upgrade(TestHttpUpgradeHandler.class); + assertThat(handler, instanceOf(TestHttpUpgradeHandler.class)); + } + else + { + response.getWriter().println("No upgrade"); + response.getWriter().println("End of Test"); + } + } + } + + public static class TestHttpUpgradeHandler implements HttpUpgradeHandler + { + public TestHttpUpgradeHandler() + { + } + + @Override + public void destroy() + { + LOG.debug("===============destroy"); + } + + @Override + public void init(WebConnection wc) + { + try + { + ServletInputStream input = wc.getInputStream(); + ServletOutputStream output = wc.getOutputStream(); + TestReadListener readListener = new TestReadListener("/", input, output); + input.setReadListener(readListener); + output.println("===============TCKHttpUpgradeHandler.init"); + output.flush(); + } + catch (Exception ex) + { + throw new RuntimeException(ex); + } + } + } + + private static class TestReadListener implements ReadListener + { + private final ServletInputStream input; + private final ServletOutputStream output; + private final String delimiter; + + TestReadListener(String del, ServletInputStream in, ServletOutputStream out) + { + input = in; + output = out; + delimiter = del; + } + + public void onAllDataRead() + { + try + { + output.println("=onAllDataRead"); + output.close(); + } + catch (Exception ex) + { + throw new IllegalStateException(ex); + } + } + + public void onDataAvailable() + { + try + { + output.println("=onDataAvailable"); + StringBuilder sb = new StringBuilder(); + int len; + byte[] b = new byte[1024]; + while (input.isReady() && (len = input.read(b)) != -1) + { + String data = new String(b, 0, len); + sb.append(data); + } + output.println(delimiter + sb.toString()); + output.flush(); + } + catch (Exception ex) + { + throw new IllegalStateException(ex); + } + } + + public void onError(final Throwable t) + { + LOG.error("TestReadListener error", t); + } + } + + private static boolean compareString(String expected, String actual) + { + String[] listExpected = expected.split("[|]"); + boolean found = true; + for (int i = 0, n = listExpected.length, startIdx = 0, bodyLength = actual.length(); i < n; i++) + { + String search = listExpected[i]; + if (startIdx >= bodyLength) + { + startIdx = bodyLength; + } + + int searchIdx = actual.toLowerCase().indexOf(search.toLowerCase(), startIdx); + + LOG.debug("[ServletTestUtil] Scanning response for " + "search string: '" + search + "' starting at index " + "location: " + startIdx); + if (searchIdx < 0) + { + found = false; + String s = "[ServletTestUtil] Unable to find the following " + + "search string in the server's " + + "response: '" + search + "' at index: " + + startIdx + + "\n[ServletTestUtil] Server's response:\n" + + "-------------------------------------------\n" + + actual + + "\n-------------------------------------------\n"; + LOG.debug(s); + break; + } + + LOG.debug("[ServletTestUtil] Found search string: '" + search + "' at index '" + searchIdx + "' in the server's " + "response"); + // the new searchIdx is the old index plus the lenght of the + // search string. + startIdx = searchIdx + search.length(); + } + return found; + } + + private static void writeChunk(OutputStream out, String data) throws IOException + { + if (data != null) + { + out.write(data.getBytes()); + } + out.flush(); + } +}