Abort non-fully consumed S3 input stream (#62167) (#62370)

Today when an S3RetryingInputStream is closed the remaining bytes 
that were not consumed are drained right before closing the underlying 
stream. In some contexts it might be more efficient to not consume the 
remaining bytes and just drop the connection.

This is for example the case with snapshot backed indices prewarming, 
where there is not point in reading potentially large blobs if we know 
the cache file we want to write the content of the blob as already been 
evicted. Draining all bytes here takes a slot in the prewarming thread 
pool for nothing.
This commit is contained in:
Tanguy Leroux 2020-09-15 14:33:37 +02:00 committed by GitHub
parent 4eea602d2d
commit faf96c175e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 251 additions and 18 deletions

View File

@ -21,12 +21,13 @@ package org.elasticsearch.repositories.s3;
import com.amazonaws.AmazonClientException;
import com.amazonaws.services.s3.model.AmazonS3Exception;
import com.amazonaws.services.s3.model.GetObjectRequest;
import com.amazonaws.services.s3.model.ObjectMetadata;
import com.amazonaws.services.s3.model.S3Object;
import com.amazonaws.services.s3.model.S3ObjectInputStream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.Streams;
import org.elasticsearch.core.internal.io.IOUtils;
import java.io.IOException;
@ -53,12 +54,14 @@ class S3RetryingInputStream extends InputStream {
private final long start;
private final long end;
private final int maxAttempts;
private final List<IOException> failures;
private InputStream currentStream;
private S3ObjectInputStream currentStream;
private long currentStreamLastOffset;
private int attempt = 1;
private List<IOException> failures = new ArrayList<>(MAX_SUPPRESSED_EXCEPTIONS);
private long currentOffset;
private boolean closed;
private boolean eof;
S3RetryingInputStream(S3BlobStore blobStore, String blobKey) throws IOException {
this(blobStore, blobKey, 0, Long.MAX_VALUE - 1);
@ -75,12 +78,13 @@ class S3RetryingInputStream extends InputStream {
this.blobStore = blobStore;
this.blobKey = blobKey;
this.maxAttempts = blobStore.getMaxRetries() + 1;
this.failures = new ArrayList<>(MAX_SUPPRESSED_EXCEPTIONS);
this.start = start;
this.end = end;
currentStream = openStream();
openStream();
}
private InputStream openStream() throws IOException {
private void openStream() throws IOException {
try (AmazonS3Reference clientReference = blobStore.clientReference()) {
final GetObjectRequest getObjectRequest = new GetObjectRequest(blobStore.bucket(), blobKey);
getObjectRequest.setRequestMetricCollector(blobStore.getMetricCollector);
@ -90,7 +94,8 @@ class S3RetryingInputStream extends InputStream {
getObjectRequest.setRange(Math.addExact(start, currentOffset), end);
}
final S3Object s3Object = SocketAccess.doPrivileged(() -> clientReference.client().getObject(getObjectRequest));
return s3Object.getObjectContent();
this.currentStreamLastOffset = Math.addExact(Math.addExact(start, currentOffset), getStreamLength(s3Object));
this.currentStream = s3Object.getObjectContent();
} catch (final AmazonClientException e) {
if (e instanceof AmazonS3Exception) {
if (404 == ((AmazonS3Exception) e).getStatusCode()) {
@ -101,12 +106,35 @@ class S3RetryingInputStream extends InputStream {
}
}
private long getStreamLength(final S3Object object) {
final ObjectMetadata metadata = object.getObjectMetadata();
try {
// Returns the content range of the object if response contains the Content-Range header.
final Long[] range = metadata.getContentRange();
if (range != null) {
assert range[1] >= range[0] : range[1] + " vs " + range[0];
assert range[0] == start + currentOffset :
"Content-Range start value [" + range[0] + "] exceeds start [" + start + "] + current offset [" + currentOffset + ']';
assert range[1] == end : "Content-Range end value [" + range[1] + "] exceeds end [" + end + ']';
return range[1] - range[0] + 1L;
}
return metadata.getContentLength();
} catch (Exception e) {
assert false : e;
return Long.MAX_VALUE - 1L; // assume a large stream so that the underlying stream is aborted on closing, unless eof is reached
}
}
@Override
public int read() throws IOException {
ensureOpen();
while (true) {
try {
final int result = currentStream.read();
if (result == -1) {
eof = true;
return -1;
}
currentOffset += 1;
return result;
} catch (IOException e) {
@ -122,6 +150,7 @@ class S3RetryingInputStream extends InputStream {
try {
final int bytesRead = currentStream.read(b, off, len);
if (bytesRead == -1) {
eof = true;
return -1;
}
currentOffset += bytesRead;
@ -151,24 +180,36 @@ class S3RetryingInputStream extends InputStream {
if (failures.size() < MAX_SUPPRESSED_EXCEPTIONS) {
failures.add(e);
}
try {
Streams.consumeFully(currentStream);
} catch (Exception e2) {
logger.trace("Failed to fully consume stream on close", e);
}
maybeAbort(currentStream);
IOUtils.closeWhileHandlingException(currentStream);
currentStream = openStream();
openStream();
}
@Override
public void close() throws IOException {
maybeAbort(currentStream);
try {
Streams.consumeFully(currentStream);
} catch (Exception e) {
logger.trace("Failed to fully consume stream on close", e);
currentStream.close();
} finally {
closed = true;
}
}
/**
* Abort the {@link S3ObjectInputStream} if it wasn't read completely at the time this method is called,
* suppressing all thrown exceptions.
*/
private void maybeAbort(S3ObjectInputStream stream) {
if (eof) {
return;
}
try {
if (start + currentOffset < currentStreamLastOffset) {
stream.abort();
}
} catch (Exception e) {
logger.warn("Failed to abort stream before closing", e);
}
currentStream.close();
closed = true;
}
@Override
@ -187,4 +228,17 @@ class S3RetryingInputStream extends InputStream {
}
return e;
}
// package-private for tests
boolean isEof() {
return eof;
}
// package-private for tests
boolean isAborted() {
if (currentStream == null || currentStream.getHttpRequest() == null) {
return false;
}
return currentStream.getHttpRequest().isAborted();
}
}

View File

@ -43,6 +43,8 @@ import org.elasticsearch.repositories.blobstore.AbstractBlobContainerRetriesTest
import org.junit.After;
import org.junit.Before;
import java.io.ByteArrayInputStream;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.InetSocketAddress;
@ -133,7 +135,17 @@ public class S3BlobContainerRetriesTests extends AbstractBlobContainerRetriesTes
bufferSize == null ? S3Repository.BUFFER_SIZE_SETTING.getDefault(Settings.EMPTY) : bufferSize,
S3Repository.CANNED_ACL_SETTING.getDefault(Settings.EMPTY),
S3Repository.STORAGE_CLASS_SETTING.getDefault(Settings.EMPTY),
repositoryMetadata));
repositoryMetadata)) {
@Override
public InputStream readBlob(String blobName) throws IOException {
return new AssertingInputStream(super.readBlob(blobName), blobName);
}
@Override
public InputStream readBlob(String blobName, long position, long length) throws IOException {
return new AssertingInputStream(super.readBlob(blobName, position, length), blobName, position, length);
}
};
}
public void testWriteBlobWithRetries() throws Exception {
@ -292,4 +304,55 @@ public class S3BlobContainerRetriesTests extends AbstractBlobContainerRetriesTes
assertThat(countDownUploads.get(), equalTo(0));
assertThat(countDownComplete.isCountedDown(), is(true));
}
/**
* Asserts that an InputStream is fully consumed, or aborted, when it is closed
*/
private static class AssertingInputStream extends FilterInputStream {
private final String blobName;
private final boolean range;
private final long position;
private final long length;
AssertingInputStream(InputStream in, String blobName) {
super(in);
this.blobName = blobName;
this.position = 0L;
this.length = Long.MAX_VALUE;
this.range = false;
}
AssertingInputStream(InputStream in, String blobName, long position, long length) {
super(in);
this.blobName = blobName;
this.position = position;
this.length = length;
this.range = true;
}
@Override
public String toString() {
String description = "[blobName='" + blobName + "', range=" + range;
if (range) {
description += ", position=" + position;
description += ", length=" + length;
}
description += ']';
return description;
}
@Override
public void close() throws IOException {
super.close();
if (in instanceof S3RetryingInputStream) {
final S3RetryingInputStream s3Stream = (S3RetryingInputStream) in;
assertTrue("Stream " + toString() + " should have reached EOF or should have been aborted but got [eof=" + s3Stream.isEof()
+ ", aborted=" + s3Stream.isAborted() + ']', s3Stream.isEof() || s3Stream.isAborted());
} else {
assertThat(in, instanceOf(ByteArrayInputStream.class));
assertThat(((ByteArrayInputStream) in).available(), equalTo(0));
}
}
}
}

View File

@ -0,0 +1,116 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.repositories.s3;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.model.GetObjectRequest;
import com.amazonaws.services.s3.model.S3Object;
import com.amazonaws.services.s3.model.S3ObjectInputStream;
import org.apache.http.client.methods.HttpGet;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.Streams;
import org.elasticsearch.test.ESTestCase;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.Arrays;
import static org.hamcrest.Matchers.is;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class S3RetryingInputStreamTests extends ESTestCase {
public void testInputStreamFullyConsumed() throws IOException {
final byte[] expectedBytes = randomByteArrayOfLength(randomIntBetween(1, 512));
final S3RetryingInputStream stream = createInputStream(expectedBytes, null, null);
Streams.consumeFully(stream);
assertThat(stream.isEof(), is(true));
assertThat(stream.isAborted(), is(false));
}
public void testInputStreamIsAborted() throws IOException {
final byte[] expectedBytes = randomByteArrayOfLength(randomIntBetween(10, 512));
final byte[] actualBytes = new byte[randomIntBetween(1, Math.max(1, expectedBytes.length - 1))];
final S3RetryingInputStream stream = createInputStream(expectedBytes, null, null);
stream.read(actualBytes);
stream.close();
assertArrayEquals(Arrays.copyOf(expectedBytes, actualBytes.length), actualBytes);
assertThat(stream.isEof(), is(false));
assertThat(stream.isAborted(), is(true));
}
public void testRangeInputStreamFullyConsumed() throws IOException {
final byte[] bytes = randomByteArrayOfLength(randomIntBetween(1, 512));
final int position = randomIntBetween(0, bytes.length - 1);
final int length = randomIntBetween(1, bytes.length - position);
final S3RetryingInputStream stream = createInputStream(bytes, position, length);
Streams.consumeFully(stream);
assertThat(stream.isEof(), is(true));
assertThat(stream.isAborted(), is(false));
}
public void testRangeInputStreamIsAborted() throws IOException {
final byte[] expectedBytes = randomByteArrayOfLength(randomIntBetween(10, 512));
final byte[] actualBytes = new byte[randomIntBetween(1, Math.max(1, expectedBytes.length - 1))];
final int length = randomIntBetween(actualBytes.length + 1, expectedBytes.length);
final int position = randomIntBetween(0, Math.max(1, expectedBytes.length - length));
final S3RetryingInputStream stream = createInputStream(expectedBytes, position, length);
stream.read(actualBytes);
stream.close();
assertArrayEquals(Arrays.copyOfRange(expectedBytes, position, position + actualBytes.length), actualBytes);
assertThat(stream.isEof(), is(false));
assertThat(stream.isAborted(), is(true));
}
private S3RetryingInputStream createInputStream(
final byte[] data,
@Nullable final Integer position,
@Nullable final Integer length
) throws IOException {
final S3Object s3Object = new S3Object();
final AmazonS3 client = mock(AmazonS3.class);
when(client.getObject(any(GetObjectRequest.class))).thenReturn(s3Object);
final AmazonS3Reference clientReference = mock(AmazonS3Reference.class);
when(clientReference.client()).thenReturn(client);
final S3BlobStore blobStore = mock(S3BlobStore.class);
when(blobStore.clientReference()).thenReturn(clientReference);
if (position != null && length != null) {
s3Object.getObjectMetadata().setContentLength(length);
s3Object.setObjectContent(new S3ObjectInputStream(new ByteArrayInputStream(data, position, length), new HttpGet()));
return new S3RetryingInputStream(blobStore, "_blob", position, Math.addExact(position, length - 1));
} else {
s3Object.getObjectMetadata().setContentLength(data.length);
s3Object.setObjectContent(new S3ObjectInputStream(new ByteArrayInputStream(data), new HttpGet()));
return new S3RetryingInputStream(blobStore, "_blob");
}
}
}