diff --git a/jetty-server/src/main/java/org/eclipse/jetty/server/HttpChannel.java b/jetty-server/src/main/java/org/eclipse/jetty/server/HttpChannel.java index add0b3301bc..05d47a91ace 100644 --- a/jetty-server/src/main/java/org/eclipse/jetty/server/HttpChannel.java +++ b/jetty-server/src/main/java/org/eclipse/jetty/server/HttpChannel.java @@ -318,7 +318,12 @@ public class HttpChannel implements Runnable, HttpOutput.Interceptor { case TERMINATED: case WAIT: + // break loop without calling unhandle break loop; + + case NOOP: + // do nothing other than call unhandle + break; case DISPATCH: { diff --git a/jetty-server/src/main/java/org/eclipse/jetty/server/HttpChannelState.java b/jetty-server/src/main/java/org/eclipse/jetty/server/HttpChannelState.java index c5e43b59516..3abddd890d0 100644 --- a/jetty-server/src/main/java/org/eclipse/jetty/server/HttpChannelState.java +++ b/jetty-server/src/main/java/org/eclipse/jetty/server/HttpChannelState.java @@ -73,6 +73,7 @@ public class HttpChannelState */ public enum Action { + NOOP, // No action DISPATCH, // handle a normal request dispatch ASYNC_DISPATCH, // handle an async request dispatch ERROR_DISPATCH, // handle a normal error @@ -243,6 +244,8 @@ public class HttpChannelState case IDLE: case REGISTERED: break; + default: + throw new IllegalStateException(getStatusStringLocked()); } if (_asyncWritePossible) @@ -269,14 +272,13 @@ public class HttpChannelState case STARTED: case EXPIRING: case ERRORING: - return Action.WAIT; + _state=State.ASYNC_WAIT; + return Action.NOOP; case NOT_ASYNC: - break; default: throw new IllegalStateException(getStatusStringLocked()); } - return Action.WAIT; case ASYNC_ERROR: return Action.ASYNC_ERROR; @@ -408,6 +410,7 @@ public class HttpChannelState case DISPATCHED: case ASYNC_IO: case ASYNC_ERROR: + case ASYNC_WAIT: break; default: diff --git a/jetty-servlet/src/test/java/org/eclipse/jetty/servlet/AsyncServletIOTest.java b/jetty-servlet/src/test/java/org/eclipse/jetty/servlet/AsyncServletIOTest.java index 55d9a47767d..690c3ebb95b 100644 --- a/jetty-servlet/src/test/java/org/eclipse/jetty/servlet/AsyncServletIOTest.java +++ b/jetty-servlet/src/test/java/org/eclipse/jetty/servlet/AsyncServletIOTest.java @@ -22,7 +22,6 @@ import static java.nio.charset.StandardCharsets.ISO_8859_1; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.greaterThanOrEqualTo; -import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.startsWith; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; @@ -41,6 +40,9 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; +import java.util.function.UnaryOperator; import javax.servlet.AsyncContext; import javax.servlet.AsyncEvent; @@ -63,6 +65,7 @@ import org.eclipse.jetty.server.Server; import org.eclipse.jetty.server.ServerConnector; import org.eclipse.jetty.util.log.Log; import org.eclipse.jetty.util.log.Logger; +import org.eclipse.jetty.util.thread.QueuedThreadPool; import org.hamcrest.Matchers; import org.junit.After; import org.junit.Assert; @@ -77,14 +80,19 @@ public class AsyncServletIOTest protected AsyncIOServlet2 _servlet2=new AsyncIOServlet2(); protected AsyncIOServlet3 _servlet3=new AsyncIOServlet3(); protected AsyncIOServlet4 _servlet4=new AsyncIOServlet4(); + protected StolenAsyncReadServlet _servletStolenAsyncRead=new StolenAsyncReadServlet(); protected int _port; - protected Server _server = new Server(); + protected WrappingQTP _wQTP; + protected Server _server; protected ServletHandler _servletHandler; protected ServerConnector _connector; @Before public void setUp() throws Exception { + _wQTP = new WrappingQTP(); + _server = new Server(_wQTP); + HttpConfiguration http_config = new HttpConfiguration(); http_config.setOutputBufferSize(4096); _connector = new ServerConnector(_server,new HttpConnectionFactory(http_config)); @@ -113,6 +121,10 @@ public class AsyncServletIOTest holder4.setAsyncSupported(true); _servletHandler.addServletWithMapping(holder4,"/path4/*"); + ServletHolder holder5=new ServletHolder(_servletStolenAsyncRead); + holder5.setAsyncSupported(true); + _servletHandler.addServletWithMapping(holder5,"/stolen/*"); + _server.start(); _port=_connector.getLocalPort(); @@ -787,5 +799,179 @@ public class AsyncServletIOTest } } + + @Test + public void testStolenAsyncRead() throws Exception + { + StringBuilder request = new StringBuilder(512); + request.append("POST /ctx/stolen/info HTTP/1.1\r\n") + .append("Host: localhost\r\n") + .append("Content-Type: text/plain\r\n") + .append("Content-Length: 2\r\n") + .append("\r\n") + .append("1"); + int port=_port; + try (Socket socket = new Socket("localhost",port)) + { + socket.setSoTimeout(10000); + OutputStream out = socket.getOutputStream(); + out.write(request.toString().getBytes(ISO_8859_1)); + out.flush(); + + // wait until server is ready + _servletStolenAsyncRead.ready.await(); + final CountDownLatch wait = new CountDownLatch(1); + + // Stop any dispatches until we want them + UnaryOperator old = _wQTP.wrapper.getAndSet(r-> + ()-> + { + try + { + wait.await(); + r.run(); + } + catch (InterruptedException e) + { + e.printStackTrace(); + } + } + ); + + // We are an unrelated thread, let's mess with the input stream + ServletInputStream sin = _servletStolenAsyncRead.listener.in; + sin.setReadListener(_servletStolenAsyncRead.listener); + // thread should be dispatched to handle, but held by our wQTP wait. + + // Let's steal our read + Assert.assertTrue(sin.isReady()); + Assert.assertThat(sin.read(),Matchers.is((int)'1')); + Assert.assertFalse(sin.isReady()); + + // let the ODA call go + _wQTP.wrapper.set(old); + wait.countDown(); + + // ODA should not be called + Assert.assertFalse(_servletStolenAsyncRead.oda.await(500,TimeUnit.MILLISECONDS)); + + // Send some more data + out.write((int)'2'); + out.flush(); + + // ODA should now be called!! + Assert.assertTrue(_servletStolenAsyncRead.oda.await(500,TimeUnit.MILLISECONDS)); + + // We can not read some more + Assert.assertTrue(sin.isReady()); + Assert.assertThat(sin.read(),Matchers.is((int)'2')); + + // read EOF + Assert.assertTrue(sin.isReady()); + Assert.assertThat(sin.read(),Matchers.is(-1)); + + BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream())); + + // response line + String line = in.readLine(); + LOG.debug("response-line: "+line); + Assert.assertThat(line,startsWith("HTTP/1.1 200 OK")); + + // Skip headers + while (line!=null) + { + line = in.readLine(); + LOG.debug("header-line: "+line); + if (line.length()==0) + break; + } + } + + assertTrue(_servletStolenAsyncRead.completed.await(5, TimeUnit.SECONDS)); + } + + @SuppressWarnings("serial") + public class StolenAsyncReadServlet extends HttpServlet + { + public CountDownLatch ready = new CountDownLatch(1); + public CountDownLatch oda = new CountDownLatch(1); + public CountDownLatch completed = new CountDownLatch(1); + public volatile StealingListener listener; + + @Override + public void doPost(final HttpServletRequest request, final HttpServletResponse response) throws IOException + { + listener = new StealingListener(request); + ready.countDown(); + } + + public class StealingListener implements ReadListener, AsyncListener + { + final HttpServletRequest request; + final ServletInputStream in; + final AsyncContext asyncContext; + + StealingListener(HttpServletRequest request) throws IOException + { + asyncContext = request.startAsync(); + asyncContext.setTimeout(10000L); + asyncContext.addListener(this); + this.request=request; + in = request.getInputStream(); + } + + @Override + public void onDataAvailable() + { + oda.countDown(); + } + + @Override + public void onAllDataRead() throws IOException + { + asyncContext.complete(); + } + + @Override + public void onError(final Throwable t) + { + t.printStackTrace(); + asyncContext.complete(); + } + + @Override + public void onComplete(final AsyncEvent event) + { + completed.countDown(); + } + + @Override + public void onTimeout(final AsyncEvent event) + { + asyncContext.complete(); + } + + @Override + public void onError(final AsyncEvent event) + { + asyncContext.complete(); + } + + @Override + public void onStartAsync(AsyncEvent event) + { + } + } + } + private class WrappingQTP extends QueuedThreadPool + { + AtomicReference> wrapper = new AtomicReference<>(UnaryOperator.identity()); + + @Override + public void execute(Runnable job) + { + super.execute(wrapper.get().apply(job)); + } + } }