The actual cause can be lost in a long list of parse exceptions this surfaces the cause when the problem is size.
This commit is contained in:
parent
c5443f78ce
commit
dbb9c802b1
|
@ -6,6 +6,7 @@
|
|||
|
||||
package org.elasticsearch.xpack.core.ml.inference;
|
||||
|
||||
import org.elasticsearch.ExceptionsHelper;
|
||||
import org.elasticsearch.common.CheckedFunction;
|
||||
import org.elasticsearch.common.bytes.BytesArray;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
|
@ -16,6 +17,7 @@ import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
|
|||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.common.xcontent.XContentParseException;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||
|
@ -51,10 +53,29 @@ public final class InferenceToXContentCompressor {
|
|||
public static <T> T inflate(String compressedString,
|
||||
CheckedFunction<XContentParser, T, IOException> parserFunction,
|
||||
NamedXContentRegistry xContentRegistry) throws IOException {
|
||||
return inflate(compressedString, parserFunction, xContentRegistry, MAX_INFLATED_BYTES);
|
||||
}
|
||||
|
||||
static <T> T inflate(String compressedString,
|
||||
CheckedFunction<XContentParser, T, IOException> parserFunction,
|
||||
NamedXContentRegistry xContentRegistry,
|
||||
long maxBytes) throws IOException {
|
||||
try(XContentParser parser = JsonXContent.jsonXContent.createParser(xContentRegistry,
|
||||
LoggingDeprecationHandler.INSTANCE,
|
||||
inflate(compressedString, MAX_INFLATED_BYTES))) {
|
||||
inflate(compressedString, maxBytes))) {
|
||||
return parserFunction.apply(parser);
|
||||
} catch (XContentParseException parseException) {
|
||||
SimpleBoundedInputStream.StreamSizeExceededException streamSizeCause =
|
||||
(SimpleBoundedInputStream.StreamSizeExceededException)
|
||||
ExceptionsHelper.unwrap(parseException, SimpleBoundedInputStream.StreamSizeExceededException.class);
|
||||
|
||||
if (streamSizeCause != null) {
|
||||
// The root cause is that the model is too big.
|
||||
throw new IOException("Cannot parse model definition as the content is larger than the maximum stream size of ["
|
||||
+ streamSizeCause.getMaxBytes() + "] bytes. Max stream size is 10% of the JVM heap or 1GB whichever is smallest");
|
||||
} else {
|
||||
throw parseException;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,19 @@ public final class SimpleBoundedInputStream extends InputStream {
|
|||
private final long maxBytes;
|
||||
private long numBytes;
|
||||
|
||||
public static class StreamSizeExceededException extends IOException {
|
||||
private final long maxBytes;
|
||||
|
||||
public StreamSizeExceededException(String message, long maxBytes) {
|
||||
super(message);
|
||||
this.maxBytes = maxBytes;
|
||||
}
|
||||
|
||||
public long getMaxBytes() {
|
||||
return maxBytes;
|
||||
}
|
||||
}
|
||||
|
||||
public SimpleBoundedInputStream(InputStream inputStream, long maxBytes) {
|
||||
this.in = ExceptionsHelper.requireNonNull(inputStream, "inputStream");
|
||||
if (maxBytes < 0) {
|
||||
|
@ -31,13 +44,14 @@ public final class SimpleBoundedInputStream extends InputStream {
|
|||
/**
|
||||
* A simple wrapper around the injected input stream that restricts the total number of bytes able to be read.
|
||||
* @return The byte read.
|
||||
* @throws IOException on failure or when byte limit is exceeded
|
||||
* @throws StreamSizeExceededException when byte limit is exceeded
|
||||
* @throws IOException on failure
|
||||
*/
|
||||
@Override
|
||||
public int read() throws IOException {
|
||||
// We have reached the maximum, signal stream completion.
|
||||
if (numBytes >= maxBytes) {
|
||||
throw new IOException("input stream exceeded maximum bytes of [" + maxBytes + "]");
|
||||
throw new StreamSizeExceededException("input stream exceeded maximum bytes of [" + maxBytes + "]", maxBytes);
|
||||
}
|
||||
numBytes++;
|
||||
return in.read();
|
||||
|
|
|
@ -46,13 +46,34 @@ public class InferenceToXContentCompressorTests extends ESTestCase {
|
|||
int max = firstDeflate.getBytes(StandardCharsets.UTF_8).length + 10;
|
||||
IOException ex = expectThrows(IOException.class,
|
||||
() -> Streams.readFully(InferenceToXContentCompressor.inflate(firstDeflate, max)));
|
||||
assertThat(ex.getMessage(), equalTo("input stream exceeded maximum bytes of [" + max + "]"));
|
||||
assertThat(ex.getMessage(), equalTo("" +
|
||||
"input stream exceeded maximum bytes of [" + max + "]"));
|
||||
}
|
||||
|
||||
public void testInflateGarbage() {
|
||||
expectThrows(IOException.class, () -> Streams.readFully(InferenceToXContentCompressor.inflate(randomAlphaOfLength(10), 100L)));
|
||||
}
|
||||
|
||||
public void testInflateParsingTooLargeStream() throws IOException {
|
||||
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder()
|
||||
.setPreProcessors(Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
|
||||
OneHotEncodingTests.createRandom(),
|
||||
TargetMeanEncodingTests.createRandom()))
|
||||
.limit(100)
|
||||
.collect(Collectors.toList()))
|
||||
.build();
|
||||
String compressedString = InferenceToXContentCompressor.deflate(definition);
|
||||
int max = compressedString.getBytes(StandardCharsets.UTF_8).length + 10;
|
||||
|
||||
IOException e = expectThrows(IOException.class, ()-> InferenceToXContentCompressor.inflate(compressedString,
|
||||
parser -> TrainedModelDefinition.fromXContent(parser, true).build(),
|
||||
xContentRegistry(),
|
||||
max));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("Cannot parse model definition as the content is larger than the maximum stream size of ["
|
||||
+ max + "] bytes. Max stream size is 10% of the JVM heap or 1GB whichever is smallest"));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
|
|
Loading…
Reference in New Issue