[7.x] [ML] calculate cache misses for inference and return in stats (#58252) (#58363)

When a local model is constructed, the cache hit miss count is incremented.

When a user calls _stats, we will include the sum cache hit miss count across ALL nodes. This statistic is important to in comparing against the inference_count. If the cache hit miss count is near the inference_count it indicates that the cache is overburdened, or inappropriately configured.
This commit is contained in:
Benjamin Trent 2020-06-19 09:46:51 -04:00 committed by GitHub
parent d8dc638a67
commit bf8641aa15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 471 additions and 86 deletions

View File

@ -18,6 +18,7 @@
*/
package org.elasticsearch.client.ml.inference;
import org.elasticsearch.client.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
@ -38,32 +39,36 @@ public class TrainedModelStats implements ToXContentObject {
public static final ParseField MODEL_ID = new ParseField("model_id");
public static final ParseField PIPELINE_COUNT = new ParseField("pipeline_count");
public static final ParseField INGEST_STATS = new ParseField("ingest");
public static final ParseField INFERENCE_STATS = new ParseField("inference_stats");
private final String modelId;
private final Map<String, Object> ingestStats;
private final int pipelineCount;
private final InferenceStats inferenceStats;
@SuppressWarnings("unchecked")
static final ConstructingObjectParser<TrainedModelStats, Void> PARSER =
new ConstructingObjectParser<>(
"trained_model_stats",
true,
args -> new TrainedModelStats((String) args[0], (Map<String, Object>) args[1], (Integer) args[2]));
args -> new TrainedModelStats((String) args[0], (Map<String, Object>) args[1], (Integer) args[2], (InferenceStats) args[3]));
static {
PARSER.declareString(constructorArg(), MODEL_ID);
PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.mapOrdered(), INGEST_STATS);
PARSER.declareInt(constructorArg(), PIPELINE_COUNT);
PARSER.declareObject(optionalConstructorArg(), InferenceStats.PARSER, INFERENCE_STATS);
}
public static TrainedModelStats fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
public TrainedModelStats(String modelId, Map<String, Object> ingestStats, int pipelineCount) {
public TrainedModelStats(String modelId, Map<String, Object> ingestStats, int pipelineCount, InferenceStats inferenceStats) {
this.modelId = modelId;
this.ingestStats = ingestStats;
this.pipelineCount = pipelineCount;
this.inferenceStats = inferenceStats;
}
/**
@ -89,6 +94,13 @@ public class TrainedModelStats implements ToXContentObject {
return pipelineCount;
}
/**
* Inference statistics
*/
public InferenceStats getInferenceStats() {
return inferenceStats;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
@ -97,13 +109,16 @@ public class TrainedModelStats implements ToXContentObject {
if (ingestStats != null) {
builder.field(INGEST_STATS.getPreferredName(), ingestStats);
}
if (inferenceStats != null) {
builder.field(INFERENCE_STATS.getPreferredName(), inferenceStats);
}
builder.endObject();
return builder;
}
@Override
public int hashCode() {
return Objects.hash(modelId, ingestStats, pipelineCount);
return Objects.hash(modelId, ingestStats, pipelineCount, inferenceStats);
}
@Override
@ -117,7 +132,8 @@ public class TrainedModelStats implements ToXContentObject {
TrainedModelStats other = (TrainedModelStats) obj;
return Objects.equals(this.modelId, other.modelId)
&& Objects.equals(this.ingestStats, other.ingestStats)
&& Objects.equals(this.pipelineCount, other.pipelineCount);
&& Objects.equals(this.pipelineCount, other.pipelineCount)
&& Objects.equals(this.inferenceStats, other.inferenceStats);
}
}

View File

@ -0,0 +1,170 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.trainedmodel;
import org.elasticsearch.client.common.TimeUtil;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
import java.time.Instant;
import java.util.Objects;
public class InferenceStats implements ToXContentObject {
public static final String NAME = "inference_stats";
public static final ParseField MISSING_ALL_FIELDS_COUNT = new ParseField("missing_all_fields_count");
public static final ParseField INFERENCE_COUNT = new ParseField("inference_count");
public static final ParseField CACHE_MISS_COUNT = new ParseField("cache_miss_count");
public static final ParseField FAILURE_COUNT = new ParseField("failure_count");
public static final ParseField TIMESTAMP = new ParseField("timestamp");
public static final ConstructingObjectParser<InferenceStats, Void> PARSER = new ConstructingObjectParser<>(
NAME,
true,
a -> new InferenceStats((Long)a[0], (Long)a[1], (Long)a[2], (Long)a[3], (Instant)a[4])
);
static {
PARSER.declareLong(ConstructingObjectParser.constructorArg(), MISSING_ALL_FIELDS_COUNT);
PARSER.declareLong(ConstructingObjectParser.constructorArg(), INFERENCE_COUNT);
PARSER.declareLong(ConstructingObjectParser.constructorArg(), FAILURE_COUNT);
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), CACHE_MISS_COUNT);
PARSER.declareField(ConstructingObjectParser.constructorArg(),
p -> TimeUtil.parseTimeFieldToInstant(p, TIMESTAMP.getPreferredName()),
TIMESTAMP,
ObjectParser.ValueType.VALUE);
}
private final long missingAllFieldsCount;
private final long inferenceCount;
private final long failureCount;
private final long cacheMissCount;
private final Instant timeStamp;
private InferenceStats(Long missingAllFieldsCount,
Long inferenceCount,
Long failureCount,
Long cacheMissCount,
Instant instant) {
this(unboxOrZero(missingAllFieldsCount),
unboxOrZero(inferenceCount),
unboxOrZero(failureCount),
unboxOrZero(cacheMissCount),
instant);
}
public InferenceStats(long missingAllFieldsCount,
long inferenceCount,
long failureCount,
long cacheMissCount,
Instant timeStamp) {
this.missingAllFieldsCount = missingAllFieldsCount;
this.inferenceCount = inferenceCount;
this.failureCount = failureCount;
this.cacheMissCount = cacheMissCount;
this.timeStamp = timeStamp == null ?
Instant.ofEpochMilli(Instant.now().toEpochMilli()) :
Instant.ofEpochMilli(timeStamp.toEpochMilli());
}
/**
* How many times this model attempted to infer with all its fields missing
*/
public long getMissingAllFieldsCount() {
return missingAllFieldsCount;
}
/**
* How many inference calls were made against this model
*/
public long getInferenceCount() {
return inferenceCount;
}
/**
* How many inference failures occurred.
*/
public long getFailureCount() {
return failureCount;
}
/**
* How many cache misses occurred when inferring this model
*/
public long getCacheMissCount() {
return cacheMissCount;
}
/**
* The timestamp of these statistics.
*/
public Instant getTimeStamp() {
return timeStamp;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FAILURE_COUNT.getPreferredName(), failureCount);
builder.field(INFERENCE_COUNT.getPreferredName(), inferenceCount);
builder.field(CACHE_MISS_COUNT.getPreferredName(), cacheMissCount);
builder.field(MISSING_ALL_FIELDS_COUNT.getPreferredName(), missingAllFieldsCount);
builder.timeField(TIMESTAMP.getPreferredName(), TIMESTAMP.getPreferredName() + "_string", timeStamp.toEpochMilli());
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InferenceStats that = (InferenceStats) o;
return missingAllFieldsCount == that.missingAllFieldsCount
&& inferenceCount == that.inferenceCount
&& failureCount == that.failureCount
&& cacheMissCount == that.cacheMissCount
&& Objects.equals(timeStamp, that.timeStamp);
}
@Override
public int hashCode() {
return Objects.hash(missingAllFieldsCount, inferenceCount, failureCount, cacheMissCount, timeStamp);
}
@Override
public String toString() {
return "InferenceStats{" +
"missingAllFieldsCount=" + missingAllFieldsCount +
", inferenceCount=" + inferenceCount +
", failureCount=" + failureCount +
", cacheMissCount=" + cacheMissCount +
", timeStamp=" + timeStamp +
'}';
}
private static long unboxOrZero(@Nullable Long value) {
return value == null ? 0L : value;
}
}

View File

@ -18,6 +18,7 @@
*/
package org.elasticsearch.client.ml.inference;
import org.elasticsearch.client.ml.inference.trainedmodel.InferenceStatsTests;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
@ -58,7 +59,8 @@ public class TrainedModelStatsTests extends AbstractXContentTestCase<TrainedMode
return new TrainedModelStats(
randomAlphaOfLength(10),
randomBoolean() ? null : randomIngestStats(),
randomInt());
randomInt(),
randomBoolean() ? null : InferenceStatsTests.randomInstance());
}
private Map<String, Object> randomIngestStats() {

View File

@ -0,0 +1,54 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.trainedmodel;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.time.Instant;
public class InferenceStatsTests extends AbstractXContentTestCase<InferenceStats> {
public static InferenceStats randomInstance() {
return new InferenceStats(randomNonNegativeLong(),
randomNonNegativeLong(),
randomNonNegativeLong(),
randomNonNegativeLong(),
Instant.now()
);
}
@Override
protected InferenceStats doParseInstance(XContentParser parser) throws IOException {
return InferenceStats.PARSER.apply(parser, null);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected InferenceStats createTestInstance() {
return randomInstance();
}
}

View File

@ -66,6 +66,74 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=from]
(Optional, integer)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=size]
[role="child_attributes"]
[[ml-get-inference-stats-results]]
==== {api-response-body-title}
`count`::
(integer)
The total number of trained model statistics that matched the requested ID patterns.
Could be higher than the number of items in the `trained_model_stats` array as the
size of the array is restricted by the supplied `size` parameter.
`trained_model_stats`::
(array)
An array of trained model statistics, which are sorted by the `model_id` value in
ascending order.
+
.Properties of trained model stats
[%collapsible%open]
====
`model_id`:::
(string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
`pipeline_count`:::
(integer)
The number of ingest pipelines that currently refer to the model.
`inference_stats`:::
(object)
A collection of inference stats fields.
+
.Properties of inference stats
[%collapsible%open]
=====
`missing_all_fields_count`:::
(integer)
The number of inference calls where all the training features for the model
were missing.
`inference_count`:::
(integer)
The total number of times the model has been called for inference.
This is across all inference contexts, including all pipelines.
`cache_miss_count`:::
(integer)
The number of times the model was loaded for inference and was not retrieved from the
cache. If this number is close to the `inference_count`, then the cache
is not being appropriately used. This can be remedied by increasing the cache's size
or its time-to-live (TTL). See <<general-ml-settings>> for the
appropriate settings.
`failure_count`:::
(integer)
The number of failures when using the model for inference.
`timestamp`:::
(<<time-units,time units>>)
The time when the statistics were last updated.
=====
`ingest`:::
(object)
A collection of ingest stats for the model across all nodes. The values are
summations of the individual node statistics. The format matches the `ingest`
section in <<cluster-nodes-stats>>.
====
[[ml-get-inference-stats-response-codes]]
==== {api-response-codes-title}
@ -73,7 +141,6 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=size]
`404` (Missing resources)::
If `allow_no_match` is `false`, this code indicates that there are no
resources that match the request or only partial matches for the request.
[[ml-get-inference-stats-example]]
==== {api-examples-title}
@ -96,11 +163,25 @@ The API returns the following results:
"trained_model_stats": [
{
"model_id": "flight-delay-prediction-1574775339910",
"pipeline_count": 0
"pipeline_count": 0,
"inference_stats": {
"failure_count": 0,
"inference_count": 4,
"cache_miss_count": 3,
"missing_all_fields_count": 0,
"timestamp": 1592399986979
}
},
{
"model_id": "regression-job-one-1574775307356",
"pipeline_count": 1,
"inference_stats": {
"failure_count": 0,
"inference_count": 178,
"cache_miss_count": 3,
"missing_all_fields_count": 0,
"timestamp": 1592399986979
},
"ingest": {
"total": {
"count": 178,
@ -134,4 +215,4 @@ The API returns the following results:
]
}
----
// NOTCONSOLE
// NOTCONSOLE

View File

@ -121,6 +121,10 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
return pipelineCount;
}
public InferenceStats getInferenceStats() {
return inferenceStats;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();

View File

@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.Version;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
@ -20,15 +21,13 @@ import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import java.io.IOException;
import java.time.Instant;
import java.util.Objects;
import java.util.concurrent.atomic.LongAdder;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
public class InferenceStats implements ToXContentObject, Writeable {
public static final String NAME = "inference_stats";
public static final ParseField MISSING_ALL_FIELDS_COUNT = new ParseField("missing_all_fields_count");
public static final ParseField INFERENCE_COUNT = new ParseField("inference_count");
public static final ParseField CACHE_MISS_COUNT = new ParseField("cache_miss_count");
public static final ParseField MODEL_ID = new ParseField("model_id");
public static final ParseField NODE_ID = new ParseField("node_id");
public static final ParseField FAILURE_COUNT = new ParseField("failure_count");
@ -38,12 +37,13 @@ public class InferenceStats implements ToXContentObject, Writeable {
public static final ConstructingObjectParser<InferenceStats, Void> PARSER = new ConstructingObjectParser<>(
NAME,
true,
a -> new InferenceStats((Long)a[0], (Long)a[1], (Long)a[2], (String)a[3], (String)a[4], (Instant)a[5])
a -> new InferenceStats((Long)a[0], (Long)a[1], (Long)a[2], (Long)a[3], (String)a[4], (String)a[5], (Instant)a[6])
);
static {
PARSER.declareLong(ConstructingObjectParser.constructorArg(), MISSING_ALL_FIELDS_COUNT);
PARSER.declareLong(ConstructingObjectParser.constructorArg(), INFERENCE_COUNT);
PARSER.declareLong(ConstructingObjectParser.constructorArg(), FAILURE_COUNT);
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), CACHE_MISS_COUNT);
PARSER.declareString(ConstructingObjectParser.constructorArg(), MODEL_ID);
PARSER.declareString(ConstructingObjectParser.constructorArg(), NODE_ID);
PARSER.declareField(ConstructingObjectParser.constructorArg(),
@ -51,9 +51,6 @@ public class InferenceStats implements ToXContentObject, Writeable {
TIMESTAMP,
ObjectParser.ValueType.VALUE);
}
public static InferenceStats emptyStats(String modelId, String nodeId) {
return new InferenceStats(0L, 0L, 0L, modelId, nodeId, Instant.now());
}
public static String docId(String modelId, String nodeId) {
return NAME + "-" + modelId + "-" + nodeId;
@ -62,6 +59,7 @@ public class InferenceStats implements ToXContentObject, Writeable {
private final long missingAllFieldsCount;
private final long inferenceCount;
private final long failureCount;
private final long cacheMissCount;
private final String modelId;
private final String nodeId;
private final Instant timeStamp;
@ -69,12 +67,14 @@ public class InferenceStats implements ToXContentObject, Writeable {
private InferenceStats(Long missingAllFieldsCount,
Long inferenceCount,
Long failureCount,
Long cacheMissCount,
String modelId,
String nodeId,
Instant instant) {
this(unbox(missingAllFieldsCount),
unbox(inferenceCount),
unbox(failureCount),
this(unboxOrZero(missingAllFieldsCount),
unboxOrZero(inferenceCount),
unboxOrZero(failureCount),
unboxOrZero(cacheMissCount),
modelId,
nodeId,
instant);
@ -83,12 +83,14 @@ public class InferenceStats implements ToXContentObject, Writeable {
public InferenceStats(long missingAllFieldsCount,
long inferenceCount,
long failureCount,
long cacheMissCount,
String modelId,
String nodeId,
Instant timeStamp) {
this.missingAllFieldsCount = missingAllFieldsCount;
this.inferenceCount = inferenceCount;
this.failureCount = failureCount;
this.cacheMissCount = cacheMissCount;
this.modelId = modelId;
this.nodeId = nodeId;
this.timeStamp = timeStamp == null ?
@ -100,6 +102,11 @@ public class InferenceStats implements ToXContentObject, Writeable {
this.missingAllFieldsCount = in.readVLong();
this.inferenceCount = in.readVLong();
this.failureCount = in.readVLong();
if (in.getVersion().onOrAfter(Version.V_7_9_0)) {
this.cacheMissCount = in.readVLong();
} else {
this.cacheMissCount = 0L;
}
this.modelId = in.readOptionalString();
this.nodeId = in.readOptionalString();
this.timeStamp = in.readInstant();
@ -117,6 +124,10 @@ public class InferenceStats implements ToXContentObject, Writeable {
return failureCount;
}
public long getCacheMissCount() {
return cacheMissCount;
}
public String getModelId() {
return modelId;
}
@ -130,7 +141,7 @@ public class InferenceStats implements ToXContentObject, Writeable {
}
public boolean hasStats() {
return missingAllFieldsCount > 0 || inferenceCount > 0 || failureCount > 0;
return missingAllFieldsCount > 0 || inferenceCount > 0 || failureCount > 0 || cacheMissCount > 0;
}
@Override
@ -145,6 +156,7 @@ public class InferenceStats implements ToXContentObject, Writeable {
}
builder.field(FAILURE_COUNT.getPreferredName(), failureCount);
builder.field(INFERENCE_COUNT.getPreferredName(), inferenceCount);
builder.field(CACHE_MISS_COUNT.getPreferredName(), cacheMissCount);
builder.field(MISSING_ALL_FIELDS_COUNT.getPreferredName(), missingAllFieldsCount);
builder.timeField(TIMESTAMP.getPreferredName(), TIMESTAMP.getPreferredName() + "_string", timeStamp.toEpochMilli());
builder.endObject();
@ -159,6 +171,7 @@ public class InferenceStats implements ToXContentObject, Writeable {
return missingAllFieldsCount == that.missingAllFieldsCount
&& inferenceCount == that.inferenceCount
&& failureCount == that.failureCount
&& cacheMissCount == that.cacheMissCount
&& Objects.equals(modelId, that.modelId)
&& Objects.equals(nodeId, that.nodeId)
&& Objects.equals(timeStamp, that.timeStamp);
@ -166,7 +179,7 @@ public class InferenceStats implements ToXContentObject, Writeable {
@Override
public int hashCode() {
return Objects.hash(missingAllFieldsCount, inferenceCount, failureCount, modelId, nodeId, timeStamp);
return Objects.hash(missingAllFieldsCount, inferenceCount, failureCount, cacheMissCount, modelId, nodeId, timeStamp);
}
@Override
@ -175,13 +188,14 @@ public class InferenceStats implements ToXContentObject, Writeable {
"missingAllFieldsCount=" + missingAllFieldsCount +
", inferenceCount=" + inferenceCount +
", failureCount=" + failureCount +
", cacheMissCount=" + cacheMissCount +
", modelId='" + modelId + '\'' +
", nodeId='" + nodeId + '\'' +
", timeStamp=" + timeStamp +
'}';
}
private static long unbox(@Nullable Long value) {
private static long unboxOrZero(@Nullable Long value) {
return value == null ? 0L : value;
}
@ -194,6 +208,9 @@ public class InferenceStats implements ToXContentObject, Writeable {
out.writeVLong(this.missingAllFieldsCount);
out.writeVLong(this.inferenceCount);
out.writeVLong(this.failureCount);
if (out.getVersion().onOrAfter(Version.V_7_9_0)) {
out.writeVLong(this.cacheMissCount);
}
out.writeOptionalString(this.modelId);
out.writeOptionalString(this.nodeId);
out.writeInstant(timeStamp);
@ -201,66 +218,55 @@ public class InferenceStats implements ToXContentObject, Writeable {
public static class Accumulator {
private final LongAdder missingFieldsAccumulator = new LongAdder();
private final LongAdder inferenceAccumulator = new LongAdder();
private final LongAdder failureCountAccumulator = new LongAdder();
private long missingFieldsAccumulator = 0L;
private long inferenceAccumulator = 0L;
private long failureCountAccumulator = 0L;
private long cacheMissAccumulator = 0L;
private final String modelId;
private final String nodeId;
// curious reader
// you may be wondering why the lock set to the fair.
// When `currentStatsAndReset` is called, we want it guaranteed that it will eventually execute.
// If a ReadWriteLock is unfair, there are no such guarantees.
// A call for the `writelock::lock` could pause indefinitely.
private final ReadWriteLock readWriteLock = new ReentrantReadWriteLock(true);
public Accumulator(String modelId, String nodeId) {
public Accumulator(String modelId, String nodeId, long cacheMisses) {
this.modelId = modelId;
this.nodeId = nodeId;
this.cacheMissAccumulator = cacheMisses;
}
public Accumulator(InferenceStats previousStats) {
Accumulator(InferenceStats previousStats) {
this.modelId = previousStats.modelId;
this.nodeId = previousStats.nodeId;
this.missingFieldsAccumulator.add(previousStats.missingAllFieldsCount);
this.inferenceAccumulator.add(previousStats.inferenceCount);
this.failureCountAccumulator.add(previousStats.failureCount);
this.missingFieldsAccumulator += previousStats.missingAllFieldsCount;
this.inferenceAccumulator += previousStats.inferenceCount;
this.failureCountAccumulator += previousStats.failureCount;
this.cacheMissAccumulator += previousStats.cacheMissCount;
}
/**
* NOT Thread Safe
*
* @param otherStats the other stats with which to increment the current stats
* @return Updated accumulator
*/
public Accumulator merge(InferenceStats otherStats) {
this.missingFieldsAccumulator.add(otherStats.missingAllFieldsCount);
this.inferenceAccumulator.add(otherStats.inferenceCount);
this.failureCountAccumulator.add(otherStats.failureCount);
this.missingFieldsAccumulator += otherStats.missingAllFieldsCount;
this.inferenceAccumulator += otherStats.inferenceCount;
this.failureCountAccumulator += otherStats.failureCount;
this.cacheMissAccumulator += otherStats.cacheMissCount;
return this;
}
public Accumulator incMissingFields() {
readWriteLock.readLock().lock();
try {
this.missingFieldsAccumulator.increment();
return this;
} finally {
readWriteLock.readLock().unlock();
}
public synchronized Accumulator incMissingFields() {
this.missingFieldsAccumulator++;
return this;
}
public Accumulator incInference() {
readWriteLock.readLock().lock();
try {
this.inferenceAccumulator.increment();
return this;
} finally {
readWriteLock.readLock().unlock();
}
public synchronized Accumulator incInference() {
this.inferenceAccumulator++;
return this;
}
public Accumulator incFailure() {
readWriteLock.readLock().lock();
try {
this.failureCountAccumulator.increment();
return this;
} finally {
readWriteLock.readLock().unlock();
}
public synchronized Accumulator incFailure() {
this.failureCountAccumulator++;
return this;
}
/**
@ -269,23 +275,20 @@ public class InferenceStats implements ToXContentObject, Writeable {
* Returns the current stats and resets the values of all the counters.
* @return The current stats
*/
public InferenceStats currentStatsAndReset() {
readWriteLock.writeLock().lock();
try {
InferenceStats stats = currentStats(Instant.now());
this.missingFieldsAccumulator.reset();
this.inferenceAccumulator.reset();
this.failureCountAccumulator.reset();
return stats;
} finally {
readWriteLock.writeLock().unlock();
}
public synchronized InferenceStats currentStatsAndReset() {
InferenceStats stats = currentStats(Instant.now());
this.missingFieldsAccumulator = 0L;
this.inferenceAccumulator = 0L;
this.failureCountAccumulator = 0L;
this.cacheMissAccumulator = 0L;
return stats;
}
public InferenceStats currentStats(Instant timeStamp) {
return new InferenceStats(missingFieldsAccumulator.longValue(),
inferenceAccumulator.longValue(),
failureCountAccumulator.longValue(),
return new InferenceStats(missingFieldsAccumulator,
inferenceAccumulator,
failureCountAccumulator,
cacheMissAccumulator,
modelId,
nodeId,
timeStamp);

View File

@ -97,6 +97,9 @@
"failure_count": {
"type": "long"
},
"cache_miss_count": {
"type": "long"
},
"missing_all_fields_count": {
"type": "long"
},

View File

@ -65,6 +65,9 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
@Override
protected Response mutateInstanceForVersion(Response instance, Version version) {
if (version.equals(Version.CURRENT)) {
return instance;
}
if (version.before(Version.V_7_8_0)) {
List<Response.TrainedModelStats> stats = instance.getResources()
.results()
@ -76,7 +79,15 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
.collect(Collectors.toList());
return new Response(new QueryPage<>(stats, instance.getResources().count(), RESULTS_FIELD));
}
return instance;
List<Response.TrainedModelStats> stats = instance.getResources()
.results()
.stream()
.map(s -> new Response.TrainedModelStats(s.getModelId(),
adjustForVersion(s.getIngestStats(), version),
s.getPipelineCount(),
InferenceStatsTests.mutateForVersion(s.getInferenceStats(), version)))
.collect(Collectors.toList());
return new Response(new QueryPage<>(stats, instance.getResources().count(), RESULTS_FIELD));
}
IngestStats adjustForVersion(IngestStats stats, Version version) {

View File

@ -5,11 +5,12 @@
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.Version;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import java.io.IOException;
import java.time.Instant;
@ -17,10 +18,11 @@ import java.util.Collections;
import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE;
public class InferenceStatsTests extends AbstractSerializingTestCase<InferenceStats> {
public class InferenceStatsTests extends AbstractBWCSerializationTestCase<InferenceStats> {
public static InferenceStats createTestInstance(String modelId, @Nullable String nodeId) {
return new InferenceStats(randomNonNegativeLong(),
randomNonNegativeLong(),
randomNonNegativeLong(),
randomNonNegativeLong(),
modelId,
@ -29,6 +31,24 @@ public class InferenceStatsTests extends AbstractSerializingTestCase<InferenceSt
);
}
public static InferenceStats mutateForVersion(InferenceStats instance, Version version) {
if (instance == null) {
return null;
}
if (version.before(Version.V_7_9_0)) {
return new InferenceStats(
instance.getMissingAllFieldsCount(),
instance.getInferenceCount(),
instance.getFailureCount(),
0L,
instance.getModelId(),
instance.getNodeId(),
instance.getTimeStamp()
);
}
return instance;
}
@Override
protected InferenceStats doParseInstance(XContentParser parser) throws IOException {
return InferenceStats.PARSER.apply(parser, null);
@ -54,4 +74,8 @@ public class InferenceStatsTests extends AbstractSerializingTestCase<InferenceSt
return new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true"));
}
@Override
protected InferenceStats mutateInstanceForVersion(InferenceStats instance, Version version) {
return mutateForVersion(instance, version);
}
}

View File

@ -117,9 +117,13 @@ public class InferenceIngestIT extends ESRestTestCase {
try {
Response statsResponse = client().performRequest(new Request("GET",
"_ml/inference/" + classificationModelId + "/_stats"));
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":10"));
String response = EntityUtils.toString(statsResponse.getEntity());
assertThat(response, containsString("\"inference_count\":10"));
assertThat(response, containsString("\"cache_miss_count\":30"));
statsResponse = client().performRequest(new Request("GET", "_ml/inference/" + regressionModelId + "/_stats"));
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":10"));
response = EntityUtils.toString(statsResponse.getEntity());
assertThat(response, containsString("\"inference_count\":10"));
assertThat(response, containsString("\"cache_miss_count\":30"));
} catch (ResponseException ex) {
//this could just mean shard failures.
fail(ex.getMessage());
@ -169,9 +173,13 @@ public class InferenceIngestIT extends ESRestTestCase {
try {
Response statsResponse = client().performRequest(new Request("GET",
"_ml/inference/" + classificationModelId + "/_stats"));
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":10"));
String response = EntityUtils.toString(statsResponse.getEntity());
assertThat(response, containsString("\"inference_count\":10"));
assertThat(response, containsString("\"cache_miss_count\":3"));
statsResponse = client().performRequest(new Request("GET", "_ml/inference/" + regressionModelId + "/_stats"));
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":15"));
response = EntityUtils.toString(statsResponse.getEntity());
assertThat(response, containsString("\"inference_count\":15"));
assertThat(response, containsString("\"cache_miss_count\":3"));
// can get both
statsResponse = client().performRequest(new Request("GET", "_ml/inference/_stats"));
String entityString = EntityUtils.toString(statsResponse.getEntity());

View File

@ -58,12 +58,14 @@ public class TrainedModelStatsService {
" ctx._source.{0} += params.{0};\n" +
" ctx._source.{1} += params.{1};\n" +
" ctx._source.{2} += params.{2};\n" +
" ctx._source.{3} = params.{3};";
" ctx._source.{3} += params.{3};\n" +
" ctx._source.{4} = params.{4};";
// Script to only update if stats have increased since last persistence
private static final String STATS_UPDATE_SCRIPT = Messages.getMessage(STATS_UPDATE_SCRIPT_TEMPLATE,
InferenceStats.MISSING_ALL_FIELDS_COUNT.getPreferredName(),
InferenceStats.INFERENCE_COUNT.getPreferredName(),
InferenceStats.FAILURE_COUNT.getPreferredName(),
InferenceStats.CACHE_MISS_COUNT.getPreferredName(),
InferenceStats.TIMESTAMP.getPreferredName());
private static final ToXContent.Params FOR_INTERNAL_STORAGE_PARAMS =
new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"));
@ -224,6 +226,7 @@ public class TrainedModelStatsService {
params.put(InferenceStats.MISSING_ALL_FIELDS_COUNT.getPreferredName(), stats.getMissingAllFieldsCount());
params.put(InferenceStats.TIMESTAMP.getPreferredName(), stats.getTimeStamp().toEpochMilli());
params.put(InferenceStats.INFERENCE_COUNT.getPreferredName(), stats.getInferenceCount());
params.put(InferenceStats.CACHE_MISS_COUNT.getPreferredName(), stats.getCacheMissCount());
stats.toXContent(builder, FOR_INTERNAL_STORAGE_PARAMS);
UpdateRequest updateRequest = new UpdateRequest();
updateRequest.upsert(builder)

View File

@ -46,11 +46,13 @@ public class LocalModel implements Model {
TrainedModelInput input,
Map<String, String> defaultFieldMap,
InferenceConfig modelInferenceConfig,
TrainedModelStatsService trainedModelStatsService ) {
TrainedModelStatsService trainedModelStatsService) {
this.trainedModelDefinition = trainedModelDefinition;
this.modelId = modelId;
this.fieldNames = new HashSet<>(input.getFieldNames());
this.statsAccumulator = new InferenceStats.Accumulator(modelId, nodeId);
// the ctor being called means a new instance was created.
// Consequently, it was not loaded from cache and on stats persist we should increment accordingly.
this.statsAccumulator = new InferenceStats.Accumulator(modelId, nodeId, 1L);
this.trainedModelStatsService = trainedModelStatsService;
this.defaultFieldMap = defaultFieldMap == null ? null : new HashMap<>(defaultFieldMap);
this.currentInferenceCount = new LongAdder();

View File

@ -634,6 +634,8 @@ public class TrainedModelProvider {
.field(InferenceStats.MISSING_ALL_FIELDS_COUNT.getPreferredName()))
.aggregation(AggregationBuilders.sum(InferenceStats.INFERENCE_COUNT.getPreferredName())
.field(InferenceStats.INFERENCE_COUNT.getPreferredName()))
.aggregation(AggregationBuilders.sum(InferenceStats.CACHE_MISS_COUNT.getPreferredName())
.field(InferenceStats.CACHE_MISS_COUNT.getPreferredName()))
.aggregation(AggregationBuilders.max(InferenceStats.TIMESTAMP.getPreferredName())
.field(InferenceStats.TIMESTAMP.getPreferredName()))
.query(queryBuilder));
@ -646,12 +648,14 @@ public class TrainedModelProvider {
}
Sum failures = response.getAggregations().get(InferenceStats.FAILURE_COUNT.getPreferredName());
Sum missing = response.getAggregations().get(InferenceStats.MISSING_ALL_FIELDS_COUNT.getPreferredName());
Sum cacheMiss = response.getAggregations().get(InferenceStats.CACHE_MISS_COUNT.getPreferredName());
Sum count = response.getAggregations().get(InferenceStats.INFERENCE_COUNT.getPreferredName());
Max timeStamp = response.getAggregations().get(InferenceStats.TIMESTAMP.getPreferredName());
return new InferenceStats(
missing == null ? 0L : Double.valueOf(missing.getValue()).longValue(),
count == null ? 0L : Double.valueOf(count.getValue()).longValue(),
failures == null ? 0L : Double.valueOf(failures.getValue()).longValue(),
cacheMiss == null ? 0L : Double.valueOf(cacheMiss.getValue()).longValue(),
modelId,
null,
timeStamp == null || (Numbers.isValidDouble(timeStamp.getValue()) == false) ?