Fix `IllegalArgumentException: demand pending` (#11721)

* Do not attempt to read from the underlying content source when there's a demand pending, i.e.: when inputState is unready
* document the inputState FSM and improve the doc of its internal API

Signed-off-by: Ludovic Orban <lorban@bitronix.be>
Signed-off-by: Simone Bordet <simone.bordet@gmail.com>
Co-authored-by: Simone Bordet <simone.bordet@gmail.com>
This commit is contained in:
Ludovic Orban 2024-05-03 14:09:35 +02:00 committed by GitHub
parent b11d1cb27f
commit 8e07ede5f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 129 additions and 190 deletions

View File

@ -219,9 +219,10 @@ class AsyncContentProducer implements ContentProducer
assertLocked(); assertLocked();
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("reclaim {} {}", chunk, this); LOG.debug("reclaim {} {}", chunk, this);
assert chunk == _chunk; if (chunk != _chunk)
throw new IllegalArgumentException("Cannot reclaim unknown chunk");
chunk.release(); chunk.release();
_chunk = null; _chunk = Content.Chunk.next(_chunk);
} }
@Override @Override
@ -270,6 +271,9 @@ class AsyncContentProducer implements ContentProducer
return _servletChannel.getServletRequestState().isInputUnready(); return _servletChannel.getServletRequestState().isInputUnready();
} }
/**
* Never returns an empty chunk that isn't a failure and/or last.
*/
private Content.Chunk produceChunk() private Content.Chunk produceChunk()
{ {
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
@ -309,13 +313,19 @@ class AsyncContentProducer implements ContentProducer
LOG.debug("channel has no new chunk {}", this); LOG.debug("channel has no new chunk {}", this);
return null; return null;
} }
_servletChannel.getServletRequestState().onContentAdded();
} }
} }
} }
private Content.Chunk readChunk() private Content.Chunk readChunk()
{ {
if (_servletChannel.getServletRequestState().isInputUnready())
{
if (LOG.isDebugEnabled())
LOG.debug("readChunk() in unready state, returning null {}", this);
return null;
}
Content.Chunk chunk = _servletChannel.getRequest().read(); Content.Chunk chunk = _servletChannel.getRequest().read();
if (chunk != null) if (chunk != null)
{ {

View File

@ -94,7 +94,7 @@ public interface ContentProducer
* After this call, state can be either of UNREADY or IDLE. * After this call, state can be either of UNREADY or IDLE.
* *
* @return the next content chunk that can be read from or null if the implementation does not block * @return the next content chunk that can be read from or null if the implementation does not block
* and has no available content. * and has no available content. The returned chunk can be empty IFF it is a failure and/or last.
*/ */
Content.Chunk nextChunk(); Content.Chunk nextChunk();

View File

@ -188,14 +188,14 @@ public class HttpInput extends ServletInputStream implements Runnable
LOG.debug("setting read listener to {} {}", readListener, this); LOG.debug("setting read listener to {} {}", readListener, this);
if (_readListener != null) if (_readListener != null)
throw new IllegalStateException("ReadListener already set"); throw new IllegalStateException("ReadListener already set");
//illegal if async not started // illegal if async not started
if (!_channelState.isAsyncStarted()) if (!_channelState.isAsyncStarted())
throw new IllegalStateException("Async not started"); throw new IllegalStateException("Async not started");
_readListener = Objects.requireNonNull(readListener); _readListener = Objects.requireNonNull(readListener);
_contentProducer = _asyncContentProducer; _contentProducer = _asyncContentProducer;
// trigger content production // trigger content production
if (isReady() && _channelState.onReadEof()) // onReadEof b/c we want to transition from WAITING to WOKEN if (isReady() && _channelState.onReadListenerReady()) // onReadListenerReady b/c we want to transition from WAITING to WOKEN
scheduleReadListenerNotification(); // this is needed by AsyncServletIOTest.testStolenAsyncRead scheduleReadListenerNotification(); // this is needed by AsyncServletIOTest.testStolenAsyncRead
} }
@ -244,6 +244,8 @@ public class HttpInput extends ServletInputStream implements Runnable
Content.Chunk chunk = _contentProducer.nextChunk(); Content.Chunk chunk = _contentProducer.nextChunk();
if (chunk == null) if (chunk == null)
throw new IllegalStateException("read on unready input"); throw new IllegalStateException("read on unready input");
// Is it not empty?
if (chunk.hasRemaining()) if (chunk.hasRemaining())
{ {
int read = buffer == null ? get(chunk, b, off, len) : get(chunk, buffer); int read = buffer == null ? get(chunk, b, off, len) : get(chunk, buffer);
@ -254,6 +256,7 @@ public class HttpInput extends ServletInputStream implements Runnable
return read; return read;
} }
// Is it a failure?
if (Content.Chunk.isFailure(chunk)) if (Content.Chunk.isFailure(chunk))
{ {
Throwable failure = chunk.getFailure(); Throwable failure = chunk.getFailure();
@ -264,10 +267,11 @@ public class HttpInput extends ServletInputStream implements Runnable
throw new IOException(failure); throw new IOException(failure);
} }
// Empty and not a failure; can only be EOF as per ContentProducer.nextChunk() contract.
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("read at EOF, setting consumed EOF to true {}", this); LOG.debug("read at EOF, setting consumed EOF to true {}", this);
_consumedEof = true; _consumedEof = true;
// If EOF do we need to wake for allDataRead callback? // Do we need to wake for allDataRead callback?
if (onContentProducible()) if (onContentProducible())
scheduleReadListenerNotification(); scheduleReadListenerNotification();
return -1; return -1;
@ -276,6 +280,8 @@ public class HttpInput extends ServletInputStream implements Runnable
private void scheduleReadListenerNotification() private void scheduleReadListenerNotification()
{ {
if (LOG.isDebugEnabled())
LOG.debug("scheduling ReadListener notification {}", this);
_servletChannel.execute(_servletChannel::handle); _servletChannel.execute(_servletChannel::handle);
} }

View File

@ -388,9 +388,10 @@ public class ServletChannel
*/ */
void recycle(Throwable x) void recycle(Throwable x)
{ {
_state.recycle(); // _httpInput must be recycled before _state.
_httpInput.recycle(); _httpInput.recycle();
_httpOutput.recycle(); _httpOutput.recycle();
_state.recycle();
_servletContextRequest = null; _servletContextRequest = null;
_request = null; _request = null;
_response = null; _response = null;

View File

@ -28,6 +28,7 @@ import org.eclipse.jetty.io.QuietException;
import org.eclipse.jetty.server.Request; import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.server.Response; import org.eclipse.jetty.server.Response;
import org.eclipse.jetty.server.handler.ErrorHandler; import org.eclipse.jetty.server.handler.ErrorHandler;
import org.eclipse.jetty.util.ExceptionUtil;
import org.eclipse.jetty.util.thread.AutoLock; import org.eclipse.jetty.util.thread.AutoLock;
import org.eclipse.jetty.util.thread.Scheduler; import org.eclipse.jetty.util.thread.Scheduler;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -102,12 +103,51 @@ public class ServletChannelState
/* /*
* The input readiness state. * The input readiness state.
* <pre>
* read() without preceding
* isReady() ------
* \ \ unhandle() returns Action.READ_CALLBACK to call the ReadListener,
* \ \ or read() stole available content after setReadListener()
* --> IDLE <--------------
* blocking read() unblocked ^ \
* | \
* | \ setReadListener() called while
* registering demand v v content is available
* UNREADY ------------> READY
* demand
* serviced
* </pre>
*/ */
private enum InputState private enum InputState
{ {
IDLE, // No isReady; No data /**
UNREADY, // isReady()==false; No data * The 'default' state, when there is no pending demand nor a pending notification to the ReadListener.
READY // isReady() was false; data is available * There are 3 ways to transition to this state:
* <ul>
* <li>from IDLE: when an async read() is called without a preceding call to isReady()</li>
* <li>from READY: just before unhandle() returns Action.READ_CALLBACK to call read listener or
* when read() steals available content after setReadListener()</li>
* <li>from UNREADY: when a blocking read() got unblocked</li>
* </ul>
*/
IDLE,
/**
* The 'demand registered' state. There is only 1 way to transition to this state:
* <ul>
* <li>from IDLE: when isReady() is called and there is no content available, so a demand is registered</li>
* </ul>
*/
UNREADY,
/**
* The 'dispatch a notification to the ReadListener' state. There are 2 ways to transition to this state:
* <ul>
* <li>from IDLE: when setReadListener() is called while there is content available</li>
* <li>from UNREADY: when demand is serviced because content is now available</li>
* </ul>
*/
READY
} }
/* /*
@ -1027,13 +1067,9 @@ public class ServletChannelState
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("completing {}", toStringLocked()); LOG.debug("completing {}", toStringLocked());
switch (_requestState) if (_requestState == RequestState.COMPLETED)
{ throw new IllegalStateException(getStatusStringLocked());
case COMPLETED: _requestState = RequestState.COMPLETING;
throw new IllegalStateException(getStatusStringLocked());
default:
_requestState = RequestState.COMPLETING;
}
} }
} }
@ -1049,7 +1085,7 @@ public class ServletChannelState
LOG.debug("completed {}", toStringLocked()); LOG.debug("completed {}", toStringLocked());
if (_requestState != RequestState.COMPLETING) if (_requestState != RequestState.COMPLETING)
throw new IllegalStateException(this.getStatusStringLocked()); failure = ExceptionUtil.combine(failure, new IllegalStateException(getStatusStringLocked()));
if (failure != null) if (failure != null)
abortResponse(failure); abortResponse(failure);
@ -1154,18 +1190,14 @@ public class ServletChannelState
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("upgrade {}", toStringLocked()); LOG.debug("upgrade {}", toStringLocked());
switch (_state) if (_state != State.IDLE)
{ throw new IllegalStateException(getStatusStringLocked());
case IDLE: if (_inputState != InputState.IDLE)
break; throw new IllegalStateException(getStatusStringLocked());
default:
throw new IllegalStateException(getStatusStringLocked());
}
_asyncListeners = null; _asyncListeners = null;
_state = State.UPGRADED; _state = State.UPGRADED;
_requestState = RequestState.BLOCKING; _requestState = RequestState.BLOCKING;
_initial = true; _initial = true;
_inputState = InputState.IDLE;
_asyncWritePossible = false; _asyncWritePossible = false;
_timeoutMs = DEFAULT_TIMEOUT; _timeoutMs = DEFAULT_TIMEOUT;
_event = null; _event = null;
@ -1306,19 +1338,17 @@ public class ServletChannelState
return woken; return woken;
} }
public boolean onReadEof() public boolean onReadListenerReady()
{ {
boolean woken = false; boolean woken = false;
try (AutoLock ignored = lock()) try (AutoLock ignored = lock())
{ {
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("onReadEof {}", toStringLocked()); LOG.debug("onReadListenerReady {}", toStringLocked());
switch (_inputState) switch (_inputState)
{ {
case IDLE: case IDLE:
case READY:
case UNREADY:
_inputState = InputState.READY; _inputState = InputState.READY;
if (_state == State.WAITING) if (_state == State.WAITING)
{ {
@ -1327,6 +1357,8 @@ public class ServletChannelState
} }
break; break;
case READY:
case UNREADY:
default: default:
throw new IllegalStateException(toStringLocked()); throw new IllegalStateException(toStringLocked());
} }
@ -1334,31 +1366,6 @@ public class ServletChannelState
return woken; return woken;
} }
/**
* Called to indicate that some content was produced and is
* ready for consumption.
*/
public void onContentAdded()
{
try (AutoLock ignored = lock())
{
if (LOG.isDebugEnabled())
LOG.debug("onContentAdded {}", toStringLocked());
switch (_inputState)
{
case IDLE:
case UNREADY:
case READY:
_inputState = InputState.READY;
break;
default:
throw new IllegalStateException(toStringLocked());
}
}
}
/** /**
* Called to indicate that the content is being consumed. * Called to indicate that the content is being consumed.
*/ */
@ -1398,11 +1405,11 @@ public class ServletChannelState
switch (_inputState) switch (_inputState)
{ {
case IDLE: case IDLE:
case UNREADY:
case READY: // READY->UNREADY is needed by AsyncServletIOTest.testStolenAsyncRead
_inputState = InputState.UNREADY; _inputState = InputState.UNREADY;
break; break;
case READY:
case UNREADY:
default: default:
throw new IllegalStateException(toStringLocked()); throw new IllegalStateException(toStringLocked());
} }

View File

@ -26,8 +26,6 @@ import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.UnaryOperator;
import jakarta.servlet.AsyncContext; import jakarta.servlet.AsyncContext;
import jakarta.servlet.AsyncEvent; import jakarta.servlet.AsyncEvent;
@ -41,14 +39,13 @@ import jakarta.servlet.WriteListener;
import jakarta.servlet.http.HttpServlet; import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
import org.eclipse.jetty.http.HttpTester;
import org.eclipse.jetty.server.Connector; import org.eclipse.jetty.server.Connector;
import org.eclipse.jetty.server.HttpConfiguration; import org.eclipse.jetty.server.HttpConfiguration;
import org.eclipse.jetty.server.HttpConnectionFactory; import org.eclipse.jetty.server.HttpConnectionFactory;
import org.eclipse.jetty.server.Server; import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector; import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.util.BufferUtil; import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.thread.QueuedThreadPool;
import org.hamcrest.Matchers;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -63,6 +60,7 @@ import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.startsWith; import static org.hamcrest.Matchers.startsWith;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@ -76,7 +74,6 @@ public class AsyncServletIOTest
protected AsyncIOServlet4 _servlet4 = new AsyncIOServlet4(); protected AsyncIOServlet4 _servlet4 = new AsyncIOServlet4();
protected StolenAsyncReadServlet _servletStolenAsyncRead = new StolenAsyncReadServlet(); protected StolenAsyncReadServlet _servletStolenAsyncRead = new StolenAsyncReadServlet();
protected int _port; protected int _port;
protected WrappingQTP _wQTP;
protected Server _server; protected Server _server;
protected ServletHandler _servletHandler; protected ServletHandler _servletHandler;
protected ServerConnector _connector; protected ServerConnector _connector;
@ -84,8 +81,7 @@ public class AsyncServletIOTest
@BeforeEach @BeforeEach
public void setUp() throws Exception public void setUp() throws Exception
{ {
_wQTP = new WrappingQTP(); _server = new Server();
_server = new Server(_wQTP);
HttpConfiguration httpConfig = new HttpConfiguration(); HttpConfiguration httpConfig = new HttpConfiguration();
httpConfig.setOutputBufferSize(4096); httpConfig.setOutputBufferSize(4096);
@ -794,125 +790,78 @@ public class AsyncServletIOTest
@Test @Test
public void testStolenAsyncRead() throws Exception public void testStolenAsyncRead() throws Exception
{ {
StringBuilder request = new StringBuilder(512); String request = """
request.append("POST /ctx/stolen/info HTTP/1.1\r\n") POST /ctx/stolen/info HTTP/1.1
.append("Host: localhost\r\n") Host: localhost
.append("Content-Type: text/plain\r\n") Content-Type: text/plain
.append("Content-Length: 2\r\n") Content-Length: 2
.append("\r\n")
.append("1"); 1""";
int port = _port;
try (Socket socket = new Socket("localhost", port)) try (Socket socket = new Socket("localhost", _port))
{ {
socket.setSoTimeout(10000); socket.setSoTimeout(10000);
OutputStream out = socket.getOutputStream(); OutputStream out = socket.getOutputStream();
out.write(request.toString().getBytes(ISO_8859_1)); out.write(request.getBytes(ISO_8859_1));
out.flush(); out.flush();
// wait until server is ready // Because the read was stolen, onDataAvailable() is not called.
_servletStolenAsyncRead.ready.await(); // The wait guarantees that the Servlet thread is out of doPost().
final CountDownLatch wait = new CountDownLatch(1);
final CountDownLatch held = new CountDownLatch(1);
// Stop any dispatches until we want them
UnaryOperator<Runnable> old = _wQTP.wrapper.getAndSet(r ->
() ->
{
try
{
held.countDown();
wait.await();
r.run();
}
catch (InterruptedException e)
{
e.printStackTrace();
}
}
);
// We are an unrelated thread, let's mess with the input stream
ServletInputStream sin = _servletStolenAsyncRead.listener.in;
sin.setReadListener(_servletStolenAsyncRead.listener);
// thread should be dispatched to handle, but held by our wQTP wait.
assertTrue(held.await(10, TimeUnit.SECONDS));
// Let's steal our read
assertTrue(sin.isReady());
assertThat(sin.read(), Matchers.is((int)'1'));
assertFalse(sin.isReady());
// let the ODA call go
_wQTP.wrapper.set(old);
wait.countDown();
// ODA should not be called
assertFalse(_servletStolenAsyncRead.oda.await(500, TimeUnit.MILLISECONDS)); assertFalse(_servletStolenAsyncRead.oda.await(500, TimeUnit.MILLISECONDS));
// Send some more data // Send some more data.
out.write((int)'2'); out.write('2');
out.flush(); out.flush();
// ODA should now be called!! // onDataAvailable() should now be called.
assertTrue(_servletStolenAsyncRead.oda.await(500, TimeUnit.MILLISECONDS)); assertTrue(_servletStolenAsyncRead.oda.await(500, TimeUnit.MILLISECONDS));
// We can not read some more ServletInputStream in = _servletStolenAsyncRead.listener.in;
assertTrue(sin.isReady());
assertThat(sin.read(), Matchers.is((int)'2'));
// read EOF // We can now read some more.
assertTrue(sin.isReady()); assertTrue(in.isReady());
assertThat(sin.read(), Matchers.is(-1)); assertEquals('2', in.read());
BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream())); // All content has been sent, must read EOF.
assertTrue(in.isReady());
assertEquals(-1, in.read());
// response line HttpTester.Response response = HttpTester.parseResponse(socket.getInputStream());
String line = in.readLine(); assertNotNull(response);
LOG.debug("response-line: " + line); assertEquals(200, response.getStatus());
assertThat(line, startsWith("HTTP/1.1 200 OK"));
// Skip headers
while (line != null)
{
line = in.readLine();
LOG.debug("header-line: " + line);
if (line.length() == 0)
break;
}
} }
assertTrue(_servletStolenAsyncRead.completed.await(5, TimeUnit.SECONDS));
} }
@SuppressWarnings("serial") public static class StolenAsyncReadServlet extends HttpServlet
public class StolenAsyncReadServlet extends HttpServlet
{ {
public CountDownLatch ready = new CountDownLatch(1); private final CountDownLatch oda = new CountDownLatch(1);
public CountDownLatch oda = new CountDownLatch(1); private volatile StealingListener listener;
public CountDownLatch completed = new CountDownLatch(1);
public volatile StealingListener listener;
@Override @Override
public void doPost(final HttpServletRequest request, final HttpServletResponse response) throws IOException public void doPost(final HttpServletRequest request, final HttpServletResponse response) throws IOException
{ {
listener = new StealingListener(request); listener = new StealingListener(request);
ready.countDown();
// Steal the read.
assertEquals('1', listener.in.read());
// Make sure the ReadListener is called when more content is available.
assertFalse(listener.in.isReady());
// Exit from doPost() so that ReadListener methods can now be invoked.
} }
public class StealingListener implements ReadListener, AsyncListener public class StealingListener implements ReadListener
{ {
final HttpServletRequest request; private final ServletInputStream in;
final ServletInputStream in; private final AsyncContext asyncContext;
final AsyncContext asyncContext;
StealingListener(HttpServletRequest request) throws IOException public StealingListener(HttpServletRequest request) throws IOException
{ {
asyncContext = request.startAsync(); asyncContext = request.startAsync();
asyncContext.setTimeout(10000L); asyncContext.setTimeout(0);
asyncContext.addListener(this);
this.request = request;
in = request.getInputStream(); in = request.getInputStream();
in.setReadListener(this);
} }
@Override @Override
@ -922,51 +871,17 @@ public class AsyncServletIOTest
} }
@Override @Override
public void onAllDataRead() throws IOException public void onAllDataRead()
{ {
asyncContext.complete(); asyncContext.complete();
} }
@Override @Override
public void onError(final Throwable t) public void onError(Throwable t)
{ {
t.printStackTrace(); t.printStackTrace();
asyncContext.complete(); asyncContext.complete();
} }
@Override
public void onComplete(final AsyncEvent event)
{
completed.countDown();
}
@Override
public void onTimeout(final AsyncEvent event)
{
asyncContext.complete();
}
@Override
public void onError(final AsyncEvent event)
{
asyncContext.complete();
}
@Override
public void onStartAsync(AsyncEvent event)
{
}
}
}
private class WrappingQTP extends QueuedThreadPool
{
AtomicReference<UnaryOperator<Runnable>> wrapper = new AtomicReference<>(UnaryOperator.identity());
@Override
public void execute(Runnable job)
{
super.execute(wrapper.get().apply(job));
} }
} }
} }