this adds the new field `feature_importance_baseline` and allows it to be optionally be included in the model's metadata. Related to: https://github.com/elastic/ml-cpp/pull/1522
This commit is contained in:
parent
dab1b14a10
commit
1e63313c19
|
@ -1,208 +0,0 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.client.ml.inference.trainedmodel.metadata;
|
||||
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParseException;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
public class TotalFeatureImportance implements ToXContentObject {
|
||||
|
||||
private static final String NAME = "total_feature_importance";
|
||||
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
|
||||
public static final ParseField IMPORTANCE = new ParseField("importance");
|
||||
public static final ParseField CLASSES = new ParseField("classes");
|
||||
public static final ParseField MEAN_MAGNITUDE = new ParseField("mean_magnitude");
|
||||
public static final ParseField MIN = new ParseField("min");
|
||||
public static final ParseField MAX = new ParseField("max");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static final ConstructingObjectParser<TotalFeatureImportance, Void> PARSER = new ConstructingObjectParser<>(NAME,
|
||||
true,
|
||||
a -> new TotalFeatureImportance((String)a[0], (Importance)a[1], (List<ClassImportance>)a[2]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
|
||||
PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), Importance.PARSER, IMPORTANCE);
|
||||
PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), ClassImportance.PARSER, CLASSES);
|
||||
}
|
||||
|
||||
public static TotalFeatureImportance fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
public final String featureName;
|
||||
public final Importance importance;
|
||||
public final List<ClassImportance> classImportances;
|
||||
|
||||
TotalFeatureImportance(String featureName, @Nullable Importance importance, @Nullable List<ClassImportance> classImportances) {
|
||||
this.featureName = featureName;
|
||||
this.importance = importance;
|
||||
this.classImportances = classImportances == null ? Collections.emptyList() : classImportances;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(FEATURE_NAME.getPreferredName(), featureName);
|
||||
if (importance != null) {
|
||||
builder.field(IMPORTANCE.getPreferredName(), importance);
|
||||
}
|
||||
if (classImportances.isEmpty() == false) {
|
||||
builder.field(CLASSES.getPreferredName(), classImportances);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
TotalFeatureImportance that = (TotalFeatureImportance) o;
|
||||
return Objects.equals(that.importance, importance)
|
||||
&& Objects.equals(featureName, that.featureName)
|
||||
&& Objects.equals(classImportances, that.classImportances);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(featureName, importance, classImportances);
|
||||
}
|
||||
|
||||
public static class Importance implements ToXContentObject {
|
||||
private static final String NAME = "importance";
|
||||
|
||||
public static final ConstructingObjectParser<Importance, Void> PARSER = new ConstructingObjectParser<>(NAME,
|
||||
true,
|
||||
a -> new Importance((double)a[0], (double)a[1], (double)a[2]));
|
||||
|
||||
static {
|
||||
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MEAN_MAGNITUDE);
|
||||
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MIN);
|
||||
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MAX);
|
||||
}
|
||||
|
||||
private final double meanMagnitude;
|
||||
private final double min;
|
||||
private final double max;
|
||||
|
||||
public Importance(double meanMagnitude, double min, double max) {
|
||||
this.meanMagnitude = meanMagnitude;
|
||||
this.min = min;
|
||||
this.max = max;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Importance that = (Importance) o;
|
||||
return Double.compare(that.meanMagnitude, meanMagnitude) == 0 &&
|
||||
Double.compare(that.min, min) == 0 &&
|
||||
Double.compare(that.max, max) == 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(meanMagnitude, min, max);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude);
|
||||
builder.field(MIN.getPreferredName(), min);
|
||||
builder.field(MAX.getPreferredName(), max);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
}
|
||||
|
||||
public static class ClassImportance implements ToXContentObject {
|
||||
private static final String NAME = "total_class_importance";
|
||||
|
||||
public static final ParseField CLASS_NAME = new ParseField("class_name");
|
||||
public static final ParseField IMPORTANCE = new ParseField("importance");
|
||||
|
||||
public static final ConstructingObjectParser<ClassImportance, Void> PARSER = new ConstructingObjectParser<>(NAME,
|
||||
true,
|
||||
a -> new ClassImportance(a[0], (Importance)a[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> {
|
||||
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
|
||||
return p.text();
|
||||
} else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) {
|
||||
return p.numberValue();
|
||||
} else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) {
|
||||
return p.booleanValue();
|
||||
}
|
||||
throw new XContentParseException("Unsupported token [" + p.currentToken() + "]");
|
||||
}, CLASS_NAME, ObjectParser.ValueType.VALUE);
|
||||
PARSER.declareObject(ConstructingObjectParser.constructorArg(), Importance.PARSER, IMPORTANCE);
|
||||
}
|
||||
|
||||
public static ClassImportance fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
public final Object className;
|
||||
public final Importance importance;
|
||||
|
||||
ClassImportance(Object className, Importance importance) {
|
||||
this.className = className;
|
||||
this.importance = importance;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(CLASS_NAME.getPreferredName(), className);
|
||||
builder.field(IMPORTANCE.getPreferredName(), importance);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
ClassImportance that = (ClassImportance) o;
|
||||
return Objects.equals(that.importance, importance) && Objects.equals(className, that.className);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(className, importance);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
|
@ -1,71 +0,0 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference.trainedmodel.metadata;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
|
||||
public class TotalFeatureImportanceTests extends AbstractXContentTestCase<TotalFeatureImportance> {
|
||||
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static TotalFeatureImportance randomInstance() {
|
||||
Supplier<Object> classNameGenerator = randomFrom(
|
||||
() -> randomAlphaOfLength(10),
|
||||
ESTestCase::randomBoolean,
|
||||
() -> randomIntBetween(0, 10)
|
||||
);
|
||||
return new TotalFeatureImportance(
|
||||
randomAlphaOfLength(10),
|
||||
randomBoolean() ? null : randomImportance(),
|
||||
randomBoolean() ?
|
||||
null :
|
||||
Stream.generate(() -> new TotalFeatureImportance.ClassImportance(classNameGenerator.get(), randomImportance()))
|
||||
.limit(randomIntBetween(1, 10))
|
||||
.collect(Collectors.toList())
|
||||
);
|
||||
}
|
||||
|
||||
private static TotalFeatureImportance.Importance randomImportance() {
|
||||
return new TotalFeatureImportance.Importance(randomDouble(), randomDouble(), randomDouble());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TotalFeatureImportance createTestInstance() {
|
||||
return randomInstance();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TotalFeatureImportance doParseInstance(XContentParser parser) throws IOException {
|
||||
return TotalFeatureImportance.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
|
@ -10,6 +10,7 @@ import org.elasticsearch.action.ActionType;
|
|||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.util.set.Sets;
|
||||
import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest;
|
||||
import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse;
|
||||
|
@ -34,41 +35,34 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
|
|||
super(NAME, Response::new);
|
||||
}
|
||||
|
||||
public static class Request extends AbstractGetResourcesRequest {
|
||||
|
||||
public static class Includes implements Writeable {
|
||||
static final String DEFINITION = "definition";
|
||||
static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance";
|
||||
static final String FEATURE_IMPORTANCE_BASELINE = "feature_importance_baseline";
|
||||
private static final Set<String> KNOWN_INCLUDES;
|
||||
static {
|
||||
HashSet<String> includes = new HashSet<>(2, 1.0f);
|
||||
HashSet<String> includes = new HashSet<>(3, 1.0f);
|
||||
includes.add(DEFINITION);
|
||||
includes.add(TOTAL_FEATURE_IMPORTANCE);
|
||||
includes.add(FEATURE_IMPORTANCE_BASELINE);
|
||||
KNOWN_INCLUDES = Collections.unmodifiableSet(includes);
|
||||
}
|
||||
public static final ParseField INCLUDE = new ParseField("include");
|
||||
public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
|
||||
public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match");
|
||||
public static final ParseField TAGS = new ParseField("tags");
|
||||
|
||||
public static Includes forModelDefinition() {
|
||||
return new Includes(new HashSet<>(Collections.singletonList(DEFINITION)));
|
||||
}
|
||||
|
||||
public static Includes empty() {
|
||||
return new Includes(new HashSet<>());
|
||||
}
|
||||
|
||||
public static Includes all() {
|
||||
return new Includes(KNOWN_INCLUDES);
|
||||
}
|
||||
|
||||
private final Set<String> includes;
|
||||
private final List<String> tags;
|
||||
|
||||
@Deprecated
|
||||
public Request(String id, boolean includeModelDefinition, List<String> tags) {
|
||||
setResourceId(id);
|
||||
setAllowNoResources(true);
|
||||
this.tags = tags == null ? Collections.emptyList() : tags;
|
||||
if (includeModelDefinition) {
|
||||
this.includes = new HashSet<>(Collections.singletonList(DEFINITION));
|
||||
} else {
|
||||
this.includes = Collections.emptySet();
|
||||
}
|
||||
}
|
||||
|
||||
public Request(String id, List<String> tags, Set<String> includes) {
|
||||
setResourceId(id);
|
||||
setAllowNoResources(true);
|
||||
this.tags = tags == null ? Collections.emptyList() : tags;
|
||||
public Includes(Set<String> includes) {
|
||||
this.includes = includes == null ? Collections.emptySet() : includes;
|
||||
Set<String> unknownIncludes = Sets.difference(this.includes, KNOWN_INCLUDES);
|
||||
if (unknownIncludes.isEmpty() == false) {
|
||||
|
@ -79,16 +73,76 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
|
|||
}
|
||||
}
|
||||
|
||||
public Includes(StreamInput in) throws IOException {
|
||||
this.includes = in.readSet(StreamInput::readString);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeCollection(this.includes, StreamOutput::writeString);
|
||||
}
|
||||
|
||||
public boolean isIncludeModelDefinition() {
|
||||
return this.includes.contains(DEFINITION);
|
||||
}
|
||||
|
||||
public boolean isIncludeTotalFeatureImportance() {
|
||||
return this.includes.contains(TOTAL_FEATURE_IMPORTANCE);
|
||||
}
|
||||
|
||||
public boolean isIncludeFeatureImportanceBaseline() {
|
||||
return this.includes.contains(FEATURE_IMPORTANCE_BASELINE);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Includes includes1 = (Includes) o;
|
||||
return Objects.equals(includes, includes1.includes);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(includes);
|
||||
}
|
||||
}
|
||||
|
||||
public static class Request extends AbstractGetResourcesRequest {
|
||||
|
||||
public static final ParseField INCLUDE = new ParseField("include");
|
||||
public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
|
||||
public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match");
|
||||
public static final ParseField TAGS = new ParseField("tags");
|
||||
|
||||
private final Includes includes;
|
||||
private final List<String> tags;
|
||||
|
||||
@Deprecated
|
||||
public Request(String id, boolean includeModelDefinition, List<String> tags) {
|
||||
setResourceId(id);
|
||||
setAllowNoResources(true);
|
||||
this.tags = tags == null ? Collections.emptyList() : tags;
|
||||
if (includeModelDefinition) {
|
||||
this.includes = Includes.forModelDefinition();
|
||||
} else {
|
||||
this.includes = Includes.empty();
|
||||
}
|
||||
}
|
||||
|
||||
public Request(String id, List<String> tags, Set<String> includes) {
|
||||
setResourceId(id);
|
||||
setAllowNoResources(true);
|
||||
this.tags = tags == null ? Collections.emptyList() : tags;
|
||||
this.includes = new Includes(includes);
|
||||
}
|
||||
|
||||
public Request(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||
this.includes = in.readSet(StreamInput::readString);
|
||||
this.includes = new Includes(in);
|
||||
} else {
|
||||
Set<String> includes = new HashSet<>();
|
||||
if (in.readBoolean()) {
|
||||
includes.add(DEFINITION);
|
||||
}
|
||||
this.includes = includes;
|
||||
this.includes = in.readBoolean() ? Includes.forModelDefinition() : Includes.empty();
|
||||
}
|
||||
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||
this.tags = in.readStringList();
|
||||
|
@ -102,25 +156,21 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
|
|||
return TrainedModelConfig.MODEL_ID.getPreferredName();
|
||||
}
|
||||
|
||||
public boolean isIncludeModelDefinition() {
|
||||
return this.includes.contains(DEFINITION);
|
||||
}
|
||||
|
||||
public boolean isIncludeTotalFeatureImportance() {
|
||||
return this.includes.contains(TOTAL_FEATURE_IMPORTANCE);
|
||||
}
|
||||
|
||||
public List<String> getTags() {
|
||||
return tags;
|
||||
}
|
||||
|
||||
public Includes getIncludes() {
|
||||
return includes;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
super.writeTo(out);
|
||||
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||
out.writeCollection(this.includes, StreamOutput::writeString);
|
||||
this.includes.writeTo(out);
|
||||
} else {
|
||||
out.writeBoolean(this.includes.contains(DEFINITION));
|
||||
out.writeBoolean(this.includes.isIncludeModelDefinition());
|
||||
}
|
||||
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||
out.writeStringCollection(tags);
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConst
|
|||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedInferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.FeatureImportanceBaseline;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportance;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
@ -37,6 +38,7 @@ import java.time.Instant;
|
|||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
@ -54,7 +56,10 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
public static final String DECOMPRESS_DEFINITION = "decompress_definition";
|
||||
public static final String FOR_EXPORT = "for_export";
|
||||
public static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance";
|
||||
private static final Set<String> RESERVED_METADATA_FIELDS = Collections.singleton(TOTAL_FEATURE_IMPORTANCE);
|
||||
public static final String FEATURE_IMPORTANCE_BASELINE = "feature_importance_baseline";
|
||||
private static final Set<String> RESERVED_METADATA_FIELDS = new HashSet<>(Arrays.asList(
|
||||
TOTAL_FEATURE_IMPORTANCE,
|
||||
FEATURE_IMPORTANCE_BASELINE));
|
||||
|
||||
private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage";
|
||||
|
||||
|
@ -487,6 +492,17 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setBaselineFeatureImportance(FeatureImportanceBaseline featureImportanceBaseline) {
|
||||
if (featureImportanceBaseline == null) {
|
||||
return this;
|
||||
}
|
||||
if (this.metadata == null) {
|
||||
this.metadata = new HashMap<>();
|
||||
}
|
||||
this.metadata.put(FEATURE_IMPORTANCE_BASELINE, featureImportanceBaseline.asMap());
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setParsedDefinition(TrainedModelDefinition.Builder definition) {
|
||||
if (definition == null) {
|
||||
return this;
|
||||
|
|
|
@ -0,0 +1,178 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParseException;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class FeatureImportanceBaseline implements ToXContentObject, Writeable {
|
||||
|
||||
private static final String NAME = "feature_importance_baseline";
|
||||
public static final ParseField BASELINE = new ParseField("baseline");
|
||||
public static final ParseField CLASSES = new ParseField("classes");
|
||||
|
||||
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
|
||||
public static final ConstructingObjectParser<FeatureImportanceBaseline, Void> LENIENT_PARSER = createParser(true);
|
||||
public static final ConstructingObjectParser<FeatureImportanceBaseline, Void> STRICT_PARSER = createParser(false);
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static ConstructingObjectParser<FeatureImportanceBaseline, Void> createParser(boolean ignoreUnknownFields) {
|
||||
ConstructingObjectParser<FeatureImportanceBaseline, Void> parser = new ConstructingObjectParser<>(NAME,
|
||||
ignoreUnknownFields,
|
||||
a -> new FeatureImportanceBaseline((Double)a[0], (List<ClassBaseline>)a[1]));
|
||||
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), BASELINE);
|
||||
parser.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(),
|
||||
ignoreUnknownFields ? ClassBaseline.LENIENT_PARSER : ClassBaseline.STRICT_PARSER,
|
||||
CLASSES);
|
||||
return parser;
|
||||
}
|
||||
|
||||
public static FeatureImportanceBaseline fromXContent(XContentParser parser, boolean lenient) throws IOException {
|
||||
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
|
||||
}
|
||||
|
||||
public final Double baseline;
|
||||
public final List<ClassBaseline> classBaselines;
|
||||
|
||||
public FeatureImportanceBaseline(StreamInput in) throws IOException {
|
||||
this.baseline = in.readOptionalDouble();
|
||||
this.classBaselines = in.readList(ClassBaseline::new);
|
||||
}
|
||||
|
||||
FeatureImportanceBaseline(Double baseline, List<ClassBaseline> classBaselines) {
|
||||
this.baseline = baseline;
|
||||
this.classBaselines = classBaselines == null ? Collections.emptyList() : classBaselines;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeOptionalDouble(baseline);
|
||||
out.writeList(classBaselines);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
return builder.map(asMap());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
FeatureImportanceBaseline that = (FeatureImportanceBaseline) o;
|
||||
return Objects.equals(that.baseline, baseline)
|
||||
&& Objects.equals(classBaselines, that.classBaselines);
|
||||
}
|
||||
|
||||
public Map<String, Object> asMap() {
|
||||
Map<String, Object> map = new LinkedHashMap<>();
|
||||
if (baseline != null) {
|
||||
map.put(BASELINE.getPreferredName(), baseline);
|
||||
}
|
||||
if (classBaselines.isEmpty() == false) {
|
||||
map.put(CLASSES.getPreferredName(), classBaselines.stream().map(ClassBaseline::asMap).collect(Collectors.toList()));
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(baseline, classBaselines);
|
||||
}
|
||||
|
||||
public static class ClassBaseline implements ToXContentObject, Writeable {
|
||||
private static final String NAME = "feature_importance_class_baseline";
|
||||
|
||||
public static final ParseField CLASS_NAME = new ParseField("class_name");
|
||||
|
||||
public static final ConstructingObjectParser<ClassBaseline, Void> LENIENT_PARSER = createParser(true);
|
||||
public static final ConstructingObjectParser<ClassBaseline, Void> STRICT_PARSER = createParser(false);
|
||||
|
||||
private static ConstructingObjectParser<ClassBaseline, Void> createParser(boolean ignoreUnknownFields) {
|
||||
ConstructingObjectParser<ClassBaseline, Void> parser = new ConstructingObjectParser<>(NAME,
|
||||
ignoreUnknownFields,
|
||||
a -> new ClassBaseline(a[0], (double)a[1]));
|
||||
parser.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> {
|
||||
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
|
||||
return p.text();
|
||||
} else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) {
|
||||
return p.numberValue();
|
||||
} else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) {
|
||||
return p.booleanValue();
|
||||
}
|
||||
throw new XContentParseException("Unsupported token [" + p.currentToken() + "]");
|
||||
}, CLASS_NAME, ObjectParser.ValueType.VALUE);
|
||||
parser.declareDouble(ConstructingObjectParser.constructorArg(), BASELINE);
|
||||
return parser;
|
||||
}
|
||||
|
||||
public static ClassBaseline fromXContent(XContentParser parser, boolean lenient) throws IOException {
|
||||
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
|
||||
}
|
||||
|
||||
public final Object className;
|
||||
public final double baseline;
|
||||
|
||||
public ClassBaseline(StreamInput in) throws IOException {
|
||||
this.className = in.readGenericValue();
|
||||
this.baseline = in.readDouble();
|
||||
}
|
||||
|
||||
ClassBaseline(Object className, double baseline) {
|
||||
this.className = className;
|
||||
this.baseline = baseline;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeGenericValue(className);
|
||||
out.writeDouble(baseline);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
return builder.map(asMap());
|
||||
}
|
||||
|
||||
private Map<String, Object> asMap() {
|
||||
Map<String, Object> map = new LinkedHashMap<>();
|
||||
map.put(CLASS_NAME.getPreferredName(), className);
|
||||
map.put(BASELINE.getPreferredName(), baseline);
|
||||
return map;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
ClassBaseline that = (ClassBaseline) o;
|
||||
return Objects.equals(that.className, className) && Objects.equals(baseline, that.baseline);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(className, baseline);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
|
@ -27,6 +27,7 @@ public class TrainedModelMetadata implements ToXContentObject, Writeable {
|
|||
|
||||
public static final String NAME = "trained_model_metadata";
|
||||
public static final ParseField TOTAL_FEATURE_IMPORTANCE = new ParseField("total_feature_importance");
|
||||
public static final ParseField FEATURE_IMPORTANCE_BASELINE = new ParseField("feature_importance_baseline");
|
||||
public static final ParseField MODEL_ID = new ParseField("model_id");
|
||||
|
||||
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
|
||||
|
@ -37,11 +38,14 @@ public class TrainedModelMetadata implements ToXContentObject, Writeable {
|
|||
private static ConstructingObjectParser<TrainedModelMetadata, Void> createParser(boolean ignoreUnknownFields) {
|
||||
ConstructingObjectParser<TrainedModelMetadata, Void> parser = new ConstructingObjectParser<>(NAME,
|
||||
ignoreUnknownFields,
|
||||
a -> new TrainedModelMetadata((String)a[0], (List<TotalFeatureImportance>)a[1]));
|
||||
a -> new TrainedModelMetadata((String)a[0], (List<TotalFeatureImportance>)a[1], (FeatureImportanceBaseline)a[2]));
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), MODEL_ID);
|
||||
parser.declareObjectArray(ConstructingObjectParser.constructorArg(),
|
||||
ignoreUnknownFields ? TotalFeatureImportance.LENIENT_PARSER : TotalFeatureImportance.STRICT_PARSER,
|
||||
TOTAL_FEATURE_IMPORTANCE);
|
||||
parser.declareObject(ConstructingObjectParser.optionalConstructorArg(),
|
||||
ignoreUnknownFields ? FeatureImportanceBaseline.LENIENT_PARSER : FeatureImportanceBaseline.STRICT_PARSER,
|
||||
FEATURE_IMPORTANCE_BASELINE);
|
||||
return parser;
|
||||
}
|
||||
|
||||
|
@ -58,16 +62,21 @@ public class TrainedModelMetadata implements ToXContentObject, Writeable {
|
|||
}
|
||||
|
||||
private final List<TotalFeatureImportance> totalFeatureImportances;
|
||||
private final FeatureImportanceBaseline featureImportanceBaselines;
|
||||
private final String modelId;
|
||||
|
||||
public TrainedModelMetadata(StreamInput in) throws IOException {
|
||||
this.modelId = in.readString();
|
||||
this.totalFeatureImportances = in.readList(TotalFeatureImportance::new);
|
||||
this.featureImportanceBaselines = in.readOptionalWriteable(FeatureImportanceBaseline::new);
|
||||
}
|
||||
|
||||
public TrainedModelMetadata(String modelId, List<TotalFeatureImportance> totalFeatureImportances) {
|
||||
public TrainedModelMetadata(String modelId,
|
||||
List<TotalFeatureImportance> totalFeatureImportances,
|
||||
FeatureImportanceBaseline featureImportanceBaselines) {
|
||||
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
|
||||
this.totalFeatureImportances = Collections.unmodifiableList(totalFeatureImportances);
|
||||
this.featureImportanceBaselines = featureImportanceBaselines;
|
||||
}
|
||||
|
||||
public String getModelId() {
|
||||
|
@ -82,24 +91,30 @@ public class TrainedModelMetadata implements ToXContentObject, Writeable {
|
|||
return totalFeatureImportances;
|
||||
}
|
||||
|
||||
public FeatureImportanceBaseline getFeatureImportanceBaselines() {
|
||||
return featureImportanceBaselines;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
TrainedModelMetadata that = (TrainedModelMetadata) o;
|
||||
return Objects.equals(totalFeatureImportances, that.totalFeatureImportances) &&
|
||||
Objects.equals(featureImportanceBaselines, that.featureImportanceBaselines) &&
|
||||
Objects.equals(modelId, that.modelId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(totalFeatureImportances, modelId);
|
||||
return Objects.hash(totalFeatureImportances, featureImportanceBaselines, modelId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(modelId);
|
||||
out.writeList(totalFeatureImportances);
|
||||
out.writeOptionalWriteable(featureImportanceBaselines);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -110,6 +125,9 @@ public class TrainedModelMetadata implements ToXContentObject, Writeable {
|
|||
}
|
||||
builder.field(MODEL_ID.getPreferredName(), modelId);
|
||||
builder.field(TOTAL_FEATURE_IMPORTANCE.getPreferredName(), totalFeatureImportances);
|
||||
if (featureImportanceBaselines != null) {
|
||||
builder.field(FEATURE_IMPORTANCE_BASELINE.getPreferredName(), featureImportanceBaselines);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
|
|
@ -71,6 +71,19 @@
|
|||
"inference_config": {
|
||||
"enabled": false
|
||||
},
|
||||
"feature_importance_baseline": {
|
||||
"properties": {
|
||||
"baseline": {
|
||||
"type": "double"
|
||||
},
|
||||
"classes": {
|
||||
"properties": {
|
||||
"class_name": { "type": "keyword"},
|
||||
"baseline": {"type" : "double"}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"total_feature_importance": {
|
||||
"type": "nested",
|
||||
"dynamic": "false",
|
||||
|
|
|
@ -10,6 +10,7 @@ import org.elasticsearch.common.io.stream.Writeable;
|
|||
import org.elasticsearch.xpack.core.action.util.PageParams;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Includes;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
|
@ -24,7 +25,9 @@ public class GetTrainedModelsRequestTests extends AbstractBWCWireSerializationTe
|
|||
randomBoolean() ? null :
|
||||
randomList(10, () -> randomAlphaOfLength(10)),
|
||||
randomBoolean() ? null :
|
||||
Stream.generate(() -> randomFrom(Request.DEFINITION, Request.TOTAL_FEATURE_IMPORTANCE))
|
||||
Stream.generate(() -> randomFrom(Includes.DEFINITION,
|
||||
Includes.TOTAL_FEATURE_IMPORTANCE,
|
||||
Includes.FEATURE_IMPORTANCE_BASELINE))
|
||||
.limit(4)
|
||||
.collect(Collectors.toSet()));
|
||||
request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100)));
|
||||
|
@ -40,8 +43,8 @@ public class GetTrainedModelsRequestTests extends AbstractBWCWireSerializationTe
|
|||
protected Request mutateInstanceForVersion(Request instance, Version version) {
|
||||
if (version.before(Version.V_7_10_0)) {
|
||||
Set<String> includes = new HashSet<>();
|
||||
if (instance.isIncludeModelDefinition()) {
|
||||
includes.add(Request.DEFINITION);
|
||||
if (instance.getIncludes().isIncludeModelDefinition()) {
|
||||
includes.add(Includes.DEFINITION);
|
||||
}
|
||||
Request request = new Request(
|
||||
instance.getResourceId(),
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
|
||||
public class FeatureImportanceBaselineTests extends AbstractBWCSerializationTestCase<FeatureImportanceBaseline> {
|
||||
|
||||
private boolean lenient;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static FeatureImportanceBaseline randomInstance() {
|
||||
Supplier<Object> classNameGenerator = randomFrom(
|
||||
() -> randomAlphaOfLength(10),
|
||||
ESTestCase::randomBoolean,
|
||||
() -> randomIntBetween(0, 10)
|
||||
);
|
||||
return new FeatureImportanceBaseline(
|
||||
randomBoolean() ? null : randomDouble(),
|
||||
randomBoolean() ?
|
||||
null :
|
||||
Stream.generate(() -> new FeatureImportanceBaseline.ClassBaseline(classNameGenerator.get(), randomDouble()))
|
||||
.limit(randomIntBetween(1, 10))
|
||||
.collect(Collectors.toList())
|
||||
);
|
||||
}
|
||||
|
||||
@Before
|
||||
public void chooseStrictOrLenient() {
|
||||
lenient = randomBoolean();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected FeatureImportanceBaseline createTestInstance() {
|
||||
return randomInstance();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<FeatureImportanceBaseline> instanceReader() {
|
||||
return FeatureImportanceBaseline::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected FeatureImportanceBaseline doParseInstance(XContentParser parser) throws IOException {
|
||||
return FeatureImportanceBaseline.fromXContent(parser, lenient);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return lenient;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected FeatureImportanceBaseline mutateInstanceForVersion(FeatureImportanceBaseline instance, Version version) {
|
||||
return instance;
|
||||
}
|
||||
}
|
|
@ -23,7 +23,8 @@ public class TrainedModelMetadataTests extends AbstractBWCSerializationTestCase<
|
|||
public static TrainedModelMetadata randomInstance() {
|
||||
return new TrainedModelMetadata(
|
||||
randomAlphaOfLength(10),
|
||||
Stream.generate(TotalFeatureImportanceTests::randomInstance).limit(randomIntBetween(1, 10)).collect(Collectors.toList()));
|
||||
Stream.generate(TotalFeatureImportanceTests::randomInstance).limit(randomIntBetween(1, 10)).collect(Collectors.toList()),
|
||||
randomBoolean() ? null : FeatureImportanceBaselineTests.randomInstance());
|
||||
}
|
||||
|
||||
@Before
|
||||
|
|
|
@ -12,6 +12,7 @@ import org.elasticsearch.common.collect.Tuple;
|
|||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.license.License;
|
||||
import org.elasticsearch.xpack.core.action.util.PageParams;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
||||
|
@ -22,6 +23,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
|||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInputTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.FeatureImportanceBaselineTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportanceTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata;
|
||||
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
|
||||
|
@ -90,7 +92,8 @@ public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {
|
|||
}
|
||||
ModelMetadata modelMetadata = new ModelMetadata(Stream.generate(TotalFeatureImportanceTests::randomInstance)
|
||||
.limit(randomIntBetween(1, 10))
|
||||
.collect(Collectors.toList()));
|
||||
.collect(Collectors.toList()),
|
||||
FeatureImportanceBaselineTests.randomInstance());
|
||||
persister.createAndIndexInferenceModelMetadata(modelMetadata);
|
||||
|
||||
PlainActionFuture<Tuple<Long, Set<String>>> getIdsFuture = new PlainActionFuture<>();
|
||||
|
@ -100,13 +103,14 @@ public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {
|
|||
String inferenceModelId = ids.v2().iterator().next();
|
||||
|
||||
PlainActionFuture<TrainedModelConfig> getTrainedModelFuture = new PlainActionFuture<>();
|
||||
trainedModelProvider.getTrainedModel(inferenceModelId, true, true, getTrainedModelFuture);
|
||||
trainedModelProvider.getTrainedModel(inferenceModelId, GetTrainedModelsAction.Includes.all(), getTrainedModelFuture);
|
||||
|
||||
TrainedModelConfig storedConfig = getTrainedModelFuture.actionGet();
|
||||
assertThat(storedConfig.getCompressedDefinition(), equalTo(compressedDefinition));
|
||||
assertThat(storedConfig.getEstimatedOperations(), equalTo((long)modelSizeInfo.numOperations()));
|
||||
assertThat(storedConfig.getEstimatedHeapMemory(), equalTo(modelSizeInfo.ramBytesUsed()));
|
||||
assertThat(storedConfig.getMetadata(), hasKey("total_feature_importance"));
|
||||
assertThat(storedConfig.getMetadata(), hasKey("feature_importance_baseline"));
|
||||
|
||||
PlainActionFuture<Map<String, TrainedModelMetadata>> getTrainedMetadataFuture = new PlainActionFuture<>();
|
||||
trainedModelProvider.getTrainedModelMetadata(Collections.singletonList(inferenceModelId), getTrainedMetadataFuture);
|
||||
|
|
|
@ -17,6 +17,7 @@ import org.elasticsearch.common.xcontent.XContentFactory;
|
|||
import org.elasticsearch.index.mapper.MapperService;
|
||||
import org.elasticsearch.license.License;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
|
||||
|
@ -91,7 +92,7 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
|||
|
||||
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
||||
blockingCall(
|
||||
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
|
||||
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), listener),
|
||||
getConfigHolder,
|
||||
exceptionHolder);
|
||||
getConfigHolder.get().ensureParsedDefinition(xContentRegistry());
|
||||
|
@ -125,7 +126,7 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
|||
|
||||
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
||||
blockingCall(listener ->
|
||||
trainedModelProvider.getTrainedModel(modelId, false, false, listener),
|
||||
trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), listener),
|
||||
getConfigHolder,
|
||||
exceptionHolder);
|
||||
getConfigHolder.get().ensureParsedDefinition(xContentRegistry());
|
||||
|
@ -139,7 +140,7 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
|||
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
||||
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
|
||||
blockingCall(
|
||||
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
|
||||
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), listener),
|
||||
getConfigHolder,
|
||||
exceptionHolder);
|
||||
assertThat(exceptionHolder.get(), is(not(nullValue())));
|
||||
|
@ -164,7 +165,7 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
|||
|
||||
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
||||
blockingCall(
|
||||
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
|
||||
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), listener),
|
||||
getConfigHolder,
|
||||
exceptionHolder);
|
||||
assertThat(exceptionHolder.get(), is(not(nullValue())));
|
||||
|
@ -206,7 +207,7 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
|||
|
||||
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
||||
blockingCall(
|
||||
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
|
||||
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), listener),
|
||||
getConfigHolder,
|
||||
exceptionHolder);
|
||||
assertThat(getConfigHolder.get(), is(nullValue()));
|
||||
|
@ -254,7 +255,7 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
|||
}
|
||||
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
||||
blockingCall(
|
||||
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
|
||||
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), listener),
|
||||
getConfigHolder,
|
||||
exceptionHolder);
|
||||
assertThat(getConfigHolder.get(), is(nullValue()));
|
||||
|
|
|
@ -49,18 +49,17 @@ public class TransportGetTrainedModelsAction extends HandledTransportAction<Requ
|
|||
return;
|
||||
}
|
||||
|
||||
if (request.isIncludeModelDefinition() && totalAndIds.v2().size() > 1) {
|
||||
if (request.getIncludes().isIncludeModelDefinition() && totalAndIds.v2().size() > 1) {
|
||||
listener.onFailure(
|
||||
ExceptionsHelper.badRequestException(Messages.INFERENCE_TOO_MANY_DEFINITIONS_REQUESTED)
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (request.isIncludeModelDefinition()) {
|
||||
if (request.getIncludes().isIncludeModelDefinition()) {
|
||||
provider.getTrainedModel(
|
||||
totalAndIds.v2().iterator().next(),
|
||||
true,
|
||||
request.isIncludeTotalFeatureImportance(),
|
||||
request.getIncludes(),
|
||||
ActionListener.wrap(
|
||||
config -> listener.onResponse(responseBuilder.setModels(Collections.singletonList(config)).build()),
|
||||
listener::onFailure
|
||||
|
@ -69,8 +68,8 @@ public class TransportGetTrainedModelsAction extends HandledTransportAction<Requ
|
|||
} else {
|
||||
provider.getTrainedModels(
|
||||
totalAndIds.v2(),
|
||||
request.getIncludes(),
|
||||
request.isAllowNoResources(),
|
||||
request.isIncludeTotalFeatureImportance(),
|
||||
ActionListener.wrap(
|
||||
configs -> listener.onResponse(responseBuilder.setModels(configs).build()),
|
||||
listener::onFailure
|
||||
|
|
|
@ -16,6 +16,7 @@ import org.elasticsearch.tasks.Task;
|
|||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.transport.TransportService;
|
||||
import org.elasticsearch.xpack.core.XPackField;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Request;
|
||||
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Response;
|
||||
|
@ -82,7 +83,7 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
|
|||
responseBuilder.setLicensed(true);
|
||||
this.modelLoadingService.getModelForPipeline(request.getModelId(), getModelListener);
|
||||
} else {
|
||||
trainedModelProvider.getTrainedModel(request.getModelId(), false, false, ActionListener.wrap(
|
||||
trainedModelProvider.getTrainedModel(request.getModelId(), GetTrainedModelsAction.Includes.empty(), ActionListener.wrap(
|
||||
trainedModelConfig -> {
|
||||
responseBuilder.setLicensed(licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel()));
|
||||
if (licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel()) || request.isPreviouslyLicensed()) {
|
||||
|
|
|
@ -129,7 +129,8 @@ public class ChunkedTrainedModelPersister {
|
|||
return;
|
||||
}
|
||||
TrainedModelMetadata trainedModelMetadata = new TrainedModelMetadata(this.currentModelId.get(),
|
||||
modelMetadata.getFeatureImportances());
|
||||
modelMetadata.getFeatureImportances(),
|
||||
modelMetadata.getFeatureImportanceBaseline());
|
||||
|
||||
|
||||
CountDownLatch latch = storeTrainedModelMetadata(trainedModelMetadata);
|
||||
|
|
|
@ -10,6 +10,7 @@ import org.elasticsearch.common.ParseField;
|
|||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.FeatureImportanceBaseline;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportance;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -17,47 +18,60 @@ import java.util.List;
|
|||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
public class ModelMetadata implements ToXContentObject {
|
||||
|
||||
public static final ParseField TOTAL_FEATURE_IMPORTANCE = new ParseField("total_feature_importance");
|
||||
public static final ParseField FEATURE_IMPORTANCE_BASELINE = new ParseField("feature_importance_baseline");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static final ConstructingObjectParser<ModelMetadata, Void> PARSER = new ConstructingObjectParser<>(
|
||||
"trained_model_metadata",
|
||||
a -> new ModelMetadata((List<TotalFeatureImportance>) a[0]));
|
||||
a -> new ModelMetadata((List<TotalFeatureImportance>) a[0], (FeatureImportanceBaseline) a[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareObjectArray(constructorArg(), TotalFeatureImportance.STRICT_PARSER, TOTAL_FEATURE_IMPORTANCE);
|
||||
PARSER.declareObject(optionalConstructorArg(), FeatureImportanceBaseline.STRICT_PARSER, FEATURE_IMPORTANCE_BASELINE);
|
||||
}
|
||||
|
||||
private final List<TotalFeatureImportance> featureImportances;
|
||||
private final FeatureImportanceBaseline featureImportanceBaseline;
|
||||
|
||||
public ModelMetadata(List<TotalFeatureImportance> featureImportances) {
|
||||
public ModelMetadata(List<TotalFeatureImportance> featureImportances, FeatureImportanceBaseline featureImportanceBaseline) {
|
||||
this.featureImportances = featureImportances;
|
||||
this.featureImportanceBaseline = featureImportanceBaseline;
|
||||
}
|
||||
|
||||
public List<TotalFeatureImportance> getFeatureImportances() {
|
||||
return featureImportances;
|
||||
}
|
||||
|
||||
public FeatureImportanceBaseline getFeatureImportanceBaseline() {
|
||||
return featureImportanceBaseline;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
ModelMetadata that = (ModelMetadata) o;
|
||||
return Objects.equals(featureImportances, that.featureImportances);
|
||||
return Objects.equals(featureImportances, that.featureImportances)
|
||||
&& Objects.equals(featureImportanceBaseline, that.featureImportanceBaseline);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(featureImportances);
|
||||
return Objects.hash(featureImportances, featureImportanceBaseline);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(TOTAL_FEATURE_IMPORTANCE.getPreferredName(), featureImportances);
|
||||
if (featureImportanceBaseline != null) {
|
||||
builder.field(FEATURE_IMPORTANCE_BASELINE.getPreferredName(), featureImportanceBaseline);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@ import org.elasticsearch.common.unit.TimeValue;
|
|||
import org.elasticsearch.common.util.set.Sets;
|
||||
import org.elasticsearch.ingest.IngestMetadata;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
|
@ -270,7 +271,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
}
|
||||
|
||||
private void loadModel(String modelId, Consumer consumer) {
|
||||
provider.getTrainedModel(modelId, false, false, ActionListener.wrap(
|
||||
provider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), ActionListener.wrap(
|
||||
trainedModelConfig -> {
|
||||
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
|
||||
provider.getTrainedModelForInference(modelId, ActionListener.wrap(
|
||||
|
@ -306,7 +307,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
// If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called
|
||||
// by a simulated pipeline
|
||||
logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", modelId));
|
||||
provider.getTrainedModel(modelId, false, false, ActionListener.wrap(
|
||||
provider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), ActionListener.wrap(
|
||||
trainedModelConfig -> {
|
||||
// Verify we can pull the model into memory without causing OOM
|
||||
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
|
||||
|
|
|
@ -68,6 +68,7 @@ import org.elasticsearch.search.sort.SortOrder;
|
|||
import org.elasticsearch.xpack.core.action.util.ExpandedIdsMatcher;
|
||||
import org.elasticsearch.xpack.core.action.util.PageParams;
|
||||
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
|
||||
import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
|
@ -439,13 +440,12 @@ public class TrainedModelProvider {
|
|||
}
|
||||
|
||||
public void getTrainedModel(final String modelId,
|
||||
final boolean includeDefinition,
|
||||
final boolean includeTotalFeatureImportance,
|
||||
final GetTrainedModelsAction.Includes includes,
|
||||
final ActionListener<TrainedModelConfig> finalListener) {
|
||||
|
||||
if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
|
||||
try {
|
||||
finalListener.onResponse(loadModelFromResource(modelId, includeDefinition == false).build());
|
||||
finalListener.onResponse(loadModelFromResource(modelId, includes.isIncludeModelDefinition() == false).build());
|
||||
return;
|
||||
} catch (ElasticsearchException ex) {
|
||||
finalListener.onFailure(ex);
|
||||
|
@ -455,7 +455,7 @@ public class TrainedModelProvider {
|
|||
|
||||
ActionListener<TrainedModelConfig.Builder> getTrainedModelListener = ActionListener.wrap(
|
||||
modelBuilder -> {
|
||||
if (includeTotalFeatureImportance == false) {
|
||||
if ((includes.isIncludeFeatureImportanceBaseline() || includes.isIncludeTotalFeatureImportance()) == false) {
|
||||
finalListener.onResponse(modelBuilder.build());
|
||||
return;
|
||||
}
|
||||
|
@ -463,8 +463,13 @@ public class TrainedModelProvider {
|
|||
metadata -> {
|
||||
TrainedModelMetadata modelMetadata = metadata.get(modelId);
|
||||
if (modelMetadata != null) {
|
||||
if (includes.isIncludeTotalFeatureImportance()) {
|
||||
modelBuilder.setFeatureImportance(modelMetadata.getTotalFeatureImportances());
|
||||
}
|
||||
if (includes.isIncludeFeatureImportanceBaseline()) {
|
||||
modelBuilder.setBaselineFeatureImportance(modelMetadata.getFeatureImportanceBaselines());
|
||||
}
|
||||
}
|
||||
finalListener.onResponse(modelBuilder.build());
|
||||
},
|
||||
failure -> {
|
||||
|
@ -493,7 +498,7 @@ public class TrainedModelProvider {
|
|||
.setSize(1)
|
||||
.request());
|
||||
|
||||
if (includeDefinition) {
|
||||
if (includes.isIncludeModelDefinition()) {
|
||||
multiSearchRequestBuilder.add(client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
|
||||
.setQuery(QueryBuilders.constantScoreQuery(QueryBuilders
|
||||
.boolQuery()
|
||||
|
@ -527,7 +532,7 @@ public class TrainedModelProvider {
|
|||
return;
|
||||
}
|
||||
|
||||
if (includeDefinition) {
|
||||
if (includes.isIncludeModelDefinition()) {
|
||||
try {
|
||||
List<TrainedModelDefinitionDoc> docs = handleSearchItems(multiSearchResponse.getResponses()[1],
|
||||
modelId,
|
||||
|
@ -569,8 +574,8 @@ public class TrainedModelProvider {
|
|||
* It assumes that there are fewer than 10k.
|
||||
*/
|
||||
public void getTrainedModels(Set<String> modelIds,
|
||||
GetTrainedModelsAction.Includes includes,
|
||||
boolean allowNoResources,
|
||||
boolean includeTotalFeatureImportance,
|
||||
final ActionListener<List<TrainedModelConfig>> finalListener) {
|
||||
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(modelIds.toArray(new String[0])));
|
||||
|
||||
|
@ -601,7 +606,7 @@ public class TrainedModelProvider {
|
|||
|
||||
ActionListener<List<TrainedModelConfig.Builder>> getTrainedModelListener = ActionListener.wrap(
|
||||
modelBuilders -> {
|
||||
if (includeTotalFeatureImportance == false) {
|
||||
if ((includes.isIncludeFeatureImportanceBaseline() || includes.isIncludeTotalFeatureImportance()) == false) {
|
||||
finalListener.onResponse(modelBuilders.stream()
|
||||
.map(TrainedModelConfig.Builder::build)
|
||||
.sorted(Comparator.comparing(TrainedModelConfig::getModelId))
|
||||
|
@ -614,8 +619,13 @@ public class TrainedModelProvider {
|
|||
.map(builder -> {
|
||||
TrainedModelMetadata modelMetadata = metadata.get(builder.getModelId());
|
||||
if (modelMetadata != null) {
|
||||
if (includes.isIncludeTotalFeatureImportance()) {
|
||||
builder.setFeatureImportance(modelMetadata.getTotalFeatureImportances());
|
||||
}
|
||||
if (includes.isIncludeFeatureImportanceBaseline()) {
|
||||
builder.setBaselineFeatureImportance(modelMetadata.getFeatureImportanceBaselines());
|
||||
}
|
||||
}
|
||||
return builder.build();
|
||||
})
|
||||
.sorted(Comparator.comparing(TrainedModelConfig::getModelId))
|
||||
|
|
|
@ -18,6 +18,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.FeatureImportanceBaselineTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportanceTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata;
|
||||
import org.elasticsearch.xpack.core.security.user.XPackUser;
|
||||
|
@ -103,7 +104,8 @@ public class ChunkedTrainedModelPersisterTests extends ESTestCase {
|
|||
TrainedModelDefinitionChunk chunk2 = new TrainedModelDefinitionChunk(randomAlphaOfLength(10), 1, true);
|
||||
ModelMetadata modelMetadata = new ModelMetadata(Stream.generate(TotalFeatureImportanceTests::randomInstance)
|
||||
.limit(randomIntBetween(1, 10))
|
||||
.collect(Collectors.toList()));
|
||||
.collect(Collectors.toList()),
|
||||
FeatureImportanceBaselineTests.randomInstance());
|
||||
|
||||
resultProcessor.createAndIndexInferenceModelConfig(modelSizeInfo);
|
||||
resultProcessor.createAndIndexInferenceModelDoc(chunk1);
|
||||
|
|
|
@ -16,6 +16,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsageTests;
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStatsTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStatsTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.FeatureImportanceBaselineTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportanceTests;
|
||||
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||
|
@ -69,7 +70,8 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
|
|||
if (randomBoolean()) {
|
||||
builder.setModelMetadata(new ModelMetadata(Stream.generate(TotalFeatureImportanceTests::randomInstance)
|
||||
.limit(randomIntBetween(1, 10))
|
||||
.collect(Collectors.toList())));
|
||||
.collect(Collectors.toList()),
|
||||
FeatureImportanceBaselineTests.randomInstance()));
|
||||
}
|
||||
return builder.build();
|
||||
}
|
||||
|
|
|
@ -34,6 +34,7 @@ import org.elasticsearch.test.ESTestCase;
|
|||
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
|
||||
import org.elasticsearch.threadpool.TestThreadPool;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
|
@ -437,9 +438,9 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
// the loading occurred or which models are currently in the cache due to evictions.
|
||||
// Verify that we have at least loaded all three
|
||||
assertBusy(() -> {
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(false), eq(false), any());
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(false), eq(false), any());
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(false), eq(false), any());
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(GetTrainedModelsAction.Includes.empty()), any());
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(GetTrainedModelsAction.Includes.empty()), any());
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(GetTrainedModelsAction.Includes.empty()), any());
|
||||
});
|
||||
assertBusy(() -> {
|
||||
assertThat(circuitBreaker.getUsed(), equalTo(10L));
|
||||
|
@ -553,10 +554,10 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
}).when(trainedModelProvider).getTrainedModelForInference(eq(modelId), any());
|
||||
doAnswer(invocationOnMock -> {
|
||||
@SuppressWarnings("rawtypes")
|
||||
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[3];
|
||||
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
|
||||
listener.onResponse(trainedModelConfig);
|
||||
return null;
|
||||
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), eq(false), any());
|
||||
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(GetTrainedModelsAction.Includes.empty()), any());
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
|
@ -564,20 +565,20 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
if (randomBoolean()) {
|
||||
doAnswer(invocationOnMock -> {
|
||||
@SuppressWarnings("rawtypes")
|
||||
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[3];
|
||||
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
|
||||
listener.onFailure(new ResourceNotFoundException(
|
||||
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
|
||||
return null;
|
||||
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), eq(false), any());
|
||||
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(GetTrainedModelsAction.Includes.empty()), any());
|
||||
} else {
|
||||
TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class);
|
||||
when(trainedModelConfig.getEstimatedHeapMemory()).thenReturn(0L);
|
||||
doAnswer(invocationOnMock -> {
|
||||
@SuppressWarnings("rawtypes")
|
||||
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[3];
|
||||
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
|
||||
listener.onResponse(trainedModelConfig);
|
||||
return null;
|
||||
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), eq(false), any());
|
||||
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(GetTrainedModelsAction.Includes.empty()), any());
|
||||
doAnswer(invocationOnMock -> {
|
||||
@SuppressWarnings("rawtypes")
|
||||
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1];
|
||||
|
|
|
@ -15,6 +15,7 @@ import org.elasticsearch.index.query.QueryBuilder;
|
|||
import org.elasticsearch.index.query.TermQueryBuilder;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.action.util.PageParams;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
|
||||
|
@ -57,14 +58,14 @@ public class TrainedModelProviderTests extends ESTestCase {
|
|||
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
|
||||
for(String modelId : TrainedModelProvider.MODELS_STORED_AS_RESOURCE) {
|
||||
PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
|
||||
trainedModelProvider.getTrainedModel(modelId, true, false, future);
|
||||
trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(),future);
|
||||
TrainedModelConfig configWithDefinition = future.actionGet();
|
||||
|
||||
assertThat(configWithDefinition.getModelId(), equalTo(modelId));
|
||||
assertThat(configWithDefinition.ensureParsedDefinition(xContentRegistry()).getModelDefinition(), is(not(nullValue())));
|
||||
|
||||
PlainActionFuture<TrainedModelConfig> futureNoDefinition = new PlainActionFuture<>();
|
||||
trainedModelProvider.getTrainedModel(modelId, false, false, futureNoDefinition);
|
||||
trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), futureNoDefinition);
|
||||
TrainedModelConfig configWithoutDefinition = futureNoDefinition.actionGet();
|
||||
|
||||
assertThat(configWithoutDefinition.getModelId(), equalTo(modelId));
|
||||
|
|
|
@ -9,6 +9,7 @@ import org.elasticsearch.action.support.PlainActionFuture;
|
|||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
|
@ -33,7 +34,7 @@ public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {
|
|||
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
|
||||
PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
|
||||
// Should be OK as we don't make any client calls
|
||||
trainedModelProvider.getTrainedModel("lang_ident_model_1", true, false, future);
|
||||
trainedModelProvider.getTrainedModel("lang_ident_model_1", GetTrainedModelsAction.Includes.forModelDefinition(), future);
|
||||
TrainedModelConfig config = future.actionGet();
|
||||
|
||||
config.ensureParsedDefinition(xContentRegistry());
|
||||
|
|
Loading…
Reference in New Issue