[7.x] Pipeline Inference Aggregation (#58965)

Adds a pipeline aggregation that loads a model and performs inference on the
input aggregation results.
This commit is contained in:
David Kyle 2020-07-03 09:29:04 +01:00 committed by GitHub
parent d22dd437f1
commit f6a0c2c59d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
45 changed files with 2127 additions and 189 deletions

View File

@ -0,0 +1,74 @@
[role="xpack"]
[testenv="basic"]
[[search-aggregations-pipeline-inference-bucket-aggregation]]
=== Inference Bucket Aggregation
A parent pipeline aggregation which loads a pre-trained model and performs inference on the
collated result field from the parent bucket aggregation.
[[inference-bucket-agg-syntax]]
==== Syntax
A `inference` aggregation looks like this in isolation:
[source,js]
--------------------------------------------------
{
"inference": {
"model_id": "a_model_for_inference", <1>
"inference_config": { <2>
"regression_config": {
"num_top_feature_importance_values": 2
}
},
"buckets_path": {
"avg_cost": "avg_agg", <3>
"max_cost": "max_agg"
}
}
}
--------------------------------------------------
// NOTCONSOLE
<1> The ID of model to use.
<2> The optional inference config which overrides the model's default settings
<3> Map the value of `avg_agg` to the model's input field `avg_cost`
[[inference-bucket-params]]
.`inference` Parameters
[options="header"]
|===
|Parameter Name |Description |Required |Default Value
| `model_id` | The ID of the model to load and infer against | Required | -
| `inference_config` | Contains the inference type and its options. There are two types: <<inference-agg-regression-opt,`regression`>> and <<inference-agg-classification-opt,`classification`>> | Optional | -
| `buckets_path` | Defines the paths to the input aggregations and maps the aggregation names to the field names expected by the model.
See <<buckets-path-syntax>> for more details | Required | -
|===
==== Configuration options for {infer} models
The `inference_config` setting is optional and usaully isn't required as the pre-trained models come equipped with sensible defaults.
In the context of aggregations some options can overridden for each of the 2 types of model.
[discrete]
[[inference-agg-regression-opt]]
===== Configuration options for {regression} models
`num_top_feature_importance_values`::
(Optional, integer)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-regression-num-top-feature-importance-values]
[discrete]
[[inference-agg-classification-opt]]
===== Configuration options for {classification} models
`num_top_classes`::
(Optional, integer)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-classification-num-top-classes]
`num_top_feature_importance_values`::
(Optional, integer)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-classification-num-top-feature-importance-values]
`prediction_field_type`::
(Optional, string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-classification-prediction-field-type]

View File

@ -84,8 +84,11 @@ public abstract class BasePipelineAggregationTestCase<AF extends AbstractPipelin
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
entries.addAll(indicesModule.getNamedWriteables());
entries.addAll(searchModule.getNamedWriteables());
entries.addAll(additionalNamedWriteables());
namedWriteableRegistry = new NamedWriteableRegistry(entries);
xContentRegistry = new NamedXContentRegistry(searchModule.getNamedXContents());
List<NamedXContentRegistry.Entry> xContentEntries = searchModule.getNamedXContents();
xContentEntries.addAll(additionalNamedContents());
xContentRegistry = new NamedXContentRegistry(xContentEntries);
//create some random type with some default field, those types will stick around for all of the subclasses
currentTypes = new String[randomIntBetween(0, 5)];
for (int i = 0; i < currentTypes.length; i++) {
@ -101,6 +104,20 @@ public abstract class BasePipelineAggregationTestCase<AF extends AbstractPipelin
return emptyList();
}
/**
* Any extra named writeables required not registered by {@link SearchModule}
*/
protected List<NamedWriteableRegistry.Entry> additionalNamedWriteables() {
return emptyList();
}
/**
* Any extra named xcontents required not registered by {@link SearchModule}
*/
protected List<NamedXContentRegistry.Entry> additionalNamedContents() {
return emptyList();
}
/**
* Generic test that creates new AggregatorFactory from the test
* AggregatorFactory and checks both for equality and asserts equality on

View File

@ -12,6 +12,7 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.CustomWordEmbeddi
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
@ -20,6 +21,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedInf
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
@ -121,6 +123,8 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
ClassificationConfigUpdate::fromXContentStrict));
namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfigUpdate.class, RegressionConfigUpdate.NAME,
RegressionConfigUpdate::fromXContentStrict));
namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfigUpdate.class, ResultsFieldUpdate.NAME,
ResultsFieldUpdate::fromXContent));
// Inference models
namedXContent.add(new NamedXContentRegistry.Entry(InferenceModel.class, Ensemble.NAME, EnsembleInferenceModel::fromXContent));
@ -170,6 +174,9 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
RegressionInferenceResults.NAME,
RegressionInferenceResults::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
WarningInferenceResults.NAME,
WarningInferenceResults::new));
// Inference Configs
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,

View File

@ -6,10 +6,9 @@
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
@ -18,9 +17,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
@ -85,6 +82,10 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
return topClasses;
}
public PredictionFieldType getPredictionFieldType() {
return predictionFieldType;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
@ -127,6 +128,11 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
return classificationLabel == null ? super.valueAsString() : classificationLabel;
}
@Override
public Object predictedValue() {
return predictionFieldType.transformPredictedValue(value(), valueAsString());
}
@Override
public void writeResult(IngestDocument document, String parentResultField) {
ExceptionsHelper.requireNonNull(document, "document");
@ -138,7 +144,7 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
}
if (getFeatureImportance().size() > 0) {
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance()
document.setFieldValue(parentResultField + "." + FEATURE_IMPORTANCE, getFeatureImportance()
.stream()
.map(FeatureImportance::toMap)
.collect(Collectors.toList()));
@ -150,74 +156,15 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
return NAME;
}
public static class TopClassEntry implements Writeable {
public final ParseField CLASS_NAME = new ParseField("class_name");
public final ParseField CLASS_PROBABILITY = new ParseField("class_probability");
public final ParseField CLASS_SCORE = new ParseField("class_score");
private final Object classification;
private final double probability;
private final double score;
public TopClassEntry(Object classification, double probability, double score) {
this.classification = ExceptionsHelper.requireNonNull(classification, CLASS_NAME);
this.probability = probability;
this.score = score;
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(resultsField, predictionFieldType.transformPredictedValue(value(), valueAsString()));
if (topClasses.size() > 0) {
builder.field(topNumClassesField, topClasses);
}
public TopClassEntry(StreamInput in) throws IOException {
if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
this.classification = in.readGenericValue();
} else {
this.classification = in.readString();
}
this.probability = in.readDouble();
this.score = in.readDouble();
}
public Object getClassification() {
return classification;
}
public double getProbability() {
return probability;
}
public double getScore() {
return score;
}
public Map<String, Object> asValueMap() {
Map<String, Object> map = new HashMap<>(3, 1.0f);
map.put(CLASS_NAME.getPreferredName(), classification);
map.put(CLASS_PROBABILITY.getPreferredName(), probability);
map.put(CLASS_SCORE.getPreferredName(), score);
return map;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
out.writeGenericValue(classification);
} else {
out.writeString(classification.toString());
}
out.writeDouble(probability);
out.writeDouble(score);
}
@Override
public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
TopClassEntry that = (TopClassEntry) object;
return Objects.equals(classification, that.classification) && probability == that.probability && score == that.score;
}
@Override
public int hashCode() {
return Objects.hash(classification, probability, score);
if (getFeatureImportance().size() > 0) {
builder.field(FEATURE_IMPORTANCE, getFeatureImportance());
}
return builder;
}
}

View File

@ -5,23 +5,33 @@
*/
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;
public class FeatureImportance implements Writeable {
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
public class FeatureImportance implements Writeable, ToXContentObject {
private final Map<String, Double> classImportance;
private final double importance;
private final String featureName;
private static final String IMPORTANCE = "importance";
private static final String FEATURE_NAME = "feature_name";
static final String IMPORTANCE = "importance";
static final String FEATURE_NAME = "feature_name";
static final String CLASS_IMPORTANCE = "class_importance";
public static FeatureImportance forRegression(String featureName, double importance) {
return new FeatureImportance(featureName, importance, null);
@ -31,7 +41,24 @@ public class FeatureImportance implements Writeable {
return new FeatureImportance(featureName, classImportance.values().stream().mapToDouble(Math::abs).sum(), classImportance);
}
private FeatureImportance(String featureName, double importance, Map<String, Double> classImportance) {
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<FeatureImportance, Void> PARSER =
new ConstructingObjectParser<>("feature_importance",
a -> new FeatureImportance((String) a[0], (Double) a[1], (Map<String, Double>) a[2])
);
static {
PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME));
PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
new ParseField(FeatureImportance.CLASS_IMPORTANCE));
}
public static FeatureImportance fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
FeatureImportance(String featureName, double importance, Map<String, Double> classImportance) {
this.featureName = Objects.requireNonNull(featureName);
this.importance = importance;
this.classImportance = classImportance == null ? null : Collections.unmodifiableMap(classImportance);
@ -79,6 +106,22 @@ public class FeatureImportance implements Writeable {
return map;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FEATURE_NAME, featureName);
builder.field(IMPORTANCE, importance);
if (classImportance != null && classImportance.isEmpty() == false) {
builder.startObject(CLASS_IMPORTANCE);
for (Map.Entry<String, Double> entry : classImportance.entrySet()) {
builder.field(entry.getKey(), entry.getValue());
}
builder.endObject();
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object object) {
if (object == this) { return true; }
@ -93,5 +136,4 @@ public class FeatureImportance implements Writeable {
public int hashCode() {
return Objects.hash(featureName, importance, classImportance);
}
}

View File

@ -6,10 +6,12 @@
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContentFragment;
import org.elasticsearch.ingest.IngestDocument;
public interface InferenceResults extends NamedWriteable {
public interface InferenceResults extends NamedWriteable, ToXContentFragment {
void writeResult(IngestDocument document, String parentResultField);
Object predictedValue();
}

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.ingest.IngestDocument;
import java.io.IOException;
@ -56,9 +57,18 @@ public class RawInferenceResults implements InferenceResults {
throw new UnsupportedOperationException("[raw] does not support writing inference results");
}
@Override
public Object predictedValue() {
return null;
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
throw new UnsupportedOperationException("[raw] does not support toXContent");
}
}

View File

@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
@ -25,18 +26,28 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
private final String resultsField;
public RegressionInferenceResults(double value, InferenceConfig config) {
this(value, (RegressionConfig) config, Collections.emptyList());
this(value, config, Collections.emptyList());
}
public RegressionInferenceResults(double value, InferenceConfig config, List<FeatureImportance> featureImportance) {
this(value, (RegressionConfig)config, featureImportance);
this(value, ((RegressionConfig)config).getResultsField(),
((RegressionConfig)config).getNumTopFeatureImportanceValues(), featureImportance);
}
private RegressionInferenceResults(double value, RegressionConfig regressionConfig, List<FeatureImportance> featureImportance) {
public RegressionInferenceResults(double value, String resultsField) {
this(value, resultsField, 0, Collections.emptyList());
}
public RegressionInferenceResults(double value, String resultsField,
List<FeatureImportance> featureImportance) {
this(value, resultsField, featureImportance.size(), featureImportance);
}
public RegressionInferenceResults(double value, String resultsField, int topNFeatures,
List<FeatureImportance> featureImportance) {
super(value,
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
regressionConfig.getNumTopFeatureImportanceValues()));
this.resultsField = regressionConfig.getResultsField();
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance, topNFeatures));
this.resultsField = resultsField;
}
public RegressionInferenceResults(StreamInput in) throws IOException {
@ -65,6 +76,11 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
return Objects.hash(value(), resultsField, getFeatureImportance());
}
@Override
public Object predictedValue() {
return super.value();
}
@Override
public void writeResult(IngestDocument document, String parentResultField) {
ExceptionsHelper.requireNonNull(document, "document");
@ -78,9 +94,17 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
}
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(resultsField, value());
if (getFeatureImportance().size() > 0) {
builder.field(FEATURE_IMPORTANCE, getFeatureImportance());
}
return builder;
}
@Override
public String getWriteableName() {
return NAME;
}
}

View File

@ -16,6 +16,8 @@ import java.util.stream.Collectors;
public abstract class SingleValueInferenceResults implements InferenceResults {
public static final String FEATURE_IMPORTANCE = "feature_importance";
private final double value;
private final List<FeatureImportance> featureImportance;

View File

@ -0,0 +1,138 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParseException;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
public class TopClassEntry implements Writeable, ToXContentObject {
public static final ParseField CLASS_NAME = new ParseField("class_name");
public static final ParseField CLASS_PROBABILITY = new ParseField("class_probability");
public static final ParseField CLASS_SCORE = new ParseField("class_score");
public static final String NAME = "top_class";
private static final ConstructingObjectParser<TopClassEntry, Void> PARSER =
new ConstructingObjectParser<>(NAME, a -> new TopClassEntry(a[0], (Double) a[1], (Double) a[2]));
static {
PARSER.declareField(constructorArg(), (p, n) -> {
Object o;
XContentParser.Token token = p.currentToken();
if (token == XContentParser.Token.VALUE_STRING) {
o = p.text();
} else if (token == XContentParser.Token.VALUE_BOOLEAN) {
o = p.booleanValue();
} else if (token == XContentParser.Token.VALUE_NUMBER) {
o = p.doubleValue();
} else {
throw new XContentParseException(p.getTokenLocation(),
"[" + NAME + "] failed to parse field [" + CLASS_NAME + "] value [" + token
+ "] is not a string, boolean or number");
}
return o;
}, CLASS_NAME, ObjectParser.ValueType.VALUE);
PARSER.declareDouble(constructorArg(), CLASS_PROBABILITY);
PARSER.declareDouble(constructorArg(), CLASS_SCORE);
}
public static TopClassEntry fromXContent(XContentParser parser) throws IOException {
return PARSER.parse(parser, null);
}
private final Object classification;
private final double probability;
private final double score;
public TopClassEntry(Object classification, double probability, double score) {
this.classification = ExceptionsHelper.requireNonNull(classification, CLASS_NAME);
this.probability = probability;
this.score = score;
}
public TopClassEntry(StreamInput in) throws IOException {
if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
this.classification = in.readGenericValue();
} else {
this.classification = in.readString();
}
this.probability = in.readDouble();
this.score = in.readDouble();
}
public Object getClassification() {
return classification;
}
public double getProbability() {
return probability;
}
public double getScore() {
return score;
}
public Map<String, Object> asValueMap() {
Map<String, Object> map = new HashMap<>(3, 1.0f);
map.put(CLASS_NAME.getPreferredName(), classification);
map.put(CLASS_PROBABILITY.getPreferredName(), probability);
map.put(CLASS_SCORE.getPreferredName(), score);
return map;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CLASS_NAME.getPreferredName(), classification);
builder.field(CLASS_PROBABILITY.getPreferredName(), probability);
builder.field(CLASS_SCORE.getPreferredName(), score);
builder.endObject();
return builder;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
out.writeGenericValue(classification);
} else {
out.writeString(classification.toString());
}
out.writeDouble(probability);
out.writeDouble(score);
}
@Override
public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
TopClassEntry that = (TopClassEntry) object;
return Objects.equals(classification, that.classification) && probability == that.probability && score == that.score;
}
@Override
public int hashCode() {
return Objects.hash(classification, probability, score);
}
}

View File

@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@ -55,12 +56,22 @@ public class WarningInferenceResults implements InferenceResults {
public void writeResult(IngestDocument document, String parentResultField) {
ExceptionsHelper.requireNonNull(document, "document");
ExceptionsHelper.requireNonNull(parentResultField, "resultField");
document.setFieldValue(parentResultField + "." + "warning", warning);
document.setFieldValue(parentResultField + "." + NAME, warning);
}
@Override
public Object predictedValue() {
return null;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(NAME, warning);
return builder;
}
@Override
public String getWriteableName() {
return NAME;
}
}

View File

@ -63,18 +63,18 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate {
config.getPredictionFieldType());
}
private static final ObjectParser<ClassificationConfigUpdate.Builder, Void> STRICT_PARSER = createParser(false);
private static final ObjectParser<Builder, Void> STRICT_PARSER = createParser(false);
private static ObjectParser<ClassificationConfigUpdate.Builder, Void> createParser(boolean lenient) {
ObjectParser<ClassificationConfigUpdate.Builder, Void> parser = new ObjectParser<>(
private static ObjectParser<Builder, Void> createParser(boolean lenient) {
ObjectParser<Builder, Void> parser = new ObjectParser<>(
NAME.getPreferredName(),
lenient,
ClassificationConfigUpdate.Builder::new);
parser.declareInt(ClassificationConfigUpdate.Builder::setNumTopClasses, NUM_TOP_CLASSES);
parser.declareString(ClassificationConfigUpdate.Builder::setResultsField, RESULTS_FIELD);
parser.declareString(ClassificationConfigUpdate.Builder::setTopClassesResultsField, TOP_CLASSES_RESULTS_FIELD);
parser.declareInt(ClassificationConfigUpdate.Builder::setNumTopFeatureImportanceValues, NUM_TOP_FEATURE_IMPORTANCE_VALUES);
parser.declareString(ClassificationConfigUpdate.Builder::setPredictionFieldType, PREDICTION_FIELD_TYPE);
Builder::new);
parser.declareInt(Builder::setNumTopClasses, NUM_TOP_CLASSES);
parser.declareString(Builder::setResultsField, RESULTS_FIELD);
parser.declareString(Builder::setTopClassesResultsField, TOP_CLASSES_RESULTS_FIELD);
parser.declareInt(Builder::setNumTopFeatureImportanceValues, NUM_TOP_FEATURE_IMPORTANCE_VALUES);
parser.declareString(Builder::setPredictionFieldType, PREDICTION_FIELD_TYPE);
return parser;
}
@ -96,6 +96,8 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate {
}
this.numTopFeatureImportanceValues = featureImportance;
this.predictionFieldType = predictionFieldType;
InferenceConfigUpdate.checkFieldUniqueness(resultsField, topClassesResultsField);
}
public ClassificationConfigUpdate(StreamInput in) throws IOException {
@ -118,6 +120,16 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate {
return resultsField;
}
@Override
public InferenceConfigUpdate.Builder<? extends InferenceConfigUpdate.Builder<?, ?>, ? extends InferenceConfigUpdate> newBuilder() {
return new Builder()
.setNumTopClasses(numTopClasses)
.setTopClassesResultsField(topClassesResultsField)
.setResultsField(resultsField)
.setNumTopFeatureImportanceValues(numTopFeatureImportanceValues)
.setPredictionFieldType(predictionFieldType);
}
public Integer getNumTopFeatureImportanceValues() {
return numTopFeatureImportanceValues;
}
@ -235,14 +247,14 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate {
&& (predictionFieldType == null || predictionFieldType.equals(originalConfig.getPredictionFieldType()));
}
public static class Builder {
public static class Builder implements InferenceConfigUpdate.Builder<Builder, ClassificationConfigUpdate> {
private Integer numTopClasses;
private String topClassesResultsField;
private String resultsField;
private Integer numTopFeatureImportanceValues;
private PredictionFieldType predictionFieldType;
public Builder setNumTopClasses(int numTopClasses) {
public Builder setNumTopClasses(Integer numTopClasses) {
this.numTopClasses = numTopClasses;
return this;
}
@ -252,12 +264,13 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate {
return this;
}
@Override
public Builder setResultsField(String resultsField) {
this.resultsField = resultsField;
return this;
}
public Builder setNumTopFeatureImportanceValues(int numTopFeatureImportanceValues) {
public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) {
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
return this;
}
@ -271,6 +284,7 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate {
return setPredictionFieldType(PredictionFieldType.fromString(predictionFieldType));
}
@Override
public ClassificationConfigUpdate build() {
return new ClassificationConfigUpdate(numTopClasses,
resultsField,

View File

@ -6,14 +6,53 @@
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
public interface InferenceConfigUpdate extends NamedXContentObject, NamedWriteable {
Set<String> RESERVED_ML_FIELD_NAMES = new HashSet<>(Arrays.asList(
WarningInferenceResults.WARNING.getPreferredName(),
TrainedModelConfig.MODEL_ID.getPreferredName()));
InferenceConfig apply(InferenceConfig originalConfig);
InferenceConfig toConfig();
boolean isSupported(InferenceConfig config);
String getResultsField();
interface Builder<T extends Builder<T, U>, U extends InferenceConfigUpdate> {
U build();
T setResultsField(String resultsField);
}
Builder<? extends Builder<?, ?>, ? extends InferenceConfigUpdate> newBuilder();
static void checkFieldUniqueness(String... fieldNames) {
Set<String> duplicatedFieldNames = new HashSet<>();
Set<String> currentFieldNames = new HashSet<>(RESERVED_ML_FIELD_NAMES);
for(String fieldName : fieldNames) {
if (fieldName == null) {
continue;
}
if (currentFieldNames.contains(fieldName)) {
duplicatedFieldNames.add(fieldName);
} else {
currentFieldNames.add(fieldName);
}
}
if (duplicatedFieldNames.isEmpty() == false) {
throw ExceptionsHelper.badRequestException("Cannot apply inference config." +
" More than one field is configured as {}",
duplicatedFieldNames);
}
}
}

View File

@ -7,8 +7,8 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.util.ArrayList;
@ -28,11 +28,11 @@ public final class InferenceHelpers {
/**
* @return Tuple of the highest scored index and the top classes
*/
public static Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses(double[] probabilities,
List<String> classificationLabels,
@Nullable double[] classificationWeights,
int numToInclude,
PredictionFieldType predictionFieldType) {
public static Tuple<Integer, List<TopClassEntry>> topClasses(double[] probabilities,
List<String> classificationLabels,
@Nullable double[] classificationWeights,
int numToInclude,
PredictionFieldType predictionFieldType) {
if (classificationLabels != null && probabilities.length != classificationLabels.size()) {
throw ExceptionsHelper
@ -65,10 +65,10 @@ public final class InferenceHelpers {
classificationLabels;
int count = numToInclude < 0 ? probabilities.length : Math.min(numToInclude, probabilities.length);
List<ClassificationInferenceResults.TopClassEntry> topClassEntries = new ArrayList<>(count);
List<TopClassEntry> topClassEntries = new ArrayList<>(count);
for(int i = 0; i < count; i++) {
int idx = sortedIndices[i];
topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(
topClassEntries.add(new TopClassEntry(
predictionFieldType.transformPredictedValue((double)idx, labels.get(idx)),
probabilities[idx],
scores[idx]));

View File

@ -18,7 +18,6 @@ import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig.DEFAULT_RESULTS_FIELD;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig.RESULTS_FIELD;
@ -68,6 +67,8 @@ public class RegressionConfigUpdate implements InferenceConfigUpdate {
"] must be greater than or equal to 0");
}
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
InferenceConfigUpdate.checkFieldUniqueness(resultsField);
}
public RegressionConfigUpdate(StreamInput in) throws IOException {
@ -75,12 +76,19 @@ public class RegressionConfigUpdate implements InferenceConfigUpdate {
this.numTopFeatureImportanceValues = in.readOptionalVInt();
}
public int getNumTopFeatureImportanceValues() {
return numTopFeatureImportanceValues == null ? 0 : numTopFeatureImportanceValues;
public Integer getNumTopFeatureImportanceValues() {
return numTopFeatureImportanceValues;
}
public String getResultsField() {
return resultsField == null ? DEFAULT_RESULTS_FIELD : resultsField;
return resultsField;
}
@Override
public InferenceConfigUpdate.Builder<? extends InferenceConfigUpdate.Builder<?, ?>, ? extends InferenceConfigUpdate> newBuilder() {
return new Builder()
.setNumTopFeatureImportanceValues(numTopFeatureImportanceValues)
.setResultsField(resultsField);
}
@Override
@ -165,10 +173,11 @@ public class RegressionConfigUpdate implements InferenceConfigUpdate {
|| originalConfig.getNumTopFeatureImportanceValues() == numTopFeatureImportanceValues);
}
public static class Builder {
public static class Builder implements InferenceConfigUpdate.Builder<Builder, RegressionConfigUpdate> {
private String resultsField;
private Integer numTopFeatureImportanceValues;
@Override
public Builder setResultsField(String resultsField) {
this.resultsField = resultsField;
return this;

View File

@ -0,0 +1,137 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig.RESULTS_FIELD;
/**
* A config update that sets the results field only.
* Supports any type of {@link InferenceConfig}
*/
public class ResultsFieldUpdate implements InferenceConfigUpdate {
public static final ParseField NAME = new ParseField("field_update");
private static final ConstructingObjectParser<ResultsFieldUpdate, Void> PARSER =
new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new ResultsFieldUpdate((String) args[0]));
static {
PARSER.declareString(constructorArg(), RESULTS_FIELD);
}
public static ResultsFieldUpdate fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final String resultsField;
public ResultsFieldUpdate(String resultsField) {
this.resultsField = Objects.requireNonNull(resultsField);
}
public ResultsFieldUpdate(StreamInput in) throws IOException {
resultsField = in.readString();
}
@Override
public InferenceConfig apply(InferenceConfig originalConfig) {
if (originalConfig instanceof ClassificationConfig) {
ClassificationConfigUpdate update = new ClassificationConfigUpdate(null, resultsField, null, null, null);
return update.apply(originalConfig);
} else if (originalConfig instanceof RegressionConfig) {
RegressionConfigUpdate update = new RegressionConfigUpdate(resultsField, null);
return update.apply(originalConfig);
} else {
throw ExceptionsHelper.badRequestException(
"Inference config of unknown type [{}] can not be updated", originalConfig.getName());
}
}
@Override
public InferenceConfig toConfig() {
return new RegressionConfig(resultsField);
}
@Override
public boolean isSupported(InferenceConfig config) {
return true;
}
@Override
public String getResultsField() {
return resultsField;
}
@Override
public InferenceConfigUpdate.Builder<? extends InferenceConfigUpdate.Builder<?, ?>, ? extends InferenceConfigUpdate> newBuilder() {
return new Builder().setResultsField(resultsField);
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(resultsField);
}
@Override
public String getName() {
return NAME.getPreferredName();
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ResultsFieldUpdate that = (ResultsFieldUpdate) o;
return Objects.equals(resultsField, that.resultsField);
}
@Override
public int hashCode() {
return Objects.hashCode(resultsField);
}
public static class Builder implements InferenceConfigUpdate.Builder<Builder, ResultsFieldUpdate> {
private String resultsField;
@Override
public Builder setResultsField(String resultsField) {
this.resultsField = resultsField;
return this;
}
@Override
public ResultsFieldUpdate build() {
return new ResultsFieldUpdate(resultsField);
}
}
}

View File

@ -7,6 +7,7 @@
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentParser;
@ -14,6 +15,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInference
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
@ -85,12 +87,12 @@ public class EnsembleInferenceModel implements InferenceModel {
private EnsembleInferenceModel(List<InferenceModel> models,
OutputAggregator outputAggregator,
TargetType targetType,
List<String> classificationLabels,
@Nullable List<String> classificationLabels,
List<Double> classificationWeights) {
this.models = ExceptionsHelper.requireNonNull(models, TRAINED_MODELS);
this.outputAggregator = ExceptionsHelper.requireNonNull(outputAggregator, AGGREGATE_OUTPUT);
this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE);
this.classificationLabels = classificationLabels == null ? null : classificationLabels;
this.classificationLabels = classificationLabels;
this.classificationWeights = classificationWeights == null ?
null :
classificationWeights.stream().mapToDouble(Double::doubleValue).toArray();
@ -204,7 +206,7 @@ public class EnsembleInferenceModel implements InferenceModel {
ClassificationConfig classificationConfig = (ClassificationConfig) config;
assert classificationWeights == null || processedInferences.length == classificationWeights.length;
// Adjust the probabilities according to the thresholds
Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses = InferenceHelpers.topClasses(
Tuple<Integer, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
processedInferences,
classificationLabels,
classificationWeights,

View File

@ -17,6 +17,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInference
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
@ -173,7 +174,7 @@ public class TreeInferenceModel implements InferenceModel {
switch (targetType) {
case CLASSIFICATION:
ClassificationConfig classificationConfig = (ClassificationConfig) config;
Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses = InferenceHelpers.topClasses(
Tuple<Integer, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
classificationProbability(value),
classificationLabels,
null,

View File

@ -16,6 +16,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
@ -134,7 +135,7 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
double[] probabilities = softMax(scores);
ClassificationConfig classificationConfig = (ClassificationConfig) config;
Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses = InferenceHelpers.topClasses(
Tuple<Integer, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
probabilities,
LANGUAGE_NAMES,
null,

View File

@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
@ -31,21 +32,24 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
FeatureImportanceTests::randomClassification :
FeatureImportanceTests::randomRegression;
return new ClassificationInferenceResults(randomDouble(),
ClassificationConfig config = ClassificationConfigTests.randomClassificationConfig();
Double value = randomDouble();
if (config.getPredictionFieldType() == PredictionFieldType.BOOLEAN) {
// value must be close to 0 or 1
value = randomBoolean() ? 0.0 : 1.0;
}
return new ClassificationInferenceResults(value,
randomBoolean() ? null : randomAlphaOfLength(10),
randomBoolean() ? null :
Stream.generate(ClassificationInferenceResultsTests::createRandomClassEntry)
Stream.generate(TopClassEntryTests::createRandomTopClassEntry)
.limit(randomIntBetween(0, 10))
.collect(Collectors.toList()),
randomBoolean() ? null :
Stream.generate(featureImportanceCtor)
.limit(randomIntBetween(1, 10))
.collect(Collectors.toList()),
ClassificationConfigTests.randomClassificationConfig());
}
private static ClassificationInferenceResults.TopClassEntry createRandomClassEntry() {
return new ClassificationInferenceResults.TopClassEntry(randomAlphaOfLength(10), randomDouble(), randomDouble());
config);
}
public void testWriteResultsWithClassificationLabel() {
@ -70,10 +74,10 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
@SuppressWarnings("unchecked")
public void testWriteResultsWithTopClasses() {
List<ClassificationInferenceResults.TopClassEntry> entries = Arrays.asList(
new ClassificationInferenceResults.TopClassEntry("foo", 0.7, 0.7),
new ClassificationInferenceResults.TopClassEntry("bar", 0.2, 0.2),
new ClassificationInferenceResults.TopClassEntry("baz", 0.1, 0.1));
List<TopClassEntry> entries = Arrays.asList(
new TopClassEntry("foo", 0.7, 0.7),
new TopClassEntry("bar", 0.2, 0.2),
new TopClassEntry("baz", 0.1, 0.1));
ClassificationInferenceResults result = new ClassificationInferenceResults(1.0,
"foo",
entries,
@ -84,8 +88,8 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
List<?> list = document.getFieldValue("result_field.bar", List.class);
assertThat(list.size(), equalTo(3));
for(int i = 0; i < 3; i++) {
Map<String, Object> map = (Map<String, Object>)list.get(i);
for (int i = 0; i < 3; i++) {
Map<String, Object> map = (Map<String, Object>) list.get(i);
assertThat(map, equalTo(entries.get(i).asValueMap()));
}
@ -110,11 +114,11 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
assertThat(document.getFieldValue("result_field.predicted_value", String.class), equalTo("foo"));
@SuppressWarnings("unchecked")
List<Map<String, Object>> writtenImportance = (List<Map<String, Object>>)document.getFieldValue(
List<Map<String, Object>> writtenImportance = (List<Map<String, Object>>) document.getFieldValue(
"result_field.feature_importance",
List.class);
assertThat(writtenImportance, hasSize(3));
importanceList.sort((l, r)-> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())));
importanceList.sort((l, r) -> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())));
for (int i = 0; i < 3; i++) {
Map<String, Object> objectMap = writtenImportance.get(i);
FeatureImportance importance = importanceList.get(i);
@ -135,4 +139,39 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
protected Writeable.Reader<ClassificationInferenceResults> instanceReader() {
return ClassificationInferenceResults::new;
}
public void testToXContent() {
ClassificationConfig toStringConfig = new ClassificationConfig(1, null, null, null, PredictionFieldType.STRING);
ClassificationInferenceResults result = new ClassificationInferenceResults(1.0, null, null, toStringConfig);
String stringRep = Strings.toString(result);
String expected = "{\"predicted_value\":\"1.0\"}";
assertEquals(expected, stringRep);
ClassificationConfig toDoubleConfig = new ClassificationConfig(1, null, null, null, PredictionFieldType.NUMBER);
result = new ClassificationInferenceResults(1.0, null, null, toDoubleConfig);
stringRep = Strings.toString(result);
expected = "{\"predicted_value\":1.0}";
assertEquals(expected, stringRep);
ClassificationConfig boolFieldConfig = new ClassificationConfig(1, null, null, null, PredictionFieldType.BOOLEAN);
result = new ClassificationInferenceResults(1.0, null, null, boolFieldConfig);
stringRep = Strings.toString(result);
expected = "{\"predicted_value\":true}";
assertEquals(expected, stringRep);
ClassificationConfig config = new ClassificationConfig(1);
result = new ClassificationInferenceResults(1.0, "label1", null, config);
stringRep = Strings.toString(result);
expected = "{\"predicted_value\":\"label1\"}";
assertEquals(expected, stringRep);
FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyMap());
TopClassEntry tp = new TopClassEntry("class", 1.0, 1.0);
result = new ClassificationInferenceResults(1.0, "label1", Collections.singletonList(tp),
Collections.singletonList(fi), config);
stringRep = Strings.toString(result);
expected = "{\"predicted_value\":\"label1\"," +
"\"top_classes\":[{\"class_name\":\"class\",\"class_probability\":1.0,\"class_score\":1.0}]}";
assertEquals(expected, stringRep);
}
}

View File

@ -6,14 +6,15 @@
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class FeatureImportanceTests extends AbstractWireSerializingTestCase<FeatureImportance> {
public class FeatureImportanceTests extends AbstractSerializingTestCase<FeatureImportance> {
public static FeatureImportance createRandomInstance() {
return randomBoolean() ? randomClassification() : randomRegression();
@ -29,7 +30,6 @@ public class FeatureImportanceTests extends AbstractWireSerializingTestCase<Feat
Stream.generate(() -> randomAlphaOfLength(10))
.limit(randomLongBetween(2, 10))
.collect(Collectors.toMap(Function.identity(), (k) -> randomDoubleBetween(-10, 10, false))));
}
@Override
@ -41,4 +41,9 @@ public class FeatureImportanceTests extends AbstractWireSerializingTestCase<Feat
protected Writeable.Reader<FeatureImportance> instanceReader() {
return FeatureImportance::new;
}
@Override
protected FeatureImportance doParseInstance(XContentParser parser) throws IOException {
return FeatureImportance.fromXContent(parser);
}
}

View File

@ -5,12 +5,14 @@
*/
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@ -75,4 +77,18 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest
protected Writeable.Reader<RegressionInferenceResults> instanceReader() {
return RegressionInferenceResults::new;
}
public void testToXContent() {
String resultsField = "ml.results";
RegressionInferenceResults result = new RegressionInferenceResults(1.0, resultsField);
String stringRep = Strings.toString(result);
String expected = "{\"" + resultsField + "\":1.0}";
assertEquals(expected, stringRep);
FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyMap());
result = new RegressionInferenceResults(1.0, resultsField, Collections.singletonList(fi));
stringRep = Strings.toString(result);
expected = "{\"" + resultsField + "\":1.0,\"feature_importance\":[{\"feature_name\":\"foo\",\"importance\":1.0}]}";
assertEquals(expected, stringRep);
}
}

View File

@ -0,0 +1,43 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException;
public class TopClassEntryTests extends AbstractSerializingTestCase<TopClassEntry> {
public static TopClassEntry createRandomTopClassEntry() {
Object classification;
if (randomBoolean()) {
classification = randomAlphaOfLength(10);
} else if (randomBoolean()) {
classification = randomBoolean();
} else {
classification = randomDouble();
}
return new TopClassEntry(classification, randomDouble(), randomDouble());
}
@Override
protected TopClassEntry doParseInstance(XContentParser parser) throws IOException {
return TopClassEntry.fromXContent(parser);
}
@Override
protected Writeable.Reader<TopClassEntry> instanceReader() {
return TopClassEntry::new;
}
@Override
protected TopClassEntry createTestInstance() {
return createRandomTopClassEntry();
}
}

View File

@ -5,15 +5,29 @@
*/
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException;
import java.util.HashMap;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.hamcrest.Matchers.equalTo;
public class WarningInferenceResultsTests extends AbstractWireSerializingTestCase<WarningInferenceResults> {
public class WarningInferenceResultsTests extends AbstractSerializingTestCase<WarningInferenceResults> {
private static final ConstructingObjectParser<WarningInferenceResults, Void> PARSER =
new ConstructingObjectParser<>("inference_warning",
a -> new WarningInferenceResults((String) a[0])
);
static {
PARSER.declareString(constructorArg(), new ParseField(WarningInferenceResults.NAME));
}
public static WarningInferenceResults createRandomResults() {
return new WarningInferenceResults(randomAlphaOfLength(10));
@ -36,4 +50,9 @@ public class WarningInferenceResultsTests extends AbstractWireSerializingTestCas
protected Writeable.Reader<WarningInferenceResults> instanceReader() {
return WarningInferenceResults::new;
}
@Override
protected WarningInferenceResults doParseInstance(XContentParser parser) throws IOException {
return PARSER.apply(parser, null);
}
}

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
@ -18,6 +19,7 @@ import java.util.Map;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests.randomClassificationConfig;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
public class ClassificationConfigUpdateTests extends AbstractBWCSerializationTestCase<ClassificationConfigUpdate> {
@ -74,6 +76,30 @@ public class ClassificationConfigUpdateTests extends AbstractBWCSerializationTes
));
}
public void testDuplicateFieldNamesThrow() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new ClassificationConfigUpdate(5, "foo", "foo", 1, PredictionFieldType.BOOLEAN));
assertEquals("Cannot apply inference config. More than one field is configured as [foo]", e.getMessage());
}
public void testDuplicateWithResultsField() {
ClassificationConfigUpdate update = randomClassificationConfigUpdate();
String newFieldName = update.getResultsField() + "_value";
InferenceConfigUpdate updateWithField = update.newBuilder().setResultsField(newFieldName).build();
assertNotSame(updateWithField, update);
assertEquals(newFieldName, updateWithField.getResultsField());
// other fields are the same
assertThat(updateWithField, instanceOf(ClassificationConfigUpdate.class));
ClassificationConfigUpdate classUpdate = (ClassificationConfigUpdate)updateWithField;
assertEquals(update.getTopClassesResultsField(), classUpdate.getTopClassesResultsField());
assertEquals(update.getNumTopClasses(), classUpdate.getNumTopClasses());
assertEquals(update.getPredictionFieldType(), classUpdate.getPredictionFieldType());
assertEquals(update.getNumTopFeatureImportanceValues(), classUpdate.getNumTopFeatureImportanceValues());
}
@Override
protected ClassificationConfigUpdate createTestInstance() {
return randomClassificationConfigUpdate();

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
@ -18,6 +19,7 @@ import java.util.Map;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests.randomRegressionConfig;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
public class RegressionConfigUpdateTests extends AbstractBWCSerializationTestCase<RegressionConfigUpdate> {
@ -60,6 +62,26 @@ public class RegressionConfigUpdateTests extends AbstractBWCSerializationTestCas
));
}
public void testInvalidResultFieldNotUnique() {
ElasticsearchStatusException e =
expectThrows(ElasticsearchStatusException.class, () -> new RegressionConfigUpdate("warning", 0));
assertEquals("Cannot apply inference config. More than one field is configured as [warning]", e.getMessage());
}
public void testNewBuilder() {
RegressionConfigUpdate update = randomRegressionConfigUpdate();
String newFieldName = update.getResultsField() + "_value";
InferenceConfigUpdate updateWithField = update.newBuilder().setResultsField(newFieldName).build();
assertNotSame(updateWithField, update);
assertEquals(newFieldName, updateWithField.getResultsField());
// other fields are the same
assertThat(updateWithField, instanceOf(RegressionConfigUpdate.class));
assertEquals(update.getNumTopFeatureImportanceValues(),
((RegressionConfigUpdate)updateWithField).getNumTopFeatureImportanceValues());
}
@Override
protected RegressionConfigUpdate createTestInstance() {
return randomRegressionConfigUpdate();

View File

@ -0,0 +1,65 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException;
import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.Mockito.mock;
public class ResultsFieldUpdateTests extends AbstractSerializingTestCase<ResultsFieldUpdate> {
@Override
protected ResultsFieldUpdate doParseInstance(XContentParser parser) throws IOException {
return ResultsFieldUpdate.fromXContent(parser);
}
@Override
protected Writeable.Reader<ResultsFieldUpdate> instanceReader() {
return ResultsFieldUpdate::new;
}
@Override
protected ResultsFieldUpdate createTestInstance() {
return new ResultsFieldUpdate(randomAlphaOfLength(4));
}
public void testIsSupported() {
ResultsFieldUpdate update = new ResultsFieldUpdate("foo");
assertTrue(update.isSupported(mock(InferenceConfig.class)));
}
public void testApply_OnlyTheResultsFieldIsChanged() {
if (randomBoolean()) {
ClassificationConfig config = ClassificationConfigTests.randomClassificationConfig();
String newResultsField = config.getResultsField() + "foobar";
ResultsFieldUpdate update = new ResultsFieldUpdate(newResultsField);
InferenceConfig applied = update.apply(config);
assertThat(applied, instanceOf(ClassificationConfig.class));
ClassificationConfig appliedConfig = (ClassificationConfig)applied;
assertEquals(newResultsField, appliedConfig.getResultsField());
assertEquals(appliedConfig, new ClassificationConfig.Builder(config).setResultsField(newResultsField).build());
} else {
RegressionConfig config = RegressionConfigTests.randomRegressionConfig();
String newResultsField = config.getResultsField() + "foobar";
ResultsFieldUpdate update = new ResultsFieldUpdate(newResultsField);
InferenceConfig applied = update.apply(config);
assertThat(applied, instanceOf(RegressionConfig.class));
RegressionConfig appliedConfig = (RegressionConfig)applied;
assertEquals(newResultsField, appliedConfig.getResultsField());
assertEquals(appliedConfig, new RegressionConfig.Builder(config).setResultsField(newResultsField).build());
}
}
}

View File

@ -12,6 +12,7 @@ import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
@ -137,7 +138,7 @@ public class EnsembleInferenceModelTests extends ESTestCase {
List<Double> expected = Arrays.asList(0.768524783, 0.231475216);
List<Double> scores = Arrays.asList(0.230557435, 0.162032651);
double eps = 0.000001;
List<ClassificationInferenceResults.TopClassEntry> probabilities =
List<TopClassEntry> probabilities =
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expected.size(); i++) {

View File

@ -12,6 +12,7 @@ import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
@ -179,7 +180,7 @@ public class TreeInferenceModelTests extends ESTestCase {
List<Double> expectedProbs = Arrays.asList(1.0, 0.0);
List<String> expectedFields = Arrays.asList("dog", "cat");
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
List<ClassificationInferenceResults.TopClassEntry> probabilities =
List<TopClassEntry> probabilities =
((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expectedProbs.size(); i++) {

View File

@ -11,7 +11,7 @@ dependencies {
// bring in machine learning rest test suite
restResources {
restApi {
includeCore '_common', 'cluster', 'nodes', 'indices', 'index', 'search', 'get', 'count', 'ingest'
includeCore '_common', 'cluster', 'nodes', 'indices', 'index', 'search', 'get', 'count', 'ingest', 'bulk'
includeXpack 'ml', 'cat'
}
restTests {

View File

@ -7,7 +7,7 @@ minimal:
# Give all users involved in these tests access to the indices where the data to
# be analyzed is stored, because the ML roles alone do not provide access to
# non-ML indices
- names: [ 'airline-data', 'index-*', 'unavailable-data', 'utopia' ]
- names: [ 'airline-data', 'index-*', 'unavailable-data', 'utopia', 'store' ]
privileges:
- create_index
- indices:admin/refresh

View File

@ -8,6 +8,8 @@ package org.elasticsearch.smoketest;
import com.carrotsearch.randomizedtesting.annotations.Name;
import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate;
import org.elasticsearch.test.rest.yaml.section.DoSection;
import org.elasticsearch.test.rest.yaml.section.ExecutableSection;
import java.io.IOException;
@ -16,8 +18,11 @@ import static org.hamcrest.Matchers.either;
public class MlWithSecurityInsufficientRoleIT extends MlWithSecurityIT {
private final ClientYamlTestCandidate testCandidate;
public MlWithSecurityInsufficientRoleIT(@Name("yaml") ClientYamlTestCandidate testCandidate) {
super(testCandidate);
this.testCandidate = testCandidate;
}
@Override
@ -26,7 +31,18 @@ public class MlWithSecurityInsufficientRoleIT extends MlWithSecurityIT {
// Cannot use expectThrows here because blacklisted tests will throw an
// InternalAssumptionViolatedException rather than an AssertionError
super.test();
fail("should have failed because of missing role");
// We should have got here if and only if no ML endpoints were called
for (ExecutableSection section : testCandidate.getTestSection().getExecutableSections()) {
if (section instanceof DoSection) {
String apiName = ((DoSection) section).getApiCallSection().getApi();
if (apiName.startsWith("ml.")) {
fail("call to ml endpoint should have failed because of missing role");
}
}
}
} catch (AssertionError ae) {
// Some tests assert on searches of wildcarded ML indices rather than on ML endpoints. For these we expect no hits.
if (ae.getMessage().contains("hits.total didn't match expected value")) {

View File

@ -47,7 +47,7 @@ public class MlWithSecurityUserRoleIT extends MlWithSecurityIT {
if (section instanceof DoSection) {
String apiName = ((DoSection) section).getApiCallSection().getApi();
if (((DoSection) section).getApiCallSection().getApi().startsWith("ml.") && isAllowed(apiName) == false) {
if (apiName.startsWith("ml.") && isAllowed(apiName) == false) {
fail("should have failed because of missing role");
}
}

View File

@ -22,8 +22,8 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodeRole;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Module;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.inject.Module;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.IndexScopedSettings;
@ -35,6 +35,7 @@ import org.elasticsearch.common.settings.SettingsModule;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.ContextParser;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.env.Environment;
import org.elasticsearch.env.NodeEnvironment;
@ -53,11 +54,13 @@ import org.elasticsearch.plugins.CircuitBreakerPlugin;
import org.elasticsearch.plugins.IngestPlugin;
import org.elasticsearch.plugins.PersistentTaskPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.plugins.SystemIndexPlugin;
import org.elasticsearch.repositories.RepositoriesService;
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestHandler;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.threadpool.ExecutorBuilder;
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
import org.elasticsearch.threadpool.ThreadPool;
@ -220,6 +223,8 @@ import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationP
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
import org.elasticsearch.xpack.ml.inference.aggs.InferencePipelineAggregationBuilder;
import org.elasticsearch.xpack.ml.inference.aggs.InternalInferenceAggregation;
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider;
@ -336,7 +341,8 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
AnalysisPlugin,
CircuitBreakerPlugin,
IngestPlugin,
PersistentTaskPlugin {
PersistentTaskPlugin,
SearchPlugin {
public static final String NAME = "ml";
public static final String BASE_PATH = "/_ml/";
public static final String PRE_V7_BASE_PATH = "/_xpack/ml/";
@ -454,6 +460,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
private final SetOnce<MlMemoryTracker> memoryTracker = new SetOnce<>();
private final SetOnce<ActionFilter> mlUpgradeModeActionFilter = new SetOnce<>();
private final SetOnce<CircuitBreaker> inferenceModelBreaker = new SetOnce<>();
private final SetOnce<ModelLoadingService> modelLoadingService = new SetOnce<>();
public MachineLearning(Settings settings, Path configPath) {
this.settings = settings;
@ -683,6 +690,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
settings,
clusterService.getNodeName(),
inferenceModelBreaker.get());
this.modelLoadingService.set(modelLoadingService);
// Data frame analytics components
AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory,
@ -962,6 +970,18 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
return Collections.singletonMap(MlClassicTokenizer.NAME, MlClassicTokenizerFactory::new);
}
@Override
public List<PipelineAggregationSpec> getPipelineAggregations() {
PipelineAggregationSpec spec = new PipelineAggregationSpec(InferencePipelineAggregationBuilder.NAME,
in -> new InferencePipelineAggregationBuilder(in, modelLoadingService),
(ContextParser<String, ? extends PipelineAggregationBuilder>)
(parser, name) -> InferencePipelineAggregationBuilder.parse(modelLoadingService, name, parser
));
spec.addResultReader(InternalInferenceAggregation::new);
return Collections.singletonList(spec);
}
@Override
public UnaryOperator<Map<String, IndexTemplateMetadata>> getIndexTemplateMetadataUpgrader() {
return UnaryOperator.identity();

View File

@ -0,0 +1,214 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.aggs;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.LatchedActionListener;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdate;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import java.util.TreeMap;
import java.util.concurrent.CountDownLatch;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
public class InferencePipelineAggregationBuilder extends AbstractPipelineAggregationBuilder<InferencePipelineAggregationBuilder> {
public static String NAME = "inference";
public static final ParseField MODEL_ID = new ParseField("model_id");
private static final ParseField INFERENCE_CONFIG = new ParseField("inference_config");
static String AGGREGATIONS_RESULTS_FIELD = "value";
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<InferencePipelineAggregationBuilder,
Tuple<SetOnce<ModelLoadingService>, String>> PARSER = new ConstructingObjectParser<>(
NAME, false,
(args, context) -> new InferencePipelineAggregationBuilder(context.v2(), context.v1(), (Map<String, String>) args[0])
);
static {
PARSER.declareObject(constructorArg(), (p, c) -> p.mapStrings(), BUCKETS_PATH_FIELD);
PARSER.declareString(InferencePipelineAggregationBuilder::setModelId, MODEL_ID);
PARSER.declareNamedObject(InferencePipelineAggregationBuilder::setInferenceConfig,
(p, c, n) -> p.namedObject(InferenceConfigUpdate.class, n, c), INFERENCE_CONFIG);
}
private final Map<String, String> bucketPathMap;
private String modelId;
private InferenceConfigUpdate inferenceConfig;
private final SetOnce<ModelLoadingService> modelLoadingService;
public static InferencePipelineAggregationBuilder parse(SetOnce<ModelLoadingService> modelLoadingService,
String pipelineAggregatorName,
XContentParser parser) {
Tuple<SetOnce<ModelLoadingService>, String> context = new Tuple<>(modelLoadingService, pipelineAggregatorName);
return PARSER.apply(parser, context);
}
public InferencePipelineAggregationBuilder(String name, SetOnce<ModelLoadingService> modelLoadingService,
Map<String, String> bucketsPath) {
super(name, NAME, new TreeMap<>(bucketsPath).values().toArray(new String[] {}));
this.modelLoadingService = modelLoadingService;
this.bucketPathMap = bucketsPath;
}
public InferencePipelineAggregationBuilder(StreamInput in, SetOnce<ModelLoadingService> modelLoadingService) throws IOException {
super(in, NAME);
modelId = in.readString();
bucketPathMap = in.readMap(StreamInput::readString, StreamInput::readString);
inferenceConfig = in.readOptionalNamedWriteable(InferenceConfigUpdate.class);
this.modelLoadingService = modelLoadingService;
}
void setModelId(String modelId) {
this.modelId = modelId;
}
void setInferenceConfig(InferenceConfigUpdate inferenceConfig) {
this.inferenceConfig = inferenceConfig;
}
@Override
protected void validate(ValidationContext context) {
context.validateHasParent(NAME, name);
if (modelId == null) {
context.addValidationError("[model_id] must be set");
}
if (inferenceConfig != null) {
// error if the results field is set and not equal to the only acceptable value
String resultsField = inferenceConfig.getResultsField();
if (Strings.isNullOrEmpty(resultsField) == false && AGGREGATIONS_RESULTS_FIELD.equals(resultsField) == false) {
context.addValidationError("setting option [" + ClassificationConfig.RESULTS_FIELD.getPreferredName()
+ "] to [" + resultsField + "] is not valid for inference aggregations");
}
if (inferenceConfig instanceof ClassificationConfigUpdate) {
ClassificationConfigUpdate classUpdate = (ClassificationConfigUpdate)inferenceConfig;
// error if the top classes result field is set and not equal to the only acceptable value
String topClassesField = classUpdate.getTopClassesResultsField();
if (Strings.isNullOrEmpty(topClassesField) == false &&
ClassificationConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD.equals(topClassesField) == false) {
context.addValidationError("setting option [" + ClassificationConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD
+ "] to [" + topClassesField + "] is not valid for inference aggregations");
}
}
}
}
@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeString(modelId);
out.writeMap(bucketPathMap, StreamOutput::writeString, StreamOutput::writeString);
out.writeOptionalNamedWriteable(inferenceConfig);
}
@Override
protected PipelineAggregator createInternal(Map<String, Object> metaData) {
SetOnce<LocalModel> model = new SetOnce<>();
SetOnce<Exception> error = new SetOnce<>();
CountDownLatch latch = new CountDownLatch(1);
ActionListener<LocalModel> listener = new LatchedActionListener<>(
ActionListener.wrap(model::set, error::set), latch);
modelLoadingService.get().getModelForSearch(modelId, listener);
try {
// TODO Avoid the blocking wait
latch.await();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException("Inference aggregation interrupted loading model", e);
}
Exception e = error.get();
if (e != null) {
if (e instanceof RuntimeException) {
throw (RuntimeException)e;
} else {
throw new RuntimeException(error.get());
}
}
InferenceConfigUpdate update = adaptForAggregation(inferenceConfig);
return new InferencePipelineAggregator(name, bucketPathMap, metaData, update, model.get());
}
static InferenceConfigUpdate adaptForAggregation(InferenceConfigUpdate originalUpdate) {
InferenceConfigUpdate updated;
if (originalUpdate == null) {
updated = new ResultsFieldUpdate(AGGREGATIONS_RESULTS_FIELD);
} else {
// Create an update that changes the default results field.
// This isn't necessary for top classes as the default is the same one used here
updated = originalUpdate.newBuilder().setResultsField(AGGREGATIONS_RESULTS_FIELD).build();
}
return updated;
}
@Override
protected boolean overrideBucketsPath() {
return true;
}
@Override
protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(MODEL_ID.getPreferredName(), modelId);
builder.field(BUCKETS_PATH_FIELD.getPreferredName(), bucketPathMap);
if (inferenceConfig != null) {
builder.startObject(INFERENCE_CONFIG.getPreferredName());
builder.field(inferenceConfig.getName(), inferenceConfig);
builder.endObject();
}
return builder;
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public int hashCode() {
return Objects.hash(super.hashCode(), bucketPathMap, modelId, inferenceConfig);
}
@Override
public boolean equals(Object obj) {
if (this == obj) return true;
if (obj == null || getClass() != obj.getClass()) return false;
if (super.equals(obj) == false) return false;
InferencePipelineAggregationBuilder other = (InferencePipelineAggregationBuilder) obj;
return Objects.equals(bucketPathMap, other.bucketPathMap)
&& Objects.equals(modelId, other.modelId)
&& Objects.equals(inferenceConfig, other.inferenceConfig);
}
}

View File

@ -0,0 +1,133 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.aggs;
import org.elasticsearch.search.aggregations.AggregationExecutionException;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.aggregations.InternalMultiBucketAggregation;
import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation;
import org.elasticsearch.search.aggregations.bucket.terms.StringTerms;
import org.elasticsearch.search.aggregations.metrics.InternalNumericMetricsAggregation;
import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.aggregations.support.AggregationPath;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
public class InferencePipelineAggregator extends PipelineAggregator {
private final Map<String, String> bucketPathMap;
private final InferenceConfigUpdate configUpdate;
private final LocalModel model;
public InferencePipelineAggregator(String name, Map<String,
String> bucketPathMap,
Map<String, Object> metaData,
InferenceConfigUpdate configUpdate,
LocalModel model) {
super(name, bucketPathMap.values().toArray(new String[] {}), metaData);
this.bucketPathMap = bucketPathMap;
this.configUpdate = configUpdate;
this.model = model;
}
@SuppressWarnings({"rawtypes", "unchecked"})
@Override
public InternalAggregation reduce(InternalAggregation aggregation, InternalAggregation.ReduceContext reduceContext) {
InternalMultiBucketAggregation<InternalMultiBucketAggregation, InternalMultiBucketAggregation.InternalBucket> originalAgg =
(InternalMultiBucketAggregation<InternalMultiBucketAggregation, InternalMultiBucketAggregation.InternalBucket>) aggregation;
List<? extends InternalMultiBucketAggregation.InternalBucket> buckets = originalAgg.getBuckets();
List<InternalMultiBucketAggregation.InternalBucket> newBuckets = new ArrayList<>();
for (InternalMultiBucketAggregation.InternalBucket bucket : buckets) {
Map<String, Object> inputFields = new HashMap<>();
if (bucket.getDocCount() == 0) {
// ignore this empty bucket unless the doc count is used
if (bucketPathMap.containsKey("_count") == false) {
newBuckets.add(bucket);
continue;
}
}
for (Map.Entry<String, String> entry : bucketPathMap.entrySet()) {
String aggName = entry.getKey();
String bucketPath = entry.getValue();
Object propertyValue = resolveBucketValue(originalAgg, bucket, bucketPath);
if (propertyValue instanceof Number) {
double doubleVal = ((Number) propertyValue).doubleValue();
// NaN or infinite values indicate a missing value or a
// valid result of an invalid calculation. Either way only
// a valid number will do
if (Double.isFinite(doubleVal)) {
inputFields.put(aggName, doubleVal);
}
} else if (propertyValue instanceof InternalNumericMetricsAggregation.SingleValue) {
double doubleVal = ((InternalNumericMetricsAggregation.SingleValue) propertyValue).value();
if (Double.isFinite(doubleVal)) {
inputFields.put(aggName, doubleVal);
}
} else if (propertyValue instanceof StringTerms.Bucket) {
StringTerms.Bucket b = (StringTerms.Bucket) propertyValue;
inputFields.put(aggName, b.getKeyAsString());
} else if (propertyValue instanceof String) {
inputFields.put(aggName, propertyValue);
} else if (propertyValue != null) {
// Doubles, String terms or null are valid, any other type is an error
throw invalidAggTypeError(bucketPath, propertyValue);
}
}
InferenceResults inference;
try {
inference = model.infer(inputFields, configUpdate);
} catch (Exception e) {
inference = new WarningInferenceResults(e.getMessage());
}
final List<InternalAggregation> aggs = bucket.getAggregations().asList().stream().map(
(p) -> (InternalAggregation) p).collect(Collectors.toList());
InternalInferenceAggregation aggResult = new InternalInferenceAggregation(name(), metadata(), inference);
aggs.add(aggResult);
InternalMultiBucketAggregation.InternalBucket newBucket = originalAgg.createBucket(new InternalAggregations(aggs), bucket);
newBuckets.add(newBucket);
}
return originalAgg.create(newBuckets);
}
public static Object resolveBucketValue(MultiBucketsAggregation agg,
InternalMultiBucketAggregation.InternalBucket bucket,
String aggPath) {
List<String> aggPathsList = AggregationPath.parse(aggPath).getPathElementsAsStringList();
return bucket.getProperty(agg.getName(), aggPathsList);
}
private static AggregationExecutionException invalidAggTypeError(String aggPath, Object propertyValue) {
String msg = AbstractPipelineAggregationBuilder.BUCKETS_PATH_FIELD.getPreferredName() +
" must reference either a number value, a single value numeric metric aggregation or a string: got [" +
propertyValue + "] of type [" + propertyValue.getClass().getSimpleName() + "] " +
"] at aggregation [" + aggPath + "]";
return new AggregationExecutionException(msg);
}
}

View File

@ -0,0 +1,98 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.aggs;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.InvalidAggregationPathException;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Objects;
public class InternalInferenceAggregation extends InternalAggregation {
private final InferenceResults inferenceResult;
protected InternalInferenceAggregation(String name, Map<String, Object> metadata,
InferenceResults inferenceResult) {
super(name, metadata);
this.inferenceResult = inferenceResult;
}
public InternalInferenceAggregation(StreamInput in) throws IOException {
super(in);
inferenceResult = in.readNamedWriteable(InferenceResults.class);
}
public InferenceResults getInferenceResult() {
return inferenceResult;
}
@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeNamedWriteable(inferenceResult);
}
@Override
public InternalAggregation reduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
throw new UnsupportedOperationException("Reducing an inference aggregation is not supported");
}
@Override
public Object getProperty(List<String> path) {
Object propertyValue;
if (path.isEmpty()) {
propertyValue = this;
} else if (path.size() == 1) {
if (CommonFields.VALUE.getPreferredName().equals(path.get(0))) {
propertyValue = inferenceResult.predictedValue();
} else {
throw invalidPathException(path);
}
} else {
throw invalidPathException(path);
}
return propertyValue;
}
private InvalidAggregationPathException invalidPathException(List<String> path) {
return new InvalidAggregationPathException("unknown property " + path + " for " +
InferencePipelineAggregationBuilder.NAME + " aggregation [" + getName() + "]");
}
@Override
public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
return inferenceResult.toXContent(builder, params);
}
@Override
public String getWriteableName() {
return "inference";
}
@Override
public int hashCode() {
return Objects.hash(super.hashCode(), inferenceResult);
}
@Override
public boolean equals(Object obj) {
if (this == obj) return true;
if (obj == null || getClass() != obj.getClass()) return false;
if (super.equals(obj) == false) return false;
InternalInferenceAggregation other = (InternalInferenceAggregation) obj;
return Objects.equals(inferenceResult, other.inferenceResult);
}
}

View File

@ -0,0 +1,131 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.aggs;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParseException;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.ParsedAggregation;
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import java.io.IOException;
import java.util.List;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults.FEATURE_IMPORTANCE;
/**
* There isn't enough information in toXContent representation of the
* {@link org.elasticsearch.xpack.core.ml.inference.results.InferenceResults}
* objects to fully reconstruct them. In particular, depending on which
* fields are written (result value, feature importance) it is not possible to
* distinguish between a Regression result and a Classification result.
*
* This class parses the union all possible fields that may be written by
* InferenceResults.
*
* The warning field is mutually exclusive with all the other fields.
*/
public class ParsedInference extends ParsedAggregation {
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<ParsedInference, Void> PARSER =
new ConstructingObjectParser<>(ParsedInference.class.getSimpleName(), true,
args -> new ParsedInference(args[0], (List<FeatureImportance>) args[1],
(List<TopClassEntry>) args[2], (String) args[3]));
static {
PARSER.declareField(optionalConstructorArg(), (p, n) -> {
Object o;
XContentParser.Token token = p.currentToken();
if (token == XContentParser.Token.VALUE_STRING) {
o = p.text();
} else if (token == XContentParser.Token.VALUE_BOOLEAN) {
o = p.booleanValue();
} else if (token == XContentParser.Token.VALUE_NUMBER) {
o = p.doubleValue();
} else {
throw new XContentParseException(p.getTokenLocation(),
"[" + ParsedInference.class.getSimpleName() + "] failed to parse field [" + CommonFields.VALUE + "] "
+ "value [" + token + "] is not a string, boolean or number");
}
return o;
}, CommonFields.VALUE, ObjectParser.ValueType.VALUE);
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> FeatureImportance.fromXContent(p),
new ParseField(SingleValueInferenceResults.FEATURE_IMPORTANCE));
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> TopClassEntry.fromXContent(p),
new ParseField(ClassificationConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD));
PARSER.declareString(optionalConstructorArg(), new ParseField(WarningInferenceResults.NAME));
declareAggregationFields(PARSER);
}
public static ParsedInference fromXContent(XContentParser parser, final String name) {
ParsedInference parsed = PARSER.apply(parser, null);
parsed.setName(name);
return parsed;
}
private final Object value;
private final List<FeatureImportance> featureImportance;
private final List<TopClassEntry> topClasses;
private final String warning;
ParsedInference(Object value,
List<FeatureImportance> featureImportance,
List<TopClassEntry> topClasses,
String warning) {
this.value = value;
this.warning = warning;
this.featureImportance = featureImportance;
this.topClasses = topClasses;
}
public Object getValue() {
return value;
}
public List<FeatureImportance> getFeatureImportance() {
return featureImportance;
}
public List<TopClassEntry> getTopClasses() {
return topClasses;
}
public String getWarning() {
return warning;
}
@Override
protected XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
if (warning != null) {
builder.field(WarningInferenceResults.WARNING.getPreferredName(), warning);
} else {
builder.field(CommonFields.VALUE.getPreferredName(), value);
if (topClasses != null && topClasses.size() > 0) {
builder.field(ClassificationConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD, topClasses);
}
if (featureImportance != null && featureImportance.size() > 0) {
builder.field(FEATURE_IMPORTANCE, featureImportance);
}
}
return builder;
}
@Override
public String getType() {
return InferencePipelineAggregationBuilder.NAME;
}
}

View File

@ -28,7 +28,6 @@ import org.elasticsearch.ingest.PipelineConfiguration;
import org.elasticsearch.ingest.Processor;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
@ -42,10 +41,8 @@ import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
@ -172,10 +169,6 @@ public class InferenceProcessor extends AbstractProcessor {
private static final int MAX_INFERENCE_PROCESSOR_SEARCH_RECURSIONS = 10;
private static final Logger logger = LogManager.getLogger(Factory.class);
private static final Set<String> RESERVED_ML_FIELD_NAMES = new HashSet<>(Arrays.asList(
WarningInferenceResults.WARNING.getPreferredName(),
MODEL_ID));
private final Client client;
private final InferenceAuditor auditor;
private volatile int currentInferenceProcessors;
@ -333,12 +326,10 @@ public class InferenceProcessor extends AbstractProcessor {
if (inferenceConfig.containsKey(ClassificationConfig.NAME.getPreferredName())) {
checkSupportedVersion(ClassificationConfig.EMPTY_PARAMS);
ClassificationConfigUpdate config = ClassificationConfigUpdate.fromMap(valueMap);
checkFieldUniqueness(config.getResultsField(), config.getTopClassesResultsField());
return config;
} else if (inferenceConfig.containsKey(RegressionConfig.NAME.getPreferredName())) {
checkSupportedVersion(RegressionConfig.EMPTY_PARAMS);
RegressionConfigUpdate config = RegressionConfigUpdate.fromMap(valueMap);
checkFieldUniqueness(config.getResultsField());
return config;
} else {
throw ExceptionsHelper.badRequestException("unrecognized inference configuration type {}. Supported types {}",
@ -347,26 +338,6 @@ public class InferenceProcessor extends AbstractProcessor {
}
}
private static void checkFieldUniqueness(String... fieldNames) {
Set<String> duplicatedFieldNames = new HashSet<>();
Set<String> currentFieldNames = new HashSet<>(RESERVED_ML_FIELD_NAMES);
for(String fieldName : fieldNames) {
if (fieldName == null) {
continue;
}
if (currentFieldNames.contains(fieldName)) {
duplicatedFieldNames.add(fieldName);
} else {
currentFieldNames.add(fieldName);
}
}
if (duplicatedFieldNames.isEmpty() == false) {
throw ExceptionsHelper.badRequestException("Cannot create processor as configured." +
" More than one field is configured as {}",
duplicatedFieldNames);
}
}
void checkSupportedVersion(InferenceConfig config) {
if (config.getMinimalSupportedVersion().after(minNodeVersion)) {
throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION,

View File

@ -22,6 +22,7 @@ import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.atomic.LongAdder;
import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING;
@ -116,6 +117,22 @@ public class LocalModel {
}
}
public InferenceResults infer(Map<String, Object> fields, InferenceConfigUpdate update) throws Exception {
AtomicReference<InferenceResults> result = new AtomicReference<>();
AtomicReference<Exception> exception = new AtomicReference<>();
ActionListener<InferenceResults> listener = ActionListener.wrap(
result::set,
exception::set
);
infer(fields, update, listener);
if (exception.get() != null) {
throw exception.get();
}
return result.get();
}
/**
* Used for translating field names in according to the passed `fieldMappings` parameter.
*

View File

@ -0,0 +1,130 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.aggs;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.BasePipelineAggregationTestCase;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdateTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdateTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdate;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
public class InferencePipelineAggregationBuilderTests extends BasePipelineAggregationTestCase<InferencePipelineAggregationBuilder> {
private static final String NAME = "inf-agg";
@Override
protected List<SearchPlugin> plugins() {
return Collections.singletonList(new MachineLearning(Settings.EMPTY, null));
}
@Override
protected List<NamedXContentRegistry.Entry> additionalNamedContents() {
return new MlInferenceNamedXContentProvider().getNamedXContentParsers();
}
@Override
protected List<NamedWriteableRegistry.Entry> additionalNamedWriteables() {
return new MlInferenceNamedXContentProvider().getNamedWriteables();
}
@Override
protected InferencePipelineAggregationBuilder createTestAggregatorFactory() {
Map<String, String> bucketPaths = Stream.generate(() -> randomAlphaOfLength(8))
.limit(randomIntBetween(1, 4))
.collect(Collectors.toMap(Function.identity(), (t) -> randomAlphaOfLength(5)));
InferencePipelineAggregationBuilder builder =
new InferencePipelineAggregationBuilder(NAME, new SetOnce<>(mock(ModelLoadingService.class)), bucketPaths);
builder.setModelId(randomAlphaOfLength(6));
if (randomBoolean()) {
InferenceConfigUpdate config;
if (randomBoolean()) {
config = ClassificationConfigUpdateTests.randomClassificationConfigUpdate();
} else {
config = RegressionConfigUpdateTests.randomRegressionConfigUpdate();
}
builder.setInferenceConfig(config);
}
return builder;
}
public void testAdaptForAggregation_givenNull() {
InferenceConfigUpdate update = InferencePipelineAggregationBuilder.adaptForAggregation(null);
assertThat(update, is(instanceOf(ResultsFieldUpdate.class)));
assertEquals(InferencePipelineAggregationBuilder.AGGREGATIONS_RESULTS_FIELD, update.getResultsField());
}
public void testAdaptForAggregation() {
RegressionConfigUpdate regressionConfigUpdate = new RegressionConfigUpdate(null, 20);
InferenceConfigUpdate update = InferencePipelineAggregationBuilder.adaptForAggregation(regressionConfigUpdate);
assertEquals(InferencePipelineAggregationBuilder.AGGREGATIONS_RESULTS_FIELD, update.getResultsField());
ClassificationConfigUpdate configUpdate = new ClassificationConfigUpdate(1, null, null, null, null);
update = InferencePipelineAggregationBuilder.adaptForAggregation(configUpdate);
assertEquals(InferencePipelineAggregationBuilder.AGGREGATIONS_RESULTS_FIELD, update.getResultsField());
}
public void testValidate() {
InferencePipelineAggregationBuilder aggregationBuilder = createTestAggregatorFactory();
PipelineAggregationBuilder.ValidationContext validationContext =
PipelineAggregationBuilder.ValidationContext.forInsideTree(mock(AggregationBuilder.class), null);
aggregationBuilder.setModelId(null);
aggregationBuilder.validate(validationContext);
List<String> errors = validationContext.getValidationException().validationErrors();
assertEquals("[model_id] must be set", errors.get(0));
}
public void testValidate_invalidResultsField() {
InferencePipelineAggregationBuilder aggregationBuilder = createTestAggregatorFactory();
PipelineAggregationBuilder.ValidationContext validationContext =
PipelineAggregationBuilder.ValidationContext.forInsideTree(mock(AggregationBuilder.class), null);
RegressionConfigUpdate regressionConfigUpdate = new RegressionConfigUpdate("foo", null);
aggregationBuilder.setInferenceConfig(regressionConfigUpdate);
aggregationBuilder.validate(validationContext);
List<String> errors = validationContext.getValidationException().validationErrors();
assertEquals("setting option [results_field] to [foo] is not valid for inference aggregations", errors.get(0));
}
public void testValidate_invalidTopClassesField() {
InferencePipelineAggregationBuilder aggregationBuilder = createTestAggregatorFactory();
PipelineAggregationBuilder.ValidationContext validationContext =
PipelineAggregationBuilder.ValidationContext.forInsideTree(mock(AggregationBuilder.class), null);
ClassificationConfigUpdate configUpdate = new ClassificationConfigUpdate(1, null, "some_other_field", null, null);
aggregationBuilder.setInferenceConfig(configUpdate);
aggregationBuilder.validate(validationContext);
List<String> errors = validationContext.getValidationException().validationErrors();
assertEquals("setting option [top_classes] to [some_other_field] is not valid for inference aggregations", errors.get(0));
}
}

View File

@ -0,0 +1,227 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.aggs;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.search.aggregations.Aggregation;
import org.elasticsearch.search.aggregations.InvalidAggregationPathException;
import org.elasticsearch.search.aggregations.ParsedAggregation;
import org.elasticsearch.test.InternalAggregationTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResultsTests;
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResultsTests;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.ml.MachineLearning;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import static org.hamcrest.Matchers.sameInstance;
public class InternalInferenceAggregationTests extends InternalAggregationTestCase<InternalInferenceAggregation> {
@Override
protected SearchPlugin registerPlugin() {
return new MachineLearning(Settings.EMPTY, null);
}
@Override
protected List<NamedWriteableRegistry.Entry> getNamedWriteables() {
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>(super.getNamedWriteables());
entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
return entries;
}
@Override
protected List<NamedXContentRegistry.Entry> getNamedXContents() {
List<NamedXContentRegistry.Entry> entries = new ArrayList<>(super.getNamedXContents());
entries.add(new NamedXContentRegistry.Entry(Aggregation.class,
new ParseField(InferencePipelineAggregationBuilder.NAME), (p, c) -> ParsedInference.fromXContent(p, (String)c)));
return entries;
}
@Override
protected Predicate<String> excludePathsFromXContentInsertion() {
return p -> p.contains("top_classes") || p.contains("feature_importance");
}
@Override
protected InternalInferenceAggregation createTestInstance(String name, Map<String, Object> metadata) {
InferenceResults result;
if (randomBoolean()) {
// build a random result with the result field set to `value`
ClassificationInferenceResults randomResults = ClassificationInferenceResultsTests.createRandomResults();
result = new ClassificationInferenceResults(
randomResults.value(),
randomResults.getClassificationLabel(),
randomResults.getTopClasses(),
randomResults.getFeatureImportance(),
new ClassificationConfig(null, "value", null, null, randomResults.getPredictionFieldType())
);
} else if (randomBoolean()) {
// build a random result with the result field set to `value`
RegressionInferenceResults randomResults = RegressionInferenceResultsTests.createRandomResults();
result = new RegressionInferenceResults(
randomResults.value(),
"value",
randomResults.getFeatureImportance());
} else {
result = new WarningInferenceResults("this is a warning");
}
return new InternalInferenceAggregation(name, metadata, result);
}
@Override
public void testReduceRandom() {
expectThrows(UnsupportedOperationException.class, () -> createTestInstance("name", null).reduce(null, null));
}
@Override
protected void assertReduced(InternalInferenceAggregation reduced, List<InternalInferenceAggregation> inputs) {
// no test since reduce operation is unsupported
}
@Override
protected void assertFromXContent(InternalInferenceAggregation agg, ParsedAggregation parsedAggregation) {
ParsedInference parsed = ((ParsedInference) parsedAggregation);
InferenceResults result = agg.getInferenceResult();
if (result instanceof WarningInferenceResults) {
WarningInferenceResults warning = (WarningInferenceResults) result;
assertEquals(warning.getWarning(), parsed.getWarning());
} else if (result instanceof RegressionInferenceResults) {
RegressionInferenceResults regression = (RegressionInferenceResults) result;
assertEquals(regression.value(), parsed.getValue());
List<FeatureImportance> featureImportance = regression.getFeatureImportance();
if (featureImportance.isEmpty()) {
featureImportance = null;
}
assertEquals(featureImportance, parsed.getFeatureImportance());
} else if (result instanceof ClassificationInferenceResults) {
ClassificationInferenceResults classification = (ClassificationInferenceResults) result;
assertEquals(classification.predictedValue(), parsed.getValue());
List<FeatureImportance> featureImportance = classification.getFeatureImportance();
if (featureImportance.isEmpty()) {
featureImportance = null;
}
assertEquals(featureImportance, parsed.getFeatureImportance());
List<TopClassEntry> topClasses = classification.getTopClasses();
if (topClasses.isEmpty()) {
topClasses = null;
}
assertEquals(topClasses, parsed.getTopClasses());
}
}
public void testGetProperty_givenEmptyPath() {
InternalInferenceAggregation internalAgg = createTestInstance();
assertThat(internalAgg, sameInstance(internalAgg.getProperty(Collections.emptyList())));
}
public void testGetProperty_givenTooLongPath() {
InternalInferenceAggregation internalAgg = createTestInstance();
InvalidAggregationPathException e = expectThrows(InvalidAggregationPathException.class,
() -> internalAgg.getProperty(Arrays.asList("one", "two")));
String message = "unknown property [one, two] for inference aggregation [" + internalAgg.getName() + "]";
assertEquals(message, e.getMessage());
}
public void testGetProperty_givenWrongPath() {
InternalInferenceAggregation internalAgg = createTestInstance();
InvalidAggregationPathException e = expectThrows(InvalidAggregationPathException.class,
() -> internalAgg.getProperty(Collections.singletonList("bar")));
String message = "unknown property [bar] for inference aggregation [" + internalAgg.getName() + "]";
assertEquals(message, e.getMessage());
}
public void testGetProperty_value() {
{
ClassificationInferenceResults results = ClassificationInferenceResultsTests.createRandomResults();
InternalInferenceAggregation internalAgg = new InternalInferenceAggregation("foo", Collections.emptyMap(), results);
assertEquals(results.predictedValue(), internalAgg.getProperty(Collections.singletonList("value")));
}
{
RegressionInferenceResults results = RegressionInferenceResultsTests.createRandomResults();
InternalInferenceAggregation internalAgg = new InternalInferenceAggregation("foo", Collections.emptyMap(), results);
assertEquals(results.value(), internalAgg.getProperty(Collections.singletonList("value")));
}
{
WarningInferenceResults results = new WarningInferenceResults("a warning from history");
InternalInferenceAggregation internalAgg = new InternalInferenceAggregation("foo", Collections.emptyMap(), results);
assertNull(internalAgg.getProperty(Collections.singletonList("value")));
}
}
public void testGetProperty_featureImportance() {
{
ClassificationInferenceResults results = ClassificationInferenceResultsTests.createRandomResults();
InternalInferenceAggregation internalAgg = new InternalInferenceAggregation("foo", Collections.emptyMap(), results);
expectThrows(InvalidAggregationPathException.class,
() -> internalAgg.getProperty(Collections.singletonList("feature_importance")));
}
{
RegressionInferenceResults results = RegressionInferenceResultsTests.createRandomResults();
InternalInferenceAggregation internalAgg = new InternalInferenceAggregation("foo", Collections.emptyMap(), results);
expectThrows(InvalidAggregationPathException.class,
() -> internalAgg.getProperty(Collections.singletonList("feature_importance")));
}
{
WarningInferenceResults results = new WarningInferenceResults("a warning from history");
InternalInferenceAggregation internalAgg = new InternalInferenceAggregation("foo", Collections.emptyMap(), results);
expectThrows(InvalidAggregationPathException.class,
() -> internalAgg.getProperty(Collections.singletonList("feature_importance")));
}
}
public void testGetProperty_topClasses() {
{
ClassificationInferenceResults results = ClassificationInferenceResultsTests.createRandomResults();
InternalInferenceAggregation internalAgg = new InternalInferenceAggregation("foo", Collections.emptyMap(), results);
expectThrows(InvalidAggregationPathException.class,
() -> internalAgg.getProperty(Collections.singletonList("top_classes")));
}
{
RegressionInferenceResults results = RegressionInferenceResultsTests.createRandomResults();
InternalInferenceAggregation internalAgg = new InternalInferenceAggregation("foo", Collections.emptyMap(), results);
expectThrows(InvalidAggregationPathException.class,
() -> internalAgg.getProperty(Collections.singletonList("top_classes")));
}
{
WarningInferenceResults results = new WarningInferenceResults("a warning from history");
InternalInferenceAggregation internalAgg = new InternalInferenceAggregation("foo", Collections.emptyMap(), results);
expectThrows(InvalidAggregationPathException.class,
() -> internalAgg.getProperty(Collections.singletonList("top_classes")));
}
}
}

View File

@ -291,7 +291,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, regression);
fail("should not have succeeded creating with duplicate fields");
} catch (Exception ex) {
assertThat(ex.getMessage(), equalTo("Cannot create processor as configured. " +
assertThat(ex.getMessage(), equalTo("Cannot apply inference config. " +
"More than one field is configured as [warning]"));
}
}

View File

@ -12,6 +12,7 @@ import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
@ -92,9 +93,9 @@ public class InferenceProcessorTests extends ESTestCase {
Map<String, Object> ingestMetadata = new HashMap<>();
IngestDocument document = new IngestDocument(source, ingestMetadata);
List<ClassificationInferenceResults.TopClassEntry> classes = new ArrayList<>(2);
classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6, 0.6));
classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4, 0.4));
List<TopClassEntry> classes = new ArrayList<>(2);
classes.add(new TopClassEntry("foo", 0.6, 0.6));
classes.add(new TopClassEntry("bar", 0.4, 0.4));
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", classes, classificationConfig)),
@ -102,7 +103,7 @@ public class InferenceProcessorTests extends ESTestCase {
inferenceProcessor.mutateDocument(response, document);
assertThat((List<Map<?,?>>)document.getFieldValue("ml.my_processor.top_classes", List.class),
contains(classes.stream().map(ClassificationInferenceResults.TopClassEntry::asValueMap).toArray(Map[]::new)));
contains(classes.stream().map(TopClassEntry::asValueMap).toArray(Map[]::new)));
assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("classification_model"));
assertThat(document.getFieldValue("ml.my_processor.predicted_value", String.class), equalTo("foo"));
}
@ -122,9 +123,9 @@ public class InferenceProcessorTests extends ESTestCase {
Map<String, Object> ingestMetadata = new HashMap<>();
IngestDocument document = new IngestDocument(source, ingestMetadata);
List<ClassificationInferenceResults.TopClassEntry> classes = new ArrayList<>(2);
classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6, 0.6));
classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4, 0.4));
List<TopClassEntry> classes = new ArrayList<>(2);
classes.add(new TopClassEntry("foo", 0.6, 0.6));
classes.add(new TopClassEntry("bar", 0.4, 0.4));
List<FeatureImportance> featureInfluence = new ArrayList<>();
featureInfluence.add(FeatureImportance.forRegression("feature_1", 1.13));
@ -163,9 +164,9 @@ public class InferenceProcessorTests extends ESTestCase {
Map<String, Object> ingestMetadata = new HashMap<>();
IngestDocument document = new IngestDocument(source, ingestMetadata);
List<ClassificationInferenceResults.TopClassEntry> classes = new ArrayList<>(2);
classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6, 0.6));
classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4, 0.4));
List<TopClassEntry> classes = new ArrayList<>(2);
classes.add(new TopClassEntry("foo", 0.6, 0.6));
classes.add(new TopClassEntry("bar", 0.4, 0.4));
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", classes, classificationConfig)),
@ -173,7 +174,7 @@ public class InferenceProcessorTests extends ESTestCase {
inferenceProcessor.mutateDocument(response, document);
assertThat((List<Map<?,?>>)document.getFieldValue("ml.my_processor.tops", List.class),
contains(classes.stream().map(ClassificationInferenceResults.TopClassEntry::asValueMap).toArray(Map[]::new)));
contains(classes.stream().map(TopClassEntry::asValueMap).toArray(Map[]::new)));
assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("classification_model"));
assertThat(document.getFieldValue("ml.my_processor.result", String.class), equalTo("foo"));
}

View File

@ -0,0 +1,266 @@
setup:
- skip:
features: headers
- do:
headers:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
ml.put_trained_model:
model_id: a-complex-regression-model
body: >
{
"description": "super complex model for tests",
"input": {"field_names": ["avg_cost", "item"]},
"inference_config": {
"regression": {
"results_field": "regression-value",
"num_top_feature_importance_values": 2
}
},
"definition": {
"preprocessors" : [{
"one_hot_encoding": {
"field": "product_type",
"hot_map": {
"TV": "type_tv",
"VCR": "type_vcr",
"Laptop": "type_laptop"
}
}
}],
"trained_model": {
"ensemble": {
"feature_names": [],
"target_type": "regression",
"trained_models": [
{
"tree": {
"feature_names": [
"avg_cost", "type_tv", "type_vcr", "type_laptop"
],
"tree_structure": [
{
"node_index": 0,
"split_feature": 0,
"split_gain": 12,
"threshold": 38,
"decision_type": "lte",
"default_left": true,
"left_child": 1,
"right_child": 2
},
{
"node_index": 1,
"leaf_value": 5.0
},
{
"node_index": 2,
"leaf_value": 2.0
}
],
"target_type": "regression"
}
}
]
}
}
}
}
- do:
headers:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
indices.create:
index: store
body:
mappings:
properties:
product:
type: keyword
cost:
type: integer
time:
type: date
- do:
headers:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
Content-Type: application/json
bulk:
index: store
refresh: true
body: |
{ "index": {} }
{ "product": "TV", "cost": 300, "time": 1587501233000 }
{ "index": {} }
{ "product": "TV", "cost": 400, "time": 1587501233000}
{ "index": {} }
{ "product": "VCR", "cost": 150, "time": 1587501233000 }
{ "index": {} }
{ "product": "VCR", "cost": 180, "time": 1587501233000 }
{ "index": {} }
{ "product": "Laptop", "cost": 15000, "time": 1587501233000 }
---
"Test pipeline regression simple":
- do:
search:
index: store
body: |
{
"size": 0,
"aggs": {
"good": {
"terms": {
"field": "product",
"size": 10
},
"aggs": {
"avg_cost_agg": {
"avg": {
"field": "cost"
}
},
"regression_agg": {
"inference": {
"model_id": "a-complex-regression-model",
"inference_config": {
"regression": {
"results_field": "value"
}
},
"buckets_path": {
"avg_cost": "avg_cost_agg"
}
}
}
}
}
}
}
- match: { aggregations.good.buckets.0.regression_agg.value: 2.0 }
---
"Test pipeline agg referencing a single bucket":
- do:
search:
index: store
body: |
{
"size": 0,
"query": {
"match_all": {}
},
"aggs": {
"date_histo": {
"date_histogram": {
"field": "time",
"fixed_interval": "1d"
},
"aggs": {
"good": {
"terms": {
"field": "product",
"size": 10
},
"aggs": {
"avg_cost_agg": {
"avg": {
"field": "cost"
}
}
}
},
"regression_agg": {
"inference": {
"model_id": "a-complex-regression-model",
"buckets_path": {
"avg_cost": "good['TV']>avg_cost_agg",
"product_type": "good['TV']"
}
}
}
}
}
}
}
- match: { aggregations.date_histo.buckets.0.regression_agg.value: 2.0 }
---
"Test all fields missing warning":
- do:
search:
index: store
body: |
{
"size": 0,
"query": { "match_all" : { } },
"aggs": {
"good": {
"terms": {
"field": "product",
"size": 10
},
"aggs": {
"avg_cost_agg": {
"avg": {
"field": "cost"
}
},
"regression_agg" : {
"inference": {
"model_id": "a-complex-regression-model",
"buckets_path": {
"cost" : "avg_cost_agg"
}
}
}
}
}
}
}
- match: { aggregations.good.buckets.0.regression_agg.warning: "Model [a-complex-regression-model] could not be inferred as all fields were missing" }
---
"Test setting results field is invalid":
- do:
catch: /action_request_validation_exception/
search:
index: store
body: |
{
"size": 0,
"query": { "match_all" : { } },
"aggs": {
"good": {
"terms": {
"field": "product",
"size": 10
},
"aggs": {
"avg_cost_agg": {
"avg": {
"field": "cost"
}
},
"regression_agg" : {
"inference": {
"model_id": "a-complex-regression-model",
"inference_config": {
"regression": {
"results_field": "banana"
}
},
"buckets_path": {
"cost" : "avg_cost_agg"
}
}
}
}
}
}
}
- match: { error.root_cause.0.type: "action_request_validation_exception" }
- match: { error.root_cause.0.reason: "Validation Failed: 1: setting option [results_field] to [banana] is not valid for inference aggregations;" }