[ML][Inference] stream inflate to parser + throw when byte limit is reached (#51644) (#51679)

Three fixes for when the `compressed_definition` is utilized on PUT

* Update the inflate byte limit to be the minimum of 10% the max heap, or 1GB (what it was previously)
* Stream data directly to the JSON parser, so if it is invalid, we don't have to inflate the whole stream to find out
* Throw when the maximum bytes are reach indicating that is why the request was rejected
This commit is contained in:
Benjamin Trent 2020-01-30 10:16:14 -05:00 committed by GitHub
parent 9c7a63214c
commit 2a2a0941af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 32 deletions

View File

@ -9,7 +9,6 @@ package org.elasticsearch.xpack.core.ml.inference;
import org.elasticsearch.common.CheckedFunction; import org.elasticsearch.common.CheckedFunction;
import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.Streams;
import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
@ -17,6 +16,8 @@ import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.monitor.jvm.JvmInfo;
import org.elasticsearch.xpack.core.ml.inference.utils.SimpleBoundedInputStream; import org.elasticsearch.xpack.core.ml.inference.utils.SimpleBoundedInputStream;
import java.io.IOException; import java.io.IOException;
@ -33,7 +34,10 @@ import java.util.zip.GZIPOutputStream;
*/ */
public final class InferenceToXContentCompressor { public final class InferenceToXContentCompressor {
private static final int BUFFER_SIZE = 4096; private static final int BUFFER_SIZE = 4096;
private static final long MAX_INFLATED_BYTES = 1_000_000_000; // 1 gb maximum // Either 10% of the configured JVM heap, or 1 GB, which ever is smaller
private static final long MAX_INFLATED_BYTES = Math.min(
(long)((0.10) * JvmInfo.jvmInfo().getMem().getHeapMax().getBytes()),
1_000_000_000); // 1 gb maximum
private InferenceToXContentCompressor() {} private InferenceToXContentCompressor() {}
@ -45,33 +49,34 @@ public final class InferenceToXContentCompressor {
static <T> T inflate(String compressedString, static <T> T inflate(String compressedString,
CheckedFunction<XContentParser, T, IOException> parserFunction, CheckedFunction<XContentParser, T, IOException> parserFunction,
NamedXContentRegistry xContentRegistry) throws IOException { NamedXContentRegistry xContentRegistry) throws IOException {
try(XContentParser parser = XContentHelper.createParser(xContentRegistry, try(XContentParser parser = JsonXContent.jsonXContent.createParser(xContentRegistry,
LoggingDeprecationHandler.INSTANCE, LoggingDeprecationHandler.INSTANCE,
inflate(compressedString, MAX_INFLATED_BYTES), inflate(compressedString, MAX_INFLATED_BYTES))) {
XContentType.JSON)) {
return parserFunction.apply(parser); return parserFunction.apply(parser);
} }
} }
static Map<String, Object> inflateToMap(String compressedString) throws IOException { static Map<String, Object> inflateToMap(String compressedString) throws IOException {
// Don't need the xcontent registry as we are not deflating named objects. // Don't need the xcontent registry as we are not deflating named objects.
try(XContentParser parser = XContentHelper.createParser(NamedXContentRegistry.EMPTY, try(XContentParser parser = JsonXContent.jsonXContent.createParser(NamedXContentRegistry.EMPTY,
LoggingDeprecationHandler.INSTANCE, LoggingDeprecationHandler.INSTANCE,
inflate(compressedString, MAX_INFLATED_BYTES), inflate(compressedString, MAX_INFLATED_BYTES))) {
XContentType.JSON)) {
return parser.mapOrdered(); return parser.mapOrdered();
} }
} }
static BytesReference inflate(String compressedString, long streamSize) throws IOException { static InputStream inflate(String compressedString, long streamSize) throws IOException {
byte[] compressedBytes = Base64.getDecoder().decode(compressedString.getBytes(StandardCharsets.UTF_8)); byte[] compressedBytes = Base64.getDecoder().decode(compressedString.getBytes(StandardCharsets.UTF_8));
// If the compressed length is already too large, it make sense that the inflated length would be as well
// In the extremely small string case, the compressed data could actually be longer than the compressed stream
if (compressedBytes.length > Math.max(100L, streamSize)) {
throw new IOException("compressed stream is longer than maximum allowed bytes [" + streamSize + "]");
}
InputStream gzipStream = new GZIPInputStream(new BytesArray(compressedBytes).streamInput(), BUFFER_SIZE); InputStream gzipStream = new GZIPInputStream(new BytesArray(compressedBytes).streamInput(), BUFFER_SIZE);
InputStream inflateStream = new SimpleBoundedInputStream(gzipStream, streamSize); return new SimpleBoundedInputStream(gzipStream, streamSize);
return Streams.readFully(inflateStream);
} }
//Public for testing (for now) private static String deflate(BytesReference reference) throws IOException {
public static String deflate(BytesReference reference) throws IOException {
BytesStreamOutput out = new BytesStreamOutput(); BytesStreamOutput out = new BytesStreamOutput();
try (OutputStream compressedOutput = new GZIPOutputStream(out, BUFFER_SIZE)) { try (OutputStream compressedOutput = new GZIPOutputStream(out, BUFFER_SIZE)) {
reference.writeTo(compressedOutput); reference.writeTo(compressedOutput);

View File

@ -28,17 +28,16 @@ public final class SimpleBoundedInputStream extends InputStream {
this.maxBytes = maxBytes; this.maxBytes = maxBytes;
} }
/** /**
* A simple wrapper around the injected input stream that restricts the total number of bytes able to be read. * A simple wrapper around the injected input stream that restricts the total number of bytes able to be read.
* @return The byte read. -1 on internal stream completion or when maxBytes is exceeded. * @return The byte read.
* @throws IOException on failure * @throws IOException on failure or when byte limit is exceeded
*/ */
@Override @Override
public int read() throws IOException { public int read() throws IOException {
// We have reached the maximum, signal stream completion. // We have reached the maximum, signal stream completion.
if (numBytes >= maxBytes) { if (numBytes >= maxBytes) {
return -1; throw new IOException("input stream exceeded maximum bytes of [" + maxBytes + "]");
} }
numBytes++; numBytes++;
return in.read(); return in.read();

View File

@ -5,15 +5,17 @@
*/ */
package org.elasticsearch.xpack.core.ml.inference; package org.elasticsearch.xpack.core.ml.inference;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.Streams;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
@ -33,20 +35,22 @@ public class InferenceToXContentCompressorTests extends ESTestCase {
} }
public void testInflateTooLargeStream() throws IOException { public void testInflateTooLargeStream() throws IOException {
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build(); TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder()
.setPreProcessors(Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
OneHotEncodingTests.createRandom(),
TargetMeanEncodingTests.createRandom()))
.limit(100)
.collect(Collectors.toList()))
.build();
String firstDeflate = InferenceToXContentCompressor.deflate(definition); String firstDeflate = InferenceToXContentCompressor.deflate(definition);
BytesReference inflatedBytes = InferenceToXContentCompressor.inflate(firstDeflate, 10L); int max = firstDeflate.getBytes(StandardCharsets.UTF_8).length + 10;
assertThat(inflatedBytes.length(), equalTo(10)); IOException ex = expectThrows(IOException.class,
try(XContentParser parser = XContentHelper.createParser(xContentRegistry(), () -> Streams.readFully(InferenceToXContentCompressor.inflate(firstDeflate, max)));
LoggingDeprecationHandler.INSTANCE, assertThat(ex.getMessage(), equalTo("input stream exceeded maximum bytes of [" + max + "]"));
inflatedBytes,
XContentType.JSON)) {
expectThrows(IOException.class, () -> TrainedModelConfig.fromXContent(parser, true));
}
} }
public void testInflateGarbage() { public void testInflateGarbage() {
expectThrows(IOException.class, () -> InferenceToXContentCompressor.inflate(randomAlphaOfLength(10), 100L)); expectThrows(IOException.class, () -> Streams.readFully(InferenceToXContentCompressor.inflate(randomAlphaOfLength(10), 100L)));
} }
@Override @Override