Three fixes for when the `compressed_definition` is utilized on PUT * Update the inflate byte limit to be the minimum of 10% the max heap, or 1GB (what it was previously) * Stream data directly to the JSON parser, so if it is invalid, we don't have to inflate the whole stream to find out * Throw when the maximum bytes are reach indicating that is why the request was rejected
This commit is contained in:
parent
9c7a63214c
commit
2a2a0941af
|
@ -9,7 +9,6 @@ package org.elasticsearch.xpack.core.ml.inference;
|
||||||
import org.elasticsearch.common.CheckedFunction;
|
import org.elasticsearch.common.CheckedFunction;
|
||||||
import org.elasticsearch.common.bytes.BytesArray;
|
import org.elasticsearch.common.bytes.BytesArray;
|
||||||
import org.elasticsearch.common.bytes.BytesReference;
|
import org.elasticsearch.common.bytes.BytesReference;
|
||||||
import org.elasticsearch.common.io.Streams;
|
|
||||||
import org.elasticsearch.common.io.stream.BytesStreamOutput;
|
import org.elasticsearch.common.io.stream.BytesStreamOutput;
|
||||||
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
|
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
|
||||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||||
|
@ -17,6 +16,8 @@ import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.common.xcontent.XContentType;
|
import org.elasticsearch.common.xcontent.XContentType;
|
||||||
|
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||||
|
import org.elasticsearch.monitor.jvm.JvmInfo;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.utils.SimpleBoundedInputStream;
|
import org.elasticsearch.xpack.core.ml.inference.utils.SimpleBoundedInputStream;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -33,7 +34,10 @@ import java.util.zip.GZIPOutputStream;
|
||||||
*/
|
*/
|
||||||
public final class InferenceToXContentCompressor {
|
public final class InferenceToXContentCompressor {
|
||||||
private static final int BUFFER_SIZE = 4096;
|
private static final int BUFFER_SIZE = 4096;
|
||||||
private static final long MAX_INFLATED_BYTES = 1_000_000_000; // 1 gb maximum
|
// Either 10% of the configured JVM heap, or 1 GB, which ever is smaller
|
||||||
|
private static final long MAX_INFLATED_BYTES = Math.min(
|
||||||
|
(long)((0.10) * JvmInfo.jvmInfo().getMem().getHeapMax().getBytes()),
|
||||||
|
1_000_000_000); // 1 gb maximum
|
||||||
|
|
||||||
private InferenceToXContentCompressor() {}
|
private InferenceToXContentCompressor() {}
|
||||||
|
|
||||||
|
@ -45,33 +49,34 @@ public final class InferenceToXContentCompressor {
|
||||||
static <T> T inflate(String compressedString,
|
static <T> T inflate(String compressedString,
|
||||||
CheckedFunction<XContentParser, T, IOException> parserFunction,
|
CheckedFunction<XContentParser, T, IOException> parserFunction,
|
||||||
NamedXContentRegistry xContentRegistry) throws IOException {
|
NamedXContentRegistry xContentRegistry) throws IOException {
|
||||||
try(XContentParser parser = XContentHelper.createParser(xContentRegistry,
|
try(XContentParser parser = JsonXContent.jsonXContent.createParser(xContentRegistry,
|
||||||
LoggingDeprecationHandler.INSTANCE,
|
LoggingDeprecationHandler.INSTANCE,
|
||||||
inflate(compressedString, MAX_INFLATED_BYTES),
|
inflate(compressedString, MAX_INFLATED_BYTES))) {
|
||||||
XContentType.JSON)) {
|
|
||||||
return parserFunction.apply(parser);
|
return parserFunction.apply(parser);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static Map<String, Object> inflateToMap(String compressedString) throws IOException {
|
static Map<String, Object> inflateToMap(String compressedString) throws IOException {
|
||||||
// Don't need the xcontent registry as we are not deflating named objects.
|
// Don't need the xcontent registry as we are not deflating named objects.
|
||||||
try(XContentParser parser = XContentHelper.createParser(NamedXContentRegistry.EMPTY,
|
try(XContentParser parser = JsonXContent.jsonXContent.createParser(NamedXContentRegistry.EMPTY,
|
||||||
LoggingDeprecationHandler.INSTANCE,
|
LoggingDeprecationHandler.INSTANCE,
|
||||||
inflate(compressedString, MAX_INFLATED_BYTES),
|
inflate(compressedString, MAX_INFLATED_BYTES))) {
|
||||||
XContentType.JSON)) {
|
|
||||||
return parser.mapOrdered();
|
return parser.mapOrdered();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static BytesReference inflate(String compressedString, long streamSize) throws IOException {
|
static InputStream inflate(String compressedString, long streamSize) throws IOException {
|
||||||
byte[] compressedBytes = Base64.getDecoder().decode(compressedString.getBytes(StandardCharsets.UTF_8));
|
byte[] compressedBytes = Base64.getDecoder().decode(compressedString.getBytes(StandardCharsets.UTF_8));
|
||||||
|
// If the compressed length is already too large, it make sense that the inflated length would be as well
|
||||||
|
// In the extremely small string case, the compressed data could actually be longer than the compressed stream
|
||||||
|
if (compressedBytes.length > Math.max(100L, streamSize)) {
|
||||||
|
throw new IOException("compressed stream is longer than maximum allowed bytes [" + streamSize + "]");
|
||||||
|
}
|
||||||
InputStream gzipStream = new GZIPInputStream(new BytesArray(compressedBytes).streamInput(), BUFFER_SIZE);
|
InputStream gzipStream = new GZIPInputStream(new BytesArray(compressedBytes).streamInput(), BUFFER_SIZE);
|
||||||
InputStream inflateStream = new SimpleBoundedInputStream(gzipStream, streamSize);
|
return new SimpleBoundedInputStream(gzipStream, streamSize);
|
||||||
return Streams.readFully(inflateStream);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//Public for testing (for now)
|
private static String deflate(BytesReference reference) throws IOException {
|
||||||
public static String deflate(BytesReference reference) throws IOException {
|
|
||||||
BytesStreamOutput out = new BytesStreamOutput();
|
BytesStreamOutput out = new BytesStreamOutput();
|
||||||
try (OutputStream compressedOutput = new GZIPOutputStream(out, BUFFER_SIZE)) {
|
try (OutputStream compressedOutput = new GZIPOutputStream(out, BUFFER_SIZE)) {
|
||||||
reference.writeTo(compressedOutput);
|
reference.writeTo(compressedOutput);
|
||||||
|
|
|
@ -28,17 +28,16 @@ public final class SimpleBoundedInputStream extends InputStream {
|
||||||
this.maxBytes = maxBytes;
|
this.maxBytes = maxBytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A simple wrapper around the injected input stream that restricts the total number of bytes able to be read.
|
* A simple wrapper around the injected input stream that restricts the total number of bytes able to be read.
|
||||||
* @return The byte read. -1 on internal stream completion or when maxBytes is exceeded.
|
* @return The byte read.
|
||||||
* @throws IOException on failure
|
* @throws IOException on failure or when byte limit is exceeded
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public int read() throws IOException {
|
public int read() throws IOException {
|
||||||
// We have reached the maximum, signal stream completion.
|
// We have reached the maximum, signal stream completion.
|
||||||
if (numBytes >= maxBytes) {
|
if (numBytes >= maxBytes) {
|
||||||
return -1;
|
throw new IOException("input stream exceeded maximum bytes of [" + maxBytes + "]");
|
||||||
}
|
}
|
||||||
numBytes++;
|
numBytes++;
|
||||||
return in.read();
|
return in.read();
|
||||||
|
|
|
@ -5,15 +5,17 @@
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.xpack.core.ml.inference;
|
package org.elasticsearch.xpack.core.ml.inference;
|
||||||
|
|
||||||
import org.elasticsearch.common.bytes.BytesReference;
|
import org.elasticsearch.common.io.Streams;
|
||||||
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
|
|
||||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
|
||||||
import org.elasticsearch.common.xcontent.XContentType;
|
|
||||||
import org.elasticsearch.test.ESTestCase;
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
|
||||||
|
@ -33,20 +35,22 @@ public class InferenceToXContentCompressorTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testInflateTooLargeStream() throws IOException {
|
public void testInflateTooLargeStream() throws IOException {
|
||||||
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build();
|
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder()
|
||||||
|
.setPreProcessors(Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
|
||||||
|
OneHotEncodingTests.createRandom(),
|
||||||
|
TargetMeanEncodingTests.createRandom()))
|
||||||
|
.limit(100)
|
||||||
|
.collect(Collectors.toList()))
|
||||||
|
.build();
|
||||||
String firstDeflate = InferenceToXContentCompressor.deflate(definition);
|
String firstDeflate = InferenceToXContentCompressor.deflate(definition);
|
||||||
BytesReference inflatedBytes = InferenceToXContentCompressor.inflate(firstDeflate, 10L);
|
int max = firstDeflate.getBytes(StandardCharsets.UTF_8).length + 10;
|
||||||
assertThat(inflatedBytes.length(), equalTo(10));
|
IOException ex = expectThrows(IOException.class,
|
||||||
try(XContentParser parser = XContentHelper.createParser(xContentRegistry(),
|
() -> Streams.readFully(InferenceToXContentCompressor.inflate(firstDeflate, max)));
|
||||||
LoggingDeprecationHandler.INSTANCE,
|
assertThat(ex.getMessage(), equalTo("input stream exceeded maximum bytes of [" + max + "]"));
|
||||||
inflatedBytes,
|
|
||||||
XContentType.JSON)) {
|
|
||||||
expectThrows(IOException.class, () -> TrainedModelConfig.fromXContent(parser, true));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testInflateGarbage() {
|
public void testInflateGarbage() {
|
||||||
expectThrows(IOException.class, () -> InferenceToXContentCompressor.inflate(randomAlphaOfLength(10), 100L));
|
expectThrows(IOException.class, () -> Streams.readFully(InferenceToXContentCompressor.inflate(randomAlphaOfLength(10), 100L)));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
Loading…
Reference in New Issue