* [ML] Start gathering and storing inference stats (#53429) This PR enables stats on inference to be gathered and stored in the `.ml-stats-*` indices. Each node + model_id will have its own running stats document and these will later be summed together when returning _stats to the user. `.ml-stats-*` is ILM managed (when possible). So, at any point the underlying index could change. This means that a stats document that is read in and then later updated will actually be a new doc in a new index. This complicates matters as this means that having a running knowledge of seq_no and primary_term is complicated and almost impossible. This is because we don't know the latest index name. We should also strive for throughput, as this code sits in the middle of an ingest pipeline (or even a query).
This commit is contained in:
parent
7a8a66d9ae
commit
c5c7ee9d73
|
@ -6,6 +6,7 @@
|
|||
package org.elasticsearch.xpack.core.ml.action;
|
||||
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.action.ActionRequestBuilder;
|
||||
import org.elasticsearch.action.ActionType;
|
||||
import org.elasticsearch.client.ElasticsearchClient;
|
||||
|
@ -20,6 +21,7 @@ import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest;
|
|||
import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse;
|
||||
import org.elasticsearch.xpack.core.action.util.QueryPage;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
|
@ -37,6 +39,7 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
|
|||
|
||||
public static final ParseField MODEL_ID = new ParseField("model_id");
|
||||
public static final ParseField PIPELINE_COUNT = new ParseField("pipeline_count");
|
||||
public static final ParseField INFERENCE_STATS = new ParseField("inference_stats");
|
||||
|
||||
private GetTrainedModelsStatsAction() {
|
||||
super(NAME, GetTrainedModelsStatsAction.Response::new);
|
||||
|
@ -78,25 +81,32 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
|
|||
public static class TrainedModelStats implements ToXContentObject, Writeable {
|
||||
private final String modelId;
|
||||
private final IngestStats ingestStats;
|
||||
private final InferenceStats inferenceStats;
|
||||
private final int pipelineCount;
|
||||
|
||||
private static final IngestStats EMPTY_INGEST_STATS = new IngestStats(new IngestStats.Stats(0, 0, 0, 0),
|
||||
Collections.emptyList(),
|
||||
Collections.emptyMap());
|
||||
|
||||
public TrainedModelStats(String modelId, IngestStats ingestStats, int pipelineCount) {
|
||||
public TrainedModelStats(String modelId, IngestStats ingestStats, int pipelineCount, InferenceStats inferenceStats) {
|
||||
this.modelId = Objects.requireNonNull(modelId);
|
||||
this.ingestStats = ingestStats == null ? EMPTY_INGEST_STATS : ingestStats;
|
||||
if (pipelineCount < 0) {
|
||||
throw new ElasticsearchException("[{}] must be a greater than or equal to 0", PIPELINE_COUNT.getPreferredName());
|
||||
}
|
||||
this.pipelineCount = pipelineCount;
|
||||
this.inferenceStats = inferenceStats;
|
||||
}
|
||||
|
||||
public TrainedModelStats(StreamInput in) throws IOException {
|
||||
modelId = in.readString();
|
||||
ingestStats = new IngestStats(in);
|
||||
pipelineCount = in.readVInt();
|
||||
if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
|
||||
this.inferenceStats = in.readOptionalWriteable(InferenceStats::new);
|
||||
} else {
|
||||
this.inferenceStats = null;
|
||||
}
|
||||
}
|
||||
|
||||
public String getModelId() {
|
||||
|
@ -120,6 +130,9 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
|
|||
// Ingest stats is a fragment
|
||||
ingestStats.toXContent(builder, params);
|
||||
}
|
||||
if (this.inferenceStats != null) {
|
||||
builder.field(INFERENCE_STATS.getPreferredName(), this.inferenceStats);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -129,11 +142,14 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
|
|||
out.writeString(modelId);
|
||||
ingestStats.writeTo(out);
|
||||
out.writeVInt(pipelineCount);
|
||||
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
|
||||
out.writeOptionalWriteable(this.inferenceStats);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(modelId, ingestStats, pipelineCount);
|
||||
return Objects.hash(modelId, ingestStats, pipelineCount, inferenceStats);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -147,7 +163,8 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
|
|||
TrainedModelStats other = (TrainedModelStats) obj;
|
||||
return Objects.equals(this.modelId, other.modelId)
|
||||
&& Objects.equals(this.ingestStats, other.ingestStats)
|
||||
&& Objects.equals(this.pipelineCount, other.pipelineCount);
|
||||
&& Objects.equals(this.pipelineCount, other.pipelineCount)
|
||||
&& Objects.equals(this.inferenceStats, other.inferenceStats);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -171,6 +188,7 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
|
|||
private long totalModelCount;
|
||||
private Set<String> expandedIds;
|
||||
private Map<String, IngestStats> ingestStatsMap;
|
||||
private Map<String, InferenceStats> inferenceStatsMap;
|
||||
|
||||
public Builder setTotalModelCount(long totalModelCount) {
|
||||
this.totalModelCount = totalModelCount;
|
||||
|
@ -191,13 +209,23 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setInferenceStatsByModelId(Map<String, InferenceStats> infereceStatsByModelId) {
|
||||
this.inferenceStatsMap = infereceStatsByModelId;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Response build() {
|
||||
List<TrainedModelStats> trainedModelStats = new ArrayList<>(expandedIds.size());
|
||||
expandedIds.forEach(id -> {
|
||||
IngestStats ingestStats = ingestStatsMap.get(id);
|
||||
trainedModelStats.add(new TrainedModelStats(id, ingestStats, ingestStats == null ?
|
||||
InferenceStats inferenceStats = inferenceStatsMap.get(id);
|
||||
trainedModelStats.add(new TrainedModelStats(
|
||||
id,
|
||||
ingestStats,
|
||||
ingestStats == null ?
|
||||
0 :
|
||||
ingestStats.getPipelineStats().size()));
|
||||
ingestStats.getPipelineStats().size(),
|
||||
inferenceStats));
|
||||
});
|
||||
trainedModelStats.sort(Comparator.comparing(TrainedModelStats::getModelId));
|
||||
return new Response(new QueryPage<>(trainedModelStats, totalModelCount, RESULTS_FIELD));
|
||||
|
|
|
@ -0,0 +1,249 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
||||
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.core.common.time.TimeUtils;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Instant;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.atomic.LongAdder;
|
||||
|
||||
public class InferenceStats implements ToXContentObject, Writeable {
|
||||
|
||||
public static final String NAME = "inference_stats";
|
||||
public static final ParseField MISSING_ALL_FIELDS_COUNT = new ParseField("missing_all_fields_count");
|
||||
public static final ParseField INFERENCE_COUNT = new ParseField("inference_count");
|
||||
public static final ParseField MODEL_ID = new ParseField("model_id");
|
||||
public static final ParseField NODE_ID = new ParseField("node_id");
|
||||
public static final ParseField FAILURE_COUNT = new ParseField("failure_count");
|
||||
public static final ParseField TYPE = new ParseField("type");
|
||||
public static final ParseField TIMESTAMP = new ParseField("time_stamp");
|
||||
|
||||
public static final ConstructingObjectParser<InferenceStats, Void> PARSER = new ConstructingObjectParser<>(
|
||||
NAME,
|
||||
true,
|
||||
a -> new InferenceStats((Long)a[0], (Long)a[1], (Long)a[2], (String)a[3], (String)a[4], (Instant)a[5])
|
||||
);
|
||||
static {
|
||||
PARSER.declareLong(ConstructingObjectParser.constructorArg(), MISSING_ALL_FIELDS_COUNT);
|
||||
PARSER.declareLong(ConstructingObjectParser.constructorArg(), INFERENCE_COUNT);
|
||||
PARSER.declareLong(ConstructingObjectParser.constructorArg(), FAILURE_COUNT);
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), MODEL_ID);
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), NODE_ID);
|
||||
PARSER.declareField(ConstructingObjectParser.constructorArg(),
|
||||
p -> TimeUtils.parseTimeFieldToInstant(p, TIMESTAMP.getPreferredName()),
|
||||
TIMESTAMP,
|
||||
ObjectParser.ValueType.VALUE);
|
||||
}
|
||||
public static InferenceStats emptyStats(String modelId, String nodeId) {
|
||||
return new InferenceStats(0L, 0L, 0L, modelId, nodeId, Instant.now());
|
||||
}
|
||||
|
||||
public static String docId(String modelId, String nodeId) {
|
||||
return NAME + "-" + modelId + "-" + nodeId;
|
||||
}
|
||||
|
||||
private final long missingAllFieldsCount;
|
||||
private final long inferenceCount;
|
||||
private final long failureCount;
|
||||
private final String modelId;
|
||||
private final String nodeId;
|
||||
private final Instant timeStamp;
|
||||
|
||||
private InferenceStats(Long missingAllFieldsCount,
|
||||
Long inferenceCount,
|
||||
Long failureCount,
|
||||
String modelId,
|
||||
String nodeId,
|
||||
Instant instant) {
|
||||
this(unbox(missingAllFieldsCount),
|
||||
unbox(inferenceCount),
|
||||
unbox(failureCount),
|
||||
modelId,
|
||||
nodeId,
|
||||
instant);
|
||||
}
|
||||
|
||||
public InferenceStats(long missingAllFieldsCount,
|
||||
long inferenceCount,
|
||||
long failureCount,
|
||||
String modelId,
|
||||
String nodeId,
|
||||
Instant timeStamp) {
|
||||
this.missingAllFieldsCount = missingAllFieldsCount;
|
||||
this.inferenceCount = inferenceCount;
|
||||
this.failureCount = failureCount;
|
||||
this.modelId = modelId;
|
||||
this.nodeId = nodeId;
|
||||
this.timeStamp = timeStamp == null ?
|
||||
Instant.ofEpochMilli(Instant.now().toEpochMilli()) :
|
||||
Instant.ofEpochMilli(timeStamp.toEpochMilli());
|
||||
}
|
||||
|
||||
public InferenceStats(StreamInput in) throws IOException {
|
||||
this.missingAllFieldsCount = in.readVLong();
|
||||
this.inferenceCount = in.readVLong();
|
||||
this.failureCount = in.readVLong();
|
||||
this.modelId = in.readOptionalString();
|
||||
this.nodeId = in.readOptionalString();
|
||||
this.timeStamp = in.readInstant();
|
||||
}
|
||||
|
||||
public long getMissingAllFieldsCount() {
|
||||
return missingAllFieldsCount;
|
||||
}
|
||||
|
||||
public long getInferenceCount() {
|
||||
return inferenceCount;
|
||||
}
|
||||
|
||||
public long getFailureCount() {
|
||||
return failureCount;
|
||||
}
|
||||
|
||||
public String getModelId() {
|
||||
return modelId;
|
||||
}
|
||||
|
||||
public String getNodeId() {
|
||||
return nodeId;
|
||||
}
|
||||
|
||||
public Instant getTimeStamp() {
|
||||
return timeStamp;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
|
||||
assert modelId != null : "model_id cannot be null when storing inference stats";
|
||||
assert nodeId != null : "node_id cannot be null when storing inference stats";
|
||||
builder.field(TYPE.getPreferredName(), NAME);
|
||||
builder.field(MODEL_ID.getPreferredName(), modelId);
|
||||
builder.field(NODE_ID.getPreferredName(), nodeId);
|
||||
}
|
||||
builder.field(FAILURE_COUNT.getPreferredName(), failureCount);
|
||||
builder.field(INFERENCE_COUNT.getPreferredName(), inferenceCount);
|
||||
builder.field(MISSING_ALL_FIELDS_COUNT.getPreferredName(), missingAllFieldsCount);
|
||||
builder.timeField(TIMESTAMP.getPreferredName(), TIMESTAMP.getPreferredName() + "_string", timeStamp.toEpochMilli());
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
InferenceStats that = (InferenceStats) o;
|
||||
return missingAllFieldsCount == that.missingAllFieldsCount
|
||||
&& inferenceCount == that.inferenceCount
|
||||
&& failureCount == that.failureCount
|
||||
&& Objects.equals(modelId, that.modelId)
|
||||
&& Objects.equals(nodeId, that.nodeId)
|
||||
&& Objects.equals(timeStamp, that.timeStamp);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(missingAllFieldsCount, inferenceCount, failureCount, modelId, nodeId, timeStamp);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "InferenceStats{" +
|
||||
"missingAllFieldsCount=" + missingAllFieldsCount +
|
||||
", inferenceCount=" + inferenceCount +
|
||||
", failureCount=" + failureCount +
|
||||
", modelId='" + modelId + '\'' +
|
||||
", nodeId='" + nodeId + '\'' +
|
||||
", timeStamp=" + timeStamp +
|
||||
'}';
|
||||
}
|
||||
|
||||
private static long unbox(@Nullable Long value) {
|
||||
return value == null ? 0L : value;
|
||||
}
|
||||
|
||||
public static Accumulator accumulator(InferenceStats stats) {
|
||||
return new Accumulator(stats);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeVLong(this.missingAllFieldsCount);
|
||||
out.writeVLong(this.inferenceCount);
|
||||
out.writeVLong(this.failureCount);
|
||||
out.writeOptionalString(this.modelId);
|
||||
out.writeOptionalString(this.nodeId);
|
||||
out.writeInstant(timeStamp);
|
||||
}
|
||||
|
||||
public static class Accumulator {
|
||||
|
||||
private final LongAdder missingFieldsAccumulator = new LongAdder();
|
||||
private final LongAdder inferenceAccumulator = new LongAdder();
|
||||
private final LongAdder failureCountAccumulator = new LongAdder();
|
||||
private final String modelId;
|
||||
private final String nodeId;
|
||||
|
||||
public Accumulator(String modelId, String nodeId) {
|
||||
this.modelId = modelId;
|
||||
this.nodeId = nodeId;
|
||||
}
|
||||
|
||||
public Accumulator(InferenceStats previousStats) {
|
||||
this.modelId = previousStats.modelId;
|
||||
this.nodeId = previousStats.nodeId;
|
||||
this.missingFieldsAccumulator.add(previousStats.missingAllFieldsCount);
|
||||
this.inferenceAccumulator.add(previousStats.inferenceCount);
|
||||
this.failureCountAccumulator.add(previousStats.failureCount);
|
||||
}
|
||||
|
||||
public Accumulator merge(InferenceStats otherStats) {
|
||||
this.missingFieldsAccumulator.add(otherStats.missingAllFieldsCount);
|
||||
this.inferenceAccumulator.add(otherStats.inferenceCount);
|
||||
this.failureCountAccumulator.add(otherStats.failureCount);
|
||||
return this;
|
||||
}
|
||||
|
||||
public void incMissingFields() {
|
||||
this.missingFieldsAccumulator.increment();
|
||||
}
|
||||
|
||||
public void incInference() {
|
||||
this.inferenceAccumulator.increment();
|
||||
}
|
||||
|
||||
public void incFailure() {
|
||||
this.failureCountAccumulator.increment();
|
||||
}
|
||||
|
||||
public InferenceStats currentStats() {
|
||||
return currentStats(Instant.now());
|
||||
}
|
||||
|
||||
public InferenceStats currentStats(Instant timeStamp) {
|
||||
return new InferenceStats(missingFieldsAccumulator.longValue(),
|
||||
inferenceAccumulator.longValue(),
|
||||
failureCountAccumulator.longValue(),
|
||||
modelId,
|
||||
nodeId,
|
||||
timeStamp);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -85,6 +85,21 @@
|
|||
"peak_usage_bytes" : {
|
||||
"type" : "long"
|
||||
},
|
||||
"model_id": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"node_id": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"inference_count": {
|
||||
"type": "long"
|
||||
},
|
||||
"failure_count": {
|
||||
"type": "long"
|
||||
},
|
||||
"missing_all_fields_count": {
|
||||
"type": "long"
|
||||
},
|
||||
"skipped_docs_count": {
|
||||
"type": "long"
|
||||
},
|
||||
|
|
|
@ -5,19 +5,23 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.action;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.ingest.IngestStats;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.action.util.QueryPage;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction.Response;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStatsTests;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction.Response.RESULTS_FIELD;
|
||||
|
||||
public class GetTrainedModelsStatsActionResponseTests extends AbstractWireSerializingTestCase<Response> {
|
||||
public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSerializationTestCase<Response> {
|
||||
|
||||
@Override
|
||||
protected Response createTestInstance() {
|
||||
|
@ -26,10 +30,11 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractWireSerial
|
|||
.limit(listSize).map(id ->
|
||||
new Response.TrainedModelStats(id,
|
||||
randomBoolean() ? randomIngestStats() : null,
|
||||
randomIntBetween(0, 10))
|
||||
randomIntBetween(0, 10),
|
||||
randomBoolean() ? InferenceStatsTests.createTestInstance(id, null) : null)
|
||||
)
|
||||
.collect(Collectors.toList());
|
||||
return new Response(new QueryPage<>(trainedModelStats, randomLongBetween(listSize, 1000), Response.RESULTS_FIELD));
|
||||
return new Response(new QueryPage<>(trainedModelStats, randomLongBetween(listSize, 1000), RESULTS_FIELD));
|
||||
}
|
||||
|
||||
private IngestStats randomIngestStats() {
|
||||
|
@ -57,4 +62,37 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractWireSerial
|
|||
protected Writeable.Reader<Response> instanceReader() {
|
||||
return Response::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Response mutateInstanceForVersion(Response instance, Version version) {
|
||||
if (version.before(Version.V_7_8_0)) {
|
||||
List<Response.TrainedModelStats> stats = instance.getResources()
|
||||
.results()
|
||||
.stream()
|
||||
.map(s -> new Response.TrainedModelStats(s.getModelId(),
|
||||
adjustForVersion(s.getIngestStats(), version),
|
||||
s.getPipelineCount(),
|
||||
null))
|
||||
.collect(Collectors.toList());
|
||||
return new Response(new QueryPage<>(stats, instance.getResources().count(), RESULTS_FIELD));
|
||||
}
|
||||
return instance;
|
||||
}
|
||||
|
||||
IngestStats adjustForVersion(IngestStats stats, Version version) {
|
||||
if (version.before(Version.V_7_6_0)) {
|
||||
return new IngestStats(stats.getTotalStats(),
|
||||
stats.getPipelineStats(),
|
||||
stats.getProcessorStats()
|
||||
.entrySet()
|
||||
.stream()
|
||||
.collect(Collectors.toMap(Map.Entry::getKey,
|
||||
(kv) -> kv.getValue()
|
||||
.stream()
|
||||
.map(pstats -> new IngestStats.ProcessorStat(pstats.getName(), "_NOT_AVAILABLE", pstats.getStats()))
|
||||
.collect(Collectors.toList()))));
|
||||
}
|
||||
return stats;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -320,9 +320,13 @@ public class TrainedModelConfigTests extends AbstractBWCSerializationTestCase<Tr
|
|||
|
||||
@Override
|
||||
protected TrainedModelConfig mutateInstanceForVersion(TrainedModelConfig instance, Version version) {
|
||||
TrainedModelConfig.Builder builder = new TrainedModelConfig.Builder(instance);
|
||||
if (version.before(Version.V_7_7_0)) {
|
||||
builder.setDefaultFieldMap(null);
|
||||
}
|
||||
if (version.before(Version.V_7_8_0)) {
|
||||
return new TrainedModelConfig.Builder(instance).setInferenceConfig(null).build();
|
||||
builder.setInferenceConfig(null);
|
||||
}
|
||||
return instance;
|
||||
return builder.build();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
||||
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Instant;
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE;
|
||||
|
||||
public class InferenceStatsTests extends AbstractSerializingTestCase<InferenceStats> {
|
||||
|
||||
public static InferenceStats createTestInstance(String modelId, @Nullable String nodeId) {
|
||||
return new InferenceStats(randomNonNegativeLong(),
|
||||
randomNonNegativeLong(),
|
||||
randomNonNegativeLong(),
|
||||
modelId,
|
||||
nodeId,
|
||||
Instant.now()
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected InferenceStats doParseInstance(XContentParser parser) throws IOException {
|
||||
return InferenceStats.PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected InferenceStats createTestInstance() {
|
||||
return createTestInstance(randomAlphaOfLength(10), randomAlphaOfLength(10));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<InferenceStats> instanceReader() {
|
||||
return InferenceStats::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ToXContent.Params getToXContentParams() {
|
||||
return new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true"));
|
||||
}
|
||||
|
||||
}
|
|
@ -8,6 +8,7 @@ package org.elasticsearch.xpack.ml.integration;
|
|||
import org.apache.http.util.EntityUtils;
|
||||
import org.elasticsearch.client.Request;
|
||||
import org.elasticsearch.client.Response;
|
||||
import org.elasticsearch.client.ResponseException;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||
|
@ -29,6 +30,7 @@ import org.junit.Before;
|
|||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
|
||||
import static org.hamcrest.CoreMatchers.containsString;
|
||||
|
@ -67,7 +69,7 @@ public class InferenceIngestIT extends ESRestTestCase {
|
|||
client().performRequest(new Request("DELETE", "_ml/inference/test_regression"));
|
||||
}
|
||||
|
||||
public void testPipelineCreationAndDeletion() throws Exception {
|
||||
public void testPathologicalPipelineCreationAndDeletion() throws Exception {
|
||||
|
||||
for (int i = 0; i < 10; i++) {
|
||||
client().performRequest(putPipeline("simple_classification_pipeline", CLASSIFICATION_PIPELINE));
|
||||
|
@ -78,6 +80,24 @@ public class InferenceIngestIT extends ESRestTestCase {
|
|||
client().performRequest(indexRequest("index_for_inference_test", "simple_regression_pipeline", generateSourceDoc()));
|
||||
client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_regression_pipeline"));
|
||||
}
|
||||
client().performRequest(new Request("POST", "index_for_inference_test/_refresh"));
|
||||
|
||||
Response searchResponse = client().performRequest(searchRequest("index_for_inference_test",
|
||||
QueryBuilders.boolQuery()
|
||||
.filter(
|
||||
QueryBuilders.existsQuery("ml.inference.regression.predicted_value"))));
|
||||
assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":10"));
|
||||
|
||||
searchResponse = client().performRequest(searchRequest("index_for_inference_test",
|
||||
QueryBuilders.boolQuery()
|
||||
.filter(
|
||||
QueryBuilders.existsQuery("ml.inference.classification.predicted_value"))));
|
||||
|
||||
assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":10"));
|
||||
}
|
||||
|
||||
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/54786")
|
||||
public void testPipelineIngest() throws Exception {
|
||||
|
||||
client().performRequest(putPipeline("simple_classification_pipeline", CLASSIFICATION_PIPELINE));
|
||||
client().performRequest(putPipeline("simple_regression_pipeline", REGRESSION_PIPELINE));
|
||||
|
@ -87,24 +107,42 @@ public class InferenceIngestIT extends ESRestTestCase {
|
|||
client().performRequest(indexRequest("index_for_inference_test", "simple_regression_pipeline", generateSourceDoc()));
|
||||
}
|
||||
|
||||
for (int i = 0; i < 5; i++) {
|
||||
client().performRequest(indexRequest("index_for_inference_test", "simple_regression_pipeline", generateSourceDoc()));
|
||||
}
|
||||
|
||||
client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_regression_pipeline"));
|
||||
client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_classification_pipeline"));
|
||||
|
||||
client().performRequest(new Request("POST", "index_for_inference_test/_refresh"));
|
||||
|
||||
|
||||
Response searchResponse = client().performRequest(searchRequest("index_for_inference_test",
|
||||
QueryBuilders.boolQuery()
|
||||
.filter(
|
||||
QueryBuilders.existsQuery("ml.inference.regression.predicted_value"))));
|
||||
assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":20"));
|
||||
assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":15"));
|
||||
|
||||
searchResponse = client().performRequest(searchRequest("index_for_inference_test",
|
||||
QueryBuilders.boolQuery()
|
||||
.filter(
|
||||
QueryBuilders.existsQuery("ml.inference.classification.predicted_value"))));
|
||||
|
||||
assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":20"));
|
||||
assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":10"));
|
||||
|
||||
assertBusy(() -> {
|
||||
try {
|
||||
Response statsResponse = client().performRequest(new Request("GET", "_ml/inference/test_classification/_stats"));
|
||||
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":10"));
|
||||
statsResponse = client().performRequest(new Request("GET", "_ml/inference/test_regression/_stats"));
|
||||
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":15"));
|
||||
// can get both
|
||||
statsResponse = client().performRequest(new Request("GET", "_ml/inference/_stats"));
|
||||
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":15"));
|
||||
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":10"));
|
||||
} catch (ResponseException ex) {
|
||||
//this could just mean shard failures.
|
||||
}
|
||||
}, 30, TimeUnit.SECONDS);
|
||||
}
|
||||
|
||||
public void testSimulate() throws IOException {
|
||||
|
|
|
@ -211,6 +211,7 @@ import org.elasticsearch.xpack.ml.dataframe.process.NativeAnalyticsProcessFactor
|
|||
import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
|
||||
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
|
||||
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
|
||||
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
|
@ -630,13 +631,20 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, Analys
|
|||
this.datafeedManager.set(datafeedManager);
|
||||
|
||||
// Inference components
|
||||
final TrainedModelStatsService trainedModelStatsService = new TrainedModelStatsService(resultsPersisterService,
|
||||
originSettingClient,
|
||||
indexNameExpressionResolver,
|
||||
clusterService,
|
||||
threadPool);
|
||||
final TrainedModelProvider trainedModelProvider = new TrainedModelProvider(client, xContentRegistry);
|
||||
final ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
|
||||
inferenceAuditor,
|
||||
threadPool,
|
||||
clusterService,
|
||||
xContentRegistry,
|
||||
settings);
|
||||
trainedModelStatsService,
|
||||
settings,
|
||||
clusterService.getNodeName());
|
||||
|
||||
// Data frame analytics components
|
||||
AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory,
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.elasticsearch.ingest.Pipeline;
|
|||
import org.elasticsearch.tasks.Task;
|
||||
import org.elasticsearch.transport.TransportService;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
|
||||
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
|
||||
|
@ -38,12 +39,12 @@ import java.util.LinkedHashSet;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
|
||||
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
|
||||
|
||||
|
||||
public class TransportGetTrainedModelsStatsAction extends HandledTransportAction<GetTrainedModelsStatsAction.Request,
|
||||
GetTrainedModelsStatsAction.Response> {
|
||||
|
||||
|
@ -73,13 +74,21 @@ public class TransportGetTrainedModelsStatsAction extends HandledTransportAction
|
|||
|
||||
GetTrainedModelsStatsAction.Response.Builder responseBuilder = new GetTrainedModelsStatsAction.Response.Builder();
|
||||
|
||||
ActionListener<List<InferenceStats>> inferenceStatsListener = ActionListener.wrap(
|
||||
inferenceStats -> listener.onResponse(responseBuilder.setInferenceStatsByModelId(inferenceStats.stream()
|
||||
.collect(Collectors.toMap(InferenceStats::getModelId, Function.identity())))
|
||||
.build()),
|
||||
listener::onFailure
|
||||
);
|
||||
|
||||
ActionListener<NodesStatsResponse> nodesStatsListener = ActionListener.wrap(
|
||||
nodesStatsResponse -> {
|
||||
Map<String, IngestStats> modelIdIngestStats = inferenceIngestStatsByPipelineId(nodesStatsResponse,
|
||||
Map<String, IngestStats> modelIdIngestStats = inferenceIngestStatsByModelId(nodesStatsResponse,
|
||||
pipelineIdsByModelIds(clusterService.state(),
|
||||
ingestService,
|
||||
responseBuilder.getExpandedIds()));
|
||||
listener.onResponse(responseBuilder.setIngestStatsByModelId(modelIdIngestStats).build());
|
||||
responseBuilder.setIngestStatsByModelId(modelIdIngestStats);
|
||||
trainedModelProvider.getInferenceStats(responseBuilder.getExpandedIds().toArray(new String[0]), inferenceStatsListener);
|
||||
},
|
||||
listener::onFailure
|
||||
);
|
||||
|
@ -103,7 +112,7 @@ public class TransportGetTrainedModelsStatsAction extends HandledTransportAction
|
|||
idsListener);
|
||||
}
|
||||
|
||||
static Map<String, IngestStats> inferenceIngestStatsByPipelineId(NodesStatsResponse response,
|
||||
static Map<String, IngestStats> inferenceIngestStatsByModelId(NodesStatsResponse response,
|
||||
Map<String, Set<String>> modelIdToPipelineId) {
|
||||
|
||||
Map<String, IngestStats> ingestStatsMap = new HashMap<>();
|
||||
|
|
|
@ -0,0 +1,180 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.inference;
|
||||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||
import org.elasticsearch.action.bulk.BulkRequest;
|
||||
import org.elasticsearch.action.support.PlainActionFuture;
|
||||
import org.elasticsearch.action.support.WriteRequest;
|
||||
import org.elasticsearch.action.update.UpdateRequest;
|
||||
import org.elasticsearch.client.OriginSettingClient;
|
||||
import org.elasticsearch.cluster.ClusterState;
|
||||
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
|
||||
import org.elasticsearch.cluster.service.ClusterService;
|
||||
import org.elasticsearch.common.component.LifecycleListener;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||
import org.elasticsearch.script.Script;
|
||||
import org.elasticsearch.script.ScriptType;
|
||||
import org.elasticsearch.threadpool.Scheduler;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||
import org.elasticsearch.xpack.ml.MachineLearning;
|
||||
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
|
||||
public class TrainedModelStatsService {
|
||||
|
||||
private static final Logger logger = LogManager.getLogger(TrainedModelStatsService.class);
|
||||
private static final TimeValue PERSISTENCE_INTERVAL = TimeValue.timeValueSeconds(1);
|
||||
|
||||
// Script to only update if stats have increased since last persistence
|
||||
private static final String STATS_UPDATE_SCRIPT = "" +
|
||||
" ctx._source.missing_all_fields_count += params.missing_all_fields_count;\n" +
|
||||
" ctx._source.inference_count += params.inference_count;\n" +
|
||||
" ctx._source.failure_count += params.failure_count;\n" +
|
||||
" ctx._source.time_stamp = params.time_stamp;";
|
||||
private static final ToXContent.Params FOR_INTERNAL_STORAGE_PARAMS =
|
||||
new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"));
|
||||
|
||||
private final Map<String, InferenceStats> statsQueue;
|
||||
private final ResultsPersisterService resultsPersisterService;
|
||||
private final OriginSettingClient client;
|
||||
private final IndexNameExpressionResolver indexNameExpressionResolver;
|
||||
private final ThreadPool threadPool;
|
||||
private volatile Scheduler.Cancellable scheduledFuture;
|
||||
private volatile boolean verifiedStatsIndexCreated;
|
||||
private volatile boolean stopped;
|
||||
private volatile ClusterState clusterState;
|
||||
|
||||
public TrainedModelStatsService(ResultsPersisterService resultsPersisterService,
|
||||
OriginSettingClient client,
|
||||
IndexNameExpressionResolver indexNameExpressionResolver,
|
||||
ClusterService clusterService,
|
||||
ThreadPool threadPool) {
|
||||
this.resultsPersisterService = resultsPersisterService;
|
||||
this.client = client;
|
||||
this.indexNameExpressionResolver = indexNameExpressionResolver;
|
||||
this.threadPool = threadPool;
|
||||
this.statsQueue = new ConcurrentHashMap<>();
|
||||
|
||||
clusterService.addLifecycleListener(new LifecycleListener() {
|
||||
@Override
|
||||
public void beforeStart() {
|
||||
start();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void beforeStop() {
|
||||
stop();
|
||||
}
|
||||
});
|
||||
clusterService.addListener((event) -> this.clusterState = event.state());
|
||||
}
|
||||
|
||||
public void queueStats(InferenceStats stats) {
|
||||
statsQueue.compute(InferenceStats.docId(stats.getModelId(), stats.getNodeId()),
|
||||
(k, previousStats) -> previousStats == null ?
|
||||
stats :
|
||||
InferenceStats.accumulator(stats).merge(previousStats).currentStats(stats.getTimeStamp()));
|
||||
}
|
||||
|
||||
void stop() {
|
||||
stopped = true;
|
||||
statsQueue.clear();
|
||||
|
||||
ThreadPool.Cancellable cancellable = this.scheduledFuture;
|
||||
if (cancellable != null) {
|
||||
cancellable.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
void start() {
|
||||
stopped = false;
|
||||
scheduledFuture = threadPool.scheduleWithFixedDelay(this::updateStats,
|
||||
PERSISTENCE_INTERVAL,
|
||||
MachineLearning.UTILITY_THREAD_POOL_NAME);
|
||||
}
|
||||
|
||||
void updateStats() {
|
||||
if (clusterState == null || statsQueue.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
if (verifiedStatsIndexCreated == false) {
|
||||
try {
|
||||
PlainActionFuture<Boolean> listener = new PlainActionFuture<>();
|
||||
MlStatsIndex.createStatsIndexAndAliasIfNecessary(client, clusterState, indexNameExpressionResolver, listener);
|
||||
listener.actionGet();
|
||||
verifiedStatsIndexCreated = true;
|
||||
} catch (Exception e) {
|
||||
logger.error("failure creating ml stats index for storing model stats", e);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
List<InferenceStats> stats = new ArrayList<>(statsQueue.size());
|
||||
for(String k : statsQueue.keySet()) {
|
||||
InferenceStats inferenceStats = statsQueue.remove(k);
|
||||
if (inferenceStats != null) {
|
||||
stats.add(inferenceStats);
|
||||
}
|
||||
}
|
||||
if (stats.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
BulkRequest bulkRequest = new BulkRequest();
|
||||
stats.stream().map(TrainedModelStatsService::buildUpdateRequest).filter(Objects::nonNull).forEach(bulkRequest::add);
|
||||
bulkRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
if (bulkRequest.requests().isEmpty()) {
|
||||
return;
|
||||
}
|
||||
resultsPersisterService.bulkIndexWithRetry(bulkRequest,
|
||||
stats.stream().map(InferenceStats::getModelId).collect(Collectors.joining(",")),
|
||||
() -> stopped == false,
|
||||
(msg) -> {});
|
||||
}
|
||||
|
||||
static UpdateRequest buildUpdateRequest(InferenceStats stats) {
|
||||
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
params.put(InferenceStats.FAILURE_COUNT.getPreferredName(), stats.getFailureCount());
|
||||
params.put(InferenceStats.MISSING_ALL_FIELDS_COUNT.getPreferredName(), stats.getMissingAllFieldsCount());
|
||||
params.put(InferenceStats.TIMESTAMP.getPreferredName(), stats.getTimeStamp().toEpochMilli());
|
||||
params.put(InferenceStats.INFERENCE_COUNT.getPreferredName(), stats.getInferenceCount());
|
||||
stats.toXContent(builder, FOR_INTERNAL_STORAGE_PARAMS);
|
||||
UpdateRequest updateRequest = new UpdateRequest();
|
||||
updateRequest.upsert(builder)
|
||||
.index(MlStatsIndex.writeAlias())
|
||||
.id(InferenceStats.docId(stats.getModelId(), stats.getNodeId()))
|
||||
.script(new Script(ScriptType.INLINE, "painless", STATS_UPDATE_SCRIPT, params));
|
||||
return updateRequest;
|
||||
} catch (IOException ex) {
|
||||
logger.error(
|
||||
() -> new ParameterizedMessage("[{}] [{}] failed to serialize stats for update.",
|
||||
stats.getModelId(),
|
||||
stats.getNodeId()),
|
||||
ex);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
|
@ -10,6 +10,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
|||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
@ -17,11 +18,14 @@ import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInference
|
|||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
|
||||
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.concurrent.atomic.LongAdder;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING;
|
||||
|
||||
|
@ -29,19 +33,30 @@ public class LocalModel<T extends InferenceConfig> implements Model<T> {
|
|||
|
||||
private final TrainedModelDefinition trainedModelDefinition;
|
||||
private final String modelId;
|
||||
private final String nodeId;
|
||||
private final Set<String> fieldNames;
|
||||
private final Map<String, String> defaultFieldMap;
|
||||
private final AtomicReference<InferenceStats.Accumulator> statsAccumulator;
|
||||
private final TrainedModelStatsService trainedModelStatsService;
|
||||
private volatile long persistenceQuotient = 100;
|
||||
private final LongAdder currentInferenceCount;
|
||||
private final T inferenceConfig;
|
||||
|
||||
public LocalModel(String modelId,
|
||||
String nodeId,
|
||||
TrainedModelDefinition trainedModelDefinition,
|
||||
TrainedModelInput input,
|
||||
Map<String, String> defaultFieldMap,
|
||||
T modelInferenceConfig) {
|
||||
T modelInferenceConfig,
|
||||
TrainedModelStatsService trainedModelStatsService ) {
|
||||
this.trainedModelDefinition = trainedModelDefinition;
|
||||
this.modelId = modelId;
|
||||
this.nodeId = nodeId;
|
||||
this.fieldNames = new HashSet<>(input.getFieldNames());
|
||||
this.statsAccumulator = new AtomicReference<>(new InferenceStats.Accumulator(modelId, nodeId));
|
||||
this.trainedModelStatsService = trainedModelStatsService;
|
||||
this.defaultFieldMap = defaultFieldMap == null ? null : new HashMap<>(defaultFieldMap);
|
||||
this.currentInferenceCount = new LongAdder();
|
||||
this.inferenceConfig = modelInferenceConfig;
|
||||
}
|
||||
|
||||
|
@ -54,6 +69,12 @@ public class LocalModel<T extends InferenceConfig> implements Model<T> {
|
|||
return modelId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public InferenceStats getLatestStatsAndReset() {
|
||||
InferenceStats.Accumulator toPersist = statsAccumulator.getAndSet(new InferenceStats.Accumulator(modelId, nodeId));
|
||||
return toPersist.currentStats();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getResultsType() {
|
||||
switch (trainedModelDefinition.getTrainedModel().targetType()) {
|
||||
|
@ -68,6 +89,16 @@ public class LocalModel<T extends InferenceConfig> implements Model<T> {
|
|||
}
|
||||
}
|
||||
|
||||
void persistStats() {
|
||||
trainedModelStatsService.queueStats(getLatestStatsAndReset());
|
||||
if (persistenceQuotient < 1000 && currentInferenceCount.sum() > 1000) {
|
||||
persistenceQuotient = 1000;
|
||||
}
|
||||
if (persistenceQuotient < 10_000 && currentInferenceCount.sum() > 10_000) {
|
||||
persistenceQuotient = 10_000;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void infer(Map<String, Object> fields, InferenceConfigUpdate<T> update, ActionListener<InferenceResults> listener) {
|
||||
if (update.isSupported(this.inferenceConfig) == false) {
|
||||
|
@ -79,14 +110,27 @@ public class LocalModel<T extends InferenceConfig> implements Model<T> {
|
|||
return;
|
||||
}
|
||||
try {
|
||||
statsAccumulator.get().incInference();
|
||||
currentInferenceCount.increment();
|
||||
|
||||
Model.mapFieldsIfNecessary(fields, defaultFieldMap);
|
||||
|
||||
boolean shouldPersistStats = ((currentInferenceCount.sum() + 1) % persistenceQuotient == 0);
|
||||
if (fieldNames.stream().allMatch(f -> MapHelper.dig(f, fields) == null)) {
|
||||
statsAccumulator.get().incMissingFields();
|
||||
if (shouldPersistStats) {
|
||||
persistStats();
|
||||
}
|
||||
listener.onResponse(new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId)));
|
||||
return;
|
||||
}
|
||||
|
||||
listener.onResponse(trainedModelDefinition.infer(fields, update.apply(inferenceConfig)));
|
||||
InferenceResults inferenceResults = trainedModelDefinition.infer(fields, update.apply(inferenceConfig));
|
||||
if (shouldPersistStats) {
|
||||
persistStats();
|
||||
}
|
||||
listener.onResponse(inferenceResults);
|
||||
} catch (Exception e) {
|
||||
statsAccumulator.get().incFailure();
|
||||
listener.onFailure(e);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
|||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
|
||||
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -43,4 +44,6 @@ public interface Model<T extends InferenceConfig> {
|
|||
});
|
||||
}
|
||||
}
|
||||
|
||||
InferenceStats getLatestStatsAndReset();
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ package org.elasticsearch.xpack.ml.inference.loadingservice;
|
|||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||
import org.apache.logging.log4j.util.MessageSupplier;
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.cluster.ClusterChangedEvent;
|
||||
|
@ -33,6 +34,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
|||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.ml.MachineLearning;
|
||||
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
|
||||
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
|
||||
|
@ -83,6 +85,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
Setting.Property.NodeScope);
|
||||
|
||||
private static final Logger logger = LogManager.getLogger(ModelLoadingService.class);
|
||||
private final TrainedModelStatsService modelStatsService;
|
||||
private final Cache<String, LocalModel<? extends InferenceConfig>> localModelCache;
|
||||
private final Set<String> referencedModels = new HashSet<>();
|
||||
private final Map<String, Queue<ActionListener<Model<? extends InferenceConfig>>>> loadingListeners = new HashMap<>();
|
||||
|
@ -92,17 +95,21 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
private final InferenceAuditor auditor;
|
||||
private final ByteSizeValue maxCacheSize;
|
||||
private final NamedXContentRegistry namedXContentRegistry;
|
||||
private final String localNode;
|
||||
|
||||
public ModelLoadingService(TrainedModelProvider trainedModelProvider,
|
||||
InferenceAuditor auditor,
|
||||
ThreadPool threadPool,
|
||||
ClusterService clusterService,
|
||||
NamedXContentRegistry namedXContentRegistry,
|
||||
Settings settings) {
|
||||
TrainedModelStatsService modelStatsService,
|
||||
Settings settings,
|
||||
String localNode) {
|
||||
this.provider = trainedModelProvider;
|
||||
this.threadPool = threadPool;
|
||||
this.maxCacheSize = INFERENCE_MODEL_CACHE_SIZE.get(settings);
|
||||
this.auditor = auditor;
|
||||
this.modelStatsService = modelStatsService;
|
||||
this.shouldNotAudit = new HashSet<>();
|
||||
this.namedXContentRegistry = namedXContentRegistry;
|
||||
this.localModelCache = CacheBuilder.<String, LocalModel<? extends InferenceConfig>>builder()
|
||||
|
@ -113,6 +120,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
.setExpireAfterAccess(INFERENCE_MODEL_CACHE_TTL.get(settings))
|
||||
.build();
|
||||
clusterService.addListener(this);
|
||||
this.localNode = localNode;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -136,13 +144,13 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
LocalModel<? extends InferenceConfig> cachedModel = localModelCache.get(modelId);
|
||||
if (cachedModel != null) {
|
||||
modelActionListener.onResponse(cachedModel);
|
||||
logger.trace("[{}] loaded from cache", modelId);
|
||||
logger.trace(() -> new ParameterizedMessage("[{}] loaded from cache", modelId));
|
||||
return;
|
||||
}
|
||||
if (loadModelIfNecessary(modelId, modelActionListener) == false) {
|
||||
// If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called
|
||||
// by a simulated pipeline
|
||||
logger.trace("[{}] not actively loading, eager loading without cache", modelId);
|
||||
logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", modelId));
|
||||
provider.getTrainedModel(modelId, true, ActionListener.wrap(
|
||||
trainedModelConfig -> {
|
||||
trainedModelConfig.ensureParsedDefinition(namedXContentRegistry);
|
||||
|
@ -151,15 +159,17 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
trainedModelConfig.getInferenceConfig();
|
||||
modelActionListener.onResponse(new LocalModel<>(
|
||||
trainedModelConfig.getModelId(),
|
||||
localNode,
|
||||
trainedModelConfig.getModelDefinition(),
|
||||
trainedModelConfig.getInput(),
|
||||
trainedModelConfig.getDefaultFieldMap(),
|
||||
inferenceConfig));
|
||||
inferenceConfig,
|
||||
modelStatsService));
|
||||
},
|
||||
modelActionListener::onFailure
|
||||
));
|
||||
} else {
|
||||
logger.trace("[{}] is loading or loaded, added new listener to queue", modelId);
|
||||
logger.trace(() -> new ParameterizedMessage("[{}] is loading or loaded, added new listener to queue", modelId));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -183,7 +193,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
if (loadingListeners.computeIfPresent(
|
||||
modelId,
|
||||
(storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)) == null) {
|
||||
logger.trace("[{}] attempting to load and cache", modelId);
|
||||
logger.trace(() -> new ParameterizedMessage("[{}] attempting to load and cache", modelId));
|
||||
loadingListeners.put(modelId, addFluently(new ArrayDeque<>(), modelActionListener));
|
||||
loadModel(modelId);
|
||||
}
|
||||
|
@ -198,7 +208,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
private void loadModel(String modelId) {
|
||||
provider.getTrainedModel(modelId, true, ActionListener.wrap(
|
||||
trainedModelConfig -> {
|
||||
logger.debug("[{}] successfully loaded model", modelId);
|
||||
logger.debug(() -> new ParameterizedMessage("[{}] successfully loaded model", modelId));
|
||||
handleLoadSuccess(modelId, trainedModelConfig);
|
||||
},
|
||||
failure -> {
|
||||
|
@ -216,10 +226,12 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
trainedModelConfig.getInferenceConfig();
|
||||
LocalModel<? extends InferenceConfig> loadedModel = new LocalModel<>(
|
||||
trainedModelConfig.getModelId(),
|
||||
localNode,
|
||||
trainedModelConfig.getModelDefinition(),
|
||||
trainedModelConfig.getInput(),
|
||||
trainedModelConfig.getDefaultFieldMap(),
|
||||
inferenceConfig);
|
||||
inferenceConfig,
|
||||
modelStatsService);
|
||||
synchronized (loadingListeners) {
|
||||
listeners = loadingListeners.remove(modelId);
|
||||
// If there is no loadingListener that means the loading was canceled and the listener was already notified as such
|
||||
|
@ -252,7 +264,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
|
||||
private void cacheEvictionListener(RemovalNotification<String, LocalModel<? extends InferenceConfig>> notification) {
|
||||
if (notification.getRemovalReason() == RemovalNotification.RemovalReason.EVICTED) {
|
||||
String msg = new ParameterizedMessage(
|
||||
MessageSupplier msg = () -> new ParameterizedMessage(
|
||||
"model cache entry evicted." +
|
||||
"current cache [{}] current max [{}] model size [{}]. " +
|
||||
"If this is undesired, consider updating setting [{}] or [{}].",
|
||||
|
@ -260,9 +272,10 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
maxCacheSize.getStringRep(),
|
||||
new ByteSizeValue(notification.getValue().ramBytesUsed()).getStringRep(),
|
||||
INFERENCE_MODEL_CACHE_SIZE.getKey(),
|
||||
INFERENCE_MODEL_CACHE_TTL.getKey()).getFormattedMessage();
|
||||
INFERENCE_MODEL_CACHE_TTL.getKey());
|
||||
auditIfNecessary(notification.getKey(), msg);
|
||||
}
|
||||
notification.getValue().persistStats();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -335,14 +348,14 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
loadModels(allReferencedModelKeys);
|
||||
}
|
||||
|
||||
private void auditIfNecessary(String modelId, String msg) {
|
||||
private void auditIfNecessary(String modelId, MessageSupplier msg) {
|
||||
if (shouldNotAudit.contains(modelId)) {
|
||||
logger.trace("[{}] {}", modelId, msg);
|
||||
logger.trace(() -> new ParameterizedMessage("[{}] {}", modelId, msg.get().getFormattedMessage()));
|
||||
return;
|
||||
}
|
||||
auditor.warning(modelId, msg);
|
||||
auditor.warning(modelId, msg.get().getFormattedMessage());
|
||||
shouldNotAudit.add(modelId);
|
||||
logger.warn("[{}] {}", modelId, msg);
|
||||
logger.warn("[{}] {}", modelId, msg.get().getFormattedMessage());
|
||||
}
|
||||
|
||||
private void loadModels(Set<String> modelIds) {
|
||||
|
|
|
@ -19,6 +19,7 @@ import org.elasticsearch.action.bulk.BulkRequest;
|
|||
import org.elasticsearch.action.bulk.BulkResponse;
|
||||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.action.search.MultiSearchAction;
|
||||
import org.elasticsearch.action.search.MultiSearchRequest;
|
||||
import org.elasticsearch.action.search.MultiSearchRequestBuilder;
|
||||
import org.elasticsearch.action.search.MultiSearchResponse;
|
||||
import org.elasticsearch.action.search.SearchAction;
|
||||
|
@ -53,14 +54,19 @@ import org.elasticsearch.index.reindex.DeleteByQueryAction;
|
|||
import org.elasticsearch.index.reindex.DeleteByQueryRequest;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||
import org.elasticsearch.search.aggregations.metrics.Max;
|
||||
import org.elasticsearch.search.aggregations.metrics.Sum;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.search.sort.SortBuilders;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
import org.elasticsearch.xpack.core.action.util.ExpandedIdsMatcher;
|
||||
import org.elasticsearch.xpack.core.action.util.PageParams;
|
||||
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||
|
@ -68,7 +74,9 @@ import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
|||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.net.URL;
|
||||
import java.time.Instant;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
|
@ -357,7 +365,7 @@ public class TrainedModelProvider {
|
|||
}
|
||||
DeleteByQueryRequest request = new DeleteByQueryRequest().setAbortOnVersionConflict(false);
|
||||
|
||||
request.indices(InferenceIndexConstants.INDEX_PATTERN);
|
||||
request.indices(InferenceIndexConstants.INDEX_PATTERN, MlStatsIndex.indexPattern());
|
||||
QueryBuilder query = QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId);
|
||||
request.setQuery(query);
|
||||
request.setRefresh(true);
|
||||
|
@ -454,6 +462,99 @@ public class TrainedModelProvider {
|
|||
client::search);
|
||||
}
|
||||
|
||||
public void getInferenceStats(String[] modelIds, ActionListener<List<InferenceStats>> listener) {
|
||||
MultiSearchRequest multiSearchRequest = new MultiSearchRequest();
|
||||
Arrays.stream(modelIds).map(this::buildStatsSearchRequest).forEach(multiSearchRequest::add);
|
||||
if (multiSearchRequest.requests().isEmpty()) {
|
||||
listener.onResponse(Collections.emptyList());
|
||||
return;
|
||||
}
|
||||
executeAsyncWithOrigin(client.threadPool().getThreadContext(),
|
||||
ML_ORIGIN,
|
||||
multiSearchRequest,
|
||||
ActionListener.<MultiSearchResponse>wrap(
|
||||
responses -> {
|
||||
List<InferenceStats> allStats = new ArrayList<>(modelIds.length);
|
||||
int modelIndex = 0;
|
||||
assert responses.getResponses().length == modelIds.length :
|
||||
"mismatch between search response size and models requested";
|
||||
for (MultiSearchResponse.Item response : responses.getResponses()) {
|
||||
if (response.isFailure()) {
|
||||
if (ExceptionsHelper.unwrapCause(response.getFailure()) instanceof ResourceNotFoundException) {
|
||||
modelIndex++;
|
||||
continue;
|
||||
}
|
||||
logger.error(new ParameterizedMessage("[{}] search failed for models",
|
||||
Strings.arrayToCommaDelimitedString(modelIds)),
|
||||
response.getFailure());
|
||||
listener.onFailure(ExceptionsHelper.serverError("Searching for stats for models [{}] failed",
|
||||
response.getFailure(),
|
||||
Strings.arrayToCommaDelimitedString(modelIds)));
|
||||
return;
|
||||
}
|
||||
try {
|
||||
InferenceStats inferenceStats = handleMultiNodeStatsResponse(response.getResponse(), modelIds[modelIndex++]);
|
||||
if (inferenceStats != null) {
|
||||
allStats.add(inferenceStats);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
listener.onFailure(e);
|
||||
return;
|
||||
}
|
||||
}
|
||||
listener.onResponse(allStats);
|
||||
},
|
||||
e -> {
|
||||
Throwable unwrapped = ExceptionsHelper.unwrapCause(e);
|
||||
if (unwrapped instanceof ResourceNotFoundException) {
|
||||
listener.onResponse(Collections.emptyList());
|
||||
return;
|
||||
}
|
||||
listener.onFailure((Exception)unwrapped);
|
||||
}
|
||||
),
|
||||
client::multiSearch);
|
||||
}
|
||||
|
||||
private SearchRequest buildStatsSearchRequest(String modelId) {
|
||||
BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.termQuery(InferenceStats.MODEL_ID.getPreferredName(), modelId))
|
||||
.filter(QueryBuilders.termQuery(InferenceStats.TYPE.getPreferredName(), InferenceStats.NAME));
|
||||
return new SearchRequest(MlStatsIndex.indexPattern())
|
||||
.indicesOptions(IndicesOptions.lenientExpandOpen())
|
||||
.allowPartialSearchResults(false)
|
||||
.source(SearchSourceBuilder.searchSource()
|
||||
.size(0)
|
||||
.aggregation(AggregationBuilders.sum(InferenceStats.FAILURE_COUNT.getPreferredName())
|
||||
.field(InferenceStats.FAILURE_COUNT.getPreferredName()))
|
||||
.aggregation(AggregationBuilders.sum(InferenceStats.MISSING_ALL_FIELDS_COUNT.getPreferredName())
|
||||
.field(InferenceStats.MISSING_ALL_FIELDS_COUNT.getPreferredName()))
|
||||
.aggregation(AggregationBuilders.sum(InferenceStats.INFERENCE_COUNT.getPreferredName())
|
||||
.field(InferenceStats.INFERENCE_COUNT.getPreferredName()))
|
||||
.aggregation(AggregationBuilders.max(InferenceStats.TIMESTAMP.getPreferredName())
|
||||
.field(InferenceStats.TIMESTAMP.getPreferredName()))
|
||||
.query(queryBuilder));
|
||||
}
|
||||
|
||||
private InferenceStats handleMultiNodeStatsResponse(SearchResponse response, String modelId) {
|
||||
if (response.getAggregations() == null) {
|
||||
logger.trace(() -> new ParameterizedMessage("[{}] no previously stored stats found", modelId));
|
||||
return null;
|
||||
}
|
||||
Sum failures = response.getAggregations().get(InferenceStats.FAILURE_COUNT.getPreferredName());
|
||||
Sum missing = response.getAggregations().get(InferenceStats.MISSING_ALL_FIELDS_COUNT.getPreferredName());
|
||||
Sum count = response.getAggregations().get(InferenceStats.INFERENCE_COUNT.getPreferredName());
|
||||
Max timeStamp = response.getAggregations().get(InferenceStats.TIMESTAMP.getPreferredName());
|
||||
return new InferenceStats(
|
||||
missing == null ? 0L : Double.valueOf(missing.getValue()).longValue(),
|
||||
count == null ? 0L : Double.valueOf(count.getValue()).longValue(),
|
||||
failures == null ? 0L : Double.valueOf(failures.getValue()).longValue(),
|
||||
modelId,
|
||||
null,
|
||||
timeStamp == null ? Instant.now() : Instant.ofEpochMilli(Double.valueOf(timeStamp.getValue()).longValue())
|
||||
);
|
||||
}
|
||||
|
||||
static Set<String> collectIds(PageParams pageParams, Set<String> foundFromResources, Set<String> foundFromDocs) {
|
||||
// If there are no matching resource models, there was no buffering and the models from the docs
|
||||
// are paginated correctly.
|
||||
|
|
|
@ -126,7 +126,7 @@ public class TransportGetTrainedModelsStatsActionTests extends ESTestCase {
|
|||
}
|
||||
|
||||
|
||||
public void testInferenceIngestStatsByPipelineId() throws IOException {
|
||||
public void testInferenceIngestStatsByModelId() {
|
||||
List<NodeStats> nodeStatsList = Arrays.asList(
|
||||
buildNodeStats(
|
||||
new IngestStats.Stats(2, 2, 3, 4),
|
||||
|
@ -193,7 +193,7 @@ public class TransportGetTrainedModelsStatsActionTests extends ESTestCase {
|
|||
put("trained_model_1", Collections.singleton("pipeline1"));
|
||||
put("trained_model_2", new HashSet<>(Arrays.asList("pipeline1", "pipeline2")));
|
||||
}};
|
||||
Map<String, IngestStats> ingestStatsMap = TransportGetTrainedModelsStatsAction.inferenceIngestStatsByPipelineId(response,
|
||||
Map<String, IngestStats> ingestStatsMap = TransportGetTrainedModelsStatsAction.inferenceIngestStatsByModelId(response,
|
||||
pipelineIdsByModelIds);
|
||||
|
||||
assertThat(ingestStatsMap.keySet(), equalTo(new HashSet<>(Arrays.asList("trained_model_1", "trained_model_2"))));
|
||||
|
|
|
@ -15,6 +15,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults
|
|||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
|
||||
|
@ -28,6 +29,8 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
|
|||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
|
||||
import org.mockito.ArgumentMatcher;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
@ -39,10 +42,18 @@ import static org.hamcrest.CoreMatchers.is;
|
|||
import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.mockito.Matchers.any;
|
||||
import static org.mockito.Matchers.argThat;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.internal.verification.VerificationModeFactory.times;
|
||||
|
||||
public class LocalModelTests extends ESTestCase {
|
||||
|
||||
public void testClassificationInfer() throws Exception {
|
||||
TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class);
|
||||
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class));
|
||||
String modelId = "classification_model";
|
||||
List<String> inputFields = Arrays.asList("field.foo.keyword", "field.bar", "categorical");
|
||||
TrainedModelDefinition definition = new TrainedModelDefinition.Builder()
|
||||
|
@ -51,10 +62,12 @@ public class LocalModelTests extends ESTestCase {
|
|||
.build();
|
||||
|
||||
Model<ClassificationConfig> model = new LocalModel<>(modelId,
|
||||
"test-node",
|
||||
definition,
|
||||
new TrainedModelInput(inputFields),
|
||||
Collections.singletonMap("field.foo", "field.foo.keyword"),
|
||||
ClassificationConfig.EMPTY_PARAMS);
|
||||
ClassificationConfig.EMPTY_PARAMS,
|
||||
modelStatsService);
|
||||
Map<String, Object> fields = new HashMap<String, Object>() {{
|
||||
put("field.foo", 1.0);
|
||||
put("field.bar", 0.5);
|
||||
|
@ -64,11 +77,13 @@ public class LocalModelTests extends ESTestCase {
|
|||
SingleValueInferenceResults result = getSingleValue(model, fields, new ClassificationConfigUpdate(0, null, null, null));
|
||||
assertThat(result.value(), equalTo(0.0));
|
||||
assertThat(result.valueAsString(), is("0"));
|
||||
assertThat(model.getLatestStatsAndReset().getInferenceCount(), equalTo(1L));
|
||||
|
||||
ClassificationInferenceResults classificationResult =
|
||||
(ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfigUpdate(1, null, null, null));
|
||||
assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001));
|
||||
assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("0"));
|
||||
assertThat(model.getLatestStatsAndReset().getInferenceCount(), equalTo(1L));
|
||||
|
||||
// Test with labels
|
||||
definition = new TrainedModelDefinition.Builder()
|
||||
|
@ -76,10 +91,12 @@ public class LocalModelTests extends ESTestCase {
|
|||
.setTrainedModel(buildClassification(true))
|
||||
.build();
|
||||
model = new LocalModel<>(modelId,
|
||||
"test-node",
|
||||
definition,
|
||||
new TrainedModelInput(inputFields),
|
||||
Collections.singletonMap("field.foo", "field.foo.keyword"),
|
||||
ClassificationConfig.EMPTY_PARAMS);
|
||||
ClassificationConfig.EMPTY_PARAMS,
|
||||
modelStatsService);
|
||||
result = getSingleValue(model, fields, new ClassificationConfigUpdate(0, null, null, null));
|
||||
assertThat(result.value(), equalTo(0.0));
|
||||
assertThat(result.valueAsString(), equalTo("not_to_be"));
|
||||
|
@ -89,29 +106,36 @@ public class LocalModelTests extends ESTestCase {
|
|||
new ClassificationConfigUpdate(1, null, null, null));
|
||||
assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001));
|
||||
assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("not_to_be"));
|
||||
assertThat(model.getLatestStatsAndReset().getInferenceCount(), equalTo(2L));
|
||||
|
||||
classificationResult = (ClassificationInferenceResults)getSingleValue(model,
|
||||
fields,
|
||||
new ClassificationConfigUpdate(2, null, null, null));
|
||||
assertThat(classificationResult.getTopClasses(), hasSize(2));
|
||||
assertThat(model.getLatestStatsAndReset().getInferenceCount(), equalTo(1L));
|
||||
|
||||
classificationResult = (ClassificationInferenceResults)getSingleValue(model,
|
||||
fields,
|
||||
new ClassificationConfigUpdate(-1, null, null, null));
|
||||
assertThat(classificationResult.getTopClasses(), hasSize(2));
|
||||
assertThat(model.getLatestStatsAndReset().getInferenceCount(), equalTo(1L));
|
||||
}
|
||||
|
||||
public void testRegression() throws Exception {
|
||||
TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class);
|
||||
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class));
|
||||
List<String> inputFields = Arrays.asList("foo", "bar", "categorical");
|
||||
TrainedModelDefinition trainedModelDefinition = new TrainedModelDefinition.Builder()
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
|
||||
.setTrainedModel(buildRegression())
|
||||
.build();
|
||||
Model<RegressionConfig> model = new LocalModel<>("regression_model",
|
||||
"test-node",
|
||||
trainedModelDefinition,
|
||||
new TrainedModelInput(inputFields),
|
||||
Collections.singletonMap("bar", "bar.keyword"),
|
||||
RegressionConfig.EMPTY_PARAMS);
|
||||
RegressionConfig.EMPTY_PARAMS,
|
||||
modelStatsService);
|
||||
|
||||
Map<String, Object> fields = new HashMap<String, Object>() {{
|
||||
put("foo", 1.0);
|
||||
|
@ -124,6 +148,8 @@ public class LocalModelTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testAllFieldsMissing() throws Exception {
|
||||
TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class);
|
||||
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class));
|
||||
List<String> inputFields = Arrays.asList("foo", "bar", "categorical");
|
||||
TrainedModelDefinition trainedModelDefinition = new TrainedModelDefinition.Builder()
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
|
||||
|
@ -131,10 +157,12 @@ public class LocalModelTests extends ESTestCase {
|
|||
.build();
|
||||
Model<RegressionConfig> model = new LocalModel<>(
|
||||
"regression_model",
|
||||
"test-node",
|
||||
trainedModelDefinition,
|
||||
new TrainedModelInput(inputFields),
|
||||
null,
|
||||
RegressionConfig.EMPTY_PARAMS);
|
||||
RegressionConfig.EMPTY_PARAMS,
|
||||
modelStatsService);
|
||||
|
||||
Map<String, Object> fields = new HashMap<String, Object>() {{
|
||||
put("something", 1.0);
|
||||
|
@ -145,6 +173,47 @@ public class LocalModelTests extends ESTestCase {
|
|||
WarningInferenceResults results = (WarningInferenceResults)getInferenceResult(model, fields, RegressionConfigUpdate.EMPTY_PARAMS);
|
||||
assertThat(results.getWarning(),
|
||||
equalTo(Messages.getMessage(Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING, "regression_model")));
|
||||
assertThat(model.getLatestStatsAndReset().getMissingAllFieldsCount(), equalTo(1L));
|
||||
}
|
||||
|
||||
public void testInferPersistsStatsAfterNumberOfCalls() throws Exception {
|
||||
TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class);
|
||||
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class));
|
||||
String modelId = "classification_model";
|
||||
List<String> inputFields = Arrays.asList("field.foo", "field.bar", "categorical");
|
||||
TrainedModelDefinition definition = new TrainedModelDefinition.Builder()
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
|
||||
.setTrainedModel(buildClassification(false))
|
||||
.build();
|
||||
|
||||
Model<ClassificationConfig> model = new LocalModel<>(modelId,
|
||||
"test-node",
|
||||
definition,
|
||||
new TrainedModelInput(inputFields),
|
||||
null,
|
||||
ClassificationConfig.EMPTY_PARAMS,
|
||||
modelStatsService
|
||||
);
|
||||
Map<String, Object> fields = new HashMap<String, Object>() {{
|
||||
put("field.foo", 1.0);
|
||||
put("field.bar", 0.5);
|
||||
put("categorical", "dog");
|
||||
}};
|
||||
|
||||
for(int i = 0; i < 100; i++) {
|
||||
getSingleValue(model, fields, new ClassificationConfigUpdate(0, null, null, null));
|
||||
}
|
||||
SingleValueInferenceResults result = getSingleValue(model, fields, new ClassificationConfigUpdate(0, null, null, null));
|
||||
assertThat(result.value(), equalTo(0.0));
|
||||
assertThat(result.valueAsString(), is("0"));
|
||||
// Should have reset after persistence, so only 2 docs have been seen since last persistence
|
||||
assertThat(model.getLatestStatsAndReset().getInferenceCount(), equalTo(2L));
|
||||
verify(modelStatsService, times(1)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
|
||||
@Override
|
||||
public boolean matches(Object o) {
|
||||
return ((InferenceStats)o).getInferenceCount() == 99L;
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
private static <T extends InferenceConfig> SingleValueInferenceResults getSingleValue(Model<T> model,
|
||||
|
|
|
@ -36,14 +36,17 @@ import org.elasticsearch.threadpool.ThreadPool;
|
|||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
|
||||
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.mockito.ArgumentMatcher;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -59,10 +62,12 @@ import static org.hamcrest.Matchers.is;
|
|||
import static org.hamcrest.Matchers.not;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
import static org.mockito.Matchers.any;
|
||||
import static org.mockito.Matchers.argThat;
|
||||
import static org.mockito.Matchers.eq;
|
||||
import static org.mockito.Mockito.atMost;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
@ -73,6 +78,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
private ThreadPool threadPool;
|
||||
private ClusterService clusterService;
|
||||
private InferenceAuditor auditor;
|
||||
private TrainedModelStatsService trainedModelStatsService;
|
||||
|
||||
@Before
|
||||
public void setUpComponents() {
|
||||
|
@ -81,6 +87,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
trainedModelProvider = mock(TrainedModelProvider.class);
|
||||
clusterService = mock(ClusterService.class);
|
||||
auditor = mock(InferenceAuditor.class);
|
||||
trainedModelStatsService = mock(TrainedModelStatsService.class);
|
||||
doAnswer(a -> null).when(auditor).error(any(String.class), any(String.class));
|
||||
doAnswer(a -> null).when(auditor).info(any(String.class), any(String.class));
|
||||
doAnswer(a -> null).when(auditor).warning(any(String.class), any(String.class));
|
||||
|
@ -106,7 +113,9 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
threadPool,
|
||||
clusterService,
|
||||
NamedXContentRegistry.EMPTY,
|
||||
Settings.EMPTY);
|
||||
trainedModelStatsService,
|
||||
Settings.EMPTY,
|
||||
"test-node");
|
||||
|
||||
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3));
|
||||
|
||||
|
@ -131,6 +140,12 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
assertThat(future.get(), is(not(nullValue())));
|
||||
}
|
||||
|
||||
verify(trainedModelStatsService, times(1)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
|
||||
@Override
|
||||
public boolean matches(final Object o) {
|
||||
return ((InferenceStats)o).getModelId().equals(model3);
|
||||
}
|
||||
}));
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(true), any());
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(true), any());
|
||||
// It is not referenced, so called eagerly
|
||||
|
@ -150,7 +165,9 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
threadPool,
|
||||
clusterService,
|
||||
NamedXContentRegistry.EMPTY,
|
||||
Settings.builder().put(ModelLoadingService.INFERENCE_MODEL_CACHE_SIZE.getKey(), new ByteSizeValue(20L)).build());
|
||||
trainedModelStatsService,
|
||||
Settings.builder().put(ModelLoadingService.INFERENCE_MODEL_CACHE_SIZE.getKey(), new ByteSizeValue(20L)).build(),
|
||||
"test-node");
|
||||
|
||||
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3));
|
||||
|
||||
|
@ -175,6 +192,24 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model2), eq(true), any());
|
||||
// Only loaded requested once on the initial load from the change event
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any());
|
||||
verify(trainedModelStatsService, atMost(2)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
|
||||
@Override
|
||||
public boolean matches(final Object o) {
|
||||
return ((InferenceStats)o).getModelId().equals(model1);
|
||||
}
|
||||
}));
|
||||
verify(trainedModelStatsService, atMost(2)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
|
||||
@Override
|
||||
public boolean matches(final Object o) {
|
||||
return ((InferenceStats)o).getModelId().equals(model2);
|
||||
}
|
||||
}));
|
||||
verify(trainedModelStatsService, times(1)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
|
||||
@Override
|
||||
public boolean matches(final Object o) {
|
||||
return ((InferenceStats)o).getModelId().equals(model3);
|
||||
}
|
||||
}));
|
||||
|
||||
// Load model 3, should invalidate 1
|
||||
for(int i = 0; i < 10; i++) {
|
||||
|
@ -214,7 +249,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model1), eq(true), any());
|
||||
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model2), eq(true), any());
|
||||
verify(trainedModelProvider, Mockito.atLeast(4)).getTrainedModel(eq(model3), eq(true), any());
|
||||
verify(trainedModelProvider, Mockito.atMost(5)).getTrainedModel(eq(model3), eq(true), any());
|
||||
verify(trainedModelProvider, atMost(5)).getTrainedModel(eq(model3), eq(true), any());
|
||||
}
|
||||
|
||||
|
||||
|
@ -227,7 +262,9 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
threadPool,
|
||||
clusterService,
|
||||
NamedXContentRegistry.EMPTY,
|
||||
Settings.EMPTY);
|
||||
trainedModelStatsService,
|
||||
Settings.EMPTY,
|
||||
"test-node");
|
||||
|
||||
modelLoadingService.clusterChanged(ingestChangedEvent(false, model1));
|
||||
|
||||
|
@ -238,6 +275,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
}
|
||||
|
||||
verify(trainedModelProvider, times(10)).getTrainedModel(eq(model1), eq(true), any());
|
||||
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class));
|
||||
}
|
||||
|
||||
public void testGetCachedMissingModel() throws Exception {
|
||||
|
@ -249,7 +287,9 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
threadPool,
|
||||
clusterService,
|
||||
NamedXContentRegistry.EMPTY,
|
||||
Settings.EMPTY);
|
||||
trainedModelStatsService,
|
||||
Settings.EMPTY,
|
||||
"test-node");
|
||||
modelLoadingService.clusterChanged(ingestChangedEvent(model));
|
||||
|
||||
PlainActionFuture<Model<? extends InferenceConfig>> future = new PlainActionFuture<>();
|
||||
|
@ -263,6 +303,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
}
|
||||
|
||||
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model), eq(true), any());
|
||||
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class));
|
||||
}
|
||||
|
||||
public void testGetMissingModel() {
|
||||
|
@ -274,7 +315,9 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
threadPool,
|
||||
clusterService,
|
||||
NamedXContentRegistry.EMPTY,
|
||||
Settings.EMPTY);
|
||||
trainedModelStatsService,
|
||||
Settings.EMPTY,
|
||||
"test-node");
|
||||
|
||||
PlainActionFuture<Model<? extends InferenceConfig>> future = new PlainActionFuture<>();
|
||||
modelLoadingService.getModel(model, future);
|
||||
|
@ -295,7 +338,9 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
threadPool,
|
||||
clusterService,
|
||||
NamedXContentRegistry.EMPTY,
|
||||
Settings.EMPTY);
|
||||
trainedModelStatsService,
|
||||
Settings.EMPTY,
|
||||
"test-node");
|
||||
|
||||
for(int i = 0; i < 3; i++) {
|
||||
PlainActionFuture<Model<? extends InferenceConfig>> future = new PlainActionFuture<>();
|
||||
|
@ -304,6 +349,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
}
|
||||
|
||||
verify(trainedModelProvider, times(3)).getTrainedModel(eq(model), eq(true), any());
|
||||
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
|
@ -312,6 +358,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
when(definition.ramBytesUsed()).thenReturn(size);
|
||||
TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class);
|
||||
when(trainedModelConfig.getModelDefinition()).thenReturn(definition);
|
||||
when(trainedModelConfig.getModelId()).thenReturn(modelId);
|
||||
when(trainedModelConfig.getInferenceConfig()).thenReturn(ClassificationConfig.EMPTY_PARAMS);
|
||||
when(trainedModelConfig.getInput()).thenReturn(new TrainedModelInput(Arrays.asList("foo", "bar", "baz")));
|
||||
doAnswer(invocationOnMock -> {
|
||||
|
|
Loading…
Reference in New Issue