[ML] adds feature_importance_baseline object to model metadata (#63172) (#63237)

this adds the new field `feature_importance_baseline` and allows it to be optionally be included in the model's metadata.

Related to: https://github.com/elastic/ml-cpp/pull/1522
This commit is contained in:
Benjamin Trent 2020-10-05 09:33:38 -04:00 committed by GitHub
parent dab1b14a10
commit 1e63313c19
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 481 additions and 372 deletions

View File

@ -1,208 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.trainedmodel.metadata;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParseException;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
public class TotalFeatureImportance implements ToXContentObject {
private static final String NAME = "total_feature_importance";
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
public static final ParseField IMPORTANCE = new ParseField("importance");
public static final ParseField CLASSES = new ParseField("classes");
public static final ParseField MEAN_MAGNITUDE = new ParseField("mean_magnitude");
public static final ParseField MIN = new ParseField("min");
public static final ParseField MAX = new ParseField("max");
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<TotalFeatureImportance, Void> PARSER = new ConstructingObjectParser<>(NAME,
true,
a -> new TotalFeatureImportance((String)a[0], (Importance)a[1], (List<ClassImportance>)a[2]));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), Importance.PARSER, IMPORTANCE);
PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), ClassImportance.PARSER, CLASSES);
}
public static TotalFeatureImportance fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
public final String featureName;
public final Importance importance;
public final List<ClassImportance> classImportances;
TotalFeatureImportance(String featureName, @Nullable Importance importance, @Nullable List<ClassImportance> classImportances) {
this.featureName = featureName;
this.importance = importance;
this.classImportances = classImportances == null ? Collections.emptyList() : classImportances;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FEATURE_NAME.getPreferredName(), featureName);
if (importance != null) {
builder.field(IMPORTANCE.getPreferredName(), importance);
}
if (classImportances.isEmpty() == false) {
builder.field(CLASSES.getPreferredName(), classImportances);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TotalFeatureImportance that = (TotalFeatureImportance) o;
return Objects.equals(that.importance, importance)
&& Objects.equals(featureName, that.featureName)
&& Objects.equals(classImportances, that.classImportances);
}
@Override
public int hashCode() {
return Objects.hash(featureName, importance, classImportances);
}
public static class Importance implements ToXContentObject {
private static final String NAME = "importance";
public static final ConstructingObjectParser<Importance, Void> PARSER = new ConstructingObjectParser<>(NAME,
true,
a -> new Importance((double)a[0], (double)a[1], (double)a[2]));
static {
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MEAN_MAGNITUDE);
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MIN);
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MAX);
}
private final double meanMagnitude;
private final double min;
private final double max;
public Importance(double meanMagnitude, double min, double max) {
this.meanMagnitude = meanMagnitude;
this.min = min;
this.max = max;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Importance that = (Importance) o;
return Double.compare(that.meanMagnitude, meanMagnitude) == 0 &&
Double.compare(that.min, min) == 0 &&
Double.compare(that.max, max) == 0;
}
@Override
public int hashCode() {
return Objects.hash(meanMagnitude, min, max);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude);
builder.field(MIN.getPreferredName(), min);
builder.field(MAX.getPreferredName(), max);
builder.endObject();
return builder;
}
}
public static class ClassImportance implements ToXContentObject {
private static final String NAME = "total_class_importance";
public static final ParseField CLASS_NAME = new ParseField("class_name");
public static final ParseField IMPORTANCE = new ParseField("importance");
public static final ConstructingObjectParser<ClassImportance, Void> PARSER = new ConstructingObjectParser<>(NAME,
true,
a -> new ClassImportance(a[0], (Importance)a[1]));
static {
PARSER.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> {
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
return p.text();
} else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) {
return p.numberValue();
} else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) {
return p.booleanValue();
}
throw new XContentParseException("Unsupported token [" + p.currentToken() + "]");
}, CLASS_NAME, ObjectParser.ValueType.VALUE);
PARSER.declareObject(ConstructingObjectParser.constructorArg(), Importance.PARSER, IMPORTANCE);
}
public static ClassImportance fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
public final Object className;
public final Importance importance;
ClassImportance(Object className, Importance importance) {
this.className = className;
this.importance = importance;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CLASS_NAME.getPreferredName(), className);
builder.field(IMPORTANCE.getPreferredName(), importance);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ClassImportance that = (ClassImportance) o;
return Objects.equals(that.importance, importance) && Objects.equals(className, that.className);
}
@Override
public int hashCode() {
return Objects.hash(className, importance);
}
}
}

View File

@ -1,71 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.trainedmodel.metadata;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class TotalFeatureImportanceTests extends AbstractXContentTestCase<TotalFeatureImportance> {
@SuppressWarnings("unchecked")
public static TotalFeatureImportance randomInstance() {
Supplier<Object> classNameGenerator = randomFrom(
() -> randomAlphaOfLength(10),
ESTestCase::randomBoolean,
() -> randomIntBetween(0, 10)
);
return new TotalFeatureImportance(
randomAlphaOfLength(10),
randomBoolean() ? null : randomImportance(),
randomBoolean() ?
null :
Stream.generate(() -> new TotalFeatureImportance.ClassImportance(classNameGenerator.get(), randomImportance()))
.limit(randomIntBetween(1, 10))
.collect(Collectors.toList())
);
}
private static TotalFeatureImportance.Importance randomImportance() {
return new TotalFeatureImportance.Importance(randomDouble(), randomDouble(), randomDouble());
}
@Override
protected TotalFeatureImportance createTestInstance() {
return randomInstance();
}
@Override
protected TotalFeatureImportance doParseInstance(XContentParser parser) throws IOException {
return TotalFeatureImportance.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
}

View File

@ -10,6 +10,7 @@ import org.elasticsearch.action.ActionType;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest;
import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse;
@ -34,41 +35,34 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
super(NAME, Response::new);
}
public static class Request extends AbstractGetResourcesRequest {
public static class Includes implements Writeable {
static final String DEFINITION = "definition";
static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance";
static final String FEATURE_IMPORTANCE_BASELINE = "feature_importance_baseline";
private static final Set<String> KNOWN_INCLUDES;
static {
HashSet<String> includes = new HashSet<>(2, 1.0f);
HashSet<String> includes = new HashSet<>(3, 1.0f);
includes.add(DEFINITION);
includes.add(TOTAL_FEATURE_IMPORTANCE);
includes.add(FEATURE_IMPORTANCE_BASELINE);
KNOWN_INCLUDES = Collections.unmodifiableSet(includes);
}
public static final ParseField INCLUDE = new ParseField("include");
public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match");
public static final ParseField TAGS = new ParseField("tags");
private final Set<String> includes;
private final List<String> tags;
@Deprecated
public Request(String id, boolean includeModelDefinition, List<String> tags) {
setResourceId(id);
setAllowNoResources(true);
this.tags = tags == null ? Collections.emptyList() : tags;
if (includeModelDefinition) {
this.includes = new HashSet<>(Collections.singletonList(DEFINITION));
} else {
this.includes = Collections.emptySet();
}
public static Includes forModelDefinition() {
return new Includes(new HashSet<>(Collections.singletonList(DEFINITION)));
}
public Request(String id, List<String> tags, Set<String> includes) {
setResourceId(id);
setAllowNoResources(true);
this.tags = tags == null ? Collections.emptyList() : tags;
public static Includes empty() {
return new Includes(new HashSet<>());
}
public static Includes all() {
return new Includes(KNOWN_INCLUDES);
}
private final Set<String> includes;
public Includes(Set<String> includes) {
this.includes = includes == null ? Collections.emptySet() : includes;
Set<String> unknownIncludes = Sets.difference(this.includes, KNOWN_INCLUDES);
if (unknownIncludes.isEmpty() == false) {
@ -79,16 +73,76 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
}
}
public Includes(StreamInput in) throws IOException {
this.includes = in.readSet(StreamInput::readString);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeCollection(this.includes, StreamOutput::writeString);
}
public boolean isIncludeModelDefinition() {
return this.includes.contains(DEFINITION);
}
public boolean isIncludeTotalFeatureImportance() {
return this.includes.contains(TOTAL_FEATURE_IMPORTANCE);
}
public boolean isIncludeFeatureImportanceBaseline() {
return this.includes.contains(FEATURE_IMPORTANCE_BASELINE);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Includes includes1 = (Includes) o;
return Objects.equals(includes, includes1.includes);
}
@Override
public int hashCode() {
return Objects.hash(includes);
}
}
public static class Request extends AbstractGetResourcesRequest {
public static final ParseField INCLUDE = new ParseField("include");
public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match");
public static final ParseField TAGS = new ParseField("tags");
private final Includes includes;
private final List<String> tags;
@Deprecated
public Request(String id, boolean includeModelDefinition, List<String> tags) {
setResourceId(id);
setAllowNoResources(true);
this.tags = tags == null ? Collections.emptyList() : tags;
if (includeModelDefinition) {
this.includes = Includes.forModelDefinition();
} else {
this.includes = Includes.empty();
}
}
public Request(String id, List<String> tags, Set<String> includes) {
setResourceId(id);
setAllowNoResources(true);
this.tags = tags == null ? Collections.emptyList() : tags;
this.includes = new Includes(includes);
}
public Request(StreamInput in) throws IOException {
super(in);
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
this.includes = in.readSet(StreamInput::readString);
this.includes = new Includes(in);
} else {
Set<String> includes = new HashSet<>();
if (in.readBoolean()) {
includes.add(DEFINITION);
}
this.includes = includes;
this.includes = in.readBoolean() ? Includes.forModelDefinition() : Includes.empty();
}
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
this.tags = in.readStringList();
@ -102,25 +156,21 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
return TrainedModelConfig.MODEL_ID.getPreferredName();
}
public boolean isIncludeModelDefinition() {
return this.includes.contains(DEFINITION);
}
public boolean isIncludeTotalFeatureImportance() {
return this.includes.contains(TOTAL_FEATURE_IMPORTANCE);
}
public List<String> getTags() {
return tags;
}
public Includes getIncludes() {
return includes;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
out.writeCollection(this.includes, StreamOutput::writeString);
this.includes.writeTo(out);
} else {
out.writeBoolean(this.includes.contains(DEFINITION));
out.writeBoolean(this.includes.isIncludeModelDefinition());
}
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
out.writeStringCollection(tags);

View File

@ -26,6 +26,7 @@ import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConst
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.FeatureImportanceBaseline;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportance;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@ -37,6 +38,7 @@ import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@ -54,7 +56,10 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
public static final String DECOMPRESS_DEFINITION = "decompress_definition";
public static final String FOR_EXPORT = "for_export";
public static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance";
private static final Set<String> RESERVED_METADATA_FIELDS = Collections.singleton(TOTAL_FEATURE_IMPORTANCE);
public static final String FEATURE_IMPORTANCE_BASELINE = "feature_importance_baseline";
private static final Set<String> RESERVED_METADATA_FIELDS = new HashSet<>(Arrays.asList(
TOTAL_FEATURE_IMPORTANCE,
FEATURE_IMPORTANCE_BASELINE));
private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage";
@ -487,6 +492,17 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
return this;
}
public Builder setBaselineFeatureImportance(FeatureImportanceBaseline featureImportanceBaseline) {
if (featureImportanceBaseline == null) {
return this;
}
if (this.metadata == null) {
this.metadata = new HashMap<>();
}
this.metadata.put(FEATURE_IMPORTANCE_BASELINE, featureImportanceBaseline.asMap());
return this;
}
public Builder setParsedDefinition(TrainedModelDefinition.Builder definition) {
if (definition == null) {
return this;

View File

@ -0,0 +1,178 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParseException;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
public class FeatureImportanceBaseline implements ToXContentObject, Writeable {
private static final String NAME = "feature_importance_baseline";
public static final ParseField BASELINE = new ParseField("baseline");
public static final ParseField CLASSES = new ParseField("classes");
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
public static final ConstructingObjectParser<FeatureImportanceBaseline, Void> LENIENT_PARSER = createParser(true);
public static final ConstructingObjectParser<FeatureImportanceBaseline, Void> STRICT_PARSER = createParser(false);
@SuppressWarnings("unchecked")
private static ConstructingObjectParser<FeatureImportanceBaseline, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<FeatureImportanceBaseline, Void> parser = new ConstructingObjectParser<>(NAME,
ignoreUnknownFields,
a -> new FeatureImportanceBaseline((Double)a[0], (List<ClassBaseline>)a[1]));
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), BASELINE);
parser.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(),
ignoreUnknownFields ? ClassBaseline.LENIENT_PARSER : ClassBaseline.STRICT_PARSER,
CLASSES);
return parser;
}
public static FeatureImportanceBaseline fromXContent(XContentParser parser, boolean lenient) throws IOException {
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
}
public final Double baseline;
public final List<ClassBaseline> classBaselines;
public FeatureImportanceBaseline(StreamInput in) throws IOException {
this.baseline = in.readOptionalDouble();
this.classBaselines = in.readList(ClassBaseline::new);
}
FeatureImportanceBaseline(Double baseline, List<ClassBaseline> classBaselines) {
this.baseline = baseline;
this.classBaselines = classBaselines == null ? Collections.emptyList() : classBaselines;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalDouble(baseline);
out.writeList(classBaselines);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return builder.map(asMap());
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
FeatureImportanceBaseline that = (FeatureImportanceBaseline) o;
return Objects.equals(that.baseline, baseline)
&& Objects.equals(classBaselines, that.classBaselines);
}
public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
if (baseline != null) {
map.put(BASELINE.getPreferredName(), baseline);
}
if (classBaselines.isEmpty() == false) {
map.put(CLASSES.getPreferredName(), classBaselines.stream().map(ClassBaseline::asMap).collect(Collectors.toList()));
}
return map;
}
@Override
public int hashCode() {
return Objects.hash(baseline, classBaselines);
}
public static class ClassBaseline implements ToXContentObject, Writeable {
private static final String NAME = "feature_importance_class_baseline";
public static final ParseField CLASS_NAME = new ParseField("class_name");
public static final ConstructingObjectParser<ClassBaseline, Void> LENIENT_PARSER = createParser(true);
public static final ConstructingObjectParser<ClassBaseline, Void> STRICT_PARSER = createParser(false);
private static ConstructingObjectParser<ClassBaseline, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<ClassBaseline, Void> parser = new ConstructingObjectParser<>(NAME,
ignoreUnknownFields,
a -> new ClassBaseline(a[0], (double)a[1]));
parser.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> {
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
return p.text();
} else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) {
return p.numberValue();
} else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) {
return p.booleanValue();
}
throw new XContentParseException("Unsupported token [" + p.currentToken() + "]");
}, CLASS_NAME, ObjectParser.ValueType.VALUE);
parser.declareDouble(ConstructingObjectParser.constructorArg(), BASELINE);
return parser;
}
public static ClassBaseline fromXContent(XContentParser parser, boolean lenient) throws IOException {
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
}
public final Object className;
public final double baseline;
public ClassBaseline(StreamInput in) throws IOException {
this.className = in.readGenericValue();
this.baseline = in.readDouble();
}
ClassBaseline(Object className, double baseline) {
this.className = className;
this.baseline = baseline;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeGenericValue(className);
out.writeDouble(baseline);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return builder.map(asMap());
}
private Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(CLASS_NAME.getPreferredName(), className);
map.put(BASELINE.getPreferredName(), baseline);
return map;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ClassBaseline that = (ClassBaseline) o;
return Objects.equals(that.className, className) && Objects.equals(baseline, that.baseline);
}
@Override
public int hashCode() {
return Objects.hash(className, baseline);
}
}
}

View File

@ -27,6 +27,7 @@ public class TrainedModelMetadata implements ToXContentObject, Writeable {
public static final String NAME = "trained_model_metadata";
public static final ParseField TOTAL_FEATURE_IMPORTANCE = new ParseField("total_feature_importance");
public static final ParseField FEATURE_IMPORTANCE_BASELINE = new ParseField("feature_importance_baseline");
public static final ParseField MODEL_ID = new ParseField("model_id");
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
@ -37,11 +38,14 @@ public class TrainedModelMetadata implements ToXContentObject, Writeable {
private static ConstructingObjectParser<TrainedModelMetadata, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<TrainedModelMetadata, Void> parser = new ConstructingObjectParser<>(NAME,
ignoreUnknownFields,
a -> new TrainedModelMetadata((String)a[0], (List<TotalFeatureImportance>)a[1]));
a -> new TrainedModelMetadata((String)a[0], (List<TotalFeatureImportance>)a[1], (FeatureImportanceBaseline)a[2]));
parser.declareString(ConstructingObjectParser.constructorArg(), MODEL_ID);
parser.declareObjectArray(ConstructingObjectParser.constructorArg(),
ignoreUnknownFields ? TotalFeatureImportance.LENIENT_PARSER : TotalFeatureImportance.STRICT_PARSER,
TOTAL_FEATURE_IMPORTANCE);
parser.declareObject(ConstructingObjectParser.optionalConstructorArg(),
ignoreUnknownFields ? FeatureImportanceBaseline.LENIENT_PARSER : FeatureImportanceBaseline.STRICT_PARSER,
FEATURE_IMPORTANCE_BASELINE);
return parser;
}
@ -58,16 +62,21 @@ public class TrainedModelMetadata implements ToXContentObject, Writeable {
}
private final List<TotalFeatureImportance> totalFeatureImportances;
private final FeatureImportanceBaseline featureImportanceBaselines;
private final String modelId;
public TrainedModelMetadata(StreamInput in) throws IOException {
this.modelId = in.readString();
this.totalFeatureImportances = in.readList(TotalFeatureImportance::new);
this.featureImportanceBaselines = in.readOptionalWriteable(FeatureImportanceBaseline::new);
}
public TrainedModelMetadata(String modelId, List<TotalFeatureImportance> totalFeatureImportances) {
public TrainedModelMetadata(String modelId,
List<TotalFeatureImportance> totalFeatureImportances,
FeatureImportanceBaseline featureImportanceBaselines) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
this.totalFeatureImportances = Collections.unmodifiableList(totalFeatureImportances);
this.featureImportanceBaselines = featureImportanceBaselines;
}
public String getModelId() {
@ -82,24 +91,30 @@ public class TrainedModelMetadata implements ToXContentObject, Writeable {
return totalFeatureImportances;
}
public FeatureImportanceBaseline getFeatureImportanceBaselines() {
return featureImportanceBaselines;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TrainedModelMetadata that = (TrainedModelMetadata) o;
return Objects.equals(totalFeatureImportances, that.totalFeatureImportances) &&
Objects.equals(featureImportanceBaselines, that.featureImportanceBaselines) &&
Objects.equals(modelId, that.modelId);
}
@Override
public int hashCode() {
return Objects.hash(totalFeatureImportances, modelId);
return Objects.hash(totalFeatureImportances, featureImportanceBaselines, modelId);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
out.writeList(totalFeatureImportances);
out.writeOptionalWriteable(featureImportanceBaselines);
}
@Override
@ -110,6 +125,9 @@ public class TrainedModelMetadata implements ToXContentObject, Writeable {
}
builder.field(MODEL_ID.getPreferredName(), modelId);
builder.field(TOTAL_FEATURE_IMPORTANCE.getPreferredName(), totalFeatureImportances);
if (featureImportanceBaselines != null) {
builder.field(FEATURE_IMPORTANCE_BASELINE.getPreferredName(), featureImportanceBaselines);
}
builder.endObject();
return builder;
}

View File

@ -71,6 +71,19 @@
"inference_config": {
"enabled": false
},
"feature_importance_baseline": {
"properties": {
"baseline": {
"type": "double"
},
"classes": {
"properties": {
"class_name": { "type": "keyword"},
"baseline": {"type" : "double"}
}
}
}
},
"total_feature_importance": {
"type": "nested",
"dynamic": "false",

View File

@ -10,6 +10,7 @@ import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Includes;
import java.util.HashSet;
import java.util.Set;
@ -24,7 +25,9 @@ public class GetTrainedModelsRequestTests extends AbstractBWCWireSerializationTe
randomBoolean() ? null :
randomList(10, () -> randomAlphaOfLength(10)),
randomBoolean() ? null :
Stream.generate(() -> randomFrom(Request.DEFINITION, Request.TOTAL_FEATURE_IMPORTANCE))
Stream.generate(() -> randomFrom(Includes.DEFINITION,
Includes.TOTAL_FEATURE_IMPORTANCE,
Includes.FEATURE_IMPORTANCE_BASELINE))
.limit(4)
.collect(Collectors.toSet()));
request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100)));
@ -40,8 +43,8 @@ public class GetTrainedModelsRequestTests extends AbstractBWCWireSerializationTe
protected Request mutateInstanceForVersion(Request instance, Version version) {
if (version.before(Version.V_7_10_0)) {
Set<String> includes = new HashSet<>();
if (instance.isIncludeModelDefinition()) {
includes.add(Request.DEFINITION);
if (instance.getIncludes().isIncludeModelDefinition()) {
includes.add(Includes.DEFINITION);
}
Request request = new Request(
instance.getResourceId(),

View File

@ -0,0 +1,71 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.junit.Before;
import java.io.IOException;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class FeatureImportanceBaselineTests extends AbstractBWCSerializationTestCase<FeatureImportanceBaseline> {
private boolean lenient;
@SuppressWarnings("unchecked")
public static FeatureImportanceBaseline randomInstance() {
Supplier<Object> classNameGenerator = randomFrom(
() -> randomAlphaOfLength(10),
ESTestCase::randomBoolean,
() -> randomIntBetween(0, 10)
);
return new FeatureImportanceBaseline(
randomBoolean() ? null : randomDouble(),
randomBoolean() ?
null :
Stream.generate(() -> new FeatureImportanceBaseline.ClassBaseline(classNameGenerator.get(), randomDouble()))
.limit(randomIntBetween(1, 10))
.collect(Collectors.toList())
);
}
@Before
public void chooseStrictOrLenient() {
lenient = randomBoolean();
}
@Override
protected FeatureImportanceBaseline createTestInstance() {
return randomInstance();
}
@Override
protected Writeable.Reader<FeatureImportanceBaseline> instanceReader() {
return FeatureImportanceBaseline::new;
}
@Override
protected FeatureImportanceBaseline doParseInstance(XContentParser parser) throws IOException {
return FeatureImportanceBaseline.fromXContent(parser, lenient);
}
@Override
protected boolean supportsUnknownFields() {
return lenient;
}
@Override
protected FeatureImportanceBaseline mutateInstanceForVersion(FeatureImportanceBaseline instance, Version version) {
return instance;
}
}

View File

@ -23,7 +23,8 @@ public class TrainedModelMetadataTests extends AbstractBWCSerializationTestCase<
public static TrainedModelMetadata randomInstance() {
return new TrainedModelMetadata(
randomAlphaOfLength(10),
Stream.generate(TotalFeatureImportanceTests::randomInstance).limit(randomIntBetween(1, 10)).collect(Collectors.toList()));
Stream.generate(TotalFeatureImportanceTests::randomInstance).limit(randomIntBetween(1, 10)).collect(Collectors.toList()),
randomBoolean() ? null : FeatureImportanceBaselineTests.randomInstance());
}
@Before

View File

@ -12,6 +12,7 @@ import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.license.License;
import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
@ -22,6 +23,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInputTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.FeatureImportanceBaselineTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportanceTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata;
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
@ -90,7 +92,8 @@ public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {
}
ModelMetadata modelMetadata = new ModelMetadata(Stream.generate(TotalFeatureImportanceTests::randomInstance)
.limit(randomIntBetween(1, 10))
.collect(Collectors.toList()));
.collect(Collectors.toList()),
FeatureImportanceBaselineTests.randomInstance());
persister.createAndIndexInferenceModelMetadata(modelMetadata);
PlainActionFuture<Tuple<Long, Set<String>>> getIdsFuture = new PlainActionFuture<>();
@ -100,13 +103,14 @@ public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {
String inferenceModelId = ids.v2().iterator().next();
PlainActionFuture<TrainedModelConfig> getTrainedModelFuture = new PlainActionFuture<>();
trainedModelProvider.getTrainedModel(inferenceModelId, true, true, getTrainedModelFuture);
trainedModelProvider.getTrainedModel(inferenceModelId, GetTrainedModelsAction.Includes.all(), getTrainedModelFuture);
TrainedModelConfig storedConfig = getTrainedModelFuture.actionGet();
assertThat(storedConfig.getCompressedDefinition(), equalTo(compressedDefinition));
assertThat(storedConfig.getEstimatedOperations(), equalTo((long)modelSizeInfo.numOperations()));
assertThat(storedConfig.getEstimatedHeapMemory(), equalTo(modelSizeInfo.ramBytesUsed()));
assertThat(storedConfig.getMetadata(), hasKey("total_feature_importance"));
assertThat(storedConfig.getMetadata(), hasKey("feature_importance_baseline"));
PlainActionFuture<Map<String, TrainedModelMetadata>> getTrainedMetadataFuture = new PlainActionFuture<>();
trainedModelProvider.getTrainedModelMetadata(Collections.singletonList(inferenceModelId), getTrainedMetadataFuture);

View File

@ -17,6 +17,7 @@ import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.license.License;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
@ -91,7 +92,7 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), listener),
getConfigHolder,
exceptionHolder);
getConfigHolder.get().ensureParsedDefinition(xContentRegistry());
@ -125,7 +126,7 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(listener ->
trainedModelProvider.getTrainedModel(modelId, false, false, listener),
trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), listener),
getConfigHolder,
exceptionHolder);
getConfigHolder.get().ensureParsedDefinition(xContentRegistry());
@ -139,7 +140,7 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), listener),
getConfigHolder,
exceptionHolder);
assertThat(exceptionHolder.get(), is(not(nullValue())));
@ -164,7 +165,7 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), listener),
getConfigHolder,
exceptionHolder);
assertThat(exceptionHolder.get(), is(not(nullValue())));
@ -206,7 +207,7 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), listener),
getConfigHolder,
exceptionHolder);
assertThat(getConfigHolder.get(), is(nullValue()));
@ -254,7 +255,7 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
}
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), listener),
getConfigHolder,
exceptionHolder);
assertThat(getConfigHolder.get(), is(nullValue()));

View File

@ -49,18 +49,17 @@ public class TransportGetTrainedModelsAction extends HandledTransportAction<Requ
return;
}
if (request.isIncludeModelDefinition() && totalAndIds.v2().size() > 1) {
if (request.getIncludes().isIncludeModelDefinition() && totalAndIds.v2().size() > 1) {
listener.onFailure(
ExceptionsHelper.badRequestException(Messages.INFERENCE_TOO_MANY_DEFINITIONS_REQUESTED)
);
return;
}
if (request.isIncludeModelDefinition()) {
if (request.getIncludes().isIncludeModelDefinition()) {
provider.getTrainedModel(
totalAndIds.v2().iterator().next(),
true,
request.isIncludeTotalFeatureImportance(),
request.getIncludes(),
ActionListener.wrap(
config -> listener.onResponse(responseBuilder.setModels(Collections.singletonList(config)).build()),
listener::onFailure
@ -69,8 +68,8 @@ public class TransportGetTrainedModelsAction extends HandledTransportAction<Requ
} else {
provider.getTrainedModels(
totalAndIds.v2(),
request.getIncludes(),
request.isAllowNoResources(),
request.isIncludeTotalFeatureImportance(),
ActionListener.wrap(
configs -> listener.onResponse(responseBuilder.setModels(configs).build()),
listener::onFailure

View File

@ -16,6 +16,7 @@ import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Request;
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Response;
@ -82,7 +83,7 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
responseBuilder.setLicensed(true);
this.modelLoadingService.getModelForPipeline(request.getModelId(), getModelListener);
} else {
trainedModelProvider.getTrainedModel(request.getModelId(), false, false, ActionListener.wrap(
trainedModelProvider.getTrainedModel(request.getModelId(), GetTrainedModelsAction.Includes.empty(), ActionListener.wrap(
trainedModelConfig -> {
responseBuilder.setLicensed(licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel()));
if (licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel()) || request.isPreviouslyLicensed()) {

View File

@ -129,7 +129,8 @@ public class ChunkedTrainedModelPersister {
return;
}
TrainedModelMetadata trainedModelMetadata = new TrainedModelMetadata(this.currentModelId.get(),
modelMetadata.getFeatureImportances());
modelMetadata.getFeatureImportances(),
modelMetadata.getFeatureImportanceBaseline());
CountDownLatch latch = storeTrainedModelMetadata(trainedModelMetadata);

View File

@ -10,6 +10,7 @@ import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.FeatureImportanceBaseline;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportance;
import java.io.IOException;
@ -17,47 +18,60 @@ import java.util.List;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
public class ModelMetadata implements ToXContentObject {
public static final ParseField TOTAL_FEATURE_IMPORTANCE = new ParseField("total_feature_importance");
public static final ParseField FEATURE_IMPORTANCE_BASELINE = new ParseField("feature_importance_baseline");
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<ModelMetadata, Void> PARSER = new ConstructingObjectParser<>(
"trained_model_metadata",
a -> new ModelMetadata((List<TotalFeatureImportance>) a[0]));
a -> new ModelMetadata((List<TotalFeatureImportance>) a[0], (FeatureImportanceBaseline) a[1]));
static {
PARSER.declareObjectArray(constructorArg(), TotalFeatureImportance.STRICT_PARSER, TOTAL_FEATURE_IMPORTANCE);
PARSER.declareObject(optionalConstructorArg(), FeatureImportanceBaseline.STRICT_PARSER, FEATURE_IMPORTANCE_BASELINE);
}
private final List<TotalFeatureImportance> featureImportances;
private final FeatureImportanceBaseline featureImportanceBaseline;
public ModelMetadata(List<TotalFeatureImportance> featureImportances) {
public ModelMetadata(List<TotalFeatureImportance> featureImportances, FeatureImportanceBaseline featureImportanceBaseline) {
this.featureImportances = featureImportances;
this.featureImportanceBaseline = featureImportanceBaseline;
}
public List<TotalFeatureImportance> getFeatureImportances() {
return featureImportances;
}
public FeatureImportanceBaseline getFeatureImportanceBaseline() {
return featureImportanceBaseline;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ModelMetadata that = (ModelMetadata) o;
return Objects.equals(featureImportances, that.featureImportances);
return Objects.equals(featureImportances, that.featureImportances)
&& Objects.equals(featureImportanceBaseline, that.featureImportanceBaseline);
}
@Override
public int hashCode() {
return Objects.hash(featureImportances);
return Objects.hash(featureImportances, featureImportanceBaseline);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(TOTAL_FEATURE_IMPORTANCE.getPreferredName(), featureImportances);
if (featureImportanceBaseline != null) {
builder.field(FEATURE_IMPORTANCE_BASELINE.getPreferredName(), featureImportanceBaseline);
}
builder.endObject();
return builder;
}

View File

@ -28,6 +28,7 @@ import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.ingest.IngestMetadata;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
@ -270,7 +271,7 @@ public class ModelLoadingService implements ClusterStateListener {
}
private void loadModel(String modelId, Consumer consumer) {
provider.getTrainedModel(modelId, false, false, ActionListener.wrap(
provider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), ActionListener.wrap(
trainedModelConfig -> {
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
provider.getTrainedModelForInference(modelId, ActionListener.wrap(
@ -306,7 +307,7 @@ public class ModelLoadingService implements ClusterStateListener {
// If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called
// by a simulated pipeline
logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", modelId));
provider.getTrainedModel(modelId, false, false, ActionListener.wrap(
provider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), ActionListener.wrap(
trainedModelConfig -> {
// Verify we can pull the model into memory without causing OOM
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);

View File

@ -68,6 +68,7 @@ import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.xpack.core.action.util.ExpandedIdsMatcher;
import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
@ -439,13 +440,12 @@ public class TrainedModelProvider {
}
public void getTrainedModel(final String modelId,
final boolean includeDefinition,
final boolean includeTotalFeatureImportance,
final GetTrainedModelsAction.Includes includes,
final ActionListener<TrainedModelConfig> finalListener) {
if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
try {
finalListener.onResponse(loadModelFromResource(modelId, includeDefinition == false).build());
finalListener.onResponse(loadModelFromResource(modelId, includes.isIncludeModelDefinition() == false).build());
return;
} catch (ElasticsearchException ex) {
finalListener.onFailure(ex);
@ -455,7 +455,7 @@ public class TrainedModelProvider {
ActionListener<TrainedModelConfig.Builder> getTrainedModelListener = ActionListener.wrap(
modelBuilder -> {
if (includeTotalFeatureImportance == false) {
if ((includes.isIncludeFeatureImportanceBaseline() || includes.isIncludeTotalFeatureImportance()) == false) {
finalListener.onResponse(modelBuilder.build());
return;
}
@ -463,7 +463,12 @@ public class TrainedModelProvider {
metadata -> {
TrainedModelMetadata modelMetadata = metadata.get(modelId);
if (modelMetadata != null) {
modelBuilder.setFeatureImportance(modelMetadata.getTotalFeatureImportances());
if (includes.isIncludeTotalFeatureImportance()) {
modelBuilder.setFeatureImportance(modelMetadata.getTotalFeatureImportances());
}
if (includes.isIncludeFeatureImportanceBaseline()) {
modelBuilder.setBaselineFeatureImportance(modelMetadata.getFeatureImportanceBaselines());
}
}
finalListener.onResponse(modelBuilder.build());
},
@ -493,7 +498,7 @@ public class TrainedModelProvider {
.setSize(1)
.request());
if (includeDefinition) {
if (includes.isIncludeModelDefinition()) {
multiSearchRequestBuilder.add(client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
.setQuery(QueryBuilders.constantScoreQuery(QueryBuilders
.boolQuery()
@ -527,7 +532,7 @@ public class TrainedModelProvider {
return;
}
if (includeDefinition) {
if (includes.isIncludeModelDefinition()) {
try {
List<TrainedModelDefinitionDoc> docs = handleSearchItems(multiSearchResponse.getResponses()[1],
modelId,
@ -569,8 +574,8 @@ public class TrainedModelProvider {
* It assumes that there are fewer than 10k.
*/
public void getTrainedModels(Set<String> modelIds,
GetTrainedModelsAction.Includes includes,
boolean allowNoResources,
boolean includeTotalFeatureImportance,
final ActionListener<List<TrainedModelConfig>> finalListener) {
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(modelIds.toArray(new String[0])));
@ -601,7 +606,7 @@ public class TrainedModelProvider {
ActionListener<List<TrainedModelConfig.Builder>> getTrainedModelListener = ActionListener.wrap(
modelBuilders -> {
if (includeTotalFeatureImportance == false) {
if ((includes.isIncludeFeatureImportanceBaseline() || includes.isIncludeTotalFeatureImportance()) == false) {
finalListener.onResponse(modelBuilders.stream()
.map(TrainedModelConfig.Builder::build)
.sorted(Comparator.comparing(TrainedModelConfig::getModelId))
@ -614,7 +619,12 @@ public class TrainedModelProvider {
.map(builder -> {
TrainedModelMetadata modelMetadata = metadata.get(builder.getModelId());
if (modelMetadata != null) {
builder.setFeatureImportance(modelMetadata.getTotalFeatureImportances());
if (includes.isIncludeTotalFeatureImportance()) {
builder.setFeatureImportance(modelMetadata.getTotalFeatureImportances());
}
if (includes.isIncludeFeatureImportanceBaseline()) {
builder.setBaselineFeatureImportance(modelMetadata.getFeatureImportanceBaselines());
}
}
return builder.build();
})

View File

@ -18,6 +18,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.FeatureImportanceBaselineTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportanceTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata;
import org.elasticsearch.xpack.core.security.user.XPackUser;
@ -103,7 +104,8 @@ public class ChunkedTrainedModelPersisterTests extends ESTestCase {
TrainedModelDefinitionChunk chunk2 = new TrainedModelDefinitionChunk(randomAlphaOfLength(10), 1, true);
ModelMetadata modelMetadata = new ModelMetadata(Stream.generate(TotalFeatureImportanceTests::randomInstance)
.limit(randomIntBetween(1, 10))
.collect(Collectors.toList()));
.collect(Collectors.toList()),
FeatureImportanceBaselineTests.randomInstance());
resultProcessor.createAndIndexInferenceModelConfig(modelSizeInfo);
resultProcessor.createAndIndexInferenceModelDoc(chunk1);

View File

@ -16,6 +16,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsageTests;
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStatsTests;
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStatsTests;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.FeatureImportanceBaselineTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportanceTests;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
@ -69,7 +70,8 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
if (randomBoolean()) {
builder.setModelMetadata(new ModelMetadata(Stream.generate(TotalFeatureImportanceTests::randomInstance)
.limit(randomIntBetween(1, 10))
.collect(Collectors.toList())));
.collect(Collectors.toList()),
FeatureImportanceBaselineTests.randomInstance()));
}
return builder.build();
}

View File

@ -34,6 +34,7 @@ import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
@ -437,9 +438,9 @@ public class ModelLoadingServiceTests extends ESTestCase {
// the loading occurred or which models are currently in the cache due to evictions.
// Verify that we have at least loaded all three
assertBusy(() -> {
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(false), eq(false), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(false), eq(false), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(false), eq(false), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(GetTrainedModelsAction.Includes.empty()), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(GetTrainedModelsAction.Includes.empty()), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(GetTrainedModelsAction.Includes.empty()), any());
});
assertBusy(() -> {
assertThat(circuitBreaker.getUsed(), equalTo(10L));
@ -553,10 +554,10 @@ public class ModelLoadingServiceTests extends ESTestCase {
}).when(trainedModelProvider).getTrainedModelForInference(eq(modelId), any());
doAnswer(invocationOnMock -> {
@SuppressWarnings("rawtypes")
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[3];
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
listener.onResponse(trainedModelConfig);
return null;
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), eq(false), any());
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(GetTrainedModelsAction.Includes.empty()), any());
}
@SuppressWarnings("unchecked")
@ -564,20 +565,20 @@ public class ModelLoadingServiceTests extends ESTestCase {
if (randomBoolean()) {
doAnswer(invocationOnMock -> {
@SuppressWarnings("rawtypes")
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[3];
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
listener.onFailure(new ResourceNotFoundException(
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
return null;
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), eq(false), any());
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(GetTrainedModelsAction.Includes.empty()), any());
} else {
TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class);
when(trainedModelConfig.getEstimatedHeapMemory()).thenReturn(0L);
doAnswer(invocationOnMock -> {
@SuppressWarnings("rawtypes")
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[3];
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
listener.onResponse(trainedModelConfig);
return null;
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), eq(false), any());
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(GetTrainedModelsAction.Includes.empty()), any());
doAnswer(invocationOnMock -> {
@SuppressWarnings("rawtypes")
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1];

View File

@ -15,6 +15,7 @@ import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.TermQueryBuilder;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
@ -57,14 +58,14 @@ public class TrainedModelProviderTests extends ESTestCase {
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
for(String modelId : TrainedModelProvider.MODELS_STORED_AS_RESOURCE) {
PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
trainedModelProvider.getTrainedModel(modelId, true, false, future);
trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(),future);
TrainedModelConfig configWithDefinition = future.actionGet();
assertThat(configWithDefinition.getModelId(), equalTo(modelId));
assertThat(configWithDefinition.ensureParsedDefinition(xContentRegistry()).getModelDefinition(), is(not(nullValue())));
PlainActionFuture<TrainedModelConfig> futureNoDefinition = new PlainActionFuture<>();
trainedModelProvider.getTrainedModel(modelId, false, false, futureNoDefinition);
trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), futureNoDefinition);
TrainedModelConfig configWithoutDefinition = futureNoDefinition.actionGet();
assertThat(configWithoutDefinition.getModelId(), equalTo(modelId));

View File

@ -9,6 +9,7 @@ import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
@ -33,7 +34,7 @@ public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
// Should be OK as we don't make any client calls
trainedModelProvider.getTrainedModel("lang_ident_model_1", true, false, future);
trainedModelProvider.getTrainedModel("lang_ident_model_1", GetTrainedModelsAction.Includes.forModelDefinition(), future);
TrainedModelConfig config = future.actionGet();
config.ensureParsedDefinition(xContentRegistry());