Better error message when the model cannot be parsed due to its size (#59166) (#59209)

The actual cause can be lost in a long list of parse exceptions
this surfaces the cause when the problem is size.
This commit is contained in:
David Kyle 2020-07-09 13:43:46 +01:00 committed by GitHub
parent c5443f78ce
commit dbb9c802b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 60 additions and 4 deletions

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.core.ml.inference;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.common.CheckedFunction;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
@ -16,6 +17,7 @@ import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParseException;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.common.xcontent.json.JsonXContent;
@ -51,10 +53,29 @@ public final class InferenceToXContentCompressor {
public static <T> T inflate(String compressedString,
CheckedFunction<XContentParser, T, IOException> parserFunction,
NamedXContentRegistry xContentRegistry) throws IOException {
return inflate(compressedString, parserFunction, xContentRegistry, MAX_INFLATED_BYTES);
}
static <T> T inflate(String compressedString,
CheckedFunction<XContentParser, T, IOException> parserFunction,
NamedXContentRegistry xContentRegistry,
long maxBytes) throws IOException {
try(XContentParser parser = JsonXContent.jsonXContent.createParser(xContentRegistry,
LoggingDeprecationHandler.INSTANCE,
inflate(compressedString, MAX_INFLATED_BYTES))) {
inflate(compressedString, maxBytes))) {
return parserFunction.apply(parser);
} catch (XContentParseException parseException) {
SimpleBoundedInputStream.StreamSizeExceededException streamSizeCause =
(SimpleBoundedInputStream.StreamSizeExceededException)
ExceptionsHelper.unwrap(parseException, SimpleBoundedInputStream.StreamSizeExceededException.class);
if (streamSizeCause != null) {
// The root cause is that the model is too big.
throw new IOException("Cannot parse model definition as the content is larger than the maximum stream size of ["
+ streamSizeCause.getMaxBytes() + "] bytes. Max stream size is 10% of the JVM heap or 1GB whichever is smallest");
} else {
throw parseException;
}
}
}

View File

@ -20,6 +20,19 @@ public final class SimpleBoundedInputStream extends InputStream {
private final long maxBytes;
private long numBytes;
public static class StreamSizeExceededException extends IOException {
private final long maxBytes;
public StreamSizeExceededException(String message, long maxBytes) {
super(message);
this.maxBytes = maxBytes;
}
public long getMaxBytes() {
return maxBytes;
}
}
public SimpleBoundedInputStream(InputStream inputStream, long maxBytes) {
this.in = ExceptionsHelper.requireNonNull(inputStream, "inputStream");
if (maxBytes < 0) {
@ -31,13 +44,14 @@ public final class SimpleBoundedInputStream extends InputStream {
/**
* A simple wrapper around the injected input stream that restricts the total number of bytes able to be read.
* @return The byte read.
* @throws IOException on failure or when byte limit is exceeded
* @throws StreamSizeExceededException when byte limit is exceeded
* @throws IOException on failure
*/
@Override
public int read() throws IOException {
// We have reached the maximum, signal stream completion.
if (numBytes >= maxBytes) {
throw new IOException("input stream exceeded maximum bytes of [" + maxBytes + "]");
throw new StreamSizeExceededException("input stream exceeded maximum bytes of [" + maxBytes + "]", maxBytes);
}
numBytes++;
return in.read();

View File

@ -46,13 +46,34 @@ public class InferenceToXContentCompressorTests extends ESTestCase {
int max = firstDeflate.getBytes(StandardCharsets.UTF_8).length + 10;
IOException ex = expectThrows(IOException.class,
() -> Streams.readFully(InferenceToXContentCompressor.inflate(firstDeflate, max)));
assertThat(ex.getMessage(), equalTo("input stream exceeded maximum bytes of [" + max + "]"));
assertThat(ex.getMessage(), equalTo("" +
"input stream exceeded maximum bytes of [" + max + "]"));
}
public void testInflateGarbage() {
expectThrows(IOException.class, () -> Streams.readFully(InferenceToXContentCompressor.inflate(randomAlphaOfLength(10), 100L)));
}
public void testInflateParsingTooLargeStream() throws IOException {
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder()
.setPreProcessors(Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
OneHotEncodingTests.createRandom(),
TargetMeanEncodingTests.createRandom()))
.limit(100)
.collect(Collectors.toList()))
.build();
String compressedString = InferenceToXContentCompressor.deflate(definition);
int max = compressedString.getBytes(StandardCharsets.UTF_8).length + 10;
IOException e = expectThrows(IOException.class, ()-> InferenceToXContentCompressor.inflate(compressedString,
parser -> TrainedModelDefinition.fromXContent(parser, true).build(),
xContentRegistry(),
max));
assertThat(e.getMessage(), equalTo("Cannot parse model definition as the content is larger than the maximum stream size of ["
+ max + "] bytes. Max stream size is 10% of the JVM heap or 1GB whichever is smallest"));
}
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());