* [ML] fix bugs with prediction field value settings (#55333) This fixes two unreleased bugs: 1. Prediction value type of `number` might show unexpected classes Analytics created models may have class labels like `1, 5, 10` (or some collection of discrete, whole numbers). These labels are passed to the inference model config in the `classification_labels` field. When the predicted value format is `numeric` it should attempt to see if the classification labels are provided and are numeric. If so, use those. If not, use the underlying value. 2. When supplying an update overwrite, inference was losing the default prediction field value. This is because it was not copied over in the copy ctor in the ClassificationConfig.Builder class. closes #55332
This commit is contained in:
parent
592b5516c1
commit
4be3663968
|
@ -10,6 +10,7 @@ import org.elasticsearch.action.ActionRequest;
|
||||||
import org.elasticsearch.action.ActionRequestValidationException;
|
import org.elasticsearch.action.ActionRequestValidationException;
|
||||||
import org.elasticsearch.action.ActionResponse;
|
import org.elasticsearch.action.ActionResponse;
|
||||||
import org.elasticsearch.action.ActionType;
|
import org.elasticsearch.action.ActionType;
|
||||||
|
import org.elasticsearch.common.Strings;
|
||||||
import org.elasticsearch.common.io.stream.StreamInput;
|
import org.elasticsearch.common.io.stream.StreamInput;
|
||||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||||
|
@ -137,6 +138,15 @@ public class InternalInferModelAction extends ActionType<InternalInferModelActio
|
||||||
return Objects.hash(modelId, objectsToInfer, update, previouslyLicensed);
|
return Objects.hash(modelId, objectsToInfer, update, previouslyLicensed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "Request{" +
|
||||||
|
"modelId='" + modelId + '\'' +
|
||||||
|
", objectsToInfer=" + objectsToInfer +
|
||||||
|
", update=" + Strings.toString(update) +
|
||||||
|
", previouslyLicensed=" + previouslyLicensed +
|
||||||
|
'}';
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class Response extends ActionResponse {
|
public static class Response extends ActionResponse {
|
||||||
|
|
|
@ -214,6 +214,7 @@ public class ClassificationConfig implements LenientlyParsedInferenceConfig, Str
|
||||||
this.topClassesResultsField = config.topClassesResultsField;
|
this.topClassesResultsField = config.topClassesResultsField;
|
||||||
this.resultsField = config.resultsField;
|
this.resultsField = config.resultsField;
|
||||||
this.numTopFeatureImportanceValues = config.numTopFeatureImportanceValues;
|
this.numTopFeatureImportanceValues = config.numTopFeatureImportanceValues;
|
||||||
|
this.predictionFieldType = config.predictionFieldType;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Builder setNumTopClasses(Integer numTopClasses) {
|
public Builder setNumTopClasses(Integer numTopClasses) {
|
||||||
|
|
|
@ -262,11 +262,15 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
private Builder setPredictionFieldType(String predictionFieldType) {
|
public Builder setPredictionFieldType(PredictionFieldType predictionFieldtype) {
|
||||||
this.predictionFieldType = PredictionFieldType.fromString(predictionFieldType);
|
this.predictionFieldType = predictionFieldtype;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private Builder setPredictionFieldType(String predictionFieldType) {
|
||||||
|
return setPredictionFieldType(PredictionFieldType.fromString(predictionFieldType));
|
||||||
|
}
|
||||||
|
|
||||||
public ClassificationConfigUpdate build() {
|
public ClassificationConfigUpdate build() {
|
||||||
return new ClassificationConfigUpdate(numTopClasses,
|
return new ClassificationConfigUpdate(numTopClasses,
|
||||||
resultsField,
|
resultsField,
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.Strings;
|
||||||
import org.elasticsearch.common.io.stream.StreamInput;
|
import org.elasticsearch.common.io.stream.StreamInput;
|
||||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
import org.elasticsearch.common.io.stream.Writeable;
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
|
@ -57,6 +58,19 @@ public enum PredictionFieldType implements Writeable {
|
||||||
}
|
}
|
||||||
return areClose(value, 1.0D);
|
return areClose(value, 1.0D);
|
||||||
case NUMBER:
|
case NUMBER:
|
||||||
|
if (Strings.isNullOrEmpty(stringRep)) {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
// Quick check to verify that the string rep is LIKELY a number
|
||||||
|
// Still handles the case where it throws and then returns the underlying value
|
||||||
|
if (stringRep.charAt(0) == '-' || Character.isDigit(stringRep.charAt(0))) {
|
||||||
|
try {
|
||||||
|
return Long.parseLong(stringRep);
|
||||||
|
} catch (NumberFormatException nfe) {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return value;
|
||||||
default:
|
default:
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,8 +48,8 @@ public class InternalInferModelActionRequestTests extends AbstractBWCWireSeriali
|
||||||
}
|
}
|
||||||
|
|
||||||
private static InferenceConfigUpdate randomInferenceConfigUpdate() {
|
private static InferenceConfigUpdate randomInferenceConfigUpdate() {
|
||||||
return randomFrom(RegressionConfigUpdateTests.randomRegressionConfig(),
|
return randomFrom(RegressionConfigUpdateTests.randomRegressionConfigUpdate(),
|
||||||
ClassificationConfigUpdateTests.randomClassificationConfig());
|
ClassificationConfigUpdateTests.randomClassificationConfigUpdate());
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Map<String, Object> randomMap() {
|
private static Map<String, Object> randomMap() {
|
||||||
|
|
|
@ -29,6 +29,9 @@ public class ClassificationConfigTests extends AbstractBWCSerializationTestCase<
|
||||||
|
|
||||||
public static ClassificationConfig mutateForVersion(ClassificationConfig instance, Version version) {
|
public static ClassificationConfig mutateForVersion(ClassificationConfig instance, Version version) {
|
||||||
ClassificationConfig.Builder builder = new ClassificationConfig.Builder(instance);
|
ClassificationConfig.Builder builder = new ClassificationConfig.Builder(instance);
|
||||||
|
if (version.before(Version.V_7_8_0)) {
|
||||||
|
builder.setPredictionFieldType(PredictionFieldType.STRING);
|
||||||
|
}
|
||||||
if (version.before(Version.V_7_7_0)) {
|
if (version.before(Version.V_7_7_0)) {
|
||||||
builder.setNumTopFeatureImportanceValues(0);
|
builder.setNumTopFeatureImportanceValues(0);
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,11 +16,12 @@ import java.util.Collections;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests.randomClassificationConfig;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
|
||||||
public class ClassificationConfigUpdateTests extends AbstractBWCSerializationTestCase<ClassificationConfigUpdate> {
|
public class ClassificationConfigUpdateTests extends AbstractBWCSerializationTestCase<ClassificationConfigUpdate> {
|
||||||
|
|
||||||
public static ClassificationConfigUpdate randomClassificationConfig() {
|
public static ClassificationConfigUpdate randomClassificationConfigUpdate() {
|
||||||
return new ClassificationConfigUpdate(randomBoolean() ? null : randomIntBetween(-1, 10),
|
return new ClassificationConfigUpdate(randomBoolean() ? null : randomIntBetween(-1, 10),
|
||||||
randomBoolean() ? null : randomAlphaOfLength(10),
|
randomBoolean() ? null : randomAlphaOfLength(10),
|
||||||
randomBoolean() ? null : randomAlphaOfLength(10),
|
randomBoolean() ? null : randomAlphaOfLength(10),
|
||||||
|
@ -49,9 +50,33 @@ public class ClassificationConfigUpdateTests extends AbstractBWCSerializationTes
|
||||||
assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key]."));
|
assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key]."));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testApply() {
|
||||||
|
ClassificationConfig originalConfig = randomClassificationConfig();
|
||||||
|
|
||||||
|
assertThat(originalConfig, equalTo(ClassificationConfigUpdate.EMPTY_PARAMS.apply(originalConfig)));
|
||||||
|
|
||||||
|
assertThat(new ClassificationConfig.Builder(originalConfig).setNumTopClasses(5).build(),
|
||||||
|
equalTo(new ClassificationConfigUpdate.Builder().setNumTopClasses(5).build().apply(originalConfig)));
|
||||||
|
assertThat(new ClassificationConfig.Builder()
|
||||||
|
.setNumTopClasses(5)
|
||||||
|
.setNumTopFeatureImportanceValues(1)
|
||||||
|
.setPredictionFieldType(PredictionFieldType.BOOLEAN)
|
||||||
|
.setResultsField("foo")
|
||||||
|
.setTopClassesResultsField("bar").build(),
|
||||||
|
equalTo(new ClassificationConfigUpdate.Builder()
|
||||||
|
.setNumTopClasses(5)
|
||||||
|
.setNumTopFeatureImportanceValues(1)
|
||||||
|
.setPredictionFieldType(PredictionFieldType.BOOLEAN)
|
||||||
|
.setResultsField("foo")
|
||||||
|
.setTopClassesResultsField("bar")
|
||||||
|
.build()
|
||||||
|
.apply(originalConfig)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected ClassificationConfigUpdate createTestInstance() {
|
protected ClassificationConfigUpdate createTestInstance() {
|
||||||
return randomClassificationConfig();
|
return randomClassificationConfigUpdate();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -38,6 +38,9 @@ public class PredictionFieldTypeTests extends ESTestCase {
|
||||||
is(nullValue()));
|
is(nullValue()));
|
||||||
assertThat(PredictionFieldType.NUMBER.transformPredictedValue(1.0, "foo"), equalTo(1.0));
|
assertThat(PredictionFieldType.NUMBER.transformPredictedValue(1.0, "foo"), equalTo(1.0));
|
||||||
assertThat(PredictionFieldType.NUMBER.transformPredictedValue(1.0, null), equalTo(1.0));
|
assertThat(PredictionFieldType.NUMBER.transformPredictedValue(1.0, null), equalTo(1.0));
|
||||||
|
assertThat(PredictionFieldType.NUMBER.transformPredictedValue(1.0, ""), equalTo(1.0));
|
||||||
|
long expected = randomLong();
|
||||||
|
assertThat(PredictionFieldType.NUMBER.transformPredictedValue(1.0, String.valueOf(expected)), equalTo(expected));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,11 +16,12 @@ import java.util.Collections;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests.randomRegressionConfig;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
|
||||||
public class RegressionConfigUpdateTests extends AbstractBWCSerializationTestCase<RegressionConfigUpdate> {
|
public class RegressionConfigUpdateTests extends AbstractBWCSerializationTestCase<RegressionConfigUpdate> {
|
||||||
|
|
||||||
public static RegressionConfigUpdate randomRegressionConfig() {
|
public static RegressionConfigUpdate randomRegressionConfigUpdate() {
|
||||||
return new RegressionConfigUpdate(randomBoolean() ? null : randomAlphaOfLength(10),
|
return new RegressionConfigUpdate(randomBoolean() ? null : randomAlphaOfLength(10),
|
||||||
randomBoolean() ? null : randomIntBetween(0, 10));
|
randomBoolean() ? null : randomIntBetween(0, 10));
|
||||||
}
|
}
|
||||||
|
@ -40,9 +41,28 @@ public class RegressionConfigUpdateTests extends AbstractBWCSerializationTestCas
|
||||||
assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key]."));
|
assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key]."));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testApply() {
|
||||||
|
RegressionConfig originalConfig = randomRegressionConfig();
|
||||||
|
|
||||||
|
assertThat(originalConfig, equalTo(RegressionConfigUpdate.EMPTY_PARAMS.apply(originalConfig)));
|
||||||
|
|
||||||
|
assertThat(new RegressionConfig.Builder(originalConfig).setNumTopFeatureImportanceValues(5).build(),
|
||||||
|
equalTo(new RegressionConfigUpdate.Builder().setNumTopFeatureImportanceValues(5).build().apply(originalConfig)));
|
||||||
|
assertThat(new RegressionConfig.Builder()
|
||||||
|
.setNumTopFeatureImportanceValues(1)
|
||||||
|
.setResultsField("foo")
|
||||||
|
.build(),
|
||||||
|
equalTo(new RegressionConfigUpdate.Builder()
|
||||||
|
.setNumTopFeatureImportanceValues(1)
|
||||||
|
.setResultsField("foo")
|
||||||
|
.build()
|
||||||
|
.apply(originalConfig)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected RegressionConfigUpdate createTestInstance() {
|
protected RegressionConfigUpdate createTestInstance() {
|
||||||
return randomRegressionConfig();
|
return randomRegressionConfigUpdate();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
Loading…
Reference in New Issue