[7.x] [ML] Start gathering and storing inference stats (#53429) (#54738)

* [ML] Start gathering and storing inference stats (#53429)

This PR enables stats on inference to be gathered and stored in the `.ml-stats-*` indices.

Each node + model_id will have its own running stats document and these will later be summed together when returning _stats to the user.

`.ml-stats-*` is ILM managed (when possible). So, at any point the underlying index could change. This means that a stats document that is read in and then later updated will actually be a new doc in a new index. This complicates matters as this means that having a running knowledge of seq_no and primary_term is complicated and almost impossible. This is because we don't know the latest index name.

We should also strive for throughput, as this code sits in the middle of an ingest pipeline (or even a query).
This commit is contained in:
Benjamin Trent 2020-04-13 08:15:46 -04:00 committed by GitHub
parent 7a8a66d9ae
commit c5c7ee9d73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 958 additions and 55 deletions

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.core.ml.action; package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionRequestBuilder; import org.elasticsearch.action.ActionRequestBuilder;
import org.elasticsearch.action.ActionType; import org.elasticsearch.action.ActionType;
import org.elasticsearch.client.ElasticsearchClient; import org.elasticsearch.client.ElasticsearchClient;
@ -20,6 +21,7 @@ import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest;
import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse; import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse;
import org.elasticsearch.xpack.core.action.util.QueryPage; import org.elasticsearch.xpack.core.action.util.QueryPage;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
@ -37,6 +39,7 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
public static final ParseField MODEL_ID = new ParseField("model_id"); public static final ParseField MODEL_ID = new ParseField("model_id");
public static final ParseField PIPELINE_COUNT = new ParseField("pipeline_count"); public static final ParseField PIPELINE_COUNT = new ParseField("pipeline_count");
public static final ParseField INFERENCE_STATS = new ParseField("inference_stats");
private GetTrainedModelsStatsAction() { private GetTrainedModelsStatsAction() {
super(NAME, GetTrainedModelsStatsAction.Response::new); super(NAME, GetTrainedModelsStatsAction.Response::new);
@ -78,25 +81,32 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
public static class TrainedModelStats implements ToXContentObject, Writeable { public static class TrainedModelStats implements ToXContentObject, Writeable {
private final String modelId; private final String modelId;
private final IngestStats ingestStats; private final IngestStats ingestStats;
private final InferenceStats inferenceStats;
private final int pipelineCount; private final int pipelineCount;
private static final IngestStats EMPTY_INGEST_STATS = new IngestStats(new IngestStats.Stats(0, 0, 0, 0), private static final IngestStats EMPTY_INGEST_STATS = new IngestStats(new IngestStats.Stats(0, 0, 0, 0),
Collections.emptyList(), Collections.emptyList(),
Collections.emptyMap()); Collections.emptyMap());
public TrainedModelStats(String modelId, IngestStats ingestStats, int pipelineCount) { public TrainedModelStats(String modelId, IngestStats ingestStats, int pipelineCount, InferenceStats inferenceStats) {
this.modelId = Objects.requireNonNull(modelId); this.modelId = Objects.requireNonNull(modelId);
this.ingestStats = ingestStats == null ? EMPTY_INGEST_STATS : ingestStats; this.ingestStats = ingestStats == null ? EMPTY_INGEST_STATS : ingestStats;
if (pipelineCount < 0) { if (pipelineCount < 0) {
throw new ElasticsearchException("[{}] must be a greater than or equal to 0", PIPELINE_COUNT.getPreferredName()); throw new ElasticsearchException("[{}] must be a greater than or equal to 0", PIPELINE_COUNT.getPreferredName());
} }
this.pipelineCount = pipelineCount; this.pipelineCount = pipelineCount;
this.inferenceStats = inferenceStats;
} }
public TrainedModelStats(StreamInput in) throws IOException { public TrainedModelStats(StreamInput in) throws IOException {
modelId = in.readString(); modelId = in.readString();
ingestStats = new IngestStats(in); ingestStats = new IngestStats(in);
pipelineCount = in.readVInt(); pipelineCount = in.readVInt();
if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
this.inferenceStats = in.readOptionalWriteable(InferenceStats::new);
} else {
this.inferenceStats = null;
}
} }
public String getModelId() { public String getModelId() {
@ -120,6 +130,9 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
// Ingest stats is a fragment // Ingest stats is a fragment
ingestStats.toXContent(builder, params); ingestStats.toXContent(builder, params);
} }
if (this.inferenceStats != null) {
builder.field(INFERENCE_STATS.getPreferredName(), this.inferenceStats);
}
builder.endObject(); builder.endObject();
return builder; return builder;
} }
@ -129,11 +142,14 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
out.writeString(modelId); out.writeString(modelId);
ingestStats.writeTo(out); ingestStats.writeTo(out);
out.writeVInt(pipelineCount); out.writeVInt(pipelineCount);
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
out.writeOptionalWriteable(this.inferenceStats);
}
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(modelId, ingestStats, pipelineCount); return Objects.hash(modelId, ingestStats, pipelineCount, inferenceStats);
} }
@Override @Override
@ -147,7 +163,8 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
TrainedModelStats other = (TrainedModelStats) obj; TrainedModelStats other = (TrainedModelStats) obj;
return Objects.equals(this.modelId, other.modelId) return Objects.equals(this.modelId, other.modelId)
&& Objects.equals(this.ingestStats, other.ingestStats) && Objects.equals(this.ingestStats, other.ingestStats)
&& Objects.equals(this.pipelineCount, other.pipelineCount); && Objects.equals(this.pipelineCount, other.pipelineCount)
&& Objects.equals(this.inferenceStats, other.inferenceStats);
} }
} }
@ -171,6 +188,7 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
private long totalModelCount; private long totalModelCount;
private Set<String> expandedIds; private Set<String> expandedIds;
private Map<String, IngestStats> ingestStatsMap; private Map<String, IngestStats> ingestStatsMap;
private Map<String, InferenceStats> inferenceStatsMap;
public Builder setTotalModelCount(long totalModelCount) { public Builder setTotalModelCount(long totalModelCount) {
this.totalModelCount = totalModelCount; this.totalModelCount = totalModelCount;
@ -191,13 +209,23 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
return this; return this;
} }
public Builder setInferenceStatsByModelId(Map<String, InferenceStats> infereceStatsByModelId) {
this.inferenceStatsMap = infereceStatsByModelId;
return this;
}
public Response build() { public Response build() {
List<TrainedModelStats> trainedModelStats = new ArrayList<>(expandedIds.size()); List<TrainedModelStats> trainedModelStats = new ArrayList<>(expandedIds.size());
expandedIds.forEach(id -> { expandedIds.forEach(id -> {
IngestStats ingestStats = ingestStatsMap.get(id); IngestStats ingestStats = ingestStatsMap.get(id);
trainedModelStats.add(new TrainedModelStats(id, ingestStats, ingestStats == null ? InferenceStats inferenceStats = inferenceStatsMap.get(id);
trainedModelStats.add(new TrainedModelStats(
id,
ingestStats,
ingestStats == null ?
0 : 0 :
ingestStats.getPipelineStats().size())); ingestStats.getPipelineStats().size(),
inferenceStats));
}); });
trainedModelStats.sort(Comparator.comparing(TrainedModelStats::getModelId)); trainedModelStats.sort(Comparator.comparing(TrainedModelStats::getModelId));
return new Response(new QueryPage<>(trainedModelStats, totalModelCount, RESULTS_FIELD)); return new Response(new QueryPage<>(trainedModelStats, totalModelCount, RESULTS_FIELD));

View File

@ -0,0 +1,249 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.common.time.TimeUtils;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import java.io.IOException;
import java.time.Instant;
import java.util.Objects;
import java.util.concurrent.atomic.LongAdder;
public class InferenceStats implements ToXContentObject, Writeable {
public static final String NAME = "inference_stats";
public static final ParseField MISSING_ALL_FIELDS_COUNT = new ParseField("missing_all_fields_count");
public static final ParseField INFERENCE_COUNT = new ParseField("inference_count");
public static final ParseField MODEL_ID = new ParseField("model_id");
public static final ParseField NODE_ID = new ParseField("node_id");
public static final ParseField FAILURE_COUNT = new ParseField("failure_count");
public static final ParseField TYPE = new ParseField("type");
public static final ParseField TIMESTAMP = new ParseField("time_stamp");
public static final ConstructingObjectParser<InferenceStats, Void> PARSER = new ConstructingObjectParser<>(
NAME,
true,
a -> new InferenceStats((Long)a[0], (Long)a[1], (Long)a[2], (String)a[3], (String)a[4], (Instant)a[5])
);
static {
PARSER.declareLong(ConstructingObjectParser.constructorArg(), MISSING_ALL_FIELDS_COUNT);
PARSER.declareLong(ConstructingObjectParser.constructorArg(), INFERENCE_COUNT);
PARSER.declareLong(ConstructingObjectParser.constructorArg(), FAILURE_COUNT);
PARSER.declareString(ConstructingObjectParser.constructorArg(), MODEL_ID);
PARSER.declareString(ConstructingObjectParser.constructorArg(), NODE_ID);
PARSER.declareField(ConstructingObjectParser.constructorArg(),
p -> TimeUtils.parseTimeFieldToInstant(p, TIMESTAMP.getPreferredName()),
TIMESTAMP,
ObjectParser.ValueType.VALUE);
}
public static InferenceStats emptyStats(String modelId, String nodeId) {
return new InferenceStats(0L, 0L, 0L, modelId, nodeId, Instant.now());
}
public static String docId(String modelId, String nodeId) {
return NAME + "-" + modelId + "-" + nodeId;
}
private final long missingAllFieldsCount;
private final long inferenceCount;
private final long failureCount;
private final String modelId;
private final String nodeId;
private final Instant timeStamp;
private InferenceStats(Long missingAllFieldsCount,
Long inferenceCount,
Long failureCount,
String modelId,
String nodeId,
Instant instant) {
this(unbox(missingAllFieldsCount),
unbox(inferenceCount),
unbox(failureCount),
modelId,
nodeId,
instant);
}
public InferenceStats(long missingAllFieldsCount,
long inferenceCount,
long failureCount,
String modelId,
String nodeId,
Instant timeStamp) {
this.missingAllFieldsCount = missingAllFieldsCount;
this.inferenceCount = inferenceCount;
this.failureCount = failureCount;
this.modelId = modelId;
this.nodeId = nodeId;
this.timeStamp = timeStamp == null ?
Instant.ofEpochMilli(Instant.now().toEpochMilli()) :
Instant.ofEpochMilli(timeStamp.toEpochMilli());
}
public InferenceStats(StreamInput in) throws IOException {
this.missingAllFieldsCount = in.readVLong();
this.inferenceCount = in.readVLong();
this.failureCount = in.readVLong();
this.modelId = in.readOptionalString();
this.nodeId = in.readOptionalString();
this.timeStamp = in.readInstant();
}
public long getMissingAllFieldsCount() {
return missingAllFieldsCount;
}
public long getInferenceCount() {
return inferenceCount;
}
public long getFailureCount() {
return failureCount;
}
public String getModelId() {
return modelId;
}
public String getNodeId() {
return nodeId;
}
public Instant getTimeStamp() {
return timeStamp;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
assert modelId != null : "model_id cannot be null when storing inference stats";
assert nodeId != null : "node_id cannot be null when storing inference stats";
builder.field(TYPE.getPreferredName(), NAME);
builder.field(MODEL_ID.getPreferredName(), modelId);
builder.field(NODE_ID.getPreferredName(), nodeId);
}
builder.field(FAILURE_COUNT.getPreferredName(), failureCount);
builder.field(INFERENCE_COUNT.getPreferredName(), inferenceCount);
builder.field(MISSING_ALL_FIELDS_COUNT.getPreferredName(), missingAllFieldsCount);
builder.timeField(TIMESTAMP.getPreferredName(), TIMESTAMP.getPreferredName() + "_string", timeStamp.toEpochMilli());
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InferenceStats that = (InferenceStats) o;
return missingAllFieldsCount == that.missingAllFieldsCount
&& inferenceCount == that.inferenceCount
&& failureCount == that.failureCount
&& Objects.equals(modelId, that.modelId)
&& Objects.equals(nodeId, that.nodeId)
&& Objects.equals(timeStamp, that.timeStamp);
}
@Override
public int hashCode() {
return Objects.hash(missingAllFieldsCount, inferenceCount, failureCount, modelId, nodeId, timeStamp);
}
@Override
public String toString() {
return "InferenceStats{" +
"missingAllFieldsCount=" + missingAllFieldsCount +
", inferenceCount=" + inferenceCount +
", failureCount=" + failureCount +
", modelId='" + modelId + '\'' +
", nodeId='" + nodeId + '\'' +
", timeStamp=" + timeStamp +
'}';
}
private static long unbox(@Nullable Long value) {
return value == null ? 0L : value;
}
public static Accumulator accumulator(InferenceStats stats) {
return new Accumulator(stats);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVLong(this.missingAllFieldsCount);
out.writeVLong(this.inferenceCount);
out.writeVLong(this.failureCount);
out.writeOptionalString(this.modelId);
out.writeOptionalString(this.nodeId);
out.writeInstant(timeStamp);
}
public static class Accumulator {
private final LongAdder missingFieldsAccumulator = new LongAdder();
private final LongAdder inferenceAccumulator = new LongAdder();
private final LongAdder failureCountAccumulator = new LongAdder();
private final String modelId;
private final String nodeId;
public Accumulator(String modelId, String nodeId) {
this.modelId = modelId;
this.nodeId = nodeId;
}
public Accumulator(InferenceStats previousStats) {
this.modelId = previousStats.modelId;
this.nodeId = previousStats.nodeId;
this.missingFieldsAccumulator.add(previousStats.missingAllFieldsCount);
this.inferenceAccumulator.add(previousStats.inferenceCount);
this.failureCountAccumulator.add(previousStats.failureCount);
}
public Accumulator merge(InferenceStats otherStats) {
this.missingFieldsAccumulator.add(otherStats.missingAllFieldsCount);
this.inferenceAccumulator.add(otherStats.inferenceCount);
this.failureCountAccumulator.add(otherStats.failureCount);
return this;
}
public void incMissingFields() {
this.missingFieldsAccumulator.increment();
}
public void incInference() {
this.inferenceAccumulator.increment();
}
public void incFailure() {
this.failureCountAccumulator.increment();
}
public InferenceStats currentStats() {
return currentStats(Instant.now());
}
public InferenceStats currentStats(Instant timeStamp) {
return new InferenceStats(missingFieldsAccumulator.longValue(),
inferenceAccumulator.longValue(),
failureCountAccumulator.longValue(),
modelId,
nodeId,
timeStamp);
}
}
}

View File

@ -85,6 +85,21 @@
"peak_usage_bytes" : { "peak_usage_bytes" : {
"type" : "long" "type" : "long"
}, },
"model_id": {
"type": "keyword"
},
"node_id": {
"type": "keyword"
},
"inference_count": {
"type": "long"
},
"failure_count": {
"type": "long"
},
"missing_all_fields_count": {
"type": "long"
},
"skipped_docs_count": { "skipped_docs_count": {
"type": "long" "type": "long"
}, },

View File

@ -5,19 +5,23 @@
*/ */
package org.elasticsearch.xpack.core.ml.action; package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.ingest.IngestStats; import org.elasticsearch.ingest.IngestStats;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.action.util.QueryPage; import org.elasticsearch.xpack.core.action.util.QueryPage;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction.Response; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction.Response;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStatsTests;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import static org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction.Response.RESULTS_FIELD;
public class GetTrainedModelsStatsActionResponseTests extends AbstractWireSerializingTestCase<Response> { public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSerializationTestCase<Response> {
@Override @Override
protected Response createTestInstance() { protected Response createTestInstance() {
@ -26,10 +30,11 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractWireSerial
.limit(listSize).map(id -> .limit(listSize).map(id ->
new Response.TrainedModelStats(id, new Response.TrainedModelStats(id,
randomBoolean() ? randomIngestStats() : null, randomBoolean() ? randomIngestStats() : null,
randomIntBetween(0, 10)) randomIntBetween(0, 10),
randomBoolean() ? InferenceStatsTests.createTestInstance(id, null) : null)
) )
.collect(Collectors.toList()); .collect(Collectors.toList());
return new Response(new QueryPage<>(trainedModelStats, randomLongBetween(listSize, 1000), Response.RESULTS_FIELD)); return new Response(new QueryPage<>(trainedModelStats, randomLongBetween(listSize, 1000), RESULTS_FIELD));
} }
private IngestStats randomIngestStats() { private IngestStats randomIngestStats() {
@ -57,4 +62,37 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractWireSerial
protected Writeable.Reader<Response> instanceReader() { protected Writeable.Reader<Response> instanceReader() {
return Response::new; return Response::new;
} }
@Override
protected Response mutateInstanceForVersion(Response instance, Version version) {
if (version.before(Version.V_7_8_0)) {
List<Response.TrainedModelStats> stats = instance.getResources()
.results()
.stream()
.map(s -> new Response.TrainedModelStats(s.getModelId(),
adjustForVersion(s.getIngestStats(), version),
s.getPipelineCount(),
null))
.collect(Collectors.toList());
return new Response(new QueryPage<>(stats, instance.getResources().count(), RESULTS_FIELD));
}
return instance;
}
IngestStats adjustForVersion(IngestStats stats, Version version) {
if (version.before(Version.V_7_6_0)) {
return new IngestStats(stats.getTotalStats(),
stats.getPipelineStats(),
stats.getProcessorStats()
.entrySet()
.stream()
.collect(Collectors.toMap(Map.Entry::getKey,
(kv) -> kv.getValue()
.stream()
.map(pstats -> new IngestStats.ProcessorStat(pstats.getName(), "_NOT_AVAILABLE", pstats.getStats()))
.collect(Collectors.toList()))));
}
return stats;
}
} }

View File

@ -320,9 +320,13 @@ public class TrainedModelConfigTests extends AbstractBWCSerializationTestCase<Tr
@Override @Override
protected TrainedModelConfig mutateInstanceForVersion(TrainedModelConfig instance, Version version) { protected TrainedModelConfig mutateInstanceForVersion(TrainedModelConfig instance, Version version) {
TrainedModelConfig.Builder builder = new TrainedModelConfig.Builder(instance);
if (version.before(Version.V_7_7_0)) {
builder.setDefaultFieldMap(null);
}
if (version.before(Version.V_7_8_0)) { if (version.before(Version.V_7_8_0)) {
return new TrainedModelConfig.Builder(instance).setInferenceConfig(null).build(); builder.setInferenceConfig(null);
} }
return instance; return builder.build();
} }
} }

View File

@ -0,0 +1,57 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException;
import java.time.Instant;
import java.util.Collections;
import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE;
public class InferenceStatsTests extends AbstractSerializingTestCase<InferenceStats> {
public static InferenceStats createTestInstance(String modelId, @Nullable String nodeId) {
return new InferenceStats(randomNonNegativeLong(),
randomNonNegativeLong(),
randomNonNegativeLong(),
modelId,
nodeId,
Instant.now()
);
}
@Override
protected InferenceStats doParseInstance(XContentParser parser) throws IOException {
return InferenceStats.PARSER.apply(parser, null);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected InferenceStats createTestInstance() {
return createTestInstance(randomAlphaOfLength(10), randomAlphaOfLength(10));
}
@Override
protected Writeable.Reader<InferenceStats> instanceReader() {
return InferenceStats::new;
}
@Override
protected ToXContent.Params getToXContentParams() {
return new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true"));
}
}

View File

@ -8,6 +8,7 @@ package org.elasticsearch.xpack.ml.integration;
import org.apache.http.util.EntityUtils; import org.apache.http.util.EntityUtils;
import org.elasticsearch.client.Request; import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response; import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseException;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.util.concurrent.ThreadContext;
@ -29,6 +30,7 @@ import org.junit.Before;
import java.io.IOException; import java.io.IOException;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue; import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.containsString;
@ -67,7 +69,7 @@ public class InferenceIngestIT extends ESRestTestCase {
client().performRequest(new Request("DELETE", "_ml/inference/test_regression")); client().performRequest(new Request("DELETE", "_ml/inference/test_regression"));
} }
public void testPipelineCreationAndDeletion() throws Exception { public void testPathologicalPipelineCreationAndDeletion() throws Exception {
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
client().performRequest(putPipeline("simple_classification_pipeline", CLASSIFICATION_PIPELINE)); client().performRequest(putPipeline("simple_classification_pipeline", CLASSIFICATION_PIPELINE));
@ -78,6 +80,24 @@ public class InferenceIngestIT extends ESRestTestCase {
client().performRequest(indexRequest("index_for_inference_test", "simple_regression_pipeline", generateSourceDoc())); client().performRequest(indexRequest("index_for_inference_test", "simple_regression_pipeline", generateSourceDoc()));
client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_regression_pipeline")); client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_regression_pipeline"));
} }
client().performRequest(new Request("POST", "index_for_inference_test/_refresh"));
Response searchResponse = client().performRequest(searchRequest("index_for_inference_test",
QueryBuilders.boolQuery()
.filter(
QueryBuilders.existsQuery("ml.inference.regression.predicted_value"))));
assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":10"));
searchResponse = client().performRequest(searchRequest("index_for_inference_test",
QueryBuilders.boolQuery()
.filter(
QueryBuilders.existsQuery("ml.inference.classification.predicted_value"))));
assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":10"));
}
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/54786")
public void testPipelineIngest() throws Exception {
client().performRequest(putPipeline("simple_classification_pipeline", CLASSIFICATION_PIPELINE)); client().performRequest(putPipeline("simple_classification_pipeline", CLASSIFICATION_PIPELINE));
client().performRequest(putPipeline("simple_regression_pipeline", REGRESSION_PIPELINE)); client().performRequest(putPipeline("simple_regression_pipeline", REGRESSION_PIPELINE));
@ -87,24 +107,42 @@ public class InferenceIngestIT extends ESRestTestCase {
client().performRequest(indexRequest("index_for_inference_test", "simple_regression_pipeline", generateSourceDoc())); client().performRequest(indexRequest("index_for_inference_test", "simple_regression_pipeline", generateSourceDoc()));
} }
for (int i = 0; i < 5; i++) {
client().performRequest(indexRequest("index_for_inference_test", "simple_regression_pipeline", generateSourceDoc()));
}
client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_regression_pipeline")); client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_regression_pipeline"));
client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_classification_pipeline")); client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_classification_pipeline"));
client().performRequest(new Request("POST", "index_for_inference_test/_refresh")); client().performRequest(new Request("POST", "index_for_inference_test/_refresh"));
Response searchResponse = client().performRequest(searchRequest("index_for_inference_test", Response searchResponse = client().performRequest(searchRequest("index_for_inference_test",
QueryBuilders.boolQuery() QueryBuilders.boolQuery()
.filter( .filter(
QueryBuilders.existsQuery("ml.inference.regression.predicted_value")))); QueryBuilders.existsQuery("ml.inference.regression.predicted_value"))));
assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":20")); assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":15"));
searchResponse = client().performRequest(searchRequest("index_for_inference_test", searchResponse = client().performRequest(searchRequest("index_for_inference_test",
QueryBuilders.boolQuery() QueryBuilders.boolQuery()
.filter( .filter(
QueryBuilders.existsQuery("ml.inference.classification.predicted_value")))); QueryBuilders.existsQuery("ml.inference.classification.predicted_value"))));
assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":20")); assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":10"));
assertBusy(() -> {
try {
Response statsResponse = client().performRequest(new Request("GET", "_ml/inference/test_classification/_stats"));
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":10"));
statsResponse = client().performRequest(new Request("GET", "_ml/inference/test_regression/_stats"));
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":15"));
// can get both
statsResponse = client().performRequest(new Request("GET", "_ml/inference/_stats"));
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":15"));
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":10"));
} catch (ResponseException ex) {
//this could just mean shard failures.
}
}, 30, TimeUnit.SECONDS);
} }
public void testSimulate() throws IOException { public void testSimulate() throws IOException {

View File

@ -211,6 +211,7 @@ import org.elasticsearch.xpack.ml.dataframe.process.NativeAnalyticsProcessFactor
import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory; import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult; import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
@ -630,13 +631,20 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, Analys
this.datafeedManager.set(datafeedManager); this.datafeedManager.set(datafeedManager);
// Inference components // Inference components
final TrainedModelStatsService trainedModelStatsService = new TrainedModelStatsService(resultsPersisterService,
originSettingClient,
indexNameExpressionResolver,
clusterService,
threadPool);
final TrainedModelProvider trainedModelProvider = new TrainedModelProvider(client, xContentRegistry); final TrainedModelProvider trainedModelProvider = new TrainedModelProvider(client, xContentRegistry);
final ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, final ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
inferenceAuditor, inferenceAuditor,
threadPool, threadPool,
clusterService, clusterService,
xContentRegistry, xContentRegistry,
settings); trainedModelStatsService,
settings,
clusterService.getNodeName());
// Data frame analytics components // Data frame analytics components
AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory, AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory,

View File

@ -26,6 +26,7 @@ import org.elasticsearch.ingest.Pipeline;
import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService; import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
@ -38,12 +39,12 @@ import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
public class TransportGetTrainedModelsStatsAction extends HandledTransportAction<GetTrainedModelsStatsAction.Request, public class TransportGetTrainedModelsStatsAction extends HandledTransportAction<GetTrainedModelsStatsAction.Request,
GetTrainedModelsStatsAction.Response> { GetTrainedModelsStatsAction.Response> {
@ -73,13 +74,21 @@ public class TransportGetTrainedModelsStatsAction extends HandledTransportAction
GetTrainedModelsStatsAction.Response.Builder responseBuilder = new GetTrainedModelsStatsAction.Response.Builder(); GetTrainedModelsStatsAction.Response.Builder responseBuilder = new GetTrainedModelsStatsAction.Response.Builder();
ActionListener<List<InferenceStats>> inferenceStatsListener = ActionListener.wrap(
inferenceStats -> listener.onResponse(responseBuilder.setInferenceStatsByModelId(inferenceStats.stream()
.collect(Collectors.toMap(InferenceStats::getModelId, Function.identity())))
.build()),
listener::onFailure
);
ActionListener<NodesStatsResponse> nodesStatsListener = ActionListener.wrap( ActionListener<NodesStatsResponse> nodesStatsListener = ActionListener.wrap(
nodesStatsResponse -> { nodesStatsResponse -> {
Map<String, IngestStats> modelIdIngestStats = inferenceIngestStatsByPipelineId(nodesStatsResponse, Map<String, IngestStats> modelIdIngestStats = inferenceIngestStatsByModelId(nodesStatsResponse,
pipelineIdsByModelIds(clusterService.state(), pipelineIdsByModelIds(clusterService.state(),
ingestService, ingestService,
responseBuilder.getExpandedIds())); responseBuilder.getExpandedIds()));
listener.onResponse(responseBuilder.setIngestStatsByModelId(modelIdIngestStats).build()); responseBuilder.setIngestStatsByModelId(modelIdIngestStats);
trainedModelProvider.getInferenceStats(responseBuilder.getExpandedIds().toArray(new String[0]), inferenceStatsListener);
}, },
listener::onFailure listener::onFailure
); );
@ -103,7 +112,7 @@ public class TransportGetTrainedModelsStatsAction extends HandledTransportAction
idsListener); idsListener);
} }
static Map<String, IngestStats> inferenceIngestStatsByPipelineId(NodesStatsResponse response, static Map<String, IngestStats> inferenceIngestStatsByModelId(NodesStatsResponse response,
Map<String, Set<String>> modelIdToPipelineId) { Map<String, Set<String>> modelIdToPipelineId) {
Map<String, IngestStats> ingestStatsMap = new HashMap<>(); Map<String, IngestStats> ingestStatsMap = new HashMap<>();

View File

@ -0,0 +1,180 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.action.update.UpdateRequest;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.component.LifecycleListener;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.threadpool.Scheduler;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
public class TrainedModelStatsService {
private static final Logger logger = LogManager.getLogger(TrainedModelStatsService.class);
private static final TimeValue PERSISTENCE_INTERVAL = TimeValue.timeValueSeconds(1);
// Script to only update if stats have increased since last persistence
private static final String STATS_UPDATE_SCRIPT = "" +
" ctx._source.missing_all_fields_count += params.missing_all_fields_count;\n" +
" ctx._source.inference_count += params.inference_count;\n" +
" ctx._source.failure_count += params.failure_count;\n" +
" ctx._source.time_stamp = params.time_stamp;";
private static final ToXContent.Params FOR_INTERNAL_STORAGE_PARAMS =
new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"));
private final Map<String, InferenceStats> statsQueue;
private final ResultsPersisterService resultsPersisterService;
private final OriginSettingClient client;
private final IndexNameExpressionResolver indexNameExpressionResolver;
private final ThreadPool threadPool;
private volatile Scheduler.Cancellable scheduledFuture;
private volatile boolean verifiedStatsIndexCreated;
private volatile boolean stopped;
private volatile ClusterState clusterState;
public TrainedModelStatsService(ResultsPersisterService resultsPersisterService,
OriginSettingClient client,
IndexNameExpressionResolver indexNameExpressionResolver,
ClusterService clusterService,
ThreadPool threadPool) {
this.resultsPersisterService = resultsPersisterService;
this.client = client;
this.indexNameExpressionResolver = indexNameExpressionResolver;
this.threadPool = threadPool;
this.statsQueue = new ConcurrentHashMap<>();
clusterService.addLifecycleListener(new LifecycleListener() {
@Override
public void beforeStart() {
start();
}
@Override
public void beforeStop() {
stop();
}
});
clusterService.addListener((event) -> this.clusterState = event.state());
}
public void queueStats(InferenceStats stats) {
statsQueue.compute(InferenceStats.docId(stats.getModelId(), stats.getNodeId()),
(k, previousStats) -> previousStats == null ?
stats :
InferenceStats.accumulator(stats).merge(previousStats).currentStats(stats.getTimeStamp()));
}
void stop() {
stopped = true;
statsQueue.clear();
ThreadPool.Cancellable cancellable = this.scheduledFuture;
if (cancellable != null) {
cancellable.cancel();
}
}
void start() {
stopped = false;
scheduledFuture = threadPool.scheduleWithFixedDelay(this::updateStats,
PERSISTENCE_INTERVAL,
MachineLearning.UTILITY_THREAD_POOL_NAME);
}
void updateStats() {
if (clusterState == null || statsQueue.isEmpty()) {
return;
}
if (verifiedStatsIndexCreated == false) {
try {
PlainActionFuture<Boolean> listener = new PlainActionFuture<>();
MlStatsIndex.createStatsIndexAndAliasIfNecessary(client, clusterState, indexNameExpressionResolver, listener);
listener.actionGet();
verifiedStatsIndexCreated = true;
} catch (Exception e) {
logger.error("failure creating ml stats index for storing model stats", e);
return;
}
}
List<InferenceStats> stats = new ArrayList<>(statsQueue.size());
for(String k : statsQueue.keySet()) {
InferenceStats inferenceStats = statsQueue.remove(k);
if (inferenceStats != null) {
stats.add(inferenceStats);
}
}
if (stats.isEmpty()) {
return;
}
BulkRequest bulkRequest = new BulkRequest();
stats.stream().map(TrainedModelStatsService::buildUpdateRequest).filter(Objects::nonNull).forEach(bulkRequest::add);
bulkRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
if (bulkRequest.requests().isEmpty()) {
return;
}
resultsPersisterService.bulkIndexWithRetry(bulkRequest,
stats.stream().map(InferenceStats::getModelId).collect(Collectors.joining(",")),
() -> stopped == false,
(msg) -> {});
}
static UpdateRequest buildUpdateRequest(InferenceStats stats) {
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
Map<String, Object> params = new HashMap<>();
params.put(InferenceStats.FAILURE_COUNT.getPreferredName(), stats.getFailureCount());
params.put(InferenceStats.MISSING_ALL_FIELDS_COUNT.getPreferredName(), stats.getMissingAllFieldsCount());
params.put(InferenceStats.TIMESTAMP.getPreferredName(), stats.getTimeStamp().toEpochMilli());
params.put(InferenceStats.INFERENCE_COUNT.getPreferredName(), stats.getInferenceCount());
stats.toXContent(builder, FOR_INTERNAL_STORAGE_PARAMS);
UpdateRequest updateRequest = new UpdateRequest();
updateRequest.upsert(builder)
.index(MlStatsIndex.writeAlias())
.id(InferenceStats.docId(stats.getModelId(), stats.getNodeId()))
.script(new Script(ScriptType.INLINE, "painless", STATS_UPDATE_SCRIPT, params));
return updateRequest;
} catch (IOException ex) {
logger.error(
() -> new ParameterizedMessage("[{}] [{}] failed to serialize stats for update.",
stats.getModelId(),
stats.getNodeId()),
ex);
}
return null;
}
}

View File

@ -10,6 +10,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@ -17,11 +18,14 @@ import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInference
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.utils.MapHelper; import org.elasticsearch.xpack.core.ml.utils.MapHelper;
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.atomic.LongAdder;
import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING; import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING;
@ -29,19 +33,30 @@ public class LocalModel<T extends InferenceConfig> implements Model<T> {
private final TrainedModelDefinition trainedModelDefinition; private final TrainedModelDefinition trainedModelDefinition;
private final String modelId; private final String modelId;
private final String nodeId;
private final Set<String> fieldNames; private final Set<String> fieldNames;
private final Map<String, String> defaultFieldMap; private final Map<String, String> defaultFieldMap;
private final AtomicReference<InferenceStats.Accumulator> statsAccumulator;
private final TrainedModelStatsService trainedModelStatsService;
private volatile long persistenceQuotient = 100;
private final LongAdder currentInferenceCount;
private final T inferenceConfig; private final T inferenceConfig;
public LocalModel(String modelId, public LocalModel(String modelId,
String nodeId,
TrainedModelDefinition trainedModelDefinition, TrainedModelDefinition trainedModelDefinition,
TrainedModelInput input, TrainedModelInput input,
Map<String, String> defaultFieldMap, Map<String, String> defaultFieldMap,
T modelInferenceConfig) { T modelInferenceConfig,
TrainedModelStatsService trainedModelStatsService ) {
this.trainedModelDefinition = trainedModelDefinition; this.trainedModelDefinition = trainedModelDefinition;
this.modelId = modelId; this.modelId = modelId;
this.nodeId = nodeId;
this.fieldNames = new HashSet<>(input.getFieldNames()); this.fieldNames = new HashSet<>(input.getFieldNames());
this.statsAccumulator = new AtomicReference<>(new InferenceStats.Accumulator(modelId, nodeId));
this.trainedModelStatsService = trainedModelStatsService;
this.defaultFieldMap = defaultFieldMap == null ? null : new HashMap<>(defaultFieldMap); this.defaultFieldMap = defaultFieldMap == null ? null : new HashMap<>(defaultFieldMap);
this.currentInferenceCount = new LongAdder();
this.inferenceConfig = modelInferenceConfig; this.inferenceConfig = modelInferenceConfig;
} }
@ -54,6 +69,12 @@ public class LocalModel<T extends InferenceConfig> implements Model<T> {
return modelId; return modelId;
} }
@Override
public InferenceStats getLatestStatsAndReset() {
InferenceStats.Accumulator toPersist = statsAccumulator.getAndSet(new InferenceStats.Accumulator(modelId, nodeId));
return toPersist.currentStats();
}
@Override @Override
public String getResultsType() { public String getResultsType() {
switch (trainedModelDefinition.getTrainedModel().targetType()) { switch (trainedModelDefinition.getTrainedModel().targetType()) {
@ -68,6 +89,16 @@ public class LocalModel<T extends InferenceConfig> implements Model<T> {
} }
} }
void persistStats() {
trainedModelStatsService.queueStats(getLatestStatsAndReset());
if (persistenceQuotient < 1000 && currentInferenceCount.sum() > 1000) {
persistenceQuotient = 1000;
}
if (persistenceQuotient < 10_000 && currentInferenceCount.sum() > 10_000) {
persistenceQuotient = 10_000;
}
}
@Override @Override
public void infer(Map<String, Object> fields, InferenceConfigUpdate<T> update, ActionListener<InferenceResults> listener) { public void infer(Map<String, Object> fields, InferenceConfigUpdate<T> update, ActionListener<InferenceResults> listener) {
if (update.isSupported(this.inferenceConfig) == false) { if (update.isSupported(this.inferenceConfig) == false) {
@ -79,14 +110,27 @@ public class LocalModel<T extends InferenceConfig> implements Model<T> {
return; return;
} }
try { try {
statsAccumulator.get().incInference();
currentInferenceCount.increment();
Model.mapFieldsIfNecessary(fields, defaultFieldMap); Model.mapFieldsIfNecessary(fields, defaultFieldMap);
boolean shouldPersistStats = ((currentInferenceCount.sum() + 1) % persistenceQuotient == 0);
if (fieldNames.stream().allMatch(f -> MapHelper.dig(f, fields) == null)) { if (fieldNames.stream().allMatch(f -> MapHelper.dig(f, fields) == null)) {
statsAccumulator.get().incMissingFields();
if (shouldPersistStats) {
persistStats();
}
listener.onResponse(new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId))); listener.onResponse(new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId)));
return; return;
} }
InferenceResults inferenceResults = trainedModelDefinition.infer(fields, update.apply(inferenceConfig));
listener.onResponse(trainedModelDefinition.infer(fields, update.apply(inferenceConfig))); if (shouldPersistStats) {
persistStats();
}
listener.onResponse(inferenceResults);
} catch (Exception e) { } catch (Exception e) {
statsAccumulator.get().incFailure();
listener.onFailure(e); listener.onFailure(e);
} }
} }

View File

@ -10,6 +10,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.utils.MapHelper; import org.elasticsearch.xpack.core.ml.utils.MapHelper;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import java.util.Map; import java.util.Map;
@ -43,4 +44,6 @@ public interface Model<T extends InferenceConfig> {
}); });
} }
} }
InferenceStats getLatestStatsAndReset();
} }

View File

@ -8,6 +8,7 @@ package org.elasticsearch.xpack.ml.inference.loadingservice;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.logging.log4j.util.MessageSupplier;
import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.ClusterChangedEvent;
@ -33,6 +34,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
@ -83,6 +85,7 @@ public class ModelLoadingService implements ClusterStateListener {
Setting.Property.NodeScope); Setting.Property.NodeScope);
private static final Logger logger = LogManager.getLogger(ModelLoadingService.class); private static final Logger logger = LogManager.getLogger(ModelLoadingService.class);
private final TrainedModelStatsService modelStatsService;
private final Cache<String, LocalModel<? extends InferenceConfig>> localModelCache; private final Cache<String, LocalModel<? extends InferenceConfig>> localModelCache;
private final Set<String> referencedModels = new HashSet<>(); private final Set<String> referencedModels = new HashSet<>();
private final Map<String, Queue<ActionListener<Model<? extends InferenceConfig>>>> loadingListeners = new HashMap<>(); private final Map<String, Queue<ActionListener<Model<? extends InferenceConfig>>>> loadingListeners = new HashMap<>();
@ -92,17 +95,21 @@ public class ModelLoadingService implements ClusterStateListener {
private final InferenceAuditor auditor; private final InferenceAuditor auditor;
private final ByteSizeValue maxCacheSize; private final ByteSizeValue maxCacheSize;
private final NamedXContentRegistry namedXContentRegistry; private final NamedXContentRegistry namedXContentRegistry;
private final String localNode;
public ModelLoadingService(TrainedModelProvider trainedModelProvider, public ModelLoadingService(TrainedModelProvider trainedModelProvider,
InferenceAuditor auditor, InferenceAuditor auditor,
ThreadPool threadPool, ThreadPool threadPool,
ClusterService clusterService, ClusterService clusterService,
NamedXContentRegistry namedXContentRegistry, NamedXContentRegistry namedXContentRegistry,
Settings settings) { TrainedModelStatsService modelStatsService,
Settings settings,
String localNode) {
this.provider = trainedModelProvider; this.provider = trainedModelProvider;
this.threadPool = threadPool; this.threadPool = threadPool;
this.maxCacheSize = INFERENCE_MODEL_CACHE_SIZE.get(settings); this.maxCacheSize = INFERENCE_MODEL_CACHE_SIZE.get(settings);
this.auditor = auditor; this.auditor = auditor;
this.modelStatsService = modelStatsService;
this.shouldNotAudit = new HashSet<>(); this.shouldNotAudit = new HashSet<>();
this.namedXContentRegistry = namedXContentRegistry; this.namedXContentRegistry = namedXContentRegistry;
this.localModelCache = CacheBuilder.<String, LocalModel<? extends InferenceConfig>>builder() this.localModelCache = CacheBuilder.<String, LocalModel<? extends InferenceConfig>>builder()
@ -113,6 +120,7 @@ public class ModelLoadingService implements ClusterStateListener {
.setExpireAfterAccess(INFERENCE_MODEL_CACHE_TTL.get(settings)) .setExpireAfterAccess(INFERENCE_MODEL_CACHE_TTL.get(settings))
.build(); .build();
clusterService.addListener(this); clusterService.addListener(this);
this.localNode = localNode;
} }
/** /**
@ -136,13 +144,13 @@ public class ModelLoadingService implements ClusterStateListener {
LocalModel<? extends InferenceConfig> cachedModel = localModelCache.get(modelId); LocalModel<? extends InferenceConfig> cachedModel = localModelCache.get(modelId);
if (cachedModel != null) { if (cachedModel != null) {
modelActionListener.onResponse(cachedModel); modelActionListener.onResponse(cachedModel);
logger.trace("[{}] loaded from cache", modelId); logger.trace(() -> new ParameterizedMessage("[{}] loaded from cache", modelId));
return; return;
} }
if (loadModelIfNecessary(modelId, modelActionListener) == false) { if (loadModelIfNecessary(modelId, modelActionListener) == false) {
// If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called // If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called
// by a simulated pipeline // by a simulated pipeline
logger.trace("[{}] not actively loading, eager loading without cache", modelId); logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", modelId));
provider.getTrainedModel(modelId, true, ActionListener.wrap( provider.getTrainedModel(modelId, true, ActionListener.wrap(
trainedModelConfig -> { trainedModelConfig -> {
trainedModelConfig.ensureParsedDefinition(namedXContentRegistry); trainedModelConfig.ensureParsedDefinition(namedXContentRegistry);
@ -151,15 +159,17 @@ public class ModelLoadingService implements ClusterStateListener {
trainedModelConfig.getInferenceConfig(); trainedModelConfig.getInferenceConfig();
modelActionListener.onResponse(new LocalModel<>( modelActionListener.onResponse(new LocalModel<>(
trainedModelConfig.getModelId(), trainedModelConfig.getModelId(),
localNode,
trainedModelConfig.getModelDefinition(), trainedModelConfig.getModelDefinition(),
trainedModelConfig.getInput(), trainedModelConfig.getInput(),
trainedModelConfig.getDefaultFieldMap(), trainedModelConfig.getDefaultFieldMap(),
inferenceConfig)); inferenceConfig,
modelStatsService));
}, },
modelActionListener::onFailure modelActionListener::onFailure
)); ));
} else { } else {
logger.trace("[{}] is loading or loaded, added new listener to queue", modelId); logger.trace(() -> new ParameterizedMessage("[{}] is loading or loaded, added new listener to queue", modelId));
} }
} }
@ -183,7 +193,7 @@ public class ModelLoadingService implements ClusterStateListener {
if (loadingListeners.computeIfPresent( if (loadingListeners.computeIfPresent(
modelId, modelId,
(storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)) == null) { (storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)) == null) {
logger.trace("[{}] attempting to load and cache", modelId); logger.trace(() -> new ParameterizedMessage("[{}] attempting to load and cache", modelId));
loadingListeners.put(modelId, addFluently(new ArrayDeque<>(), modelActionListener)); loadingListeners.put(modelId, addFluently(new ArrayDeque<>(), modelActionListener));
loadModel(modelId); loadModel(modelId);
} }
@ -198,7 +208,7 @@ public class ModelLoadingService implements ClusterStateListener {
private void loadModel(String modelId) { private void loadModel(String modelId) {
provider.getTrainedModel(modelId, true, ActionListener.wrap( provider.getTrainedModel(modelId, true, ActionListener.wrap(
trainedModelConfig -> { trainedModelConfig -> {
logger.debug("[{}] successfully loaded model", modelId); logger.debug(() -> new ParameterizedMessage("[{}] successfully loaded model", modelId));
handleLoadSuccess(modelId, trainedModelConfig); handleLoadSuccess(modelId, trainedModelConfig);
}, },
failure -> { failure -> {
@ -216,10 +226,12 @@ public class ModelLoadingService implements ClusterStateListener {
trainedModelConfig.getInferenceConfig(); trainedModelConfig.getInferenceConfig();
LocalModel<? extends InferenceConfig> loadedModel = new LocalModel<>( LocalModel<? extends InferenceConfig> loadedModel = new LocalModel<>(
trainedModelConfig.getModelId(), trainedModelConfig.getModelId(),
localNode,
trainedModelConfig.getModelDefinition(), trainedModelConfig.getModelDefinition(),
trainedModelConfig.getInput(), trainedModelConfig.getInput(),
trainedModelConfig.getDefaultFieldMap(), trainedModelConfig.getDefaultFieldMap(),
inferenceConfig); inferenceConfig,
modelStatsService);
synchronized (loadingListeners) { synchronized (loadingListeners) {
listeners = loadingListeners.remove(modelId); listeners = loadingListeners.remove(modelId);
// If there is no loadingListener that means the loading was canceled and the listener was already notified as such // If there is no loadingListener that means the loading was canceled and the listener was already notified as such
@ -252,7 +264,7 @@ public class ModelLoadingService implements ClusterStateListener {
private void cacheEvictionListener(RemovalNotification<String, LocalModel<? extends InferenceConfig>> notification) { private void cacheEvictionListener(RemovalNotification<String, LocalModel<? extends InferenceConfig>> notification) {
if (notification.getRemovalReason() == RemovalNotification.RemovalReason.EVICTED) { if (notification.getRemovalReason() == RemovalNotification.RemovalReason.EVICTED) {
String msg = new ParameterizedMessage( MessageSupplier msg = () -> new ParameterizedMessage(
"model cache entry evicted." + "model cache entry evicted." +
"current cache [{}] current max [{}] model size [{}]. " + "current cache [{}] current max [{}] model size [{}]. " +
"If this is undesired, consider updating setting [{}] or [{}].", "If this is undesired, consider updating setting [{}] or [{}].",
@ -260,9 +272,10 @@ public class ModelLoadingService implements ClusterStateListener {
maxCacheSize.getStringRep(), maxCacheSize.getStringRep(),
new ByteSizeValue(notification.getValue().ramBytesUsed()).getStringRep(), new ByteSizeValue(notification.getValue().ramBytesUsed()).getStringRep(),
INFERENCE_MODEL_CACHE_SIZE.getKey(), INFERENCE_MODEL_CACHE_SIZE.getKey(),
INFERENCE_MODEL_CACHE_TTL.getKey()).getFormattedMessage(); INFERENCE_MODEL_CACHE_TTL.getKey());
auditIfNecessary(notification.getKey(), msg); auditIfNecessary(notification.getKey(), msg);
} }
notification.getValue().persistStats();
} }
@Override @Override
@ -335,14 +348,14 @@ public class ModelLoadingService implements ClusterStateListener {
loadModels(allReferencedModelKeys); loadModels(allReferencedModelKeys);
} }
private void auditIfNecessary(String modelId, String msg) { private void auditIfNecessary(String modelId, MessageSupplier msg) {
if (shouldNotAudit.contains(modelId)) { if (shouldNotAudit.contains(modelId)) {
logger.trace("[{}] {}", modelId, msg); logger.trace(() -> new ParameterizedMessage("[{}] {}", modelId, msg.get().getFormattedMessage()));
return; return;
} }
auditor.warning(modelId, msg); auditor.warning(modelId, msg.get().getFormattedMessage());
shouldNotAudit.add(modelId); shouldNotAudit.add(modelId);
logger.warn("[{}] {}", modelId, msg); logger.warn("[{}] {}", modelId, msg.get().getFormattedMessage());
} }
private void loadModels(Set<String> modelIds) { private void loadModels(Set<String> modelIds) {

View File

@ -19,6 +19,7 @@ import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.MultiSearchAction; import org.elasticsearch.action.search.MultiSearchAction;
import org.elasticsearch.action.search.MultiSearchRequest;
import org.elasticsearch.action.search.MultiSearchRequestBuilder; import org.elasticsearch.action.search.MultiSearchRequestBuilder;
import org.elasticsearch.action.search.MultiSearchResponse; import org.elasticsearch.action.search.MultiSearchResponse;
import org.elasticsearch.action.search.SearchAction; import org.elasticsearch.action.search.SearchAction;
@ -53,14 +54,19 @@ import org.elasticsearch.index.reindex.DeleteByQueryAction;
import org.elasticsearch.index.reindex.DeleteByQueryRequest; import org.elasticsearch.index.reindex.DeleteByQueryRequest;
import org.elasticsearch.rest.RestStatus; import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.metrics.Max;
import org.elasticsearch.search.aggregations.metrics.Sum;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.sort.SortBuilders; import org.elasticsearch.search.sort.SortBuilders;
import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.xpack.core.action.util.ExpandedIdsMatcher; import org.elasticsearch.xpack.core.action.util.ExpandedIdsMatcher;
import org.elasticsearch.xpack.core.action.util.PageParams; import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
@ -68,7 +74,9 @@ import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.net.URL; import java.net.URL;
import java.time.Instant;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator; import java.util.Comparator;
@ -357,7 +365,7 @@ public class TrainedModelProvider {
} }
DeleteByQueryRequest request = new DeleteByQueryRequest().setAbortOnVersionConflict(false); DeleteByQueryRequest request = new DeleteByQueryRequest().setAbortOnVersionConflict(false);
request.indices(InferenceIndexConstants.INDEX_PATTERN); request.indices(InferenceIndexConstants.INDEX_PATTERN, MlStatsIndex.indexPattern());
QueryBuilder query = QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId); QueryBuilder query = QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId);
request.setQuery(query); request.setQuery(query);
request.setRefresh(true); request.setRefresh(true);
@ -454,6 +462,99 @@ public class TrainedModelProvider {
client::search); client::search);
} }
public void getInferenceStats(String[] modelIds, ActionListener<List<InferenceStats>> listener) {
MultiSearchRequest multiSearchRequest = new MultiSearchRequest();
Arrays.stream(modelIds).map(this::buildStatsSearchRequest).forEach(multiSearchRequest::add);
if (multiSearchRequest.requests().isEmpty()) {
listener.onResponse(Collections.emptyList());
return;
}
executeAsyncWithOrigin(client.threadPool().getThreadContext(),
ML_ORIGIN,
multiSearchRequest,
ActionListener.<MultiSearchResponse>wrap(
responses -> {
List<InferenceStats> allStats = new ArrayList<>(modelIds.length);
int modelIndex = 0;
assert responses.getResponses().length == modelIds.length :
"mismatch between search response size and models requested";
for (MultiSearchResponse.Item response : responses.getResponses()) {
if (response.isFailure()) {
if (ExceptionsHelper.unwrapCause(response.getFailure()) instanceof ResourceNotFoundException) {
modelIndex++;
continue;
}
logger.error(new ParameterizedMessage("[{}] search failed for models",
Strings.arrayToCommaDelimitedString(modelIds)),
response.getFailure());
listener.onFailure(ExceptionsHelper.serverError("Searching for stats for models [{}] failed",
response.getFailure(),
Strings.arrayToCommaDelimitedString(modelIds)));
return;
}
try {
InferenceStats inferenceStats = handleMultiNodeStatsResponse(response.getResponse(), modelIds[modelIndex++]);
if (inferenceStats != null) {
allStats.add(inferenceStats);
}
} catch (Exception e) {
listener.onFailure(e);
return;
}
}
listener.onResponse(allStats);
},
e -> {
Throwable unwrapped = ExceptionsHelper.unwrapCause(e);
if (unwrapped instanceof ResourceNotFoundException) {
listener.onResponse(Collections.emptyList());
return;
}
listener.onFailure((Exception)unwrapped);
}
),
client::multiSearch);
}
private SearchRequest buildStatsSearchRequest(String modelId) {
BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery(InferenceStats.MODEL_ID.getPreferredName(), modelId))
.filter(QueryBuilders.termQuery(InferenceStats.TYPE.getPreferredName(), InferenceStats.NAME));
return new SearchRequest(MlStatsIndex.indexPattern())
.indicesOptions(IndicesOptions.lenientExpandOpen())
.allowPartialSearchResults(false)
.source(SearchSourceBuilder.searchSource()
.size(0)
.aggregation(AggregationBuilders.sum(InferenceStats.FAILURE_COUNT.getPreferredName())
.field(InferenceStats.FAILURE_COUNT.getPreferredName()))
.aggregation(AggregationBuilders.sum(InferenceStats.MISSING_ALL_FIELDS_COUNT.getPreferredName())
.field(InferenceStats.MISSING_ALL_FIELDS_COUNT.getPreferredName()))
.aggregation(AggregationBuilders.sum(InferenceStats.INFERENCE_COUNT.getPreferredName())
.field(InferenceStats.INFERENCE_COUNT.getPreferredName()))
.aggregation(AggregationBuilders.max(InferenceStats.TIMESTAMP.getPreferredName())
.field(InferenceStats.TIMESTAMP.getPreferredName()))
.query(queryBuilder));
}
private InferenceStats handleMultiNodeStatsResponse(SearchResponse response, String modelId) {
if (response.getAggregations() == null) {
logger.trace(() -> new ParameterizedMessage("[{}] no previously stored stats found", modelId));
return null;
}
Sum failures = response.getAggregations().get(InferenceStats.FAILURE_COUNT.getPreferredName());
Sum missing = response.getAggregations().get(InferenceStats.MISSING_ALL_FIELDS_COUNT.getPreferredName());
Sum count = response.getAggregations().get(InferenceStats.INFERENCE_COUNT.getPreferredName());
Max timeStamp = response.getAggregations().get(InferenceStats.TIMESTAMP.getPreferredName());
return new InferenceStats(
missing == null ? 0L : Double.valueOf(missing.getValue()).longValue(),
count == null ? 0L : Double.valueOf(count.getValue()).longValue(),
failures == null ? 0L : Double.valueOf(failures.getValue()).longValue(),
modelId,
null,
timeStamp == null ? Instant.now() : Instant.ofEpochMilli(Double.valueOf(timeStamp.getValue()).longValue())
);
}
static Set<String> collectIds(PageParams pageParams, Set<String> foundFromResources, Set<String> foundFromDocs) { static Set<String> collectIds(PageParams pageParams, Set<String> foundFromResources, Set<String> foundFromDocs) {
// If there are no matching resource models, there was no buffering and the models from the docs // If there are no matching resource models, there was no buffering and the models from the docs
// are paginated correctly. // are paginated correctly.

View File

@ -126,7 +126,7 @@ public class TransportGetTrainedModelsStatsActionTests extends ESTestCase {
} }
public void testInferenceIngestStatsByPipelineId() throws IOException { public void testInferenceIngestStatsByModelId() {
List<NodeStats> nodeStatsList = Arrays.asList( List<NodeStats> nodeStatsList = Arrays.asList(
buildNodeStats( buildNodeStats(
new IngestStats.Stats(2, 2, 3, 4), new IngestStats.Stats(2, 2, 3, 4),
@ -193,7 +193,7 @@ public class TransportGetTrainedModelsStatsActionTests extends ESTestCase {
put("trained_model_1", Collections.singleton("pipeline1")); put("trained_model_1", Collections.singleton("pipeline1"));
put("trained_model_2", new HashSet<>(Arrays.asList("pipeline1", "pipeline2"))); put("trained_model_2", new HashSet<>(Arrays.asList("pipeline1", "pipeline2")));
}}; }};
Map<String, IngestStats> ingestStatsMap = TransportGetTrainedModelsStatsAction.inferenceIngestStatsByPipelineId(response, Map<String, IngestStats> ingestStatsMap = TransportGetTrainedModelsStatsAction.inferenceIngestStatsByModelId(response,
pipelineIdsByModelIds); pipelineIdsByModelIds);
assertThat(ingestStatsMap.keySet(), equalTo(new HashSet<>(Arrays.asList("trained_model_1", "trained_model_2")))); assertThat(ingestStatsMap.keySet(), equalTo(new HashSet<>(Arrays.asList("trained_model_1", "trained_model_2"))));

View File

@ -15,6 +15,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
@ -28,6 +29,8 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
import org.mockito.ArgumentMatcher;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
@ -39,10 +42,18 @@ import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.hasSize;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.argThat;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.internal.verification.VerificationModeFactory.times;
public class LocalModelTests extends ESTestCase { public class LocalModelTests extends ESTestCase {
public void testClassificationInfer() throws Exception { public void testClassificationInfer() throws Exception {
TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class);
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class));
String modelId = "classification_model"; String modelId = "classification_model";
List<String> inputFields = Arrays.asList("field.foo.keyword", "field.bar", "categorical"); List<String> inputFields = Arrays.asList("field.foo.keyword", "field.bar", "categorical");
TrainedModelDefinition definition = new TrainedModelDefinition.Builder() TrainedModelDefinition definition = new TrainedModelDefinition.Builder()
@ -51,10 +62,12 @@ public class LocalModelTests extends ESTestCase {
.build(); .build();
Model<ClassificationConfig> model = new LocalModel<>(modelId, Model<ClassificationConfig> model = new LocalModel<>(modelId,
"test-node",
definition, definition,
new TrainedModelInput(inputFields), new TrainedModelInput(inputFields),
Collections.singletonMap("field.foo", "field.foo.keyword"), Collections.singletonMap("field.foo", "field.foo.keyword"),
ClassificationConfig.EMPTY_PARAMS); ClassificationConfig.EMPTY_PARAMS,
modelStatsService);
Map<String, Object> fields = new HashMap<String, Object>() {{ Map<String, Object> fields = new HashMap<String, Object>() {{
put("field.foo", 1.0); put("field.foo", 1.0);
put("field.bar", 0.5); put("field.bar", 0.5);
@ -64,11 +77,13 @@ public class LocalModelTests extends ESTestCase {
SingleValueInferenceResults result = getSingleValue(model, fields, new ClassificationConfigUpdate(0, null, null, null)); SingleValueInferenceResults result = getSingleValue(model, fields, new ClassificationConfigUpdate(0, null, null, null));
assertThat(result.value(), equalTo(0.0)); assertThat(result.value(), equalTo(0.0));
assertThat(result.valueAsString(), is("0")); assertThat(result.valueAsString(), is("0"));
assertThat(model.getLatestStatsAndReset().getInferenceCount(), equalTo(1L));
ClassificationInferenceResults classificationResult = ClassificationInferenceResults classificationResult =
(ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfigUpdate(1, null, null, null)); (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfigUpdate(1, null, null, null));
assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001));
assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("0")); assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("0"));
assertThat(model.getLatestStatsAndReset().getInferenceCount(), equalTo(1L));
// Test with labels // Test with labels
definition = new TrainedModelDefinition.Builder() definition = new TrainedModelDefinition.Builder()
@ -76,10 +91,12 @@ public class LocalModelTests extends ESTestCase {
.setTrainedModel(buildClassification(true)) .setTrainedModel(buildClassification(true))
.build(); .build();
model = new LocalModel<>(modelId, model = new LocalModel<>(modelId,
"test-node",
definition, definition,
new TrainedModelInput(inputFields), new TrainedModelInput(inputFields),
Collections.singletonMap("field.foo", "field.foo.keyword"), Collections.singletonMap("field.foo", "field.foo.keyword"),
ClassificationConfig.EMPTY_PARAMS); ClassificationConfig.EMPTY_PARAMS,
modelStatsService);
result = getSingleValue(model, fields, new ClassificationConfigUpdate(0, null, null, null)); result = getSingleValue(model, fields, new ClassificationConfigUpdate(0, null, null, null));
assertThat(result.value(), equalTo(0.0)); assertThat(result.value(), equalTo(0.0));
assertThat(result.valueAsString(), equalTo("not_to_be")); assertThat(result.valueAsString(), equalTo("not_to_be"));
@ -89,29 +106,36 @@ public class LocalModelTests extends ESTestCase {
new ClassificationConfigUpdate(1, null, null, null)); new ClassificationConfigUpdate(1, null, null, null));
assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001));
assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("not_to_be")); assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("not_to_be"));
assertThat(model.getLatestStatsAndReset().getInferenceCount(), equalTo(2L));
classificationResult = (ClassificationInferenceResults)getSingleValue(model, classificationResult = (ClassificationInferenceResults)getSingleValue(model,
fields, fields,
new ClassificationConfigUpdate(2, null, null, null)); new ClassificationConfigUpdate(2, null, null, null));
assertThat(classificationResult.getTopClasses(), hasSize(2)); assertThat(classificationResult.getTopClasses(), hasSize(2));
assertThat(model.getLatestStatsAndReset().getInferenceCount(), equalTo(1L));
classificationResult = (ClassificationInferenceResults)getSingleValue(model, classificationResult = (ClassificationInferenceResults)getSingleValue(model,
fields, fields,
new ClassificationConfigUpdate(-1, null, null, null)); new ClassificationConfigUpdate(-1, null, null, null));
assertThat(classificationResult.getTopClasses(), hasSize(2)); assertThat(classificationResult.getTopClasses(), hasSize(2));
assertThat(model.getLatestStatsAndReset().getInferenceCount(), equalTo(1L));
} }
public void testRegression() throws Exception { public void testRegression() throws Exception {
TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class);
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class));
List<String> inputFields = Arrays.asList("foo", "bar", "categorical"); List<String> inputFields = Arrays.asList("foo", "bar", "categorical");
TrainedModelDefinition trainedModelDefinition = new TrainedModelDefinition.Builder() TrainedModelDefinition trainedModelDefinition = new TrainedModelDefinition.Builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
.setTrainedModel(buildRegression()) .setTrainedModel(buildRegression())
.build(); .build();
Model<RegressionConfig> model = new LocalModel<>("regression_model", Model<RegressionConfig> model = new LocalModel<>("regression_model",
"test-node",
trainedModelDefinition, trainedModelDefinition,
new TrainedModelInput(inputFields), new TrainedModelInput(inputFields),
Collections.singletonMap("bar", "bar.keyword"), Collections.singletonMap("bar", "bar.keyword"),
RegressionConfig.EMPTY_PARAMS); RegressionConfig.EMPTY_PARAMS,
modelStatsService);
Map<String, Object> fields = new HashMap<String, Object>() {{ Map<String, Object> fields = new HashMap<String, Object>() {{
put("foo", 1.0); put("foo", 1.0);
@ -124,6 +148,8 @@ public class LocalModelTests extends ESTestCase {
} }
public void testAllFieldsMissing() throws Exception { public void testAllFieldsMissing() throws Exception {
TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class);
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class));
List<String> inputFields = Arrays.asList("foo", "bar", "categorical"); List<String> inputFields = Arrays.asList("foo", "bar", "categorical");
TrainedModelDefinition trainedModelDefinition = new TrainedModelDefinition.Builder() TrainedModelDefinition trainedModelDefinition = new TrainedModelDefinition.Builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
@ -131,10 +157,12 @@ public class LocalModelTests extends ESTestCase {
.build(); .build();
Model<RegressionConfig> model = new LocalModel<>( Model<RegressionConfig> model = new LocalModel<>(
"regression_model", "regression_model",
"test-node",
trainedModelDefinition, trainedModelDefinition,
new TrainedModelInput(inputFields), new TrainedModelInput(inputFields),
null, null,
RegressionConfig.EMPTY_PARAMS); RegressionConfig.EMPTY_PARAMS,
modelStatsService);
Map<String, Object> fields = new HashMap<String, Object>() {{ Map<String, Object> fields = new HashMap<String, Object>() {{
put("something", 1.0); put("something", 1.0);
@ -145,6 +173,47 @@ public class LocalModelTests extends ESTestCase {
WarningInferenceResults results = (WarningInferenceResults)getInferenceResult(model, fields, RegressionConfigUpdate.EMPTY_PARAMS); WarningInferenceResults results = (WarningInferenceResults)getInferenceResult(model, fields, RegressionConfigUpdate.EMPTY_PARAMS);
assertThat(results.getWarning(), assertThat(results.getWarning(),
equalTo(Messages.getMessage(Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING, "regression_model"))); equalTo(Messages.getMessage(Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING, "regression_model")));
assertThat(model.getLatestStatsAndReset().getMissingAllFieldsCount(), equalTo(1L));
}
public void testInferPersistsStatsAfterNumberOfCalls() throws Exception {
TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class);
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class));
String modelId = "classification_model";
List<String> inputFields = Arrays.asList("field.foo", "field.bar", "categorical");
TrainedModelDefinition definition = new TrainedModelDefinition.Builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
.setTrainedModel(buildClassification(false))
.build();
Model<ClassificationConfig> model = new LocalModel<>(modelId,
"test-node",
definition,
new TrainedModelInput(inputFields),
null,
ClassificationConfig.EMPTY_PARAMS,
modelStatsService
);
Map<String, Object> fields = new HashMap<String, Object>() {{
put("field.foo", 1.0);
put("field.bar", 0.5);
put("categorical", "dog");
}};
for(int i = 0; i < 100; i++) {
getSingleValue(model, fields, new ClassificationConfigUpdate(0, null, null, null));
}
SingleValueInferenceResults result = getSingleValue(model, fields, new ClassificationConfigUpdate(0, null, null, null));
assertThat(result.value(), equalTo(0.0));
assertThat(result.valueAsString(), is("0"));
// Should have reset after persistence, so only 2 docs have been seen since last persistence
assertThat(model.getLatestStatsAndReset().getInferenceCount(), equalTo(2L));
verify(modelStatsService, times(1)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
@Override
public boolean matches(Object o) {
return ((InferenceStats)o).getInferenceCount() == 99L;
}
}));
} }
private static <T extends InferenceConfig> SingleValueInferenceResults getSingleValue(Model<T> model, private static <T extends InferenceConfig> SingleValueInferenceResults getSingleValue(Model<T> model,

View File

@ -36,14 +36,17 @@ import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.mockito.ArgumentMatcher;
import org.mockito.Mockito; import org.mockito.Mockito;
import java.io.IOException; import java.io.IOException;
@ -59,10 +62,12 @@ import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.argThat;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.atMost; import static org.mockito.Mockito.atMost;
import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -73,6 +78,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
private ThreadPool threadPool; private ThreadPool threadPool;
private ClusterService clusterService; private ClusterService clusterService;
private InferenceAuditor auditor; private InferenceAuditor auditor;
private TrainedModelStatsService trainedModelStatsService;
@Before @Before
public void setUpComponents() { public void setUpComponents() {
@ -81,6 +87,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
trainedModelProvider = mock(TrainedModelProvider.class); trainedModelProvider = mock(TrainedModelProvider.class);
clusterService = mock(ClusterService.class); clusterService = mock(ClusterService.class);
auditor = mock(InferenceAuditor.class); auditor = mock(InferenceAuditor.class);
trainedModelStatsService = mock(TrainedModelStatsService.class);
doAnswer(a -> null).when(auditor).error(any(String.class), any(String.class)); doAnswer(a -> null).when(auditor).error(any(String.class), any(String.class));
doAnswer(a -> null).when(auditor).info(any(String.class), any(String.class)); doAnswer(a -> null).when(auditor).info(any(String.class), any(String.class));
doAnswer(a -> null).when(auditor).warning(any(String.class), any(String.class)); doAnswer(a -> null).when(auditor).warning(any(String.class), any(String.class));
@ -106,7 +113,9 @@ public class ModelLoadingServiceTests extends ESTestCase {
threadPool, threadPool,
clusterService, clusterService,
NamedXContentRegistry.EMPTY, NamedXContentRegistry.EMPTY,
Settings.EMPTY); trainedModelStatsService,
Settings.EMPTY,
"test-node");
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3)); modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3));
@ -131,6 +140,12 @@ public class ModelLoadingServiceTests extends ESTestCase {
assertThat(future.get(), is(not(nullValue()))); assertThat(future.get(), is(not(nullValue())));
} }
verify(trainedModelStatsService, times(1)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
@Override
public boolean matches(final Object o) {
return ((InferenceStats)o).getModelId().equals(model3);
}
}));
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(true), any()); verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(true), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(true), any()); verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(true), any());
// It is not referenced, so called eagerly // It is not referenced, so called eagerly
@ -150,7 +165,9 @@ public class ModelLoadingServiceTests extends ESTestCase {
threadPool, threadPool,
clusterService, clusterService,
NamedXContentRegistry.EMPTY, NamedXContentRegistry.EMPTY,
Settings.builder().put(ModelLoadingService.INFERENCE_MODEL_CACHE_SIZE.getKey(), new ByteSizeValue(20L)).build()); trainedModelStatsService,
Settings.builder().put(ModelLoadingService.INFERENCE_MODEL_CACHE_SIZE.getKey(), new ByteSizeValue(20L)).build(),
"test-node");
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3)); modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3));
@ -175,6 +192,24 @@ public class ModelLoadingServiceTests extends ESTestCase {
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model2), eq(true), any()); verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model2), eq(true), any());
// Only loaded requested once on the initial load from the change event // Only loaded requested once on the initial load from the change event
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any()); verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any());
verify(trainedModelStatsService, atMost(2)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
@Override
public boolean matches(final Object o) {
return ((InferenceStats)o).getModelId().equals(model1);
}
}));
verify(trainedModelStatsService, atMost(2)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
@Override
public boolean matches(final Object o) {
return ((InferenceStats)o).getModelId().equals(model2);
}
}));
verify(trainedModelStatsService, times(1)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
@Override
public boolean matches(final Object o) {
return ((InferenceStats)o).getModelId().equals(model3);
}
}));
// Load model 3, should invalidate 1 // Load model 3, should invalidate 1
for(int i = 0; i < 10; i++) { for(int i = 0; i < 10; i++) {
@ -214,7 +249,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model1), eq(true), any()); verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model1), eq(true), any());
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model2), eq(true), any()); verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model2), eq(true), any());
verify(trainedModelProvider, Mockito.atLeast(4)).getTrainedModel(eq(model3), eq(true), any()); verify(trainedModelProvider, Mockito.atLeast(4)).getTrainedModel(eq(model3), eq(true), any());
verify(trainedModelProvider, Mockito.atMost(5)).getTrainedModel(eq(model3), eq(true), any()); verify(trainedModelProvider, atMost(5)).getTrainedModel(eq(model3), eq(true), any());
} }
@ -227,7 +262,9 @@ public class ModelLoadingServiceTests extends ESTestCase {
threadPool, threadPool,
clusterService, clusterService,
NamedXContentRegistry.EMPTY, NamedXContentRegistry.EMPTY,
Settings.EMPTY); trainedModelStatsService,
Settings.EMPTY,
"test-node");
modelLoadingService.clusterChanged(ingestChangedEvent(false, model1)); modelLoadingService.clusterChanged(ingestChangedEvent(false, model1));
@ -238,6 +275,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
} }
verify(trainedModelProvider, times(10)).getTrainedModel(eq(model1), eq(true), any()); verify(trainedModelProvider, times(10)).getTrainedModel(eq(model1), eq(true), any());
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class));
} }
public void testGetCachedMissingModel() throws Exception { public void testGetCachedMissingModel() throws Exception {
@ -249,7 +287,9 @@ public class ModelLoadingServiceTests extends ESTestCase {
threadPool, threadPool,
clusterService, clusterService,
NamedXContentRegistry.EMPTY, NamedXContentRegistry.EMPTY,
Settings.EMPTY); trainedModelStatsService,
Settings.EMPTY,
"test-node");
modelLoadingService.clusterChanged(ingestChangedEvent(model)); modelLoadingService.clusterChanged(ingestChangedEvent(model));
PlainActionFuture<Model<? extends InferenceConfig>> future = new PlainActionFuture<>(); PlainActionFuture<Model<? extends InferenceConfig>> future = new PlainActionFuture<>();
@ -263,6 +303,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
} }
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model), eq(true), any()); verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model), eq(true), any());
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class));
} }
public void testGetMissingModel() { public void testGetMissingModel() {
@ -274,7 +315,9 @@ public class ModelLoadingServiceTests extends ESTestCase {
threadPool, threadPool,
clusterService, clusterService,
NamedXContentRegistry.EMPTY, NamedXContentRegistry.EMPTY,
Settings.EMPTY); trainedModelStatsService,
Settings.EMPTY,
"test-node");
PlainActionFuture<Model<? extends InferenceConfig>> future = new PlainActionFuture<>(); PlainActionFuture<Model<? extends InferenceConfig>> future = new PlainActionFuture<>();
modelLoadingService.getModel(model, future); modelLoadingService.getModel(model, future);
@ -295,7 +338,9 @@ public class ModelLoadingServiceTests extends ESTestCase {
threadPool, threadPool,
clusterService, clusterService,
NamedXContentRegistry.EMPTY, NamedXContentRegistry.EMPTY,
Settings.EMPTY); trainedModelStatsService,
Settings.EMPTY,
"test-node");
for(int i = 0; i < 3; i++) { for(int i = 0; i < 3; i++) {
PlainActionFuture<Model<? extends InferenceConfig>> future = new PlainActionFuture<>(); PlainActionFuture<Model<? extends InferenceConfig>> future = new PlainActionFuture<>();
@ -304,6 +349,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
} }
verify(trainedModelProvider, times(3)).getTrainedModel(eq(model), eq(true), any()); verify(trainedModelProvider, times(3)).getTrainedModel(eq(model), eq(true), any());
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class));
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@ -312,6 +358,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
when(definition.ramBytesUsed()).thenReturn(size); when(definition.ramBytesUsed()).thenReturn(size);
TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class); TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class);
when(trainedModelConfig.getModelDefinition()).thenReturn(definition); when(trainedModelConfig.getModelDefinition()).thenReturn(definition);
when(trainedModelConfig.getModelId()).thenReturn(modelId);
when(trainedModelConfig.getInferenceConfig()).thenReturn(ClassificationConfig.EMPTY_PARAMS); when(trainedModelConfig.getInferenceConfig()).thenReturn(ClassificationConfig.EMPTY_PARAMS);
when(trainedModelConfig.getInput()).thenReturn(new TrainedModelInput(Arrays.asList("foo", "bar", "baz"))); when(trainedModelConfig.getInput()).thenReturn(new TrainedModelInput(Arrays.asList("foo", "bar", "baz")));
doAnswer(invocationOnMock -> { doAnswer(invocationOnMock -> {