Fixes #10679 - Review HTTP/2 rate control. (#10681)

* Bumped the rate control rate from 50 events/s to 128.
* Added rate control for all CONTINUATION frames.
* Added rate control for invalid PUSH_PROMISE frames.
* Added rate control for RST_STREAM frames.
* Added rate control for all SETTINGS frames.
* Fixed growth of header block accumulation buffer.

Signed-off-by: Simone Bordet <simone.bordet@gmail.com>
This commit is contained in:
Simone Bordet 2023-10-09 15:07:52 +02:00 committed by GitHub
parent 90fdd4236d
commit dbb94514dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 184 additions and 48 deletions

View File

@ -39,7 +39,7 @@ import org.slf4j.LoggerFactory;
*/
public abstract class BodyParser
{
protected static final Logger LOG = LoggerFactory.getLogger(BodyParser.class);
private static final Logger LOG = LoggerFactory.getLogger(BodyParser.class);
private final HeaderParser headerParser;
private final Parser.Listener listener;

View File

@ -75,16 +75,28 @@ public class ContinuationBodyParser extends BodyParser
int remaining = buffer.remaining();
if (remaining < length)
{
headerBlockFragments.storeFragment(buffer, remaining, false);
ContinuationFrame frame = new ContinuationFrame(getStreamId(), false);
if (!rateControlOnEvent(frame))
return connectionFailure(buffer, ErrorCode.ENHANCE_YOUR_CALM_ERROR.code, "invalid_continuation_frame_rate");
if (!headerBlockFragments.storeFragment(buffer, remaining, false))
return connectionFailure(buffer, ErrorCode.PROTOCOL_ERROR.code, "invalid_continuation_stream");
length -= remaining;
break;
}
else
{
boolean last = hasFlag(Flags.END_HEADERS);
headerBlockFragments.storeFragment(buffer, length, last);
boolean endHeaders = hasFlag(Flags.END_HEADERS);
ContinuationFrame frame = new ContinuationFrame(getStreamId(), endHeaders);
if (!rateControlOnEvent(frame))
return connectionFailure(buffer, ErrorCode.ENHANCE_YOUR_CALM_ERROR.code, "invalid_continuation_frame_rate");
if (!headerBlockFragments.storeFragment(buffer, length, endHeaders))
return connectionFailure(buffer, ErrorCode.PROTOCOL_ERROR.code, "invalid_continuation_stream");
reset();
if (last)
if (endHeaders)
return onHeaders(buffer);
return true;
}
@ -103,17 +115,21 @@ public class ContinuationBodyParser extends BodyParser
ByteBuffer headerBlock = headerBlockFragments.complete();
MetaData metaData = headerBlockParser.parse(headerBlock, headerBlock.remaining());
headerBlockFragments.getByteBufferPool().release(headerBlock);
if (metaData == null)
return true;
HeadersFrame frame = new HeadersFrame(getStreamId(), metaData, headerBlockFragments.getPriorityFrame(), headerBlockFragments.isEndStream());
headerBlockFragments.reset();
if (metaData == HeaderBlockParser.SESSION_FAILURE)
return false;
HeadersFrame frame = new HeadersFrame(getStreamId(), metaData, headerBlockFragments.getPriorityFrame(), headerBlockFragments.isEndStream());
if (metaData == HeaderBlockParser.STREAM_FAILURE)
if (metaData != HeaderBlockParser.STREAM_FAILURE)
{
notifyHeaders(frame);
}
else
{
if (!rateControlOnEvent(frame))
return connectionFailure(buffer, ErrorCode.ENHANCE_YOUR_CALM_ERROR.code, "invalid_continuation_frame_rate");
return connectionFailure(buffer, ErrorCode.ENHANCE_YOUR_CALM_ERROR.code, "invalid_headers_frame_rate");
}
notifyHeaders(frame);
return true;
}

View File

@ -21,14 +21,22 @@ import org.eclipse.jetty.io.ByteBufferPool;
public class HeaderBlockFragments
{
private final ByteBufferPool byteBufferPool;
private final int maxCapacity;
private PriorityFrame priorityFrame;
private boolean endStream;
private int streamId;
private boolean endStream;
private ByteBuffer storage;
@Deprecated
public HeaderBlockFragments(ByteBufferPool byteBufferPool)
{
this(byteBufferPool, 8192);
}
public HeaderBlockFragments(ByteBufferPool byteBufferPool, int maxCapacity)
{
this.byteBufferPool = byteBufferPool;
this.maxCapacity = maxCapacity;
}
public ByteBufferPool getByteBufferPool()
@ -36,18 +44,30 @@ public class HeaderBlockFragments
return byteBufferPool;
}
public void storeFragment(ByteBuffer fragment, int length, boolean last)
void reset()
{
priorityFrame = null;
streamId = 0;
endStream = false;
storage = null;
}
public boolean storeFragment(ByteBuffer fragment, int length, boolean last)
{
if (storage == null)
{
int space = last ? length : length * 2;
storage = byteBufferPool.acquire(space, fragment.isDirect());
if (length > maxCapacity)
return false;
int capacity = last ? length : length * 2;
storage = byteBufferPool.acquire(capacity, fragment.isDirect());
storage.clear();
}
// Grow the storage if necessary.
if (storage.remaining() < length)
{
if (storage.position() + length > maxCapacity)
return false;
int space = last ? length : length * 2;
int capacity = storage.position() + space;
ByteBuffer newStorage = byteBufferPool.acquire(capacity, storage.isDirect());
@ -63,6 +83,7 @@ public class HeaderBlockFragments
fragment.limit(fragment.position() + length);
storage.put(fragment);
fragment.limit(limit);
return true;
}
public PriorityFrame getPriorityFrame()
@ -87,10 +108,8 @@ public class HeaderBlockFragments
public ByteBuffer complete()
{
ByteBuffer result = storage;
storage = null;
result.flip();
return result;
storage.flip();
return storage;
}
public int getStreamId()

View File

@ -22,9 +22,13 @@ import org.eclipse.jetty.http2.frames.FrameType;
import org.eclipse.jetty.http2.frames.HeadersFrame;
import org.eclipse.jetty.http2.frames.PriorityFrame;
import org.eclipse.jetty.util.BufferUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class HeadersBodyParser extends BodyParser
{
private static final Logger LOG = LoggerFactory.getLogger(HeadersBodyParser.class);
private final HeaderBlockParser headerBlockParser;
private final HeaderBlockFragments headerBlockFragments;
private State state = State.PREPARE;
@ -71,8 +75,15 @@ public class HeadersBodyParser extends BodyParser
}
else
{
headerBlockFragments.setStreamId(getStreamId());
headerBlockFragments.setEndStream(isEndStream());
if (headerBlockFragments.getStreamId() != 0)
{
connectionFailure(buffer, ErrorCode.PROTOCOL_ERROR.code, "invalid_headers_frame");
}
else
{
headerBlockFragments.setStreamId(getStreamId());
headerBlockFragments.setEndStream(isEndStream());
}
}
}
@ -167,6 +178,18 @@ public class HeadersBodyParser extends BodyParser
break;
}
case HEADERS:
{
if (!hasFlag(Flags.END_HEADERS))
{
headerBlockFragments.setStreamId(getStreamId());
headerBlockFragments.setEndStream(isEndStream());
if (hasFlag(Flags.PRIORITY))
headerBlockFragments.setPriorityFrame(new PriorityFrame(getStreamId(), parentStreamId, weight, exclusive));
}
state = State.HEADER_BLOCK;
break;
}
case HEADER_BLOCK:
{
if (hasFlag(Flags.END_HEADERS))
{
@ -191,7 +214,7 @@ public class HeadersBodyParser extends BodyParser
{
HeadersFrame frame = new HeadersFrame(getStreamId(), metaData, null, isEndStream());
if (!rateControlOnEvent(frame))
connectionFailure(buffer, ErrorCode.ENHANCE_YOUR_CALM_ERROR.code, "invalid_headers_frame_rate");
return connectionFailure(buffer, ErrorCode.ENHANCE_YOUR_CALM_ERROR.code, "invalid_headers_frame_rate");
}
}
}
@ -200,16 +223,14 @@ public class HeadersBodyParser extends BodyParser
int remaining = buffer.remaining();
if (remaining < length)
{
headerBlockFragments.storeFragment(buffer, remaining, false);
if (!headerBlockFragments.storeFragment(buffer, remaining, false))
return connectionFailure(buffer, ErrorCode.PROTOCOL_ERROR.code, "invalid_headers_frame");
length -= remaining;
}
else
{
headerBlockFragments.setStreamId(getStreamId());
headerBlockFragments.setEndStream(isEndStream());
if (hasFlag(Flags.PRIORITY))
headerBlockFragments.setPriorityFrame(new PriorityFrame(getStreamId(), parentStreamId, weight, exclusive));
headerBlockFragments.storeFragment(buffer, length, false);
if (!headerBlockFragments.storeFragment(buffer, length, false))
return connectionFailure(buffer, ErrorCode.PROTOCOL_ERROR.code, "invalid_headers_frame");
state = State.PADDING;
loop = paddingLength == 0;
}
@ -253,6 +274,6 @@ public class HeadersBodyParser extends BodyParser
private enum State
{
PREPARE, PADDING_LENGTH, EXCLUSIVE, PARENT_STREAM_ID, PARENT_STREAM_ID_BYTES, WEIGHT, HEADERS, PADDING
PREPARE, PADDING_LENGTH, EXCLUSIVE, PARENT_STREAM_ID, PARENT_STREAM_ID_BYTES, WEIGHT, HEADERS, HEADER_BLOCK, PADDING
}
}

View File

@ -85,7 +85,7 @@ public class Parser
this.listener = listener;
unknownBodyParser = new UnknownBodyParser(headerParser, listener);
HeaderBlockParser headerBlockParser = new HeaderBlockParser(headerParser, byteBufferPool, hpackDecoder, unknownBodyParser);
HeaderBlockFragments headerBlockFragments = new HeaderBlockFragments(byteBufferPool);
HeaderBlockFragments headerBlockFragments = new HeaderBlockFragments(byteBufferPool, hpackDecoder.getMaxHeaderListSize());
bodyParsers[FrameType.DATA.getType()] = new DataBodyParser(headerParser, listener);
bodyParsers[FrameType.HEADERS.getType()] = new HeadersBodyParser(headerParser, listener, headerBlockParser, headerBlockFragments);
bodyParsers[FrameType.PRIORITY.getType()] = new PriorityBodyParser(headerParser, listener);

View File

@ -18,6 +18,7 @@ import java.nio.ByteBuffer;
import org.eclipse.jetty.http.MetaData;
import org.eclipse.jetty.http2.ErrorCode;
import org.eclipse.jetty.http2.Flags;
import org.eclipse.jetty.http2.frames.HeadersFrame;
import org.eclipse.jetty.http2.frames.PushPromiseFrame;
public class PushPromiseBodyParser extends BodyParser
@ -65,13 +66,9 @@ public class PushPromiseBodyParser extends BodyParser
length = getBodyLength();
if (isPadding())
{
state = State.PADDING_LENGTH;
}
else
{
state = State.STREAM_ID;
}
break;
}
case PADDING_LENGTH:
@ -131,7 +128,15 @@ public class PushPromiseBodyParser extends BodyParser
state = State.PADDING;
loop = paddingLength == 0;
if (metaData != HeaderBlockParser.STREAM_FAILURE)
{
onPushPromise(streamId, metaData);
}
else
{
HeadersFrame frame = new HeadersFrame(getStreamId(), metaData, null, isEndStream());
if (!rateControlOnEvent(frame))
return connectionFailure(buffer, ErrorCode.ENHANCE_YOUR_CALM_ERROR.code, "invalid_headers_frame_rate");
}
}
break;
}

View File

@ -58,7 +58,7 @@ public class ResetBodyParser extends BodyParser
{
if (buffer.remaining() >= 4)
{
return onReset(buffer.getInt());
return onReset(buffer, buffer.getInt());
}
else
{
@ -73,7 +73,7 @@ public class ResetBodyParser extends BodyParser
--cursor;
error += currByte << (8 * cursor);
if (cursor == 0)
return onReset(error);
return onReset(buffer, error);
break;
}
default:
@ -85,9 +85,11 @@ public class ResetBodyParser extends BodyParser
return false;
}
private boolean onReset(int error)
private boolean onReset(ByteBuffer buffer, int error)
{
ResetFrame frame = new ResetFrame(getStreamId(), error);
if (!rateControlOnEvent(frame))
return connectionFailure(buffer, ErrorCode.ENHANCE_YOUR_CALM_ERROR.code, "invalid_rst_stream_frame_rate");
reset();
notifyReset(frame);
return true;

View File

@ -72,10 +72,7 @@ public class SettingsBodyParser extends BodyParser
return;
boolean isReply = hasFlag(Flags.ACK);
SettingsFrame frame = new SettingsFrame(Collections.emptyMap(), isReply);
if (!isReply && !rateControlOnEvent(frame))
connectionFailure(buffer, ErrorCode.ENHANCE_YOUR_CALM_ERROR.code, "invalid_settings_frame_rate");
else
onSettings(frame);
onSettings(buffer, frame);
}
private boolean validateFrame(ByteBuffer buffer, int streamId, int bodyLength)
@ -218,11 +215,13 @@ public class SettingsBodyParser extends BodyParser
return connectionFailure(buffer, ErrorCode.PROTOCOL_ERROR.code, "invalid_settings_max_frame_size");
SettingsFrame frame = new SettingsFrame(settings, hasFlag(Flags.ACK));
return onSettings(frame);
return onSettings(buffer, frame);
}
private boolean onSettings(SettingsFrame frame)
private boolean onSettings(ByteBuffer buffer, SettingsFrame frame)
{
if (!rateControlOnEvent(frame))
return connectionFailure(buffer, ErrorCode.ENHANCE_YOUR_CALM_ERROR.code, "invalid_settings_frame_rate");
reset();
notifySettings(frame);
return true;

View File

@ -35,7 +35,6 @@ public class UnknownBodyParser extends BodyParser
boolean parsed = cursor == 0;
if (parsed && !rateControlOnEvent(new UnknownFrame(getFrameType())))
return connectionFailure(buffer, ErrorCode.ENHANCE_YOUR_CALM_ERROR.code, "invalid_unknown_frame_rate");
return parsed;
}

View File

@ -16,6 +16,7 @@ package org.eclipse.jetty.http2.frames;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import org.eclipse.jetty.http.HostPortHttpField;
import org.eclipse.jetty.http.HttpField;
@ -32,6 +33,8 @@ import org.eclipse.jetty.io.ByteBufferPool;
import org.eclipse.jetty.io.MappedByteBufferPool;
import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.greaterThan;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
@ -154,4 +157,51 @@ public class ContinuationParseTest
assertNull(priority);
}
}
@Test
public void testLargeHeadersBlock() throws Exception
{
// Use a ByteBufferPool with a small factor, so that the accumulation buffer is not too large.
ByteBufferPool byteBufferPool = new MappedByteBufferPool(128);
// A small max headers size, used for both accumulation and decoding.
int maxHeadersSize = 512;
Parser parser = new Parser(byteBufferPool, maxHeadersSize);
// Specify headers block size to generate CONTINUATION frames.
int maxHeadersBlockFragment = 128;
HeadersGenerator generator = new HeadersGenerator(new HeaderGenerator(), new HpackEncoder(), maxHeadersBlockFragment);
int streamId = 13;
HttpFields fields = HttpFields.build()
.put("Accept", "text/html")
// Large header that generates a large headers block.
.put("User-Agent", "Jetty".repeat(256));
MetaData.Request metaData = new MetaData.Request("GET", HttpScheme.HTTP.asString(), new HostPortHttpField("localhost:8080"), "/path", HttpVersion.HTTP_2, fields, -1);
ByteBufferPool.Lease lease = new ByteBufferPool.Lease(byteBufferPool);
generator.generateHeaders(lease, streamId, metaData, null, true);
List<ByteBuffer> byteBuffers = lease.getByteBuffers();
assertThat(byteBuffers.stream().mapToInt(ByteBuffer::remaining).sum(), greaterThan(maxHeadersSize));
AtomicBoolean failed = new AtomicBoolean();
parser.init(new Parser.Listener.Adapter()
{
@Override
public void onConnectionFailure(int error, String reason)
{
failed.set(true);
}
});
// Set a large max headers size for decoding, to ensure
// the failure is due to accumulation, not decoding.
parser.getHpackDecoder().setMaxHeaderListSize(10 * maxHeadersSize);
for (ByteBuffer byteBuffer : byteBuffers)
{
parser.parse(byteBuffer);
if (failed.get())
break;
}
assertTrue(failed.get());
}
}

View File

@ -89,12 +89,20 @@ public class FrameFloodTest
}
@Test
public void testSettingsFrameFlood()
public void testEmptySettingsFrameFlood()
{
byte[] payload = new byte[0];
testFrameFlood(null, frameFrom(payload.length, FrameType.SETTINGS.getType(), 0, 0, payload));
}
@Test
public void testSettingsFrameFlood()
{
// | Key0 | Key1 | Value0 | Value1 | Value2 | Value3 |
byte[] payload = new byte[]{0, 8, 0, 0, 0, 1};
testFrameFlood(null, frameFrom(payload.length, FrameType.SETTINGS.getType(), 0, 0, payload));
}
@Test
public void testPingFrameFlood()
{
@ -103,7 +111,7 @@ public class FrameFloodTest
}
@Test
public void testContinuationFrameFlood()
public void testEmptyContinuationFrameFlood()
{
int streamId = 13;
byte[] headersPayload = new byte[0];
@ -112,6 +120,23 @@ public class FrameFloodTest
testFrameFlood(headersBytes, frameFrom(continuationPayload.length, FrameType.CONTINUATION.getType(), 0, streamId, continuationPayload));
}
@Test
public void testContinuationFrameFlood()
{
int streamId = 13;
byte[] headersPayload = new byte[0];
byte[] headersBytes = frameFrom(headersPayload.length, FrameType.HEADERS.getType(), 0, streamId, headersPayload);
byte[] continuationPayload = new byte[1];
testFrameFlood(headersBytes, frameFrom(continuationPayload.length, FrameType.CONTINUATION.getType(), 0, streamId, continuationPayload));
}
@Test
public void testResetStreamFrameFlood()
{
byte[] payload = {0, 0, 0, 0};
testFrameFlood(null, frameFrom(payload.length, FrameType.RST_STREAM.getType(), 0, 13, payload));
}
@Test
public void testUnknownFrameFlood()
{

View File

@ -63,7 +63,7 @@ public abstract class AbstractHTTP2ServerConnectionFactory extends AbstractConne
private int maxFrameSize = Frame.DEFAULT_MAX_LENGTH;
private int maxSettingsKeys = SettingsFrame.DEFAULT_MAX_KEYS;
private boolean connectProtocolEnabled = true;
private RateControl.Factory rateControlFactory = new WindowRateControl.Factory(50);
private RateControl.Factory rateControlFactory = new WindowRateControl.Factory(128);
private FlowControlStrategy.Factory flowControlStrategyFactory = () -> new BufferingFlowControlStrategy(0.5F);
private long streamIdleTimeout;
private boolean useInputDirectByteBuffers;