[ML][Inference] stream inflate to parser + throw when byte limit is reached (#51644) (#51679)

Three fixes for when the `compressed_definition` is utilized on PUT

* Update the inflate byte limit to be the minimum of 10% the max heap, or 1GB (what it was previously)
* Stream data directly to the JSON parser, so if it is invalid, we don't have to inflate the whole stream to find out
* Throw when the maximum bytes are reach indicating that is why the request was rejected
This commit is contained in:
Benjamin Trent 2020-01-30 10:16:14 -05:00 committed by GitHub
parent 9c7a63214c
commit 2a2a0941af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 32 deletions

View File

@ -9,7 +9,6 @@ package org.elasticsearch.xpack.core.ml.inference;
import org.elasticsearch.common.CheckedFunction;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.Streams;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
@ -17,6 +16,8 @@ import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.monitor.jvm.JvmInfo;
import org.elasticsearch.xpack.core.ml.inference.utils.SimpleBoundedInputStream;
import java.io.IOException;
@ -33,7 +34,10 @@ import java.util.zip.GZIPOutputStream;
*/
public final class InferenceToXContentCompressor {
private static final int BUFFER_SIZE = 4096;
private static final long MAX_INFLATED_BYTES = 1_000_000_000; // 1 gb maximum
// Either 10% of the configured JVM heap, or 1 GB, which ever is smaller
private static final long MAX_INFLATED_BYTES = Math.min(
(long)((0.10) * JvmInfo.jvmInfo().getMem().getHeapMax().getBytes()),
1_000_000_000); // 1 gb maximum
private InferenceToXContentCompressor() {}
@ -45,33 +49,34 @@ public final class InferenceToXContentCompressor {
static <T> T inflate(String compressedString,
CheckedFunction<XContentParser, T, IOException> parserFunction,
NamedXContentRegistry xContentRegistry) throws IOException {
try(XContentParser parser = XContentHelper.createParser(xContentRegistry,
try(XContentParser parser = JsonXContent.jsonXContent.createParser(xContentRegistry,
LoggingDeprecationHandler.INSTANCE,
inflate(compressedString, MAX_INFLATED_BYTES),
XContentType.JSON)) {
inflate(compressedString, MAX_INFLATED_BYTES))) {
return parserFunction.apply(parser);
}
}
static Map<String, Object> inflateToMap(String compressedString) throws IOException {
// Don't need the xcontent registry as we are not deflating named objects.
try(XContentParser parser = XContentHelper.createParser(NamedXContentRegistry.EMPTY,
try(XContentParser parser = JsonXContent.jsonXContent.createParser(NamedXContentRegistry.EMPTY,
LoggingDeprecationHandler.INSTANCE,
inflate(compressedString, MAX_INFLATED_BYTES),
XContentType.JSON)) {
inflate(compressedString, MAX_INFLATED_BYTES))) {
return parser.mapOrdered();
}
}
static BytesReference inflate(String compressedString, long streamSize) throws IOException {
static InputStream inflate(String compressedString, long streamSize) throws IOException {
byte[] compressedBytes = Base64.getDecoder().decode(compressedString.getBytes(StandardCharsets.UTF_8));
// If the compressed length is already too large, it make sense that the inflated length would be as well
// In the extremely small string case, the compressed data could actually be longer than the compressed stream
if (compressedBytes.length > Math.max(100L, streamSize)) {
throw new IOException("compressed stream is longer than maximum allowed bytes [" + streamSize + "]");
}
InputStream gzipStream = new GZIPInputStream(new BytesArray(compressedBytes).streamInput(), BUFFER_SIZE);
InputStream inflateStream = new SimpleBoundedInputStream(gzipStream, streamSize);
return Streams.readFully(inflateStream);
return new SimpleBoundedInputStream(gzipStream, streamSize);
}
//Public for testing (for now)
public static String deflate(BytesReference reference) throws IOException {
private static String deflate(BytesReference reference) throws IOException {
BytesStreamOutput out = new BytesStreamOutput();
try (OutputStream compressedOutput = new GZIPOutputStream(out, BUFFER_SIZE)) {
reference.writeTo(compressedOutput);

View File

@ -28,17 +28,16 @@ public final class SimpleBoundedInputStream extends InputStream {
this.maxBytes = maxBytes;
}
/**
* A simple wrapper around the injected input stream that restricts the total number of bytes able to be read.
* @return The byte read. -1 on internal stream completion or when maxBytes is exceeded.
* @throws IOException on failure
* @return The byte read.
* @throws IOException on failure or when byte limit is exceeded
*/
@Override
public int read() throws IOException {
// We have reached the maximum, signal stream completion.
if (numBytes >= maxBytes) {
return -1;
throw new IOException("input stream exceeded maximum bytes of [" + maxBytes + "]");
}
numBytes++;
return in.read();

View File

@ -5,15 +5,17 @@
*/
package org.elasticsearch.xpack.core.ml.inference;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.io.Streams;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.hamcrest.Matchers.equalTo;
@ -33,20 +35,22 @@ public class InferenceToXContentCompressorTests extends ESTestCase {
}
public void testInflateTooLargeStream() throws IOException {
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build();
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder()
.setPreProcessors(Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
OneHotEncodingTests.createRandom(),
TargetMeanEncodingTests.createRandom()))
.limit(100)
.collect(Collectors.toList()))
.build();
String firstDeflate = InferenceToXContentCompressor.deflate(definition);
BytesReference inflatedBytes = InferenceToXContentCompressor.inflate(firstDeflate, 10L);
assertThat(inflatedBytes.length(), equalTo(10));
try(XContentParser parser = XContentHelper.createParser(xContentRegistry(),
LoggingDeprecationHandler.INSTANCE,
inflatedBytes,
XContentType.JSON)) {
expectThrows(IOException.class, () -> TrainedModelConfig.fromXContent(parser, true));
}
int max = firstDeflate.getBytes(StandardCharsets.UTF_8).length + 10;
IOException ex = expectThrows(IOException.class,
() -> Streams.readFully(InferenceToXContentCompressor.inflate(firstDeflate, max)));
assertThat(ex.getMessage(), equalTo("input stream exceeded maximum bytes of [" + max + "]"));
}
public void testInflateGarbage() {
expectThrows(IOException.class, () -> InferenceToXContentCompressor.inflate(randomAlphaOfLength(10), 100L));
expectThrows(IOException.class, () -> Streams.readFully(InferenceToXContentCompressor.inflate(randomAlphaOfLength(10), 100L)));
}
@Override