[7.x] [ML] adding prediction_field_type to inference config (#55128) (#55230)

* [ML] adding prediction_field_type to inference config (#55128)

Data frame analytics dynamically determines the classification field type. This field type then dictates the encoded JSON that is written to Elasticsearch. 

Inference needs to know about this field type so that it may provide the EXACT SAME predicted values as analytics. 

Here is added a new field `prediction_field_type` which indicates the desired type. Options are: `string` (DEFAULT), `number`, `boolean` (where close_to(1.0) == true, false otherwise). 

Analytics provides the default `prediction_field_type` when the model is created from the process.
This commit is contained in:
Benjamin Trent 2020-04-15 09:45:22 -04:00 committed by GitHub
parent 2f91e2aab7
commit 8ff2cbf1a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 423 additions and 80 deletions

View File

@ -71,6 +71,10 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-results-field]
(Optional, string)
include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-top-classes-results-field]
`prediction_field_type`::
(Optional, string)
include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-prediction-field-type]
[discrete]
[[inference-processor-config-example]]
==== `inference_config` examples

View File

@ -375,6 +375,10 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-num-
(Optional, integer)
include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-num-top-feature-importance-values]
`prediction_field_type`::::
(Optional, string)
include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-prediction-field-type]
`results_field`::::
(Optional, string)
include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-results-field]

View File

@ -1027,6 +1027,12 @@ Specifies the field to which the top classes are written. Defaults to
`top_classes`.
end::inference-config-classification-top-classes-results-field[]
tag::inference-config-classification-prediction-field-type[]
Specifies the type of the predicted field to write.
Acceptable values are: `string`, `number`, `boolean`. When `boolean` is provided
`1.0` is transformed to `true` and `0.0` to `false`.
end::inference-config-classification-prediction-field-type[]
tag::inference-config-regression-num-top-feature-importance-values[]
Specifies the maximum number of
{ml-docs}/dfa-regression.html#dfa-regression-feature-importance[feature

View File

@ -16,6 +16,7 @@ import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.mapper.FieldAliasMapper;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
@ -236,7 +237,7 @@ public class Classification implements DataFrameAnalysis {
if (predictionFieldName != null) {
params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
}
String predictionFieldType = getPredictionFieldType(fieldInfo.getTypes(dependentVariable));
String predictionFieldType = getPredictionFieldTypeParamString(getPredictionFieldType(fieldInfo.getTypes(dependentVariable)));
if (predictionFieldType != null) {
params.put(PREDICTION_FIELD_TYPE, predictionFieldType);
}
@ -245,19 +246,36 @@ public class Classification implements DataFrameAnalysis {
return params;
}
private static String getPredictionFieldType(Set<String> dependentVariableTypes) {
private static String getPredictionFieldTypeParamString(PredictionFieldType predictionFieldType) {
if (predictionFieldType == null) {
return null;
}
switch(predictionFieldType)
{
case NUMBER:
// C++ process uses int64_t type, so it is safe for the dependent variable to use long numbers.
return "int";
case STRING:
return "string";
case BOOLEAN:
return "bool";
default:
return null;
}
}
public static PredictionFieldType getPredictionFieldType(Set<String> dependentVariableTypes) {
if (dependentVariableTypes == null) {
return null;
}
if (Types.categorical().containsAll(dependentVariableTypes)) {
return "string";
return PredictionFieldType.STRING;
}
if (Types.bool().containsAll(dependentVariableTypes)) {
return "bool";
return PredictionFieldType.BOOLEAN;
}
if (Types.discreteNumerical().containsAll(dependentVariableTypes)) {
// C++ process uses int64_t type, so it is safe for the dependent variable to use long numbers.
return "int";
return PredictionFieldType.NUMBER;
}
return null;
}

View File

@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
@ -12,6 +13,7 @@ import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
@ -30,6 +32,7 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
private final String resultsField;
private final String classificationLabel;
private final List<TopClassEntry> topClasses;
private final PredictionFieldType predictionFieldType;
public ClassificationInferenceResults(double value,
String classificationLabel,
@ -58,6 +61,7 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses);
this.topNumClassesField = classificationConfig.getTopClassesResultsField();
this.resultsField = classificationConfig.getResultsField();
this.predictionFieldType = classificationConfig.getPredictionFieldType();
}
public ClassificationInferenceResults(StreamInput in) throws IOException {
@ -66,6 +70,11 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
this.topClasses = Collections.unmodifiableList(in.readList(TopClassEntry::new));
this.topNumClassesField = in.readString();
this.resultsField = in.readString();
if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
this.predictionFieldType = in.readEnum(PredictionFieldType.class);
} else {
this.predictionFieldType = PredictionFieldType.STRING;
}
}
public String getClassificationLabel() {
@ -83,6 +92,9 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
out.writeCollection(topClasses);
out.writeString(topNumClassesField);
out.writeString(resultsField);
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
out.writeEnum(predictionFieldType);
}
}
@Override
@ -95,12 +107,19 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
&& Objects.equals(resultsField, that.resultsField)
&& Objects.equals(topNumClassesField, that.topNumClassesField)
&& Objects.equals(topClasses, that.topClasses)
&& Objects.equals(predictionFieldType, that.predictionFieldType)
&& Objects.equals(getFeatureImportance(), that.getFeatureImportance());
}
@Override
public int hashCode() {
return Objects.hash(value(), classificationLabel, topClasses, resultsField, topNumClassesField, getFeatureImportance());
return Objects.hash(value(),
classificationLabel,
topClasses,
resultsField,
topNumClassesField,
getFeatureImportance(),
predictionFieldType);
}
@Override
@ -112,7 +131,8 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
public void writeResult(IngestDocument document, String parentResultField) {
ExceptionsHelper.requireNonNull(document, "document");
ExceptionsHelper.requireNonNull(parentResultField, "parentResultField");
document.setFieldValue(parentResultField + "." + this.resultsField, valueAsString());
document.setFieldValue(parentResultField + "." + this.resultsField,
predictionFieldType.transformPredictedValue(value(), valueAsString()));
if (topClasses.size() > 0) {
document.setFieldValue(parentResultField + "." + topNumClassesField,
topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
@ -130,34 +150,33 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
return NAME;
}
public static class TopClassEntry implements Writeable {
public final ParseField CLASS_NAME = new ParseField("class_name");
public final ParseField CLASS_PROBABILITY = new ParseField("class_probability");
public final ParseField CLASS_SCORE = new ParseField("class_score");
private final String classification;
private final Object classification;
private final double probability;
private final double score;
public TopClassEntry(String classification, double probability) {
this(classification, probability, probability);
}
public TopClassEntry(String classification, double probability, double score) {
public TopClassEntry(Object classification, double probability, double score) {
this.classification = ExceptionsHelper.requireNonNull(classification, CLASS_NAME);
this.probability = probability;
this.score = score;
}
public TopClassEntry(StreamInput in) throws IOException {
this.classification = in.readString();
if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
this.classification = in.readGenericValue();
} else {
this.classification = in.readString();
}
this.probability = in.readDouble();
this.score = in.readDouble();
}
public String getClassification() {
public Object getClassification() {
return classification;
}
@ -179,7 +198,11 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(classification);
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
out.writeGenericValue(classification);
} else {
out.writeString(classification.toString());
}
out.writeDouble(probability);
out.writeDouble(score);
}

View File

@ -27,15 +27,17 @@ public class ClassificationConfig implements LenientlyParsedInferenceConfig, Str
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
public static final ParseField TOP_CLASSES_RESULTS_FIELD = new ParseField("top_classes_results_field");
public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
public static final ParseField PREDICTION_FIELD_TYPE = new ParseField("prediction_field_type");
private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0;
public static ClassificationConfig EMPTY_PARAMS =
new ClassificationConfig(0, DEFAULT_RESULTS_FIELD, DEFAULT_TOP_CLASSES_RESULTS_FIELD, null);
new ClassificationConfig(0, DEFAULT_RESULTS_FIELD, DEFAULT_TOP_CLASSES_RESULTS_FIELD, null, null);
private final int numTopClasses;
private final String topClassesResultsField;
private final String resultsField;
private final int numTopFeatureImportanceValues;
private final PredictionFieldType predictionFieldType;
private static final ObjectParser<ClassificationConfig.Builder, Void> LENIENT_PARSER = createParser(true);
private static final ObjectParser<ClassificationConfig.Builder, Void> STRICT_PARSER = createParser(false);
@ -49,6 +51,17 @@ public class ClassificationConfig implements LenientlyParsedInferenceConfig, Str
parser.declareString(ClassificationConfig.Builder::setResultsField, RESULTS_FIELD);
parser.declareString(ClassificationConfig.Builder::setTopClassesResultsField, TOP_CLASSES_RESULTS_FIELD);
parser.declareInt(ClassificationConfig.Builder::setNumTopFeatureImportanceValues, NUM_TOP_FEATURE_IMPORTANCE_VALUES);
parser.declareField(ClassificationConfig.Builder::setPredictionFieldType,
(p, c) -> {
try {
return PredictionFieldType.fromString(p.text());
} catch (IllegalArgumentException iae) {
if (lenient) {
return PredictionFieldType.STRING;
}
throw iae;
}
}, PREDICTION_FIELD_TYPE, ObjectParser.ValueType.STRING);
return parser;
}
@ -61,14 +74,14 @@ public class ClassificationConfig implements LenientlyParsedInferenceConfig, Str
}
public ClassificationConfig(Integer numTopClasses) {
this(numTopClasses, null, null, null);
this(numTopClasses, null, null, null, null);
}
public ClassificationConfig(Integer numTopClasses, String resultsField, String topClassesResultsField) {
this(numTopClasses, resultsField, topClassesResultsField, 0);
}
public ClassificationConfig(Integer numTopClasses, String resultsField, String topClassesResultsField, Integer featureImportance) {
public ClassificationConfig(Integer numTopClasses,
String resultsField,
String topClassesResultsField,
Integer featureImportance,
PredictionFieldType predictionFieldType) {
this.numTopClasses = numTopClasses == null ? 0 : numTopClasses;
this.topClassesResultsField = topClassesResultsField == null ? DEFAULT_TOP_CLASSES_RESULTS_FIELD : topClassesResultsField;
this.resultsField = resultsField == null ? DEFAULT_RESULTS_FIELD : resultsField;
@ -77,6 +90,7 @@ public class ClassificationConfig implements LenientlyParsedInferenceConfig, Str
"] must be greater than or equal to 0");
}
this.numTopFeatureImportanceValues = featureImportance == null ? 0 : featureImportance;
this.predictionFieldType = predictionFieldType == null ? PredictionFieldType.STRING : predictionFieldType;
}
public ClassificationConfig(StreamInput in) throws IOException {
@ -88,6 +102,11 @@ public class ClassificationConfig implements LenientlyParsedInferenceConfig, Str
} else {
this.numTopFeatureImportanceValues = 0;
}
if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
this.predictionFieldType = PredictionFieldType.fromStream(in);
} else {
this.predictionFieldType = PredictionFieldType.STRING;
}
}
public int getNumTopClasses() {
@ -106,6 +125,10 @@ public class ClassificationConfig implements LenientlyParsedInferenceConfig, Str
return numTopFeatureImportanceValues;
}
public PredictionFieldType getPredictionFieldType() {
return predictionFieldType;
}
@Override
public boolean requestingImportance() {
return numTopFeatureImportanceValues > 0;
@ -119,6 +142,9 @@ public class ClassificationConfig implements LenientlyParsedInferenceConfig, Str
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
out.writeVInt(numTopFeatureImportanceValues);
}
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
predictionFieldType.writeTo(out);
}
}
@Override
@ -129,12 +155,13 @@ public class ClassificationConfig implements LenientlyParsedInferenceConfig, Str
return Objects.equals(numTopClasses, that.numTopClasses)
&& Objects.equals(topClassesResultsField, that.topClassesResultsField)
&& Objects.equals(resultsField, that.resultsField)
&& Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues);
&& Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues)
&& Objects.equals(predictionFieldType, that.predictionFieldType);
}
@Override
public int hashCode() {
return Objects.hash(numTopClasses, topClassesResultsField, resultsField, numTopFeatureImportanceValues);
return Objects.hash(numTopClasses, topClassesResultsField, resultsField, numTopFeatureImportanceValues, predictionFieldType);
}
@Override
@ -144,6 +171,7 @@ public class ClassificationConfig implements LenientlyParsedInferenceConfig, Str
builder.field(TOP_CLASSES_RESULTS_FIELD.getPreferredName(), topClassesResultsField);
builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
builder.field(PREDICTION_FIELD_TYPE.getPreferredName(), predictionFieldType.toString());
builder.endObject();
return builder;
}
@ -176,6 +204,7 @@ public class ClassificationConfig implements LenientlyParsedInferenceConfig, Str
private Integer numTopClasses;
private String topClassesResultsField;
private String resultsField;
private PredictionFieldType predictionFieldType;
private Integer numTopFeatureImportanceValues;
Builder() {}
@ -207,8 +236,17 @@ public class ClassificationConfig implements LenientlyParsedInferenceConfig, Str
return this;
}
public Builder setPredictionFieldType(PredictionFieldType predictionFieldType) {
this.predictionFieldType = predictionFieldType;
return this;
}
public ClassificationConfig build() {
return new ClassificationConfig(numTopClasses, resultsField, topClassesResultsField, numTopFeatureImportanceValues);
return new ClassificationConfig(numTopClasses,
resultsField,
topClassesResultsField,
numTopFeatureImportanceValues,
predictionFieldType);
}
}
}

View File

@ -20,6 +20,7 @@ import java.util.Objects;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig.NUM_TOP_CLASSES;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig.PREDICTION_FIELD_TYPE;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig.RESULTS_FIELD;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig.TOP_CLASSES_RESULTS_FIELD;
@ -28,12 +29,13 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate<Classif
public static final ParseField NAME = new ParseField("classification");
public static ClassificationConfigUpdate EMPTY_PARAMS =
new ClassificationConfigUpdate(null, null, null, null);
new ClassificationConfigUpdate(null, null, null, null, null);
private final Integer numTopClasses;
private final String topClassesResultsField;
private final String resultsField;
private final Integer numTopFeatureImportanceValues;
private final PredictionFieldType predictionFieldType;
public static ClassificationConfigUpdate fromMap(Map<String, Object> map) {
Map<String, Object> options = new HashMap<>(map);
@ -41,18 +43,24 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate<Classif
String topClassesResultsField = (String)options.remove(TOP_CLASSES_RESULTS_FIELD.getPreferredName());
String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName());
Integer featureImportance = (Integer)options.remove(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName());
String predictionFieldTypeStr = (String)options.remove(PREDICTION_FIELD_TYPE.getPreferredName());
if (options.isEmpty() == false) {
throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet());
}
return new ClassificationConfigUpdate(numTopClasses, resultsField, topClassesResultsField, featureImportance);
return new ClassificationConfigUpdate(numTopClasses,
resultsField,
topClassesResultsField,
featureImportance,
predictionFieldTypeStr == null ? null : PredictionFieldType.fromString(predictionFieldTypeStr));
}
public static ClassificationConfigUpdate fromConfig(ClassificationConfig config) {
return new ClassificationConfigUpdate(config.getNumTopClasses(),
config.getResultsField(),
config.getTopClassesResultsField(),
config.getNumTopFeatureImportanceValues());
config.getNumTopFeatureImportanceValues(),
config.getPredictionFieldType());
}
private static final ObjectParser<ClassificationConfigUpdate.Builder, Void> STRICT_PARSER = createParser(false);
@ -66,6 +74,7 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate<Classif
parser.declareString(ClassificationConfigUpdate.Builder::setResultsField, RESULTS_FIELD);
parser.declareString(ClassificationConfigUpdate.Builder::setTopClassesResultsField, TOP_CLASSES_RESULTS_FIELD);
parser.declareInt(ClassificationConfigUpdate.Builder::setNumTopFeatureImportanceValues, NUM_TOP_FEATURE_IMPORTANCE_VALUES);
parser.declareString(ClassificationConfigUpdate.Builder::setPredictionFieldType, PREDICTION_FIELD_TYPE);
return parser;
}
@ -76,7 +85,8 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate<Classif
public ClassificationConfigUpdate(Integer numTopClasses,
String resultsField,
String topClassesResultsField,
Integer featureImportance) {
Integer featureImportance,
PredictionFieldType predictionFieldType) {
this.numTopClasses = numTopClasses;
this.topClassesResultsField = topClassesResultsField;
this.resultsField = resultsField;
@ -85,6 +95,7 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate<Classif
"] must be greater than or equal to 0");
}
this.numTopFeatureImportanceValues = featureImportance;
this.predictionFieldType = predictionFieldType;
}
public ClassificationConfigUpdate(StreamInput in) throws IOException {
@ -92,6 +103,7 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate<Classif
this.topClassesResultsField = in.readOptionalString();
this.resultsField = in.readOptionalString();
this.numTopFeatureImportanceValues = in.readOptionalVInt();
this.predictionFieldType = in.readOptionalWriteable(PredictionFieldType::fromStream);
}
public Integer getNumTopClasses() {
@ -110,12 +122,17 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate<Classif
return numTopFeatureImportanceValues;
}
public PredictionFieldType getPredictionFieldType() {
return predictionFieldType;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalInt(numTopClasses);
out.writeOptionalString(topClassesResultsField);
out.writeOptionalString(resultsField);
out.writeOptionalVInt(numTopFeatureImportanceValues);
out.writeOptionalWriteable(predictionFieldType);
}
@Override
@ -126,12 +143,13 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate<Classif
return Objects.equals(numTopClasses, that.numTopClasses)
&& Objects.equals(topClassesResultsField, that.topClassesResultsField)
&& Objects.equals(resultsField, that.resultsField)
&& Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues);
&& Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues)
&& Objects.equals(predictionFieldType, that.predictionFieldType);
}
@Override
public int hashCode() {
return Objects.hash(numTopClasses, topClassesResultsField, resultsField, numTopFeatureImportanceValues);
return Objects.hash(numTopClasses, topClassesResultsField, resultsField, numTopFeatureImportanceValues, predictionFieldType);
}
@Override
@ -149,6 +167,9 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate<Classif
if (numTopFeatureImportanceValues != null) {
builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
}
if (predictionFieldType != null) {
builder.field(PREDICTION_FIELD_TYPE.getPreferredName(), predictionFieldType.toString());
}
builder.endObject();
return builder;
}
@ -181,6 +202,9 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate<Classif
if (numTopClasses != null) {
builder.setNumTopClasses(numTopClasses);
}
if (predictionFieldType != null) {
builder.setPredictionFieldType(predictionFieldType);
}
return builder.build();
}
@ -199,7 +223,8 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate<Classif
&& (numTopFeatureImportanceValues == null
|| originalConfig.getNumTopFeatureImportanceValues() == numTopFeatureImportanceValues)
&& (topClassesResultsField == null || topClassesResultsField.equals(originalConfig.getTopClassesResultsField()))
&& (numTopClasses == null || originalConfig.getNumTopClasses() == numTopClasses);
&& (numTopClasses == null || originalConfig.getNumTopClasses() == numTopClasses)
&& (predictionFieldType == null || predictionFieldType.equals(originalConfig.getPredictionFieldType()));
}
public static class Builder {
@ -207,6 +232,7 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate<Classif
private String topClassesResultsField;
private String resultsField;
private Integer numTopFeatureImportanceValues;
private PredictionFieldType predictionFieldType;
public Builder setNumTopClasses(int numTopClasses) {
this.numTopClasses = numTopClasses;
@ -228,8 +254,17 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate<Classif
return this;
}
private Builder setPredictionFieldType(String predictionFieldType) {
this.predictionFieldType = PredictionFieldType.fromString(predictionFieldType);
return this;
}
public ClassificationConfigUpdate build() {
return new ClassificationConfigUpdate(numTopClasses, resultsField, topClassesResultsField, numTopFeatureImportanceValues);
return new ClassificationConfigUpdate(numTopClasses,
resultsField,
topClassesResultsField,
numTopFeatureImportanceValues,
predictionFieldType);
}
}
}

View File

@ -31,7 +31,8 @@ public final class InferenceHelpers {
public static Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses(double[] probabilities,
List<String> classificationLabels,
@Nullable double[] classificationWeights,
int numToInclude) {
int numToInclude,
PredictionFieldType predictionFieldType) {
if (classificationLabels != null && probabilities.length != classificationLabels.size()) {
throw ExceptionsHelper
@ -67,7 +68,10 @@ public final class InferenceHelpers {
List<ClassificationInferenceResults.TopClassEntry> topClassEntries = new ArrayList<>(count);
for(int i = 0; i < count; i++) {
int idx = sortedIndices[i];
topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities[idx], scores[idx]));
topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(
predictionFieldType.transformPredictedValue((double)idx, labels.get(idx)),
probabilities[idx],
scores[idx]));
}
return Tuple.tuple(sortedIndices[0], topClassEntries);

View File

@ -0,0 +1,68 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import java.io.IOException;
import java.util.Locale;
/**
* The type of the prediction field.
* This modifies how the predicted class values are written for classification models
*/
public enum PredictionFieldType implements Writeable {
STRING,
NUMBER,
BOOLEAN;
private static final double EPS = 1.0E-9;
public static PredictionFieldType fromString(String name) {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
}
public static PredictionFieldType fromStream(StreamInput in) throws IOException {
return in.readEnum(PredictionFieldType.class);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeEnum(this);
}
@Override
public String toString() {
return name().toLowerCase(Locale.ROOT);
}
public Object transformPredictedValue(Double value, String stringRep) {
if (value == null) {
return null;
}
switch(this) {
case STRING:
return stringRep == null ? value.toString() : stringRep;
case BOOLEAN:
if ((areClose(value, 1.0D) || areClose(value, 0.0D)) == false) {
throw new IllegalArgumentException(
"Cannot transform numbers other than 0.0 or 1.0 to boolean. Provided number [" + value + "]");
}
return areClose(value, 1.0D);
case NUMBER:
default:
return value;
}
}
private static boolean areClose(double value1, double value2) {
return Math.abs(value1 - value2) < EPS;
}
}

View File

@ -186,7 +186,8 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
processedInferences,
classificationLabels,
classificationWeights,
classificationConfig.getNumTopClasses());
classificationConfig.getNumTopClasses(),
classificationConfig.getPredictionFieldType());
return new ClassificationInferenceResults((double)topClasses.v1(),
classificationLabel(topClasses.v1(), classificationLabels),
topClasses.v2(),

View File

@ -20,6 +20,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConf
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@ -136,7 +137,8 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
probabilities,
LANGUAGE_NAMES,
null,
classificationConfig.getNumTopClasses());
classificationConfig.getNumTopClasses(),
PredictionFieldType.STRING);
assert topClasses.v1() >= 0 && topClasses.v1() < LANGUAGE_NAMES.size() :
"Invalid language predicted. Predicted language index " + topClasses.v1();
return new ClassificationInferenceResults(topClasses.v1(),

View File

@ -162,7 +162,8 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
classificationProbability(value),
classificationLabels,
null,
classificationConfig.getNumTopClasses());
classificationConfig.getNumTopClasses(),
classificationConfig.getPredictionFieldType());
return new ClassificationInferenceResults(topClasses.v1(),
classificationLabel(topClasses.v1(), classificationLabels),
topClasses.v2(),

View File

@ -10,6 +10,7 @@ import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
import java.util.Arrays;
import java.util.Collections;
@ -44,7 +45,7 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
}
private static ClassificationInferenceResults.TopClassEntry createRandomClassEntry() {
return new ClassificationInferenceResults.TopClassEntry(randomAlphaOfLength(10), randomDouble());
return new ClassificationInferenceResults.TopClassEntry(randomAlphaOfLength(10), randomDouble(), randomDouble());
}
public void testWriteResultsWithClassificationLabel() {
@ -70,13 +71,13 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
@SuppressWarnings("unchecked")
public void testWriteResultsWithTopClasses() {
List<ClassificationInferenceResults.TopClassEntry> entries = Arrays.asList(
new ClassificationInferenceResults.TopClassEntry("foo", 0.7),
new ClassificationInferenceResults.TopClassEntry("bar", 0.2),
new ClassificationInferenceResults.TopClassEntry("baz", 0.1));
new ClassificationInferenceResults.TopClassEntry("foo", 0.7, 0.7),
new ClassificationInferenceResults.TopClassEntry("bar", 0.2, 0.2),
new ClassificationInferenceResults.TopClassEntry("baz", 0.1, 0.1));
ClassificationInferenceResults result = new ClassificationInferenceResults(1.0,
"foo",
entries,
new ClassificationConfig(3, "my_results", "bar"));
new ClassificationConfig(3, "my_results", "bar", null, PredictionFieldType.STRING));
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
result.writeResult(document, "result_field");
@ -103,7 +104,7 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
"foo",
Collections.emptyList(),
importanceList,
new ClassificationConfig(0, "predicted_value", "top_classes", 3));
new ClassificationConfig(0, "predicted_value", "top_classes", 3, PredictionFieldType.STRING));
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
result.writeResult(document, "result_field");

View File

@ -21,7 +21,9 @@ public class ClassificationConfigTests extends AbstractBWCSerializationTestCase<
public static ClassificationConfig randomClassificationConfig() {
return new ClassificationConfig(randomBoolean() ? null : randomIntBetween(-1, 10),
randomBoolean() ? null : randomAlphaOfLength(10),
randomBoolean() ? null : randomAlphaOfLength(10)
randomBoolean() ? null : randomAlphaOfLength(10),
randomBoolean() ? null : randomIntBetween(0, 10),
randomFrom(PredictionFieldType.values())
);
}

View File

@ -24,20 +24,22 @@ public class ClassificationConfigUpdateTests extends AbstractBWCSerializationTes
return new ClassificationConfigUpdate(randomBoolean() ? null : randomIntBetween(-1, 10),
randomBoolean() ? null : randomAlphaOfLength(10),
randomBoolean() ? null : randomAlphaOfLength(10),
randomBoolean() ? null : randomIntBetween(0, 10)
randomBoolean() ? null : randomIntBetween(0, 10),
randomBoolean() ? null : randomFrom(PredictionFieldType.values())
);
}
public void testFromMap() {
ClassificationConfigUpdate expected = new ClassificationConfigUpdate(null, null, null, null);
ClassificationConfigUpdate expected = ClassificationConfigUpdate.EMPTY_PARAMS;
assertThat(ClassificationConfigUpdate.fromMap(Collections.emptyMap()), equalTo(expected));
expected = new ClassificationConfigUpdate(3, "foo", "bar", 2);
expected = new ClassificationConfigUpdate(3, "foo", "bar", 2, PredictionFieldType.NUMBER);
Map<String, Object> configMap = new HashMap<>();
configMap.put(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 3);
configMap.put(ClassificationConfig.RESULTS_FIELD.getPreferredName(), "foo");
configMap.put(ClassificationConfig.TOP_CLASSES_RESULTS_FIELD.getPreferredName(), "bar");
configMap.put(ClassificationConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), 2);
configMap.put(ClassificationConfig.PREDICTION_FIELD_TYPE.getPreferredName(), PredictionFieldType.NUMBER.toString());
assertThat(ClassificationConfigUpdate.fromMap(configMap), equalTo(expected));
}

View File

@ -0,0 +1,43 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.test.ESTestCase;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.nullValue;
public class PredictionFieldTypeTests extends ESTestCase {
public void testTransformPredictedValueBoolean() {
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(null, randomBoolean() ? null : randomAlphaOfLength(10)),
is(nullValue()));
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(1.0, randomBoolean() ? null : randomAlphaOfLength(10)),
is(true));
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(0.0, randomBoolean() ? null : randomAlphaOfLength(10)),
is(false));
expectThrows(IllegalArgumentException.class,
() -> PredictionFieldType.BOOLEAN.transformPredictedValue(0.1, randomBoolean() ? null : randomAlphaOfLength(10)));
expectThrows(IllegalArgumentException.class,
() -> PredictionFieldType.BOOLEAN.transformPredictedValue(1.1, randomBoolean() ? null : randomAlphaOfLength(10)));
}
public void testTransformPredictedValueString() {
assertThat(PredictionFieldType.STRING.transformPredictedValue(null, randomBoolean() ? null : randomAlphaOfLength(10)),
is(nullValue()));
assertThat(PredictionFieldType.STRING.transformPredictedValue(1.0, "foo"), equalTo("foo"));
assertThat(PredictionFieldType.STRING.transformPredictedValue(1.0, null), equalTo("1.0"));
}
public void testTransformPredictedValueNumber() {
assertThat(PredictionFieldType.NUMBER.transformPredictedValue(null, randomBoolean() ? null : randomAlphaOfLength(10)),
is(nullValue()));
assertThat(PredictionFieldType.NUMBER.transformPredictedValue(1.0, "foo"), equalTo(1.0));
assertThat(PredictionFieldType.NUMBER.transformPredictedValue(1.0, null), equalTo(1.0));
}
}

View File

@ -27,6 +27,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
@ -47,6 +48,7 @@ import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
@ -243,9 +245,11 @@ public class AnalyticsResultProcessor {
case CLASSIFICATION:
assert analytics.getAnalysis() instanceof Classification;
Classification classification = ((Classification)analytics.getAnalysis());
PredictionFieldType predictionFieldType = getPredictionFieldType(classification);
return ClassificationConfig.builder()
.setNumTopClasses(classification.getNumTopClasses())
.setNumTopFeatureImportanceValues(classification.getBoostedTreeParams().getNumTopFeatureImportanceValues())
.setPredictionFieldType(predictionFieldType)
.build();
case REGRESSION:
assert analytics.getAnalysis() instanceof Regression;
@ -254,14 +258,24 @@ public class AnalyticsResultProcessor {
.setNumTopFeatureImportanceValues(regression.getBoostedTreeParams().getNumTopFeatureImportanceValues())
.build();
default:
setAndReportFailure(ExceptionsHelper.serverError(
throw ExceptionsHelper.serverError(
"process created a model with an unsupported target type [{}]",
null,
targetType));
return null;
targetType);
}
}
PredictionFieldType getPredictionFieldType(Classification classification) {
String dependentVariable = classification.getDependentVariable();
Optional<ExtractedField> extractedField = fieldNames.stream()
.filter(f -> f.getName().equals(dependentVariable))
.findAny();
PredictionFieldType predictionFieldType = Classification.getPredictionFieldType(
extractedField.isPresent() ? extractedField.get().getTypes() : null
);
return predictionFieldType == null ? PredictionFieldType.STRING : predictionFieldType;
}
private String getDependentVariable() {
if (analytics.getAnalysis() instanceof Classification) {
return ((Classification)analytics.getAnalysis()).getDependentVariable();

View File

@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.dataframe.process;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.collect.Set;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.json.JsonXContent;
@ -15,11 +16,13 @@ import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.security.user.XPackUser;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
@ -210,6 +213,19 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
Mockito.verifyNoMoreInteractions(auditor);
}
public void testGetPredictionFieldType() {
List<ExtractedField> extractedFieldList = Arrays.asList(
new DocValueField("foo", Collections.emptySet()),
new DocValueField("bar", Set.of("keyword")),
new DocValueField("baz", Set.of("long")),
new DocValueField("bingo", Set.of("boolean")));
AnalyticsResultProcessor resultProcessor = createResultProcessor(extractedFieldList);
assertThat(resultProcessor.getPredictionFieldType(new Classification("foo")), equalTo(PredictionFieldType.STRING));
assertThat(resultProcessor.getPredictionFieldType(new Classification("bar")), equalTo(PredictionFieldType.STRING));
assertThat(resultProcessor.getPredictionFieldType(new Classification("baz")), equalTo(PredictionFieldType.NUMBER));
assertThat(resultProcessor.getPredictionFieldType(new Classification("bingo")), equalTo(PredictionFieldType.BOOLEAN));
}
@SuppressWarnings("unchecked")
public void testProcess_GivenInferenceModelFailedToStore() {
givenDataFrameRows(0);

View File

@ -15,6 +15,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResu
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
@ -77,8 +78,8 @@ public class InferenceProcessorTests extends ESTestCase {
@SuppressWarnings("unchecked")
public void testMutateDocumentClassificationTopNClasses() {
ClassificationConfigUpdate classificationConfigUpdate = new ClassificationConfigUpdate(2, null, null, null);
ClassificationConfig classificationConfig = new ClassificationConfig(2, null, null, null);
ClassificationConfigUpdate classificationConfigUpdate = new ClassificationConfigUpdate(2, null, null, null, null);
ClassificationConfig classificationConfig = new ClassificationConfig(2, null, null, null, PredictionFieldType.STRING);
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
auditor,
"my_processor",
@ -92,8 +93,8 @@ public class InferenceProcessorTests extends ESTestCase {
IngestDocument document = new IngestDocument(source, ingestMetadata);
List<ClassificationInferenceResults.TopClassEntry> classes = new ArrayList<>(2);
classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6));
classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4));
classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6, 0.6));
classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4, 0.4));
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", classes, classificationConfig)),
@ -107,8 +108,8 @@ public class InferenceProcessorTests extends ESTestCase {
}
public void testMutateDocumentClassificationFeatureInfluence() {
ClassificationConfig classificationConfig = new ClassificationConfig(2, null, null, 2);
ClassificationConfigUpdate classificationConfigUpdate = new ClassificationConfigUpdate(2, null, null, 2);
ClassificationConfig classificationConfig = new ClassificationConfig(2, null, null, 2, PredictionFieldType.STRING);
ClassificationConfigUpdate classificationConfigUpdate = new ClassificationConfigUpdate(2, null, null, 2, null);
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
auditor,
"my_processor",
@ -122,8 +123,8 @@ public class InferenceProcessorTests extends ESTestCase {
IngestDocument document = new IngestDocument(source, ingestMetadata);
List<ClassificationInferenceResults.TopClassEntry> classes = new ArrayList<>(2);
classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6));
classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4));
classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6, 0.6));
classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4, 0.4));
List<FeatureImportance> featureInfluence = new ArrayList<>();
featureInfluence.add(FeatureImportance.forRegression("feature_1", 1.13));
@ -148,8 +149,8 @@ public class InferenceProcessorTests extends ESTestCase {
@SuppressWarnings("unchecked")
public void testMutateDocumentClassificationTopNClassesWithSpecificField() {
ClassificationConfig classificationConfig = new ClassificationConfig(2, "result", "tops");
ClassificationConfigUpdate classificationConfigUpdate = new ClassificationConfigUpdate(2, "result", "tops", null);
ClassificationConfig classificationConfig = new ClassificationConfig(2, "result", "tops", null, PredictionFieldType.STRING);
ClassificationConfigUpdate classificationConfigUpdate = new ClassificationConfigUpdate(2, "result", "tops", null, null);
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
auditor,
"my_processor",
@ -163,8 +164,8 @@ public class InferenceProcessorTests extends ESTestCase {
IngestDocument document = new IngestDocument(source, ingestMetadata);
List<ClassificationInferenceResults.TopClassEntry> classes = new ArrayList<>(2);
classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6));
classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4));
classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6, 0.6));
classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4, 0.4));
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", classes, classificationConfig)),
@ -240,7 +241,7 @@ public class InferenceProcessorTests extends ESTestCase {
"my_processor",
"my_field",
modelId,
new ClassificationConfigUpdate(topNClasses, null, null, null),
new ClassificationConfigUpdate(topNClasses, null, null, null, null),
Collections.emptyMap());
Map<String, Object> source = new HashMap<String, Object>(){{
@ -269,7 +270,7 @@ public class InferenceProcessorTests extends ESTestCase {
"my_processor",
"my_field",
modelId,
new ClassificationConfigUpdate(topNClasses, null, null, null),
new ClassificationConfigUpdate(topNClasses, null, null, null, null),
fieldMapping);
Map<String, Object> source = new HashMap<String, Object>(5){{
@ -305,7 +306,7 @@ public class InferenceProcessorTests extends ESTestCase {
"my_processor",
"my_field",
modelId,
new ClassificationConfigUpdate(topNClasses, null, null, null),
new ClassificationConfigUpdate(topNClasses, null, null, null, null),
fieldMapping);
Map<String, Object> source = new HashMap<String, Object>(5){{

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.ml.inference.loadingservice;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
@ -17,6 +18,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConf
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
@ -74,13 +76,13 @@ public class LocalModelTests extends ESTestCase {
put("categorical", "dog");
}};
SingleValueInferenceResults result = getSingleValue(model, fields, new ClassificationConfigUpdate(0, null, null, null));
SingleValueInferenceResults result = getSingleValue(model, fields, ClassificationConfigUpdate.EMPTY_PARAMS);
assertThat(result.value(), equalTo(0.0));
assertThat(result.valueAsString(), is("0"));
assertThat(model.getLatestStatsAndReset().getInferenceCount(), equalTo(1L));
ClassificationInferenceResults classificationResult =
(ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfigUpdate(1, null, null, null));
(ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfigUpdate(1, null, null, null, null));
assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001));
assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("0"));
assertThat(model.getLatestStatsAndReset().getInferenceCount(), equalTo(1L));
@ -97,30 +99,88 @@ public class LocalModelTests extends ESTestCase {
Collections.singletonMap("field.foo", "field.foo.keyword"),
ClassificationConfig.EMPTY_PARAMS,
modelStatsService);
result = getSingleValue(model, fields, new ClassificationConfigUpdate(0, null, null, null));
result = getSingleValue(model, fields, ClassificationConfigUpdate.EMPTY_PARAMS);
assertThat(result.value(), equalTo(0.0));
assertThat(result.valueAsString(), equalTo("not_to_be"));
classificationResult = (ClassificationInferenceResults)getSingleValue(model,
fields,
new ClassificationConfigUpdate(1, null, null, null));
new ClassificationConfigUpdate(1, null, null, null, null));
assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001));
assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("not_to_be"));
assertThat(model.getLatestStatsAndReset().getInferenceCount(), equalTo(2L));
classificationResult = (ClassificationInferenceResults)getSingleValue(model,
fields,
new ClassificationConfigUpdate(2, null, null, null));
new ClassificationConfigUpdate(2, null, null, null, null));
assertThat(classificationResult.getTopClasses(), hasSize(2));
assertThat(model.getLatestStatsAndReset().getInferenceCount(), equalTo(1L));
classificationResult = (ClassificationInferenceResults)getSingleValue(model,
fields,
new ClassificationConfigUpdate(-1, null, null, null));
new ClassificationConfigUpdate(-1, null, null, null, null));
assertThat(classificationResult.getTopClasses(), hasSize(2));
assertThat(model.getLatestStatsAndReset().getInferenceCount(), equalTo(1L));
}
@SuppressWarnings("unchecked")
public void testClassificationInferWithDifferentPredictionFieldTypes() throws Exception {
TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class);
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class));
String modelId = "classification_model";
List<String> inputFields = Arrays.asList("field.foo.keyword", "field.bar", "categorical");
TrainedModelDefinition definition = new TrainedModelDefinition.Builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
.setTrainedModel(buildClassification(true))
.build();
Model<ClassificationConfig> model = new LocalModel<>(modelId,
"test-node",
definition,
new TrainedModelInput(inputFields),
Collections.singletonMap("field.foo", "field.foo.keyword"),
ClassificationConfig.EMPTY_PARAMS,
modelStatsService);
Map<String, Object> fields = new HashMap<String, Object>() {{
put("field.foo", 1.0);
put("field.bar", 0.5);
put("categorical", "dog");
}};
InferenceResults result = getInferenceResult(
model,
fields,
new ClassificationConfigUpdate(2, null, null, null, PredictionFieldType.STRING));
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
result.writeResult(document, "result_field");
assertThat(document.getFieldValue("result_field.predicted_value", String.class), equalTo("not_to_be"));
List<?> list = document.getFieldValue("result_field.top_classes", List.class);
assertThat(list.size(), equalTo(2));
assertThat(((Map<String, Object>)list.get(0)).get("class_name"), equalTo("not_to_be"));
assertThat(((Map<String, Object>)list.get(1)).get("class_name"), equalTo("to_be"));
result = getInferenceResult(model, fields, new ClassificationConfigUpdate(2, null, null, null, PredictionFieldType.NUMBER));
document = new IngestDocument(new HashMap<>(), new HashMap<>());
result.writeResult(document, "result_field");
assertThat(document.getFieldValue("result_field.predicted_value", Double.class), equalTo(0.0));
list = document.getFieldValue("result_field.top_classes", List.class);
assertThat(list.size(), equalTo(2));
assertThat(((Map<String, Object>)list.get(0)).get("class_name"), equalTo(0.0));
assertThat(((Map<String, Object>)list.get(1)).get("class_name"), equalTo(1.0));
result = getInferenceResult(model, fields, new ClassificationConfigUpdate(2, null, null, null, PredictionFieldType.BOOLEAN));
document = new IngestDocument(new HashMap<>(), new HashMap<>());
result.writeResult(document, "result_field");
assertThat(document.getFieldValue("result_field.predicted_value", Boolean.class), equalTo(false));
list = document.getFieldValue("result_field.top_classes", List.class);
assertThat(list.size(), equalTo(2));
assertThat(((Map<String, Object>)list.get(0)).get("class_name"), equalTo(false));
assertThat(((Map<String, Object>)list.get(1)).get("class_name"), equalTo(true));
}
public void testRegression() throws Exception {
TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class);
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class));
@ -201,9 +261,9 @@ public class LocalModelTests extends ESTestCase {
}};
for(int i = 0; i < 100; i++) {
getSingleValue(model, fields, new ClassificationConfigUpdate(0, null, null, null));
getSingleValue(model, fields, ClassificationConfigUpdate.EMPTY_PARAMS);
}
SingleValueInferenceResults result = getSingleValue(model, fields, new ClassificationConfigUpdate(0, null, null, null));
SingleValueInferenceResults result = getSingleValue(model, fields, ClassificationConfigUpdate.EMPTY_PARAMS);
assertThat(result.value(), equalTo(0.0));
assertThat(result.valueAsString(), is("0"));
// Should have reset after persistence, so only 2 docs have been seen since last persistence

View File

@ -168,7 +168,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
contains("not_to_be", "to_be"));
// Get top classes
request = new InternalInferModelAction.Request(modelId2, toInfer, new ClassificationConfigUpdate(2, null, null, null), true);
request = new InternalInferModelAction.Request(modelId2, toInfer, new ClassificationConfigUpdate(2, null, null, null, null), true);
response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
ClassificationInferenceResults classificationInferenceResults =
@ -187,7 +187,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability()));
// Test that top classes restrict the number returned
request = new InternalInferModelAction.Request(modelId2, toInfer2, new ClassificationConfigUpdate(1, null, null, null), true);
request = new InternalInferModelAction.Request(modelId2, toInfer2, new ClassificationConfigUpdate(1, null, null, null, null), true);
response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(0);
@ -281,7 +281,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
// Get top classes
request = new InternalInferModelAction.Request(modelId, toInfer, new ClassificationConfigUpdate(3, null, null, null), true);
request = new InternalInferModelAction.Request(modelId, toInfer, new ClassificationConfigUpdate(3, null, null, null, null), true);
response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
ClassificationInferenceResults classificationInferenceResults =