[7.x] [ML] calculate cache misses for inference and return in stats (#58252) (#58363)

When a local model is constructed, the cache hit miss count is incremented.

When a user calls _stats, we will include the sum cache hit miss count across ALL nodes. This statistic is important to in comparing against the inference_count. If the cache hit miss count is near the inference_count it indicates that the cache is overburdened, or inappropriately configured.
This commit is contained in:
Benjamin Trent 2020-06-19 09:46:51 -04:00 committed by GitHub
parent d8dc638a67
commit bf8641aa15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 471 additions and 86 deletions

View File

@ -18,6 +18,7 @@
*/ */
package org.elasticsearch.client.ml.inference; package org.elasticsearch.client.ml.inference;
import org.elasticsearch.client.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ConstructingObjectParser;
@ -38,32 +39,36 @@ public class TrainedModelStats implements ToXContentObject {
public static final ParseField MODEL_ID = new ParseField("model_id"); public static final ParseField MODEL_ID = new ParseField("model_id");
public static final ParseField PIPELINE_COUNT = new ParseField("pipeline_count"); public static final ParseField PIPELINE_COUNT = new ParseField("pipeline_count");
public static final ParseField INGEST_STATS = new ParseField("ingest"); public static final ParseField INGEST_STATS = new ParseField("ingest");
public static final ParseField INFERENCE_STATS = new ParseField("inference_stats");
private final String modelId; private final String modelId;
private final Map<String, Object> ingestStats; private final Map<String, Object> ingestStats;
private final int pipelineCount; private final int pipelineCount;
private final InferenceStats inferenceStats;
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
static final ConstructingObjectParser<TrainedModelStats, Void> PARSER = static final ConstructingObjectParser<TrainedModelStats, Void> PARSER =
new ConstructingObjectParser<>( new ConstructingObjectParser<>(
"trained_model_stats", "trained_model_stats",
true, true,
args -> new TrainedModelStats((String) args[0], (Map<String, Object>) args[1], (Integer) args[2])); args -> new TrainedModelStats((String) args[0], (Map<String, Object>) args[1], (Integer) args[2], (InferenceStats) args[3]));
static { static {
PARSER.declareString(constructorArg(), MODEL_ID); PARSER.declareString(constructorArg(), MODEL_ID);
PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.mapOrdered(), INGEST_STATS); PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.mapOrdered(), INGEST_STATS);
PARSER.declareInt(constructorArg(), PIPELINE_COUNT); PARSER.declareInt(constructorArg(), PIPELINE_COUNT);
PARSER.declareObject(optionalConstructorArg(), InferenceStats.PARSER, INFERENCE_STATS);
} }
public static TrainedModelStats fromXContent(XContentParser parser) { public static TrainedModelStats fromXContent(XContentParser parser) {
return PARSER.apply(parser, null); return PARSER.apply(parser, null);
} }
public TrainedModelStats(String modelId, Map<String, Object> ingestStats, int pipelineCount) { public TrainedModelStats(String modelId, Map<String, Object> ingestStats, int pipelineCount, InferenceStats inferenceStats) {
this.modelId = modelId; this.modelId = modelId;
this.ingestStats = ingestStats; this.ingestStats = ingestStats;
this.pipelineCount = pipelineCount; this.pipelineCount = pipelineCount;
this.inferenceStats = inferenceStats;
} }
/** /**
@ -89,6 +94,13 @@ public class TrainedModelStats implements ToXContentObject {
return pipelineCount; return pipelineCount;
} }
/**
* Inference statistics
*/
public InferenceStats getInferenceStats() {
return inferenceStats;
}
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
@ -97,13 +109,16 @@ public class TrainedModelStats implements ToXContentObject {
if (ingestStats != null) { if (ingestStats != null) {
builder.field(INGEST_STATS.getPreferredName(), ingestStats); builder.field(INGEST_STATS.getPreferredName(), ingestStats);
} }
if (inferenceStats != null) {
builder.field(INFERENCE_STATS.getPreferredName(), inferenceStats);
}
builder.endObject(); builder.endObject();
return builder; return builder;
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(modelId, ingestStats, pipelineCount); return Objects.hash(modelId, ingestStats, pipelineCount, inferenceStats);
} }
@Override @Override
@ -117,7 +132,8 @@ public class TrainedModelStats implements ToXContentObject {
TrainedModelStats other = (TrainedModelStats) obj; TrainedModelStats other = (TrainedModelStats) obj;
return Objects.equals(this.modelId, other.modelId) return Objects.equals(this.modelId, other.modelId)
&& Objects.equals(this.ingestStats, other.ingestStats) && Objects.equals(this.ingestStats, other.ingestStats)
&& Objects.equals(this.pipelineCount, other.pipelineCount); && Objects.equals(this.pipelineCount, other.pipelineCount)
&& Objects.equals(this.inferenceStats, other.inferenceStats);
} }
} }

View File

@ -0,0 +1,170 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.trainedmodel;
import org.elasticsearch.client.common.TimeUtil;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
import java.time.Instant;
import java.util.Objects;
public class InferenceStats implements ToXContentObject {
public static final String NAME = "inference_stats";
public static final ParseField MISSING_ALL_FIELDS_COUNT = new ParseField("missing_all_fields_count");
public static final ParseField INFERENCE_COUNT = new ParseField("inference_count");
public static final ParseField CACHE_MISS_COUNT = new ParseField("cache_miss_count");
public static final ParseField FAILURE_COUNT = new ParseField("failure_count");
public static final ParseField TIMESTAMP = new ParseField("timestamp");
public static final ConstructingObjectParser<InferenceStats, Void> PARSER = new ConstructingObjectParser<>(
NAME,
true,
a -> new InferenceStats((Long)a[0], (Long)a[1], (Long)a[2], (Long)a[3], (Instant)a[4])
);
static {
PARSER.declareLong(ConstructingObjectParser.constructorArg(), MISSING_ALL_FIELDS_COUNT);
PARSER.declareLong(ConstructingObjectParser.constructorArg(), INFERENCE_COUNT);
PARSER.declareLong(ConstructingObjectParser.constructorArg(), FAILURE_COUNT);
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), CACHE_MISS_COUNT);
PARSER.declareField(ConstructingObjectParser.constructorArg(),
p -> TimeUtil.parseTimeFieldToInstant(p, TIMESTAMP.getPreferredName()),
TIMESTAMP,
ObjectParser.ValueType.VALUE);
}
private final long missingAllFieldsCount;
private final long inferenceCount;
private final long failureCount;
private final long cacheMissCount;
private final Instant timeStamp;
private InferenceStats(Long missingAllFieldsCount,
Long inferenceCount,
Long failureCount,
Long cacheMissCount,
Instant instant) {
this(unboxOrZero(missingAllFieldsCount),
unboxOrZero(inferenceCount),
unboxOrZero(failureCount),
unboxOrZero(cacheMissCount),
instant);
}
public InferenceStats(long missingAllFieldsCount,
long inferenceCount,
long failureCount,
long cacheMissCount,
Instant timeStamp) {
this.missingAllFieldsCount = missingAllFieldsCount;
this.inferenceCount = inferenceCount;
this.failureCount = failureCount;
this.cacheMissCount = cacheMissCount;
this.timeStamp = timeStamp == null ?
Instant.ofEpochMilli(Instant.now().toEpochMilli()) :
Instant.ofEpochMilli(timeStamp.toEpochMilli());
}
/**
* How many times this model attempted to infer with all its fields missing
*/
public long getMissingAllFieldsCount() {
return missingAllFieldsCount;
}
/**
* How many inference calls were made against this model
*/
public long getInferenceCount() {
return inferenceCount;
}
/**
* How many inference failures occurred.
*/
public long getFailureCount() {
return failureCount;
}
/**
* How many cache misses occurred when inferring this model
*/
public long getCacheMissCount() {
return cacheMissCount;
}
/**
* The timestamp of these statistics.
*/
public Instant getTimeStamp() {
return timeStamp;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FAILURE_COUNT.getPreferredName(), failureCount);
builder.field(INFERENCE_COUNT.getPreferredName(), inferenceCount);
builder.field(CACHE_MISS_COUNT.getPreferredName(), cacheMissCount);
builder.field(MISSING_ALL_FIELDS_COUNT.getPreferredName(), missingAllFieldsCount);
builder.timeField(TIMESTAMP.getPreferredName(), TIMESTAMP.getPreferredName() + "_string", timeStamp.toEpochMilli());
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InferenceStats that = (InferenceStats) o;
return missingAllFieldsCount == that.missingAllFieldsCount
&& inferenceCount == that.inferenceCount
&& failureCount == that.failureCount
&& cacheMissCount == that.cacheMissCount
&& Objects.equals(timeStamp, that.timeStamp);
}
@Override
public int hashCode() {
return Objects.hash(missingAllFieldsCount, inferenceCount, failureCount, cacheMissCount, timeStamp);
}
@Override
public String toString() {
return "InferenceStats{" +
"missingAllFieldsCount=" + missingAllFieldsCount +
", inferenceCount=" + inferenceCount +
", failureCount=" + failureCount +
", cacheMissCount=" + cacheMissCount +
", timeStamp=" + timeStamp +
'}';
}
private static long unboxOrZero(@Nullable Long value) {
return value == null ? 0L : value;
}
}

View File

@ -18,6 +18,7 @@
*/ */
package org.elasticsearch.client.ml.inference; package org.elasticsearch.client.ml.inference;
import org.elasticsearch.client.ml.inference.trainedmodel.InferenceStatsTests;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
@ -58,7 +59,8 @@ public class TrainedModelStatsTests extends AbstractXContentTestCase<TrainedMode
return new TrainedModelStats( return new TrainedModelStats(
randomAlphaOfLength(10), randomAlphaOfLength(10),
randomBoolean() ? null : randomIngestStats(), randomBoolean() ? null : randomIngestStats(),
randomInt()); randomInt(),
randomBoolean() ? null : InferenceStatsTests.randomInstance());
} }
private Map<String, Object> randomIngestStats() { private Map<String, Object> randomIngestStats() {

View File

@ -0,0 +1,54 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.trainedmodel;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.time.Instant;
public class InferenceStatsTests extends AbstractXContentTestCase<InferenceStats> {
public static InferenceStats randomInstance() {
return new InferenceStats(randomNonNegativeLong(),
randomNonNegativeLong(),
randomNonNegativeLong(),
randomNonNegativeLong(),
Instant.now()
);
}
@Override
protected InferenceStats doParseInstance(XContentParser parser) throws IOException {
return InferenceStats.PARSER.apply(parser, null);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected InferenceStats createTestInstance() {
return randomInstance();
}
}

View File

@ -66,6 +66,74 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=from]
(Optional, integer) (Optional, integer)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=size] include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=size]
[role="child_attributes"]
[[ml-get-inference-stats-results]]
==== {api-response-body-title}
`count`::
(integer)
The total number of trained model statistics that matched the requested ID patterns.
Could be higher than the number of items in the `trained_model_stats` array as the
size of the array is restricted by the supplied `size` parameter.
`trained_model_stats`::
(array)
An array of trained model statistics, which are sorted by the `model_id` value in
ascending order.
+
.Properties of trained model stats
[%collapsible%open]
====
`model_id`:::
(string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
`pipeline_count`:::
(integer)
The number of ingest pipelines that currently refer to the model.
`inference_stats`:::
(object)
A collection of inference stats fields.
+
.Properties of inference stats
[%collapsible%open]
=====
`missing_all_fields_count`:::
(integer)
The number of inference calls where all the training features for the model
were missing.
`inference_count`:::
(integer)
The total number of times the model has been called for inference.
This is across all inference contexts, including all pipelines.
`cache_miss_count`:::
(integer)
The number of times the model was loaded for inference and was not retrieved from the
cache. If this number is close to the `inference_count`, then the cache
is not being appropriately used. This can be remedied by increasing the cache's size
or its time-to-live (TTL). See <<general-ml-settings>> for the
appropriate settings.
`failure_count`:::
(integer)
The number of failures when using the model for inference.
`timestamp`:::
(<<time-units,time units>>)
The time when the statistics were last updated.
=====
`ingest`:::
(object)
A collection of ingest stats for the model across all nodes. The values are
summations of the individual node statistics. The format matches the `ingest`
section in <<cluster-nodes-stats>>.
====
[[ml-get-inference-stats-response-codes]] [[ml-get-inference-stats-response-codes]]
==== {api-response-codes-title} ==== {api-response-codes-title}
@ -74,7 +142,6 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=size]
If `allow_no_match` is `false`, this code indicates that there are no If `allow_no_match` is `false`, this code indicates that there are no
resources that match the request or only partial matches for the request. resources that match the request or only partial matches for the request.
[[ml-get-inference-stats-example]] [[ml-get-inference-stats-example]]
==== {api-examples-title} ==== {api-examples-title}
@ -96,11 +163,25 @@ The API returns the following results:
"trained_model_stats": [ "trained_model_stats": [
{ {
"model_id": "flight-delay-prediction-1574775339910", "model_id": "flight-delay-prediction-1574775339910",
"pipeline_count": 0 "pipeline_count": 0,
"inference_stats": {
"failure_count": 0,
"inference_count": 4,
"cache_miss_count": 3,
"missing_all_fields_count": 0,
"timestamp": 1592399986979
}
}, },
{ {
"model_id": "regression-job-one-1574775307356", "model_id": "regression-job-one-1574775307356",
"pipeline_count": 1, "pipeline_count": 1,
"inference_stats": {
"failure_count": 0,
"inference_count": 178,
"cache_miss_count": 3,
"missing_all_fields_count": 0,
"timestamp": 1592399986979
},
"ingest": { "ingest": {
"total": { "total": {
"count": 178, "count": 178,

View File

@ -121,6 +121,10 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
return pipelineCount; return pipelineCount;
} }
public InferenceStats getInferenceStats() {
return inferenceStats;
}
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();

View File

@ -5,6 +5,7 @@
*/ */
package org.elasticsearch.xpack.core.ml.inference.trainedmodel; package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.Version;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
@ -20,15 +21,13 @@ import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import java.io.IOException; import java.io.IOException;
import java.time.Instant; import java.time.Instant;
import java.util.Objects; import java.util.Objects;
import java.util.concurrent.atomic.LongAdder;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
public class InferenceStats implements ToXContentObject, Writeable { public class InferenceStats implements ToXContentObject, Writeable {
public static final String NAME = "inference_stats"; public static final String NAME = "inference_stats";
public static final ParseField MISSING_ALL_FIELDS_COUNT = new ParseField("missing_all_fields_count"); public static final ParseField MISSING_ALL_FIELDS_COUNT = new ParseField("missing_all_fields_count");
public static final ParseField INFERENCE_COUNT = new ParseField("inference_count"); public static final ParseField INFERENCE_COUNT = new ParseField("inference_count");
public static final ParseField CACHE_MISS_COUNT = new ParseField("cache_miss_count");
public static final ParseField MODEL_ID = new ParseField("model_id"); public static final ParseField MODEL_ID = new ParseField("model_id");
public static final ParseField NODE_ID = new ParseField("node_id"); public static final ParseField NODE_ID = new ParseField("node_id");
public static final ParseField FAILURE_COUNT = new ParseField("failure_count"); public static final ParseField FAILURE_COUNT = new ParseField("failure_count");
@ -38,12 +37,13 @@ public class InferenceStats implements ToXContentObject, Writeable {
public static final ConstructingObjectParser<InferenceStats, Void> PARSER = new ConstructingObjectParser<>( public static final ConstructingObjectParser<InferenceStats, Void> PARSER = new ConstructingObjectParser<>(
NAME, NAME,
true, true,
a -> new InferenceStats((Long)a[0], (Long)a[1], (Long)a[2], (String)a[3], (String)a[4], (Instant)a[5]) a -> new InferenceStats((Long)a[0], (Long)a[1], (Long)a[2], (Long)a[3], (String)a[4], (String)a[5], (Instant)a[6])
); );
static { static {
PARSER.declareLong(ConstructingObjectParser.constructorArg(), MISSING_ALL_FIELDS_COUNT); PARSER.declareLong(ConstructingObjectParser.constructorArg(), MISSING_ALL_FIELDS_COUNT);
PARSER.declareLong(ConstructingObjectParser.constructorArg(), INFERENCE_COUNT); PARSER.declareLong(ConstructingObjectParser.constructorArg(), INFERENCE_COUNT);
PARSER.declareLong(ConstructingObjectParser.constructorArg(), FAILURE_COUNT); PARSER.declareLong(ConstructingObjectParser.constructorArg(), FAILURE_COUNT);
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), CACHE_MISS_COUNT);
PARSER.declareString(ConstructingObjectParser.constructorArg(), MODEL_ID); PARSER.declareString(ConstructingObjectParser.constructorArg(), MODEL_ID);
PARSER.declareString(ConstructingObjectParser.constructorArg(), NODE_ID); PARSER.declareString(ConstructingObjectParser.constructorArg(), NODE_ID);
PARSER.declareField(ConstructingObjectParser.constructorArg(), PARSER.declareField(ConstructingObjectParser.constructorArg(),
@ -51,9 +51,6 @@ public class InferenceStats implements ToXContentObject, Writeable {
TIMESTAMP, TIMESTAMP,
ObjectParser.ValueType.VALUE); ObjectParser.ValueType.VALUE);
} }
public static InferenceStats emptyStats(String modelId, String nodeId) {
return new InferenceStats(0L, 0L, 0L, modelId, nodeId, Instant.now());
}
public static String docId(String modelId, String nodeId) { public static String docId(String modelId, String nodeId) {
return NAME + "-" + modelId + "-" + nodeId; return NAME + "-" + modelId + "-" + nodeId;
@ -62,6 +59,7 @@ public class InferenceStats implements ToXContentObject, Writeable {
private final long missingAllFieldsCount; private final long missingAllFieldsCount;
private final long inferenceCount; private final long inferenceCount;
private final long failureCount; private final long failureCount;
private final long cacheMissCount;
private final String modelId; private final String modelId;
private final String nodeId; private final String nodeId;
private final Instant timeStamp; private final Instant timeStamp;
@ -69,12 +67,14 @@ public class InferenceStats implements ToXContentObject, Writeable {
private InferenceStats(Long missingAllFieldsCount, private InferenceStats(Long missingAllFieldsCount,
Long inferenceCount, Long inferenceCount,
Long failureCount, Long failureCount,
Long cacheMissCount,
String modelId, String modelId,
String nodeId, String nodeId,
Instant instant) { Instant instant) {
this(unbox(missingAllFieldsCount), this(unboxOrZero(missingAllFieldsCount),
unbox(inferenceCount), unboxOrZero(inferenceCount),
unbox(failureCount), unboxOrZero(failureCount),
unboxOrZero(cacheMissCount),
modelId, modelId,
nodeId, nodeId,
instant); instant);
@ -83,12 +83,14 @@ public class InferenceStats implements ToXContentObject, Writeable {
public InferenceStats(long missingAllFieldsCount, public InferenceStats(long missingAllFieldsCount,
long inferenceCount, long inferenceCount,
long failureCount, long failureCount,
long cacheMissCount,
String modelId, String modelId,
String nodeId, String nodeId,
Instant timeStamp) { Instant timeStamp) {
this.missingAllFieldsCount = missingAllFieldsCount; this.missingAllFieldsCount = missingAllFieldsCount;
this.inferenceCount = inferenceCount; this.inferenceCount = inferenceCount;
this.failureCount = failureCount; this.failureCount = failureCount;
this.cacheMissCount = cacheMissCount;
this.modelId = modelId; this.modelId = modelId;
this.nodeId = nodeId; this.nodeId = nodeId;
this.timeStamp = timeStamp == null ? this.timeStamp = timeStamp == null ?
@ -100,6 +102,11 @@ public class InferenceStats implements ToXContentObject, Writeable {
this.missingAllFieldsCount = in.readVLong(); this.missingAllFieldsCount = in.readVLong();
this.inferenceCount = in.readVLong(); this.inferenceCount = in.readVLong();
this.failureCount = in.readVLong(); this.failureCount = in.readVLong();
if (in.getVersion().onOrAfter(Version.V_7_9_0)) {
this.cacheMissCount = in.readVLong();
} else {
this.cacheMissCount = 0L;
}
this.modelId = in.readOptionalString(); this.modelId = in.readOptionalString();
this.nodeId = in.readOptionalString(); this.nodeId = in.readOptionalString();
this.timeStamp = in.readInstant(); this.timeStamp = in.readInstant();
@ -117,6 +124,10 @@ public class InferenceStats implements ToXContentObject, Writeable {
return failureCount; return failureCount;
} }
public long getCacheMissCount() {
return cacheMissCount;
}
public String getModelId() { public String getModelId() {
return modelId; return modelId;
} }
@ -130,7 +141,7 @@ public class InferenceStats implements ToXContentObject, Writeable {
} }
public boolean hasStats() { public boolean hasStats() {
return missingAllFieldsCount > 0 || inferenceCount > 0 || failureCount > 0; return missingAllFieldsCount > 0 || inferenceCount > 0 || failureCount > 0 || cacheMissCount > 0;
} }
@Override @Override
@ -145,6 +156,7 @@ public class InferenceStats implements ToXContentObject, Writeable {
} }
builder.field(FAILURE_COUNT.getPreferredName(), failureCount); builder.field(FAILURE_COUNT.getPreferredName(), failureCount);
builder.field(INFERENCE_COUNT.getPreferredName(), inferenceCount); builder.field(INFERENCE_COUNT.getPreferredName(), inferenceCount);
builder.field(CACHE_MISS_COUNT.getPreferredName(), cacheMissCount);
builder.field(MISSING_ALL_FIELDS_COUNT.getPreferredName(), missingAllFieldsCount); builder.field(MISSING_ALL_FIELDS_COUNT.getPreferredName(), missingAllFieldsCount);
builder.timeField(TIMESTAMP.getPreferredName(), TIMESTAMP.getPreferredName() + "_string", timeStamp.toEpochMilli()); builder.timeField(TIMESTAMP.getPreferredName(), TIMESTAMP.getPreferredName() + "_string", timeStamp.toEpochMilli());
builder.endObject(); builder.endObject();
@ -159,6 +171,7 @@ public class InferenceStats implements ToXContentObject, Writeable {
return missingAllFieldsCount == that.missingAllFieldsCount return missingAllFieldsCount == that.missingAllFieldsCount
&& inferenceCount == that.inferenceCount && inferenceCount == that.inferenceCount
&& failureCount == that.failureCount && failureCount == that.failureCount
&& cacheMissCount == that.cacheMissCount
&& Objects.equals(modelId, that.modelId) && Objects.equals(modelId, that.modelId)
&& Objects.equals(nodeId, that.nodeId) && Objects.equals(nodeId, that.nodeId)
&& Objects.equals(timeStamp, that.timeStamp); && Objects.equals(timeStamp, that.timeStamp);
@ -166,7 +179,7 @@ public class InferenceStats implements ToXContentObject, Writeable {
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(missingAllFieldsCount, inferenceCount, failureCount, modelId, nodeId, timeStamp); return Objects.hash(missingAllFieldsCount, inferenceCount, failureCount, cacheMissCount, modelId, nodeId, timeStamp);
} }
@Override @Override
@ -175,13 +188,14 @@ public class InferenceStats implements ToXContentObject, Writeable {
"missingAllFieldsCount=" + missingAllFieldsCount + "missingAllFieldsCount=" + missingAllFieldsCount +
", inferenceCount=" + inferenceCount + ", inferenceCount=" + inferenceCount +
", failureCount=" + failureCount + ", failureCount=" + failureCount +
", cacheMissCount=" + cacheMissCount +
", modelId='" + modelId + '\'' + ", modelId='" + modelId + '\'' +
", nodeId='" + nodeId + '\'' + ", nodeId='" + nodeId + '\'' +
", timeStamp=" + timeStamp + ", timeStamp=" + timeStamp +
'}'; '}';
} }
private static long unbox(@Nullable Long value) { private static long unboxOrZero(@Nullable Long value) {
return value == null ? 0L : value; return value == null ? 0L : value;
} }
@ -194,6 +208,9 @@ public class InferenceStats implements ToXContentObject, Writeable {
out.writeVLong(this.missingAllFieldsCount); out.writeVLong(this.missingAllFieldsCount);
out.writeVLong(this.inferenceCount); out.writeVLong(this.inferenceCount);
out.writeVLong(this.failureCount); out.writeVLong(this.failureCount);
if (out.getVersion().onOrAfter(Version.V_7_9_0)) {
out.writeVLong(this.cacheMissCount);
}
out.writeOptionalString(this.modelId); out.writeOptionalString(this.modelId);
out.writeOptionalString(this.nodeId); out.writeOptionalString(this.nodeId);
out.writeInstant(timeStamp); out.writeInstant(timeStamp);
@ -201,66 +218,55 @@ public class InferenceStats implements ToXContentObject, Writeable {
public static class Accumulator { public static class Accumulator {
private final LongAdder missingFieldsAccumulator = new LongAdder(); private long missingFieldsAccumulator = 0L;
private final LongAdder inferenceAccumulator = new LongAdder(); private long inferenceAccumulator = 0L;
private final LongAdder failureCountAccumulator = new LongAdder(); private long failureCountAccumulator = 0L;
private long cacheMissAccumulator = 0L;
private final String modelId; private final String modelId;
private final String nodeId; private final String nodeId;
// curious reader
// you may be wondering why the lock set to the fair.
// When `currentStatsAndReset` is called, we want it guaranteed that it will eventually execute.
// If a ReadWriteLock is unfair, there are no such guarantees.
// A call for the `writelock::lock` could pause indefinitely.
private final ReadWriteLock readWriteLock = new ReentrantReadWriteLock(true);
public Accumulator(String modelId, String nodeId) { public Accumulator(String modelId, String nodeId, long cacheMisses) {
this.modelId = modelId; this.modelId = modelId;
this.nodeId = nodeId; this.nodeId = nodeId;
this.cacheMissAccumulator = cacheMisses;
} }
public Accumulator(InferenceStats previousStats) { Accumulator(InferenceStats previousStats) {
this.modelId = previousStats.modelId; this.modelId = previousStats.modelId;
this.nodeId = previousStats.nodeId; this.nodeId = previousStats.nodeId;
this.missingFieldsAccumulator.add(previousStats.missingAllFieldsCount); this.missingFieldsAccumulator += previousStats.missingAllFieldsCount;
this.inferenceAccumulator.add(previousStats.inferenceCount); this.inferenceAccumulator += previousStats.inferenceCount;
this.failureCountAccumulator.add(previousStats.failureCount); this.failureCountAccumulator += previousStats.failureCount;
this.cacheMissAccumulator += previousStats.cacheMissCount;
} }
/**
* NOT Thread Safe
*
* @param otherStats the other stats with which to increment the current stats
* @return Updated accumulator
*/
public Accumulator merge(InferenceStats otherStats) { public Accumulator merge(InferenceStats otherStats) {
this.missingFieldsAccumulator.add(otherStats.missingAllFieldsCount); this.missingFieldsAccumulator += otherStats.missingAllFieldsCount;
this.inferenceAccumulator.add(otherStats.inferenceCount); this.inferenceAccumulator += otherStats.inferenceCount;
this.failureCountAccumulator.add(otherStats.failureCount); this.failureCountAccumulator += otherStats.failureCount;
this.cacheMissAccumulator += otherStats.cacheMissCount;
return this; return this;
} }
public Accumulator incMissingFields() { public synchronized Accumulator incMissingFields() {
readWriteLock.readLock().lock(); this.missingFieldsAccumulator++;
try {
this.missingFieldsAccumulator.increment();
return this; return this;
} finally {
readWriteLock.readLock().unlock();
}
} }
public Accumulator incInference() { public synchronized Accumulator incInference() {
readWriteLock.readLock().lock(); this.inferenceAccumulator++;
try {
this.inferenceAccumulator.increment();
return this; return this;
} finally {
readWriteLock.readLock().unlock();
}
} }
public Accumulator incFailure() { public synchronized Accumulator incFailure() {
readWriteLock.readLock().lock(); this.failureCountAccumulator++;
try {
this.failureCountAccumulator.increment();
return this; return this;
} finally {
readWriteLock.readLock().unlock();
}
} }
/** /**
@ -269,23 +275,20 @@ public class InferenceStats implements ToXContentObject, Writeable {
* Returns the current stats and resets the values of all the counters. * Returns the current stats and resets the values of all the counters.
* @return The current stats * @return The current stats
*/ */
public InferenceStats currentStatsAndReset() { public synchronized InferenceStats currentStatsAndReset() {
readWriteLock.writeLock().lock();
try {
InferenceStats stats = currentStats(Instant.now()); InferenceStats stats = currentStats(Instant.now());
this.missingFieldsAccumulator.reset(); this.missingFieldsAccumulator = 0L;
this.inferenceAccumulator.reset(); this.inferenceAccumulator = 0L;
this.failureCountAccumulator.reset(); this.failureCountAccumulator = 0L;
this.cacheMissAccumulator = 0L;
return stats; return stats;
} finally {
readWriteLock.writeLock().unlock();
}
} }
public InferenceStats currentStats(Instant timeStamp) { public InferenceStats currentStats(Instant timeStamp) {
return new InferenceStats(missingFieldsAccumulator.longValue(), return new InferenceStats(missingFieldsAccumulator,
inferenceAccumulator.longValue(), inferenceAccumulator,
failureCountAccumulator.longValue(), failureCountAccumulator,
cacheMissAccumulator,
modelId, modelId,
nodeId, nodeId,
timeStamp); timeStamp);

View File

@ -97,6 +97,9 @@
"failure_count": { "failure_count": {
"type": "long" "type": "long"
}, },
"cache_miss_count": {
"type": "long"
},
"missing_all_fields_count": { "missing_all_fields_count": {
"type": "long" "type": "long"
}, },

View File

@ -65,6 +65,9 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
@Override @Override
protected Response mutateInstanceForVersion(Response instance, Version version) { protected Response mutateInstanceForVersion(Response instance, Version version) {
if (version.equals(Version.CURRENT)) {
return instance;
}
if (version.before(Version.V_7_8_0)) { if (version.before(Version.V_7_8_0)) {
List<Response.TrainedModelStats> stats = instance.getResources() List<Response.TrainedModelStats> stats = instance.getResources()
.results() .results()
@ -76,7 +79,15 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
.collect(Collectors.toList()); .collect(Collectors.toList());
return new Response(new QueryPage<>(stats, instance.getResources().count(), RESULTS_FIELD)); return new Response(new QueryPage<>(stats, instance.getResources().count(), RESULTS_FIELD));
} }
return instance; List<Response.TrainedModelStats> stats = instance.getResources()
.results()
.stream()
.map(s -> new Response.TrainedModelStats(s.getModelId(),
adjustForVersion(s.getIngestStats(), version),
s.getPipelineCount(),
InferenceStatsTests.mutateForVersion(s.getInferenceStats(), version)))
.collect(Collectors.toList());
return new Response(new QueryPage<>(stats, instance.getResources().count(), RESULTS_FIELD));
} }
IngestStats adjustForVersion(IngestStats stats, Version version) { IngestStats adjustForVersion(IngestStats stats, Version version) {

View File

@ -5,11 +5,12 @@
*/ */
package org.elasticsearch.xpack.core.ml.inference.trainedmodel; package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.Version;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import java.io.IOException; import java.io.IOException;
import java.time.Instant; import java.time.Instant;
@ -17,10 +18,11 @@ import java.util.Collections;
import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE; import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE;
public class InferenceStatsTests extends AbstractSerializingTestCase<InferenceStats> { public class InferenceStatsTests extends AbstractBWCSerializationTestCase<InferenceStats> {
public static InferenceStats createTestInstance(String modelId, @Nullable String nodeId) { public static InferenceStats createTestInstance(String modelId, @Nullable String nodeId) {
return new InferenceStats(randomNonNegativeLong(), return new InferenceStats(randomNonNegativeLong(),
randomNonNegativeLong(),
randomNonNegativeLong(), randomNonNegativeLong(),
randomNonNegativeLong(), randomNonNegativeLong(),
modelId, modelId,
@ -29,6 +31,24 @@ public class InferenceStatsTests extends AbstractSerializingTestCase<InferenceSt
); );
} }
public static InferenceStats mutateForVersion(InferenceStats instance, Version version) {
if (instance == null) {
return null;
}
if (version.before(Version.V_7_9_0)) {
return new InferenceStats(
instance.getMissingAllFieldsCount(),
instance.getInferenceCount(),
instance.getFailureCount(),
0L,
instance.getModelId(),
instance.getNodeId(),
instance.getTimeStamp()
);
}
return instance;
}
@Override @Override
protected InferenceStats doParseInstance(XContentParser parser) throws IOException { protected InferenceStats doParseInstance(XContentParser parser) throws IOException {
return InferenceStats.PARSER.apply(parser, null); return InferenceStats.PARSER.apply(parser, null);
@ -54,4 +74,8 @@ public class InferenceStatsTests extends AbstractSerializingTestCase<InferenceSt
return new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true")); return new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true"));
} }
@Override
protected InferenceStats mutateInstanceForVersion(InferenceStats instance, Version version) {
return mutateForVersion(instance, version);
}
} }

View File

@ -117,9 +117,13 @@ public class InferenceIngestIT extends ESRestTestCase {
try { try {
Response statsResponse = client().performRequest(new Request("GET", Response statsResponse = client().performRequest(new Request("GET",
"_ml/inference/" + classificationModelId + "/_stats")); "_ml/inference/" + classificationModelId + "/_stats"));
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":10")); String response = EntityUtils.toString(statsResponse.getEntity());
assertThat(response, containsString("\"inference_count\":10"));
assertThat(response, containsString("\"cache_miss_count\":30"));
statsResponse = client().performRequest(new Request("GET", "_ml/inference/" + regressionModelId + "/_stats")); statsResponse = client().performRequest(new Request("GET", "_ml/inference/" + regressionModelId + "/_stats"));
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":10")); response = EntityUtils.toString(statsResponse.getEntity());
assertThat(response, containsString("\"inference_count\":10"));
assertThat(response, containsString("\"cache_miss_count\":30"));
} catch (ResponseException ex) { } catch (ResponseException ex) {
//this could just mean shard failures. //this could just mean shard failures.
fail(ex.getMessage()); fail(ex.getMessage());
@ -169,9 +173,13 @@ public class InferenceIngestIT extends ESRestTestCase {
try { try {
Response statsResponse = client().performRequest(new Request("GET", Response statsResponse = client().performRequest(new Request("GET",
"_ml/inference/" + classificationModelId + "/_stats")); "_ml/inference/" + classificationModelId + "/_stats"));
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":10")); String response = EntityUtils.toString(statsResponse.getEntity());
assertThat(response, containsString("\"inference_count\":10"));
assertThat(response, containsString("\"cache_miss_count\":3"));
statsResponse = client().performRequest(new Request("GET", "_ml/inference/" + regressionModelId + "/_stats")); statsResponse = client().performRequest(new Request("GET", "_ml/inference/" + regressionModelId + "/_stats"));
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":15")); response = EntityUtils.toString(statsResponse.getEntity());
assertThat(response, containsString("\"inference_count\":15"));
assertThat(response, containsString("\"cache_miss_count\":3"));
// can get both // can get both
statsResponse = client().performRequest(new Request("GET", "_ml/inference/_stats")); statsResponse = client().performRequest(new Request("GET", "_ml/inference/_stats"));
String entityString = EntityUtils.toString(statsResponse.getEntity()); String entityString = EntityUtils.toString(statsResponse.getEntity());

View File

@ -58,12 +58,14 @@ public class TrainedModelStatsService {
" ctx._source.{0} += params.{0};\n" + " ctx._source.{0} += params.{0};\n" +
" ctx._source.{1} += params.{1};\n" + " ctx._source.{1} += params.{1};\n" +
" ctx._source.{2} += params.{2};\n" + " ctx._source.{2} += params.{2};\n" +
" ctx._source.{3} = params.{3};"; " ctx._source.{3} += params.{3};\n" +
" ctx._source.{4} = params.{4};";
// Script to only update if stats have increased since last persistence // Script to only update if stats have increased since last persistence
private static final String STATS_UPDATE_SCRIPT = Messages.getMessage(STATS_UPDATE_SCRIPT_TEMPLATE, private static final String STATS_UPDATE_SCRIPT = Messages.getMessage(STATS_UPDATE_SCRIPT_TEMPLATE,
InferenceStats.MISSING_ALL_FIELDS_COUNT.getPreferredName(), InferenceStats.MISSING_ALL_FIELDS_COUNT.getPreferredName(),
InferenceStats.INFERENCE_COUNT.getPreferredName(), InferenceStats.INFERENCE_COUNT.getPreferredName(),
InferenceStats.FAILURE_COUNT.getPreferredName(), InferenceStats.FAILURE_COUNT.getPreferredName(),
InferenceStats.CACHE_MISS_COUNT.getPreferredName(),
InferenceStats.TIMESTAMP.getPreferredName()); InferenceStats.TIMESTAMP.getPreferredName());
private static final ToXContent.Params FOR_INTERNAL_STORAGE_PARAMS = private static final ToXContent.Params FOR_INTERNAL_STORAGE_PARAMS =
new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")); new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"));
@ -224,6 +226,7 @@ public class TrainedModelStatsService {
params.put(InferenceStats.MISSING_ALL_FIELDS_COUNT.getPreferredName(), stats.getMissingAllFieldsCount()); params.put(InferenceStats.MISSING_ALL_FIELDS_COUNT.getPreferredName(), stats.getMissingAllFieldsCount());
params.put(InferenceStats.TIMESTAMP.getPreferredName(), stats.getTimeStamp().toEpochMilli()); params.put(InferenceStats.TIMESTAMP.getPreferredName(), stats.getTimeStamp().toEpochMilli());
params.put(InferenceStats.INFERENCE_COUNT.getPreferredName(), stats.getInferenceCount()); params.put(InferenceStats.INFERENCE_COUNT.getPreferredName(), stats.getInferenceCount());
params.put(InferenceStats.CACHE_MISS_COUNT.getPreferredName(), stats.getCacheMissCount());
stats.toXContent(builder, FOR_INTERNAL_STORAGE_PARAMS); stats.toXContent(builder, FOR_INTERNAL_STORAGE_PARAMS);
UpdateRequest updateRequest = new UpdateRequest(); UpdateRequest updateRequest = new UpdateRequest();
updateRequest.upsert(builder) updateRequest.upsert(builder)

View File

@ -46,11 +46,13 @@ public class LocalModel implements Model {
TrainedModelInput input, TrainedModelInput input,
Map<String, String> defaultFieldMap, Map<String, String> defaultFieldMap,
InferenceConfig modelInferenceConfig, InferenceConfig modelInferenceConfig,
TrainedModelStatsService trainedModelStatsService ) { TrainedModelStatsService trainedModelStatsService) {
this.trainedModelDefinition = trainedModelDefinition; this.trainedModelDefinition = trainedModelDefinition;
this.modelId = modelId; this.modelId = modelId;
this.fieldNames = new HashSet<>(input.getFieldNames()); this.fieldNames = new HashSet<>(input.getFieldNames());
this.statsAccumulator = new InferenceStats.Accumulator(modelId, nodeId); // the ctor being called means a new instance was created.
// Consequently, it was not loaded from cache and on stats persist we should increment accordingly.
this.statsAccumulator = new InferenceStats.Accumulator(modelId, nodeId, 1L);
this.trainedModelStatsService = trainedModelStatsService; this.trainedModelStatsService = trainedModelStatsService;
this.defaultFieldMap = defaultFieldMap == null ? null : new HashMap<>(defaultFieldMap); this.defaultFieldMap = defaultFieldMap == null ? null : new HashMap<>(defaultFieldMap);
this.currentInferenceCount = new LongAdder(); this.currentInferenceCount = new LongAdder();

View File

@ -634,6 +634,8 @@ public class TrainedModelProvider {
.field(InferenceStats.MISSING_ALL_FIELDS_COUNT.getPreferredName())) .field(InferenceStats.MISSING_ALL_FIELDS_COUNT.getPreferredName()))
.aggregation(AggregationBuilders.sum(InferenceStats.INFERENCE_COUNT.getPreferredName()) .aggregation(AggregationBuilders.sum(InferenceStats.INFERENCE_COUNT.getPreferredName())
.field(InferenceStats.INFERENCE_COUNT.getPreferredName())) .field(InferenceStats.INFERENCE_COUNT.getPreferredName()))
.aggregation(AggregationBuilders.sum(InferenceStats.CACHE_MISS_COUNT.getPreferredName())
.field(InferenceStats.CACHE_MISS_COUNT.getPreferredName()))
.aggregation(AggregationBuilders.max(InferenceStats.TIMESTAMP.getPreferredName()) .aggregation(AggregationBuilders.max(InferenceStats.TIMESTAMP.getPreferredName())
.field(InferenceStats.TIMESTAMP.getPreferredName())) .field(InferenceStats.TIMESTAMP.getPreferredName()))
.query(queryBuilder)); .query(queryBuilder));
@ -646,12 +648,14 @@ public class TrainedModelProvider {
} }
Sum failures = response.getAggregations().get(InferenceStats.FAILURE_COUNT.getPreferredName()); Sum failures = response.getAggregations().get(InferenceStats.FAILURE_COUNT.getPreferredName());
Sum missing = response.getAggregations().get(InferenceStats.MISSING_ALL_FIELDS_COUNT.getPreferredName()); Sum missing = response.getAggregations().get(InferenceStats.MISSING_ALL_FIELDS_COUNT.getPreferredName());
Sum cacheMiss = response.getAggregations().get(InferenceStats.CACHE_MISS_COUNT.getPreferredName());
Sum count = response.getAggregations().get(InferenceStats.INFERENCE_COUNT.getPreferredName()); Sum count = response.getAggregations().get(InferenceStats.INFERENCE_COUNT.getPreferredName());
Max timeStamp = response.getAggregations().get(InferenceStats.TIMESTAMP.getPreferredName()); Max timeStamp = response.getAggregations().get(InferenceStats.TIMESTAMP.getPreferredName());
return new InferenceStats( return new InferenceStats(
missing == null ? 0L : Double.valueOf(missing.getValue()).longValue(), missing == null ? 0L : Double.valueOf(missing.getValue()).longValue(),
count == null ? 0L : Double.valueOf(count.getValue()).longValue(), count == null ? 0L : Double.valueOf(count.getValue()).longValue(),
failures == null ? 0L : Double.valueOf(failures.getValue()).longValue(), failures == null ? 0L : Double.valueOf(failures.getValue()).longValue(),
cacheMiss == null ? 0L : Double.valueOf(cacheMiss.getValue()).longValue(),
modelId, modelId,
null, null,
timeStamp == null || (Numbers.isValidDouble(timeStamp.getValue()) == false) ? timeStamp == null || (Numbers.isValidDouble(timeStamp.getValue()) == false) ?