[ML] always write prediction_[score|probability] for classification inference (#60335) (#60397)

In order to unify model inference and analytics results we
need to write the same fields.

prediction_probability and prediction_score are now written
for inference calls against classification models.
This commit is contained in:
Benjamin Trent 2020-07-29 10:58:14 -04:00 committed by GitHub
parent 9d4a64e749
commit 76359aaa53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 231 additions and 57 deletions

View File

@ -27,32 +27,40 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
public static final String NAME = "classification";
public static final String PREDICTION_PROBABILITY = "prediction_probability";
public static final String PREDICTION_SCORE = "prediction_score";
private final String topNumClassesField;
private final String resultsField;
private final String classificationLabel;
private final Double predictionProbability;
private final Double predictionScore;
private final List<TopClassEntry> topClasses;
private final PredictionFieldType predictionFieldType;
public ClassificationInferenceResults(double value,
String classificationLabel,
List<TopClassEntry> topClasses,
InferenceConfig config) {
this(value, classificationLabel, topClasses, Collections.emptyList(), (ClassificationConfig)config);
}
public ClassificationInferenceResults(double value,
String classificationLabel,
List<TopClassEntry> topClasses,
List<FeatureImportance> featureImportance,
InferenceConfig config) {
this(value, classificationLabel, topClasses, featureImportance, (ClassificationConfig)config);
InferenceConfig config,
Double predictionProbability,
Double predictionScore) {
this(value,
classificationLabel,
topClasses,
featureImportance,
(ClassificationConfig)config,
predictionProbability,
predictionScore);
}
private ClassificationInferenceResults(double value,
String classificationLabel,
List<TopClassEntry> topClasses,
List<FeatureImportance> featureImportance,
ClassificationConfig classificationConfig) {
ClassificationConfig classificationConfig,
Double predictionProbability,
Double predictionScore) {
super(value,
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
classificationConfig.getNumTopFeatureImportanceValues()));
@ -61,6 +69,8 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
this.topNumClassesField = classificationConfig.getTopClassesResultsField();
this.resultsField = classificationConfig.getResultsField();
this.predictionFieldType = classificationConfig.getPredictionFieldType();
this.predictionProbability = predictionProbability;
this.predictionScore = predictionScore;
}
public ClassificationInferenceResults(StreamInput in) throws IOException {
@ -74,6 +84,13 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
} else {
this.predictionFieldType = PredictionFieldType.STRING;
}
if (in.getVersion().onOrAfter(Version.V_7_9_0)) {
this.predictionProbability = in.readOptionalDouble();
this.predictionScore = in.readOptionalDouble();
} else {
this.predictionProbability = topClasses.size() > 0 ? topClasses.get(0).getProbability() : null;
this.predictionScore = topClasses.size() > 0 ? topClasses.get(0).getScore() : null;
}
}
public String getClassificationLabel() {
@ -98,6 +115,10 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
out.writeEnum(predictionFieldType);
}
if (out.getVersion().onOrAfter(Version.V_7_9_0)) {
out.writeOptionalDouble(predictionProbability);
out.writeOptionalDouble(predictionScore);
}
}
@Override
@ -111,6 +132,8 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
&& Objects.equals(topNumClassesField, that.topNumClassesField)
&& Objects.equals(topClasses, that.topClasses)
&& Objects.equals(predictionFieldType, that.predictionFieldType)
&& Objects.equals(predictionProbability, that.predictionProbability)
&& Objects.equals(predictionScore, that.predictionScore)
&& Objects.equals(getFeatureImportance(), that.getFeatureImportance());
}
@ -121,6 +144,8 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
topClasses,
resultsField,
topNumClassesField,
predictionProbability,
predictionScore,
getFeatureImportance(),
predictionFieldType);
}
@ -142,6 +167,14 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
document.setFieldValue(parentResultField, asMap());
}
public Double getPredictionProbability() {
return predictionProbability;
}
public Double getPredictionScore() {
return predictionScore;
}
@Override
public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
@ -149,6 +182,12 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
if (topClasses.isEmpty() == false) {
map.put(topNumClassesField, topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
}
if (predictionProbability != null) {
map.put(PREDICTION_PROBABILITY, predictionProbability);
}
if (predictionScore != null) {
map.put(PREDICTION_SCORE, predictionScore);
}
if (getFeatureImportance().isEmpty() == false) {
map.put(FEATURE_IMPORTANCE, getFeatureImportance().stream().map(FeatureImportance::toMap).collect(Collectors.toList()));
}
@ -166,6 +205,12 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
if (topClasses.size() > 0) {
builder.field(topNumClassesField, topClasses);
}
if (predictionProbability != null) {
builder.field(PREDICTION_PROBABILITY, predictionProbability);
}
if (predictionScore != null) {
builder.field(PREDICTION_SCORE, predictionScore);
}
if (getFeatureImportance().size() > 0) {
builder.field(FEATURE_IMPORTANCE, getFeatureImportance());
}

View File

@ -28,7 +28,7 @@ public final class InferenceHelpers {
/**
* @return Tuple of the highest scored index and the top classes
*/
public static Tuple<Integer, List<TopClassEntry>> topClasses(double[] probabilities,
public static Tuple<TopClassificationValue, List<TopClassEntry>> topClasses(double[] probabilities,
List<String> classificationLabels,
@Nullable double[] classificationWeights,
int numToInclude,
@ -55,8 +55,11 @@ public final class InferenceHelpers {
.mapToInt(i -> i)
.toArray();
final TopClassificationValue topClassificationValue = new TopClassificationValue(sortedIndices[0],
probabilities[sortedIndices[0]],
scores[sortedIndices[0]]);
if (numToInclude == 0) {
return Tuple.tuple(sortedIndices[0], Collections.emptyList());
return Tuple.tuple(topClassificationValue, Collections.emptyList());
}
List<String> labels = classificationLabels == null ?
@ -74,7 +77,7 @@ public final class InferenceHelpers {
scores[idx]));
}
return Tuple.tuple(sortedIndices[0], topClassEntries);
return Tuple.tuple(topClassificationValue, topClassEntries);
}
public static String classificationLabel(Integer inferenceValue, @Nullable List<String> classificationLabels) {
@ -155,4 +158,28 @@ public final class InferenceHelpers {
}
return sumTo;
}
public static class TopClassificationValue {
private final int value;
private final double probability;
private final double score;
TopClassificationValue(int value, double probability, double score) {
this.value = value;
this.probability = probability;
this.score = score;
}
public int getValue() {
return value;
}
public double getProbability() {
return probability;
}
public double getScore() {
return score;
}
}
}

View File

@ -206,17 +206,20 @@ public class EnsembleInferenceModel implements InferenceModel {
ClassificationConfig classificationConfig = (ClassificationConfig) config;
assert classificationWeights == null || processedInferences.length == classificationWeights.length;
// Adjust the probabilities according to the thresholds
Tuple<Integer, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
Tuple<InferenceHelpers.TopClassificationValue, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
processedInferences,
classificationLabels,
classificationWeights,
classificationConfig.getNumTopClasses(),
classificationConfig.getPredictionFieldType());
return new ClassificationInferenceResults((double)topClasses.v1(),
classificationLabel(topClasses.v1(), classificationLabels),
final InferenceHelpers.TopClassificationValue value = topClasses.v1();
return new ClassificationInferenceResults((double)value.getValue(),
classificationLabel(topClasses.v1().getValue(), classificationLabels),
topClasses.v2(),
transformFeatureImportance(decodedFeatureImportance, classificationLabels),
config);
config,
value.getProbability(),
value.getScore());
default:
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on ensemble model");
}

View File

@ -174,17 +174,20 @@ public class TreeInferenceModel implements InferenceModel {
switch (targetType) {
case CLASSIFICATION:
ClassificationConfig classificationConfig = (ClassificationConfig) config;
Tuple<Integer, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
Tuple<InferenceHelpers.TopClassificationValue, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
classificationProbability(value),
classificationLabels,
null,
classificationConfig.getNumTopClasses(),
classificationConfig.getPredictionFieldType());
return new ClassificationInferenceResults(topClasses.v1(),
classificationLabel(topClasses.v1(), classificationLabels),
final InferenceHelpers.TopClassificationValue classificationValue = topClasses.v1();
return new ClassificationInferenceResults(classificationValue.getValue(),
classificationLabel(classificationValue.getValue(), classificationLabels),
topClasses.v2(),
InferenceHelpers.transformFeatureImportance(decodedFeatureImportance, classificationLabels),
config);
config,
classificationValue.getProbability(),
classificationValue.getScore());
case REGRESSION:
return new RegressionInferenceResults(value[0],
config,

View File

@ -135,19 +135,22 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
double[] probabilities = softMax(scores);
ClassificationConfig classificationConfig = (ClassificationConfig) config;
Tuple<Integer, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
Tuple<InferenceHelpers.TopClassificationValue, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
probabilities,
LANGUAGE_NAMES,
null,
classificationConfig.getNumTopClasses(),
PredictionFieldType.STRING);
assert topClasses.v1() >= 0 && topClasses.v1() < LANGUAGE_NAMES.size() :
final InferenceHelpers.TopClassificationValue classificationValue = topClasses.v1();
assert classificationValue.getValue() >= 0 && classificationValue.getValue() < LANGUAGE_NAMES.size() :
"Invalid language predicted. Predicted language index " + topClasses.v1();
return new ClassificationInferenceResults(topClasses.v1(),
LANGUAGE_NAMES.get(topClasses.v1()),
return new ClassificationInferenceResults(classificationValue.getValue(),
LANGUAGE_NAMES.get(classificationValue.getValue()),
topClasses.v2(),
Collections.emptyList(),
classificationConfig);
classificationConfig,
classificationValue.getProbability(),
classificationValue.getScore());
}
@Override

View File

@ -49,12 +49,20 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
Stream.generate(featureImportanceCtor)
.limit(randomIntBetween(1, 10))
.collect(Collectors.toList()),
config);
config,
randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false),
randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false));
}
public void testWriteResultsWithClassificationLabel() {
ClassificationInferenceResults result =
new ClassificationInferenceResults(1.0, "foo", Collections.emptyList(), ClassificationConfig.EMPTY_PARAMS);
new ClassificationInferenceResults(1.0,
"foo",
Collections.emptyList(),
Collections.emptyList(),
ClassificationConfig.EMPTY_PARAMS,
1.0,
1.0);
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
result.writeResult(document, "result_field");
@ -65,7 +73,10 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
ClassificationInferenceResults result = new ClassificationInferenceResults(1.0,
null,
Collections.emptyList(),
ClassificationConfig.EMPTY_PARAMS);
Collections.emptyList(),
ClassificationConfig.EMPTY_PARAMS,
1.0,
1.0);
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
result.writeResult(document, "result_field");
@ -81,7 +92,10 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
ClassificationInferenceResults result = new ClassificationInferenceResults(1.0,
"foo",
entries,
new ClassificationConfig(3, "my_results", "bar", null, PredictionFieldType.STRING));
Collections.emptyList(),
new ClassificationConfig(3, "my_results", "bar", null, PredictionFieldType.STRING),
0.7,
0.7);
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
result.writeResult(document, "result_field");
@ -108,7 +122,9 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
"foo",
Collections.emptyList(),
importanceList,
new ClassificationConfig(0, "predicted_value", "top_classes", 3, PredictionFieldType.STRING));
new ClassificationConfig(0, "predicted_value", "top_classes", 3, PredictionFieldType.STRING),
1.0,
1.0);
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
result.writeResult(document, "result_field");
@ -142,36 +158,65 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
public void testToXContent() {
ClassificationConfig toStringConfig = new ClassificationConfig(1, null, null, null, PredictionFieldType.STRING);
ClassificationInferenceResults result = new ClassificationInferenceResults(1.0, null, null, toStringConfig);
ClassificationInferenceResults result = new ClassificationInferenceResults(1.0,
null,
null,
Collections.emptyList(),
toStringConfig,
1.0,
1.0);
String stringRep = Strings.toString(result);
String expected = "{\"predicted_value\":\"1.0\"}";
String expected = "{\"predicted_value\":\"1.0\",\"prediction_probability\":1.0,\"prediction_score\":1.0}";
assertEquals(expected, stringRep);
ClassificationConfig toDoubleConfig = new ClassificationConfig(1, null, null, null, PredictionFieldType.NUMBER);
result = new ClassificationInferenceResults(1.0, null, null, toDoubleConfig);
result = new ClassificationInferenceResults(1.0, null, null, Collections.emptyList(), toDoubleConfig,
1.0,
1.0);
stringRep = Strings.toString(result);
expected = "{\"predicted_value\":1.0}";
expected = "{\"predicted_value\":1.0,\"prediction_probability\":1.0,\"prediction_score\":1.0}";
assertEquals(expected, stringRep);
ClassificationConfig boolFieldConfig = new ClassificationConfig(1, null, null, null, PredictionFieldType.BOOLEAN);
result = new ClassificationInferenceResults(1.0, null, null, boolFieldConfig);
result = new ClassificationInferenceResults(1.0, null, null, Collections.emptyList(), boolFieldConfig,
1.0,
1.0);
stringRep = Strings.toString(result);
expected = "{\"predicted_value\":true}";
expected = "{\"predicted_value\":true,\"prediction_probability\":1.0,\"prediction_score\":1.0}";
assertEquals(expected, stringRep);
ClassificationConfig config = new ClassificationConfig(1);
result = new ClassificationInferenceResults(1.0, "label1", null, config);
result = new ClassificationInferenceResults(1.0, "label1", null, Collections.emptyList(), config,
1.0,
1.0);
stringRep = Strings.toString(result);
expected = "{\"predicted_value\":\"label1\"}";
expected = "{\"predicted_value\":\"label1\",\"prediction_probability\":1.0,\"prediction_score\":1.0}";
assertEquals(expected, stringRep);
FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyMap());
TopClassEntry tp = new TopClassEntry("class", 1.0, 1.0);
result = new ClassificationInferenceResults(1.0, "label1", Collections.singletonList(tp),
Collections.singletonList(fi), config);
Collections.singletonList(fi), config,
1.0,
1.0);
stringRep = Strings.toString(result);
expected = "{\"predicted_value\":\"label1\"," +
"\"top_classes\":[{\"class_name\":\"class\",\"class_probability\":1.0,\"class_score\":1.0}]}";
"\"top_classes\":[{\"class_name\":\"class\",\"class_probability\":1.0,\"class_score\":1.0}]," +
"\"prediction_probability\":1.0,\"prediction_score\":1.0}";
assertEquals(expected, stringRep);
config = new ClassificationConfig(0);
result = new ClassificationInferenceResults(1.0,
"label1",
Collections.emptyList(),
Collections.emptyList(),
config,
1.0,
1.0);
stringRep = Strings.toString(result);
expected = "{\"predicted_value\":\"label1\",\"prediction_probability\":1.0,\"prediction_score\":1.0}";
assertEquals(expected, stringRep);
}
}

View File

@ -206,7 +206,7 @@ public class InferenceIngestIT extends ESRestTestCase {
" \"inference\": {\n" +
" \"target_field\": \"ml.classification\",\n" +
" \"inference_config\": {\"classification\": " +
" {\"num_top_classes\":2, " +
" {\"num_top_classes\":0, " +
" \"top_classes_results_field\": \"result_class_prob\"," +
" \"num_top_feature_importance_values\": 2" +
" }},\n" +
@ -245,6 +245,8 @@ public class InferenceIngestIT extends ESRestTestCase {
Response response = client().performRequest(simulateRequest(source));
String responseString = EntityUtils.toString(response.getEntity());
assertThat(responseString, containsString("\"prediction_probability\":1.0"));
assertThat(responseString, containsString("\"prediction_score\":1.0"));
assertThat(responseString, containsString("\"predicted_value\":\"second\""));
assertThat(responseString, containsString("\"predicted_value\":1.0"));
assertThat(responseString, containsString("\"feature_name\":\"col1\""));

View File

@ -95,11 +95,17 @@ public class InferenceRunnerTests extends ESTestCase {
LocalModel localModel = localModelInferences(new ClassificationInferenceResults(1.0,
"foo",
Collections.emptyList(),
config),
Collections.emptyList(),
config,
1.0,
1.0),
new ClassificationInferenceResults(0.0,
"bar",
Collections.emptyList(),
config));
Collections.emptyList(),
config,
.5,
.7));
InferenceRunner inferenceRunner = createInferenceRunner(extractedFields);
@ -117,10 +123,15 @@ public class InferenceRunnerTests extends ESTestCase {
Map<String, Object> expectedResultsField1 = new HashMap<>();
expectedResultsField1.put("predicted_value", "foo");
expectedResultsField1.put("prediction_probability", 1.0);
expectedResultsField1.put("prediction_score", 1.0);
expectedResultsField1.put("predicted_value", "foo");
expectedResultsField1.put("is_training", false);
Map<String, Object> expectedResultsField2 = new HashMap<>();
expectedResultsField2.put("predicted_value", "bar");
expectedResultsField2.put("prediction_probability", 0.5);
expectedResultsField2.put("prediction_score", 0.7);
expectedResultsField2.put("is_training", false);
assertThat(doc1Source.get("test_results_field"), equalTo(expectedResultsField1));

View File

@ -76,7 +76,9 @@ public class InternalInferenceAggregationTests extends InternalAggregationTestCa
randomResults.getClassificationLabel(),
randomResults.getTopClasses(),
randomResults.getFeatureImportance(),
new ClassificationConfig(null, "value", null, null, randomResults.getPredictionFieldType())
new ClassificationConfig(null, "value", null, null, randomResults.getPredictionFieldType()),
randomResults.getPredictionProbability(),
randomResults.getPredictionScore()
);
} else if (randomBoolean()) {
// build a random result with the result field set to `value`

View File

@ -23,6 +23,8 @@ import java.io.IOException;
import java.util.List;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults.PREDICTION_PROBABILITY;
import static org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults.PREDICTION_SCORE;
import static org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults.FEATURE_IMPORTANCE;
@ -44,7 +46,7 @@ public class ParsedInference extends ParsedAggregation {
private static final ConstructingObjectParser<ParsedInference, Void> PARSER =
new ConstructingObjectParser<>(ParsedInference.class.getSimpleName(), true,
args -> new ParsedInference(args[0], (List<FeatureImportance>) args[1],
(List<TopClassEntry>) args[2], (String) args[3]));
(List<TopClassEntry>) args[2], (String) args[3], (Double) args[4], (Double) args[5]));
static {
PARSER.declareField(optionalConstructorArg(), (p, n) -> {
@ -68,6 +70,8 @@ public class ParsedInference extends ParsedAggregation {
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> TopClassEntry.fromXContent(p),
new ParseField(ClassificationConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD));
PARSER.declareString(optionalConstructorArg(), new ParseField(WarningInferenceResults.NAME));
PARSER.declareDouble(optionalConstructorArg(), new ParseField(PREDICTION_PROBABILITY));
PARSER.declareDouble(optionalConstructorArg(), new ParseField(PREDICTION_SCORE));
declareAggregationFields(PARSER);
}
@ -81,15 +85,21 @@ public class ParsedInference extends ParsedAggregation {
private final List<FeatureImportance> featureImportance;
private final List<TopClassEntry> topClasses;
private final String warning;
private final Double predictionProbability;
private final Double predictionScore;
ParsedInference(Object value,
List<FeatureImportance> featureImportance,
List<TopClassEntry> topClasses,
String warning) {
String warning,
Double predictionProbability,
Double predictionScore) {
this.value = value;
this.warning = warning;
this.featureImportance = featureImportance;
this.topClasses = topClasses;
this.predictionProbability = predictionProbability;
this.predictionScore = predictionScore;
}
public Object getValue() {
@ -120,6 +130,12 @@ public class ParsedInference extends ParsedAggregation {
if (featureImportance != null && featureImportance.size() > 0) {
builder.field(FEATURE_IMPORTANCE, featureImportance);
}
if (predictionProbability != null) {
builder.field(PREDICTION_PROBABILITY, predictionProbability);
}
if (predictionScore != null) {
builder.field(PREDICTION_SCORE, predictionScore);
}
}
return builder;
}

View File

@ -68,7 +68,10 @@ public class InferenceProcessorTests extends ESTestCase {
Collections.singletonList(new ClassificationInferenceResults(1.0,
"foo",
null,
ClassificationConfig.EMPTY_PARAMS)),
Collections.emptyList(),
ClassificationConfig.EMPTY_PARAMS,
1.0,
1.0)),
true);
inferenceProcessor.mutateDocument(response, document);
@ -98,7 +101,13 @@ public class InferenceProcessorTests extends ESTestCase {
classes.add(new TopClassEntry("bar", 0.4, 0.4));
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", classes, classificationConfig)),
Collections.singletonList(new ClassificationInferenceResults(1.0,
"foo",
classes,
Collections.emptyList(),
classificationConfig,
0.6,
0.6)),
true);
inferenceProcessor.mutateDocument(response, document);
@ -136,7 +145,9 @@ public class InferenceProcessorTests extends ESTestCase {
"foo",
classes,
featureInfluence,
classificationConfig)),
classificationConfig,
0.6,
0.6)),
true);
inferenceProcessor.mutateDocument(response, document);
@ -169,7 +180,13 @@ public class InferenceProcessorTests extends ESTestCase {
classes.add(new TopClassEntry("bar", 0.4, 0.4));
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", classes, classificationConfig)),
Collections.singletonList(new ClassificationInferenceResults(1.0,
"foo",
classes,
Collections.emptyList(),
classificationConfig,
0.6,
0.6)),
true);
inferenceProcessor.mutateDocument(response, document);