[7.x] [ML] fix bugs with prediction field value settings (#55333) (#55394)

* [ML] fix bugs with prediction field value settings (#55333)

This fixes two unreleased bugs:

1. Prediction value type of `number` might show unexpected classes

Analytics created models may have class labels like `1, 5, 10` (or some collection of discrete, whole numbers). These labels are passed to the inference model config in the `classification_labels` field.

When the predicted value format is `numeric` it should attempt to see if the classification labels are provided and are numeric. If so, use those. If not, use the underlying value.

2. When supplying an update overwrite, inference was losing the default prediction field value. This is because it was not copied over in the copy ctor in the ClassificationConfig.Builder class. 

closes #55332
This commit is contained in:
Benjamin Trent 2020-04-17 14:45:02 -04:00 committed by GitHub
parent 592b5516c1
commit 4be3663968
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 88 additions and 8 deletions

View File

@ -10,6 +10,7 @@ import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType; import org.elasticsearch.action.ActionType;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
@ -137,6 +138,15 @@ public class InternalInferModelAction extends ActionType<InternalInferModelActio
return Objects.hash(modelId, objectsToInfer, update, previouslyLicensed); return Objects.hash(modelId, objectsToInfer, update, previouslyLicensed);
} }
@Override
public String toString() {
return "Request{" +
"modelId='" + modelId + '\'' +
", objectsToInfer=" + objectsToInfer +
", update=" + Strings.toString(update) +
", previouslyLicensed=" + previouslyLicensed +
'}';
}
} }
public static class Response extends ActionResponse { public static class Response extends ActionResponse {

View File

@ -214,6 +214,7 @@ public class ClassificationConfig implements LenientlyParsedInferenceConfig, Str
this.topClassesResultsField = config.topClassesResultsField; this.topClassesResultsField = config.topClassesResultsField;
this.resultsField = config.resultsField; this.resultsField = config.resultsField;
this.numTopFeatureImportanceValues = config.numTopFeatureImportanceValues; this.numTopFeatureImportanceValues = config.numTopFeatureImportanceValues;
this.predictionFieldType = config.predictionFieldType;
} }
public Builder setNumTopClasses(Integer numTopClasses) { public Builder setNumTopClasses(Integer numTopClasses) {

View File

@ -262,11 +262,15 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate {
return this; return this;
} }
private Builder setPredictionFieldType(String predictionFieldType) { public Builder setPredictionFieldType(PredictionFieldType predictionFieldtype) {
this.predictionFieldType = PredictionFieldType.fromString(predictionFieldType); this.predictionFieldType = predictionFieldtype;
return this; return this;
} }
private Builder setPredictionFieldType(String predictionFieldType) {
return setPredictionFieldType(PredictionFieldType.fromString(predictionFieldType));
}
public ClassificationConfigUpdate build() { public ClassificationConfigUpdate build() {
return new ClassificationConfigUpdate(numTopClasses, return new ClassificationConfigUpdate(numTopClasses,
resultsField, resultsField,

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.core.ml.inference.trainedmodel; package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
@ -57,6 +58,19 @@ public enum PredictionFieldType implements Writeable {
} }
return areClose(value, 1.0D); return areClose(value, 1.0D);
case NUMBER: case NUMBER:
if (Strings.isNullOrEmpty(stringRep)) {
return value;
}
// Quick check to verify that the string rep is LIKELY a number
// Still handles the case where it throws and then returns the underlying value
if (stringRep.charAt(0) == '-' || Character.isDigit(stringRep.charAt(0))) {
try {
return Long.parseLong(stringRep);
} catch (NumberFormatException nfe) {
return value;
}
}
return value;
default: default:
return value; return value;
} }

View File

@ -48,8 +48,8 @@ public class InternalInferModelActionRequestTests extends AbstractBWCWireSeriali
} }
private static InferenceConfigUpdate randomInferenceConfigUpdate() { private static InferenceConfigUpdate randomInferenceConfigUpdate() {
return randomFrom(RegressionConfigUpdateTests.randomRegressionConfig(), return randomFrom(RegressionConfigUpdateTests.randomRegressionConfigUpdate(),
ClassificationConfigUpdateTests.randomClassificationConfig()); ClassificationConfigUpdateTests.randomClassificationConfigUpdate());
} }
private static Map<String, Object> randomMap() { private static Map<String, Object> randomMap() {

View File

@ -29,6 +29,9 @@ public class ClassificationConfigTests extends AbstractBWCSerializationTestCase<
public static ClassificationConfig mutateForVersion(ClassificationConfig instance, Version version) { public static ClassificationConfig mutateForVersion(ClassificationConfig instance, Version version) {
ClassificationConfig.Builder builder = new ClassificationConfig.Builder(instance); ClassificationConfig.Builder builder = new ClassificationConfig.Builder(instance);
if (version.before(Version.V_7_8_0)) {
builder.setPredictionFieldType(PredictionFieldType.STRING);
}
if (version.before(Version.V_7_7_0)) { if (version.before(Version.V_7_7_0)) {
builder.setNumTopFeatureImportanceValues(0); builder.setNumTopFeatureImportanceValues(0);
} }

View File

@ -16,11 +16,12 @@ import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests.randomClassificationConfig;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
public class ClassificationConfigUpdateTests extends AbstractBWCSerializationTestCase<ClassificationConfigUpdate> { public class ClassificationConfigUpdateTests extends AbstractBWCSerializationTestCase<ClassificationConfigUpdate> {
public static ClassificationConfigUpdate randomClassificationConfig() { public static ClassificationConfigUpdate randomClassificationConfigUpdate() {
return new ClassificationConfigUpdate(randomBoolean() ? null : randomIntBetween(-1, 10), return new ClassificationConfigUpdate(randomBoolean() ? null : randomIntBetween(-1, 10),
randomBoolean() ? null : randomAlphaOfLength(10), randomBoolean() ? null : randomAlphaOfLength(10),
randomBoolean() ? null : randomAlphaOfLength(10), randomBoolean() ? null : randomAlphaOfLength(10),
@ -49,9 +50,33 @@ public class ClassificationConfigUpdateTests extends AbstractBWCSerializationTes
assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key].")); assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key]."));
} }
public void testApply() {
ClassificationConfig originalConfig = randomClassificationConfig();
assertThat(originalConfig, equalTo(ClassificationConfigUpdate.EMPTY_PARAMS.apply(originalConfig)));
assertThat(new ClassificationConfig.Builder(originalConfig).setNumTopClasses(5).build(),
equalTo(new ClassificationConfigUpdate.Builder().setNumTopClasses(5).build().apply(originalConfig)));
assertThat(new ClassificationConfig.Builder()
.setNumTopClasses(5)
.setNumTopFeatureImportanceValues(1)
.setPredictionFieldType(PredictionFieldType.BOOLEAN)
.setResultsField("foo")
.setTopClassesResultsField("bar").build(),
equalTo(new ClassificationConfigUpdate.Builder()
.setNumTopClasses(5)
.setNumTopFeatureImportanceValues(1)
.setPredictionFieldType(PredictionFieldType.BOOLEAN)
.setResultsField("foo")
.setTopClassesResultsField("bar")
.build()
.apply(originalConfig)
));
}
@Override @Override
protected ClassificationConfigUpdate createTestInstance() { protected ClassificationConfigUpdate createTestInstance() {
return randomClassificationConfig(); return randomClassificationConfigUpdate();
} }
@Override @Override

View File

@ -38,6 +38,9 @@ public class PredictionFieldTypeTests extends ESTestCase {
is(nullValue())); is(nullValue()));
assertThat(PredictionFieldType.NUMBER.transformPredictedValue(1.0, "foo"), equalTo(1.0)); assertThat(PredictionFieldType.NUMBER.transformPredictedValue(1.0, "foo"), equalTo(1.0));
assertThat(PredictionFieldType.NUMBER.transformPredictedValue(1.0, null), equalTo(1.0)); assertThat(PredictionFieldType.NUMBER.transformPredictedValue(1.0, null), equalTo(1.0));
assertThat(PredictionFieldType.NUMBER.transformPredictedValue(1.0, ""), equalTo(1.0));
long expected = randomLong();
assertThat(PredictionFieldType.NUMBER.transformPredictedValue(1.0, String.valueOf(expected)), equalTo(expected));
} }
} }

View File

@ -16,11 +16,12 @@ import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests.randomRegressionConfig;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
public class RegressionConfigUpdateTests extends AbstractBWCSerializationTestCase<RegressionConfigUpdate> { public class RegressionConfigUpdateTests extends AbstractBWCSerializationTestCase<RegressionConfigUpdate> {
public static RegressionConfigUpdate randomRegressionConfig() { public static RegressionConfigUpdate randomRegressionConfigUpdate() {
return new RegressionConfigUpdate(randomBoolean() ? null : randomAlphaOfLength(10), return new RegressionConfigUpdate(randomBoolean() ? null : randomAlphaOfLength(10),
randomBoolean() ? null : randomIntBetween(0, 10)); randomBoolean() ? null : randomIntBetween(0, 10));
} }
@ -40,9 +41,28 @@ public class RegressionConfigUpdateTests extends AbstractBWCSerializationTestCas
assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key].")); assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key]."));
} }
public void testApply() {
RegressionConfig originalConfig = randomRegressionConfig();
assertThat(originalConfig, equalTo(RegressionConfigUpdate.EMPTY_PARAMS.apply(originalConfig)));
assertThat(new RegressionConfig.Builder(originalConfig).setNumTopFeatureImportanceValues(5).build(),
equalTo(new RegressionConfigUpdate.Builder().setNumTopFeatureImportanceValues(5).build().apply(originalConfig)));
assertThat(new RegressionConfig.Builder()
.setNumTopFeatureImportanceValues(1)
.setResultsField("foo")
.build(),
equalTo(new RegressionConfigUpdate.Builder()
.setNumTopFeatureImportanceValues(1)
.setResultsField("foo")
.build()
.apply(originalConfig)
));
}
@Override @Override
protected RegressionConfigUpdate createTestInstance() { protected RegressionConfigUpdate createTestInstance() {
return randomRegressionConfig(); return randomRegressionConfigUpdate();
} }
@Override @Override