[7.x] Pipeline Inference Aggregation (#58965)
Adds a pipeline aggregation that loads a model and performs inference on the input aggregation results.
This commit is contained in:
parent
d22dd437f1
commit
f6a0c2c59d
|
@ -0,0 +1,74 @@
|
|||
[role="xpack"]
|
||||
[testenv="basic"]
|
||||
[[search-aggregations-pipeline-inference-bucket-aggregation]]
|
||||
=== Inference Bucket Aggregation
|
||||
|
||||
A parent pipeline aggregation which loads a pre-trained model and performs inference on the
|
||||
collated result field from the parent bucket aggregation.
|
||||
|
||||
[[inference-bucket-agg-syntax]]
|
||||
==== Syntax
|
||||
|
||||
A `inference` aggregation looks like this in isolation:
|
||||
|
||||
[source,js]
|
||||
--------------------------------------------------
|
||||
{
|
||||
"inference": {
|
||||
"model_id": "a_model_for_inference", <1>
|
||||
"inference_config": { <2>
|
||||
"regression_config": {
|
||||
"num_top_feature_importance_values": 2
|
||||
}
|
||||
},
|
||||
"buckets_path": {
|
||||
"avg_cost": "avg_agg", <3>
|
||||
"max_cost": "max_agg"
|
||||
}
|
||||
}
|
||||
}
|
||||
--------------------------------------------------
|
||||
// NOTCONSOLE
|
||||
<1> The ID of model to use.
|
||||
<2> The optional inference config which overrides the model's default settings
|
||||
<3> Map the value of `avg_agg` to the model's input field `avg_cost`
|
||||
|
||||
[[inference-bucket-params]]
|
||||
.`inference` Parameters
|
||||
[options="header"]
|
||||
|===
|
||||
|Parameter Name |Description |Required |Default Value
|
||||
| `model_id` | The ID of the model to load and infer against | Required | -
|
||||
| `inference_config` | Contains the inference type and its options. There are two types: <<inference-agg-regression-opt,`regression`>> and <<inference-agg-classification-opt,`classification`>> | Optional | -
|
||||
| `buckets_path` | Defines the paths to the input aggregations and maps the aggregation names to the field names expected by the model.
|
||||
See <<buckets-path-syntax>> for more details | Required | -
|
||||
|===
|
||||
|
||||
|
||||
==== Configuration options for {infer} models
|
||||
The `inference_config` setting is optional and usaully isn't required as the pre-trained models come equipped with sensible defaults.
|
||||
In the context of aggregations some options can overridden for each of the 2 types of model.
|
||||
|
||||
[discrete]
|
||||
[[inference-agg-regression-opt]]
|
||||
===== Configuration options for {regression} models
|
||||
|
||||
`num_top_feature_importance_values`::
|
||||
(Optional, integer)
|
||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-regression-num-top-feature-importance-values]
|
||||
|
||||
[discrete]
|
||||
[[inference-agg-classification-opt]]
|
||||
===== Configuration options for {classification} models
|
||||
|
||||
`num_top_classes`::
|
||||
(Optional, integer)
|
||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-classification-num-top-classes]
|
||||
|
||||
`num_top_feature_importance_values`::
|
||||
(Optional, integer)
|
||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-classification-num-top-feature-importance-values]
|
||||
|
||||
`prediction_field_type`::
|
||||
(Optional, string)
|
||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-classification-prediction-field-type]
|
|
@ -84,8 +84,11 @@ public abstract class BasePipelineAggregationTestCase<AF extends AbstractPipelin
|
|||
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
|
||||
entries.addAll(indicesModule.getNamedWriteables());
|
||||
entries.addAll(searchModule.getNamedWriteables());
|
||||
entries.addAll(additionalNamedWriteables());
|
||||
namedWriteableRegistry = new NamedWriteableRegistry(entries);
|
||||
xContentRegistry = new NamedXContentRegistry(searchModule.getNamedXContents());
|
||||
List<NamedXContentRegistry.Entry> xContentEntries = searchModule.getNamedXContents();
|
||||
xContentEntries.addAll(additionalNamedContents());
|
||||
xContentRegistry = new NamedXContentRegistry(xContentEntries);
|
||||
//create some random type with some default field, those types will stick around for all of the subclasses
|
||||
currentTypes = new String[randomIntBetween(0, 5)];
|
||||
for (int i = 0; i < currentTypes.length; i++) {
|
||||
|
@ -101,6 +104,20 @@ public abstract class BasePipelineAggregationTestCase<AF extends AbstractPipelin
|
|||
return emptyList();
|
||||
}
|
||||
|
||||
/**
|
||||
* Any extra named writeables required not registered by {@link SearchModule}
|
||||
*/
|
||||
protected List<NamedWriteableRegistry.Entry> additionalNamedWriteables() {
|
||||
return emptyList();
|
||||
}
|
||||
|
||||
/**
|
||||
* Any extra named xcontents required not registered by {@link SearchModule}
|
||||
*/
|
||||
protected List<NamedXContentRegistry.Entry> additionalNamedContents() {
|
||||
return emptyList();
|
||||
}
|
||||
|
||||
/**
|
||||
* Generic test that creates new AggregatorFactory from the test
|
||||
* AggregatorFactory and checks both for equality and asserts equality on
|
||||
|
|
|
@ -12,6 +12,7 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.CustomWordEmbeddi
|
|||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
|
@ -20,6 +21,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedInf
|
|||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdate;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
|
||||
|
@ -121,6 +123,8 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
|
|||
ClassificationConfigUpdate::fromXContentStrict));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfigUpdate.class, RegressionConfigUpdate.NAME,
|
||||
RegressionConfigUpdate::fromXContentStrict));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfigUpdate.class, ResultsFieldUpdate.NAME,
|
||||
ResultsFieldUpdate::fromXContent));
|
||||
|
||||
// Inference models
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(InferenceModel.class, Ensemble.NAME, EnsembleInferenceModel::fromXContent));
|
||||
|
@ -170,6 +174,9 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
|
|||
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
|
||||
RegressionInferenceResults.NAME,
|
||||
RegressionInferenceResults::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
|
||||
WarningInferenceResults.NAME,
|
||||
WarningInferenceResults::new));
|
||||
|
||||
// Inference Configs
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
|
||||
|
|
|
@ -6,10 +6,9 @@
|
|||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
|
@ -18,9 +17,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
|
@ -85,6 +82,10 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
|||
return topClasses;
|
||||
}
|
||||
|
||||
public PredictionFieldType getPredictionFieldType() {
|
||||
return predictionFieldType;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
super.writeTo(out);
|
||||
|
@ -127,6 +128,11 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
|||
return classificationLabel == null ? super.valueAsString() : classificationLabel;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object predictedValue() {
|
||||
return predictionFieldType.transformPredictedValue(value(), valueAsString());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeResult(IngestDocument document, String parentResultField) {
|
||||
ExceptionsHelper.requireNonNull(document, "document");
|
||||
|
@ -138,7 +144,7 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
|||
topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
|
||||
}
|
||||
if (getFeatureImportance().size() > 0) {
|
||||
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance()
|
||||
document.setFieldValue(parentResultField + "." + FEATURE_IMPORTANCE, getFeatureImportance()
|
||||
.stream()
|
||||
.map(FeatureImportance::toMap)
|
||||
.collect(Collectors.toList()));
|
||||
|
@ -150,74 +156,15 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
|||
return NAME;
|
||||
}
|
||||
|
||||
public static class TopClassEntry implements Writeable {
|
||||
|
||||
public final ParseField CLASS_NAME = new ParseField("class_name");
|
||||
public final ParseField CLASS_PROBABILITY = new ParseField("class_probability");
|
||||
public final ParseField CLASS_SCORE = new ParseField("class_score");
|
||||
|
||||
private final Object classification;
|
||||
private final double probability;
|
||||
private final double score;
|
||||
|
||||
public TopClassEntry(Object classification, double probability, double score) {
|
||||
this.classification = ExceptionsHelper.requireNonNull(classification, CLASS_NAME);
|
||||
this.probability = probability;
|
||||
this.score = score;
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.field(resultsField, predictionFieldType.transformPredictedValue(value(), valueAsString()));
|
||||
if (topClasses.size() > 0) {
|
||||
builder.field(topNumClassesField, topClasses);
|
||||
}
|
||||
|
||||
public TopClassEntry(StreamInput in) throws IOException {
|
||||
if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
|
||||
this.classification = in.readGenericValue();
|
||||
} else {
|
||||
this.classification = in.readString();
|
||||
}
|
||||
this.probability = in.readDouble();
|
||||
this.score = in.readDouble();
|
||||
}
|
||||
|
||||
public Object getClassification() {
|
||||
return classification;
|
||||
}
|
||||
|
||||
public double getProbability() {
|
||||
return probability;
|
||||
}
|
||||
|
||||
public double getScore() {
|
||||
return score;
|
||||
}
|
||||
|
||||
public Map<String, Object> asValueMap() {
|
||||
Map<String, Object> map = new HashMap<>(3, 1.0f);
|
||||
map.put(CLASS_NAME.getPreferredName(), classification);
|
||||
map.put(CLASS_PROBABILITY.getPreferredName(), probability);
|
||||
map.put(CLASS_SCORE.getPreferredName(), score);
|
||||
return map;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
|
||||
out.writeGenericValue(classification);
|
||||
} else {
|
||||
out.writeString(classification.toString());
|
||||
}
|
||||
out.writeDouble(probability);
|
||||
out.writeDouble(score);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object object) {
|
||||
if (object == this) { return true; }
|
||||
if (object == null || getClass() != object.getClass()) { return false; }
|
||||
TopClassEntry that = (TopClassEntry) object;
|
||||
return Objects.equals(classification, that.classification) && probability == that.probability && score == that.score;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(classification, probability, score);
|
||||
if (getFeatureImportance().size() > 0) {
|
||||
builder.field(FEATURE_IMPORTANCE, getFeatureImportance());
|
||||
}
|
||||
return builder;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,23 +5,33 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
public class FeatureImportance implements Writeable {
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
public class FeatureImportance implements Writeable, ToXContentObject {
|
||||
|
||||
private final Map<String, Double> classImportance;
|
||||
private final double importance;
|
||||
private final String featureName;
|
||||
private static final String IMPORTANCE = "importance";
|
||||
private static final String FEATURE_NAME = "feature_name";
|
||||
static final String IMPORTANCE = "importance";
|
||||
static final String FEATURE_NAME = "feature_name";
|
||||
static final String CLASS_IMPORTANCE = "class_importance";
|
||||
|
||||
public static FeatureImportance forRegression(String featureName, double importance) {
|
||||
return new FeatureImportance(featureName, importance, null);
|
||||
|
@ -31,7 +41,24 @@ public class FeatureImportance implements Writeable {
|
|||
return new FeatureImportance(featureName, classImportance.values().stream().mapToDouble(Math::abs).sum(), classImportance);
|
||||
}
|
||||
|
||||
private FeatureImportance(String featureName, double importance, Map<String, Double> classImportance) {
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<FeatureImportance, Void> PARSER =
|
||||
new ConstructingObjectParser<>("feature_importance",
|
||||
a -> new FeatureImportance((String) a[0], (Double) a[1], (Map<String, Double>) a[2])
|
||||
);
|
||||
|
||||
static {
|
||||
PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME));
|
||||
PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
|
||||
PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
|
||||
new ParseField(FeatureImportance.CLASS_IMPORTANCE));
|
||||
}
|
||||
|
||||
public static FeatureImportance fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
FeatureImportance(String featureName, double importance, Map<String, Double> classImportance) {
|
||||
this.featureName = Objects.requireNonNull(featureName);
|
||||
this.importance = importance;
|
||||
this.classImportance = classImportance == null ? null : Collections.unmodifiableMap(classImportance);
|
||||
|
@ -79,6 +106,22 @@ public class FeatureImportance implements Writeable {
|
|||
return map;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(FEATURE_NAME, featureName);
|
||||
builder.field(IMPORTANCE, importance);
|
||||
if (classImportance != null && classImportance.isEmpty() == false) {
|
||||
builder.startObject(CLASS_IMPORTANCE);
|
||||
for (Map.Entry<String, Double> entry : classImportance.entrySet()) {
|
||||
builder.field(entry.getKey(), entry.getValue());
|
||||
}
|
||||
builder.endObject();
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object object) {
|
||||
if (object == this) { return true; }
|
||||
|
@ -93,5 +136,4 @@ public class FeatureImportance implements Writeable {
|
|||
public int hashCode() {
|
||||
return Objects.hash(featureName, importance, classImportance);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -6,10 +6,12 @@
|
|||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.common.xcontent.ToXContentFragment;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
|
||||
public interface InferenceResults extends NamedWriteable {
|
||||
public interface InferenceResults extends NamedWriteable, ToXContentFragment {
|
||||
|
||||
void writeResult(IngestDocument document, String parentResultField);
|
||||
|
||||
Object predictedValue();
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -56,9 +57,18 @@ public class RawInferenceResults implements InferenceResults {
|
|||
throw new UnsupportedOperationException("[raw] does not support writing inference results");
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object predictedValue() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
throw new UnsupportedOperationException("[raw] does not support toXContent");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference.results;
|
|||
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
|
@ -25,18 +26,28 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
|
|||
private final String resultsField;
|
||||
|
||||
public RegressionInferenceResults(double value, InferenceConfig config) {
|
||||
this(value, (RegressionConfig) config, Collections.emptyList());
|
||||
this(value, config, Collections.emptyList());
|
||||
}
|
||||
|
||||
public RegressionInferenceResults(double value, InferenceConfig config, List<FeatureImportance> featureImportance) {
|
||||
this(value, (RegressionConfig)config, featureImportance);
|
||||
this(value, ((RegressionConfig)config).getResultsField(),
|
||||
((RegressionConfig)config).getNumTopFeatureImportanceValues(), featureImportance);
|
||||
}
|
||||
|
||||
private RegressionInferenceResults(double value, RegressionConfig regressionConfig, List<FeatureImportance> featureImportance) {
|
||||
public RegressionInferenceResults(double value, String resultsField) {
|
||||
this(value, resultsField, 0, Collections.emptyList());
|
||||
}
|
||||
|
||||
public RegressionInferenceResults(double value, String resultsField,
|
||||
List<FeatureImportance> featureImportance) {
|
||||
this(value, resultsField, featureImportance.size(), featureImportance);
|
||||
}
|
||||
|
||||
public RegressionInferenceResults(double value, String resultsField, int topNFeatures,
|
||||
List<FeatureImportance> featureImportance) {
|
||||
super(value,
|
||||
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
|
||||
regressionConfig.getNumTopFeatureImportanceValues()));
|
||||
this.resultsField = regressionConfig.getResultsField();
|
||||
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance, topNFeatures));
|
||||
this.resultsField = resultsField;
|
||||
}
|
||||
|
||||
public RegressionInferenceResults(StreamInput in) throws IOException {
|
||||
|
@ -65,6 +76,11 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
|
|||
return Objects.hash(value(), resultsField, getFeatureImportance());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object predictedValue() {
|
||||
return super.value();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeResult(IngestDocument document, String parentResultField) {
|
||||
ExceptionsHelper.requireNonNull(document, "document");
|
||||
|
@ -78,9 +94,17 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.field(resultsField, value());
|
||||
if (getFeatureImportance().size() > 0) {
|
||||
builder.field(FEATURE_IMPORTANCE, getFeatureImportance());
|
||||
}
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -16,6 +16,8 @@ import java.util.stream.Collectors;
|
|||
|
||||
public abstract class SingleValueInferenceResults implements InferenceResults {
|
||||
|
||||
public static final String FEATURE_IMPORTANCE = "feature_importance";
|
||||
|
||||
private final double value;
|
||||
private final List<FeatureImportance> featureImportance;
|
||||
|
||||
|
|
|
@ -0,0 +1,138 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParseException;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
|
||||
public class TopClassEntry implements Writeable, ToXContentObject {
|
||||
|
||||
public static final ParseField CLASS_NAME = new ParseField("class_name");
|
||||
public static final ParseField CLASS_PROBABILITY = new ParseField("class_probability");
|
||||
public static final ParseField CLASS_SCORE = new ParseField("class_score");
|
||||
|
||||
public static final String NAME = "top_class";
|
||||
|
||||
private static final ConstructingObjectParser<TopClassEntry, Void> PARSER =
|
||||
new ConstructingObjectParser<>(NAME, a -> new TopClassEntry(a[0], (Double) a[1], (Double) a[2]));
|
||||
|
||||
static {
|
||||
PARSER.declareField(constructorArg(), (p, n) -> {
|
||||
Object o;
|
||||
XContentParser.Token token = p.currentToken();
|
||||
if (token == XContentParser.Token.VALUE_STRING) {
|
||||
o = p.text();
|
||||
} else if (token == XContentParser.Token.VALUE_BOOLEAN) {
|
||||
o = p.booleanValue();
|
||||
} else if (token == XContentParser.Token.VALUE_NUMBER) {
|
||||
o = p.doubleValue();
|
||||
} else {
|
||||
throw new XContentParseException(p.getTokenLocation(),
|
||||
"[" + NAME + "] failed to parse field [" + CLASS_NAME + "] value [" + token
|
||||
+ "] is not a string, boolean or number");
|
||||
}
|
||||
return o;
|
||||
}, CLASS_NAME, ObjectParser.ValueType.VALUE);
|
||||
PARSER.declareDouble(constructorArg(), CLASS_PROBABILITY);
|
||||
PARSER.declareDouble(constructorArg(), CLASS_SCORE);
|
||||
}
|
||||
|
||||
public static TopClassEntry fromXContent(XContentParser parser) throws IOException {
|
||||
return PARSER.parse(parser, null);
|
||||
}
|
||||
|
||||
private final Object classification;
|
||||
private final double probability;
|
||||
private final double score;
|
||||
|
||||
public TopClassEntry(Object classification, double probability, double score) {
|
||||
this.classification = ExceptionsHelper.requireNonNull(classification, CLASS_NAME);
|
||||
this.probability = probability;
|
||||
this.score = score;
|
||||
}
|
||||
|
||||
public TopClassEntry(StreamInput in) throws IOException {
|
||||
if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
|
||||
this.classification = in.readGenericValue();
|
||||
} else {
|
||||
this.classification = in.readString();
|
||||
}
|
||||
this.probability = in.readDouble();
|
||||
this.score = in.readDouble();
|
||||
}
|
||||
|
||||
public Object getClassification() {
|
||||
return classification;
|
||||
}
|
||||
|
||||
public double getProbability() {
|
||||
return probability;
|
||||
}
|
||||
|
||||
public double getScore() {
|
||||
return score;
|
||||
}
|
||||
|
||||
public Map<String, Object> asValueMap() {
|
||||
Map<String, Object> map = new HashMap<>(3, 1.0f);
|
||||
map.put(CLASS_NAME.getPreferredName(), classification);
|
||||
map.put(CLASS_PROBABILITY.getPreferredName(), probability);
|
||||
map.put(CLASS_SCORE.getPreferredName(), score);
|
||||
return map;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(CLASS_NAME.getPreferredName(), classification);
|
||||
builder.field(CLASS_PROBABILITY.getPreferredName(), probability);
|
||||
builder.field(CLASS_SCORE.getPreferredName(), score);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
|
||||
out.writeGenericValue(classification);
|
||||
} else {
|
||||
out.writeString(classification.toString());
|
||||
}
|
||||
out.writeDouble(probability);
|
||||
out.writeDouble(score);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object object) {
|
||||
if (object == this) { return true; }
|
||||
if (object == null || getClass() != object.getClass()) { return false; }
|
||||
TopClassEntry that = (TopClassEntry) object;
|
||||
return Objects.equals(classification, that.classification) && probability == that.probability && score == that.score;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(classification, probability, score);
|
||||
}
|
||||
}
|
|
@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.inference.results;
|
|||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
|
@ -55,12 +56,22 @@ public class WarningInferenceResults implements InferenceResults {
|
|||
public void writeResult(IngestDocument document, String parentResultField) {
|
||||
ExceptionsHelper.requireNonNull(document, "document");
|
||||
ExceptionsHelper.requireNonNull(parentResultField, "resultField");
|
||||
document.setFieldValue(parentResultField + "." + "warning", warning);
|
||||
document.setFieldValue(parentResultField + "." + NAME, warning);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object predictedValue() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.field(NAME, warning);
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -63,18 +63,18 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate {
|
|||
config.getPredictionFieldType());
|
||||
}
|
||||
|
||||
private static final ObjectParser<ClassificationConfigUpdate.Builder, Void> STRICT_PARSER = createParser(false);
|
||||
private static final ObjectParser<Builder, Void> STRICT_PARSER = createParser(false);
|
||||
|
||||
private static ObjectParser<ClassificationConfigUpdate.Builder, Void> createParser(boolean lenient) {
|
||||
ObjectParser<ClassificationConfigUpdate.Builder, Void> parser = new ObjectParser<>(
|
||||
private static ObjectParser<Builder, Void> createParser(boolean lenient) {
|
||||
ObjectParser<Builder, Void> parser = new ObjectParser<>(
|
||||
NAME.getPreferredName(),
|
||||
lenient,
|
||||
ClassificationConfigUpdate.Builder::new);
|
||||
parser.declareInt(ClassificationConfigUpdate.Builder::setNumTopClasses, NUM_TOP_CLASSES);
|
||||
parser.declareString(ClassificationConfigUpdate.Builder::setResultsField, RESULTS_FIELD);
|
||||
parser.declareString(ClassificationConfigUpdate.Builder::setTopClassesResultsField, TOP_CLASSES_RESULTS_FIELD);
|
||||
parser.declareInt(ClassificationConfigUpdate.Builder::setNumTopFeatureImportanceValues, NUM_TOP_FEATURE_IMPORTANCE_VALUES);
|
||||
parser.declareString(ClassificationConfigUpdate.Builder::setPredictionFieldType, PREDICTION_FIELD_TYPE);
|
||||
Builder::new);
|
||||
parser.declareInt(Builder::setNumTopClasses, NUM_TOP_CLASSES);
|
||||
parser.declareString(Builder::setResultsField, RESULTS_FIELD);
|
||||
parser.declareString(Builder::setTopClassesResultsField, TOP_CLASSES_RESULTS_FIELD);
|
||||
parser.declareInt(Builder::setNumTopFeatureImportanceValues, NUM_TOP_FEATURE_IMPORTANCE_VALUES);
|
||||
parser.declareString(Builder::setPredictionFieldType, PREDICTION_FIELD_TYPE);
|
||||
return parser;
|
||||
}
|
||||
|
||||
|
@ -96,6 +96,8 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate {
|
|||
}
|
||||
this.numTopFeatureImportanceValues = featureImportance;
|
||||
this.predictionFieldType = predictionFieldType;
|
||||
|
||||
InferenceConfigUpdate.checkFieldUniqueness(resultsField, topClassesResultsField);
|
||||
}
|
||||
|
||||
public ClassificationConfigUpdate(StreamInput in) throws IOException {
|
||||
|
@ -118,6 +120,16 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate {
|
|||
return resultsField;
|
||||
}
|
||||
|
||||
@Override
|
||||
public InferenceConfigUpdate.Builder<? extends InferenceConfigUpdate.Builder<?, ?>, ? extends InferenceConfigUpdate> newBuilder() {
|
||||
return new Builder()
|
||||
.setNumTopClasses(numTopClasses)
|
||||
.setTopClassesResultsField(topClassesResultsField)
|
||||
.setResultsField(resultsField)
|
||||
.setNumTopFeatureImportanceValues(numTopFeatureImportanceValues)
|
||||
.setPredictionFieldType(predictionFieldType);
|
||||
}
|
||||
|
||||
public Integer getNumTopFeatureImportanceValues() {
|
||||
return numTopFeatureImportanceValues;
|
||||
}
|
||||
|
@ -235,14 +247,14 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate {
|
|||
&& (predictionFieldType == null || predictionFieldType.equals(originalConfig.getPredictionFieldType()));
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
public static class Builder implements InferenceConfigUpdate.Builder<Builder, ClassificationConfigUpdate> {
|
||||
private Integer numTopClasses;
|
||||
private String topClassesResultsField;
|
||||
private String resultsField;
|
||||
private Integer numTopFeatureImportanceValues;
|
||||
private PredictionFieldType predictionFieldType;
|
||||
|
||||
public Builder setNumTopClasses(int numTopClasses) {
|
||||
public Builder setNumTopClasses(Integer numTopClasses) {
|
||||
this.numTopClasses = numTopClasses;
|
||||
return this;
|
||||
}
|
||||
|
@ -252,12 +264,13 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate {
|
|||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Builder setResultsField(String resultsField) {
|
||||
this.resultsField = resultsField;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setNumTopFeatureImportanceValues(int numTopFeatureImportanceValues) {
|
||||
public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) {
|
||||
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
|
||||
return this;
|
||||
}
|
||||
|
@ -271,6 +284,7 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate {
|
|||
return setPredictionFieldType(PredictionFieldType.fromString(predictionFieldType));
|
||||
}
|
||||
|
||||
@Override
|
||||
public ClassificationConfigUpdate build() {
|
||||
return new ClassificationConfigUpdate(numTopClasses,
|
||||
resultsField,
|
||||
|
|
|
@ -6,14 +6,53 @@
|
|||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
|
||||
|
||||
public interface InferenceConfigUpdate extends NamedXContentObject, NamedWriteable {
|
||||
Set<String> RESERVED_ML_FIELD_NAMES = new HashSet<>(Arrays.asList(
|
||||
WarningInferenceResults.WARNING.getPreferredName(),
|
||||
TrainedModelConfig.MODEL_ID.getPreferredName()));
|
||||
|
||||
InferenceConfig apply(InferenceConfig originalConfig);
|
||||
|
||||
InferenceConfig toConfig();
|
||||
|
||||
boolean isSupported(InferenceConfig config);
|
||||
|
||||
String getResultsField();
|
||||
|
||||
interface Builder<T extends Builder<T, U>, U extends InferenceConfigUpdate> {
|
||||
U build();
|
||||
T setResultsField(String resultsField);
|
||||
}
|
||||
|
||||
Builder<? extends Builder<?, ?>, ? extends InferenceConfigUpdate> newBuilder();
|
||||
|
||||
static void checkFieldUniqueness(String... fieldNames) {
|
||||
Set<String> duplicatedFieldNames = new HashSet<>();
|
||||
Set<String> currentFieldNames = new HashSet<>(RESERVED_ML_FIELD_NAMES);
|
||||
for(String fieldName : fieldNames) {
|
||||
if (fieldName == null) {
|
||||
continue;
|
||||
}
|
||||
if (currentFieldNames.contains(fieldName)) {
|
||||
duplicatedFieldNames.add(fieldName);
|
||||
} else {
|
||||
currentFieldNames.add(fieldName);
|
||||
}
|
||||
}
|
||||
if (duplicatedFieldNames.isEmpty() == false) {
|
||||
throw ExceptionsHelper.badRequestException("Cannot apply inference config." +
|
||||
" More than one field is configured as {}",
|
||||
duplicatedFieldNames);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,8 +7,8 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
|||
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
@ -28,11 +28,11 @@ public final class InferenceHelpers {
|
|||
/**
|
||||
* @return Tuple of the highest scored index and the top classes
|
||||
*/
|
||||
public static Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses(double[] probabilities,
|
||||
List<String> classificationLabels,
|
||||
@Nullable double[] classificationWeights,
|
||||
int numToInclude,
|
||||
PredictionFieldType predictionFieldType) {
|
||||
public static Tuple<Integer, List<TopClassEntry>> topClasses(double[] probabilities,
|
||||
List<String> classificationLabels,
|
||||
@Nullable double[] classificationWeights,
|
||||
int numToInclude,
|
||||
PredictionFieldType predictionFieldType) {
|
||||
|
||||
if (classificationLabels != null && probabilities.length != classificationLabels.size()) {
|
||||
throw ExceptionsHelper
|
||||
|
@ -65,10 +65,10 @@ public final class InferenceHelpers {
|
|||
classificationLabels;
|
||||
|
||||
int count = numToInclude < 0 ? probabilities.length : Math.min(numToInclude, probabilities.length);
|
||||
List<ClassificationInferenceResults.TopClassEntry> topClassEntries = new ArrayList<>(count);
|
||||
List<TopClassEntry> topClassEntries = new ArrayList<>(count);
|
||||
for(int i = 0; i < count; i++) {
|
||||
int idx = sortedIndices[i];
|
||||
topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(
|
||||
topClassEntries.add(new TopClassEntry(
|
||||
predictionFieldType.transformPredictedValue((double)idx, labels.get(idx)),
|
||||
probabilities[idx],
|
||||
scores[idx]));
|
||||
|
|
|
@ -18,7 +18,6 @@ import java.util.HashMap;
|
|||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig.DEFAULT_RESULTS_FIELD;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig.RESULTS_FIELD;
|
||||
|
||||
|
@ -68,6 +67,8 @@ public class RegressionConfigUpdate implements InferenceConfigUpdate {
|
|||
"] must be greater than or equal to 0");
|
||||
}
|
||||
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
|
||||
|
||||
InferenceConfigUpdate.checkFieldUniqueness(resultsField);
|
||||
}
|
||||
|
||||
public RegressionConfigUpdate(StreamInput in) throws IOException {
|
||||
|
@ -75,12 +76,19 @@ public class RegressionConfigUpdate implements InferenceConfigUpdate {
|
|||
this.numTopFeatureImportanceValues = in.readOptionalVInt();
|
||||
}
|
||||
|
||||
public int getNumTopFeatureImportanceValues() {
|
||||
return numTopFeatureImportanceValues == null ? 0 : numTopFeatureImportanceValues;
|
||||
public Integer getNumTopFeatureImportanceValues() {
|
||||
return numTopFeatureImportanceValues;
|
||||
}
|
||||
|
||||
public String getResultsField() {
|
||||
return resultsField == null ? DEFAULT_RESULTS_FIELD : resultsField;
|
||||
return resultsField;
|
||||
}
|
||||
|
||||
@Override
|
||||
public InferenceConfigUpdate.Builder<? extends InferenceConfigUpdate.Builder<?, ?>, ? extends InferenceConfigUpdate> newBuilder() {
|
||||
return new Builder()
|
||||
.setNumTopFeatureImportanceValues(numTopFeatureImportanceValues)
|
||||
.setResultsField(resultsField);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -165,10 +173,11 @@ public class RegressionConfigUpdate implements InferenceConfigUpdate {
|
|||
|| originalConfig.getNumTopFeatureImportanceValues() == numTopFeatureImportanceValues);
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
public static class Builder implements InferenceConfigUpdate.Builder<Builder, RegressionConfigUpdate> {
|
||||
private String resultsField;
|
||||
private Integer numTopFeatureImportanceValues;
|
||||
|
||||
@Override
|
||||
public Builder setResultsField(String resultsField) {
|
||||
this.resultsField = resultsField;
|
||||
return this;
|
||||
|
|
|
@ -0,0 +1,137 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig.RESULTS_FIELD;
|
||||
|
||||
/**
|
||||
* A config update that sets the results field only.
|
||||
* Supports any type of {@link InferenceConfig}
|
||||
*/
|
||||
public class ResultsFieldUpdate implements InferenceConfigUpdate {
|
||||
|
||||
public static final ParseField NAME = new ParseField("field_update");
|
||||
|
||||
private static final ConstructingObjectParser<ResultsFieldUpdate, Void> PARSER =
|
||||
new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new ResultsFieldUpdate((String) args[0]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(constructorArg(), RESULTS_FIELD);
|
||||
}
|
||||
|
||||
public static ResultsFieldUpdate fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final String resultsField;
|
||||
|
||||
public ResultsFieldUpdate(String resultsField) {
|
||||
this.resultsField = Objects.requireNonNull(resultsField);
|
||||
}
|
||||
|
||||
public ResultsFieldUpdate(StreamInput in) throws IOException {
|
||||
resultsField = in.readString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public InferenceConfig apply(InferenceConfig originalConfig) {
|
||||
if (originalConfig instanceof ClassificationConfig) {
|
||||
ClassificationConfigUpdate update = new ClassificationConfigUpdate(null, resultsField, null, null, null);
|
||||
return update.apply(originalConfig);
|
||||
} else if (originalConfig instanceof RegressionConfig) {
|
||||
RegressionConfigUpdate update = new RegressionConfigUpdate(resultsField, null);
|
||||
return update.apply(originalConfig);
|
||||
} else {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"Inference config of unknown type [{}] can not be updated", originalConfig.getName());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public InferenceConfig toConfig() {
|
||||
return new RegressionConfig(resultsField);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isSupported(InferenceConfig config) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getResultsField() {
|
||||
return resultsField;
|
||||
}
|
||||
|
||||
@Override
|
||||
public InferenceConfigUpdate.Builder<? extends InferenceConfigUpdate.Builder<?, ?>, ? extends InferenceConfigUpdate> newBuilder() {
|
||||
return new Builder().setResultsField(resultsField);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(resultsField);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
ResultsFieldUpdate that = (ResultsFieldUpdate) o;
|
||||
return Objects.equals(resultsField, that.resultsField);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(resultsField);
|
||||
}
|
||||
|
||||
public static class Builder implements InferenceConfigUpdate.Builder<Builder, ResultsFieldUpdate> {
|
||||
private String resultsField;
|
||||
|
||||
@Override
|
||||
public Builder setResultsField(String resultsField) {
|
||||
this.resultsField = resultsField;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ResultsFieldUpdate build() {
|
||||
return new ResultsFieldUpdate(resultsField);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -7,6 +7,7 @@
|
|||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
|
||||
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
@ -14,6 +15,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInference
|
|||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
|
||||
|
@ -85,12 +87,12 @@ public class EnsembleInferenceModel implements InferenceModel {
|
|||
private EnsembleInferenceModel(List<InferenceModel> models,
|
||||
OutputAggregator outputAggregator,
|
||||
TargetType targetType,
|
||||
List<String> classificationLabels,
|
||||
@Nullable List<String> classificationLabels,
|
||||
List<Double> classificationWeights) {
|
||||
this.models = ExceptionsHelper.requireNonNull(models, TRAINED_MODELS);
|
||||
this.outputAggregator = ExceptionsHelper.requireNonNull(outputAggregator, AGGREGATE_OUTPUT);
|
||||
this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE);
|
||||
this.classificationLabels = classificationLabels == null ? null : classificationLabels;
|
||||
this.classificationLabels = classificationLabels;
|
||||
this.classificationWeights = classificationWeights == null ?
|
||||
null :
|
||||
classificationWeights.stream().mapToDouble(Double::doubleValue).toArray();
|
||||
|
@ -204,7 +206,7 @@ public class EnsembleInferenceModel implements InferenceModel {
|
|||
ClassificationConfig classificationConfig = (ClassificationConfig) config;
|
||||
assert classificationWeights == null || processedInferences.length == classificationWeights.length;
|
||||
// Adjust the probabilities according to the thresholds
|
||||
Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses = InferenceHelpers.topClasses(
|
||||
Tuple<Integer, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
|
||||
processedInferences,
|
||||
classificationLabels,
|
||||
classificationWeights,
|
||||
|
|
|
@ -17,6 +17,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInference
|
|||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
|
||||
|
@ -173,7 +174,7 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
switch (targetType) {
|
||||
case CLASSIFICATION:
|
||||
ClassificationConfig classificationConfig = (ClassificationConfig) config;
|
||||
Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses = InferenceHelpers.topClasses(
|
||||
Tuple<Integer, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
|
||||
classificationProbability(value),
|
||||
classificationLabels,
|
||||
null,
|
||||
|
|
|
@ -16,6 +16,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
|
|||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
|
||||
|
@ -134,7 +135,7 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
|
|||
double[] probabilities = softMax(scores);
|
||||
|
||||
ClassificationConfig classificationConfig = (ClassificationConfig) config;
|
||||
Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses = InferenceHelpers.topClasses(
|
||||
Tuple<Integer, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
|
||||
probabilities,
|
||||
LANGUAGE_NAMES,
|
||||
null,
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
|
@ -31,21 +32,24 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
|
|||
FeatureImportanceTests::randomClassification :
|
||||
FeatureImportanceTests::randomRegression;
|
||||
|
||||
return new ClassificationInferenceResults(randomDouble(),
|
||||
ClassificationConfig config = ClassificationConfigTests.randomClassificationConfig();
|
||||
Double value = randomDouble();
|
||||
if (config.getPredictionFieldType() == PredictionFieldType.BOOLEAN) {
|
||||
// value must be close to 0 or 1
|
||||
value = randomBoolean() ? 0.0 : 1.0;
|
||||
}
|
||||
|
||||
return new ClassificationInferenceResults(value,
|
||||
randomBoolean() ? null : randomAlphaOfLength(10),
|
||||
randomBoolean() ? null :
|
||||
Stream.generate(ClassificationInferenceResultsTests::createRandomClassEntry)
|
||||
Stream.generate(TopClassEntryTests::createRandomTopClassEntry)
|
||||
.limit(randomIntBetween(0, 10))
|
||||
.collect(Collectors.toList()),
|
||||
randomBoolean() ? null :
|
||||
Stream.generate(featureImportanceCtor)
|
||||
.limit(randomIntBetween(1, 10))
|
||||
.collect(Collectors.toList()),
|
||||
ClassificationConfigTests.randomClassificationConfig());
|
||||
}
|
||||
|
||||
private static ClassificationInferenceResults.TopClassEntry createRandomClassEntry() {
|
||||
return new ClassificationInferenceResults.TopClassEntry(randomAlphaOfLength(10), randomDouble(), randomDouble());
|
||||
config);
|
||||
}
|
||||
|
||||
public void testWriteResultsWithClassificationLabel() {
|
||||
|
@ -70,10 +74,10 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
|
|||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testWriteResultsWithTopClasses() {
|
||||
List<ClassificationInferenceResults.TopClassEntry> entries = Arrays.asList(
|
||||
new ClassificationInferenceResults.TopClassEntry("foo", 0.7, 0.7),
|
||||
new ClassificationInferenceResults.TopClassEntry("bar", 0.2, 0.2),
|
||||
new ClassificationInferenceResults.TopClassEntry("baz", 0.1, 0.1));
|
||||
List<TopClassEntry> entries = Arrays.asList(
|
||||
new TopClassEntry("foo", 0.7, 0.7),
|
||||
new TopClassEntry("bar", 0.2, 0.2),
|
||||
new TopClassEntry("baz", 0.1, 0.1));
|
||||
ClassificationInferenceResults result = new ClassificationInferenceResults(1.0,
|
||||
"foo",
|
||||
entries,
|
||||
|
@ -84,8 +88,8 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
|
|||
List<?> list = document.getFieldValue("result_field.bar", List.class);
|
||||
assertThat(list.size(), equalTo(3));
|
||||
|
||||
for(int i = 0; i < 3; i++) {
|
||||
Map<String, Object> map = (Map<String, Object>)list.get(i);
|
||||
for (int i = 0; i < 3; i++) {
|
||||
Map<String, Object> map = (Map<String, Object>) list.get(i);
|
||||
assertThat(map, equalTo(entries.get(i).asValueMap()));
|
||||
}
|
||||
|
||||
|
@ -110,11 +114,11 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
|
|||
|
||||
assertThat(document.getFieldValue("result_field.predicted_value", String.class), equalTo("foo"));
|
||||
@SuppressWarnings("unchecked")
|
||||
List<Map<String, Object>> writtenImportance = (List<Map<String, Object>>)document.getFieldValue(
|
||||
List<Map<String, Object>> writtenImportance = (List<Map<String, Object>>) document.getFieldValue(
|
||||
"result_field.feature_importance",
|
||||
List.class);
|
||||
assertThat(writtenImportance, hasSize(3));
|
||||
importanceList.sort((l, r)-> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())));
|
||||
importanceList.sort((l, r) -> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())));
|
||||
for (int i = 0; i < 3; i++) {
|
||||
Map<String, Object> objectMap = writtenImportance.get(i);
|
||||
FeatureImportance importance = importanceList.get(i);
|
||||
|
@ -135,4 +139,39 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
|
|||
protected Writeable.Reader<ClassificationInferenceResults> instanceReader() {
|
||||
return ClassificationInferenceResults::new;
|
||||
}
|
||||
|
||||
public void testToXContent() {
|
||||
ClassificationConfig toStringConfig = new ClassificationConfig(1, null, null, null, PredictionFieldType.STRING);
|
||||
ClassificationInferenceResults result = new ClassificationInferenceResults(1.0, null, null, toStringConfig);
|
||||
String stringRep = Strings.toString(result);
|
||||
String expected = "{\"predicted_value\":\"1.0\"}";
|
||||
assertEquals(expected, stringRep);
|
||||
|
||||
ClassificationConfig toDoubleConfig = new ClassificationConfig(1, null, null, null, PredictionFieldType.NUMBER);
|
||||
result = new ClassificationInferenceResults(1.0, null, null, toDoubleConfig);
|
||||
stringRep = Strings.toString(result);
|
||||
expected = "{\"predicted_value\":1.0}";
|
||||
assertEquals(expected, stringRep);
|
||||
|
||||
ClassificationConfig boolFieldConfig = new ClassificationConfig(1, null, null, null, PredictionFieldType.BOOLEAN);
|
||||
result = new ClassificationInferenceResults(1.0, null, null, boolFieldConfig);
|
||||
stringRep = Strings.toString(result);
|
||||
expected = "{\"predicted_value\":true}";
|
||||
assertEquals(expected, stringRep);
|
||||
|
||||
ClassificationConfig config = new ClassificationConfig(1);
|
||||
result = new ClassificationInferenceResults(1.0, "label1", null, config);
|
||||
stringRep = Strings.toString(result);
|
||||
expected = "{\"predicted_value\":\"label1\"}";
|
||||
assertEquals(expected, stringRep);
|
||||
|
||||
FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyMap());
|
||||
TopClassEntry tp = new TopClassEntry("class", 1.0, 1.0);
|
||||
result = new ClassificationInferenceResults(1.0, "label1", Collections.singletonList(tp),
|
||||
Collections.singletonList(fi), config);
|
||||
stringRep = Strings.toString(result);
|
||||
expected = "{\"predicted_value\":\"label1\"," +
|
||||
"\"top_classes\":[{\"class_name\":\"class\",\"class_probability\":1.0,\"class_score\":1.0}]}";
|
||||
assertEquals(expected, stringRep);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,14 +6,15 @@
|
|||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
|
||||
public class FeatureImportanceTests extends AbstractWireSerializingTestCase<FeatureImportance> {
|
||||
public class FeatureImportanceTests extends AbstractSerializingTestCase<FeatureImportance> {
|
||||
|
||||
public static FeatureImportance createRandomInstance() {
|
||||
return randomBoolean() ? randomClassification() : randomRegression();
|
||||
|
@ -29,7 +30,6 @@ public class FeatureImportanceTests extends AbstractWireSerializingTestCase<Feat
|
|||
Stream.generate(() -> randomAlphaOfLength(10))
|
||||
.limit(randomLongBetween(2, 10))
|
||||
.collect(Collectors.toMap(Function.identity(), (k) -> randomDoubleBetween(-10, 10, false))));
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -41,4 +41,9 @@ public class FeatureImportanceTests extends AbstractWireSerializingTestCase<Feat
|
|||
protected Writeable.Reader<FeatureImportance> instanceReader() {
|
||||
return FeatureImportance::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected FeatureImportance doParseInstance(XContentParser parser) throws IOException {
|
||||
return FeatureImportance.fromXContent(parser);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,12 +5,14 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
@ -75,4 +77,18 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest
|
|||
protected Writeable.Reader<RegressionInferenceResults> instanceReader() {
|
||||
return RegressionInferenceResults::new;
|
||||
}
|
||||
|
||||
public void testToXContent() {
|
||||
String resultsField = "ml.results";
|
||||
RegressionInferenceResults result = new RegressionInferenceResults(1.0, resultsField);
|
||||
String stringRep = Strings.toString(result);
|
||||
String expected = "{\"" + resultsField + "\":1.0}";
|
||||
assertEquals(expected, stringRep);
|
||||
|
||||
FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyMap());
|
||||
result = new RegressionInferenceResults(1.0, resultsField, Collections.singletonList(fi));
|
||||
stringRep = Strings.toString(result);
|
||||
expected = "{\"" + resultsField + "\":1.0,\"feature_importance\":[{\"feature_name\":\"foo\",\"importance\":1.0}]}";
|
||||
assertEquals(expected, stringRep);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class TopClassEntryTests extends AbstractSerializingTestCase<TopClassEntry> {
|
||||
|
||||
public static TopClassEntry createRandomTopClassEntry() {
|
||||
Object classification;
|
||||
if (randomBoolean()) {
|
||||
classification = randomAlphaOfLength(10);
|
||||
} else if (randomBoolean()) {
|
||||
classification = randomBoolean();
|
||||
} else {
|
||||
classification = randomDouble();
|
||||
}
|
||||
return new TopClassEntry(classification, randomDouble(), randomDouble());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TopClassEntry doParseInstance(XContentParser parser) throws IOException {
|
||||
return TopClassEntry.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<TopClassEntry> instanceReader() {
|
||||
return TopClassEntry::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TopClassEntry createTestInstance() {
|
||||
return createRandomTopClassEntry();
|
||||
}
|
||||
}
|
|
@ -5,15 +5,29 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class WarningInferenceResultsTests extends AbstractWireSerializingTestCase<WarningInferenceResults> {
|
||||
public class WarningInferenceResultsTests extends AbstractSerializingTestCase<WarningInferenceResults> {
|
||||
|
||||
private static final ConstructingObjectParser<WarningInferenceResults, Void> PARSER =
|
||||
new ConstructingObjectParser<>("inference_warning",
|
||||
a -> new WarningInferenceResults((String) a[0])
|
||||
);
|
||||
|
||||
static {
|
||||
PARSER.declareString(constructorArg(), new ParseField(WarningInferenceResults.NAME));
|
||||
}
|
||||
|
||||
public static WarningInferenceResults createRandomResults() {
|
||||
return new WarningInferenceResults(randomAlphaOfLength(10));
|
||||
|
@ -36,4 +50,9 @@ public class WarningInferenceResultsTests extends AbstractWireSerializingTestCas
|
|||
protected Writeable.Reader<WarningInferenceResults> instanceReader() {
|
||||
return WarningInferenceResults::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected WarningInferenceResults doParseInstance(XContentParser parser) throws IOException {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
||||
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
@ -18,6 +19,7 @@ import java.util.Map;
|
|||
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests.randomClassificationConfig;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
|
||||
public class ClassificationConfigUpdateTests extends AbstractBWCSerializationTestCase<ClassificationConfigUpdate> {
|
||||
|
||||
|
@ -74,6 +76,30 @@ public class ClassificationConfigUpdateTests extends AbstractBWCSerializationTes
|
|||
));
|
||||
}
|
||||
|
||||
public void testDuplicateFieldNamesThrow() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new ClassificationConfigUpdate(5, "foo", "foo", 1, PredictionFieldType.BOOLEAN));
|
||||
|
||||
assertEquals("Cannot apply inference config. More than one field is configured as [foo]", e.getMessage());
|
||||
}
|
||||
|
||||
public void testDuplicateWithResultsField() {
|
||||
ClassificationConfigUpdate update = randomClassificationConfigUpdate();
|
||||
String newFieldName = update.getResultsField() + "_value";
|
||||
|
||||
InferenceConfigUpdate updateWithField = update.newBuilder().setResultsField(newFieldName).build();
|
||||
|
||||
assertNotSame(updateWithField, update);
|
||||
assertEquals(newFieldName, updateWithField.getResultsField());
|
||||
// other fields are the same
|
||||
assertThat(updateWithField, instanceOf(ClassificationConfigUpdate.class));
|
||||
ClassificationConfigUpdate classUpdate = (ClassificationConfigUpdate)updateWithField;
|
||||
assertEquals(update.getTopClassesResultsField(), classUpdate.getTopClassesResultsField());
|
||||
assertEquals(update.getNumTopClasses(), classUpdate.getNumTopClasses());
|
||||
assertEquals(update.getPredictionFieldType(), classUpdate.getPredictionFieldType());
|
||||
assertEquals(update.getNumTopFeatureImportanceValues(), classUpdate.getNumTopFeatureImportanceValues());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ClassificationConfigUpdate createTestInstance() {
|
||||
return randomClassificationConfigUpdate();
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
||||
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
@ -18,6 +19,7 @@ import java.util.Map;
|
|||
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests.randomRegressionConfig;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
|
||||
public class RegressionConfigUpdateTests extends AbstractBWCSerializationTestCase<RegressionConfigUpdate> {
|
||||
|
||||
|
@ -60,6 +62,26 @@ public class RegressionConfigUpdateTests extends AbstractBWCSerializationTestCas
|
|||
));
|
||||
}
|
||||
|
||||
public void testInvalidResultFieldNotUnique() {
|
||||
ElasticsearchStatusException e =
|
||||
expectThrows(ElasticsearchStatusException.class, () -> new RegressionConfigUpdate("warning", 0));
|
||||
assertEquals("Cannot apply inference config. More than one field is configured as [warning]", e.getMessage());
|
||||
}
|
||||
|
||||
public void testNewBuilder() {
|
||||
RegressionConfigUpdate update = randomRegressionConfigUpdate();
|
||||
String newFieldName = update.getResultsField() + "_value";
|
||||
|
||||
InferenceConfigUpdate updateWithField = update.newBuilder().setResultsField(newFieldName).build();
|
||||
|
||||
assertNotSame(updateWithField, update);
|
||||
assertEquals(newFieldName, updateWithField.getResultsField());
|
||||
// other fields are the same
|
||||
assertThat(updateWithField, instanceOf(RegressionConfigUpdate.class));
|
||||
assertEquals(update.getNumTopFeatureImportanceValues(),
|
||||
((RegressionConfigUpdate)updateWithField).getNumTopFeatureImportanceValues());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RegressionConfigUpdate createTestInstance() {
|
||||
return randomRegressionConfigUpdate();
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
public class ResultsFieldUpdateTests extends AbstractSerializingTestCase<ResultsFieldUpdate> {
|
||||
|
||||
@Override
|
||||
protected ResultsFieldUpdate doParseInstance(XContentParser parser) throws IOException {
|
||||
return ResultsFieldUpdate.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<ResultsFieldUpdate> instanceReader() {
|
||||
return ResultsFieldUpdate::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ResultsFieldUpdate createTestInstance() {
|
||||
return new ResultsFieldUpdate(randomAlphaOfLength(4));
|
||||
}
|
||||
|
||||
public void testIsSupported() {
|
||||
ResultsFieldUpdate update = new ResultsFieldUpdate("foo");
|
||||
assertTrue(update.isSupported(mock(InferenceConfig.class)));
|
||||
}
|
||||
|
||||
public void testApply_OnlyTheResultsFieldIsChanged() {
|
||||
if (randomBoolean()) {
|
||||
ClassificationConfig config = ClassificationConfigTests.randomClassificationConfig();
|
||||
String newResultsField = config.getResultsField() + "foobar";
|
||||
ResultsFieldUpdate update = new ResultsFieldUpdate(newResultsField);
|
||||
InferenceConfig applied = update.apply(config);
|
||||
|
||||
assertThat(applied, instanceOf(ClassificationConfig.class));
|
||||
ClassificationConfig appliedConfig = (ClassificationConfig)applied;
|
||||
assertEquals(newResultsField, appliedConfig.getResultsField());
|
||||
|
||||
assertEquals(appliedConfig, new ClassificationConfig.Builder(config).setResultsField(newResultsField).build());
|
||||
} else {
|
||||
RegressionConfig config = RegressionConfigTests.randomRegressionConfig();
|
||||
String newResultsField = config.getResultsField() + "foobar";
|
||||
ResultsFieldUpdate update = new ResultsFieldUpdate(newResultsField);
|
||||
InferenceConfig applied = update.apply(config);
|
||||
|
||||
assertThat(applied, instanceOf(RegressionConfig.class));
|
||||
RegressionConfig appliedConfig = (RegressionConfig)applied;
|
||||
assertEquals(newResultsField, appliedConfig.getResultsField());
|
||||
|
||||
assertEquals(appliedConfig, new RegressionConfig.Builder(config).setResultsField(newResultsField).build());
|
||||
}
|
||||
}
|
||||
}
|
|
@ -12,6 +12,7 @@ import org.elasticsearch.test.ESTestCase;
|
|||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
|
@ -137,7 +138,7 @@ public class EnsembleInferenceModelTests extends ESTestCase {
|
|||
List<Double> expected = Arrays.asList(0.768524783, 0.231475216);
|
||||
List<Double> scores = Arrays.asList(0.230557435, 0.162032651);
|
||||
double eps = 0.000001;
|
||||
List<ClassificationInferenceResults.TopClassEntry> probabilities =
|
||||
List<TopClassEntry> probabilities =
|
||||
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
|
||||
.getTopClasses();
|
||||
for(int i = 0; i < expected.size(); i++) {
|
||||
|
|
|
@ -12,6 +12,7 @@ import org.elasticsearch.test.ESTestCase;
|
|||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
|
@ -179,7 +180,7 @@ public class TreeInferenceModelTests extends ESTestCase {
|
|||
List<Double> expectedProbs = Arrays.asList(1.0, 0.0);
|
||||
List<String> expectedFields = Arrays.asList("dog", "cat");
|
||||
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
|
||||
List<ClassificationInferenceResults.TopClassEntry> probabilities =
|
||||
List<TopClassEntry> probabilities =
|
||||
((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
|
||||
.getTopClasses();
|
||||
for(int i = 0; i < expectedProbs.size(); i++) {
|
||||
|
|
|
@ -11,7 +11,7 @@ dependencies {
|
|||
// bring in machine learning rest test suite
|
||||
restResources {
|
||||
restApi {
|
||||
includeCore '_common', 'cluster', 'nodes', 'indices', 'index', 'search', 'get', 'count', 'ingest'
|
||||
includeCore '_common', 'cluster', 'nodes', 'indices', 'index', 'search', 'get', 'count', 'ingest', 'bulk'
|
||||
includeXpack 'ml', 'cat'
|
||||
}
|
||||
restTests {
|
||||
|
|
|
@ -7,7 +7,7 @@ minimal:
|
|||
# Give all users involved in these tests access to the indices where the data to
|
||||
# be analyzed is stored, because the ML roles alone do not provide access to
|
||||
# non-ML indices
|
||||
- names: [ 'airline-data', 'index-*', 'unavailable-data', 'utopia' ]
|
||||
- names: [ 'airline-data', 'index-*', 'unavailable-data', 'utopia', 'store' ]
|
||||
privileges:
|
||||
- create_index
|
||||
- indices:admin/refresh
|
||||
|
|
|
@ -8,6 +8,8 @@ package org.elasticsearch.smoketest;
|
|||
import com.carrotsearch.randomizedtesting.annotations.Name;
|
||||
|
||||
import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate;
|
||||
import org.elasticsearch.test.rest.yaml.section.DoSection;
|
||||
import org.elasticsearch.test.rest.yaml.section.ExecutableSection;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
|
@ -16,8 +18,11 @@ import static org.hamcrest.Matchers.either;
|
|||
|
||||
public class MlWithSecurityInsufficientRoleIT extends MlWithSecurityIT {
|
||||
|
||||
private final ClientYamlTestCandidate testCandidate;
|
||||
|
||||
public MlWithSecurityInsufficientRoleIT(@Name("yaml") ClientYamlTestCandidate testCandidate) {
|
||||
super(testCandidate);
|
||||
this.testCandidate = testCandidate;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -26,7 +31,18 @@ public class MlWithSecurityInsufficientRoleIT extends MlWithSecurityIT {
|
|||
// Cannot use expectThrows here because blacklisted tests will throw an
|
||||
// InternalAssumptionViolatedException rather than an AssertionError
|
||||
super.test();
|
||||
fail("should have failed because of missing role");
|
||||
|
||||
// We should have got here if and only if no ML endpoints were called
|
||||
for (ExecutableSection section : testCandidate.getTestSection().getExecutableSections()) {
|
||||
if (section instanceof DoSection) {
|
||||
String apiName = ((DoSection) section).getApiCallSection().getApi();
|
||||
|
||||
if (apiName.startsWith("ml.")) {
|
||||
fail("call to ml endpoint should have failed because of missing role");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} catch (AssertionError ae) {
|
||||
// Some tests assert on searches of wildcarded ML indices rather than on ML endpoints. For these we expect no hits.
|
||||
if (ae.getMessage().contains("hits.total didn't match expected value")) {
|
||||
|
|
|
@ -47,7 +47,7 @@ public class MlWithSecurityUserRoleIT extends MlWithSecurityIT {
|
|||
if (section instanceof DoSection) {
|
||||
String apiName = ((DoSection) section).getApiCallSection().getApi();
|
||||
|
||||
if (((DoSection) section).getApiCallSection().getApi().startsWith("ml.") && isAllowed(apiName) == false) {
|
||||
if (apiName.startsWith("ml.") && isAllowed(apiName) == false) {
|
||||
fail("should have failed because of missing role");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,8 +22,8 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
|
|||
import org.elasticsearch.cluster.node.DiscoveryNodeRole;
|
||||
import org.elasticsearch.cluster.node.DiscoveryNodes;
|
||||
import org.elasticsearch.cluster.service.ClusterService;
|
||||
import org.elasticsearch.common.inject.Module;
|
||||
import org.elasticsearch.common.breaker.CircuitBreaker;
|
||||
import org.elasticsearch.common.inject.Module;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.settings.ClusterSettings;
|
||||
import org.elasticsearch.common.settings.IndexScopedSettings;
|
||||
|
@ -35,6 +35,7 @@ import org.elasticsearch.common.settings.SettingsModule;
|
|||
import org.elasticsearch.common.unit.ByteSizeUnit;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.xcontent.ContextParser;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.env.Environment;
|
||||
import org.elasticsearch.env.NodeEnvironment;
|
||||
|
@ -53,11 +54,13 @@ import org.elasticsearch.plugins.CircuitBreakerPlugin;
|
|||
import org.elasticsearch.plugins.IngestPlugin;
|
||||
import org.elasticsearch.plugins.PersistentTaskPlugin;
|
||||
import org.elasticsearch.plugins.Plugin;
|
||||
import org.elasticsearch.plugins.SearchPlugin;
|
||||
import org.elasticsearch.plugins.SystemIndexPlugin;
|
||||
import org.elasticsearch.repositories.RepositoriesService;
|
||||
import org.elasticsearch.rest.RestController;
|
||||
import org.elasticsearch.rest.RestHandler;
|
||||
import org.elasticsearch.script.ScriptService;
|
||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
import org.elasticsearch.threadpool.ExecutorBuilder;
|
||||
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
|
@ -220,6 +223,8 @@ import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationP
|
|||
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
|
||||
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
|
||||
import org.elasticsearch.xpack.ml.inference.aggs.InferencePipelineAggregationBuilder;
|
||||
import org.elasticsearch.xpack.ml.inference.aggs.InternalInferenceAggregation;
|
||||
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
|
||||
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
|
||||
import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider;
|
||||
|
@ -336,7 +341,8 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
|
|||
AnalysisPlugin,
|
||||
CircuitBreakerPlugin,
|
||||
IngestPlugin,
|
||||
PersistentTaskPlugin {
|
||||
PersistentTaskPlugin,
|
||||
SearchPlugin {
|
||||
public static final String NAME = "ml";
|
||||
public static final String BASE_PATH = "/_ml/";
|
||||
public static final String PRE_V7_BASE_PATH = "/_xpack/ml/";
|
||||
|
@ -454,6 +460,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
|
|||
private final SetOnce<MlMemoryTracker> memoryTracker = new SetOnce<>();
|
||||
private final SetOnce<ActionFilter> mlUpgradeModeActionFilter = new SetOnce<>();
|
||||
private final SetOnce<CircuitBreaker> inferenceModelBreaker = new SetOnce<>();
|
||||
private final SetOnce<ModelLoadingService> modelLoadingService = new SetOnce<>();
|
||||
|
||||
public MachineLearning(Settings settings, Path configPath) {
|
||||
this.settings = settings;
|
||||
|
@ -683,6 +690,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
|
|||
settings,
|
||||
clusterService.getNodeName(),
|
||||
inferenceModelBreaker.get());
|
||||
this.modelLoadingService.set(modelLoadingService);
|
||||
|
||||
// Data frame analytics components
|
||||
AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory,
|
||||
|
@ -962,6 +970,18 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
|
|||
return Collections.singletonMap(MlClassicTokenizer.NAME, MlClassicTokenizerFactory::new);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<PipelineAggregationSpec> getPipelineAggregations() {
|
||||
PipelineAggregationSpec spec = new PipelineAggregationSpec(InferencePipelineAggregationBuilder.NAME,
|
||||
in -> new InferencePipelineAggregationBuilder(in, modelLoadingService),
|
||||
(ContextParser<String, ? extends PipelineAggregationBuilder>)
|
||||
(parser, name) -> InferencePipelineAggregationBuilder.parse(modelLoadingService, name, parser
|
||||
));
|
||||
spec.addResultReader(InternalInferenceAggregation::new);
|
||||
|
||||
return Collections.singletonList(spec);
|
||||
}
|
||||
|
||||
@Override
|
||||
public UnaryOperator<Map<String, IndexTemplateMetadata>> getIndexTemplateMetadataUpgrader() {
|
||||
return UnaryOperator.identity();
|
||||
|
|
|
@ -0,0 +1,214 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.inference.aggs;
|
||||
|
||||
import org.apache.lucene.util.SetOnce;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.LatchedActionListener;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdate;
|
||||
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
|
||||
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.TreeMap;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
|
||||
public class InferencePipelineAggregationBuilder extends AbstractPipelineAggregationBuilder<InferencePipelineAggregationBuilder> {
|
||||
|
||||
public static String NAME = "inference";
|
||||
|
||||
public static final ParseField MODEL_ID = new ParseField("model_id");
|
||||
private static final ParseField INFERENCE_CONFIG = new ParseField("inference_config");
|
||||
|
||||
static String AGGREGATIONS_RESULTS_FIELD = "value";
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<InferencePipelineAggregationBuilder,
|
||||
Tuple<SetOnce<ModelLoadingService>, String>> PARSER = new ConstructingObjectParser<>(
|
||||
NAME, false,
|
||||
(args, context) -> new InferencePipelineAggregationBuilder(context.v2(), context.v1(), (Map<String, String>) args[0])
|
||||
);
|
||||
|
||||
static {
|
||||
PARSER.declareObject(constructorArg(), (p, c) -> p.mapStrings(), BUCKETS_PATH_FIELD);
|
||||
PARSER.declareString(InferencePipelineAggregationBuilder::setModelId, MODEL_ID);
|
||||
PARSER.declareNamedObject(InferencePipelineAggregationBuilder::setInferenceConfig,
|
||||
(p, c, n) -> p.namedObject(InferenceConfigUpdate.class, n, c), INFERENCE_CONFIG);
|
||||
}
|
||||
|
||||
private final Map<String, String> bucketPathMap;
|
||||
private String modelId;
|
||||
private InferenceConfigUpdate inferenceConfig;
|
||||
private final SetOnce<ModelLoadingService> modelLoadingService;
|
||||
|
||||
public static InferencePipelineAggregationBuilder parse(SetOnce<ModelLoadingService> modelLoadingService,
|
||||
String pipelineAggregatorName,
|
||||
XContentParser parser) {
|
||||
Tuple<SetOnce<ModelLoadingService>, String> context = new Tuple<>(modelLoadingService, pipelineAggregatorName);
|
||||
return PARSER.apply(parser, context);
|
||||
}
|
||||
|
||||
public InferencePipelineAggregationBuilder(String name, SetOnce<ModelLoadingService> modelLoadingService,
|
||||
Map<String, String> bucketsPath) {
|
||||
super(name, NAME, new TreeMap<>(bucketsPath).values().toArray(new String[] {}));
|
||||
this.modelLoadingService = modelLoadingService;
|
||||
this.bucketPathMap = bucketsPath;
|
||||
}
|
||||
|
||||
public InferencePipelineAggregationBuilder(StreamInput in, SetOnce<ModelLoadingService> modelLoadingService) throws IOException {
|
||||
super(in, NAME);
|
||||
modelId = in.readString();
|
||||
bucketPathMap = in.readMap(StreamInput::readString, StreamInput::readString);
|
||||
inferenceConfig = in.readOptionalNamedWriteable(InferenceConfigUpdate.class);
|
||||
this.modelLoadingService = modelLoadingService;
|
||||
}
|
||||
|
||||
void setModelId(String modelId) {
|
||||
this.modelId = modelId;
|
||||
}
|
||||
|
||||
void setInferenceConfig(InferenceConfigUpdate inferenceConfig) {
|
||||
this.inferenceConfig = inferenceConfig;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void validate(ValidationContext context) {
|
||||
context.validateHasParent(NAME, name);
|
||||
if (modelId == null) {
|
||||
context.addValidationError("[model_id] must be set");
|
||||
}
|
||||
|
||||
if (inferenceConfig != null) {
|
||||
// error if the results field is set and not equal to the only acceptable value
|
||||
String resultsField = inferenceConfig.getResultsField();
|
||||
if (Strings.isNullOrEmpty(resultsField) == false && AGGREGATIONS_RESULTS_FIELD.equals(resultsField) == false) {
|
||||
context.addValidationError("setting option [" + ClassificationConfig.RESULTS_FIELD.getPreferredName()
|
||||
+ "] to [" + resultsField + "] is not valid for inference aggregations");
|
||||
}
|
||||
|
||||
if (inferenceConfig instanceof ClassificationConfigUpdate) {
|
||||
ClassificationConfigUpdate classUpdate = (ClassificationConfigUpdate)inferenceConfig;
|
||||
|
||||
// error if the top classes result field is set and not equal to the only acceptable value
|
||||
String topClassesField = classUpdate.getTopClassesResultsField();
|
||||
if (Strings.isNullOrEmpty(topClassesField) == false &&
|
||||
ClassificationConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD.equals(topClassesField) == false) {
|
||||
context.addValidationError("setting option [" + ClassificationConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD
|
||||
+ "] to [" + topClassesField + "] is not valid for inference aggregations");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doWriteTo(StreamOutput out) throws IOException {
|
||||
out.writeString(modelId);
|
||||
out.writeMap(bucketPathMap, StreamOutput::writeString, StreamOutput::writeString);
|
||||
out.writeOptionalNamedWriteable(inferenceConfig);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected PipelineAggregator createInternal(Map<String, Object> metaData) {
|
||||
|
||||
SetOnce<LocalModel> model = new SetOnce<>();
|
||||
SetOnce<Exception> error = new SetOnce<>();
|
||||
CountDownLatch latch = new CountDownLatch(1);
|
||||
ActionListener<LocalModel> listener = new LatchedActionListener<>(
|
||||
ActionListener.wrap(model::set, error::set), latch);
|
||||
|
||||
modelLoadingService.get().getModelForSearch(modelId, listener);
|
||||
try {
|
||||
// TODO Avoid the blocking wait
|
||||
latch.await();
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
throw new RuntimeException("Inference aggregation interrupted loading model", e);
|
||||
}
|
||||
|
||||
Exception e = error.get();
|
||||
if (e != null) {
|
||||
if (e instanceof RuntimeException) {
|
||||
throw (RuntimeException)e;
|
||||
} else {
|
||||
throw new RuntimeException(error.get());
|
||||
}
|
||||
}
|
||||
|
||||
InferenceConfigUpdate update = adaptForAggregation(inferenceConfig);
|
||||
|
||||
return new InferencePipelineAggregator(name, bucketPathMap, metaData, update, model.get());
|
||||
}
|
||||
|
||||
static InferenceConfigUpdate adaptForAggregation(InferenceConfigUpdate originalUpdate) {
|
||||
InferenceConfigUpdate updated;
|
||||
if (originalUpdate == null) {
|
||||
updated = new ResultsFieldUpdate(AGGREGATIONS_RESULTS_FIELD);
|
||||
} else {
|
||||
// Create an update that changes the default results field.
|
||||
// This isn't necessary for top classes as the default is the same one used here
|
||||
updated = originalUpdate.newBuilder().setResultsField(AGGREGATIONS_RESULTS_FIELD).build();
|
||||
}
|
||||
|
||||
return updated;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean overrideBucketsPath() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.field(MODEL_ID.getPreferredName(), modelId);
|
||||
builder.field(BUCKETS_PATH_FIELD.getPreferredName(), bucketPathMap);
|
||||
if (inferenceConfig != null) {
|
||||
builder.startObject(INFERENCE_CONFIG.getPreferredName());
|
||||
builder.field(inferenceConfig.getName(), inferenceConfig);
|
||||
builder.endObject();
|
||||
}
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(super.hashCode(), bucketPathMap, modelId, inferenceConfig);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (this == obj) return true;
|
||||
if (obj == null || getClass() != obj.getClass()) return false;
|
||||
if (super.equals(obj) == false) return false;
|
||||
|
||||
InferencePipelineAggregationBuilder other = (InferencePipelineAggregationBuilder) obj;
|
||||
return Objects.equals(bucketPathMap, other.bucketPathMap)
|
||||
&& Objects.equals(modelId, other.modelId)
|
||||
&& Objects.equals(inferenceConfig, other.inferenceConfig);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,133 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.inference.aggs;
|
||||
|
||||
import org.elasticsearch.search.aggregations.AggregationExecutionException;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregation;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregations;
|
||||
import org.elasticsearch.search.aggregations.InternalMultiBucketAggregation;
|
||||
import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation;
|
||||
import org.elasticsearch.search.aggregations.bucket.terms.StringTerms;
|
||||
import org.elasticsearch.search.aggregations.metrics.InternalNumericMetricsAggregation;
|
||||
import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.search.aggregations.support.AggregationPath;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
|
||||
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
|
||||
public class InferencePipelineAggregator extends PipelineAggregator {
|
||||
|
||||
private final Map<String, String> bucketPathMap;
|
||||
private final InferenceConfigUpdate configUpdate;
|
||||
private final LocalModel model;
|
||||
|
||||
public InferencePipelineAggregator(String name, Map<String,
|
||||
String> bucketPathMap,
|
||||
Map<String, Object> metaData,
|
||||
InferenceConfigUpdate configUpdate,
|
||||
LocalModel model) {
|
||||
super(name, bucketPathMap.values().toArray(new String[] {}), metaData);
|
||||
this.bucketPathMap = bucketPathMap;
|
||||
this.configUpdate = configUpdate;
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
@SuppressWarnings({"rawtypes", "unchecked"})
|
||||
@Override
|
||||
public InternalAggregation reduce(InternalAggregation aggregation, InternalAggregation.ReduceContext reduceContext) {
|
||||
|
||||
InternalMultiBucketAggregation<InternalMultiBucketAggregation, InternalMultiBucketAggregation.InternalBucket> originalAgg =
|
||||
(InternalMultiBucketAggregation<InternalMultiBucketAggregation, InternalMultiBucketAggregation.InternalBucket>) aggregation;
|
||||
List<? extends InternalMultiBucketAggregation.InternalBucket> buckets = originalAgg.getBuckets();
|
||||
|
||||
List<InternalMultiBucketAggregation.InternalBucket> newBuckets = new ArrayList<>();
|
||||
for (InternalMultiBucketAggregation.InternalBucket bucket : buckets) {
|
||||
Map<String, Object> inputFields = new HashMap<>();
|
||||
|
||||
if (bucket.getDocCount() == 0) {
|
||||
// ignore this empty bucket unless the doc count is used
|
||||
if (bucketPathMap.containsKey("_count") == false) {
|
||||
newBuckets.add(bucket);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
for (Map.Entry<String, String> entry : bucketPathMap.entrySet()) {
|
||||
String aggName = entry.getKey();
|
||||
String bucketPath = entry.getValue();
|
||||
Object propertyValue = resolveBucketValue(originalAgg, bucket, bucketPath);
|
||||
|
||||
if (propertyValue instanceof Number) {
|
||||
double doubleVal = ((Number) propertyValue).doubleValue();
|
||||
// NaN or infinite values indicate a missing value or a
|
||||
// valid result of an invalid calculation. Either way only
|
||||
// a valid number will do
|
||||
if (Double.isFinite(doubleVal)) {
|
||||
inputFields.put(aggName, doubleVal);
|
||||
}
|
||||
} else if (propertyValue instanceof InternalNumericMetricsAggregation.SingleValue) {
|
||||
double doubleVal = ((InternalNumericMetricsAggregation.SingleValue) propertyValue).value();
|
||||
if (Double.isFinite(doubleVal)) {
|
||||
inputFields.put(aggName, doubleVal);
|
||||
}
|
||||
} else if (propertyValue instanceof StringTerms.Bucket) {
|
||||
StringTerms.Bucket b = (StringTerms.Bucket) propertyValue;
|
||||
inputFields.put(aggName, b.getKeyAsString());
|
||||
} else if (propertyValue instanceof String) {
|
||||
inputFields.put(aggName, propertyValue);
|
||||
} else if (propertyValue != null) {
|
||||
// Doubles, String terms or null are valid, any other type is an error
|
||||
throw invalidAggTypeError(bucketPath, propertyValue);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
InferenceResults inference;
|
||||
try {
|
||||
inference = model.infer(inputFields, configUpdate);
|
||||
} catch (Exception e) {
|
||||
inference = new WarningInferenceResults(e.getMessage());
|
||||
}
|
||||
|
||||
final List<InternalAggregation> aggs = bucket.getAggregations().asList().stream().map(
|
||||
(p) -> (InternalAggregation) p).collect(Collectors.toList());
|
||||
|
||||
InternalInferenceAggregation aggResult = new InternalInferenceAggregation(name(), metadata(), inference);
|
||||
aggs.add(aggResult);
|
||||
InternalMultiBucketAggregation.InternalBucket newBucket = originalAgg.createBucket(new InternalAggregations(aggs), bucket);
|
||||
newBuckets.add(newBucket);
|
||||
}
|
||||
|
||||
return originalAgg.create(newBuckets);
|
||||
}
|
||||
|
||||
public static Object resolveBucketValue(MultiBucketsAggregation agg,
|
||||
InternalMultiBucketAggregation.InternalBucket bucket,
|
||||
String aggPath) {
|
||||
|
||||
List<String> aggPathsList = AggregationPath.parse(aggPath).getPathElementsAsStringList();
|
||||
return bucket.getProperty(agg.getName(), aggPathsList);
|
||||
}
|
||||
|
||||
private static AggregationExecutionException invalidAggTypeError(String aggPath, Object propertyValue) {
|
||||
|
||||
String msg = AbstractPipelineAggregationBuilder.BUCKETS_PATH_FIELD.getPreferredName() +
|
||||
" must reference either a number value, a single value numeric metric aggregation or a string: got [" +
|
||||
propertyValue + "] of type [" + propertyValue.getClass().getSimpleName() + "] " +
|
||||
"] at aggregation [" + aggPath + "]";
|
||||
return new AggregationExecutionException(msg);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.inference.aggs;
|
||||
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregation;
|
||||
import org.elasticsearch.search.aggregations.InvalidAggregationPathException;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
public class InternalInferenceAggregation extends InternalAggregation {
|
||||
|
||||
private final InferenceResults inferenceResult;
|
||||
|
||||
protected InternalInferenceAggregation(String name, Map<String, Object> metadata,
|
||||
InferenceResults inferenceResult) {
|
||||
super(name, metadata);
|
||||
this.inferenceResult = inferenceResult;
|
||||
}
|
||||
|
||||
public InternalInferenceAggregation(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
inferenceResult = in.readNamedWriteable(InferenceResults.class);
|
||||
}
|
||||
|
||||
public InferenceResults getInferenceResult() {
|
||||
return inferenceResult;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doWriteTo(StreamOutput out) throws IOException {
|
||||
out.writeNamedWriteable(inferenceResult);
|
||||
}
|
||||
|
||||
@Override
|
||||
public InternalAggregation reduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
|
||||
throw new UnsupportedOperationException("Reducing an inference aggregation is not supported");
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public Object getProperty(List<String> path) {
|
||||
Object propertyValue;
|
||||
|
||||
if (path.isEmpty()) {
|
||||
propertyValue = this;
|
||||
} else if (path.size() == 1) {
|
||||
if (CommonFields.VALUE.getPreferredName().equals(path.get(0))) {
|
||||
propertyValue = inferenceResult.predictedValue();
|
||||
} else {
|
||||
throw invalidPathException(path);
|
||||
}
|
||||
} else {
|
||||
throw invalidPathException(path);
|
||||
}
|
||||
|
||||
return propertyValue;
|
||||
}
|
||||
|
||||
private InvalidAggregationPathException invalidPathException(List<String> path) {
|
||||
return new InvalidAggregationPathException("unknown property " + path + " for " +
|
||||
InferencePipelineAggregationBuilder.NAME + " aggregation [" + getName() + "]");
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
|
||||
return inferenceResult.toXContent(builder, params);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return "inference";
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(super.hashCode(), inferenceResult);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (this == obj) return true;
|
||||
if (obj == null || getClass() != obj.getClass()) return false;
|
||||
if (super.equals(obj) == false) return false;
|
||||
InternalInferenceAggregation other = (InternalInferenceAggregation) obj;
|
||||
return Objects.equals(inferenceResult, other.inferenceResult);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,131 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.inference.aggs;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParseException;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.aggregations.ParsedAggregation;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults.FEATURE_IMPORTANCE;
|
||||
|
||||
|
||||
/**
|
||||
* There isn't enough information in toXContent representation of the
|
||||
* {@link org.elasticsearch.xpack.core.ml.inference.results.InferenceResults}
|
||||
* objects to fully reconstruct them. In particular, depending on which
|
||||
* fields are written (result value, feature importance) it is not possible to
|
||||
* distinguish between a Regression result and a Classification result.
|
||||
*
|
||||
* This class parses the union all possible fields that may be written by
|
||||
* InferenceResults.
|
||||
*
|
||||
* The warning field is mutually exclusive with all the other fields.
|
||||
*/
|
||||
public class ParsedInference extends ParsedAggregation {
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<ParsedInference, Void> PARSER =
|
||||
new ConstructingObjectParser<>(ParsedInference.class.getSimpleName(), true,
|
||||
args -> new ParsedInference(args[0], (List<FeatureImportance>) args[1],
|
||||
(List<TopClassEntry>) args[2], (String) args[3]));
|
||||
|
||||
static {
|
||||
PARSER.declareField(optionalConstructorArg(), (p, n) -> {
|
||||
Object o;
|
||||
XContentParser.Token token = p.currentToken();
|
||||
if (token == XContentParser.Token.VALUE_STRING) {
|
||||
o = p.text();
|
||||
} else if (token == XContentParser.Token.VALUE_BOOLEAN) {
|
||||
o = p.booleanValue();
|
||||
} else if (token == XContentParser.Token.VALUE_NUMBER) {
|
||||
o = p.doubleValue();
|
||||
} else {
|
||||
throw new XContentParseException(p.getTokenLocation(),
|
||||
"[" + ParsedInference.class.getSimpleName() + "] failed to parse field [" + CommonFields.VALUE + "] "
|
||||
+ "value [" + token + "] is not a string, boolean or number");
|
||||
}
|
||||
return o;
|
||||
}, CommonFields.VALUE, ObjectParser.ValueType.VALUE);
|
||||
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> FeatureImportance.fromXContent(p),
|
||||
new ParseField(SingleValueInferenceResults.FEATURE_IMPORTANCE));
|
||||
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> TopClassEntry.fromXContent(p),
|
||||
new ParseField(ClassificationConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD));
|
||||
PARSER.declareString(optionalConstructorArg(), new ParseField(WarningInferenceResults.NAME));
|
||||
declareAggregationFields(PARSER);
|
||||
}
|
||||
|
||||
public static ParsedInference fromXContent(XContentParser parser, final String name) {
|
||||
ParsedInference parsed = PARSER.apply(parser, null);
|
||||
parsed.setName(name);
|
||||
return parsed;
|
||||
}
|
||||
|
||||
private final Object value;
|
||||
private final List<FeatureImportance> featureImportance;
|
||||
private final List<TopClassEntry> topClasses;
|
||||
private final String warning;
|
||||
|
||||
ParsedInference(Object value,
|
||||
List<FeatureImportance> featureImportance,
|
||||
List<TopClassEntry> topClasses,
|
||||
String warning) {
|
||||
this.value = value;
|
||||
this.warning = warning;
|
||||
this.featureImportance = featureImportance;
|
||||
this.topClasses = topClasses;
|
||||
}
|
||||
|
||||
public Object getValue() {
|
||||
return value;
|
||||
}
|
||||
|
||||
public List<FeatureImportance> getFeatureImportance() {
|
||||
return featureImportance;
|
||||
}
|
||||
|
||||
public List<TopClassEntry> getTopClasses() {
|
||||
return topClasses;
|
||||
}
|
||||
|
||||
public String getWarning() {
|
||||
return warning;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
|
||||
if (warning != null) {
|
||||
builder.field(WarningInferenceResults.WARNING.getPreferredName(), warning);
|
||||
} else {
|
||||
builder.field(CommonFields.VALUE.getPreferredName(), value);
|
||||
if (topClasses != null && topClasses.size() > 0) {
|
||||
builder.field(ClassificationConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD, topClasses);
|
||||
}
|
||||
if (featureImportance != null && featureImportance.size() > 0) {
|
||||
builder.field(FEATURE_IMPORTANCE, featureImportance);
|
||||
}
|
||||
}
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getType() {
|
||||
return InferencePipelineAggregationBuilder.NAME;
|
||||
}
|
||||
}
|
|
@ -28,7 +28,6 @@ import org.elasticsearch.ingest.PipelineConfiguration;
|
|||
import org.elasticsearch.ingest.Processor;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
|
@ -42,10 +41,8 @@ import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
|
|||
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.function.BiConsumer;
|
||||
import java.util.function.Consumer;
|
||||
|
@ -172,10 +169,6 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
private static final int MAX_INFERENCE_PROCESSOR_SEARCH_RECURSIONS = 10;
|
||||
private static final Logger logger = LogManager.getLogger(Factory.class);
|
||||
|
||||
private static final Set<String> RESERVED_ML_FIELD_NAMES = new HashSet<>(Arrays.asList(
|
||||
WarningInferenceResults.WARNING.getPreferredName(),
|
||||
MODEL_ID));
|
||||
|
||||
private final Client client;
|
||||
private final InferenceAuditor auditor;
|
||||
private volatile int currentInferenceProcessors;
|
||||
|
@ -333,12 +326,10 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
if (inferenceConfig.containsKey(ClassificationConfig.NAME.getPreferredName())) {
|
||||
checkSupportedVersion(ClassificationConfig.EMPTY_PARAMS);
|
||||
ClassificationConfigUpdate config = ClassificationConfigUpdate.fromMap(valueMap);
|
||||
checkFieldUniqueness(config.getResultsField(), config.getTopClassesResultsField());
|
||||
return config;
|
||||
} else if (inferenceConfig.containsKey(RegressionConfig.NAME.getPreferredName())) {
|
||||
checkSupportedVersion(RegressionConfig.EMPTY_PARAMS);
|
||||
RegressionConfigUpdate config = RegressionConfigUpdate.fromMap(valueMap);
|
||||
checkFieldUniqueness(config.getResultsField());
|
||||
return config;
|
||||
} else {
|
||||
throw ExceptionsHelper.badRequestException("unrecognized inference configuration type {}. Supported types {}",
|
||||
|
@ -347,26 +338,6 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
}
|
||||
}
|
||||
|
||||
private static void checkFieldUniqueness(String... fieldNames) {
|
||||
Set<String> duplicatedFieldNames = new HashSet<>();
|
||||
Set<String> currentFieldNames = new HashSet<>(RESERVED_ML_FIELD_NAMES);
|
||||
for(String fieldName : fieldNames) {
|
||||
if (fieldName == null) {
|
||||
continue;
|
||||
}
|
||||
if (currentFieldNames.contains(fieldName)) {
|
||||
duplicatedFieldNames.add(fieldName);
|
||||
} else {
|
||||
currentFieldNames.add(fieldName);
|
||||
}
|
||||
}
|
||||
if (duplicatedFieldNames.isEmpty() == false) {
|
||||
throw ExceptionsHelper.badRequestException("Cannot create processor as configured." +
|
||||
" More than one field is configured as {}",
|
||||
duplicatedFieldNames);
|
||||
}
|
||||
}
|
||||
|
||||
void checkSupportedVersion(InferenceConfig config) {
|
||||
if (config.getMinimalSupportedVersion().after(minNodeVersion)) {
|
||||
throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION,
|
||||
|
|
|
@ -22,6 +22,7 @@ import java.util.HashMap;
|
|||
import java.util.HashSet;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.concurrent.atomic.LongAdder;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING;
|
||||
|
@ -78,7 +79,7 @@ public class LocalModel {
|
|||
persistenceQuotient = 10_000;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public void infer(Map<String, Object> fields, InferenceConfigUpdate update, ActionListener<InferenceResults> listener) {
|
||||
if (update.isSupported(this.inferenceConfig) == false) {
|
||||
listener.onFailure(ExceptionsHelper.badRequestException(
|
||||
|
@ -116,6 +117,22 @@ public class LocalModel {
|
|||
}
|
||||
}
|
||||
|
||||
public InferenceResults infer(Map<String, Object> fields, InferenceConfigUpdate update) throws Exception {
|
||||
AtomicReference<InferenceResults> result = new AtomicReference<>();
|
||||
AtomicReference<Exception> exception = new AtomicReference<>();
|
||||
ActionListener<InferenceResults> listener = ActionListener.wrap(
|
||||
result::set,
|
||||
exception::set
|
||||
);
|
||||
|
||||
infer(fields, update, listener);
|
||||
if (exception.get() != null) {
|
||||
throw exception.get();
|
||||
}
|
||||
|
||||
return result.get();
|
||||
}
|
||||
|
||||
/**
|
||||
* Used for translating field names in according to the passed `fieldMappings` parameter.
|
||||
*
|
||||
|
|
|
@ -0,0 +1,130 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.inference.aggs;
|
||||
|
||||
import org.apache.lucene.util.SetOnce;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.plugins.SearchPlugin;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.BasePipelineAggregationTestCase;
|
||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdateTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdateTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdate;
|
||||
import org.elasticsearch.xpack.ml.MachineLearning;
|
||||
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
public class InferencePipelineAggregationBuilderTests extends BasePipelineAggregationTestCase<InferencePipelineAggregationBuilder> {
|
||||
|
||||
private static final String NAME = "inf-agg";
|
||||
|
||||
@Override
|
||||
protected List<SearchPlugin> plugins() {
|
||||
return Collections.singletonList(new MachineLearning(Settings.EMPTY, null));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected List<NamedXContentRegistry.Entry> additionalNamedContents() {
|
||||
return new MlInferenceNamedXContentProvider().getNamedXContentParsers();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected List<NamedWriteableRegistry.Entry> additionalNamedWriteables() {
|
||||
return new MlInferenceNamedXContentProvider().getNamedWriteables();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected InferencePipelineAggregationBuilder createTestAggregatorFactory() {
|
||||
Map<String, String> bucketPaths = Stream.generate(() -> randomAlphaOfLength(8))
|
||||
.limit(randomIntBetween(1, 4))
|
||||
.collect(Collectors.toMap(Function.identity(), (t) -> randomAlphaOfLength(5)));
|
||||
|
||||
InferencePipelineAggregationBuilder builder =
|
||||
new InferencePipelineAggregationBuilder(NAME, new SetOnce<>(mock(ModelLoadingService.class)), bucketPaths);
|
||||
builder.setModelId(randomAlphaOfLength(6));
|
||||
|
||||
if (randomBoolean()) {
|
||||
InferenceConfigUpdate config;
|
||||
if (randomBoolean()) {
|
||||
config = ClassificationConfigUpdateTests.randomClassificationConfigUpdate();
|
||||
} else {
|
||||
config = RegressionConfigUpdateTests.randomRegressionConfigUpdate();
|
||||
}
|
||||
builder.setInferenceConfig(config);
|
||||
}
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
||||
public void testAdaptForAggregation_givenNull() {
|
||||
InferenceConfigUpdate update = InferencePipelineAggregationBuilder.adaptForAggregation(null);
|
||||
assertThat(update, is(instanceOf(ResultsFieldUpdate.class)));
|
||||
assertEquals(InferencePipelineAggregationBuilder.AGGREGATIONS_RESULTS_FIELD, update.getResultsField());
|
||||
}
|
||||
|
||||
public void testAdaptForAggregation() {
|
||||
RegressionConfigUpdate regressionConfigUpdate = new RegressionConfigUpdate(null, 20);
|
||||
InferenceConfigUpdate update = InferencePipelineAggregationBuilder.adaptForAggregation(regressionConfigUpdate);
|
||||
assertEquals(InferencePipelineAggregationBuilder.AGGREGATIONS_RESULTS_FIELD, update.getResultsField());
|
||||
|
||||
ClassificationConfigUpdate configUpdate = new ClassificationConfigUpdate(1, null, null, null, null);
|
||||
update = InferencePipelineAggregationBuilder.adaptForAggregation(configUpdate);
|
||||
assertEquals(InferencePipelineAggregationBuilder.AGGREGATIONS_RESULTS_FIELD, update.getResultsField());
|
||||
}
|
||||
|
||||
public void testValidate() {
|
||||
InferencePipelineAggregationBuilder aggregationBuilder = createTestAggregatorFactory();
|
||||
PipelineAggregationBuilder.ValidationContext validationContext =
|
||||
PipelineAggregationBuilder.ValidationContext.forInsideTree(mock(AggregationBuilder.class), null);
|
||||
|
||||
aggregationBuilder.setModelId(null);
|
||||
aggregationBuilder.validate(validationContext);
|
||||
List<String> errors = validationContext.getValidationException().validationErrors();
|
||||
assertEquals("[model_id] must be set", errors.get(0));
|
||||
}
|
||||
|
||||
public void testValidate_invalidResultsField() {
|
||||
InferencePipelineAggregationBuilder aggregationBuilder = createTestAggregatorFactory();
|
||||
PipelineAggregationBuilder.ValidationContext validationContext =
|
||||
PipelineAggregationBuilder.ValidationContext.forInsideTree(mock(AggregationBuilder.class), null);
|
||||
|
||||
RegressionConfigUpdate regressionConfigUpdate = new RegressionConfigUpdate("foo", null);
|
||||
aggregationBuilder.setInferenceConfig(regressionConfigUpdate);
|
||||
aggregationBuilder.validate(validationContext);
|
||||
List<String> errors = validationContext.getValidationException().validationErrors();
|
||||
assertEquals("setting option [results_field] to [foo] is not valid for inference aggregations", errors.get(0));
|
||||
}
|
||||
|
||||
public void testValidate_invalidTopClassesField() {
|
||||
InferencePipelineAggregationBuilder aggregationBuilder = createTestAggregatorFactory();
|
||||
PipelineAggregationBuilder.ValidationContext validationContext =
|
||||
PipelineAggregationBuilder.ValidationContext.forInsideTree(mock(AggregationBuilder.class), null);
|
||||
|
||||
ClassificationConfigUpdate configUpdate = new ClassificationConfigUpdate(1, null, "some_other_field", null, null);
|
||||
aggregationBuilder.setInferenceConfig(configUpdate);
|
||||
aggregationBuilder.validate(validationContext);
|
||||
List<String> errors = validationContext.getValidationException().validationErrors();
|
||||
assertEquals("setting option [top_classes] to [some_other_field] is not valid for inference aggregations", errors.get(0));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,227 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.inference.aggs;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.plugins.SearchPlugin;
|
||||
import org.elasticsearch.search.aggregations.Aggregation;
|
||||
import org.elasticsearch.search.aggregations.InvalidAggregationPathException;
|
||||
import org.elasticsearch.search.aggregations.ParsedAggregation;
|
||||
import org.elasticsearch.test.InternalAggregationTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResultsTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResultsTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.ml.MachineLearning;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
import static org.hamcrest.Matchers.sameInstance;
|
||||
|
||||
public class InternalInferenceAggregationTests extends InternalAggregationTestCase<InternalInferenceAggregation> {
|
||||
|
||||
@Override
|
||||
protected SearchPlugin registerPlugin() {
|
||||
return new MachineLearning(Settings.EMPTY, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected List<NamedWriteableRegistry.Entry> getNamedWriteables() {
|
||||
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>(super.getNamedWriteables());
|
||||
entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
|
||||
return entries;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected List<NamedXContentRegistry.Entry> getNamedXContents() {
|
||||
List<NamedXContentRegistry.Entry> entries = new ArrayList<>(super.getNamedXContents());
|
||||
entries.add(new NamedXContentRegistry.Entry(Aggregation.class,
|
||||
new ParseField(InferencePipelineAggregationBuilder.NAME), (p, c) -> ParsedInference.fromXContent(p, (String)c)));
|
||||
|
||||
return entries;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Predicate<String> excludePathsFromXContentInsertion() {
|
||||
return p -> p.contains("top_classes") || p.contains("feature_importance");
|
||||
}
|
||||
|
||||
@Override
|
||||
protected InternalInferenceAggregation createTestInstance(String name, Map<String, Object> metadata) {
|
||||
InferenceResults result;
|
||||
|
||||
if (randomBoolean()) {
|
||||
// build a random result with the result field set to `value`
|
||||
ClassificationInferenceResults randomResults = ClassificationInferenceResultsTests.createRandomResults();
|
||||
result = new ClassificationInferenceResults(
|
||||
randomResults.value(),
|
||||
randomResults.getClassificationLabel(),
|
||||
randomResults.getTopClasses(),
|
||||
randomResults.getFeatureImportance(),
|
||||
new ClassificationConfig(null, "value", null, null, randomResults.getPredictionFieldType())
|
||||
);
|
||||
} else if (randomBoolean()) {
|
||||
// build a random result with the result field set to `value`
|
||||
RegressionInferenceResults randomResults = RegressionInferenceResultsTests.createRandomResults();
|
||||
result = new RegressionInferenceResults(
|
||||
randomResults.value(),
|
||||
"value",
|
||||
randomResults.getFeatureImportance());
|
||||
} else {
|
||||
result = new WarningInferenceResults("this is a warning");
|
||||
}
|
||||
|
||||
return new InternalInferenceAggregation(name, metadata, result);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void testReduceRandom() {
|
||||
expectThrows(UnsupportedOperationException.class, () -> createTestInstance("name", null).reduce(null, null));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void assertReduced(InternalInferenceAggregation reduced, List<InternalInferenceAggregation> inputs) {
|
||||
// no test since reduce operation is unsupported
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void assertFromXContent(InternalInferenceAggregation agg, ParsedAggregation parsedAggregation) {
|
||||
ParsedInference parsed = ((ParsedInference) parsedAggregation);
|
||||
|
||||
InferenceResults result = agg.getInferenceResult();
|
||||
if (result instanceof WarningInferenceResults) {
|
||||
WarningInferenceResults warning = (WarningInferenceResults) result;
|
||||
assertEquals(warning.getWarning(), parsed.getWarning());
|
||||
} else if (result instanceof RegressionInferenceResults) {
|
||||
RegressionInferenceResults regression = (RegressionInferenceResults) result;
|
||||
assertEquals(regression.value(), parsed.getValue());
|
||||
List<FeatureImportance> featureImportance = regression.getFeatureImportance();
|
||||
if (featureImportance.isEmpty()) {
|
||||
featureImportance = null;
|
||||
}
|
||||
assertEquals(featureImportance, parsed.getFeatureImportance());
|
||||
} else if (result instanceof ClassificationInferenceResults) {
|
||||
ClassificationInferenceResults classification = (ClassificationInferenceResults) result;
|
||||
assertEquals(classification.predictedValue(), parsed.getValue());
|
||||
|
||||
List<FeatureImportance> featureImportance = classification.getFeatureImportance();
|
||||
if (featureImportance.isEmpty()) {
|
||||
featureImportance = null;
|
||||
}
|
||||
assertEquals(featureImportance, parsed.getFeatureImportance());
|
||||
|
||||
List<TopClassEntry> topClasses = classification.getTopClasses();
|
||||
if (topClasses.isEmpty()) {
|
||||
topClasses = null;
|
||||
}
|
||||
assertEquals(topClasses, parsed.getTopClasses());
|
||||
}
|
||||
}
|
||||
|
||||
public void testGetProperty_givenEmptyPath() {
|
||||
InternalInferenceAggregation internalAgg = createTestInstance();
|
||||
assertThat(internalAgg, sameInstance(internalAgg.getProperty(Collections.emptyList())));
|
||||
}
|
||||
|
||||
public void testGetProperty_givenTooLongPath() {
|
||||
InternalInferenceAggregation internalAgg = createTestInstance();
|
||||
InvalidAggregationPathException e = expectThrows(InvalidAggregationPathException.class,
|
||||
() -> internalAgg.getProperty(Arrays.asList("one", "two")));
|
||||
|
||||
String message = "unknown property [one, two] for inference aggregation [" + internalAgg.getName() + "]";
|
||||
assertEquals(message, e.getMessage());
|
||||
}
|
||||
|
||||
public void testGetProperty_givenWrongPath() {
|
||||
InternalInferenceAggregation internalAgg = createTestInstance();
|
||||
InvalidAggregationPathException e = expectThrows(InvalidAggregationPathException.class,
|
||||
() -> internalAgg.getProperty(Collections.singletonList("bar")));
|
||||
|
||||
String message = "unknown property [bar] for inference aggregation [" + internalAgg.getName() + "]";
|
||||
assertEquals(message, e.getMessage());
|
||||
}
|
||||
|
||||
public void testGetProperty_value() {
|
||||
{
|
||||
ClassificationInferenceResults results = ClassificationInferenceResultsTests.createRandomResults();
|
||||
InternalInferenceAggregation internalAgg = new InternalInferenceAggregation("foo", Collections.emptyMap(), results);
|
||||
assertEquals(results.predictedValue(), internalAgg.getProperty(Collections.singletonList("value")));
|
||||
}
|
||||
|
||||
{
|
||||
RegressionInferenceResults results = RegressionInferenceResultsTests.createRandomResults();
|
||||
InternalInferenceAggregation internalAgg = new InternalInferenceAggregation("foo", Collections.emptyMap(), results);
|
||||
assertEquals(results.value(), internalAgg.getProperty(Collections.singletonList("value")));
|
||||
}
|
||||
|
||||
{
|
||||
WarningInferenceResults results = new WarningInferenceResults("a warning from history");
|
||||
InternalInferenceAggregation internalAgg = new InternalInferenceAggregation("foo", Collections.emptyMap(), results);
|
||||
assertNull(internalAgg.getProperty(Collections.singletonList("value")));
|
||||
}
|
||||
}
|
||||
|
||||
public void testGetProperty_featureImportance() {
|
||||
{
|
||||
ClassificationInferenceResults results = ClassificationInferenceResultsTests.createRandomResults();
|
||||
InternalInferenceAggregation internalAgg = new InternalInferenceAggregation("foo", Collections.emptyMap(), results);
|
||||
expectThrows(InvalidAggregationPathException.class,
|
||||
() -> internalAgg.getProperty(Collections.singletonList("feature_importance")));
|
||||
}
|
||||
|
||||
{
|
||||
RegressionInferenceResults results = RegressionInferenceResultsTests.createRandomResults();
|
||||
InternalInferenceAggregation internalAgg = new InternalInferenceAggregation("foo", Collections.emptyMap(), results);
|
||||
expectThrows(InvalidAggregationPathException.class,
|
||||
() -> internalAgg.getProperty(Collections.singletonList("feature_importance")));
|
||||
}
|
||||
|
||||
{
|
||||
WarningInferenceResults results = new WarningInferenceResults("a warning from history");
|
||||
InternalInferenceAggregation internalAgg = new InternalInferenceAggregation("foo", Collections.emptyMap(), results);
|
||||
expectThrows(InvalidAggregationPathException.class,
|
||||
() -> internalAgg.getProperty(Collections.singletonList("feature_importance")));
|
||||
}
|
||||
}
|
||||
|
||||
public void testGetProperty_topClasses() {
|
||||
{
|
||||
ClassificationInferenceResults results = ClassificationInferenceResultsTests.createRandomResults();
|
||||
InternalInferenceAggregation internalAgg = new InternalInferenceAggregation("foo", Collections.emptyMap(), results);
|
||||
expectThrows(InvalidAggregationPathException.class,
|
||||
() -> internalAgg.getProperty(Collections.singletonList("top_classes")));
|
||||
}
|
||||
|
||||
{
|
||||
RegressionInferenceResults results = RegressionInferenceResultsTests.createRandomResults();
|
||||
InternalInferenceAggregation internalAgg = new InternalInferenceAggregation("foo", Collections.emptyMap(), results);
|
||||
expectThrows(InvalidAggregationPathException.class,
|
||||
() -> internalAgg.getProperty(Collections.singletonList("top_classes")));
|
||||
}
|
||||
|
||||
{
|
||||
WarningInferenceResults results = new WarningInferenceResults("a warning from history");
|
||||
InternalInferenceAggregation internalAgg = new InternalInferenceAggregation("foo", Collections.emptyMap(), results);
|
||||
expectThrows(InvalidAggregationPathException.class,
|
||||
() -> internalAgg.getProperty(Collections.singletonList("top_classes")));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -291,7 +291,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|||
processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, regression);
|
||||
fail("should not have succeeded creating with duplicate fields");
|
||||
} catch (Exception ex) {
|
||||
assertThat(ex.getMessage(), equalTo("Cannot create processor as configured. " +
|
||||
assertThat(ex.getMessage(), equalTo("Cannot apply inference config. " +
|
||||
"More than one field is configured as [warning]"));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
|
|||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
|
||||
|
@ -92,9 +93,9 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
Map<String, Object> ingestMetadata = new HashMap<>();
|
||||
IngestDocument document = new IngestDocument(source, ingestMetadata);
|
||||
|
||||
List<ClassificationInferenceResults.TopClassEntry> classes = new ArrayList<>(2);
|
||||
classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6, 0.6));
|
||||
classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4, 0.4));
|
||||
List<TopClassEntry> classes = new ArrayList<>(2);
|
||||
classes.add(new TopClassEntry("foo", 0.6, 0.6));
|
||||
classes.add(new TopClassEntry("bar", 0.4, 0.4));
|
||||
|
||||
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
|
||||
Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", classes, classificationConfig)),
|
||||
|
@ -102,7 +103,7 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
inferenceProcessor.mutateDocument(response, document);
|
||||
|
||||
assertThat((List<Map<?,?>>)document.getFieldValue("ml.my_processor.top_classes", List.class),
|
||||
contains(classes.stream().map(ClassificationInferenceResults.TopClassEntry::asValueMap).toArray(Map[]::new)));
|
||||
contains(classes.stream().map(TopClassEntry::asValueMap).toArray(Map[]::new)));
|
||||
assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("classification_model"));
|
||||
assertThat(document.getFieldValue("ml.my_processor.predicted_value", String.class), equalTo("foo"));
|
||||
}
|
||||
|
@ -122,9 +123,9 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
Map<String, Object> ingestMetadata = new HashMap<>();
|
||||
IngestDocument document = new IngestDocument(source, ingestMetadata);
|
||||
|
||||
List<ClassificationInferenceResults.TopClassEntry> classes = new ArrayList<>(2);
|
||||
classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6, 0.6));
|
||||
classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4, 0.4));
|
||||
List<TopClassEntry> classes = new ArrayList<>(2);
|
||||
classes.add(new TopClassEntry("foo", 0.6, 0.6));
|
||||
classes.add(new TopClassEntry("bar", 0.4, 0.4));
|
||||
|
||||
List<FeatureImportance> featureInfluence = new ArrayList<>();
|
||||
featureInfluence.add(FeatureImportance.forRegression("feature_1", 1.13));
|
||||
|
@ -163,9 +164,9 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
Map<String, Object> ingestMetadata = new HashMap<>();
|
||||
IngestDocument document = new IngestDocument(source, ingestMetadata);
|
||||
|
||||
List<ClassificationInferenceResults.TopClassEntry> classes = new ArrayList<>(2);
|
||||
classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6, 0.6));
|
||||
classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4, 0.4));
|
||||
List<TopClassEntry> classes = new ArrayList<>(2);
|
||||
classes.add(new TopClassEntry("foo", 0.6, 0.6));
|
||||
classes.add(new TopClassEntry("bar", 0.4, 0.4));
|
||||
|
||||
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
|
||||
Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", classes, classificationConfig)),
|
||||
|
@ -173,7 +174,7 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
inferenceProcessor.mutateDocument(response, document);
|
||||
|
||||
assertThat((List<Map<?,?>>)document.getFieldValue("ml.my_processor.tops", List.class),
|
||||
contains(classes.stream().map(ClassificationInferenceResults.TopClassEntry::asValueMap).toArray(Map[]::new)));
|
||||
contains(classes.stream().map(TopClassEntry::asValueMap).toArray(Map[]::new)));
|
||||
assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("classification_model"));
|
||||
assertThat(document.getFieldValue("ml.my_processor.result", String.class), equalTo("foo"));
|
||||
}
|
||||
|
|
|
@ -0,0 +1,266 @@
|
|||
setup:
|
||||
- skip:
|
||||
features: headers
|
||||
- do:
|
||||
headers:
|
||||
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
|
||||
ml.put_trained_model:
|
||||
model_id: a-complex-regression-model
|
||||
body: >
|
||||
{
|
||||
"description": "super complex model for tests",
|
||||
"input": {"field_names": ["avg_cost", "item"]},
|
||||
"inference_config": {
|
||||
"regression": {
|
||||
"results_field": "regression-value",
|
||||
"num_top_feature_importance_values": 2
|
||||
}
|
||||
},
|
||||
"definition": {
|
||||
"preprocessors" : [{
|
||||
"one_hot_encoding": {
|
||||
"field": "product_type",
|
||||
"hot_map": {
|
||||
"TV": "type_tv",
|
||||
"VCR": "type_vcr",
|
||||
"Laptop": "type_laptop"
|
||||
}
|
||||
}
|
||||
}],
|
||||
"trained_model": {
|
||||
"ensemble": {
|
||||
"feature_names": [],
|
||||
"target_type": "regression",
|
||||
"trained_models": [
|
||||
{
|
||||
"tree": {
|
||||
"feature_names": [
|
||||
"avg_cost", "type_tv", "type_vcr", "type_laptop"
|
||||
],
|
||||
"tree_structure": [
|
||||
{
|
||||
"node_index": 0,
|
||||
"split_feature": 0,
|
||||
"split_gain": 12,
|
||||
"threshold": 38,
|
||||
"decision_type": "lte",
|
||||
"default_left": true,
|
||||
"left_child": 1,
|
||||
"right_child": 2
|
||||
},
|
||||
{
|
||||
"node_index": 1,
|
||||
"leaf_value": 5.0
|
||||
},
|
||||
{
|
||||
"node_index": 2,
|
||||
"leaf_value": 2.0
|
||||
}
|
||||
],
|
||||
"target_type": "regression"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
- do:
|
||||
headers:
|
||||
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
|
||||
indices.create:
|
||||
index: store
|
||||
body:
|
||||
mappings:
|
||||
properties:
|
||||
product:
|
||||
type: keyword
|
||||
cost:
|
||||
type: integer
|
||||
time:
|
||||
type: date
|
||||
|
||||
- do:
|
||||
headers:
|
||||
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
|
||||
Content-Type: application/json
|
||||
bulk:
|
||||
index: store
|
||||
refresh: true
|
||||
body: |
|
||||
{ "index": {} }
|
||||
{ "product": "TV", "cost": 300, "time": 1587501233000 }
|
||||
{ "index": {} }
|
||||
{ "product": "TV", "cost": 400, "time": 1587501233000}
|
||||
{ "index": {} }
|
||||
{ "product": "VCR", "cost": 150, "time": 1587501233000 }
|
||||
{ "index": {} }
|
||||
{ "product": "VCR", "cost": 180, "time": 1587501233000 }
|
||||
{ "index": {} }
|
||||
{ "product": "Laptop", "cost": 15000, "time": 1587501233000 }
|
||||
|
||||
---
|
||||
"Test pipeline regression simple":
|
||||
|
||||
- do:
|
||||
search:
|
||||
index: store
|
||||
body: |
|
||||
{
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
"good": {
|
||||
"terms": {
|
||||
"field": "product",
|
||||
"size": 10
|
||||
},
|
||||
"aggs": {
|
||||
"avg_cost_agg": {
|
||||
"avg": {
|
||||
"field": "cost"
|
||||
}
|
||||
},
|
||||
"regression_agg": {
|
||||
"inference": {
|
||||
"model_id": "a-complex-regression-model",
|
||||
"inference_config": {
|
||||
"regression": {
|
||||
"results_field": "value"
|
||||
}
|
||||
},
|
||||
"buckets_path": {
|
||||
"avg_cost": "avg_cost_agg"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
- match: { aggregations.good.buckets.0.regression_agg.value: 2.0 }
|
||||
---
|
||||
"Test pipeline agg referencing a single bucket":
|
||||
|
||||
- do:
|
||||
search:
|
||||
index: store
|
||||
body: |
|
||||
{
|
||||
"size": 0,
|
||||
"query": {
|
||||
"match_all": {}
|
||||
},
|
||||
"aggs": {
|
||||
"date_histo": {
|
||||
"date_histogram": {
|
||||
"field": "time",
|
||||
"fixed_interval": "1d"
|
||||
},
|
||||
"aggs": {
|
||||
"good": {
|
||||
"terms": {
|
||||
"field": "product",
|
||||
"size": 10
|
||||
},
|
||||
"aggs": {
|
||||
"avg_cost_agg": {
|
||||
"avg": {
|
||||
"field": "cost"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"regression_agg": {
|
||||
"inference": {
|
||||
"model_id": "a-complex-regression-model",
|
||||
"buckets_path": {
|
||||
"avg_cost": "good['TV']>avg_cost_agg",
|
||||
"product_type": "good['TV']"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
- match: { aggregations.date_histo.buckets.0.regression_agg.value: 2.0 }
|
||||
|
||||
---
|
||||
"Test all fields missing warning":
|
||||
|
||||
- do:
|
||||
search:
|
||||
index: store
|
||||
body: |
|
||||
{
|
||||
"size": 0,
|
||||
"query": { "match_all" : { } },
|
||||
"aggs": {
|
||||
"good": {
|
||||
"terms": {
|
||||
"field": "product",
|
||||
"size": 10
|
||||
},
|
||||
"aggs": {
|
||||
"avg_cost_agg": {
|
||||
"avg": {
|
||||
"field": "cost"
|
||||
}
|
||||
},
|
||||
"regression_agg" : {
|
||||
"inference": {
|
||||
"model_id": "a-complex-regression-model",
|
||||
"buckets_path": {
|
||||
"cost" : "avg_cost_agg"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
- match: { aggregations.good.buckets.0.regression_agg.warning: "Model [a-complex-regression-model] could not be inferred as all fields were missing" }
|
||||
|
||||
---
|
||||
"Test setting results field is invalid":
|
||||
|
||||
- do:
|
||||
catch: /action_request_validation_exception/
|
||||
search:
|
||||
index: store
|
||||
body: |
|
||||
{
|
||||
"size": 0,
|
||||
"query": { "match_all" : { } },
|
||||
"aggs": {
|
||||
"good": {
|
||||
"terms": {
|
||||
"field": "product",
|
||||
"size": 10
|
||||
},
|
||||
"aggs": {
|
||||
"avg_cost_agg": {
|
||||
"avg": {
|
||||
"field": "cost"
|
||||
}
|
||||
},
|
||||
"regression_agg" : {
|
||||
"inference": {
|
||||
"model_id": "a-complex-regression-model",
|
||||
"inference_config": {
|
||||
"regression": {
|
||||
"results_field": "banana"
|
||||
}
|
||||
},
|
||||
"buckets_path": {
|
||||
"cost" : "avg_cost_agg"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
- match: { error.root_cause.0.type: "action_request_validation_exception" }
|
||||
- match: { error.root_cause.0.reason: "Validation Failed: 1: setting option [results_field] to [banana] is not valid for inference aggregations;" }
|
Loading…
Reference in New Issue