Fix issue with pipeline releasing bytes early (#54474)

Currently there is an issue with the InboundPipeline releasing bytes
earlier than appropriate. This can lead to the bytes being reused before
the message is handled. This commit fixes that issue and adds a test to
detect when it is occurring.
This commit is contained in:
Tim Brooks 2020-03-30 22:39:15 -06:00 committed by GitHub
parent 5d760051a9
commit 915435bbe4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 72 additions and 29 deletions

View File

@ -69,37 +69,22 @@ public class InboundPipeline implements Releasable {
public void handleBytes(TcpChannel channel, ReleasableBytesReference reference) throws IOException {
pending.add(reference.retain());
final ReleasableBytesReference composite;
if (pending.size() == 1) {
composite = pending.peekFirst();
} else {
final ReleasableBytesReference[] bytesReferences = pending.toArray(new ReleasableBytesReference[0]);
final Releasable releasable = () -> Releasables.closeWhileHandlingException(bytesReferences);
composite = new ReleasableBytesReference(new CompositeBytesReference(bytesReferences), releasable);
}
final ArrayList<Object> fragments = fragmentList.get();
int bytesConsumed = 0;
boolean continueHandling = true;
while (continueHandling && isClosed == false) {
boolean continueDecoding = true;
while (continueDecoding) {
final int remaining = composite.length() - bytesConsumed;
if (remaining != 0) {
try (ReleasableBytesReference slice = composite.retainedSlice(bytesConsumed, remaining)) {
final int bytesDecoded = decoder.decode(slice, fragments::add);
if (bytesDecoded != 0) {
bytesConsumed += bytesDecoded;
if (fragments.isEmpty() == false && endOfMessage(fragments.get(fragments.size() - 1))) {
continueDecoding = false;
}
} else {
while (continueDecoding && pending.isEmpty() == false) {
try (ReleasableBytesReference toDecode = getPendingBytes()) {
final int bytesDecoded = decoder.decode(toDecode, fragments::add);
if (bytesDecoded != 0) {
releasePendingBytes(bytesDecoded);
if (fragments.isEmpty() == false && endOfMessage(fragments.get(fragments.size() - 1))) {
continueDecoding = false;
}
} else {
continueDecoding = false;
}
} else {
continueDecoding = false;
}
}
@ -118,8 +103,6 @@ public class InboundPipeline implements Releasable {
}
}
}
releasePendingBytes(bytesConsumed);
}
private void forwardFragments(TcpChannel channel, ArrayList<Object> fragments) throws IOException {
@ -155,11 +138,22 @@ public class InboundPipeline implements Releasable {
return fragment == InboundDecoder.PING || fragment == InboundDecoder.END_CONTENT || fragment instanceof Exception;
}
private void releasePendingBytes(int bytesConsumed) {
if (isClosed) {
// Are released by the close method
return;
private ReleasableBytesReference getPendingBytes() {
if (pending.size() == 1) {
return pending.peekFirst().retain();
} else {
final ReleasableBytesReference[] bytesReferences = new ReleasableBytesReference[pending.size()];
int index = 0;
for (ReleasableBytesReference pendingReference : pending) {
bytesReferences[index] = pendingReference.retain();
++index;
}
final Releasable releasable = () -> Releasables.closeWhileHandlingException(bytesReferences);
return new ReleasableBytesReference(new CompositeBytesReference(bytesReferences), releasable);
}
}
private void releasePendingBytes(int bytesConsumed) {
int bytesToRelease = bytesConsumed;
while (bytesToRelease != 0) {
try (ReleasableBytesReference reference = pending.pollFirst()) {

View File

@ -25,6 +25,7 @@ import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.Streams;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.common.util.concurrent.ThreadContext;
@ -35,6 +36,7 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import static org.hamcrest.Matchers.instanceOf;
@ -172,6 +174,53 @@ public class InboundPipelineTests extends ESTestCase {
}
}
public void testEnsureBodyIsNotPrematurelyReleased() throws IOException {
final PageCacheRecycler recycler = PageCacheRecycler.NON_RECYCLING_INSTANCE;
BiConsumer<TcpChannel, InboundMessage> messageHandler = (c, m) -> {};
BiConsumer<TcpChannel, Tuple<Header, Exception>> errorHandler = (c, e) -> {};
final InboundPipeline pipeline = new InboundPipeline(Version.CURRENT, recycler, messageHandler, errorHandler);
try (BytesStreamOutput streamOutput = new BytesStreamOutput()) {
String actionName = "actionName";
final Version version = Version.CURRENT;
final String value = randomAlphaOfLength(1000);
final boolean isRequest = randomBoolean();
final long requestId = randomNonNegativeLong();
OutboundMessage message;
if (isRequest) {
message = new OutboundMessage.Request(threadContext, new String[0], new TestRequest(value),
version, actionName, requestId, false, false);
} else {
message = new OutboundMessage.Response(threadContext, Collections.emptySet(), new TestResponse(value),
version, requestId, false, false);
}
final BytesReference reference = message.serialize(streamOutput);
final int fixedHeaderSize = TcpHeader.headerSize(Version.CURRENT);
final int variableHeaderSize = reference.getInt(fixedHeaderSize - 4);
final int totalHeaderSize = fixedHeaderSize + variableHeaderSize;
final AtomicBoolean bodyReleased = new AtomicBoolean(false);
for (int i = 0; i < totalHeaderSize - 1; ++i) {
try (ReleasableBytesReference slice = ReleasableBytesReference.wrap(reference.slice(i, 1))) {
pipeline.handleBytes(new FakeTcpChannel(), slice);
}
}
final Releasable releasable = () -> bodyReleased.set(true);
final int from = totalHeaderSize - 1;
final BytesReference partHeaderPartBody = reference.slice(from, reference.length() - from - 1);
try (ReleasableBytesReference slice = new ReleasableBytesReference(partHeaderPartBody, releasable)) {
pipeline.handleBytes(new FakeTcpChannel(), slice);
}
assertFalse(bodyReleased.get());
try (ReleasableBytesReference slice = new ReleasableBytesReference(reference.slice(reference.length() - 1, 1), releasable)) {
pipeline.handleBytes(new FakeTcpChannel(), slice);
}
assertTrue(bodyReleased.get());
}
}
private static class MessageData {
private final Version version;