In order to unify model inference and analytics results we need to write the same fields. prediction_probability and prediction_score are now written for inference calls against classification models.
This commit is contained in:
parent
9d4a64e749
commit
76359aaa53
|
@ -27,32 +27,40 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
|||
|
||||
public static final String NAME = "classification";
|
||||
|
||||
public static final String PREDICTION_PROBABILITY = "prediction_probability";
|
||||
public static final String PREDICTION_SCORE = "prediction_score";
|
||||
|
||||
private final String topNumClassesField;
|
||||
private final String resultsField;
|
||||
private final String classificationLabel;
|
||||
private final Double predictionProbability;
|
||||
private final Double predictionScore;
|
||||
private final List<TopClassEntry> topClasses;
|
||||
private final PredictionFieldType predictionFieldType;
|
||||
|
||||
public ClassificationInferenceResults(double value,
|
||||
String classificationLabel,
|
||||
List<TopClassEntry> topClasses,
|
||||
InferenceConfig config) {
|
||||
this(value, classificationLabel, topClasses, Collections.emptyList(), (ClassificationConfig)config);
|
||||
}
|
||||
|
||||
public ClassificationInferenceResults(double value,
|
||||
String classificationLabel,
|
||||
List<TopClassEntry> topClasses,
|
||||
List<FeatureImportance> featureImportance,
|
||||
InferenceConfig config) {
|
||||
this(value, classificationLabel, topClasses, featureImportance, (ClassificationConfig)config);
|
||||
InferenceConfig config,
|
||||
Double predictionProbability,
|
||||
Double predictionScore) {
|
||||
this(value,
|
||||
classificationLabel,
|
||||
topClasses,
|
||||
featureImportance,
|
||||
(ClassificationConfig)config,
|
||||
predictionProbability,
|
||||
predictionScore);
|
||||
}
|
||||
|
||||
private ClassificationInferenceResults(double value,
|
||||
String classificationLabel,
|
||||
List<TopClassEntry> topClasses,
|
||||
List<FeatureImportance> featureImportance,
|
||||
ClassificationConfig classificationConfig) {
|
||||
ClassificationConfig classificationConfig,
|
||||
Double predictionProbability,
|
||||
Double predictionScore) {
|
||||
super(value,
|
||||
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
|
||||
classificationConfig.getNumTopFeatureImportanceValues()));
|
||||
|
@ -61,6 +69,8 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
|||
this.topNumClassesField = classificationConfig.getTopClassesResultsField();
|
||||
this.resultsField = classificationConfig.getResultsField();
|
||||
this.predictionFieldType = classificationConfig.getPredictionFieldType();
|
||||
this.predictionProbability = predictionProbability;
|
||||
this.predictionScore = predictionScore;
|
||||
}
|
||||
|
||||
public ClassificationInferenceResults(StreamInput in) throws IOException {
|
||||
|
@ -74,6 +84,13 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
|||
} else {
|
||||
this.predictionFieldType = PredictionFieldType.STRING;
|
||||
}
|
||||
if (in.getVersion().onOrAfter(Version.V_7_9_0)) {
|
||||
this.predictionProbability = in.readOptionalDouble();
|
||||
this.predictionScore = in.readOptionalDouble();
|
||||
} else {
|
||||
this.predictionProbability = topClasses.size() > 0 ? topClasses.get(0).getProbability() : null;
|
||||
this.predictionScore = topClasses.size() > 0 ? topClasses.get(0).getScore() : null;
|
||||
}
|
||||
}
|
||||
|
||||
public String getClassificationLabel() {
|
||||
|
@ -98,6 +115,10 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
|||
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
|
||||
out.writeEnum(predictionFieldType);
|
||||
}
|
||||
if (out.getVersion().onOrAfter(Version.V_7_9_0)) {
|
||||
out.writeOptionalDouble(predictionProbability);
|
||||
out.writeOptionalDouble(predictionScore);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -111,6 +132,8 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
|||
&& Objects.equals(topNumClassesField, that.topNumClassesField)
|
||||
&& Objects.equals(topClasses, that.topClasses)
|
||||
&& Objects.equals(predictionFieldType, that.predictionFieldType)
|
||||
&& Objects.equals(predictionProbability, that.predictionProbability)
|
||||
&& Objects.equals(predictionScore, that.predictionScore)
|
||||
&& Objects.equals(getFeatureImportance(), that.getFeatureImportance());
|
||||
}
|
||||
|
||||
|
@ -121,6 +144,8 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
|||
topClasses,
|
||||
resultsField,
|
||||
topNumClassesField,
|
||||
predictionProbability,
|
||||
predictionScore,
|
||||
getFeatureImportance(),
|
||||
predictionFieldType);
|
||||
}
|
||||
|
@ -142,6 +167,14 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
|||
document.setFieldValue(parentResultField, asMap());
|
||||
}
|
||||
|
||||
public Double getPredictionProbability() {
|
||||
return predictionProbability;
|
||||
}
|
||||
|
||||
public Double getPredictionScore() {
|
||||
return predictionScore;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Object> asMap() {
|
||||
Map<String, Object> map = new LinkedHashMap<>();
|
||||
|
@ -149,6 +182,12 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
|||
if (topClasses.isEmpty() == false) {
|
||||
map.put(topNumClassesField, topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
|
||||
}
|
||||
if (predictionProbability != null) {
|
||||
map.put(PREDICTION_PROBABILITY, predictionProbability);
|
||||
}
|
||||
if (predictionScore != null) {
|
||||
map.put(PREDICTION_SCORE, predictionScore);
|
||||
}
|
||||
if (getFeatureImportance().isEmpty() == false) {
|
||||
map.put(FEATURE_IMPORTANCE, getFeatureImportance().stream().map(FeatureImportance::toMap).collect(Collectors.toList()));
|
||||
}
|
||||
|
@ -166,6 +205,12 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
|||
if (topClasses.size() > 0) {
|
||||
builder.field(topNumClassesField, topClasses);
|
||||
}
|
||||
if (predictionProbability != null) {
|
||||
builder.field(PREDICTION_PROBABILITY, predictionProbability);
|
||||
}
|
||||
if (predictionScore != null) {
|
||||
builder.field(PREDICTION_SCORE, predictionScore);
|
||||
}
|
||||
if (getFeatureImportance().size() > 0) {
|
||||
builder.field(FEATURE_IMPORTANCE, getFeatureImportance());
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ public final class InferenceHelpers {
|
|||
/**
|
||||
* @return Tuple of the highest scored index and the top classes
|
||||
*/
|
||||
public static Tuple<Integer, List<TopClassEntry>> topClasses(double[] probabilities,
|
||||
public static Tuple<TopClassificationValue, List<TopClassEntry>> topClasses(double[] probabilities,
|
||||
List<String> classificationLabels,
|
||||
@Nullable double[] classificationWeights,
|
||||
int numToInclude,
|
||||
|
@ -55,8 +55,11 @@ public final class InferenceHelpers {
|
|||
.mapToInt(i -> i)
|
||||
.toArray();
|
||||
|
||||
final TopClassificationValue topClassificationValue = new TopClassificationValue(sortedIndices[0],
|
||||
probabilities[sortedIndices[0]],
|
||||
scores[sortedIndices[0]]);
|
||||
if (numToInclude == 0) {
|
||||
return Tuple.tuple(sortedIndices[0], Collections.emptyList());
|
||||
return Tuple.tuple(topClassificationValue, Collections.emptyList());
|
||||
}
|
||||
|
||||
List<String> labels = classificationLabels == null ?
|
||||
|
@ -74,7 +77,7 @@ public final class InferenceHelpers {
|
|||
scores[idx]));
|
||||
}
|
||||
|
||||
return Tuple.tuple(sortedIndices[0], topClassEntries);
|
||||
return Tuple.tuple(topClassificationValue, topClassEntries);
|
||||
}
|
||||
|
||||
public static String classificationLabel(Integer inferenceValue, @Nullable List<String> classificationLabels) {
|
||||
|
@ -155,4 +158,28 @@ public final class InferenceHelpers {
|
|||
}
|
||||
return sumTo;
|
||||
}
|
||||
|
||||
public static class TopClassificationValue {
|
||||
private final int value;
|
||||
private final double probability;
|
||||
private final double score;
|
||||
|
||||
TopClassificationValue(int value, double probability, double score) {
|
||||
this.value = value;
|
||||
this.probability = probability;
|
||||
this.score = score;
|
||||
}
|
||||
|
||||
public int getValue() {
|
||||
return value;
|
||||
}
|
||||
|
||||
public double getProbability() {
|
||||
return probability;
|
||||
}
|
||||
|
||||
public double getScore() {
|
||||
return score;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -206,17 +206,20 @@ public class EnsembleInferenceModel implements InferenceModel {
|
|||
ClassificationConfig classificationConfig = (ClassificationConfig) config;
|
||||
assert classificationWeights == null || processedInferences.length == classificationWeights.length;
|
||||
// Adjust the probabilities according to the thresholds
|
||||
Tuple<Integer, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
|
||||
Tuple<InferenceHelpers.TopClassificationValue, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
|
||||
processedInferences,
|
||||
classificationLabels,
|
||||
classificationWeights,
|
||||
classificationConfig.getNumTopClasses(),
|
||||
classificationConfig.getPredictionFieldType());
|
||||
return new ClassificationInferenceResults((double)topClasses.v1(),
|
||||
classificationLabel(topClasses.v1(), classificationLabels),
|
||||
final InferenceHelpers.TopClassificationValue value = topClasses.v1();
|
||||
return new ClassificationInferenceResults((double)value.getValue(),
|
||||
classificationLabel(topClasses.v1().getValue(), classificationLabels),
|
||||
topClasses.v2(),
|
||||
transformFeatureImportance(decodedFeatureImportance, classificationLabels),
|
||||
config);
|
||||
config,
|
||||
value.getProbability(),
|
||||
value.getScore());
|
||||
default:
|
||||
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on ensemble model");
|
||||
}
|
||||
|
|
|
@ -174,17 +174,20 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
switch (targetType) {
|
||||
case CLASSIFICATION:
|
||||
ClassificationConfig classificationConfig = (ClassificationConfig) config;
|
||||
Tuple<Integer, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
|
||||
Tuple<InferenceHelpers.TopClassificationValue, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
|
||||
classificationProbability(value),
|
||||
classificationLabels,
|
||||
null,
|
||||
classificationConfig.getNumTopClasses(),
|
||||
classificationConfig.getPredictionFieldType());
|
||||
return new ClassificationInferenceResults(topClasses.v1(),
|
||||
classificationLabel(topClasses.v1(), classificationLabels),
|
||||
final InferenceHelpers.TopClassificationValue classificationValue = topClasses.v1();
|
||||
return new ClassificationInferenceResults(classificationValue.getValue(),
|
||||
classificationLabel(classificationValue.getValue(), classificationLabels),
|
||||
topClasses.v2(),
|
||||
InferenceHelpers.transformFeatureImportance(decodedFeatureImportance, classificationLabels),
|
||||
config);
|
||||
config,
|
||||
classificationValue.getProbability(),
|
||||
classificationValue.getScore());
|
||||
case REGRESSION:
|
||||
return new RegressionInferenceResults(value[0],
|
||||
config,
|
||||
|
|
|
@ -135,19 +135,22 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
|
|||
double[] probabilities = softMax(scores);
|
||||
|
||||
ClassificationConfig classificationConfig = (ClassificationConfig) config;
|
||||
Tuple<Integer, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
|
||||
Tuple<InferenceHelpers.TopClassificationValue, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
|
||||
probabilities,
|
||||
LANGUAGE_NAMES,
|
||||
null,
|
||||
classificationConfig.getNumTopClasses(),
|
||||
PredictionFieldType.STRING);
|
||||
assert topClasses.v1() >= 0 && topClasses.v1() < LANGUAGE_NAMES.size() :
|
||||
final InferenceHelpers.TopClassificationValue classificationValue = topClasses.v1();
|
||||
assert classificationValue.getValue() >= 0 && classificationValue.getValue() < LANGUAGE_NAMES.size() :
|
||||
"Invalid language predicted. Predicted language index " + topClasses.v1();
|
||||
return new ClassificationInferenceResults(topClasses.v1(),
|
||||
LANGUAGE_NAMES.get(topClasses.v1()),
|
||||
return new ClassificationInferenceResults(classificationValue.getValue(),
|
||||
LANGUAGE_NAMES.get(classificationValue.getValue()),
|
||||
topClasses.v2(),
|
||||
Collections.emptyList(),
|
||||
classificationConfig);
|
||||
classificationConfig,
|
||||
classificationValue.getProbability(),
|
||||
classificationValue.getScore());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -49,12 +49,20 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
|
|||
Stream.generate(featureImportanceCtor)
|
||||
.limit(randomIntBetween(1, 10))
|
||||
.collect(Collectors.toList()),
|
||||
config);
|
||||
config,
|
||||
randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false),
|
||||
randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false));
|
||||
}
|
||||
|
||||
public void testWriteResultsWithClassificationLabel() {
|
||||
ClassificationInferenceResults result =
|
||||
new ClassificationInferenceResults(1.0, "foo", Collections.emptyList(), ClassificationConfig.EMPTY_PARAMS);
|
||||
new ClassificationInferenceResults(1.0,
|
||||
"foo",
|
||||
Collections.emptyList(),
|
||||
Collections.emptyList(),
|
||||
ClassificationConfig.EMPTY_PARAMS,
|
||||
1.0,
|
||||
1.0);
|
||||
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
|
||||
result.writeResult(document, "result_field");
|
||||
|
||||
|
@ -65,7 +73,10 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
|
|||
ClassificationInferenceResults result = new ClassificationInferenceResults(1.0,
|
||||
null,
|
||||
Collections.emptyList(),
|
||||
ClassificationConfig.EMPTY_PARAMS);
|
||||
Collections.emptyList(),
|
||||
ClassificationConfig.EMPTY_PARAMS,
|
||||
1.0,
|
||||
1.0);
|
||||
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
|
||||
result.writeResult(document, "result_field");
|
||||
|
||||
|
@ -81,7 +92,10 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
|
|||
ClassificationInferenceResults result = new ClassificationInferenceResults(1.0,
|
||||
"foo",
|
||||
entries,
|
||||
new ClassificationConfig(3, "my_results", "bar", null, PredictionFieldType.STRING));
|
||||
Collections.emptyList(),
|
||||
new ClassificationConfig(3, "my_results", "bar", null, PredictionFieldType.STRING),
|
||||
0.7,
|
||||
0.7);
|
||||
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
|
||||
result.writeResult(document, "result_field");
|
||||
|
||||
|
@ -108,7 +122,9 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
|
|||
"foo",
|
||||
Collections.emptyList(),
|
||||
importanceList,
|
||||
new ClassificationConfig(0, "predicted_value", "top_classes", 3, PredictionFieldType.STRING));
|
||||
new ClassificationConfig(0, "predicted_value", "top_classes", 3, PredictionFieldType.STRING),
|
||||
1.0,
|
||||
1.0);
|
||||
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
|
||||
result.writeResult(document, "result_field");
|
||||
|
||||
|
@ -142,36 +158,65 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
|
|||
|
||||
public void testToXContent() {
|
||||
ClassificationConfig toStringConfig = new ClassificationConfig(1, null, null, null, PredictionFieldType.STRING);
|
||||
ClassificationInferenceResults result = new ClassificationInferenceResults(1.0, null, null, toStringConfig);
|
||||
ClassificationInferenceResults result = new ClassificationInferenceResults(1.0,
|
||||
null,
|
||||
null,
|
||||
Collections.emptyList(),
|
||||
toStringConfig,
|
||||
1.0,
|
||||
1.0);
|
||||
String stringRep = Strings.toString(result);
|
||||
String expected = "{\"predicted_value\":\"1.0\"}";
|
||||
String expected = "{\"predicted_value\":\"1.0\",\"prediction_probability\":1.0,\"prediction_score\":1.0}";
|
||||
assertEquals(expected, stringRep);
|
||||
|
||||
ClassificationConfig toDoubleConfig = new ClassificationConfig(1, null, null, null, PredictionFieldType.NUMBER);
|
||||
result = new ClassificationInferenceResults(1.0, null, null, toDoubleConfig);
|
||||
result = new ClassificationInferenceResults(1.0, null, null, Collections.emptyList(), toDoubleConfig,
|
||||
1.0,
|
||||
1.0);
|
||||
stringRep = Strings.toString(result);
|
||||
expected = "{\"predicted_value\":1.0}";
|
||||
expected = "{\"predicted_value\":1.0,\"prediction_probability\":1.0,\"prediction_score\":1.0}";
|
||||
assertEquals(expected, stringRep);
|
||||
|
||||
ClassificationConfig boolFieldConfig = new ClassificationConfig(1, null, null, null, PredictionFieldType.BOOLEAN);
|
||||
result = new ClassificationInferenceResults(1.0, null, null, boolFieldConfig);
|
||||
result = new ClassificationInferenceResults(1.0, null, null, Collections.emptyList(), boolFieldConfig,
|
||||
1.0,
|
||||
1.0);
|
||||
stringRep = Strings.toString(result);
|
||||
expected = "{\"predicted_value\":true}";
|
||||
expected = "{\"predicted_value\":true,\"prediction_probability\":1.0,\"prediction_score\":1.0}";
|
||||
assertEquals(expected, stringRep);
|
||||
|
||||
ClassificationConfig config = new ClassificationConfig(1);
|
||||
result = new ClassificationInferenceResults(1.0, "label1", null, config);
|
||||
result = new ClassificationInferenceResults(1.0, "label1", null, Collections.emptyList(), config,
|
||||
1.0,
|
||||
1.0);
|
||||
stringRep = Strings.toString(result);
|
||||
expected = "{\"predicted_value\":\"label1\"}";
|
||||
expected = "{\"predicted_value\":\"label1\",\"prediction_probability\":1.0,\"prediction_score\":1.0}";
|
||||
assertEquals(expected, stringRep);
|
||||
|
||||
FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyMap());
|
||||
TopClassEntry tp = new TopClassEntry("class", 1.0, 1.0);
|
||||
result = new ClassificationInferenceResults(1.0, "label1", Collections.singletonList(tp),
|
||||
Collections.singletonList(fi), config);
|
||||
Collections.singletonList(fi), config,
|
||||
1.0,
|
||||
1.0);
|
||||
stringRep = Strings.toString(result);
|
||||
expected = "{\"predicted_value\":\"label1\"," +
|
||||
"\"top_classes\":[{\"class_name\":\"class\",\"class_probability\":1.0,\"class_score\":1.0}]}";
|
||||
"\"top_classes\":[{\"class_name\":\"class\",\"class_probability\":1.0,\"class_score\":1.0}]," +
|
||||
"\"prediction_probability\":1.0,\"prediction_score\":1.0}";
|
||||
assertEquals(expected, stringRep);
|
||||
|
||||
|
||||
config = new ClassificationConfig(0);
|
||||
result = new ClassificationInferenceResults(1.0,
|
||||
"label1",
|
||||
Collections.emptyList(),
|
||||
Collections.emptyList(),
|
||||
config,
|
||||
1.0,
|
||||
1.0);
|
||||
stringRep = Strings.toString(result);
|
||||
expected = "{\"predicted_value\":\"label1\",\"prediction_probability\":1.0,\"prediction_score\":1.0}";
|
||||
assertEquals(expected, stringRep);
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -206,7 +206,7 @@ public class InferenceIngestIT extends ESRestTestCase {
|
|||
" \"inference\": {\n" +
|
||||
" \"target_field\": \"ml.classification\",\n" +
|
||||
" \"inference_config\": {\"classification\": " +
|
||||
" {\"num_top_classes\":2, " +
|
||||
" {\"num_top_classes\":0, " +
|
||||
" \"top_classes_results_field\": \"result_class_prob\"," +
|
||||
" \"num_top_feature_importance_values\": 2" +
|
||||
" }},\n" +
|
||||
|
@ -245,6 +245,8 @@ public class InferenceIngestIT extends ESRestTestCase {
|
|||
|
||||
Response response = client().performRequest(simulateRequest(source));
|
||||
String responseString = EntityUtils.toString(response.getEntity());
|
||||
assertThat(responseString, containsString("\"prediction_probability\":1.0"));
|
||||
assertThat(responseString, containsString("\"prediction_score\":1.0"));
|
||||
assertThat(responseString, containsString("\"predicted_value\":\"second\""));
|
||||
assertThat(responseString, containsString("\"predicted_value\":1.0"));
|
||||
assertThat(responseString, containsString("\"feature_name\":\"col1\""));
|
||||
|
|
|
@ -95,11 +95,17 @@ public class InferenceRunnerTests extends ESTestCase {
|
|||
LocalModel localModel = localModelInferences(new ClassificationInferenceResults(1.0,
|
||||
"foo",
|
||||
Collections.emptyList(),
|
||||
config),
|
||||
Collections.emptyList(),
|
||||
config,
|
||||
1.0,
|
||||
1.0),
|
||||
new ClassificationInferenceResults(0.0,
|
||||
"bar",
|
||||
Collections.emptyList(),
|
||||
config));
|
||||
Collections.emptyList(),
|
||||
config,
|
||||
.5,
|
||||
.7));
|
||||
|
||||
InferenceRunner inferenceRunner = createInferenceRunner(extractedFields);
|
||||
|
||||
|
@ -117,10 +123,15 @@ public class InferenceRunnerTests extends ESTestCase {
|
|||
|
||||
Map<String, Object> expectedResultsField1 = new HashMap<>();
|
||||
expectedResultsField1.put("predicted_value", "foo");
|
||||
expectedResultsField1.put("prediction_probability", 1.0);
|
||||
expectedResultsField1.put("prediction_score", 1.0);
|
||||
expectedResultsField1.put("predicted_value", "foo");
|
||||
expectedResultsField1.put("is_training", false);
|
||||
|
||||
Map<String, Object> expectedResultsField2 = new HashMap<>();
|
||||
expectedResultsField2.put("predicted_value", "bar");
|
||||
expectedResultsField2.put("prediction_probability", 0.5);
|
||||
expectedResultsField2.put("prediction_score", 0.7);
|
||||
expectedResultsField2.put("is_training", false);
|
||||
|
||||
assertThat(doc1Source.get("test_results_field"), equalTo(expectedResultsField1));
|
||||
|
|
|
@ -76,7 +76,9 @@ public class InternalInferenceAggregationTests extends InternalAggregationTestCa
|
|||
randomResults.getClassificationLabel(),
|
||||
randomResults.getTopClasses(),
|
||||
randomResults.getFeatureImportance(),
|
||||
new ClassificationConfig(null, "value", null, null, randomResults.getPredictionFieldType())
|
||||
new ClassificationConfig(null, "value", null, null, randomResults.getPredictionFieldType()),
|
||||
randomResults.getPredictionProbability(),
|
||||
randomResults.getPredictionScore()
|
||||
);
|
||||
} else if (randomBoolean()) {
|
||||
// build a random result with the result field set to `value`
|
||||
|
|
|
@ -23,6 +23,8 @@ import java.io.IOException;
|
|||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults.PREDICTION_PROBABILITY;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults.PREDICTION_SCORE;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults.FEATURE_IMPORTANCE;
|
||||
|
||||
|
||||
|
@ -44,7 +46,7 @@ public class ParsedInference extends ParsedAggregation {
|
|||
private static final ConstructingObjectParser<ParsedInference, Void> PARSER =
|
||||
new ConstructingObjectParser<>(ParsedInference.class.getSimpleName(), true,
|
||||
args -> new ParsedInference(args[0], (List<FeatureImportance>) args[1],
|
||||
(List<TopClassEntry>) args[2], (String) args[3]));
|
||||
(List<TopClassEntry>) args[2], (String) args[3], (Double) args[4], (Double) args[5]));
|
||||
|
||||
static {
|
||||
PARSER.declareField(optionalConstructorArg(), (p, n) -> {
|
||||
|
@ -68,6 +70,8 @@ public class ParsedInference extends ParsedAggregation {
|
|||
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> TopClassEntry.fromXContent(p),
|
||||
new ParseField(ClassificationConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD));
|
||||
PARSER.declareString(optionalConstructorArg(), new ParseField(WarningInferenceResults.NAME));
|
||||
PARSER.declareDouble(optionalConstructorArg(), new ParseField(PREDICTION_PROBABILITY));
|
||||
PARSER.declareDouble(optionalConstructorArg(), new ParseField(PREDICTION_SCORE));
|
||||
declareAggregationFields(PARSER);
|
||||
}
|
||||
|
||||
|
@ -81,15 +85,21 @@ public class ParsedInference extends ParsedAggregation {
|
|||
private final List<FeatureImportance> featureImportance;
|
||||
private final List<TopClassEntry> topClasses;
|
||||
private final String warning;
|
||||
private final Double predictionProbability;
|
||||
private final Double predictionScore;
|
||||
|
||||
ParsedInference(Object value,
|
||||
List<FeatureImportance> featureImportance,
|
||||
List<TopClassEntry> topClasses,
|
||||
String warning) {
|
||||
String warning,
|
||||
Double predictionProbability,
|
||||
Double predictionScore) {
|
||||
this.value = value;
|
||||
this.warning = warning;
|
||||
this.featureImportance = featureImportance;
|
||||
this.topClasses = topClasses;
|
||||
this.predictionProbability = predictionProbability;
|
||||
this.predictionScore = predictionScore;
|
||||
}
|
||||
|
||||
public Object getValue() {
|
||||
|
@ -120,6 +130,12 @@ public class ParsedInference extends ParsedAggregation {
|
|||
if (featureImportance != null && featureImportance.size() > 0) {
|
||||
builder.field(FEATURE_IMPORTANCE, featureImportance);
|
||||
}
|
||||
if (predictionProbability != null) {
|
||||
builder.field(PREDICTION_PROBABILITY, predictionProbability);
|
||||
}
|
||||
if (predictionScore != null) {
|
||||
builder.field(PREDICTION_SCORE, predictionScore);
|
||||
}
|
||||
}
|
||||
return builder;
|
||||
}
|
||||
|
|
|
@ -68,7 +68,10 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
Collections.singletonList(new ClassificationInferenceResults(1.0,
|
||||
"foo",
|
||||
null,
|
||||
ClassificationConfig.EMPTY_PARAMS)),
|
||||
Collections.emptyList(),
|
||||
ClassificationConfig.EMPTY_PARAMS,
|
||||
1.0,
|
||||
1.0)),
|
||||
true);
|
||||
inferenceProcessor.mutateDocument(response, document);
|
||||
|
||||
|
@ -98,7 +101,13 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
classes.add(new TopClassEntry("bar", 0.4, 0.4));
|
||||
|
||||
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
|
||||
Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", classes, classificationConfig)),
|
||||
Collections.singletonList(new ClassificationInferenceResults(1.0,
|
||||
"foo",
|
||||
classes,
|
||||
Collections.emptyList(),
|
||||
classificationConfig,
|
||||
0.6,
|
||||
0.6)),
|
||||
true);
|
||||
inferenceProcessor.mutateDocument(response, document);
|
||||
|
||||
|
@ -136,7 +145,9 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
"foo",
|
||||
classes,
|
||||
featureInfluence,
|
||||
classificationConfig)),
|
||||
classificationConfig,
|
||||
0.6,
|
||||
0.6)),
|
||||
true);
|
||||
inferenceProcessor.mutateDocument(response, document);
|
||||
|
||||
|
@ -169,7 +180,13 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
classes.add(new TopClassEntry("bar", 0.4, 0.4));
|
||||
|
||||
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
|
||||
Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", classes, classificationConfig)),
|
||||
Collections.singletonList(new ClassificationInferenceResults(1.0,
|
||||
"foo",
|
||||
classes,
|
||||
Collections.emptyList(),
|
||||
classificationConfig,
|
||||
0.6,
|
||||
0.6)),
|
||||
true);
|
||||
inferenceProcessor.mutateDocument(response, document);
|
||||
|
||||
|
|
Loading…
Reference in New Issue