Three fixes for when the `compressed_definition` is utilized on PUT * Update the inflate byte limit to be the minimum of 10% the max heap, or 1GB (what it was previously) * Stream data directly to the JSON parser, so if it is invalid, we don't have to inflate the whole stream to find out * Throw when the maximum bytes are reach indicating that is why the request was rejected
This commit is contained in:
parent
9c7a63214c
commit
2a2a0941af
|
@ -9,7 +9,6 @@ package org.elasticsearch.xpack.core.ml.inference;
|
|||
import org.elasticsearch.common.CheckedFunction;
|
||||
import org.elasticsearch.common.bytes.BytesArray;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.io.Streams;
|
||||
import org.elasticsearch.common.io.stream.BytesStreamOutput;
|
||||
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
|
@ -17,6 +16,8 @@ import org.elasticsearch.common.xcontent.ToXContentObject;
|
|||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.monitor.jvm.JvmInfo;
|
||||
import org.elasticsearch.xpack.core.ml.inference.utils.SimpleBoundedInputStream;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -33,7 +34,10 @@ import java.util.zip.GZIPOutputStream;
|
|||
*/
|
||||
public final class InferenceToXContentCompressor {
|
||||
private static final int BUFFER_SIZE = 4096;
|
||||
private static final long MAX_INFLATED_BYTES = 1_000_000_000; // 1 gb maximum
|
||||
// Either 10% of the configured JVM heap, or 1 GB, which ever is smaller
|
||||
private static final long MAX_INFLATED_BYTES = Math.min(
|
||||
(long)((0.10) * JvmInfo.jvmInfo().getMem().getHeapMax().getBytes()),
|
||||
1_000_000_000); // 1 gb maximum
|
||||
|
||||
private InferenceToXContentCompressor() {}
|
||||
|
||||
|
@ -45,33 +49,34 @@ public final class InferenceToXContentCompressor {
|
|||
static <T> T inflate(String compressedString,
|
||||
CheckedFunction<XContentParser, T, IOException> parserFunction,
|
||||
NamedXContentRegistry xContentRegistry) throws IOException {
|
||||
try(XContentParser parser = XContentHelper.createParser(xContentRegistry,
|
||||
try(XContentParser parser = JsonXContent.jsonXContent.createParser(xContentRegistry,
|
||||
LoggingDeprecationHandler.INSTANCE,
|
||||
inflate(compressedString, MAX_INFLATED_BYTES),
|
||||
XContentType.JSON)) {
|
||||
inflate(compressedString, MAX_INFLATED_BYTES))) {
|
||||
return parserFunction.apply(parser);
|
||||
}
|
||||
}
|
||||
|
||||
static Map<String, Object> inflateToMap(String compressedString) throws IOException {
|
||||
// Don't need the xcontent registry as we are not deflating named objects.
|
||||
try(XContentParser parser = XContentHelper.createParser(NamedXContentRegistry.EMPTY,
|
||||
try(XContentParser parser = JsonXContent.jsonXContent.createParser(NamedXContentRegistry.EMPTY,
|
||||
LoggingDeprecationHandler.INSTANCE,
|
||||
inflate(compressedString, MAX_INFLATED_BYTES),
|
||||
XContentType.JSON)) {
|
||||
inflate(compressedString, MAX_INFLATED_BYTES))) {
|
||||
return parser.mapOrdered();
|
||||
}
|
||||
}
|
||||
|
||||
static BytesReference inflate(String compressedString, long streamSize) throws IOException {
|
||||
static InputStream inflate(String compressedString, long streamSize) throws IOException {
|
||||
byte[] compressedBytes = Base64.getDecoder().decode(compressedString.getBytes(StandardCharsets.UTF_8));
|
||||
// If the compressed length is already too large, it make sense that the inflated length would be as well
|
||||
// In the extremely small string case, the compressed data could actually be longer than the compressed stream
|
||||
if (compressedBytes.length > Math.max(100L, streamSize)) {
|
||||
throw new IOException("compressed stream is longer than maximum allowed bytes [" + streamSize + "]");
|
||||
}
|
||||
InputStream gzipStream = new GZIPInputStream(new BytesArray(compressedBytes).streamInput(), BUFFER_SIZE);
|
||||
InputStream inflateStream = new SimpleBoundedInputStream(gzipStream, streamSize);
|
||||
return Streams.readFully(inflateStream);
|
||||
return new SimpleBoundedInputStream(gzipStream, streamSize);
|
||||
}
|
||||
|
||||
//Public for testing (for now)
|
||||
public static String deflate(BytesReference reference) throws IOException {
|
||||
private static String deflate(BytesReference reference) throws IOException {
|
||||
BytesStreamOutput out = new BytesStreamOutput();
|
||||
try (OutputStream compressedOutput = new GZIPOutputStream(out, BUFFER_SIZE)) {
|
||||
reference.writeTo(compressedOutput);
|
||||
|
|
|
@ -28,17 +28,16 @@ public final class SimpleBoundedInputStream extends InputStream {
|
|||
this.maxBytes = maxBytes;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* A simple wrapper around the injected input stream that restricts the total number of bytes able to be read.
|
||||
* @return The byte read. -1 on internal stream completion or when maxBytes is exceeded.
|
||||
* @throws IOException on failure
|
||||
* @return The byte read.
|
||||
* @throws IOException on failure or when byte limit is exceeded
|
||||
*/
|
||||
@Override
|
||||
public int read() throws IOException {
|
||||
// We have reached the maximum, signal stream completion.
|
||||
if (numBytes >= maxBytes) {
|
||||
return -1;
|
||||
throw new IOException("input stream exceeded maximum bytes of [" + maxBytes + "]");
|
||||
}
|
||||
numBytes++;
|
||||
return in.read();
|
||||
|
|
|
@ -5,15 +5,17 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference;
|
||||
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
|
||||
import org.elasticsearch.common.io.Streams;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
|
@ -33,20 +35,22 @@ public class InferenceToXContentCompressorTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testInflateTooLargeStream() throws IOException {
|
||||
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build();
|
||||
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder()
|
||||
.setPreProcessors(Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
|
||||
OneHotEncodingTests.createRandom(),
|
||||
TargetMeanEncodingTests.createRandom()))
|
||||
.limit(100)
|
||||
.collect(Collectors.toList()))
|
||||
.build();
|
||||
String firstDeflate = InferenceToXContentCompressor.deflate(definition);
|
||||
BytesReference inflatedBytes = InferenceToXContentCompressor.inflate(firstDeflate, 10L);
|
||||
assertThat(inflatedBytes.length(), equalTo(10));
|
||||
try(XContentParser parser = XContentHelper.createParser(xContentRegistry(),
|
||||
LoggingDeprecationHandler.INSTANCE,
|
||||
inflatedBytes,
|
||||
XContentType.JSON)) {
|
||||
expectThrows(IOException.class, () -> TrainedModelConfig.fromXContent(parser, true));
|
||||
}
|
||||
int max = firstDeflate.getBytes(StandardCharsets.UTF_8).length + 10;
|
||||
IOException ex = expectThrows(IOException.class,
|
||||
() -> Streams.readFully(InferenceToXContentCompressor.inflate(firstDeflate, max)));
|
||||
assertThat(ex.getMessage(), equalTo("input stream exceeded maximum bytes of [" + max + "]"));
|
||||
}
|
||||
|
||||
public void testInflateGarbage() {
|
||||
expectThrows(IOException.class, () -> InferenceToXContentCompressor.inflate(randomAlphaOfLength(10), 100L));
|
||||
expectThrows(IOException.class, () -> Streams.readFully(InferenceToXContentCompressor.inflate(randomAlphaOfLength(10), 100L)));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
Loading…
Reference in New Issue