Alternate DelayedHandler & ThreadLimitHandler implementations #9051 (#9056)

* Improved javadoc
* Refactored ThreadLimitHandler to avoid lambda creation and to always execute
* Refactored DelayedHandler to avoid lambda creation and to execute only if needed
* added modules for the DelayedHandler

Signed-off-by: Simone Bordet <simone.bordet@gmail.com>
Signed-off-by: Greg Wilkins <gregw@webtide.com>
Co-authored-by: Simone Bordet <simone.bordet@gmail.com>
This commit is contained in:
Greg Wilkins 2022-12-24 10:49:29 +11:00 committed by GitHub
parent 0e95953be3
commit 838091d2be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 1151 additions and 426 deletions

View File

@ -208,6 +208,25 @@ public class MimeTypes
}) })
.build(); .build();
public static Type getBaseType(String contentType)
{
if (StringUtil.isEmpty(contentType))
return null;
Type type = CACHE.getBest(contentType);
if (type == null)
return null;
if (type.asString().length() == contentType.length())
return type.getBaseType();
if (contentType.charAt(type.asString().length()) == ';')
return type.getBaseType();
contentType = contentType.replace(" ", "");
if (type.asString().length() == contentType.length())
return type.getBaseType();
if (contentType.charAt(type.asString().length()) == ';')
return type.getBaseType();
return null;
}
protected final Map<String, String> _mimeMap = new HashMap<>(); protected final Map<String, String> _mimeMap = new HashMap<>();
protected final Map<String, String> _inferredEncodings = new HashMap<>(); protected final Map<String, String> _inferredEncodings = new HashMap<>();
protected final Map<String, String> _assumedEncodings = new HashMap<>(); protected final Map<String, String> _assumedEncodings = new HashMap<>();

View File

@ -127,6 +127,29 @@ public class MimeTypesTest
MimeTypes.getContentTypeWithoutCharset(contentTypeWithCharset), is(expectedContentType)); MimeTypes.getContentTypeWithoutCharset(contentTypeWithCharset), is(expectedContentType));
} }
public static Stream<Arguments> mimeTypesGetBaseTypeCases()
{
return Stream.of(
Arguments.of("foo/bar", null),
Arguments.of("foo/bar;charset=abc;some=else", null),
Arguments.of("text/html", MimeTypes.Type.TEXT_HTML),
Arguments.of("text/html;charset=utf-8", MimeTypes.Type.TEXT_HTML),
Arguments.of("text/html; charset=iso-8859-1", MimeTypes.Type.TEXT_HTML),
Arguments.of("text/html;charset=utf-8;other=param", MimeTypes.Type.TEXT_HTML),
Arguments.of("text/html;other=param;charset=iso-8859-1", MimeTypes.Type.TEXT_HTML),
Arguments.of(null, null)
);
}
@ParameterizedTest
@MethodSource("mimeTypesGetBaseTypeCases")
public void testMimeTypesGetBaseType(String contentTypeWithCharset, MimeTypes.Type expectedType)
{
MimeTypes.CACHE.keySet().forEach(System.err::println);
assertThat(MimeTypes.getBaseType(contentTypeWithCharset), is(expectedType));
}
@Test @Test
public void testWrapper() public void testWrapper()
{ {

View File

@ -305,9 +305,12 @@ public class Content
/** /**
* <p>Demands to invoke the given demand callback parameter when a chunk of content is available.</p> * <p>Demands to invoke the given demand callback parameter when a chunk of content is available.</p>
* <p>See how to use this method <a href="#idiom">idiomatically</a>.</p> * <p>See how to use this method <a href="#idiom">idiomatically</a>.</p>
* <p>Implementations must guarantee that calls to this method are safely reentrant, to avoid * <p>Implementations guarantee that calls to this method are safely reentrant so that
* stack overflows in the case of mutual recursion between the execution of the {@code Runnable} * stack overflows are avoided in the case of mutual recursion between the execution of
* callback and a call to this method.</p> * the {@code Runnable} callback and a call to this method. Invocations of the passed
* {@code Runnable} are serialized and a callback for {@code demand} call is
* not invoked until any previous {@code demand} callback has returned.
* Thus the {@code Runnable} should not block waiting for a callback of a future demand call.</p>
* <p>The demand callback may be invoked <em>spuriously</em>: a subsequent call to {@link #read()} * <p>The demand callback may be invoked <em>spuriously</em>: a subsequent call to {@link #read()}
* may return {@code null}.</p> * may return {@code null}.</p>
* <p>Calling this method establishes a <em>pending demand</em>, which is fulfilled when the demand * <p>Calling this method establishes a <em>pending demand</em>, which is fulfilled when the demand
@ -399,7 +402,9 @@ public class Content
* *
* @param last whether the String is the last to write * @param last whether the String is the last to write
* @param utf8Content the String to write * @param utf8Content the String to write
* @param callback the callback to notify when the write operation is complete * @param callback the callback to notify when the write operation is complete.
* Implementations have the same guarantees for invocation of this
* callback as for {@link #write(boolean, ByteBuffer, Callback)}.
*/ */
static void write(Sink sink, boolean last, String utf8Content, Callback callback) static void write(Sink sink, boolean last, String utf8Content, Callback callback)
{ {
@ -409,6 +414,9 @@ public class Content
/** /**
* <p>Writes the given {@link ByteBuffer}, notifying the {@link Callback} * <p>Writes the given {@link ByteBuffer}, notifying the {@link Callback}
* when the write is complete.</p> * when the write is complete.</p>
* <p>Implementations guarantee that calls to this method are safely reentrant so that
* stack overflows are avoided in the case of mutual recursion between the execution of
* the {@code Callback} and a call to this method.</p>
* *
* @param last whether the ByteBuffer is the last to write * @param last whether the ByteBuffer is the last to write
* @param byteBuffer the ByteBuffer to write * @param byteBuffer the ByteBuffer to write

View File

@ -0,0 +1,15 @@
<?xml version="1.0"?>
<!DOCTYPE Configure PUBLIC "-//Jetty//Configure//EN" "https://www.eclipse.org/jetty/configure_10_0.dtd">
<!-- =============================================================== -->
<!-- Mixin the Thread Limit Handler to the entire server -->
<!-- =============================================================== -->
<Configure id="Server" class="org.eclipse.jetty.server.Server">
<Call name="insertHandler">
<Arg>
<New id="DelayedHandler" class="org.eclipse.jetty.server.handler.DelayedHandler">
</New>
</Arg>
</Call>
</Configure>

View File

@ -0,0 +1,20 @@
[description]
Applies DelayedHandler to entire server.
Delays request handling until any body content has arrived, to minimize blocking.
For form data and multipart, the handling is delayed until the entire request body has
been asynchronously read. For all other content types, the delay is until the first byte
has arrived.
[tags]
server
[depend]
server
[after]
threadlimit
[xml]
etc/jetty-delayed.xml

View File

@ -1,13 +1,10 @@
[description] [description]
Applies ThreadLimiteHandler to entire server Applies ThreadLimitHandler to entire server, to limit the threads per IP address for DOS protection.
[tags] [tags]
server server
[description]
Limit the threads per IP address for DOS protection.
[depend] [depend]
server server

View File

@ -49,7 +49,6 @@ public class FormFields extends CompletableFuture<Fields> implements Runnable
if (request.getLength() == 0 || StringUtil.isBlank(contentType)) if (request.getLength() == 0 || StringUtil.isBlank(contentType))
return null; return null;
// TODO mimeTypes from context
MimeTypes.Type type = MimeTypes.CACHE.get(MimeTypes.getContentTypeWithoutCharset(contentType)); MimeTypes.Type type = MimeTypes.CACHE.get(MimeTypes.getContentTypeWithoutCharset(contentType));
if (MimeTypes.Type.FORM_ENCODED != type) if (MimeTypes.Type.FORM_ENCODED != type)
return null; return null;
@ -60,21 +59,20 @@ public class FormFields extends CompletableFuture<Fields> implements Runnable
public static CompletableFuture<Fields> from(Request request) public static CompletableFuture<Fields> from(Request request)
{ {
Object attr = request.getAttribute(FormFields.class.getName()); // TODO make this attributes provided by the ContextRequest wrapper
if (attr instanceof FormFields futureFormFields)
return futureFormFields;
Charset charset = getFormEncodedCharset(request);
if (charset == null)
return EMPTY;
int maxFields = getRequestAttribute(request, FormFields.MAX_FIELDS_ATTRIBUTE); int maxFields = getRequestAttribute(request, FormFields.MAX_FIELDS_ATTRIBUTE);
int maxLength = getRequestAttribute(request, FormFields.MAX_LENGTH_ATTRIBUTE); int maxLength = getRequestAttribute(request, FormFields.MAX_LENGTH_ATTRIBUTE);
FormFields futureFormFields = new FormFields(request, charset, maxFields, maxLength); return from(request, maxFields, maxLength);
futureFormFields.run(); }
request.setAttribute(FormFields.class.getName(), futureFormFields);
return futureFormFields; public static CompletableFuture<Fields> from(Request request, Charset charset)
{
// TODO make this attributes provided by the ContextRequest wrapper
int maxFields = getRequestAttribute(request, FormFields.MAX_FIELDS_ATTRIBUTE);
int maxLength = getRequestAttribute(request, FormFields.MAX_LENGTH_ATTRIBUTE);
return from(request, charset, maxFields, maxLength);
} }
public static CompletableFuture<Fields> from(Request request, int maxFields, int maxLength) public static CompletableFuture<Fields> from(Request request, int maxFields, int maxLength)
@ -87,9 +85,14 @@ public class FormFields extends CompletableFuture<Fields> implements Runnable
if (charset == null) if (charset == null)
return EMPTY; return EMPTY;
return from(request, charset, maxFields, maxLength);
}
public static CompletableFuture<Fields> from(Request request, Charset charset, int maxFields, int maxLength)
{
FormFields futureFormFields = new FormFields(request, charset, maxFields, maxLength); FormFields futureFormFields = new FormFields(request, charset, maxFields, maxLength);
futureFormFields.run();
request.setAttribute(FormFields.class.getName(), futureFormFields); request.setAttribute(FormFields.class.getName(), futureFormFields);
futureFormFields.run();
return futureFormFields; return futureFormFields;
} }

View File

@ -17,6 +17,7 @@ import java.io.InputStream;
import java.net.InetAddress; import java.net.InetAddress;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.SocketAddress; import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@ -183,6 +184,18 @@ public interface Request extends Attributes, Content.Source
*/ */
HttpFields getHeaders(); HttpFields getHeaders();
/**
* {@inheritDoc}
* @param demandCallback the demand callback to invoke when there is a content chunk available.
* In addition to the invocation guarantees of {@link Content.Source#demand(Runnable)},
* this implementation serializes the invocation of the {@code Runnable} with
* invocations of any {@link Response#write(boolean, ByteBuffer, Callback)}
* {@code Callback} invocations.
* @see Content.Source#demand(Runnable)
*/
@Override
void demand(Runnable demandCallback);
/** /**
* @return the HTTP trailers of this request, or {@code null} if they are not present * @return the HTTP trailers of this request, or {@code null} if they are not present
*/ */

View File

@ -65,6 +65,22 @@ public interface Response extends Content.Sink
CompletableFuture<Void> writeInterim(int status, HttpFields headers); CompletableFuture<Void> writeInterim(int status, HttpFields headers);
/**
* {@inheritDoc}
* <p>Invocations of the passed {@code Callback} are serialized and a callback for a completed {@code write} call is
* not invoked until any previous {@code write} callback has returned.
* Thus the {@code Callback} should not block waiting for a callback of a future write call.</p>
* @param last whether the ByteBuffer is the last to write
* @param byteBuffer the ByteBuffer to write
* @param callback the callback to notify when the write operation is complete
* In addition to the invocation guarantees of {@link Content.Sink#write(boolean, ByteBuffer, Callback)},
* this implementation serializes the invocation of the {@code Callback} with
* invocations of any {@link Request#demand(Runnable)} {@code Runnable} invocations.
* @see Content.Sink#write(boolean, ByteBuffer, Callback)
*/
@Override
void write(boolean last, ByteBuffer byteBuffer, Callback callback);
/** /**
* <p>Returns a chunk processor suitable to be passed to the * <p>Returns a chunk processor suitable to be passed to the
* {@link Content#copy(Content.Source, Content.Sink, Content.Chunk.Processor, Callback)} * {@link Content#copy(Content.Source, Content.Sink, Content.Chunk.Processor, Callback)}

View File

@ -34,7 +34,8 @@ public class ContextRequest extends Request.Wrapper implements Invocable
@Override @Override
public void demand(Runnable demandCallback) public void demand(Runnable demandCallback)
{ {
super.demand(() -> _context.run(demandCallback, this)); // inner class used instead of lambda for clarity in stack traces
super.demand(new OnContextDemand(demandCallback));
} }
@Override @Override
@ -66,4 +67,20 @@ public class ContextRequest extends Request.Wrapper implements Invocable
default -> super.getAttribute(name); default -> super.getAttribute(name);
}; };
} }
private class OnContextDemand implements Runnable
{
private final Runnable _demandCallback;
public OnContextDemand(Runnable demandCallback)
{
_demandCallback = demandCallback;
}
@Override
public void run()
{
_context.run(_demandCallback, ContextRequest.this);
}
}
} }

View File

@ -13,11 +13,16 @@
package org.eclipse.jetty.server.handler; package org.eclipse.jetty.server.handler;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.Objects; import java.util.Objects;
import java.util.function.BiConsumer; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicReference;
import org.eclipse.jetty.http.HttpField; import org.eclipse.jetty.http.HttpField;
import org.eclipse.jetty.http.HttpHeader; import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.HttpHeaderValue;
import org.eclipse.jetty.http.HttpStatus; import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.http.MimeTypes; import org.eclipse.jetty.http.MimeTypes;
import org.eclipse.jetty.http.MultiPart; import org.eclipse.jetty.http.MultiPart;
@ -27,10 +32,13 @@ import org.eclipse.jetty.server.FormFields;
import org.eclipse.jetty.server.Handler; import org.eclipse.jetty.server.Handler;
import org.eclipse.jetty.server.Request; import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.server.Response; import org.eclipse.jetty.server.Response;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.util.Callback; import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.Fields; import org.eclipse.jetty.util.Fields;
import org.eclipse.jetty.util.IO;
import org.eclipse.jetty.util.StringUtil;
public abstract class DelayedHandler extends Handler.Wrapper public class DelayedHandler extends Handler.Wrapper
{ {
@Override @Override
public boolean process(Request request, Response response, Callback callback) throws Exception public boolean process(Request request, Response response, Callback callback) throws Exception
@ -39,29 +47,79 @@ public abstract class DelayedHandler extends Handler.Wrapper
if (next == null) if (next == null)
return false; return false;
DelayedProcess delayed = newDelayedProcess(next, request, response, callback); boolean contentExpected = false;
String contentType = null;
loop: for (HttpField field : request.getHeaders())
{
HttpHeader header = field.getHeader();
if (header == null)
continue;
switch (header)
{
case CONTENT_TYPE:
contentType = field.getValue();
break;
case CONTENT_LENGTH:
contentExpected = field.getLongValue() > 0;
break;
case TRANSFER_ENCODING:
contentExpected = field.contains(HttpHeaderValue.CHUNKED.asString());
break;
case EXPECT:
if (field.contains(HttpHeaderValue.CONTINUE.asString()))
{
contentExpected = false;
break loop;
}
break;
default:
break;
}
}
MimeTypes.Type mimeType = MimeTypes.getBaseType(contentType);
DelayedProcess delayed = newDelayedProcess(contentExpected, contentType, mimeType, next, request, response, callback);
if (delayed == null) if (delayed == null)
return next.process(request, response, callback); return next.process(request, response, callback);
delay(delayed); delayed.delay();
return true; return true;
} }
protected DelayedProcess newDelayedProcess(Handler next, Request request, Response response, Callback callback) protected DelayedProcess newDelayedProcess(boolean contentExpected, String contentType, MimeTypes.Type mimeType, Handler handler, Request request, Response response, Callback callback)
{ {
return new DelayedProcess(next, request, response, callback); // if no content is expected, then no delay
if (!contentExpected)
return null;
// if we are not configured to delay dispatch, then no delay
if (!request.getConnectionMetaData().getHttpConfiguration().isDelayDispatchUntilContent())
return null;
// If there is no known content type, then delay only until content is available
if (mimeType == null)
return new UntilContentDelayedProcess(handler, request, response, callback);
// Otherwise, delay until a known content type is fully read; or if the type is not known then until the content is available
return switch (mimeType)
{
case FORM_ENCODED -> new UntilFormDelayedProcess(handler, request, response, callback, contentType);
case MULTIPART_FORM_DATA -> new UntilMultiPartDelayedProcess(handler, request, response, callback, contentType);
default -> new UntilContentDelayedProcess(handler, request, response, callback);
};
} }
protected abstract void delay(DelayedProcess delay) throws Exception; protected abstract static class DelayedProcess
protected static class DelayedProcess implements Runnable
{ {
private final Handler _handler; private final Handler _handler;
private final Request _request; private final Request _request;
private final Response _response; private final Response _response;
private final Callback _callback; private final Callback _callback;
public DelayedProcess(Handler handler, Request request, Response response, Callback callback) protected DelayedProcess(Handler handler, Request request, Response response, Callback callback)
{ {
_handler = Objects.requireNonNull(handler); _handler = Objects.requireNonNull(handler);
_request = Objects.requireNonNull(request); _request = Objects.requireNonNull(request);
@ -89,17 +147,11 @@ public abstract class DelayedHandler extends Handler.Wrapper
return _callback; return _callback;
} }
protected boolean process() throws Exception protected void process()
{
return getHandler().process(getRequest(), getResponse(), getCallback());
}
@Override
public void run()
{ {
try try
{ {
if (!process()) if (!getHandler().process(getRequest(), getResponse(), getCallback()))
Response.writeError(getRequest(), getResponse(), getCallback(), HttpStatus.NOT_FOUND_404); Response.writeError(getRequest(), getResponse(), getCallback(), HttpStatus.NOT_FOUND_404);
} }
catch (Throwable t) catch (Throwable t)
@ -107,148 +159,187 @@ public abstract class DelayedHandler extends Handler.Wrapper
Response.writeError(getRequest(), getResponse(), getCallback(), t); Response.writeError(getRequest(), getResponse(), getCallback(), t);
} }
} }
protected abstract void delay() throws Exception;
} }
public static class UntilContent extends DelayedHandler protected static class UntilContentDelayedProcess extends DelayedProcess
{ {
@Override public UntilContentDelayedProcess(Handler handler, Request request, Response response, Callback callback)
protected DelayedProcess newDelayedProcess(Handler next, Request request, Response response, Callback callback)
{ {
if (!request.getConnectionMetaData().getHttpConfiguration().isDelayDispatchUntilContent()) super(handler, request, response, callback);
return null;
if (request.getLength() == 0 && !request.getHeaders().contains(HttpHeader.CONTENT_TYPE))
return null;
// TODO: add logic to not delay if it's a CONNECT request.
// TODO: also add logic to not delay if it's a request that expects 100 Continue.
return new DelayedProcess(next, request, response, callback);
} }
@Override @Override
protected void delay(DelayedProcess request) protected void delay()
{ {
request.getRequest().demand(request); Content.Chunk chunk = super.getRequest().read();
} if (chunk == null)
}
public static class UntilFormFields extends DelayedHandler
{
@Override
protected FormDelayedProcess newDelayedProcess(Handler next, Request request, Response response, Callback callback)
{
if (!request.getConnectionMetaData().getHttpConfiguration().isDelayDispatchUntilContent())
return null;
if (FormFields.getFormEncodedCharset(request) == null)
return null;
return new FormDelayedProcess(next, request, response, callback);
}
@Override
protected void delay(DelayedProcess delayed)
{
FormFields.from(delayed.getRequest()).whenComplete((FormDelayedProcess)delayed);
}
protected static class FormDelayedProcess extends DelayedProcess implements BiConsumer<Fields, Throwable>
{
public FormDelayedProcess(Handler handler, Request wrapped, Response response, Callback callback)
{ {
super(handler, wrapped, response, callback); getRequest().demand(this::onContent);
} }
else
@Override
public void accept(Fields fields, Throwable x)
{ {
if (x == null) try
run();
else
Response.writeError(getRequest(), getResponse(), getCallback(), x);
}
}
}
public static class UntilMultiPartFormData extends DelayedHandler
{
@Override
protected MultiPartDelayedProcess newDelayedProcess(Handler next, Request request, Response response, Callback callback)
{
if (!request.getConnectionMetaData().getHttpConfiguration().isDelayDispatchUntilContent())
return null;
String contentType = request.getHeaders().get(HttpHeader.CONTENT_TYPE);
if (contentType == null)
return null;
String contentTypeValue = HttpField.valueParameters(contentType, null);
if (!MimeTypes.Type.MULTIPART_FORM_DATA.is(contentTypeValue))
return null;
String boundary = MultiPart.extractBoundary(contentType);
if (boundary == null)
return null;
return new MultiPartDelayedProcess(next, request, response, callback, boundary);
}
@Override
protected void delay(DelayedProcess request)
{
request.run();
((MultiPartDelayedProcess)request).whenDone();
}
protected static class MultiPartDelayedProcess extends DelayedProcess implements BiConsumer<MultiPartFormData.Parts, Throwable>
{
private final MultiPartFormData _formData;
public MultiPartDelayedProcess(Handler handler, Request wrapped, Response response, Callback callback, String boundary)
{
super(handler, wrapped, response, callback);
_formData = new MultiPartFormData(boundary);
getRequest().setAttribute(MultiPartFormData.class.getName(), _formData);
}
@Override
public void accept(MultiPartFormData.Parts parts, Throwable x)
{
if (x == null)
super.run();
else
Response.writeError(getRequest(), getResponse(), getCallback(), x);
}
@Override
public void run()
{
while (true)
{ {
Content.Chunk chunk = getRequest().read(); getHandler().process(new RewindChunkRequest(getRequest(), chunk), getResponse(), getCallback());
if (chunk == null) }
{ catch (Exception e)
getRequest().demand(this); {
return; Response.writeError(getRequest(), getResponse(), getCallback(), e);
}
if (chunk instanceof Content.Chunk.Error error)
{
_formData.completeExceptionally(error.getCause());
return;
}
_formData.parse(chunk);
chunk.release();
if (chunk.isLast())
return;
} }
} }
}
public void whenDone() public void onContent()
{
// We must execute here, because demand callbacks are serialized and process may block on a demand callback
getRequest().getContext().execute(this::process);
}
private static class RewindChunkRequest extends Request.Wrapper
{
private final AtomicReference<Content.Chunk> _chunk;
public RewindChunkRequest(Request wrapped, Content.Chunk chunk)
{ {
if (_formData.isDone()) super(wrapped);
super.run(); _chunk = new AtomicReference<>(chunk);
else }
_formData.whenComplete(this);
@Override
public Content.Chunk read()
{
Content.Chunk chunk = _chunk.getAndSet(null);
if (chunk != null)
return chunk;
return super.read();
}
}
}
protected static class UntilFormDelayedProcess extends DelayedProcess
{
private final Charset _charset;
public UntilFormDelayedProcess(Handler handler, Request wrapped, Response response, Callback callback, String contentType)
{
super(handler, wrapped, response, callback);
String cs = MimeTypes.getCharsetFromContentType(contentType);
_charset = StringUtil.isEmpty(cs) ? StandardCharsets.UTF_8 : Charset.forName(cs);
}
@Override
protected void delay()
{
CompletableFuture<Fields> futureFormFields = FormFields.from(getRequest(), _charset);
// if we are done already, then we are still in the scope of the original process call and can
// process directly, otherwise we must execute a call to process as we are within a serialized
// demand callback.
futureFormFields.whenComplete(futureFormFields.isDone() ? this::process : this::executeProcess);
}
private void process(Fields fields, Throwable x)
{
if (x == null)
super.process();
else
Response.writeError(getRequest(), getResponse(), getCallback(), x);
}
private void executeProcess(Fields fields, Throwable x)
{
if (x == null)
// We must execute here as even though we have consumed all the input, we are probably
// invoked in a demand runnable that is serialized with any write callbacks that might be done in process
getRequest().getContext().execute(super::process);
else
Response.writeError(getRequest(), getResponse(), getCallback(), x);
}
}
protected static class UntilMultiPartDelayedProcess extends DelayedProcess
{
private final MultiPartFormData _formData;
public UntilMultiPartDelayedProcess(Handler handler, Request wrapped, Response response, Callback callback, String contentType)
{
super(handler, wrapped, response, callback);
String boundary = MultiPart.extractBoundary(contentType);
_formData = boundary == null ? null : new MultiPartFormData(boundary);
getRequest().setAttribute(MultiPartFormData.class.getName(), _formData);
}
private void process(MultiPartFormData.Parts parts, Throwable x)
{
if (x == null)
{
super.process();
}
else
{
Response.writeError(getRequest(), getResponse(), getCallback(), x);
}
}
private void executeProcess(MultiPartFormData.Parts parts, Throwable x)
{
if (x == null)
{
// We must execute here as even though we have consumed all the input, we are probably
// invoked in a demand runnable that is serialized with any write callbacks that might be done in process
getRequest().getContext().execute(super::process);
}
else
{
Response.writeError(getRequest(), getResponse(), getCallback(), x);
}
}
@Override
public void delay()
{
if (_formData == null)
{
super.process();
}
else
{
Object baseTempDirectory = getRequest().getContext().getAttribute(Server.BASE_TEMP_DIR_ATTR);
_formData.setFilesDirectory(IO.asFile(baseTempDirectory == null ? System.getProperty("java.io.tmpdir") : baseTempDirectory).toPath());
readAndParse();
// if we are done already, then we are still in the scope of the original process call and can
// process directly, otherwise we must execute a call to process as we are within a serialized
// demand callback.
_formData.whenComplete(_formData.isDone() ? this::process : this::executeProcess);
}
}
private void readAndParse()
{
while (true)
{
Content.Chunk chunk = getRequest().read();
if (chunk == null)
{
getRequest().demand(this::readAndParse);
return;
}
if (chunk instanceof Content.Chunk.Error error)
{
_formData.completeExceptionally(error.getCause());
return;
}
_formData.parse(chunk);
chunk.release();
if (chunk.isLast())
{
if (!_formData.isDone())
process(null, new IOException("Incomplete multipart"));
return;
}
} }
} }
} }

View File

@ -13,16 +13,19 @@
package org.eclipse.jetty.server.handler; package org.eclipse.jetty.server.handler;
import java.io.Closeable;
import java.net.InetAddress; import java.net.InetAddress;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.channels.WritePendingException;
import java.util.ArrayDeque; import java.util.ArrayDeque;
import java.util.Deque; import java.util.Deque;
import java.util.Objects; import java.util.Objects;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import org.eclipse.jetty.http.HostPortHttpField; import org.eclipse.jetty.http.HostPortHttpField;
import org.eclipse.jetty.http.HttpField; import org.eclipse.jetty.http.HttpField;
@ -34,7 +37,6 @@ import org.eclipse.jetty.server.Handler;
import org.eclipse.jetty.server.Request; import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.server.Response; import org.eclipse.jetty.server.Response;
import org.eclipse.jetty.util.Callback; import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.IO;
import org.eclipse.jetty.util.IncludeExcludeSet; import org.eclipse.jetty.util.IncludeExcludeSet;
import org.eclipse.jetty.util.InetAddressSet; import org.eclipse.jetty.util.InetAddressSet;
import org.eclipse.jetty.util.StringUtil; import org.eclipse.jetty.util.StringUtil;
@ -78,7 +80,7 @@ public class ThreadLimitHandler extends Handler.Wrapper
public ThreadLimitHandler() public ThreadLimitHandler()
{ {
this(null, false); this(null, true);
} }
public ThreadLimitHandler(@Name("forwardedHeader") String forwardedHeader) public ThreadLimitHandler(@Name("forwardedHeader") String forwardedHeader)
@ -182,14 +184,6 @@ public class ThreadLimitHandler extends Handler.Wrapper
return true; return true;
} }
private static void getAndClose(CompletableFuture<Closeable> cf)
{
LOG.debug("getting {}", cf);
Closeable closeable = cf.getNow(null);
LOG.debug("closing {}", closeable);
IO.close(closeable);
}
private Remote getRemote(Request baseRequest) private Remote getRemote(Request baseRequest)
{ {
String ip = getRemoteIP(baseRequest); String ip = getRemoteIP(baseRequest);
@ -204,7 +198,7 @@ public class ThreadLimitHandler extends Handler.Wrapper
Remote remote = _remotes.get(ip); Remote remote = _remotes.get(ip);
if (remote == null) if (remote == null)
{ {
Remote r = new Remote(ip, limit); Remote r = new Remote(baseRequest.getContext(), ip, limit);
remote = _remotes.putIfAbsent(ip, r); remote = _remotes.putIfAbsent(ip, r);
if (remote == null) if (remote == null)
remote = r; remote = r;
@ -278,6 +272,7 @@ public class ThreadLimitHandler extends Handler.Wrapper
private final Handler _handler; private final Handler _handler;
private final LimitedResponse _response; private final LimitedResponse _response;
private final Callback _callback; private final Callback _callback;
private final AtomicReference<Runnable> _onContent = new AtomicReference<>();
public LimitedRequest(Remote remote, Handler handler, Request request, Response response, Callback callback) public LimitedRequest(Remote remote, Handler handler, Request request, Response response, Callback callback)
{ {
@ -305,79 +300,75 @@ public class ThreadLimitHandler extends Handler.Wrapper
protected void process() throws Exception protected void process() throws Exception
{ {
CompletableFuture<Closeable> futurePermit = _remote.acquire(); Permit permit = _remote.acquire();
// Did we get a permit? // Did we get a permit?
if (futurePermit.isDone()) if (permit.isAllocated())
{ {
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("Threadpermitted {}", _remote); LOG.debug("Thread permitted {} {} {}", _remote, getWrapped(), _handler);
process(futurePermit); process(permit);
} }
else else
{ {
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("Threadlimited {}", _remote); LOG.debug("Thread limited {} {} {}", _remote, getWrapped(), _handler);
futurePermit.thenAccept(c -> process(futurePermit)); permit.whenAllocated(this::process);
} }
} }
protected void process(CompletableFuture<Closeable> futurePermit) protected void process(Permit permit)
{ {
Callback callback = Callback.from(_callback, () -> getAndClose(futurePermit));
try try
{ {
if (!_handler.process(this, _response, callback)) if (!_handler.process(this, _response, _callback))
Response.writeError(this, _response, callback, HttpStatus.NOT_FOUND_404); Response.writeError(this, _response, _callback, HttpStatus.NOT_FOUND_404);
} }
catch (Throwable x) catch (Throwable x)
{ {
callback.failed(x); _callback.failed(x);
}
finally
{
permit.release();
} }
} }
@Override @Override
public void demand(Runnable onContent) public void demand(Runnable onContent)
{ {
Runnable permittedDemand = () -> if (!_onContent.compareAndSet(null, Objects.requireNonNull(onContent)))
throw new IllegalStateException("Pending demand");
super.demand(this::onContent);
}
private void onContent()
{
Permit permit = _remote.acquire();
if (permit.isAllocated())
onPermittedContent(permit);
else
permit.whenAllocated(this::onPermittedContent);
}
private void onPermittedContent(Permit permit)
{
try
{ {
// TODO need to consider if we already have a permit! Runnable onContent = _onContent.getAndSet(null);
CompletableFuture<Closeable> futurePermit = _remote.acquire(); onContent.run();
}
if (futurePermit.isDone()) finally
{ {
try permit.release();
{ }
onContent.run();
}
finally
{
getAndClose(futurePermit);
}
}
else
{
futurePermit.thenAccept(c ->
{
try
{
onContent.run();
}
finally
{
IO.close(c);
}
});
}
};
super.demand(permittedDemand);
} }
} }
private static class LimitedResponse extends Response.Wrapper private static class LimitedResponse extends Response.Wrapper implements Callback
{ {
private final Remote _remote; private final Remote _remote;
private final AtomicReference<Callback> _writeCallback = new AtomicReference<>();
public LimitedResponse(LimitedRequest limitedRequest, Response response) public LimitedResponse(LimitedRequest limitedRequest, Response response)
{ {
@ -388,143 +379,218 @@ public class ThreadLimitHandler extends Handler.Wrapper
@Override @Override
public void write(boolean last, ByteBuffer byteBuffer, Callback callback) public void write(boolean last, ByteBuffer byteBuffer, Callback callback)
{ {
Callback permittedCallback = new Callback() if (!_writeCallback.compareAndSet(null, Objects.requireNonNull(callback)))
throw new WritePendingException();
super.write(last, byteBuffer, this);
}
@Override
public void succeeded()
{
Permit permit = _remote.acquire();
if (permit.isAllocated())
permittedSuccess(permit);
else
permit.whenAllocated(this::permittedSuccess);
}
private void permittedSuccess(Permit permit)
{
try
{ {
@Override _writeCallback.getAndSet(null).succeeded();
public void succeeded() }
{ finally
// TODO need to consider if we already have a permit! {
CompletableFuture<Closeable> futurePermit = _remote.acquire(); permit.release();
if (futurePermit.isDone()) }
{ }
try
{
callback.succeeded();
}
finally
{
getAndClose(futurePermit);
}
}
else
{
futurePermit.thenAccept(c ->
{
try
{
callback.succeeded();
}
finally
{
IO.close(c);
}
});
}
}
@Override @Override
public void failed(Throwable x) public void failed(Throwable x)
{ {
CompletableFuture<Closeable> futurePermit = _remote.acquire(); Permit permit = _remote.acquire();
if (futurePermit.isDone()) if (permit.isAllocated())
{ permittedFailure(permit, x);
try else
{ permit.whenAllocated(p -> permittedFailure(p, x));
callback.failed(x); }
}
finally
{
getAndClose(futurePermit);
}
}
else
{
futurePermit.thenAccept(c ->
{
try
{
callback.failed(x);
}
finally
{
IO.close(c);
}
});
}
}
@Override private void permittedFailure(Permit permit, Throwable x)
public InvocationType getInvocationType() {
{ try
return callback.getInvocationType(); {
} _writeCallback.getAndSet(null).failed(x);
}; }
finally
super.write(last, byteBuffer, permittedCallback); {
permit.release();
}
} }
} }
private static final class Remote implements Closeable private interface Permit
{ {
boolean isAllocated();
void whenAllocated(Consumer<Permit> permitConsumer);
void release();
}
private static class NoopPermit implements Permit
{
@Override
public boolean isAllocated()
{
return true;
}
@Override
public void whenAllocated(Consumer<Permit> permitConsumer)
{
throw new UnsupportedOperationException();
}
@Override
public void release()
{
}
}
private static class AllocatedPermit implements Permit
{
private final Remote _remote;
private AllocatedPermit(Remote remote)
{
_remote = remote;
}
@Override
public boolean isAllocated()
{
return true;
}
@Override
public void whenAllocated(Consumer<Permit> permitConsumer)
{
throw new UnsupportedOperationException();
}
@Override
public void release()
{
_remote.release();
}
@Override
public String toString()
{
return "AllocatedPermit:" + _remote;
}
}
private static class FuturePermit implements Permit
{
private final CompletableFuture<Permit> _future = new CompletableFuture<>();
private final Remote _remote;
private FuturePermit(Remote remote)
{
_remote = remote;
}
public boolean isAllocated()
{
return _future.isDone();
}
public void whenAllocated(Consumer<Permit> permitConsumer)
{
_future.thenAccept(permitConsumer);
}
void complete()
{
if (!_future.complete(this))
throw new IllegalStateException();
}
public void release()
{
_remote.release();
}
}
private static final class Remote
{
private final Executor _executor;
private final String _ip; private final String _ip;
private final int _limit; private final int _limit;
private final AutoLock _lock = new AutoLock(); private final AutoLock _lock = new AutoLock();
private int _permits; private int _permits;
private final Deque<CompletableFuture<Closeable>> _queue = new ArrayDeque<>(); private final Deque<FuturePermit> _queue = new ArrayDeque<>();
private final CompletableFuture<Closeable> _permitted = CompletableFuture.completedFuture(this); private final Permit _permitted = new AllocatedPermit(this);
private final ThreadLocal<Boolean> _threadPermit = new ThreadLocal<>();
private static final Permit NOOP = new NoopPermit();
public Remote(String ip, int limit) public Remote(Executor executor, String ip, int limit)
{ {
_executor = executor;
_ip = ip; _ip = ip;
_limit = limit; _limit = limit;
} }
public CompletableFuture<Closeable> acquire() Permit acquire()
{ {
try (AutoLock lock = _lock.lock()) try (AutoLock lock = _lock.lock())
{ {
// Does this thread already have an available pass
if (_threadPermit.get() == Boolean.TRUE)
return NOOP;
// Do we have available passes? // Do we have available passes?
if (_permits < _limit) if (_permits < _limit)
{ {
// Yes - increment the allocated passes // Yes - increment the allocated passes
_permits++; _permits++;
_threadPermit.set(Boolean.TRUE);
// return the already completed future // return the already completed future
return _permitted; // TODO is it OK to share/reuse this? return _permitted;
} }
// No pass available, so queue a new future // No pass available, so queue a new future
CompletableFuture<Closeable> pass = new CompletableFuture<>();
_queue.addLast(pass); FuturePermit futurePermit = new FuturePermit(this);
return pass; _queue.addLast(futurePermit);
return futurePermit;
} }
} }
@Override public void release()
public void close()
{ {
FuturePermit pending;
try (AutoLock lock = _lock.lock()) try (AutoLock lock = _lock.lock())
{ {
// reduce the allocated passes // reduce the allocated passes
_permits--; _permits--;
while (true) _threadPermit.set(Boolean.FALSE);
{ // Are there any future passes pending?
// Are there any future passes waiting? pending = _queue.pollFirst();
CompletableFuture<Closeable> permit = _queue.pollFirst();
// No - we are done // yes, allocate them a permit
if (permit == null) if (pending != null)
break; _permits++;
}
// Yes - if we can complete them, we are done if (pending != null)
if (permit.complete(this)) {
{ // We cannot complete the pending in this thread, as we may be in a process, demand or write callback
_permits++; // that is serialized and other actions are waiting for the return. Thus, we must execute.
break; _executor.execute(pending::complete);
}
// Somebody else must have completed/failed that future pass,
// so let's try for another.
}
} }
} }

View File

@ -134,50 +134,7 @@ public class HttpChannelState implements HttpChannel, Components
{ {
_connectionMetaData = connectionMetaData; _connectionMetaData = connectionMetaData;
// The SerializedInvoker is used to prevent infinite recursion of callbacks calling methods calling callbacks etc. // The SerializedInvoker is used to prevent infinite recursion of callbacks calling methods calling callbacks etc.
_serializedInvoker = new SerializedInvoker() _serializedInvoker = new HttpChannelSerializedInvoker();
{
@Override
protected void onError(Runnable task, Throwable failure)
{
ChannelRequest request;
Content.Chunk.Error error;
boolean callbackCompleted;
try (AutoLock ignore = _lock.lock())
{
callbackCompleted = _callbackCompleted;
request = _request;
error = _request == null ? null : _error;
}
if (request == null || callbackCompleted)
{
// It is too late to handle error, so just log it
super.onError(task, failure);
}
else if (error == null)
{
// Try to fail the request, but we might lose a race.
try
{
request._callback.failed(failure);
}
catch (Throwable t)
{
if (ExceptionUtil.areNotAssociated(failure, t))
failure.addSuppressed(t);
super.onError(task, failure);
}
}
else
{
// We are already in error, so we will not handle this one,
// but we will add as suppressed if we have not seen it already.
Throwable cause = error.getCause();
if (cause != null && ExceptionUtil.areNotAssociated(cause, failure))
error.getCause().addSuppressed(failure);
}
}
};
} }
@Override @Override
@ -709,8 +666,6 @@ public class HttpChannelState implements HttpChannel, Components
public static class ChannelRequest implements Attributes, Request public static class ChannelRequest implements Attributes, Request
{ {
private static final Logger LOG = LoggerFactory.getLogger(ChannelResponse.class);
private final long _timeStamp = System.currentTimeMillis(); private final long _timeStamp = System.currentTimeMillis();
private final ChannelCallback _callback = new ChannelCallback(this); private final ChannelCallback _callback = new ChannelCallback(this);
private final String _id; private final String _id;
@ -927,6 +882,9 @@ public class HttpChannelState implements HttpChannel, Components
{ {
HttpChannelState httpChannel = lockedGetHttpChannel(); HttpChannelState httpChannel = lockedGetHttpChannel();
if (LOG.isDebugEnabled())
LOG.debug("demand {}", httpChannel);
error = httpChannel._error != null; error = httpChannel._error != null;
if (!error) if (!error)
{ {
@ -1006,8 +964,6 @@ public class HttpChannelState implements HttpChannel, Components
public static class ChannelResponse implements Response, Callback public static class ChannelResponse implements Response, Callback
{ {
private static final Logger LOG = LoggerFactory.getLogger(ChannelResponse.class);
private final ChannelRequest _request; private final ChannelRequest _request;
private int _status; private int _status;
private long _contentBytesWritten; private long _contentBytesWritten;
@ -1254,8 +1210,6 @@ public class HttpChannelState implements HttpChannel, Components
private static class ChannelCallback implements Callback private static class ChannelCallback implements Callback
{ {
private static final Logger LOG = LoggerFactory.getLogger(ChannelCallback.class);
private final ChannelRequest _request; private final ChannelRequest _request;
private Throwable _completedBy; private Throwable _completedBy;
@ -1501,4 +1455,49 @@ public class HttpChannelState implements HttpChannel, Components
_request.getHttpChannel()._handlerInvoker.failed(_failure); _request.getHttpChannel()._handlerInvoker.failed(_failure);
} }
} }
private class HttpChannelSerializedInvoker extends SerializedInvoker
{
@Override
protected void onError(Runnable task, Throwable failure)
{
ChannelRequest request;
Content.Chunk.Error error;
boolean callbackCompleted;
try (AutoLock ignore = _lock.lock())
{
callbackCompleted = _callbackCompleted;
request = _request;
error = _request == null ? null : _error;
}
if (request == null || callbackCompleted)
{
// It is too late to handle error, so just log it
super.onError(task, failure);
}
else if (error == null)
{
// Try to fail the request, but we might lose a race.
try
{
request._callback.failed(failure);
}
catch (Throwable t)
{
if (ExceptionUtil.areNotAssociated(failure, t))
failure.addSuppressed(t);
super.onError(task, failure);
}
}
else
{
// We are already in error, so we will not handle this one,
// but we will add as suppressed if we have not seen it already.
Throwable cause = error.getCause();
if (cause != null && ExceptionUtil.areNotAssociated(cause, failure))
error.getCause().addSuppressed(failure);
}
}
}
} }

View File

@ -13,7 +13,9 @@
package org.eclipse.jetty.server.handler; package org.eclipse.jetty.server.handler;
import java.io.ByteArrayOutputStream;
import java.io.OutputStream; import java.io.OutputStream;
import java.io.PrintStream;
import java.net.Socket; import java.net.Socket;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
@ -40,6 +42,8 @@ import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.sameInstance;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
@ -70,16 +74,10 @@ public class DelayedHandlerTest
DelayedHandler delayedHandler = new DelayedHandler() DelayedHandler delayedHandler = new DelayedHandler()
{ {
@Override @Override
protected DelayedProcess newDelayedProcess(Handler next, Request request, Response response, Callback callback) protected DelayedProcess newDelayedProcess(boolean contentExpected, String contentType, MimeTypes.Type mimeType, Handler handler, Request request, Response response, Callback callback)
{ {
return null; return null;
} }
@Override
protected void delay(DelayedProcess request)
{
throw new UnsupportedOperationException();
}
}; };
_server.setHandler(delayedHandler); _server.setHandler(delayedHandler);
@ -113,16 +111,16 @@ public class DelayedHandlerTest
DelayedHandler delayedHandler = new DelayedHandler() DelayedHandler delayedHandler = new DelayedHandler()
{ {
@Override @Override
protected DelayedProcess newDelayedProcess(Handler next, Request request, Response response, Callback callback) protected DelayedProcess newDelayedProcess(boolean contentExpected, String contentType, MimeTypes.Type mimeType, Handler handler, Request request, Response response, Callback callback)
{ {
return new DelayedProcess(next, request, response, callback); return new DelayedProcess(handler, request, response, callback)
} {
@Override
@Override protected void delay() throws Exception
protected void delay(DelayedProcess request) throws InterruptedException {
{ handleEx.exchange(this::process);
handleEx.exchange(request); }
};
} }
}; };
@ -168,9 +166,9 @@ public class DelayedHandlerTest
} }
@Test @Test
public void testOnContent() throws Exception public void testDelayedUntilContent() throws Exception
{ {
DelayedHandler delayedHandler = new DelayedHandler.UntilContent(); DelayedHandler delayedHandler = new DelayedHandler();
_server.setHandler(delayedHandler); _server.setHandler(delayedHandler);
CountDownLatch processing = new CountDownLatch(1); CountDownLatch processing = new CountDownLatch(1);
@ -179,6 +177,15 @@ public class DelayedHandlerTest
@Override @Override
public boolean process(Request request, Response response, Callback callback) throws Exception public boolean process(Request request, Response response, Callback callback) throws Exception
{ {
// Check that we are not called via any demand callback
ByteArrayOutputStream out = new ByteArrayOutputStream(8192);
new Throwable().printStackTrace(new PrintStream(out));
String stack = out.toString(StandardCharsets.ISO_8859_1);
assertThat(stack, not(containsString("DemandContentCallback.succeeded")));
assertThat(stack, not(containsString("%s.%s".formatted(
DelayedHandler.UntilContentDelayedProcess.class.getSimpleName(),
DelayedHandler.UntilContentDelayedProcess.class.getMethod("onContent").getName()))));
processing.countDown(); processing.countDown();
return super.process(request, response, callback); return super.process(request, response, callback);
} }
@ -213,28 +220,135 @@ public class DelayedHandlerTest
} }
} }
@Test
public void testDelayedUntilContentInContext() throws Exception
{
ContextHandler context = new ContextHandler();
_server.setHandler(context);
DelayedHandler delayedHandler = new DelayedHandler();
context.setHandler(delayedHandler);
CountDownLatch processing = new CountDownLatch(1);
delayedHandler.setHandler(new HelloHandler()
{
@Override
public boolean process(Request request, Response response, Callback callback) throws Exception
{
// Check that we are not called via any demand callback
ByteArrayOutputStream out = new ByteArrayOutputStream(8192);
new Throwable().printStackTrace(new PrintStream(out));
String stack = out.toString(StandardCharsets.ISO_8859_1);
assertThat(stack, not(containsString("DemandContentCallback.succeeded")));
assertThat(stack, not(containsString("%s.%s".formatted(
DelayedHandler.UntilContentDelayedProcess.class.getSimpleName(),
DelayedHandler.UntilContentDelayedProcess.class.getMethod("onContent").getName()))));
// Check the thread is in the context
assertThat(ContextHandler.getCurrentContext(), sameInstance(context.getContext()));
// Check the request is wrapped in the context
assertThat(request.getContext(), sameInstance(context.getContext()));
processing.countDown();
return super.process(request, response, callback);
}
});
_server.start();
try (Socket socket = new Socket("localhost", _connector.getLocalPort()))
{
String request = """
POST / HTTP/1.1\r
Host: localhost\r
Content-Length: 10\r
\r
""";
OutputStream output = socket.getOutputStream();
output.write(request.getBytes(StandardCharsets.UTF_8));
output.flush();
assertFalse(processing.await(250, TimeUnit.MILLISECONDS));
output.write("01234567\r\n".getBytes(StandardCharsets.UTF_8));
output.flush();
assertTrue(processing.await(10, TimeUnit.SECONDS));
HttpTester.Input input = HttpTester.from(socket.getInputStream());
HttpTester.Response response = HttpTester.parseResponse(input);
assertNotNull(response);
assertEquals(HttpStatus.OK_200, response.getStatus());
String content = new String(response.getContentBytes(), StandardCharsets.UTF_8);
assertThat(content, containsString("Hello"));
}
}
@Test
public void testNoDelayWithContent() throws Exception
{
DelayedHandler delayedHandler = new DelayedHandler();
_server.setHandler(delayedHandler);
delayedHandler.setHandler(new HelloHandler()
{
@Override
public boolean process(Request request, Response response, Callback callback) throws Exception
{
// Check that we are called directly from HttpConnection.onFillable
ByteArrayOutputStream out = new ByteArrayOutputStream(8192);
new Throwable().printStackTrace(new PrintStream(out));
String stack = out.toString(StandardCharsets.ISO_8859_1);
assertThat(stack, containsString("org.eclipse.jetty.server.internal.HttpConnection.onFillable"));
assertThat(stack, containsString("org.eclipse.jetty.server.handler.DelayedHandler.process"));
// Check the content is available
String content = Content.Source.asString(request);
assertThat(content, equalTo("1234567890"));
return super.process(request, response, callback);
}
});
_server.start();
try (Socket socket = new Socket("localhost", _connector.getLocalPort()))
{
String request = """
POST / HTTP/1.1\r
Host: localhost\r
Content-Length: 10\r
\r
1234567890\r
""";
OutputStream output = socket.getOutputStream();
output.write(request.getBytes(StandardCharsets.UTF_8));
output.flush();
HttpTester.Input input = HttpTester.from(socket.getInputStream());
HttpTester.Response response = HttpTester.parseResponse(input);
assertNotNull(response);
assertEquals(HttpStatus.OK_200, response.getStatus());
String content = new String(response.getContentBytes(), StandardCharsets.UTF_8);
assertThat(content, containsString("Hello"));
}
}
@Test @Test
public void testDelayed404() throws Exception public void testDelayed404() throws Exception
{ {
DelayedHandler delayedHandler = new DelayedHandler() DelayedHandler delayedHandler = new DelayedHandler()
{ {
@Override @Override
protected void delay(DelayedProcess delayed) throws Exception protected DelayedProcess newDelayedProcess(boolean contentExpected, String contentType, MimeTypes.Type mimeType, Handler handler, Request request, Response response, Callback callback)
{ {
delayed.getRequest().getContext().execute(() -> return new DelayedProcess(handler, request, response, callback)
{ {
try @Override
protected void delay()
{ {
if (!getHandler().process(delayed.getRequest(), delayed.getResponse(), delayed.getCallback())) getRequest().getContext().execute(this::process);
Response.writeError(delayed.getRequest(), delayed.getResponse(), delayed.getCallback(), HttpStatus.NOT_FOUND_404);
} }
catch (Throwable t) };
{
Response.writeError(delayed.getRequest(), delayed.getResponse(), delayed.getCallback(), t);
}
});
} }
}; };
_server.setHandler(delayedHandler); _server.setHandler(delayedHandler);
@ -272,7 +386,7 @@ public class DelayedHandlerTest
@Test @Test
public void testDelayedFormFields() throws Exception public void testDelayedFormFields() throws Exception
{ {
DelayedHandler delayedHandler = new DelayedHandler.UntilFormFields(); DelayedHandler delayedHandler = new DelayedHandler();
_server.setHandler(delayedHandler); _server.setHandler(delayedHandler);
CountDownLatch processing = new CountDownLatch(2); CountDownLatch processing = new CountDownLatch(2);
@ -339,4 +453,54 @@ public class DelayedHandlerTest
assertThat(content, containsString("x=[1, 2, 3]")); assertThat(content, containsString("x=[1, 2, 3]"));
} }
} }
@Test
public void testNoDelayFormFields() throws Exception
{
DelayedHandler delayedHandler = new DelayedHandler();
_server.setHandler(delayedHandler);
delayedHandler.setHandler(new Handler.Abstract()
{
@Override
public boolean process(Request request, Response response, Callback callback) throws Exception
{
// Check that we are called directly from HttpConnection.onFillable via DelayedHandler.process
ByteArrayOutputStream out = new ByteArrayOutputStream(8192);
new Throwable().printStackTrace(new PrintStream(out));
String stack = out.toString(StandardCharsets.ISO_8859_1);
assertThat(stack, containsString("org.eclipse.jetty.server.internal.HttpConnection.onFillable"));
assertThat(stack, containsString("org.eclipse.jetty.server.handler.DelayedHandler.process"));
Fields fields = FormFields.from(request).get(1, TimeUnit.NANOSECONDS);
Content.Sink.write(response, true, String.valueOf(fields), callback);
return true;
}
});
_server.start();
try (Socket socket = new Socket("localhost", _connector.getLocalPort()))
{
OutputStream output = socket.getOutputStream();
output.write("""
POST / HTTP/1.1
Host: localhost
Content-Type: %s
Content-Length: 22
name=value&x=1&x=2&x=3
""".formatted(MimeTypes.Type.FORM_ENCODED).getBytes(StandardCharsets.UTF_8));
output.flush();
HttpTester.Input input = HttpTester.from(socket.getInputStream());
HttpTester.Response response = HttpTester.parseResponse(input);
assertNotNull(response);
assertEquals(HttpStatus.OK_200, response.getStatus());
String content = new String(response.getContentBytes(), StandardCharsets.UTF_8);
assertThat(content, containsString("name=[value]"));
assertThat(content, containsString("x=[1, 2, 3]"));
}
}
} }

View File

@ -118,7 +118,7 @@ public class MultiPartFormDataHandlerTest
@Test @Test
public void testDelayedUntilFormData() throws Exception public void testDelayedUntilFormData() throws Exception
{ {
DelayedHandler.UntilMultiPartFormData delayedHandler = new DelayedHandler.UntilMultiPartFormData(); DelayedHandler delayedHandler = new DelayedHandler();
CountDownLatch processLatch = new CountDownLatch(1); CountDownLatch processLatch = new CountDownLatch(1);
delayedHandler.setHandler(new Handler.Abstract.NonBlocking() delayedHandler.setHandler(new Handler.Abstract.NonBlocking()
{ {

View File

@ -17,9 +17,12 @@ import java.net.Socket;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.HttpStatus; import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.io.Content;
import org.eclipse.jetty.server.Connector; import org.eclipse.jetty.server.Connector;
import org.eclipse.jetty.server.Handler; import org.eclipse.jetty.server.Handler;
import org.eclipse.jetty.server.LocalConnector; import org.eclipse.jetty.server.LocalConnector;
@ -29,6 +32,7 @@ import org.eclipse.jetty.server.Response;
import org.eclipse.jetty.server.Server; import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector; import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.util.Callback; import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.IO;
import org.hamcrest.Matchers; import org.hamcrest.Matchers;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
@ -36,7 +40,9 @@ import org.junit.jupiter.api.Test;
import static org.awaitility.Awaitility.await; import static org.awaitility.Awaitility.await;
import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class ThreadLimitHandlerTest public class ThreadLimitHandlerTest
{ {
@ -243,4 +249,123 @@ public class ThreadLimitHandlerTest
await().atMost(10, TimeUnit.SECONDS).until(total::get, is(10)); await().atMost(10, TimeUnit.SECONDS).until(total::get, is(10));
await().atMost(10, TimeUnit.SECONDS).until(count::get, is(0)); await().atMost(10, TimeUnit.SECONDS).until(count::get, is(0));
} }
@Test
public void testDemandLimit() throws Exception
{
ThreadLimitHandler handler = new ThreadLimitHandler("Forwarded");
handler.setThreadLimit(4);
AtomicInteger count = new AtomicInteger(0);
CountDownLatch processed = new CountDownLatch(5);
CountDownLatch latch = new CountDownLatch(1);
handler.setHandler(new Handler.Abstract()
{
@Override
public boolean process(Request request, Response response, Callback callback) throws Exception
{
processed.countDown();
Runnable onContent = new Runnable()
{
private final AtomicLong read = new AtomicLong();
@Override
public void run()
{
count.incrementAndGet();
try
{
latch.await();
while (true)
{
Content.Chunk chunk = request.read();
if (chunk == null)
{
request.demand(this);
return;
}
if (chunk instanceof Error error)
throw error.getCause();
if (chunk.hasRemaining())
read.addAndGet(chunk.remaining());
chunk.release();
if (chunk.isLast())
{
Content.Sink.write(response, true, request.getHttpURI() + " read " + read.get(), callback);
return;
}
}
}
catch (Throwable t)
{
callback.failed(t);
}
finally
{
count.decrementAndGet();
}
}
};
if (request.getHeaders().get(HttpHeader.CONTENT_LENGTH) == null)
callback.succeeded();
else
request.demand(onContent);
return true;
}
});
_server.setHandler(handler);
_server.start();
Socket[] client = new Socket[5];
for (int i = 0; i < client.length; i++)
{
client[i] = new Socket("127.0.0.1", _connector.getLocalPort());
client[i].getOutputStream().write(("POST /" + i + " HTTP/1.0\r\nForwarded: for=1.2.3.4\r\nContent-Length: 2\r\n\r\n").getBytes());
client[i].getOutputStream().flush();
}
// wait until all 5 requests are processed
assertTrue(processed.await(10, TimeUnit.SECONDS));
// wait until we are threadlessly waiting for demand
await().atMost(10, TimeUnit.SECONDS).until(count::get, is(0));
// Send some content for the clients
for (Socket socket : client)
{
socket.getOutputStream().write('X');
socket.getOutputStream().flush();
}
// wait until we 4 threads are blocked in onContent
await().atMost(10, TimeUnit.SECONDS).until(count::get, is(4));
// check that other requests are not blocked
String response = _local.getResponse("GET /other HTTP/1.0\r\nForwarded: for=6.6.6.6\r\n\r\n");
assertThat(response, Matchers.containsString(" 200 OK"));
// let the requests go
latch.countDown();
// Wait until we are threadlessly waiting again
await().atMost(10, TimeUnit.SECONDS).until(count::get, is(0));
// Send the rest of the content for the clients
for (Socket socket : client)
{
socket.getOutputStream().write('Y');
socket.getOutputStream().flush();
}
// read all the responses
for (Socket socket : client)
{
response = IO.toString(socket.getInputStream());
assertThat(response, containsString(" 200 OK"));
assertThat(response, containsString(" read 2"));
}
}
} }

View File

@ -83,7 +83,7 @@ public class HandlerBenchmark
{ {
_server.addConnector(_connector); _server.addConnector(_connector);
_connector.getConnectionFactory(HttpConnectionFactory.class).getHttpConfiguration().addCustomizer(new ForwardedRequestCustomizer()); _connector.getConnectionFactory(HttpConnectionFactory.class).getHttpConfiguration().addCustomizer(new ForwardedRequestCustomizer());
DelayedHandler.UntilContent delayedHandler = new DelayedHandler.UntilContent(); DelayedHandler delayedHandler = new DelayedHandler();
_server.setHandler(delayedHandler); _server.setHandler(delayedHandler);
ContextHandlerCollection contexts = new ContextHandlerCollection(); ContextHandlerCollection contexts = new ContextHandlerCollection();
delayedHandler.setHandler(contexts); delayedHandler.setHandler(contexts);

View File

@ -559,6 +559,25 @@ public class IO
return total; return total;
} }
/**
* <p>Convert an object to a {@link File} if possible.</p>
* @param fileObject A File, String, Path or null to be converted into a File
* @return A File representation of the passed argument or null.
*/
public static File asFile(Object fileObject)
{
if (fileObject == null)
return null;
if (fileObject instanceof File)
return (File)fileObject;
if (fileObject instanceof String)
return new File((String)fileObject);
if (fileObject instanceof Path)
return ((Path)fileObject).toFile();
return null;
}
} }

View File

@ -138,7 +138,7 @@ public class SerializedInvoker
@Override @Override
public String toString() public String toString()
{ {
return String.format("%s@%x", getClass().getSimpleName(), hashCode()); return String.format("%s@%x{tail=%s}", getClass().getSimpleName(), hashCode(), _tail);
} }
protected void onError(Runnable task, Throwable t) protected void onError(Runnable task, Throwable t)

View File

@ -53,14 +53,18 @@ public class ServletMultiPartFormData
* @param request the HTTP request with multipart content * @param request the HTTP request with multipart content
* @return a {@link Parts} object to access the individual {@link Part}s * @return a {@link Parts} object to access the individual {@link Part}s
* @throws IOException if reading the request content fails * @throws IOException if reading the request content fails
* @see org.eclipse.jetty.server.handler.DelayedHandler
*/ */
public static Parts from(ServletContextRequest.ServletApiRequest request) throws IOException public static Parts from(ServletContextRequest.ServletApiRequest request) throws IOException
{ {
try try
{ {
// Look for a previously read and parsed MultiPartFormData from the DelayedHandler
MultiPartFormData formData = (MultiPartFormData)request.getAttribute(MultiPartFormData.class.getName()); MultiPartFormData formData = (MultiPartFormData)request.getAttribute(MultiPartFormData.class.getName());
if (formData != null) if (formData != null)
return new Parts(formData); return new Parts(formData);
// TODO set the files directory
return new ServletMultiPartFormData().parse(request); return new ServletMultiPartFormData().parse(request);
} }
catch (Throwable x) catch (Throwable x)
@ -188,6 +192,7 @@ public class ServletMultiPartFormData
@Override @Override
public void write(String fileName) throws IOException public void write(String fileName) throws IOException
{ {
// TODO This should simply move a part that is already on the file system.
Path filePath = Path.of(fileName); Path filePath = Path.of(fileName);
if (!filePath.isAbsolute()) if (!filePath.isAbsolute())
filePath = _formData.getFilesDirectory().resolve(filePath).normalize(); filePath = _formData.getFilesDirectory().resolve(filePath).normalize();

View File

@ -197,21 +197,12 @@ public class WebInfConfiguration extends AbstractConfiguration
* Given an Object, return File reference for object. * Given an Object, return File reference for object.
* Typically used to convert anonymous Object from getAttribute() calls to a File object. * Typically used to convert anonymous Object from getAttribute() calls to a File object.
* *
* @param fileattr the file attribute to analyze and return from (supports type File, Path, and String). * @param fileObject the file object to analyze and return from (supports type File, Path, and String).
* @return the File object if it can be converted otherwise null. * @return the File object if it can be converted otherwise null.
*/ */
private File asFile(Object fileattr) private File asFile(Object fileObject)
{ {
if (fileattr == null) return IO.asFile(fileObject);
return null;
if (fileattr instanceof File)
return (File)fileattr;
if (fileattr instanceof String)
return new File((String)fileattr);
if (fileattr instanceof Path)
return ((Path)fileattr).toFile();
return null;
} }
public void makeTempDirectory(File parent, WebAppContext context) public void makeTempDirectory(File parent, WebAppContext context)

View File

@ -27,7 +27,10 @@ import java.nio.file.Path;
import java.nio.file.StandardOpenOption; import java.nio.file.StandardOpenOption;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Stream; import java.util.stream.Stream;
import jakarta.servlet.MultipartConfigElement; import jakarta.servlet.MultipartConfigElement;
@ -43,20 +46,24 @@ import org.eclipse.jetty.client.http.HttpClientTransportOverHTTP;
import org.eclipse.jetty.client.util.InputStreamResponseListener; import org.eclipse.jetty.client.util.InputStreamResponseListener;
import org.eclipse.jetty.client.util.MultiPartRequestContent; import org.eclipse.jetty.client.util.MultiPartRequestContent;
import org.eclipse.jetty.client.util.PathRequestContent; import org.eclipse.jetty.client.util.PathRequestContent;
import org.eclipse.jetty.ee10.servlet.DefaultServlet;
import org.eclipse.jetty.ee10.servlet.ServletHolder; import org.eclipse.jetty.ee10.servlet.ServletHolder;
import org.eclipse.jetty.http.HttpFields; import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.http.HttpHeader; import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.HttpMethod; import org.eclipse.jetty.http.HttpMethod;
import org.eclipse.jetty.http.MultiPart; import org.eclipse.jetty.http.MultiPart;
import org.eclipse.jetty.io.ClientConnector; import org.eclipse.jetty.io.ClientConnector;
import org.eclipse.jetty.io.Content;
import org.eclipse.jetty.server.Handler; import org.eclipse.jetty.server.Handler;
import org.eclipse.jetty.server.HttpConfiguration; import org.eclipse.jetty.server.HttpConfiguration;
import org.eclipse.jetty.server.HttpConnectionFactory; import org.eclipse.jetty.server.HttpConnectionFactory;
import org.eclipse.jetty.server.Server; import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector; import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.server.handler.DefaultHandler; import org.eclipse.jetty.server.handler.DefaultHandler;
import org.eclipse.jetty.server.handler.DelayedHandler;
import org.eclipse.jetty.toolchain.test.FS; import org.eclipse.jetty.toolchain.test.FS;
import org.eclipse.jetty.toolchain.test.MavenTestingUtils; import org.eclipse.jetty.toolchain.test.MavenTestingUtils;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.IO; import org.eclipse.jetty.util.IO;
import org.eclipse.jetty.util.resource.FileSystemPool; import org.eclipse.jetty.util.resource.FileSystemPool;
import org.eclipse.jetty.util.thread.QueuedThreadPool; import org.eclipse.jetty.util.thread.QueuedThreadPool;
@ -65,7 +72,6 @@ import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.Assumptions;
import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Tag;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
@ -75,6 +81,7 @@ import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertTrue;
@Tag("large-disk-resource") @Tag("large-disk-resource")
public class HugeResourceTest public class HugeResourceTest
@ -87,6 +94,7 @@ public class HugeResourceTest
public static Path multipartTempDir; public static Path multipartTempDir;
public Server server; public Server server;
HttpConfiguration httpConfig;
public HttpClient client; public HttpClient client;
@BeforeAll @BeforeAll
@ -106,8 +114,9 @@ public class HugeResourceTest
String.format("FileStore %s of %s needs at least 30GB of free space for this test (only had %,.2fGB)", String.format("FileStore %s of %s needs at least 30GB of free space for this test (only had %,.2fGB)",
baseFileStore, staticBase, (double)(baseFileStore.getUnallocatedSpace() / GB))); baseFileStore, staticBase, (double)(baseFileStore.getUnallocatedSpace() / GB)));
makeStaticFile(staticBase.resolve("test-1m.dat"), MB);
makeStaticFile(staticBase.resolve("test-1g.dat"), GB); makeStaticFile(staticBase.resolve("test-1g.dat"), GB);
makeStaticFile(staticBase.resolve("test-4g.dat"), 4 * GB); // makeStaticFile(staticBase.resolve("test-4g.dat"), 4 * GB);
// makeStaticFile(staticBase.resolve("test-10g.dat"), 10 * GB); // makeStaticFile(staticBase.resolve("test-10g.dat"), 10 * GB);
outputDir = MavenTestingUtils.getTargetTestingPath(HugeResourceTest.class.getSimpleName() + "-outputdir"); outputDir = MavenTestingUtils.getTargetTestingPath(HugeResourceTest.class.getSimpleName() + "-outputdir");
@ -121,8 +130,9 @@ public class HugeResourceTest
{ {
ArrayList<Arguments> ret = new ArrayList<>(); ArrayList<Arguments> ret = new ArrayList<>();
ret.add(Arguments.of("test-1m.dat", MB));
ret.add(Arguments.of("test-1g.dat", GB)); ret.add(Arguments.of("test-1g.dat", GB));
ret.add(Arguments.of("test-4g.dat", 4 * GB)); // ret.add(Arguments.of("test-4g.dat", 4 * GB));
// ret.add(Arguments.of("test-10g.dat", 10 * GB)); // ret.add(Arguments.of("test-10g.dat", 10 * GB));
return ret.stream(); return ret.stream();
@ -188,9 +198,10 @@ public class HugeResourceTest
assertThat(FileSystemPool.INSTANCE.mounts(), empty()); assertThat(FileSystemPool.INSTANCE.mounts(), empty());
QueuedThreadPool serverThreads = new QueuedThreadPool(); QueuedThreadPool serverThreads = new QueuedThreadPool();
serverThreads.setDetailedDump(true);
serverThreads.setName("server"); serverThreads.setName("server");
server = new Server(serverThreads); server = new Server(serverThreads);
HttpConfiguration httpConfig = new HttpConfiguration(); httpConfig = new HttpConfiguration();
ServerConnector connector = new ServerConnector(server, 1, 1, new HttpConnectionFactory(httpConfig)); ServerConnector connector = new ServerConnector(server, 1, 1, new HttpConnectionFactory(httpConfig));
connector.setPort(0); connector.setPort(0);
server.addConnector(connector); server.addConnector(connector);
@ -201,6 +212,7 @@ public class HugeResourceTest
context.addServlet(PostServlet.class, "/post"); context.addServlet(PostServlet.class, "/post");
context.addServlet(ChunkedServlet.class, "/chunked/*"); context.addServlet(ChunkedServlet.class, "/chunked/*");
context.addServlet(DefaultServlet.class, "/");
String location = multipartTempDir.toString(); String location = multipartTempDir.toString();
long maxFileSize = Long.MAX_VALUE; long maxFileSize = Long.MAX_VALUE;
@ -211,10 +223,11 @@ public class HugeResourceTest
ServletHolder holder = context.addServlet(MultipartServlet.class, "/multipart"); ServletHolder holder = context.addServlet(MultipartServlet.class, "/multipart");
holder.getRegistration().setMultipartConfig(multipartConfig); holder.getRegistration().setMultipartConfig(multipartConfig);
DefaultHandler defaultHandler = new DefaultHandler(); DelayedHandler delayedHandler = new DelayedHandler();
defaultHandler.setServer(server); server.setHandler(delayedHandler);
httpConfig.setDelayDispatchUntilContent(false);
server.setHandler(new Handler.Collection(context, defaultHandler)); delayedHandler.setHandler(new Handler.Collection(context, new DefaultHandler()));
server.start(); server.start();
} }
@ -364,7 +377,64 @@ public class HugeResourceTest
@ParameterizedTest @ParameterizedTest
@MethodSource("staticFiles") @MethodSource("staticFiles")
@Disabled // TODO public void testUploadDelayed(String filename, long expectedSize) throws Exception
{
httpConfig.setDelayDispatchUntilContent(true);
Path inputFile = staticBase.resolve(filename);
AtomicBoolean stalled = new AtomicBoolean(true);
AtomicReference<Runnable> demand = new AtomicReference<>();
PathRequestContent content = new PathRequestContent(inputFile)
{
@Override
public Content.Chunk read()
{
if (stalled.get())
return null;
return super.read();
}
@Override
public void demand(Runnable demandCallback)
{
if (stalled.get())
demand.set(demandCallback);
else
super.demand(demandCallback);
}
};
URI destUri = server.getURI().resolve("/post");
Request request = client.newRequest(destUri).method(HttpMethod.POST).body(content);
StringBuilder responseBody = new StringBuilder();
request.onResponseContent((r, b) ->
{
if (b.hasRemaining())
responseBody.append(BufferUtil.toString(b));
});
AtomicReference<Response> responseRef = new AtomicReference<>();
CountDownLatch complete = new CountDownLatch(1);
request.send(e ->
{
responseRef.set(e.getResponse());
complete.countDown();
});
while (demand.get() == null)
Thread.onSpinWait();
Thread.sleep(100);
stalled.set(false);
demand.get().run();
assertTrue(complete.await(30, TimeUnit.SECONDS));
Response response = responseRef.get();
assertThat("HTTP Response Code", response.getStatus(), is(200));
assertThat("Response", responseBody.toString(), containsString("bytes-received=" + expectedSize));
}
@ParameterizedTest
@MethodSource("staticFiles")
public void testUploadMultipart(String filename, long expectedSize) throws Exception public void testUploadMultipart(String filename, long expectedSize) throws Exception
{ {
MultiPartRequestContent multipart = new MultiPartRequestContent(); MultiPartRequestContent multipart = new MultiPartRequestContent();
@ -385,6 +455,70 @@ public class HugeResourceTest
assertThat("Response", responseBody, containsString(expectedResponse)); assertThat("Response", responseBody, containsString(expectedResponse));
} }
@ParameterizedTest
@MethodSource("staticFiles")
public void testUploadMultipartDelayed(String filename, long expectedSize) throws Exception
{
httpConfig.setDelayDispatchUntilContent(true);
AtomicBoolean stalled = new AtomicBoolean(true);
AtomicReference<Runnable> demand = new AtomicReference<>();
MultiPartRequestContent multipart = new MultiPartRequestContent()
{
@Override
public Content.Chunk read()
{
if (stalled.get())
return null;
return super.read();
}
@Override
public void demand(Runnable demandCallback)
{
if (stalled.get())
demand.set(demandCallback);
else
super.demand(demandCallback);
}
};
Path inputFile = staticBase.resolve(filename);
String name = String.format("file-%d", expectedSize);
multipart.addPart(new MultiPart.PathPart(name, filename, HttpFields.EMPTY, inputFile));
multipart.close();
URI destUri = server.getURI().resolve("/multipart");
client.setIdleTimeout(90_000);
Request request = client.newRequest(destUri).method(HttpMethod.POST).body(multipart);
StringBuilder responseBody = new StringBuilder();
request.onResponseContent((r, b) ->
{
if (b.hasRemaining())
responseBody.append(BufferUtil.toString(b));
});
AtomicReference<Response> responseRef = new AtomicReference<>();
CountDownLatch complete = new CountDownLatch(1);
request.send(e ->
{
responseRef.set(e.getResponse());
complete.countDown();
});
while (demand.get() == null)
Thread.onSpinWait();
Thread.sleep(100);
stalled.set(false);
demand.get().run();
assertTrue(complete.await(30, TimeUnit.SECONDS));
Response response = responseRef.get();
assertThat("HTTP Response Code", response.getStatus(), is(200));
// dumpResponse(response);
String expectedResponse = String.format("part[%s].size=%d", name, expectedSize);
assertThat("Response", responseBody.toString(), containsString(expectedResponse));
}
private void dumpResponse(Response response) private void dumpResponse(Response response)
{ {
System.out.printf(" %s %d %s%n", response.getVersion(), response.getStatus(), response.getReason()); System.out.printf(" %s %d %s%n", response.getVersion(), response.getStatus(), response.getReason());